From 2d7dfca4a4c1980d7f5cee9a0d101b78214d4486 Mon Sep 17 00:00:00 2001 From: Filip Michalsky Date: Wed, 30 Jul 2025 21:51:50 -0400 Subject: [PATCH] clean up --- .../stagehand/1_stagehand_wrapper_agent.py | 28 +- examples/stagehand/2_stagehand_tools_agent.py | 51 +-- examples/stagehand/3_stagehand_mcp_agent.py | 75 +++-- .../4_stagehand_multi_agent_workflow.py | 144 ++++---- tests/stagehand/test_stagehand_integration.py | 310 +++++++++++------- tests/stagehand/test_stagehand_simple.py | 304 +++++++++++++++++ 6 files changed, 666 insertions(+), 246 deletions(-) create mode 100644 tests/stagehand/test_stagehand_simple.py diff --git a/examples/stagehand/1_stagehand_wrapper_agent.py b/examples/stagehand/1_stagehand_wrapper_agent.py index 158549ff..c4a04906 100644 --- a/examples/stagehand/1_stagehand_wrapper_agent.py +++ b/examples/stagehand/1_stagehand_wrapper_agent.py @@ -12,7 +12,7 @@ and implements browser automation through natural language commands. import asyncio import json import os -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional from dotenv import load_dotenv from loguru import logger @@ -75,7 +75,8 @@ class StagehandAgent(SwarmsAgent): project_id=browserbase_project_id or os.getenv("BROWSERBASE_PROJECT_ID"), model_name=model_name, - model_api_key=model_api_key or os.getenv("OPENAI_API_KEY"), + model_api_key=model_api_key + or os.getenv("OPENAI_API_KEY"), ) self.stagehand = None self._initialized = False @@ -86,7 +87,9 @@ class StagehandAgent(SwarmsAgent): self.stagehand = Stagehand(self.stagehand_config) await self.stagehand.init() self._initialized = True - logger.info(f"Stagehand initialized for {self.agent_name}") + logger.info( + f"Stagehand initialized for {self.agent_name}" + ) async def _close_stagehand(self): """Close Stagehand instance.""" @@ -112,9 +115,7 @@ class StagehandAgent(SwarmsAgent): """ return asyncio.run(self._async_run(task, *args, **kwargs)) - async def _async_run( - self, task: str, *args, **kwargs - ) -> str: + async def _async_run(self, task: str, *args, **kwargs) -> str: """Async implementation of run method.""" try: await self._init_stagehand() @@ -183,9 +184,13 @@ class StagehandAgent(SwarmsAgent): elif "search" in task.lower(): # Perform search action - search_query = task.split("search for")[-1].strip().strip("'\"") + search_query = ( + task.split("search for")[-1].strip().strip("'\"") + ) # First, find the search box - search_box = await page.observe("find the search input field") + search_box = await page.observe( + "find the search input field" + ) if search_box: # Click on search box and type await page.act(f"click on {search_box[0]}") @@ -198,7 +203,10 @@ class StagehandAgent(SwarmsAgent): # Perform observation observation = await page.observe(task) result["data"]["observation"] = [ - {"description": obs.description, "selector": obs.selector} + { + "description": obs.description, + "selector": obs.selector, + } for obs in observation ] result["action"] = "observe" @@ -254,4 +262,4 @@ if __name__ == "__main__": print(result3) # Clean up - browser_agent.cleanup() \ No newline at end of file + browser_agent.cleanup() diff --git a/examples/stagehand/2_stagehand_tools_agent.py b/examples/stagehand/2_stagehand_tools_agent.py index 15cae06b..f4931ceb 100644 --- a/examples/stagehand/2_stagehand_tools_agent.py +++ b/examples/stagehand/2_stagehand_tools_agent.py @@ -3,7 +3,7 @@ Stagehand Tools for Swarms Agent ================================= This example demonstrates how to create Stagehand browser automation tools -that can be used by a standard Swarms Agent. Each Stagehand method (act, +that can be used by a standard Swarms Agent. Each Stagehand method (act, extract, observe) becomes a separate tool that the agent can use. This approach gives the agent more fine-grained control over browser @@ -13,11 +13,10 @@ automation tasks. import asyncio import json import os -from typing import Any, Dict, List, Optional, Union +from typing import Optional from dotenv import load_dotenv from loguru import logger -from pydantic import BaseModel, Field from swarms import Agent from swarms.tools.base_tool import BaseTool @@ -51,9 +50,11 @@ class BrowserState: config = StagehandConfig( env=env, api_key=api_key or os.getenv("BROWSERBASE_API_KEY"), - project_id=project_id or os.getenv("BROWSERBASE_PROJECT_ID"), + project_id=project_id + or os.getenv("BROWSERBASE_PROJECT_ID"), model_name=model_name, - model_api_key=model_api_key or os.getenv("OPENAI_API_KEY"), + model_api_key=model_api_key + or os.getenv("OPENAI_API_KEY"), ) self._stagehand = Stagehand(config) await self._stagehand.init() @@ -63,7 +64,9 @@ class BrowserState: async def get_page(self): """Get the current page instance.""" if not self._initialized: - raise RuntimeError("Browser not initialized. Call init_browser first.") + raise RuntimeError( + "Browser not initialized. Call init_browser first." + ) return self._stagehand.page async def close(self): @@ -96,11 +99,11 @@ class NavigateTool(BaseTool): try: await browser_state.init_browser() page = await browser_state.get_page() - + # Ensure URL has protocol if not url.startswith(("http://", "https://")): url = f"https://{url}" - + await page.goto(url) return f"Successfully navigated to {url}" except Exception as e: @@ -130,7 +133,7 @@ class ActTool(BaseTool): try: await browser_state.init_browser() page = await browser_state.get_page() - + result = await page.act(action) return f"Action performed: {action}. Result: {result}" except Exception as e: @@ -160,9 +163,9 @@ class ExtractTool(BaseTool): try: await browser_state.init_browser() page = await browser_state.get_page() - + extracted = await page.extract(query) - + # Convert to JSON string for agent consumption if isinstance(extracted, (dict, list)): return json.dumps(extracted, indent=2) @@ -196,18 +199,20 @@ class ObserveTool(BaseTool): try: await browser_state.init_browser() page = await browser_state.get_page() - + observations = await page.observe(query) - + # Format observations for readability result = [] for obs in observations: - result.append({ - "description": obs.description, - "selector": obs.selector, - "method": obs.method - }) - + result.append( + { + "description": obs.description, + "selector": obs.selector, + "method": obs.method, + } + ) + return json.dumps(result, indent=2) except Exception as e: logger.error(f"Observation error: {str(e)}") @@ -232,15 +237,15 @@ class ScreenshotTool(BaseTool): try: await browser_state.init_browser() page = await browser_state.get_page() - + # Ensure .png extension if not filename.endswith(".png"): filename += ".png" - + # Get the underlying Playwright page playwright_page = page.page await playwright_page.screenshot(path=filename) - + return f"Screenshot saved to {filename}" except Exception as e: logger.error(f"Screenshot error: {str(e)}") @@ -329,4 +334,4 @@ if __name__ == "__main__": print(result3) # Clean up - browser_agent.run("Close the browser") \ No newline at end of file + browser_agent.run("Close the browser") diff --git a/examples/stagehand/3_stagehand_mcp_agent.py b/examples/stagehand/3_stagehand_mcp_agent.py index 4483a497..64688490 100644 --- a/examples/stagehand/3_stagehand_mcp_agent.py +++ b/examples/stagehand/3_stagehand_mcp_agent.py @@ -22,9 +22,7 @@ Features: - Prompt templates for common tasks """ -import asyncio -import os -from typing import List, Optional +from typing import List from dotenv import load_dotenv from loguru import logger @@ -104,14 +102,23 @@ class MultiSessionBrowserSwarm: num_agents: Number of agents to create """ self.agents = [] - + # Create specialized agents for different tasks agent_roles = [ - ("DataExtractor", "You specialize in extracting structured data from websites."), - ("FormFiller", "You specialize in filling out forms and interacting with web applications."), - ("WebMonitor", "You specialize in monitoring websites for changes and capturing screenshots."), + ( + "DataExtractor", + "You specialize in extracting structured data from websites.", + ), + ( + "FormFiller", + "You specialize in filling out forms and interacting with web applications.", + ), + ( + "WebMonitor", + "You specialize in monitoring websites for changes and capturing screenshots.", + ), ] - + for i in range(min(num_agents, len(agent_roles))): name, specialization = agent_roles[i] agent = Agent( @@ -137,16 +144,18 @@ Always create your own session for tasks to work independently from other agents def distribute_tasks(self, tasks: List[str]) -> List[str]: """Distribute tasks among agents.""" results = [] - + # Distribute tasks round-robin among agents for i, task in enumerate(tasks): agent_idx = i % len(self.agents) agent = self.agents[agent_idx] - - logger.info(f"Assigning task to {agent.agent_name}: {task}") + + logger.info( + f"Assigning task to {agent.agent_name}: {task}" + ) result = agent.run(task) results.append(result) - + return results @@ -155,18 +164,20 @@ if __name__ == "__main__": print("=" * 70) print("Stagehand MCP Server Integration Examples") print("=" * 70) - print("\nMake sure the Stagehand MCP server is running on http://localhost:3000/sse") + print( + "\nMake sure the Stagehand MCP server is running on http://localhost:3000/sse" + ) print("Run: cd stagehand-mcp-server && npm start\n") - + # Example 1: Single agent with MCP tools print("\nExample 1: Single Agent with MCP Tools") print("-" * 40) - + mcp_agent = StagehandMCPAgent( agent_name="WebResearchAgent", mcp_server_url="http://localhost:3000/sse", ) - + # Research task using MCP tools result1 = mcp_agent.run( """Navigate to news.ycombinator.com and extract the following: @@ -176,18 +187,18 @@ if __name__ == "__main__": Then take a screenshot of the page.""" ) print(f"Result: {result1}") - + print("\n" + "=" * 70 + "\n") - + # Example 2: Multi-session parallel browsing print("Example 2: Multi-Session Parallel Browsing") print("-" * 40) - + parallel_agent = StagehandMCPAgent( agent_name="ParallelBrowserAgent", mcp_server_url="http://localhost:3000/sse", ) - + result2 = parallel_agent.run( """Create 3 browser sessions and perform these tasks in parallel: 1. Session 1: Go to github.com/trending and extract the top 3 trending repositories @@ -197,44 +208,44 @@ if __name__ == "__main__": After extracting data from all sessions, close them.""" ) print(f"Result: {result2}") - + print("\n" + "=" * 70 + "\n") - + # Example 3: Multi-agent browser swarm print("Example 3: Multi-Agent Browser Swarm") print("-" * 40) - + # Create a swarm of specialized browser agents browser_swarm = MultiSessionBrowserSwarm( mcp_server_url="http://localhost:3000/sse", num_agents=3, ) - + # Define tasks for the swarm swarm_tasks = [ "Create a session, navigate to python.org, and extract information about the latest Python version and its key features", "Create a session, go to npmjs.com, search for 'stagehand', and extract information about the package including version and description", "Create a session, visit playwright.dev, and extract the main features and benefits listed on the homepage", ] - + print("Distributing tasks to browser swarm...") swarm_results = browser_swarm.distribute_tasks(swarm_tasks) - + for i, result in enumerate(swarm_results): print(f"\nTask {i+1} Result: {result}") - + print("\n" + "=" * 70 + "\n") - + # Example 4: Complex workflow with session management print("Example 4: Complex Multi-Page Workflow") print("-" * 40) - + workflow_agent = StagehandMCPAgent( agent_name="WorkflowAgent", mcp_server_url="http://localhost:3000/sse", max_loops=2, # Allow more complex reasoning ) - + result4 = workflow_agent.run( """Perform a comprehensive analysis of AI frameworks: 1. Create a new session @@ -246,7 +257,7 @@ if __name__ == "__main__": 7. Close the session when done""" ) print(f"Result: {result4}") - + print("\n" + "=" * 70) print("All examples completed!") - print("=" * 70) \ No newline at end of file + print("=" * 70) diff --git a/examples/stagehand/4_stagehand_multi_agent_workflow.py b/examples/stagehand/4_stagehand_multi_agent_workflow.py index 31bb1d21..4f8f8433 100644 --- a/examples/stagehand/4_stagehand_multi_agent_workflow.py +++ b/examples/stagehand/4_stagehand_multi_agent_workflow.py @@ -13,14 +13,10 @@ Use cases: 4. Data aggregation from multiple sources """ -import asyncio -import json -import os from datetime import datetime -from typing import Any, Dict, List +from typing import Dict, List, Optional from dotenv import load_dotenv -from loguru import logger from pydantic import BaseModel, Field from swarms import Agent, SequentialWorkflow, ConcurrentWorkflow @@ -33,42 +29,48 @@ load_dotenv() # Pydantic models for structured data class ProductInfo(BaseModel): """Product information schema.""" + name: str = Field(..., description="Product name") price: float = Field(..., description="Product price") availability: str = Field(..., description="Availability status") url: str = Field(..., description="Product URL") - screenshot_path: Optional[str] = Field(None, description="Screenshot file path") + screenshot_path: Optional[str] = Field( + None, description="Screenshot file path" + ) class MarketAnalysis(BaseModel): """Market analysis report schema.""" + timestamp: datetime = Field(default_factory=datetime.now) - products: List[ProductInfo] = Field(..., description="List of products analyzed") - price_range: Dict[str, float] = Field(..., description="Min and max prices") - recommendations: List[str] = Field(..., description="Analysis recommendations") + products: List[ProductInfo] = Field( + ..., description="List of products analyzed" + ) + price_range: Dict[str, float] = Field( + ..., description="Min and max prices" + ) + recommendations: List[str] = Field( + ..., description="Analysis recommendations" + ) # Specialized browser agents class ProductScraperAgent(StagehandAgent): """Specialized agent for scraping product information.""" - + def __init__(self, site_name: str, *args, **kwargs): super().__init__( - agent_name=f"ProductScraper_{site_name}", - *args, - **kwargs + agent_name=f"ProductScraper_{site_name}", *args, **kwargs ) self.site_name = site_name class PriceMonitorAgent(StagehandAgent): """Specialized agent for monitoring price changes.""" - + def __init__(self, *args, **kwargs): super().__init__( - agent_name="PriceMonitorAgent", - *args, - **kwargs + agent_name="PriceMonitorAgent", *args, **kwargs ) @@ -77,20 +79,20 @@ def create_price_comparison_workflow(): """ Create a workflow that compares prices across multiple e-commerce sites. """ - + # Create specialized agents for different sites amazon_agent = StagehandAgent( agent_name="AmazonScraperAgent", model_name="gpt-4o-mini", env="LOCAL", ) - + ebay_agent = StagehandAgent( agent_name="EbayScraperAgent", model_name="gpt-4o-mini", env="LOCAL", ) - + analysis_agent = Agent( agent_name="PriceAnalysisAgent", model_name="gpt-4o-mini", @@ -98,21 +100,21 @@ def create_price_comparison_workflow(): and provide insights on the best deals, price trends, and recommendations. Focus on value for money and highlight any significant price differences.""", ) - + # Create concurrent workflow for parallel scraping scraping_workflow = ConcurrentWorkflow( agents=[amazon_agent, ebay_agent], max_loops=1, verbose=True, ) - + # Create sequential workflow: scrape -> analyze full_workflow = SequentialWorkflow( agents=[scraping_workflow, analysis_agent], max_loops=1, verbose=True, ) - + return full_workflow @@ -121,21 +123,21 @@ def create_competitive_analysis_workflow(): """ Create a workflow for competitive analysis across multiple company websites. """ - + # Agent for extracting company information company_researcher = StagehandAgent( agent_name="CompanyResearchAgent", model_name="gpt-4o-mini", env="LOCAL", ) - + # Agent for analyzing social media presence social_media_agent = StagehandAgent( agent_name="SocialMediaAnalysisAgent", model_name="gpt-4o-mini", env="LOCAL", ) - + # Agent for compiling competitive analysis report report_compiler = Agent( agent_name="CompetitiveAnalysisReporter", @@ -144,16 +146,22 @@ def create_competitive_analysis_workflow(): based on company information and social media presence data. Identify strengths, weaknesses, and market positioning for each company.""", ) - + # Create agent rearrange for flexible routing - workflow_pattern = "company_researcher -> social_media_agent -> report_compiler" - + workflow_pattern = ( + "company_researcher -> social_media_agent -> report_compiler" + ) + competitive_workflow = AgentRearrange( - agents=[company_researcher, social_media_agent, report_compiler], + agents=[ + company_researcher, + social_media_agent, + report_compiler, + ], flow=workflow_pattern, verbose=True, ) - + return competitive_workflow @@ -162,28 +170,28 @@ def create_automated_testing_workflow(): """ Create a workflow for automated web application testing. """ - + # Agent for UI testing ui_tester = StagehandAgent( agent_name="UITestingAgent", model_name="gpt-4o-mini", env="LOCAL", ) - + # Agent for form validation testing form_tester = StagehandAgent( agent_name="FormValidationAgent", model_name="gpt-4o-mini", env="LOCAL", ) - + # Agent for accessibility testing accessibility_tester = StagehandAgent( agent_name="AccessibilityTestingAgent", model_name="gpt-4o-mini", env="LOCAL", ) - + # Agent for compiling test results test_reporter = Agent( agent_name="TestReportCompiler", @@ -192,20 +200,20 @@ def create_automated_testing_workflow(): UI, form validation, and accessibility testing into a comprehensive report. Highlight any failures, warnings, and provide recommendations for fixes.""", ) - + # Concurrent testing followed by report generation testing_workflow = ConcurrentWorkflow( agents=[ui_tester, form_tester, accessibility_tester], max_loops=1, verbose=True, ) - + full_test_workflow = SequentialWorkflow( agents=[testing_workflow, test_reporter], max_loops=1, verbose=True, ) - + return full_test_workflow @@ -214,7 +222,7 @@ def create_news_aggregation_workflow(): """ Create a workflow for news aggregation and sentiment analysis. """ - + # Multiple news scraper agents news_scrapers = [] news_sites = [ @@ -222,7 +230,7 @@ def create_news_aggregation_workflow(): ("HackerNews", "https://news.ycombinator.com"), ("Reddit", "https://reddit.com/r/technology"), ] - + for site_name, url in news_sites: scraper = StagehandAgent( agent_name=f"{site_name}Scraper", @@ -230,7 +238,7 @@ def create_news_aggregation_workflow(): env="LOCAL", ) news_scrapers.append(scraper) - + # Sentiment analysis agent sentiment_analyzer = Agent( agent_name="SentimentAnalyzer", @@ -239,7 +247,7 @@ def create_news_aggregation_workflow(): to determine overall sentiment (positive, negative, neutral) and identify key themes and trends in the technology sector.""", ) - + # Trend identification agent trend_identifier = Agent( agent_name="TrendIdentifier", @@ -248,20 +256,24 @@ def create_news_aggregation_workflow(): data, identify emerging trends, hot topics, and potential market movements in the technology sector.""", ) - + # Create workflow: parallel scraping -> sentiment analysis -> trend identification scraping_workflow = ConcurrentWorkflow( agents=news_scrapers, max_loops=1, verbose=True, ) - + analysis_workflow = SequentialWorkflow( - agents=[scraping_workflow, sentiment_analyzer, trend_identifier], + agents=[ + scraping_workflow, + sentiment_analyzer, + trend_identifier, + ], max_loops=1, verbose=True, ) - + return analysis_workflow @@ -270,13 +282,13 @@ if __name__ == "__main__": print("=" * 70) print("Stagehand Multi-Agent Workflow Examples") print("=" * 70) - + # Example 1: Price Comparison print("\nExample 1: E-commerce Price Comparison") print("-" * 40) - + price_workflow = create_price_comparison_workflow() - + # Search for a specific product across multiple sites price_result = price_workflow.run( """Search for 'iPhone 15 Pro Max 256GB' on: @@ -286,15 +298,15 @@ if __name__ == "__main__": Compare the prices and provide recommendations on where to buy.""" ) print(f"Price Comparison Result:\n{price_result}") - + print("\n" + "=" * 70 + "\n") - + # Example 2: Competitive Analysis print("Example 2: Competitive Analysis") print("-" * 40) - + competitive_workflow = create_competitive_analysis_workflow() - + competitive_result = competitive_workflow.run( """Analyze these three AI companies: 1. OpenAI - visit openai.com and extract mission, products, and recent announcements @@ -305,15 +317,15 @@ if __name__ == "__main__": Compile a competitive analysis report comparing their market positioning.""" ) print(f"Competitive Analysis Result:\n{competitive_result}") - + print("\n" + "=" * 70 + "\n") - + # Example 3: Automated Testing print("Example 3: Automated Web Testing") print("-" * 40) - + testing_workflow = create_automated_testing_workflow() - + test_result = testing_workflow.run( """Test the website example.com: 1. UI Testing: Check if all main navigation links work, images load, and layout is responsive @@ -323,15 +335,15 @@ if __name__ == "__main__": Take screenshots of any issues found and compile a comprehensive test report.""" ) print(f"Test Results:\n{test_result}") - + print("\n" + "=" * 70 + "\n") - + # Example 4: News Aggregation print("Example 4: Tech News Aggregation and Analysis") print("-" * 40) - + news_workflow = create_news_aggregation_workflow() - + news_result = news_workflow.run( """For each news source: 1. TechCrunch: Extract the top 5 headlines about AI or machine learning @@ -341,19 +353,19 @@ if __name__ == "__main__": Analyze sentiment and identify emerging trends in AI technology.""" ) print(f"News Analysis Result:\n{news_result}") - + # Cleanup all browser instances print("\n" + "=" * 70) print("Cleaning up browser instances...") - + # Clean up agents for agent in price_workflow.agents: if isinstance(agent, StagehandAgent): agent.cleanup() - elif hasattr(agent, 'agents'): # For nested workflows + elif hasattr(agent, "agents"): # For nested workflows for sub_agent in agent.agents: if isinstance(sub_agent, StagehandAgent): sub_agent.cleanup() - + print("All workflows completed!") - print("=" * 70) \ No newline at end of file + print("=" * 70) diff --git a/tests/stagehand/test_stagehand_integration.py b/tests/stagehand/test_stagehand_integration.py index c1e913ae..d2048d11 100644 --- a/tests/stagehand/test_stagehand_integration.py +++ b/tests/stagehand/test_stagehand_integration.py @@ -6,13 +6,9 @@ This module contains tests for the Stagehand browser automation integration with the Swarms framework. """ -import asyncio import json import pytest -from unittest.mock import AsyncMock, MagicMock, patch - -from swarms import Agent -from swarms.tools.base_tool import BaseTool +from unittest.mock import AsyncMock, patch # Mock Stagehand classes @@ -26,13 +22,13 @@ class MockObserveResult: class MockStagehandPage: async def goto(self, url): return None - + async def act(self, action): return f"Performed action: {action}" - + async def extract(self, query): return {"extracted": query, "data": ["item1", "item2"]} - + async def observe(self, query): return [ MockObserveResult("Search box", "#search-input"), @@ -44,10 +40,10 @@ class MockStagehand: def __init__(self, config): self.config = config self.page = MockStagehandPage() - + async def init(self): pass - + async def close(self): pass @@ -55,79 +51,106 @@ class MockStagehand: # Test StagehandAgent wrapper class TestStagehandAgent: """Test the StagehandAgent wrapper class.""" - - @patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) + + @patch( + "examples.stagehand.stagehand_wrapper_agent.Stagehand", + MockStagehand, + ) def test_agent_initialization(self): """Test that StagehandAgent initializes correctly.""" - from examples.stagehand.stagehand_wrapper_agent import StagehandAgent - + from examples.stagehand.stagehand_wrapper_agent import ( + StagehandAgent, + ) + agent = StagehandAgent( agent_name="TestAgent", model_name="gpt-4o-mini", env="LOCAL", ) - + assert agent.agent_name == "TestAgent" assert agent.stagehand_config.env == "LOCAL" assert agent.stagehand_config.model_name == "gpt-4o-mini" assert not agent._initialized - - @patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) + + @patch( + "examples.stagehand.stagehand_wrapper_agent.Stagehand", + MockStagehand, + ) def test_navigation_task(self): """Test navigation and extraction task.""" - from examples.stagehand.stagehand_wrapper_agent import StagehandAgent - + from examples.stagehand.stagehand_wrapper_agent import ( + StagehandAgent, + ) + agent = StagehandAgent( agent_name="TestAgent", model_name="gpt-4o-mini", env="LOCAL", ) - - result = agent.run("Navigate to example.com and extract the main content") - + + result = agent.run( + "Navigate to example.com and extract the main content" + ) + # Parse result result_data = json.loads(result) assert result_data["status"] == "completed" assert "navigated_to" in result_data["data"] - assert result_data["data"]["navigated_to"] == "https://example.com" + assert ( + result_data["data"]["navigated_to"] + == "https://example.com" + ) assert "extracted" in result_data["data"] - - @patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) + + @patch( + "examples.stagehand.stagehand_wrapper_agent.Stagehand", + MockStagehand, + ) def test_search_task(self): """Test search functionality.""" - from examples.stagehand.stagehand_wrapper_agent import StagehandAgent - + from examples.stagehand.stagehand_wrapper_agent import ( + StagehandAgent, + ) + agent = StagehandAgent( agent_name="TestAgent", model_name="gpt-4o-mini", env="LOCAL", ) - - result = agent.run("Go to google.com and search for 'test query'") - + + result = agent.run( + "Go to google.com and search for 'test query'" + ) + result_data = json.loads(result) assert result_data["status"] == "completed" assert result_data["data"]["search_query"] == "test query" assert result_data["action"] == "search" - - @patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) + + @patch( + "examples.stagehand.stagehand_wrapper_agent.Stagehand", + MockStagehand, + ) def test_cleanup(self): """Test that cleanup properly closes browser.""" - from examples.stagehand.stagehand_wrapper_agent import StagehandAgent - + from examples.stagehand.stagehand_wrapper_agent import ( + StagehandAgent, + ) + agent = StagehandAgent( agent_name="TestAgent", model_name="gpt-4o-mini", env="LOCAL", ) - + # Initialize the agent agent.run("Navigate to example.com") assert agent._initialized - + # Cleanup agent.cleanup() - + # After cleanup, should be able to run again result = agent.run("Navigate to example.com") assert result is not None @@ -136,65 +159,84 @@ class TestStagehandAgent: # Test Stagehand Tools class TestStagehandTools: """Test individual Stagehand tools.""" - - @patch('examples.stagehand.stagehand_tools_agent.browser_state') + + @patch("examples.stagehand.stagehand_tools_agent.browser_state") async def test_navigate_tool(self, mock_browser_state): """Test NavigateTool functionality.""" - from examples.stagehand.stagehand_tools_agent import NavigateTool - + from examples.stagehand.stagehand_tools_agent import ( + NavigateTool, + ) + # Setup mock mock_page = AsyncMock() - mock_browser_state.get_page = AsyncMock(return_value=mock_page) + mock_browser_state.get_page = AsyncMock( + return_value=mock_page + ) mock_browser_state.init_browser = AsyncMock() - + tool = NavigateTool() result = await tool._async_run("https://example.com") - - assert "Successfully navigated to https://example.com" in result + + assert ( + "Successfully navigated to https://example.com" in result + ) mock_page.goto.assert_called_once_with("https://example.com") - - @patch('examples.stagehand.stagehand_tools_agent.browser_state') + + @patch("examples.stagehand.stagehand_tools_agent.browser_state") async def test_act_tool(self, mock_browser_state): """Test ActTool functionality.""" from examples.stagehand.stagehand_tools_agent import ActTool - + # Setup mock mock_page = AsyncMock() mock_page.act = AsyncMock(return_value="Action completed") - mock_browser_state.get_page = AsyncMock(return_value=mock_page) + mock_browser_state.get_page = AsyncMock( + return_value=mock_page + ) mock_browser_state.init_browser = AsyncMock() - + tool = ActTool() result = await tool._async_run("click the button") - + assert "Action performed" in result assert "click the button" in result mock_page.act.assert_called_once_with("click the button") - - @patch('examples.stagehand.stagehand_tools_agent.browser_state') + + @patch("examples.stagehand.stagehand_tools_agent.browser_state") async def test_extract_tool(self, mock_browser_state): """Test ExtractTool functionality.""" - from examples.stagehand.stagehand_tools_agent import ExtractTool - + from examples.stagehand.stagehand_tools_agent import ( + ExtractTool, + ) + # Setup mock mock_page = AsyncMock() - mock_page.extract = AsyncMock(return_value={"title": "Test Page", "content": "Test content"}) - mock_browser_state.get_page = AsyncMock(return_value=mock_page) + mock_page.extract = AsyncMock( + return_value={ + "title": "Test Page", + "content": "Test content", + } + ) + mock_browser_state.get_page = AsyncMock( + return_value=mock_page + ) mock_browser_state.init_browser = AsyncMock() - + tool = ExtractTool() result = await tool._async_run("extract the page title") - + # Result should be JSON string parsed_result = json.loads(result) assert parsed_result["title"] == "Test Page" assert parsed_result["content"] == "Test content" - - @patch('examples.stagehand.stagehand_tools_agent.browser_state') + + @patch("examples.stagehand.stagehand_tools_agent.browser_state") async def test_observe_tool(self, mock_browser_state): """Test ObserveTool functionality.""" - from examples.stagehand.stagehand_tools_agent import ObserveTool - + from examples.stagehand.stagehand_tools_agent import ( + ObserveTool, + ) + # Setup mock mock_page = AsyncMock() mock_observations = [ @@ -202,12 +244,14 @@ class TestStagehandTools: MockObserveResult("Submit button", "#submit"), ] mock_page.observe = AsyncMock(return_value=mock_observations) - mock_browser_state.get_page = AsyncMock(return_value=mock_page) + mock_browser_state.get_page = AsyncMock( + return_value=mock_page + ) mock_browser_state.init_browser = AsyncMock() - + tool = ObserveTool() result = await tool._async_run("find the search box") - + # Result should be JSON string parsed_result = json.loads(result) assert len(parsed_result) == 2 @@ -218,47 +262,53 @@ class TestStagehandTools: # Test MCP integration class TestStagehandMCP: """Test Stagehand MCP server integration.""" - + def test_mcp_agent_initialization(self): """Test that MCP agent initializes with correct parameters.""" - from examples.stagehand.stagehand_mcp_agent import StagehandMCPAgent - + from examples.stagehand.stagehand_mcp_agent import ( + StagehandMCPAgent, + ) + mcp_agent = StagehandMCPAgent( agent_name="TestMCPAgent", mcp_server_url="http://localhost:3000/sse", model_name="gpt-4o-mini", ) - + assert mcp_agent.agent.agent_name == "TestMCPAgent" assert mcp_agent.agent.mcp_url == "http://localhost:3000/sse" assert mcp_agent.agent.model_name == "gpt-4o-mini" - + def test_multi_session_swarm_creation(self): """Test multi-session browser swarm creation.""" - from examples.stagehand.stagehand_mcp_agent import MultiSessionBrowserSwarm - + from examples.stagehand.stagehand_mcp_agent import ( + MultiSessionBrowserSwarm, + ) + swarm = MultiSessionBrowserSwarm( mcp_server_url="http://localhost:3000/sse", num_agents=3, ) - + assert len(swarm.agents) == 3 assert swarm.agents[0].agent_name == "DataExtractor_0" assert swarm.agents[1].agent_name == "FormFiller_1" assert swarm.agents[2].agent_name == "WebMonitor_2" - - @patch('swarms.Agent.run') + + @patch("swarms.Agent.run") def test_task_distribution(self, mock_run): """Test task distribution among swarm agents.""" - from examples.stagehand.stagehand_mcp_agent import MultiSessionBrowserSwarm - + from examples.stagehand.stagehand_mcp_agent import ( + MultiSessionBrowserSwarm, + ) + mock_run.return_value = "Task completed" - + swarm = MultiSessionBrowserSwarm(num_agents=2) tasks = ["Task 1", "Task 2", "Task 3"] - + results = swarm.distribute_tasks(tasks) - + assert len(results) == 3 assert all(result == "Task completed" for result in results) assert mock_run.call_count == 3 @@ -267,90 +317,120 @@ class TestStagehandMCP: # Test multi-agent workflows class TestMultiAgentWorkflows: """Test multi-agent workflow configurations.""" - - @patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) + + @patch( + "examples.stagehand.stagehand_wrapper_agent.Stagehand", + MockStagehand, + ) def test_price_comparison_workflow_creation(self): """Test creation of price comparison workflow.""" - from examples.stagehand.stagehand_multi_agent_workflow import create_price_comparison_workflow - + from examples.stagehand.stagehand_multi_agent_workflow import ( + create_price_comparison_workflow, + ) + workflow = create_price_comparison_workflow() - + # Should be a SequentialWorkflow with 2 agents assert len(workflow.agents) == 2 # First agent should be a ConcurrentWorkflow - assert hasattr(workflow.agents[0], 'agents') + assert hasattr(workflow.agents[0], "agents") # Second agent should be the analysis agent assert workflow.agents[1].agent_name == "PriceAnalysisAgent" - - @patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) + + @patch( + "examples.stagehand.stagehand_wrapper_agent.Stagehand", + MockStagehand, + ) def test_competitive_analysis_workflow_creation(self): """Test creation of competitive analysis workflow.""" - from examples.stagehand.stagehand_multi_agent_workflow import create_competitive_analysis_workflow - + from examples.stagehand.stagehand_multi_agent_workflow import ( + create_competitive_analysis_workflow, + ) + workflow = create_competitive_analysis_workflow() - + # Should have 3 agents in the rearrange pattern assert len(workflow.agents) == 3 - assert workflow.flow == "company_researcher -> social_media_agent -> report_compiler" - - @patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) + assert ( + workflow.flow + == "company_researcher -> social_media_agent -> report_compiler" + ) + + @patch( + "examples.stagehand.stagehand_wrapper_agent.Stagehand", + MockStagehand, + ) def test_automated_testing_workflow_creation(self): """Test creation of automated testing workflow.""" - from examples.stagehand.stagehand_multi_agent_workflow import create_automated_testing_workflow - + from examples.stagehand.stagehand_multi_agent_workflow import ( + create_automated_testing_workflow, + ) + workflow = create_automated_testing_workflow() - + # Should be a SequentialWorkflow assert len(workflow.agents) == 2 # First should be concurrent testing - assert hasattr(workflow.agents[0], 'agents') - assert len(workflow.agents[0].agents) == 3 # UI, Form, Accessibility testers - - @patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) + assert hasattr(workflow.agents[0], "agents") + assert ( + len(workflow.agents[0].agents) == 3 + ) # UI, Form, Accessibility testers + + @patch( + "examples.stagehand.stagehand_wrapper_agent.Stagehand", + MockStagehand, + ) def test_news_aggregation_workflow_creation(self): """Test creation of news aggregation workflow.""" - from examples.stagehand.stagehand_multi_agent_workflow import create_news_aggregation_workflow - + from examples.stagehand.stagehand_multi_agent_workflow import ( + create_news_aggregation_workflow, + ) + workflow = create_news_aggregation_workflow() - + # Should be a SequentialWorkflow with 3 stages assert len(workflow.agents) == 3 # First stage should be concurrent scrapers - assert hasattr(workflow.agents[0], 'agents') + assert hasattr(workflow.agents[0], "agents") assert len(workflow.agents[0].agents) == 3 # 3 news sources # Integration tests class TestIntegration: """End-to-end integration tests.""" - + @pytest.mark.asyncio - @patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) + @patch( + "examples.stagehand.stagehand_wrapper_agent.Stagehand", + MockStagehand, + ) async def test_full_browser_automation_flow(self): """Test a complete browser automation flow.""" - from examples.stagehand.stagehand_wrapper_agent import StagehandAgent - + from examples.stagehand.stagehand_wrapper_agent import ( + StagehandAgent, + ) + agent = StagehandAgent( agent_name="IntegrationTestAgent", model_name="gpt-4o-mini", env="LOCAL", ) - + # Test navigation nav_result = agent.run("Navigate to example.com") assert "navigated_to" in nav_result - + # Test extraction extract_result = agent.run("Extract all text from the page") assert "extracted" in extract_result - + # Test observation observe_result = agent.run("Find all buttons on the page") assert "observation" in observe_result - + # Cleanup agent.cleanup() if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file + pytest.main([__file__, "-v"]) diff --git a/tests/stagehand/test_stagehand_simple.py b/tests/stagehand/test_stagehand_simple.py new file mode 100644 index 00000000..9a220f1c --- /dev/null +++ b/tests/stagehand/test_stagehand_simple.py @@ -0,0 +1,304 @@ +""" +Simple tests for Stagehand Integration with Swarms +================================================= + +These tests verify the basic structure and functionality of the +Stagehand integration without requiring external dependencies. +""" + +import json +import pytest +from unittest.mock import MagicMock + + +class TestStagehandIntegrationStructure: + """Test that integration files have correct structure.""" + + def test_examples_directory_exists(self): + """Test that examples directory structure is correct.""" + import os + + base_path = "examples/stagehand" + assert os.path.exists(base_path) + + expected_files = [ + "1_stagehand_wrapper_agent.py", + "2_stagehand_tools_agent.py", + "3_stagehand_mcp_agent.py", + "4_stagehand_multi_agent_workflow.py", + "README.md", + "requirements.txt", + ] + + for file in expected_files: + file_path = os.path.join(base_path, file) + assert os.path.exists(file_path), f"Missing file: {file}" + + def test_wrapper_agent_imports(self): + """Test that wrapper agent has correct imports.""" + with open( + "examples/stagehand/1_stagehand_wrapper_agent.py", "r" + ) as f: + content = f.read() + + # Check for required imports + assert "from swarms import Agent" in content + assert "import asyncio" in content + assert "import json" in content + assert "class StagehandAgent" in content + + def test_tools_agent_imports(self): + """Test that tools agent has correct imports.""" + with open( + "examples/stagehand/2_stagehand_tools_agent.py", "r" + ) as f: + content = f.read() + + # Check for required imports + assert ( + "from swarms.tools.base_tool import BaseTool" in content + ) + assert "class NavigateTool" in content + assert "class ActTool" in content + assert "class ExtractTool" in content + + def test_mcp_agent_imports(self): + """Test that MCP agent has correct imports.""" + with open( + "examples/stagehand/3_stagehand_mcp_agent.py", "r" + ) as f: + content = f.read() + + # Check for required imports + assert "from swarms import Agent" in content + assert "class StagehandMCPAgent" in content + assert "mcp_url" in content + + def test_workflow_agent_imports(self): + """Test that workflow agent has correct imports.""" + with open( + "examples/stagehand/4_stagehand_multi_agent_workflow.py", + "r", + ) as f: + content = f.read() + + # Check for required imports + assert ( + "from swarms import Agent, SequentialWorkflow, ConcurrentWorkflow" + in content + ) + assert ( + "from swarms.structs.agent_rearrange import AgentRearrange" + in content + ) + + +class TestStagehandMockIntegration: + """Test Stagehand integration with mocked dependencies.""" + + def test_mock_stagehand_initialization(self): + """Test that Stagehand can be mocked and initialized.""" + + # Setup mock without importing actual stagehand + mock_stagehand = MagicMock() + mock_instance = MagicMock() + mock_instance.init = MagicMock() + mock_stagehand.return_value = mock_instance + + # Mock config creation + config = MagicMock() + stagehand_instance = mock_stagehand(config) + + # Verify mock works + assert stagehand_instance is not None + assert hasattr(stagehand_instance, "init") + + def test_json_serialization(self): + """Test JSON serialization for agent responses.""" + + # Test data that would come from browser automation + test_data = { + "task": "Navigate to example.com", + "status": "completed", + "data": { + "navigated_to": "https://example.com", + "extracted": ["item1", "item2"], + "action": "navigate", + }, + } + + # Test serialization + json_result = json.dumps(test_data, indent=2) + assert isinstance(json_result, str) + + # Test deserialization + parsed_data = json.loads(json_result) + assert parsed_data["task"] == "Navigate to example.com" + assert parsed_data["status"] == "completed" + assert len(parsed_data["data"]["extracted"]) == 2 + + def test_url_extraction_logic(self): + """Test URL extraction logic from task strings.""" + import re + + # Test cases + test_cases = [ + ( + "Navigate to https://example.com", + ["https://example.com"], + ), + ("Go to google.com and search", ["google.com"]), + ( + "Visit https://github.com/repo", + ["https://github.com/repo"], + ), + ("Open example.org", ["example.org"]), + ] + + url_pattern = r"https?://[^\s]+" + domain_pattern = r"(\w+\.\w+)" + + for task, expected in test_cases: + # Extract full URLs + urls = re.findall(url_pattern, task) + + # If no full URLs, extract domains + if not urls: + domains = re.findall(domain_pattern, task) + if domains: + urls = domains + + assert ( + len(urls) > 0 + ), f"Failed to extract URL from: {task}" + assert ( + urls[0] in expected + ), f"Expected {expected}, got {urls}" + + +class TestSwarmsPatternsCompliance: + """Test compliance with Swarms framework patterns.""" + + def test_agent_inheritance_pattern(self): + """Test that wrapper agent follows Swarms Agent inheritance pattern.""" + + # Read the wrapper agent file + with open( + "examples/stagehand/1_stagehand_wrapper_agent.py", "r" + ) as f: + content = f.read() + + # Check inheritance pattern + assert "class StagehandAgent(SwarmsAgent):" in content + assert "def run(self, task: str" in content + assert "return" in content + + def test_tools_pattern(self): + """Test that tools follow Swarms BaseTool pattern.""" + + # Read the tools agent file + with open( + "examples/stagehand/2_stagehand_tools_agent.py", "r" + ) as f: + content = f.read() + + # Check tool pattern + assert "class NavigateTool(BaseTool):" in content + assert "def run(self," in content + assert "name=" in content + assert "description=" in content + + def test_mcp_integration_pattern(self): + """Test MCP integration follows Swarms pattern.""" + + # Read the MCP agent file + with open( + "examples/stagehand/3_stagehand_mcp_agent.py", "r" + ) as f: + content = f.read() + + # Check MCP pattern + assert "mcp_url=" in content + assert "Agent(" in content + + def test_workflow_patterns(self): + """Test workflow patterns are properly used.""" + + # Read the workflow file + with open( + "examples/stagehand/4_stagehand_multi_agent_workflow.py", + "r", + ) as f: + content = f.read() + + # Check workflow patterns + assert "SequentialWorkflow" in content + assert "ConcurrentWorkflow" in content + assert "AgentRearrange" in content + + +class TestDocumentationAndExamples: + """Test documentation and example completeness.""" + + def test_readme_completeness(self): + """Test that README contains essential information.""" + + with open("examples/stagehand/README.md", "r") as f: + content = f.read() + + required_sections = [ + "# Stagehand Browser Automation Integration", + "## Overview", + "## Examples", + "## Setup", + "## Use Cases", + "## Best Practices", + ] + + for section in required_sections: + assert section in content, f"Missing section: {section}" + + def test_requirements_file(self): + """Test that requirements file has necessary dependencies.""" + + with open("examples/stagehand/requirements.txt", "r") as f: + content = f.read() + + required_deps = [ + "swarms", + "stagehand", + "python-dotenv", + "pydantic", + "loguru", + ] + + for dep in required_deps: + assert dep in content, f"Missing dependency: {dep}" + + def test_example_files_have_docstrings(self): + """Test that example files have proper docstrings.""" + + example_files = [ + "examples/stagehand/1_stagehand_wrapper_agent.py", + "examples/stagehand/2_stagehand_tools_agent.py", + "examples/stagehand/3_stagehand_mcp_agent.py", + "examples/stagehand/4_stagehand_multi_agent_workflow.py", + ] + + for file_path in example_files: + with open(file_path, "r") as f: + content = f.read() + + # Check for module docstring + assert ( + '"""' in content[:500] + ), f"Missing docstring in {file_path}" + + # Check for main execution block + assert ( + 'if __name__ == "__main__":' in content + ), f"Missing main block in {file_path}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])