pull/1000/head
Filip Michalsky 1 month ago
parent b04e60ca17
commit 2d7dfca4a4

@ -12,7 +12,7 @@ and implements browser automation through natural language commands.
import asyncio import asyncio
import json import json
import os import os
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, Optional
from dotenv import load_dotenv from dotenv import load_dotenv
from loguru import logger from loguru import logger
@ -75,7 +75,8 @@ class StagehandAgent(SwarmsAgent):
project_id=browserbase_project_id project_id=browserbase_project_id
or os.getenv("BROWSERBASE_PROJECT_ID"), or os.getenv("BROWSERBASE_PROJECT_ID"),
model_name=model_name, model_name=model_name,
model_api_key=model_api_key or os.getenv("OPENAI_API_KEY"), model_api_key=model_api_key
or os.getenv("OPENAI_API_KEY"),
) )
self.stagehand = None self.stagehand = None
self._initialized = False self._initialized = False
@ -86,7 +87,9 @@ class StagehandAgent(SwarmsAgent):
self.stagehand = Stagehand(self.stagehand_config) self.stagehand = Stagehand(self.stagehand_config)
await self.stagehand.init() await self.stagehand.init()
self._initialized = True self._initialized = True
logger.info(f"Stagehand initialized for {self.agent_name}") logger.info(
f"Stagehand initialized for {self.agent_name}"
)
async def _close_stagehand(self): async def _close_stagehand(self):
"""Close Stagehand instance.""" """Close Stagehand instance."""
@ -112,9 +115,7 @@ class StagehandAgent(SwarmsAgent):
""" """
return asyncio.run(self._async_run(task, *args, **kwargs)) return asyncio.run(self._async_run(task, *args, **kwargs))
async def _async_run( async def _async_run(self, task: str, *args, **kwargs) -> str:
self, task: str, *args, **kwargs
) -> str:
"""Async implementation of run method.""" """Async implementation of run method."""
try: try:
await self._init_stagehand() await self._init_stagehand()
@ -183,9 +184,13 @@ class StagehandAgent(SwarmsAgent):
elif "search" in task.lower(): elif "search" in task.lower():
# Perform search action # Perform search action
search_query = task.split("search for")[-1].strip().strip("'\"") search_query = (
task.split("search for")[-1].strip().strip("'\"")
)
# First, find the search box # First, find the search box
search_box = await page.observe("find the search input field") search_box = await page.observe(
"find the search input field"
)
if search_box: if search_box:
# Click on search box and type # Click on search box and type
await page.act(f"click on {search_box[0]}") await page.act(f"click on {search_box[0]}")
@ -198,7 +203,10 @@ class StagehandAgent(SwarmsAgent):
# Perform observation # Perform observation
observation = await page.observe(task) observation = await page.observe(task)
result["data"]["observation"] = [ result["data"]["observation"] = [
{"description": obs.description, "selector": obs.selector} {
"description": obs.description,
"selector": obs.selector,
}
for obs in observation for obs in observation
] ]
result["action"] = "observe" result["action"] = "observe"

@ -13,11 +13,10 @@ automation tasks.
import asyncio import asyncio
import json import json
import os import os
from typing import Any, Dict, List, Optional, Union from typing import Optional
from dotenv import load_dotenv from dotenv import load_dotenv
from loguru import logger from loguru import logger
from pydantic import BaseModel, Field
from swarms import Agent from swarms import Agent
from swarms.tools.base_tool import BaseTool from swarms.tools.base_tool import BaseTool
@ -51,9 +50,11 @@ class BrowserState:
config = StagehandConfig( config = StagehandConfig(
env=env, env=env,
api_key=api_key or os.getenv("BROWSERBASE_API_KEY"), api_key=api_key or os.getenv("BROWSERBASE_API_KEY"),
project_id=project_id or os.getenv("BROWSERBASE_PROJECT_ID"), project_id=project_id
or os.getenv("BROWSERBASE_PROJECT_ID"),
model_name=model_name, model_name=model_name,
model_api_key=model_api_key or os.getenv("OPENAI_API_KEY"), model_api_key=model_api_key
or os.getenv("OPENAI_API_KEY"),
) )
self._stagehand = Stagehand(config) self._stagehand = Stagehand(config)
await self._stagehand.init() await self._stagehand.init()
@ -63,7 +64,9 @@ class BrowserState:
async def get_page(self): async def get_page(self):
"""Get the current page instance.""" """Get the current page instance."""
if not self._initialized: if not self._initialized:
raise RuntimeError("Browser not initialized. Call init_browser first.") raise RuntimeError(
"Browser not initialized. Call init_browser first."
)
return self._stagehand.page return self._stagehand.page
async def close(self): async def close(self):
@ -202,11 +205,13 @@ class ObserveTool(BaseTool):
# Format observations for readability # Format observations for readability
result = [] result = []
for obs in observations: for obs in observations:
result.append({ result.append(
"description": obs.description, {
"selector": obs.selector, "description": obs.description,
"method": obs.method "selector": obs.selector,
}) "method": obs.method,
}
)
return json.dumps(result, indent=2) return json.dumps(result, indent=2)
except Exception as e: except Exception as e:

@ -22,9 +22,7 @@ Features:
- Prompt templates for common tasks - Prompt templates for common tasks
""" """
import asyncio from typing import List
import os
from typing import List, Optional
from dotenv import load_dotenv from dotenv import load_dotenv
from loguru import logger from loguru import logger
@ -107,9 +105,18 @@ class MultiSessionBrowserSwarm:
# Create specialized agents for different tasks # Create specialized agents for different tasks
agent_roles = [ agent_roles = [
("DataExtractor", "You specialize in extracting structured data from websites."), (
("FormFiller", "You specialize in filling out forms and interacting with web applications."), "DataExtractor",
("WebMonitor", "You specialize in monitoring websites for changes and capturing screenshots."), "You specialize in extracting structured data from websites.",
),
(
"FormFiller",
"You specialize in filling out forms and interacting with web applications.",
),
(
"WebMonitor",
"You specialize in monitoring websites for changes and capturing screenshots.",
),
] ]
for i in range(min(num_agents, len(agent_roles))): for i in range(min(num_agents, len(agent_roles))):
@ -143,7 +150,9 @@ Always create your own session for tasks to work independently from other agents
agent_idx = i % len(self.agents) agent_idx = i % len(self.agents)
agent = self.agents[agent_idx] agent = self.agents[agent_idx]
logger.info(f"Assigning task to {agent.agent_name}: {task}") logger.info(
f"Assigning task to {agent.agent_name}: {task}"
)
result = agent.run(task) result = agent.run(task)
results.append(result) results.append(result)
@ -155,7 +164,9 @@ if __name__ == "__main__":
print("=" * 70) print("=" * 70)
print("Stagehand MCP Server Integration Examples") print("Stagehand MCP Server Integration Examples")
print("=" * 70) print("=" * 70)
print("\nMake sure the Stagehand MCP server is running on http://localhost:3000/sse") print(
"\nMake sure the Stagehand MCP server is running on http://localhost:3000/sse"
)
print("Run: cd stagehand-mcp-server && npm start\n") print("Run: cd stagehand-mcp-server && npm start\n")
# Example 1: Single agent with MCP tools # Example 1: Single agent with MCP tools

@ -13,14 +13,10 @@ Use cases:
4. Data aggregation from multiple sources 4. Data aggregation from multiple sources
""" """
import asyncio
import json
import os
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List from typing import Dict, List, Optional
from dotenv import load_dotenv from dotenv import load_dotenv
from loguru import logger
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from swarms import Agent, SequentialWorkflow, ConcurrentWorkflow from swarms import Agent, SequentialWorkflow, ConcurrentWorkflow
@ -33,19 +29,29 @@ load_dotenv()
# Pydantic models for structured data # Pydantic models for structured data
class ProductInfo(BaseModel): class ProductInfo(BaseModel):
"""Product information schema.""" """Product information schema."""
name: str = Field(..., description="Product name") name: str = Field(..., description="Product name")
price: float = Field(..., description="Product price") price: float = Field(..., description="Product price")
availability: str = Field(..., description="Availability status") availability: str = Field(..., description="Availability status")
url: str = Field(..., description="Product URL") url: str = Field(..., description="Product URL")
screenshot_path: Optional[str] = Field(None, description="Screenshot file path") screenshot_path: Optional[str] = Field(
None, description="Screenshot file path"
)
class MarketAnalysis(BaseModel): class MarketAnalysis(BaseModel):
"""Market analysis report schema.""" """Market analysis report schema."""
timestamp: datetime = Field(default_factory=datetime.now) timestamp: datetime = Field(default_factory=datetime.now)
products: List[ProductInfo] = Field(..., description="List of products analyzed") products: List[ProductInfo] = Field(
price_range: Dict[str, float] = Field(..., description="Min and max prices") ..., description="List of products analyzed"
recommendations: List[str] = Field(..., description="Analysis recommendations") )
price_range: Dict[str, float] = Field(
..., description="Min and max prices"
)
recommendations: List[str] = Field(
..., description="Analysis recommendations"
)
# Specialized browser agents # Specialized browser agents
@ -54,9 +60,7 @@ class ProductScraperAgent(StagehandAgent):
def __init__(self, site_name: str, *args, **kwargs): def __init__(self, site_name: str, *args, **kwargs):
super().__init__( super().__init__(
agent_name=f"ProductScraper_{site_name}", agent_name=f"ProductScraper_{site_name}", *args, **kwargs
*args,
**kwargs
) )
self.site_name = site_name self.site_name = site_name
@ -66,9 +70,7 @@ class PriceMonitorAgent(StagehandAgent):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__( super().__init__(
agent_name="PriceMonitorAgent", agent_name="PriceMonitorAgent", *args, **kwargs
*args,
**kwargs
) )
@ -146,10 +148,16 @@ def create_competitive_analysis_workflow():
) )
# Create agent rearrange for flexible routing # Create agent rearrange for flexible routing
workflow_pattern = "company_researcher -> social_media_agent -> report_compiler" workflow_pattern = (
"company_researcher -> social_media_agent -> report_compiler"
)
competitive_workflow = AgentRearrange( competitive_workflow = AgentRearrange(
agents=[company_researcher, social_media_agent, report_compiler], agents=[
company_researcher,
social_media_agent,
report_compiler,
],
flow=workflow_pattern, flow=workflow_pattern,
verbose=True, verbose=True,
) )
@ -257,7 +265,11 @@ def create_news_aggregation_workflow():
) )
analysis_workflow = SequentialWorkflow( analysis_workflow = SequentialWorkflow(
agents=[scraping_workflow, sentiment_analyzer, trend_identifier], agents=[
scraping_workflow,
sentiment_analyzer,
trend_identifier,
],
max_loops=1, max_loops=1,
verbose=True, verbose=True,
) )
@ -350,7 +362,7 @@ if __name__ == "__main__":
for agent in price_workflow.agents: for agent in price_workflow.agents:
if isinstance(agent, StagehandAgent): if isinstance(agent, StagehandAgent):
agent.cleanup() agent.cleanup()
elif hasattr(agent, 'agents'): # For nested workflows elif hasattr(agent, "agents"): # For nested workflows
for sub_agent in agent.agents: for sub_agent in agent.agents:
if isinstance(sub_agent, StagehandAgent): if isinstance(sub_agent, StagehandAgent):
sub_agent.cleanup() sub_agent.cleanup()

@ -6,13 +6,9 @@ This module contains tests for the Stagehand browser automation
integration with the Swarms framework. integration with the Swarms framework.
""" """
import asyncio
import json import json
import pytest import pytest
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, patch
from swarms import Agent
from swarms.tools.base_tool import BaseTool
# Mock Stagehand classes # Mock Stagehand classes
@ -56,10 +52,15 @@ class MockStagehand:
class TestStagehandAgent: class TestStagehandAgent:
"""Test the StagehandAgent wrapper class.""" """Test the StagehandAgent wrapper class."""
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) @patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
def test_agent_initialization(self): def test_agent_initialization(self):
"""Test that StagehandAgent initializes correctly.""" """Test that StagehandAgent initializes correctly."""
from examples.stagehand.stagehand_wrapper_agent import StagehandAgent from examples.stagehand.stagehand_wrapper_agent import (
StagehandAgent,
)
agent = StagehandAgent( agent = StagehandAgent(
agent_name="TestAgent", agent_name="TestAgent",
@ -72,10 +73,15 @@ class TestStagehandAgent:
assert agent.stagehand_config.model_name == "gpt-4o-mini" assert agent.stagehand_config.model_name == "gpt-4o-mini"
assert not agent._initialized assert not agent._initialized
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) @patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
def test_navigation_task(self): def test_navigation_task(self):
"""Test navigation and extraction task.""" """Test navigation and extraction task."""
from examples.stagehand.stagehand_wrapper_agent import StagehandAgent from examples.stagehand.stagehand_wrapper_agent import (
StagehandAgent,
)
agent = StagehandAgent( agent = StagehandAgent(
agent_name="TestAgent", agent_name="TestAgent",
@ -83,19 +89,29 @@ class TestStagehandAgent:
env="LOCAL", env="LOCAL",
) )
result = agent.run("Navigate to example.com and extract the main content") result = agent.run(
"Navigate to example.com and extract the main content"
)
# Parse result # Parse result
result_data = json.loads(result) result_data = json.loads(result)
assert result_data["status"] == "completed" assert result_data["status"] == "completed"
assert "navigated_to" in result_data["data"] assert "navigated_to" in result_data["data"]
assert result_data["data"]["navigated_to"] == "https://example.com" assert (
result_data["data"]["navigated_to"]
== "https://example.com"
)
assert "extracted" in result_data["data"] assert "extracted" in result_data["data"]
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) @patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
def test_search_task(self): def test_search_task(self):
"""Test search functionality.""" """Test search functionality."""
from examples.stagehand.stagehand_wrapper_agent import StagehandAgent from examples.stagehand.stagehand_wrapper_agent import (
StagehandAgent,
)
agent = StagehandAgent( agent = StagehandAgent(
agent_name="TestAgent", agent_name="TestAgent",
@ -103,17 +119,24 @@ class TestStagehandAgent:
env="LOCAL", env="LOCAL",
) )
result = agent.run("Go to google.com and search for 'test query'") result = agent.run(
"Go to google.com and search for 'test query'"
)
result_data = json.loads(result) result_data = json.loads(result)
assert result_data["status"] == "completed" assert result_data["status"] == "completed"
assert result_data["data"]["search_query"] == "test query" assert result_data["data"]["search_query"] == "test query"
assert result_data["action"] == "search" assert result_data["action"] == "search"
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) @patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
def test_cleanup(self): def test_cleanup(self):
"""Test that cleanup properly closes browser.""" """Test that cleanup properly closes browser."""
from examples.stagehand.stagehand_wrapper_agent import StagehandAgent from examples.stagehand.stagehand_wrapper_agent import (
StagehandAgent,
)
agent = StagehandAgent( agent = StagehandAgent(
agent_name="TestAgent", agent_name="TestAgent",
@ -137,23 +160,29 @@ class TestStagehandAgent:
class TestStagehandTools: class TestStagehandTools:
"""Test individual Stagehand tools.""" """Test individual Stagehand tools."""
@patch('examples.stagehand.stagehand_tools_agent.browser_state') @patch("examples.stagehand.stagehand_tools_agent.browser_state")
async def test_navigate_tool(self, mock_browser_state): async def test_navigate_tool(self, mock_browser_state):
"""Test NavigateTool functionality.""" """Test NavigateTool functionality."""
from examples.stagehand.stagehand_tools_agent import NavigateTool from examples.stagehand.stagehand_tools_agent import (
NavigateTool,
)
# Setup mock # Setup mock
mock_page = AsyncMock() mock_page = AsyncMock()
mock_browser_state.get_page = AsyncMock(return_value=mock_page) mock_browser_state.get_page = AsyncMock(
return_value=mock_page
)
mock_browser_state.init_browser = AsyncMock() mock_browser_state.init_browser = AsyncMock()
tool = NavigateTool() tool = NavigateTool()
result = await tool._async_run("https://example.com") result = await tool._async_run("https://example.com")
assert "Successfully navigated to https://example.com" in result assert (
"Successfully navigated to https://example.com" in result
)
mock_page.goto.assert_called_once_with("https://example.com") mock_page.goto.assert_called_once_with("https://example.com")
@patch('examples.stagehand.stagehand_tools_agent.browser_state') @patch("examples.stagehand.stagehand_tools_agent.browser_state")
async def test_act_tool(self, mock_browser_state): async def test_act_tool(self, mock_browser_state):
"""Test ActTool functionality.""" """Test ActTool functionality."""
from examples.stagehand.stagehand_tools_agent import ActTool from examples.stagehand.stagehand_tools_agent import ActTool
@ -161,7 +190,9 @@ class TestStagehandTools:
# Setup mock # Setup mock
mock_page = AsyncMock() mock_page = AsyncMock()
mock_page.act = AsyncMock(return_value="Action completed") mock_page.act = AsyncMock(return_value="Action completed")
mock_browser_state.get_page = AsyncMock(return_value=mock_page) mock_browser_state.get_page = AsyncMock(
return_value=mock_page
)
mock_browser_state.init_browser = AsyncMock() mock_browser_state.init_browser = AsyncMock()
tool = ActTool() tool = ActTool()
@ -171,15 +202,24 @@ class TestStagehandTools:
assert "click the button" in result assert "click the button" in result
mock_page.act.assert_called_once_with("click the button") mock_page.act.assert_called_once_with("click the button")
@patch('examples.stagehand.stagehand_tools_agent.browser_state') @patch("examples.stagehand.stagehand_tools_agent.browser_state")
async def test_extract_tool(self, mock_browser_state): async def test_extract_tool(self, mock_browser_state):
"""Test ExtractTool functionality.""" """Test ExtractTool functionality."""
from examples.stagehand.stagehand_tools_agent import ExtractTool from examples.stagehand.stagehand_tools_agent import (
ExtractTool,
)
# Setup mock # Setup mock
mock_page = AsyncMock() mock_page = AsyncMock()
mock_page.extract = AsyncMock(return_value={"title": "Test Page", "content": "Test content"}) mock_page.extract = AsyncMock(
mock_browser_state.get_page = AsyncMock(return_value=mock_page) return_value={
"title": "Test Page",
"content": "Test content",
}
)
mock_browser_state.get_page = AsyncMock(
return_value=mock_page
)
mock_browser_state.init_browser = AsyncMock() mock_browser_state.init_browser = AsyncMock()
tool = ExtractTool() tool = ExtractTool()
@ -190,10 +230,12 @@ class TestStagehandTools:
assert parsed_result["title"] == "Test Page" assert parsed_result["title"] == "Test Page"
assert parsed_result["content"] == "Test content" assert parsed_result["content"] == "Test content"
@patch('examples.stagehand.stagehand_tools_agent.browser_state') @patch("examples.stagehand.stagehand_tools_agent.browser_state")
async def test_observe_tool(self, mock_browser_state): async def test_observe_tool(self, mock_browser_state):
"""Test ObserveTool functionality.""" """Test ObserveTool functionality."""
from examples.stagehand.stagehand_tools_agent import ObserveTool from examples.stagehand.stagehand_tools_agent import (
ObserveTool,
)
# Setup mock # Setup mock
mock_page = AsyncMock() mock_page = AsyncMock()
@ -202,7 +244,9 @@ class TestStagehandTools:
MockObserveResult("Submit button", "#submit"), MockObserveResult("Submit button", "#submit"),
] ]
mock_page.observe = AsyncMock(return_value=mock_observations) mock_page.observe = AsyncMock(return_value=mock_observations)
mock_browser_state.get_page = AsyncMock(return_value=mock_page) mock_browser_state.get_page = AsyncMock(
return_value=mock_page
)
mock_browser_state.init_browser = AsyncMock() mock_browser_state.init_browser = AsyncMock()
tool = ObserveTool() tool = ObserveTool()
@ -221,7 +265,9 @@ class TestStagehandMCP:
def test_mcp_agent_initialization(self): def test_mcp_agent_initialization(self):
"""Test that MCP agent initializes with correct parameters.""" """Test that MCP agent initializes with correct parameters."""
from examples.stagehand.stagehand_mcp_agent import StagehandMCPAgent from examples.stagehand.stagehand_mcp_agent import (
StagehandMCPAgent,
)
mcp_agent = StagehandMCPAgent( mcp_agent = StagehandMCPAgent(
agent_name="TestMCPAgent", agent_name="TestMCPAgent",
@ -235,7 +281,9 @@ class TestStagehandMCP:
def test_multi_session_swarm_creation(self): def test_multi_session_swarm_creation(self):
"""Test multi-session browser swarm creation.""" """Test multi-session browser swarm creation."""
from examples.stagehand.stagehand_mcp_agent import MultiSessionBrowserSwarm from examples.stagehand.stagehand_mcp_agent import (
MultiSessionBrowserSwarm,
)
swarm = MultiSessionBrowserSwarm( swarm = MultiSessionBrowserSwarm(
mcp_server_url="http://localhost:3000/sse", mcp_server_url="http://localhost:3000/sse",
@ -247,10 +295,12 @@ class TestStagehandMCP:
assert swarm.agents[1].agent_name == "FormFiller_1" assert swarm.agents[1].agent_name == "FormFiller_1"
assert swarm.agents[2].agent_name == "WebMonitor_2" assert swarm.agents[2].agent_name == "WebMonitor_2"
@patch('swarms.Agent.run') @patch("swarms.Agent.run")
def test_task_distribution(self, mock_run): def test_task_distribution(self, mock_run):
"""Test task distribution among swarm agents.""" """Test task distribution among swarm agents."""
from examples.stagehand.stagehand_mcp_agent import MultiSessionBrowserSwarm from examples.stagehand.stagehand_mcp_agent import (
MultiSessionBrowserSwarm,
)
mock_run.return_value = "Task completed" mock_run.return_value = "Task completed"
@ -268,55 +318,80 @@ class TestStagehandMCP:
class TestMultiAgentWorkflows: class TestMultiAgentWorkflows:
"""Test multi-agent workflow configurations.""" """Test multi-agent workflow configurations."""
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) @patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
def test_price_comparison_workflow_creation(self): def test_price_comparison_workflow_creation(self):
"""Test creation of price comparison workflow.""" """Test creation of price comparison workflow."""
from examples.stagehand.stagehand_multi_agent_workflow import create_price_comparison_workflow from examples.stagehand.stagehand_multi_agent_workflow import (
create_price_comparison_workflow,
)
workflow = create_price_comparison_workflow() workflow = create_price_comparison_workflow()
# Should be a SequentialWorkflow with 2 agents # Should be a SequentialWorkflow with 2 agents
assert len(workflow.agents) == 2 assert len(workflow.agents) == 2
# First agent should be a ConcurrentWorkflow # First agent should be a ConcurrentWorkflow
assert hasattr(workflow.agents[0], 'agents') assert hasattr(workflow.agents[0], "agents")
# Second agent should be the analysis agent # Second agent should be the analysis agent
assert workflow.agents[1].agent_name == "PriceAnalysisAgent" assert workflow.agents[1].agent_name == "PriceAnalysisAgent"
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) @patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
def test_competitive_analysis_workflow_creation(self): def test_competitive_analysis_workflow_creation(self):
"""Test creation of competitive analysis workflow.""" """Test creation of competitive analysis workflow."""
from examples.stagehand.stagehand_multi_agent_workflow import create_competitive_analysis_workflow from examples.stagehand.stagehand_multi_agent_workflow import (
create_competitive_analysis_workflow,
)
workflow = create_competitive_analysis_workflow() workflow = create_competitive_analysis_workflow()
# Should have 3 agents in the rearrange pattern # Should have 3 agents in the rearrange pattern
assert len(workflow.agents) == 3 assert len(workflow.agents) == 3
assert workflow.flow == "company_researcher -> social_media_agent -> report_compiler" assert (
workflow.flow
== "company_researcher -> social_media_agent -> report_compiler"
)
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) @patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
def test_automated_testing_workflow_creation(self): def test_automated_testing_workflow_creation(self):
"""Test creation of automated testing workflow.""" """Test creation of automated testing workflow."""
from examples.stagehand.stagehand_multi_agent_workflow import create_automated_testing_workflow from examples.stagehand.stagehand_multi_agent_workflow import (
create_automated_testing_workflow,
)
workflow = create_automated_testing_workflow() workflow = create_automated_testing_workflow()
# Should be a SequentialWorkflow # Should be a SequentialWorkflow
assert len(workflow.agents) == 2 assert len(workflow.agents) == 2
# First should be concurrent testing # First should be concurrent testing
assert hasattr(workflow.agents[0], 'agents') assert hasattr(workflow.agents[0], "agents")
assert len(workflow.agents[0].agents) == 3 # UI, Form, Accessibility testers assert (
len(workflow.agents[0].agents) == 3
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) ) # UI, Form, Accessibility testers
@patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
def test_news_aggregation_workflow_creation(self): def test_news_aggregation_workflow_creation(self):
"""Test creation of news aggregation workflow.""" """Test creation of news aggregation workflow."""
from examples.stagehand.stagehand_multi_agent_workflow import create_news_aggregation_workflow from examples.stagehand.stagehand_multi_agent_workflow import (
create_news_aggregation_workflow,
)
workflow = create_news_aggregation_workflow() workflow = create_news_aggregation_workflow()
# Should be a SequentialWorkflow with 3 stages # Should be a SequentialWorkflow with 3 stages
assert len(workflow.agents) == 3 assert len(workflow.agents) == 3
# First stage should be concurrent scrapers # First stage should be concurrent scrapers
assert hasattr(workflow.agents[0], 'agents') assert hasattr(workflow.agents[0], "agents")
assert len(workflow.agents[0].agents) == 3 # 3 news sources assert len(workflow.agents[0].agents) == 3 # 3 news sources
@ -325,10 +400,15 @@ class TestIntegration:
"""End-to-end integration tests.""" """End-to-end integration tests."""
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand) @patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
async def test_full_browser_automation_flow(self): async def test_full_browser_automation_flow(self):
"""Test a complete browser automation flow.""" """Test a complete browser automation flow."""
from examples.stagehand.stagehand_wrapper_agent import StagehandAgent from examples.stagehand.stagehand_wrapper_agent import (
StagehandAgent,
)
agent = StagehandAgent( agent = StagehandAgent(
agent_name="IntegrationTestAgent", agent_name="IntegrationTestAgent",

@ -0,0 +1,304 @@
"""
Simple tests for Stagehand Integration with Swarms
=================================================
These tests verify the basic structure and functionality of the
Stagehand integration without requiring external dependencies.
"""
import json
import pytest
from unittest.mock import MagicMock
class TestStagehandIntegrationStructure:
"""Test that integration files have correct structure."""
def test_examples_directory_exists(self):
"""Test that examples directory structure is correct."""
import os
base_path = "examples/stagehand"
assert os.path.exists(base_path)
expected_files = [
"1_stagehand_wrapper_agent.py",
"2_stagehand_tools_agent.py",
"3_stagehand_mcp_agent.py",
"4_stagehand_multi_agent_workflow.py",
"README.md",
"requirements.txt",
]
for file in expected_files:
file_path = os.path.join(base_path, file)
assert os.path.exists(file_path), f"Missing file: {file}"
def test_wrapper_agent_imports(self):
"""Test that wrapper agent has correct imports."""
with open(
"examples/stagehand/1_stagehand_wrapper_agent.py", "r"
) as f:
content = f.read()
# Check for required imports
assert "from swarms import Agent" in content
assert "import asyncio" in content
assert "import json" in content
assert "class StagehandAgent" in content
def test_tools_agent_imports(self):
"""Test that tools agent has correct imports."""
with open(
"examples/stagehand/2_stagehand_tools_agent.py", "r"
) as f:
content = f.read()
# Check for required imports
assert (
"from swarms.tools.base_tool import BaseTool" in content
)
assert "class NavigateTool" in content
assert "class ActTool" in content
assert "class ExtractTool" in content
def test_mcp_agent_imports(self):
"""Test that MCP agent has correct imports."""
with open(
"examples/stagehand/3_stagehand_mcp_agent.py", "r"
) as f:
content = f.read()
# Check for required imports
assert "from swarms import Agent" in content
assert "class StagehandMCPAgent" in content
assert "mcp_url" in content
def test_workflow_agent_imports(self):
"""Test that workflow agent has correct imports."""
with open(
"examples/stagehand/4_stagehand_multi_agent_workflow.py",
"r",
) as f:
content = f.read()
# Check for required imports
assert (
"from swarms import Agent, SequentialWorkflow, ConcurrentWorkflow"
in content
)
assert (
"from swarms.structs.agent_rearrange import AgentRearrange"
in content
)
class TestStagehandMockIntegration:
"""Test Stagehand integration with mocked dependencies."""
def test_mock_stagehand_initialization(self):
"""Test that Stagehand can be mocked and initialized."""
# Setup mock without importing actual stagehand
mock_stagehand = MagicMock()
mock_instance = MagicMock()
mock_instance.init = MagicMock()
mock_stagehand.return_value = mock_instance
# Mock config creation
config = MagicMock()
stagehand_instance = mock_stagehand(config)
# Verify mock works
assert stagehand_instance is not None
assert hasattr(stagehand_instance, "init")
def test_json_serialization(self):
"""Test JSON serialization for agent responses."""
# Test data that would come from browser automation
test_data = {
"task": "Navigate to example.com",
"status": "completed",
"data": {
"navigated_to": "https://example.com",
"extracted": ["item1", "item2"],
"action": "navigate",
},
}
# Test serialization
json_result = json.dumps(test_data, indent=2)
assert isinstance(json_result, str)
# Test deserialization
parsed_data = json.loads(json_result)
assert parsed_data["task"] == "Navigate to example.com"
assert parsed_data["status"] == "completed"
assert len(parsed_data["data"]["extracted"]) == 2
def test_url_extraction_logic(self):
"""Test URL extraction logic from task strings."""
import re
# Test cases
test_cases = [
(
"Navigate to https://example.com",
["https://example.com"],
),
("Go to google.com and search", ["google.com"]),
(
"Visit https://github.com/repo",
["https://github.com/repo"],
),
("Open example.org", ["example.org"]),
]
url_pattern = r"https?://[^\s]+"
domain_pattern = r"(\w+\.\w+)"
for task, expected in test_cases:
# Extract full URLs
urls = re.findall(url_pattern, task)
# If no full URLs, extract domains
if not urls:
domains = re.findall(domain_pattern, task)
if domains:
urls = domains
assert (
len(urls) > 0
), f"Failed to extract URL from: {task}"
assert (
urls[0] in expected
), f"Expected {expected}, got {urls}"
class TestSwarmsPatternsCompliance:
"""Test compliance with Swarms framework patterns."""
def test_agent_inheritance_pattern(self):
"""Test that wrapper agent follows Swarms Agent inheritance pattern."""
# Read the wrapper agent file
with open(
"examples/stagehand/1_stagehand_wrapper_agent.py", "r"
) as f:
content = f.read()
# Check inheritance pattern
assert "class StagehandAgent(SwarmsAgent):" in content
assert "def run(self, task: str" in content
assert "return" in content
def test_tools_pattern(self):
"""Test that tools follow Swarms BaseTool pattern."""
# Read the tools agent file
with open(
"examples/stagehand/2_stagehand_tools_agent.py", "r"
) as f:
content = f.read()
# Check tool pattern
assert "class NavigateTool(BaseTool):" in content
assert "def run(self," in content
assert "name=" in content
assert "description=" in content
def test_mcp_integration_pattern(self):
"""Test MCP integration follows Swarms pattern."""
# Read the MCP agent file
with open(
"examples/stagehand/3_stagehand_mcp_agent.py", "r"
) as f:
content = f.read()
# Check MCP pattern
assert "mcp_url=" in content
assert "Agent(" in content
def test_workflow_patterns(self):
"""Test workflow patterns are properly used."""
# Read the workflow file
with open(
"examples/stagehand/4_stagehand_multi_agent_workflow.py",
"r",
) as f:
content = f.read()
# Check workflow patterns
assert "SequentialWorkflow" in content
assert "ConcurrentWorkflow" in content
assert "AgentRearrange" in content
class TestDocumentationAndExamples:
"""Test documentation and example completeness."""
def test_readme_completeness(self):
"""Test that README contains essential information."""
with open("examples/stagehand/README.md", "r") as f:
content = f.read()
required_sections = [
"# Stagehand Browser Automation Integration",
"## Overview",
"## Examples",
"## Setup",
"## Use Cases",
"## Best Practices",
]
for section in required_sections:
assert section in content, f"Missing section: {section}"
def test_requirements_file(self):
"""Test that requirements file has necessary dependencies."""
with open("examples/stagehand/requirements.txt", "r") as f:
content = f.read()
required_deps = [
"swarms",
"stagehand",
"python-dotenv",
"pydantic",
"loguru",
]
for dep in required_deps:
assert dep in content, f"Missing dependency: {dep}"
def test_example_files_have_docstrings(self):
"""Test that example files have proper docstrings."""
example_files = [
"examples/stagehand/1_stagehand_wrapper_agent.py",
"examples/stagehand/2_stagehand_tools_agent.py",
"examples/stagehand/3_stagehand_mcp_agent.py",
"examples/stagehand/4_stagehand_multi_agent_workflow.py",
]
for file_path in example_files:
with open(file_path, "r") as f:
content = f.read()
# Check for module docstring
assert (
'"""' in content[:500]
), f"Missing docstring in {file_path}"
# Check for main execution block
assert (
'if __name__ == "__main__":' in content
), f"Missing main block in {file_path}"
if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading…
Cancel
Save