pull/1000/head
Filip Michalsky 1 month ago
parent b04e60ca17
commit 2d7dfca4a4

@ -12,7 +12,7 @@ and implements browser automation through natural language commands.
import asyncio
import json
import os
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Optional
from dotenv import load_dotenv
from loguru import logger
@ -75,7 +75,8 @@ class StagehandAgent(SwarmsAgent):
project_id=browserbase_project_id
or os.getenv("BROWSERBASE_PROJECT_ID"),
model_name=model_name,
model_api_key=model_api_key or os.getenv("OPENAI_API_KEY"),
model_api_key=model_api_key
or os.getenv("OPENAI_API_KEY"),
)
self.stagehand = None
self._initialized = False
@ -86,7 +87,9 @@ class StagehandAgent(SwarmsAgent):
self.stagehand = Stagehand(self.stagehand_config)
await self.stagehand.init()
self._initialized = True
logger.info(f"Stagehand initialized for {self.agent_name}")
logger.info(
f"Stagehand initialized for {self.agent_name}"
)
async def _close_stagehand(self):
"""Close Stagehand instance."""
@ -112,9 +115,7 @@ class StagehandAgent(SwarmsAgent):
"""
return asyncio.run(self._async_run(task, *args, **kwargs))
async def _async_run(
self, task: str, *args, **kwargs
) -> str:
async def _async_run(self, task: str, *args, **kwargs) -> str:
"""Async implementation of run method."""
try:
await self._init_stagehand()
@ -183,9 +184,13 @@ class StagehandAgent(SwarmsAgent):
elif "search" in task.lower():
# Perform search action
search_query = task.split("search for")[-1].strip().strip("'\"")
search_query = (
task.split("search for")[-1].strip().strip("'\"")
)
# First, find the search box
search_box = await page.observe("find the search input field")
search_box = await page.observe(
"find the search input field"
)
if search_box:
# Click on search box and type
await page.act(f"click on {search_box[0]}")
@ -198,7 +203,10 @@ class StagehandAgent(SwarmsAgent):
# Perform observation
observation = await page.observe(task)
result["data"]["observation"] = [
{"description": obs.description, "selector": obs.selector}
{
"description": obs.description,
"selector": obs.selector,
}
for obs in observation
]
result["action"] = "observe"
@ -254,4 +262,4 @@ if __name__ == "__main__":
print(result3)
# Clean up
browser_agent.cleanup()
browser_agent.cleanup()

@ -3,7 +3,7 @@ Stagehand Tools for Swarms Agent
=================================
This example demonstrates how to create Stagehand browser automation tools
that can be used by a standard Swarms Agent. Each Stagehand method (act,
that can be used by a standard Swarms Agent. Each Stagehand method (act,
extract, observe) becomes a separate tool that the agent can use.
This approach gives the agent more fine-grained control over browser
@ -13,11 +13,10 @@ automation tasks.
import asyncio
import json
import os
from typing import Any, Dict, List, Optional, Union
from typing import Optional
from dotenv import load_dotenv
from loguru import logger
from pydantic import BaseModel, Field
from swarms import Agent
from swarms.tools.base_tool import BaseTool
@ -51,9 +50,11 @@ class BrowserState:
config = StagehandConfig(
env=env,
api_key=api_key or os.getenv("BROWSERBASE_API_KEY"),
project_id=project_id or os.getenv("BROWSERBASE_PROJECT_ID"),
project_id=project_id
or os.getenv("BROWSERBASE_PROJECT_ID"),
model_name=model_name,
model_api_key=model_api_key or os.getenv("OPENAI_API_KEY"),
model_api_key=model_api_key
or os.getenv("OPENAI_API_KEY"),
)
self._stagehand = Stagehand(config)
await self._stagehand.init()
@ -63,7 +64,9 @@ class BrowserState:
async def get_page(self):
"""Get the current page instance."""
if not self._initialized:
raise RuntimeError("Browser not initialized. Call init_browser first.")
raise RuntimeError(
"Browser not initialized. Call init_browser first."
)
return self._stagehand.page
async def close(self):
@ -96,11 +99,11 @@ class NavigateTool(BaseTool):
try:
await browser_state.init_browser()
page = await browser_state.get_page()
# Ensure URL has protocol
if not url.startswith(("http://", "https://")):
url = f"https://{url}"
await page.goto(url)
return f"Successfully navigated to {url}"
except Exception as e:
@ -130,7 +133,7 @@ class ActTool(BaseTool):
try:
await browser_state.init_browser()
page = await browser_state.get_page()
result = await page.act(action)
return f"Action performed: {action}. Result: {result}"
except Exception as e:
@ -160,9 +163,9 @@ class ExtractTool(BaseTool):
try:
await browser_state.init_browser()
page = await browser_state.get_page()
extracted = await page.extract(query)
# Convert to JSON string for agent consumption
if isinstance(extracted, (dict, list)):
return json.dumps(extracted, indent=2)
@ -196,18 +199,20 @@ class ObserveTool(BaseTool):
try:
await browser_state.init_browser()
page = await browser_state.get_page()
observations = await page.observe(query)
# Format observations for readability
result = []
for obs in observations:
result.append({
"description": obs.description,
"selector": obs.selector,
"method": obs.method
})
result.append(
{
"description": obs.description,
"selector": obs.selector,
"method": obs.method,
}
)
return json.dumps(result, indent=2)
except Exception as e:
logger.error(f"Observation error: {str(e)}")
@ -232,15 +237,15 @@ class ScreenshotTool(BaseTool):
try:
await browser_state.init_browser()
page = await browser_state.get_page()
# Ensure .png extension
if not filename.endswith(".png"):
filename += ".png"
# Get the underlying Playwright page
playwright_page = page.page
await playwright_page.screenshot(path=filename)
return f"Screenshot saved to {filename}"
except Exception as e:
logger.error(f"Screenshot error: {str(e)}")
@ -329,4 +334,4 @@ if __name__ == "__main__":
print(result3)
# Clean up
browser_agent.run("Close the browser")
browser_agent.run("Close the browser")

@ -22,9 +22,7 @@ Features:
- Prompt templates for common tasks
"""
import asyncio
import os
from typing import List, Optional
from typing import List
from dotenv import load_dotenv
from loguru import logger
@ -104,14 +102,23 @@ class MultiSessionBrowserSwarm:
num_agents: Number of agents to create
"""
self.agents = []
# Create specialized agents for different tasks
agent_roles = [
("DataExtractor", "You specialize in extracting structured data from websites."),
("FormFiller", "You specialize in filling out forms and interacting with web applications."),
("WebMonitor", "You specialize in monitoring websites for changes and capturing screenshots."),
(
"DataExtractor",
"You specialize in extracting structured data from websites.",
),
(
"FormFiller",
"You specialize in filling out forms and interacting with web applications.",
),
(
"WebMonitor",
"You specialize in monitoring websites for changes and capturing screenshots.",
),
]
for i in range(min(num_agents, len(agent_roles))):
name, specialization = agent_roles[i]
agent = Agent(
@ -137,16 +144,18 @@ Always create your own session for tasks to work independently from other agents
def distribute_tasks(self, tasks: List[str]) -> List[str]:
"""Distribute tasks among agents."""
results = []
# Distribute tasks round-robin among agents
for i, task in enumerate(tasks):
agent_idx = i % len(self.agents)
agent = self.agents[agent_idx]
logger.info(f"Assigning task to {agent.agent_name}: {task}")
logger.info(
f"Assigning task to {agent.agent_name}: {task}"
)
result = agent.run(task)
results.append(result)
return results
@ -155,18 +164,20 @@ if __name__ == "__main__":
print("=" * 70)
print("Stagehand MCP Server Integration Examples")
print("=" * 70)
print("\nMake sure the Stagehand MCP server is running on http://localhost:3000/sse")
print(
"\nMake sure the Stagehand MCP server is running on http://localhost:3000/sse"
)
print("Run: cd stagehand-mcp-server && npm start\n")
# Example 1: Single agent with MCP tools
print("\nExample 1: Single Agent with MCP Tools")
print("-" * 40)
mcp_agent = StagehandMCPAgent(
agent_name="WebResearchAgent",
mcp_server_url="http://localhost:3000/sse",
)
# Research task using MCP tools
result1 = mcp_agent.run(
"""Navigate to news.ycombinator.com and extract the following:
@ -176,18 +187,18 @@ if __name__ == "__main__":
Then take a screenshot of the page."""
)
print(f"Result: {result1}")
print("\n" + "=" * 70 + "\n")
# Example 2: Multi-session parallel browsing
print("Example 2: Multi-Session Parallel Browsing")
print("-" * 40)
parallel_agent = StagehandMCPAgent(
agent_name="ParallelBrowserAgent",
mcp_server_url="http://localhost:3000/sse",
)
result2 = parallel_agent.run(
"""Create 3 browser sessions and perform these tasks in parallel:
1. Session 1: Go to github.com/trending and extract the top 3 trending repositories
@ -197,44 +208,44 @@ if __name__ == "__main__":
After extracting data from all sessions, close them."""
)
print(f"Result: {result2}")
print("\n" + "=" * 70 + "\n")
# Example 3: Multi-agent browser swarm
print("Example 3: Multi-Agent Browser Swarm")
print("-" * 40)
# Create a swarm of specialized browser agents
browser_swarm = MultiSessionBrowserSwarm(
mcp_server_url="http://localhost:3000/sse",
num_agents=3,
)
# Define tasks for the swarm
swarm_tasks = [
"Create a session, navigate to python.org, and extract information about the latest Python version and its key features",
"Create a session, go to npmjs.com, search for 'stagehand', and extract information about the package including version and description",
"Create a session, visit playwright.dev, and extract the main features and benefits listed on the homepage",
]
print("Distributing tasks to browser swarm...")
swarm_results = browser_swarm.distribute_tasks(swarm_tasks)
for i, result in enumerate(swarm_results):
print(f"\nTask {i+1} Result: {result}")
print("\n" + "=" * 70 + "\n")
# Example 4: Complex workflow with session management
print("Example 4: Complex Multi-Page Workflow")
print("-" * 40)
workflow_agent = StagehandMCPAgent(
agent_name="WorkflowAgent",
mcp_server_url="http://localhost:3000/sse",
max_loops=2, # Allow more complex reasoning
)
result4 = workflow_agent.run(
"""Perform a comprehensive analysis of AI frameworks:
1. Create a new session
@ -246,7 +257,7 @@ if __name__ == "__main__":
7. Close the session when done"""
)
print(f"Result: {result4}")
print("\n" + "=" * 70)
print("All examples completed!")
print("=" * 70)
print("=" * 70)

@ -13,14 +13,10 @@ Use cases:
4. Data aggregation from multiple sources
"""
import asyncio
import json
import os
from datetime import datetime
from typing import Any, Dict, List
from typing import Dict, List, Optional
from dotenv import load_dotenv
from loguru import logger
from pydantic import BaseModel, Field
from swarms import Agent, SequentialWorkflow, ConcurrentWorkflow
@ -33,42 +29,48 @@ load_dotenv()
# Pydantic models for structured data
class ProductInfo(BaseModel):
"""Product information schema."""
name: str = Field(..., description="Product name")
price: float = Field(..., description="Product price")
availability: str = Field(..., description="Availability status")
url: str = Field(..., description="Product URL")
screenshot_path: Optional[str] = Field(None, description="Screenshot file path")
screenshot_path: Optional[str] = Field(
None, description="Screenshot file path"
)
class MarketAnalysis(BaseModel):
"""Market analysis report schema."""
timestamp: datetime = Field(default_factory=datetime.now)
products: List[ProductInfo] = Field(..., description="List of products analyzed")
price_range: Dict[str, float] = Field(..., description="Min and max prices")
recommendations: List[str] = Field(..., description="Analysis recommendations")
products: List[ProductInfo] = Field(
..., description="List of products analyzed"
)
price_range: Dict[str, float] = Field(
..., description="Min and max prices"
)
recommendations: List[str] = Field(
..., description="Analysis recommendations"
)
# Specialized browser agents
class ProductScraperAgent(StagehandAgent):
"""Specialized agent for scraping product information."""
def __init__(self, site_name: str, *args, **kwargs):
super().__init__(
agent_name=f"ProductScraper_{site_name}",
*args,
**kwargs
agent_name=f"ProductScraper_{site_name}", *args, **kwargs
)
self.site_name = site_name
class PriceMonitorAgent(StagehandAgent):
"""Specialized agent for monitoring price changes."""
def __init__(self, *args, **kwargs):
super().__init__(
agent_name="PriceMonitorAgent",
*args,
**kwargs
agent_name="PriceMonitorAgent", *args, **kwargs
)
@ -77,20 +79,20 @@ def create_price_comparison_workflow():
"""
Create a workflow that compares prices across multiple e-commerce sites.
"""
# Create specialized agents for different sites
amazon_agent = StagehandAgent(
agent_name="AmazonScraperAgent",
model_name="gpt-4o-mini",
env="LOCAL",
)
ebay_agent = StagehandAgent(
agent_name="EbayScraperAgent",
model_name="gpt-4o-mini",
env="LOCAL",
)
analysis_agent = Agent(
agent_name="PriceAnalysisAgent",
model_name="gpt-4o-mini",
@ -98,21 +100,21 @@ def create_price_comparison_workflow():
and provide insights on the best deals, price trends, and recommendations.
Focus on value for money and highlight any significant price differences.""",
)
# Create concurrent workflow for parallel scraping
scraping_workflow = ConcurrentWorkflow(
agents=[amazon_agent, ebay_agent],
max_loops=1,
verbose=True,
)
# Create sequential workflow: scrape -> analyze
full_workflow = SequentialWorkflow(
agents=[scraping_workflow, analysis_agent],
max_loops=1,
verbose=True,
)
return full_workflow
@ -121,21 +123,21 @@ def create_competitive_analysis_workflow():
"""
Create a workflow for competitive analysis across multiple company websites.
"""
# Agent for extracting company information
company_researcher = StagehandAgent(
agent_name="CompanyResearchAgent",
model_name="gpt-4o-mini",
env="LOCAL",
)
# Agent for analyzing social media presence
social_media_agent = StagehandAgent(
agent_name="SocialMediaAnalysisAgent",
model_name="gpt-4o-mini",
env="LOCAL",
)
# Agent for compiling competitive analysis report
report_compiler = Agent(
agent_name="CompetitiveAnalysisReporter",
@ -144,16 +146,22 @@ def create_competitive_analysis_workflow():
based on company information and social media presence data. Identify strengths,
weaknesses, and market positioning for each company.""",
)
# Create agent rearrange for flexible routing
workflow_pattern = "company_researcher -> social_media_agent -> report_compiler"
workflow_pattern = (
"company_researcher -> social_media_agent -> report_compiler"
)
competitive_workflow = AgentRearrange(
agents=[company_researcher, social_media_agent, report_compiler],
agents=[
company_researcher,
social_media_agent,
report_compiler,
],
flow=workflow_pattern,
verbose=True,
)
return competitive_workflow
@ -162,28 +170,28 @@ def create_automated_testing_workflow():
"""
Create a workflow for automated web application testing.
"""
# Agent for UI testing
ui_tester = StagehandAgent(
agent_name="UITestingAgent",
model_name="gpt-4o-mini",
env="LOCAL",
)
# Agent for form validation testing
form_tester = StagehandAgent(
agent_name="FormValidationAgent",
model_name="gpt-4o-mini",
env="LOCAL",
)
# Agent for accessibility testing
accessibility_tester = StagehandAgent(
agent_name="AccessibilityTestingAgent",
model_name="gpt-4o-mini",
env="LOCAL",
)
# Agent for compiling test results
test_reporter = Agent(
agent_name="TestReportCompiler",
@ -192,20 +200,20 @@ def create_automated_testing_workflow():
UI, form validation, and accessibility testing into a comprehensive report.
Highlight any failures, warnings, and provide recommendations for fixes.""",
)
# Concurrent testing followed by report generation
testing_workflow = ConcurrentWorkflow(
agents=[ui_tester, form_tester, accessibility_tester],
max_loops=1,
verbose=True,
)
full_test_workflow = SequentialWorkflow(
agents=[testing_workflow, test_reporter],
max_loops=1,
verbose=True,
)
return full_test_workflow
@ -214,7 +222,7 @@ def create_news_aggregation_workflow():
"""
Create a workflow for news aggregation and sentiment analysis.
"""
# Multiple news scraper agents
news_scrapers = []
news_sites = [
@ -222,7 +230,7 @@ def create_news_aggregation_workflow():
("HackerNews", "https://news.ycombinator.com"),
("Reddit", "https://reddit.com/r/technology"),
]
for site_name, url in news_sites:
scraper = StagehandAgent(
agent_name=f"{site_name}Scraper",
@ -230,7 +238,7 @@ def create_news_aggregation_workflow():
env="LOCAL",
)
news_scrapers.append(scraper)
# Sentiment analysis agent
sentiment_analyzer = Agent(
agent_name="SentimentAnalyzer",
@ -239,7 +247,7 @@ def create_news_aggregation_workflow():
to determine overall sentiment (positive, negative, neutral) and identify key themes
and trends in the technology sector.""",
)
# Trend identification agent
trend_identifier = Agent(
agent_name="TrendIdentifier",
@ -248,20 +256,24 @@ def create_news_aggregation_workflow():
data, identify emerging trends, hot topics, and potential market movements in the
technology sector.""",
)
# Create workflow: parallel scraping -> sentiment analysis -> trend identification
scraping_workflow = ConcurrentWorkflow(
agents=news_scrapers,
max_loops=1,
verbose=True,
)
analysis_workflow = SequentialWorkflow(
agents=[scraping_workflow, sentiment_analyzer, trend_identifier],
agents=[
scraping_workflow,
sentiment_analyzer,
trend_identifier,
],
max_loops=1,
verbose=True,
)
return analysis_workflow
@ -270,13 +282,13 @@ if __name__ == "__main__":
print("=" * 70)
print("Stagehand Multi-Agent Workflow Examples")
print("=" * 70)
# Example 1: Price Comparison
print("\nExample 1: E-commerce Price Comparison")
print("-" * 40)
price_workflow = create_price_comparison_workflow()
# Search for a specific product across multiple sites
price_result = price_workflow.run(
"""Search for 'iPhone 15 Pro Max 256GB' on:
@ -286,15 +298,15 @@ if __name__ == "__main__":
Compare the prices and provide recommendations on where to buy."""
)
print(f"Price Comparison Result:\n{price_result}")
print("\n" + "=" * 70 + "\n")
# Example 2: Competitive Analysis
print("Example 2: Competitive Analysis")
print("-" * 40)
competitive_workflow = create_competitive_analysis_workflow()
competitive_result = competitive_workflow.run(
"""Analyze these three AI companies:
1. OpenAI - visit openai.com and extract mission, products, and recent announcements
@ -305,15 +317,15 @@ if __name__ == "__main__":
Compile a competitive analysis report comparing their market positioning."""
)
print(f"Competitive Analysis Result:\n{competitive_result}")
print("\n" + "=" * 70 + "\n")
# Example 3: Automated Testing
print("Example 3: Automated Web Testing")
print("-" * 40)
testing_workflow = create_automated_testing_workflow()
test_result = testing_workflow.run(
"""Test the website example.com:
1. UI Testing: Check if all main navigation links work, images load, and layout is responsive
@ -323,15 +335,15 @@ if __name__ == "__main__":
Take screenshots of any issues found and compile a comprehensive test report."""
)
print(f"Test Results:\n{test_result}")
print("\n" + "=" * 70 + "\n")
# Example 4: News Aggregation
print("Example 4: Tech News Aggregation and Analysis")
print("-" * 40)
news_workflow = create_news_aggregation_workflow()
news_result = news_workflow.run(
"""For each news source:
1. TechCrunch: Extract the top 5 headlines about AI or machine learning
@ -341,19 +353,19 @@ if __name__ == "__main__":
Analyze sentiment and identify emerging trends in AI technology."""
)
print(f"News Analysis Result:\n{news_result}")
# Cleanup all browser instances
print("\n" + "=" * 70)
print("Cleaning up browser instances...")
# Clean up agents
for agent in price_workflow.agents:
if isinstance(agent, StagehandAgent):
agent.cleanup()
elif hasattr(agent, 'agents'): # For nested workflows
elif hasattr(agent, "agents"): # For nested workflows
for sub_agent in agent.agents:
if isinstance(sub_agent, StagehandAgent):
sub_agent.cleanup()
print("All workflows completed!")
print("=" * 70)
print("=" * 70)

@ -6,13 +6,9 @@ This module contains tests for the Stagehand browser automation
integration with the Swarms framework.
"""
import asyncio
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from swarms import Agent
from swarms.tools.base_tool import BaseTool
from unittest.mock import AsyncMock, patch
# Mock Stagehand classes
@ -26,13 +22,13 @@ class MockObserveResult:
class MockStagehandPage:
async def goto(self, url):
return None
async def act(self, action):
return f"Performed action: {action}"
async def extract(self, query):
return {"extracted": query, "data": ["item1", "item2"]}
async def observe(self, query):
return [
MockObserveResult("Search box", "#search-input"),
@ -44,10 +40,10 @@ class MockStagehand:
def __init__(self, config):
self.config = config
self.page = MockStagehandPage()
async def init(self):
pass
async def close(self):
pass
@ -55,79 +51,106 @@ class MockStagehand:
# Test StagehandAgent wrapper
class TestStagehandAgent:
"""Test the StagehandAgent wrapper class."""
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand)
@patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
def test_agent_initialization(self):
"""Test that StagehandAgent initializes correctly."""
from examples.stagehand.stagehand_wrapper_agent import StagehandAgent
from examples.stagehand.stagehand_wrapper_agent import (
StagehandAgent,
)
agent = StagehandAgent(
agent_name="TestAgent",
model_name="gpt-4o-mini",
env="LOCAL",
)
assert agent.agent_name == "TestAgent"
assert agent.stagehand_config.env == "LOCAL"
assert agent.stagehand_config.model_name == "gpt-4o-mini"
assert not agent._initialized
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand)
@patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
def test_navigation_task(self):
"""Test navigation and extraction task."""
from examples.stagehand.stagehand_wrapper_agent import StagehandAgent
from examples.stagehand.stagehand_wrapper_agent import (
StagehandAgent,
)
agent = StagehandAgent(
agent_name="TestAgent",
model_name="gpt-4o-mini",
env="LOCAL",
)
result = agent.run("Navigate to example.com and extract the main content")
result = agent.run(
"Navigate to example.com and extract the main content"
)
# Parse result
result_data = json.loads(result)
assert result_data["status"] == "completed"
assert "navigated_to" in result_data["data"]
assert result_data["data"]["navigated_to"] == "https://example.com"
assert (
result_data["data"]["navigated_to"]
== "https://example.com"
)
assert "extracted" in result_data["data"]
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand)
@patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
def test_search_task(self):
"""Test search functionality."""
from examples.stagehand.stagehand_wrapper_agent import StagehandAgent
from examples.stagehand.stagehand_wrapper_agent import (
StagehandAgent,
)
agent = StagehandAgent(
agent_name="TestAgent",
model_name="gpt-4o-mini",
env="LOCAL",
)
result = agent.run("Go to google.com and search for 'test query'")
result = agent.run(
"Go to google.com and search for 'test query'"
)
result_data = json.loads(result)
assert result_data["status"] == "completed"
assert result_data["data"]["search_query"] == "test query"
assert result_data["action"] == "search"
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand)
@patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
def test_cleanup(self):
"""Test that cleanup properly closes browser."""
from examples.stagehand.stagehand_wrapper_agent import StagehandAgent
from examples.stagehand.stagehand_wrapper_agent import (
StagehandAgent,
)
agent = StagehandAgent(
agent_name="TestAgent",
model_name="gpt-4o-mini",
env="LOCAL",
)
# Initialize the agent
agent.run("Navigate to example.com")
assert agent._initialized
# Cleanup
agent.cleanup()
# After cleanup, should be able to run again
result = agent.run("Navigate to example.com")
assert result is not None
@ -136,65 +159,84 @@ class TestStagehandAgent:
# Test Stagehand Tools
class TestStagehandTools:
"""Test individual Stagehand tools."""
@patch('examples.stagehand.stagehand_tools_agent.browser_state')
@patch("examples.stagehand.stagehand_tools_agent.browser_state")
async def test_navigate_tool(self, mock_browser_state):
"""Test NavigateTool functionality."""
from examples.stagehand.stagehand_tools_agent import NavigateTool
from examples.stagehand.stagehand_tools_agent import (
NavigateTool,
)
# Setup mock
mock_page = AsyncMock()
mock_browser_state.get_page = AsyncMock(return_value=mock_page)
mock_browser_state.get_page = AsyncMock(
return_value=mock_page
)
mock_browser_state.init_browser = AsyncMock()
tool = NavigateTool()
result = await tool._async_run("https://example.com")
assert "Successfully navigated to https://example.com" in result
assert (
"Successfully navigated to https://example.com" in result
)
mock_page.goto.assert_called_once_with("https://example.com")
@patch('examples.stagehand.stagehand_tools_agent.browser_state')
@patch("examples.stagehand.stagehand_tools_agent.browser_state")
async def test_act_tool(self, mock_browser_state):
"""Test ActTool functionality."""
from examples.stagehand.stagehand_tools_agent import ActTool
# Setup mock
mock_page = AsyncMock()
mock_page.act = AsyncMock(return_value="Action completed")
mock_browser_state.get_page = AsyncMock(return_value=mock_page)
mock_browser_state.get_page = AsyncMock(
return_value=mock_page
)
mock_browser_state.init_browser = AsyncMock()
tool = ActTool()
result = await tool._async_run("click the button")
assert "Action performed" in result
assert "click the button" in result
mock_page.act.assert_called_once_with("click the button")
@patch('examples.stagehand.stagehand_tools_agent.browser_state')
@patch("examples.stagehand.stagehand_tools_agent.browser_state")
async def test_extract_tool(self, mock_browser_state):
"""Test ExtractTool functionality."""
from examples.stagehand.stagehand_tools_agent import ExtractTool
from examples.stagehand.stagehand_tools_agent import (
ExtractTool,
)
# Setup mock
mock_page = AsyncMock()
mock_page.extract = AsyncMock(return_value={"title": "Test Page", "content": "Test content"})
mock_browser_state.get_page = AsyncMock(return_value=mock_page)
mock_page.extract = AsyncMock(
return_value={
"title": "Test Page",
"content": "Test content",
}
)
mock_browser_state.get_page = AsyncMock(
return_value=mock_page
)
mock_browser_state.init_browser = AsyncMock()
tool = ExtractTool()
result = await tool._async_run("extract the page title")
# Result should be JSON string
parsed_result = json.loads(result)
assert parsed_result["title"] == "Test Page"
assert parsed_result["content"] == "Test content"
@patch('examples.stagehand.stagehand_tools_agent.browser_state')
@patch("examples.stagehand.stagehand_tools_agent.browser_state")
async def test_observe_tool(self, mock_browser_state):
"""Test ObserveTool functionality."""
from examples.stagehand.stagehand_tools_agent import ObserveTool
from examples.stagehand.stagehand_tools_agent import (
ObserveTool,
)
# Setup mock
mock_page = AsyncMock()
mock_observations = [
@ -202,12 +244,14 @@ class TestStagehandTools:
MockObserveResult("Submit button", "#submit"),
]
mock_page.observe = AsyncMock(return_value=mock_observations)
mock_browser_state.get_page = AsyncMock(return_value=mock_page)
mock_browser_state.get_page = AsyncMock(
return_value=mock_page
)
mock_browser_state.init_browser = AsyncMock()
tool = ObserveTool()
result = await tool._async_run("find the search box")
# Result should be JSON string
parsed_result = json.loads(result)
assert len(parsed_result) == 2
@ -218,47 +262,53 @@ class TestStagehandTools:
# Test MCP integration
class TestStagehandMCP:
"""Test Stagehand MCP server integration."""
def test_mcp_agent_initialization(self):
"""Test that MCP agent initializes with correct parameters."""
from examples.stagehand.stagehand_mcp_agent import StagehandMCPAgent
from examples.stagehand.stagehand_mcp_agent import (
StagehandMCPAgent,
)
mcp_agent = StagehandMCPAgent(
agent_name="TestMCPAgent",
mcp_server_url="http://localhost:3000/sse",
model_name="gpt-4o-mini",
)
assert mcp_agent.agent.agent_name == "TestMCPAgent"
assert mcp_agent.agent.mcp_url == "http://localhost:3000/sse"
assert mcp_agent.agent.model_name == "gpt-4o-mini"
def test_multi_session_swarm_creation(self):
"""Test multi-session browser swarm creation."""
from examples.stagehand.stagehand_mcp_agent import MultiSessionBrowserSwarm
from examples.stagehand.stagehand_mcp_agent import (
MultiSessionBrowserSwarm,
)
swarm = MultiSessionBrowserSwarm(
mcp_server_url="http://localhost:3000/sse",
num_agents=3,
)
assert len(swarm.agents) == 3
assert swarm.agents[0].agent_name == "DataExtractor_0"
assert swarm.agents[1].agent_name == "FormFiller_1"
assert swarm.agents[2].agent_name == "WebMonitor_2"
@patch('swarms.Agent.run')
@patch("swarms.Agent.run")
def test_task_distribution(self, mock_run):
"""Test task distribution among swarm agents."""
from examples.stagehand.stagehand_mcp_agent import MultiSessionBrowserSwarm
from examples.stagehand.stagehand_mcp_agent import (
MultiSessionBrowserSwarm,
)
mock_run.return_value = "Task completed"
swarm = MultiSessionBrowserSwarm(num_agents=2)
tasks = ["Task 1", "Task 2", "Task 3"]
results = swarm.distribute_tasks(tasks)
assert len(results) == 3
assert all(result == "Task completed" for result in results)
assert mock_run.call_count == 3
@ -267,90 +317,120 @@ class TestStagehandMCP:
# Test multi-agent workflows
class TestMultiAgentWorkflows:
"""Test multi-agent workflow configurations."""
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand)
@patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
def test_price_comparison_workflow_creation(self):
"""Test creation of price comparison workflow."""
from examples.stagehand.stagehand_multi_agent_workflow import create_price_comparison_workflow
from examples.stagehand.stagehand_multi_agent_workflow import (
create_price_comparison_workflow,
)
workflow = create_price_comparison_workflow()
# Should be a SequentialWorkflow with 2 agents
assert len(workflow.agents) == 2
# First agent should be a ConcurrentWorkflow
assert hasattr(workflow.agents[0], 'agents')
assert hasattr(workflow.agents[0], "agents")
# Second agent should be the analysis agent
assert workflow.agents[1].agent_name == "PriceAnalysisAgent"
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand)
@patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
def test_competitive_analysis_workflow_creation(self):
"""Test creation of competitive analysis workflow."""
from examples.stagehand.stagehand_multi_agent_workflow import create_competitive_analysis_workflow
from examples.stagehand.stagehand_multi_agent_workflow import (
create_competitive_analysis_workflow,
)
workflow = create_competitive_analysis_workflow()
# Should have 3 agents in the rearrange pattern
assert len(workflow.agents) == 3
assert workflow.flow == "company_researcher -> social_media_agent -> report_compiler"
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand)
assert (
workflow.flow
== "company_researcher -> social_media_agent -> report_compiler"
)
@patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
def test_automated_testing_workflow_creation(self):
"""Test creation of automated testing workflow."""
from examples.stagehand.stagehand_multi_agent_workflow import create_automated_testing_workflow
from examples.stagehand.stagehand_multi_agent_workflow import (
create_automated_testing_workflow,
)
workflow = create_automated_testing_workflow()
# Should be a SequentialWorkflow
assert len(workflow.agents) == 2
# First should be concurrent testing
assert hasattr(workflow.agents[0], 'agents')
assert len(workflow.agents[0].agents) == 3 # UI, Form, Accessibility testers
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand)
assert hasattr(workflow.agents[0], "agents")
assert (
len(workflow.agents[0].agents) == 3
) # UI, Form, Accessibility testers
@patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
def test_news_aggregation_workflow_creation(self):
"""Test creation of news aggregation workflow."""
from examples.stagehand.stagehand_multi_agent_workflow import create_news_aggregation_workflow
from examples.stagehand.stagehand_multi_agent_workflow import (
create_news_aggregation_workflow,
)
workflow = create_news_aggregation_workflow()
# Should be a SequentialWorkflow with 3 stages
assert len(workflow.agents) == 3
# First stage should be concurrent scrapers
assert hasattr(workflow.agents[0], 'agents')
assert hasattr(workflow.agents[0], "agents")
assert len(workflow.agents[0].agents) == 3 # 3 news sources
# Integration tests
class TestIntegration:
"""End-to-end integration tests."""
@pytest.mark.asyncio
@patch('examples.stagehand.stagehand_wrapper_agent.Stagehand', MockStagehand)
@patch(
"examples.stagehand.stagehand_wrapper_agent.Stagehand",
MockStagehand,
)
async def test_full_browser_automation_flow(self):
"""Test a complete browser automation flow."""
from examples.stagehand.stagehand_wrapper_agent import StagehandAgent
from examples.stagehand.stagehand_wrapper_agent import (
StagehandAgent,
)
agent = StagehandAgent(
agent_name="IntegrationTestAgent",
model_name="gpt-4o-mini",
env="LOCAL",
)
# Test navigation
nav_result = agent.run("Navigate to example.com")
assert "navigated_to" in nav_result
# Test extraction
extract_result = agent.run("Extract all text from the page")
assert "extracted" in extract_result
# Test observation
observe_result = agent.run("Find all buttons on the page")
assert "observation" in observe_result
# Cleanup
agent.cleanup()
if __name__ == "__main__":
pytest.main([__file__, "-v"])
pytest.main([__file__, "-v"])

@ -0,0 +1,304 @@
"""
Simple tests for Stagehand Integration with Swarms
=================================================
These tests verify the basic structure and functionality of the
Stagehand integration without requiring external dependencies.
"""
import json
import pytest
from unittest.mock import MagicMock
class TestStagehandIntegrationStructure:
"""Test that integration files have correct structure."""
def test_examples_directory_exists(self):
"""Test that examples directory structure is correct."""
import os
base_path = "examples/stagehand"
assert os.path.exists(base_path)
expected_files = [
"1_stagehand_wrapper_agent.py",
"2_stagehand_tools_agent.py",
"3_stagehand_mcp_agent.py",
"4_stagehand_multi_agent_workflow.py",
"README.md",
"requirements.txt",
]
for file in expected_files:
file_path = os.path.join(base_path, file)
assert os.path.exists(file_path), f"Missing file: {file}"
def test_wrapper_agent_imports(self):
"""Test that wrapper agent has correct imports."""
with open(
"examples/stagehand/1_stagehand_wrapper_agent.py", "r"
) as f:
content = f.read()
# Check for required imports
assert "from swarms import Agent" in content
assert "import asyncio" in content
assert "import json" in content
assert "class StagehandAgent" in content
def test_tools_agent_imports(self):
"""Test that tools agent has correct imports."""
with open(
"examples/stagehand/2_stagehand_tools_agent.py", "r"
) as f:
content = f.read()
# Check for required imports
assert (
"from swarms.tools.base_tool import BaseTool" in content
)
assert "class NavigateTool" in content
assert "class ActTool" in content
assert "class ExtractTool" in content
def test_mcp_agent_imports(self):
"""Test that MCP agent has correct imports."""
with open(
"examples/stagehand/3_stagehand_mcp_agent.py", "r"
) as f:
content = f.read()
# Check for required imports
assert "from swarms import Agent" in content
assert "class StagehandMCPAgent" in content
assert "mcp_url" in content
def test_workflow_agent_imports(self):
"""Test that workflow agent has correct imports."""
with open(
"examples/stagehand/4_stagehand_multi_agent_workflow.py",
"r",
) as f:
content = f.read()
# Check for required imports
assert (
"from swarms import Agent, SequentialWorkflow, ConcurrentWorkflow"
in content
)
assert (
"from swarms.structs.agent_rearrange import AgentRearrange"
in content
)
class TestStagehandMockIntegration:
"""Test Stagehand integration with mocked dependencies."""
def test_mock_stagehand_initialization(self):
"""Test that Stagehand can be mocked and initialized."""
# Setup mock without importing actual stagehand
mock_stagehand = MagicMock()
mock_instance = MagicMock()
mock_instance.init = MagicMock()
mock_stagehand.return_value = mock_instance
# Mock config creation
config = MagicMock()
stagehand_instance = mock_stagehand(config)
# Verify mock works
assert stagehand_instance is not None
assert hasattr(stagehand_instance, "init")
def test_json_serialization(self):
"""Test JSON serialization for agent responses."""
# Test data that would come from browser automation
test_data = {
"task": "Navigate to example.com",
"status": "completed",
"data": {
"navigated_to": "https://example.com",
"extracted": ["item1", "item2"],
"action": "navigate",
},
}
# Test serialization
json_result = json.dumps(test_data, indent=2)
assert isinstance(json_result, str)
# Test deserialization
parsed_data = json.loads(json_result)
assert parsed_data["task"] == "Navigate to example.com"
assert parsed_data["status"] == "completed"
assert len(parsed_data["data"]["extracted"]) == 2
def test_url_extraction_logic(self):
"""Test URL extraction logic from task strings."""
import re
# Test cases
test_cases = [
(
"Navigate to https://example.com",
["https://example.com"],
),
("Go to google.com and search", ["google.com"]),
(
"Visit https://github.com/repo",
["https://github.com/repo"],
),
("Open example.org", ["example.org"]),
]
url_pattern = r"https?://[^\s]+"
domain_pattern = r"(\w+\.\w+)"
for task, expected in test_cases:
# Extract full URLs
urls = re.findall(url_pattern, task)
# If no full URLs, extract domains
if not urls:
domains = re.findall(domain_pattern, task)
if domains:
urls = domains
assert (
len(urls) > 0
), f"Failed to extract URL from: {task}"
assert (
urls[0] in expected
), f"Expected {expected}, got {urls}"
class TestSwarmsPatternsCompliance:
"""Test compliance with Swarms framework patterns."""
def test_agent_inheritance_pattern(self):
"""Test that wrapper agent follows Swarms Agent inheritance pattern."""
# Read the wrapper agent file
with open(
"examples/stagehand/1_stagehand_wrapper_agent.py", "r"
) as f:
content = f.read()
# Check inheritance pattern
assert "class StagehandAgent(SwarmsAgent):" in content
assert "def run(self, task: str" in content
assert "return" in content
def test_tools_pattern(self):
"""Test that tools follow Swarms BaseTool pattern."""
# Read the tools agent file
with open(
"examples/stagehand/2_stagehand_tools_agent.py", "r"
) as f:
content = f.read()
# Check tool pattern
assert "class NavigateTool(BaseTool):" in content
assert "def run(self," in content
assert "name=" in content
assert "description=" in content
def test_mcp_integration_pattern(self):
"""Test MCP integration follows Swarms pattern."""
# Read the MCP agent file
with open(
"examples/stagehand/3_stagehand_mcp_agent.py", "r"
) as f:
content = f.read()
# Check MCP pattern
assert "mcp_url=" in content
assert "Agent(" in content
def test_workflow_patterns(self):
"""Test workflow patterns are properly used."""
# Read the workflow file
with open(
"examples/stagehand/4_stagehand_multi_agent_workflow.py",
"r",
) as f:
content = f.read()
# Check workflow patterns
assert "SequentialWorkflow" in content
assert "ConcurrentWorkflow" in content
assert "AgentRearrange" in content
class TestDocumentationAndExamples:
"""Test documentation and example completeness."""
def test_readme_completeness(self):
"""Test that README contains essential information."""
with open("examples/stagehand/README.md", "r") as f:
content = f.read()
required_sections = [
"# Stagehand Browser Automation Integration",
"## Overview",
"## Examples",
"## Setup",
"## Use Cases",
"## Best Practices",
]
for section in required_sections:
assert section in content, f"Missing section: {section}"
def test_requirements_file(self):
"""Test that requirements file has necessary dependencies."""
with open("examples/stagehand/requirements.txt", "r") as f:
content = f.read()
required_deps = [
"swarms",
"stagehand",
"python-dotenv",
"pydantic",
"loguru",
]
for dep in required_deps:
assert dep in content, f"Missing dependency: {dep}"
def test_example_files_have_docstrings(self):
"""Test that example files have proper docstrings."""
example_files = [
"examples/stagehand/1_stagehand_wrapper_agent.py",
"examples/stagehand/2_stagehand_tools_agent.py",
"examples/stagehand/3_stagehand_mcp_agent.py",
"examples/stagehand/4_stagehand_multi_agent_workflow.py",
]
for file_path in example_files:
with open(file_path, "r") as f:
content = f.read()
# Check for module docstring
assert (
'"""' in content[:500]
), f"Missing docstring in {file_path}"
# Check for main execution block
assert (
'if __name__ == "__main__":' in content
), f"Missing main block in {file_path}"
if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading…
Cancel
Save