diff --git a/README.md b/README.md index 74d0428b..b9e64eae 100644 --- a/README.md +++ b/README.md @@ -676,6 +676,81 @@ This architecture is perfect for financial analysis, strategic planning, researc --- +### Agent Orchestration Protocol (AOP) + +The **Agent Orchestration Protocol (AOP)** is a powerful framework for deploying and managing agents as distributed services. AOP enables agents to be discovered, managed, and executed through a standardized protocol, making it perfect for building scalable multi-agent systems. [Learn more about AOP](https://docs.swarms.world/en/latest/swarms/structs/aop/) + +```python +from swarms import Agent +from swarms.structs.aop import AOP + +# Create specialized agents +research_agent = Agent( + agent_name="Research-Agent", + agent_description="Expert in research and data collection", + model_name="anthropic/claude-sonnet-4-5", + max_loops=1, + tags=["research", "data-collection", "analysis"], + capabilities=["web-search", "data-gathering", "report-generation"], + role="researcher" +) + +analysis_agent = Agent( + agent_name="Analysis-Agent", + agent_description="Expert in data analysis and insights", + model_name="anthropic/claude-sonnet-4-5", + max_loops=1, + tags=["analysis", "data-processing", "insights"], + capabilities=["statistical-analysis", "pattern-recognition", "visualization"], + role="analyst" +) + +# Create AOP server +deployer = AOP( + server_name="ResearchCluster", + port=8000, + verbose=True +) + +# Add agents to the server +deployer.add_agent( + agent=research_agent, + tool_name="research_tool", + tool_description="Research and data collection tool", + timeout=30, + max_retries=3 +) + +deployer.add_agent( + agent=analysis_agent, + tool_name="analysis_tool", + tool_description="Data analysis and insights tool", + timeout=30, + max_retries=3 +) + +# List all registered agents +print("Registered agents:", deployer.list_agents()) + +# Start the AOP server +deployer.run() +``` + +AOP provides: + +| Feature | Description | +|-------------------------------|--------------------------------------------------------------------------| +| **Distributed Agent Deployment** | Deploy agents as independent services | +| **Agent Discovery** | Built-in discovery tools for finding and connecting to agents | +| **Standardized Protocol** | MCP-compatible interface for seamless integration | +| **Dynamic Management** | Add, remove, and manage agents at runtime | +| **Scalable Architecture** | Support for multiple agent clusters and load balancing | +| **Enterprise Integration** | Easy integration with existing systems and workflows | + +Perfect for deploying large scale multi-agent systems. [Read the complete AOP documentation](https://docs.swarms.world/en/latest/swarms/structs/aop/) + +--- + ## Documentation Documentation is located here at: [docs.swarms.world](https://docs.swarms.world) @@ -722,6 +797,7 @@ Explore comprehensive examples and tutorials to learn how to use Swarms effectiv | **Multi-Agent Architecture** | Agents as Tools | Using agents as tools in workflows | [Agents as Tools](https://docs.swarms.world/en/latest/swarms/examples/agents_as_tools/) | | **Multi-Agent Architecture** | Aggregate Responses | Combining multiple agent outputs | [Aggregate Examples](https://docs.swarms.world/en/latest/swarms/examples/aggregate/) | | **Multi-Agent Architecture** | Interactive GroupChat | Real-time agent interactions | [Interactive GroupChat](https://docs.swarms.world/en/latest/swarms/examples/igc_example/) | +| **Deployment Solutions** | Agent Orchestration Protocol (AOP) | Deploy agents as distributed services with discovery and management | [AOP Reference](https://docs.swarms.world/en/latest/swarms/structs/aop/) | | **Applications** | Advanced Research System | Multi-agent research system inspired by Anthropic's research methodology | [AdvancedResearch](https://github.com/The-Swarm-Corporation/AdvancedResearch) | | **Applications** | Hospital Simulation | Healthcare simulation system using multi-agent architecture | [HospitalSim](https://github.com/The-Swarm-Corporation/HospitalSim) | | **Applications** | Browser Agents | Web automation with agents | [Browser Agents](https://docs.swarms.world/en/latest/swarms/examples/swarms_of_browser_agents/) | diff --git a/moa_seq_example.py b/moa_seq_example.py new file mode 100644 index 00000000..a06cd78d --- /dev/null +++ b/moa_seq_example.py @@ -0,0 +1,21 @@ +from swarms.structs.self_moa_seq import SelfMoASeq + +# Example usage +if __name__ == "__main__": + + # Initialize + moa_seq = SelfMoASeq( + model_name="gpt-4o-mini", + temperature=0.7, + window_size=6, + verbose=True, + num_samples=4, + ) + + # Run + task = ( + "Describe an effective treatment plan for a patient with a broken rib. " + "Include immediate care, pain management, expected recovery timeline, and potential complications to watch for." + ) + + result = moa_seq.run(task) diff --git a/swarms/schemas/dynamic_swarm.py b/swarms/schemas/dynamic_swarm.py new file mode 100644 index 00000000..fde33f32 --- /dev/null +++ b/swarms/schemas/dynamic_swarm.py @@ -0,0 +1,38 @@ +from pydantic import BaseModel +from swarms.tools.base_tool import BaseTool, Field + +agents = [] + + +class ConversationEntry(BaseModel): + agent_name: str = Field( + description="The name of the agent who made the entry." + ) + message: str = Field(description="The message sent by the agent.") + + +class LeaveConversation(BaseModel): + agent_name: str = Field( + description="The name of the agent who left the conversation." + ) + + +class JoinGroupChat(BaseModel): + agent_name: str = Field( + description="The name of the agent who joined the conversation." + ) + group_chat_name: str = Field( + description="The name of the group chat." + ) + initial_message: str = Field( + description="The initial message sent by the agent." + ) + + +conversation_entry = BaseTool().base_model_to_dict(ConversationEntry) +leave_conversation = BaseTool().base_model_to_dict(LeaveConversation) +join_group_chat = BaseTool().base_model_to_dict(JoinGroupChat) + +print(conversation_entry) +print(leave_conversation) +print(join_group_chat) diff --git a/swarms/structs/collaborative_utils.py b/swarms/structs/collaborative_utils.py new file mode 100644 index 00000000..c761e340 --- /dev/null +++ b/swarms/structs/collaborative_utils.py @@ -0,0 +1,77 @@ +import traceback + +from loguru import logger + +from swarms.structs.deep_discussion import one_on_one_debate +from swarms.structs.omni_agent_types import AgentListType, AgentType + + +def talk_to_agent( + current_agent: AgentType, + agents: AgentListType, + task: str, + agent_name: str, + max_loops: int = 1, + output_type: str = "str-all-except-first", +): + """ + Initiate a one-on-one debate between the current agent and a named target agent + from a provided list, using a specified task or message as the debate topic. + + This function searches through the provided list of agents for an agent whose + 'agent_name' matches the specified `agent_name`. If found, it runs a turn-based + debate between `current_agent` and the target agent, using the `one_on_one_debate` + utility for a specified number of conversational loops. + + Args: + current_agent (AgentType): The agent initiating the debate. + agents (AgentListType): The list of agent objects (must have 'agent_name' attributes). + task (str): The task, question, or message that serves as the debate topic. + agent_name (str): The name of the agent to engage in the debate with (must match 'agent_name'). + max_loops (int, optional): Number of debate turns per agent. Defaults to 1. + output_type (str, optional): The format for the debate's output as returned by + `one_on_one_debate`. Defaults to "str-all-except-first". + + Returns: + list: The formatted conversation history generated by `one_on_one_debate`. + + Raises: + ValueError: If no agent with the specified name exists in the agent list. + Exception: Propagates any error encountered during debate setup or execution. + + Example: + >>> talk_to_agent( + ... current_agent=alice, + ... agents=[alice, bob], + ... task="Summarize and critique the given proposal.", + ... agent_name="bob", + ... max_loops=2 + ... ) + """ + try: + target_agent = None + for agent in agents: + if ( + hasattr(agent, "agent_name") + and agent.agent_name == agent_name + ): + target_agent = agent + break + + if target_agent is None: + raise ValueError( + f"Agent '{agent_name}' not found in agent list." + ) + + # Initiate a one-on-one debate between the current agent and the target agent. + return one_on_one_debate( + max_loops=max_loops, + agents=[current_agent, target_agent], + task=task, + output_type=output_type, + ) + except Exception as e: + logger.error( + f"Error talking to agent: {e} Traceback: {traceback.format_exc()}" + ) + raise e diff --git a/swarms/structs/self_moa_seq.py b/swarms/structs/self_moa_seq.py new file mode 100644 index 00000000..0744332e --- /dev/null +++ b/swarms/structs/self_moa_seq.py @@ -0,0 +1,397 @@ +from datetime import datetime +from functools import wraps +from typing import Any, Dict, List, Optional + +from loguru import logger +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from swarms.structs.agent import Agent + + +def retry_with_instance_config(func): + """ + Decorator that applies retry configuration using instance variables. + This allows the retry decorator to access instance configuration. + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + # Create retry decorator with instance configuration + retry_decorator = retry( + stop=stop_after_attempt(self.max_retries + 1), + wait=wait_exponential( + multiplier=self.retry_backoff_multiplier, + min=self.retry_delay, + max=self.retry_max_delay, + ), + retry=retry_if_exception_type((Exception,)), + before_sleep=before_sleep_log(logger, "WARNING"), + ) + + # Apply the retry decorator to the function + retried_func = retry_decorator(func) + return retried_func(self, *args, **kwargs) + + return wrapper + + +class SelfMoASeq: + """ + Self-MoA-Seq: Sequential Self-Mixture of Agents + + An ensemble method that generates multiple outputs from a single + high-performing model and aggregates them sequentially using a + sliding window approach. This addresses context length constraints + while maintaining the effectiveness of in-model diversity. + + Architecture: + - Phase 1: Generate initial samples from the proposer model + - Phase 2: Aggregate outputs using sliding window with synthesized bias + - Phase 3: Iterate until all samples are processed + + + """ + + def __init__( + self, + name: str = "SelfMoASeq", + description: str = "Self-MoA-Seq: Sequential Self-Mixture of Agents", + model_name: str = "gpt-4o-mini", + temperature: float = 0.7, + window_size: int = 6, + reserved_slots: int = 3, + max_iterations: int = 10, + max_tokens: int = 2000, + num_samples: int = 30, + enable_logging: bool = True, + log_level: str = "INFO", + verbose: bool = True, + proposer_model_name: Optional[str] = None, + aggregator_model_name: Optional[str] = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff_multiplier: float = 2.0, + retry_max_delay: float = 60.0, + ): + # Validate parameters + if window_size < 2: + raise ValueError("window_size must be at least 2") + if reserved_slots >= window_size: + raise ValueError( + "reserved_slots must be less than window_size" + ) + if not 0 <= temperature <= 2: + raise ValueError("temperature must be between 0 and 2") + if max_iterations < 1: + raise ValueError("max_iterations must be at least 1") + if num_samples < 2: + raise ValueError("num_samples must be at least 2") + if max_retries < 0: + raise ValueError("max_retries must be non-negative") + if retry_delay < 0: + raise ValueError("retry_delay must be non-negative") + if retry_backoff_multiplier < 1: + raise ValueError("retry_backoff_multiplier must be >= 1") + if retry_max_delay < retry_delay: + raise ValueError("retry_max_delay must be >= retry_delay") + + # Store parameters + self.model_name = model_name + self.temperature = temperature + self.window_size = window_size + self.reserved_slots = reserved_slots + self.max_iterations = max_iterations + self.max_tokens = max_tokens + self.num_samples = num_samples + self.enable_logging = enable_logging + self.log_level = log_level + self.verbose = verbose + + # Retry configuration + self.max_retries = max_retries + self.retry_delay = retry_delay + self.retry_backoff_multiplier = retry_backoff_multiplier + self.retry_max_delay = retry_max_delay + + # Allow model overrides + proposer_model = proposer_model_name or self.model_name + aggregator_model = aggregator_model_name or self.model_name + + # Setup logging + logger.info( + f"Initializing Self-MoA-Seq with model: {self.model_name}" + ) + + # Initialize proposer agent (generates multiple samples) + self.proposer = Agent( + agent_name="SelfMoASeq-Proposer", + system_prompt=( + "You are a sample generator. Generate diverse, high-quality responses " + "to the given task. Vary your approach while maintaining quality." + ), + model_name=proposer_model, + temperature=self.temperature, + max_loops=1, + verbose=self.verbose, + ) + + # Initialize aggregator agent (synthesizes outputs) + self.aggregator = Agent( + agent_name="SelfMoASeq-Aggregator", + system_prompt=( + "You are an expert synthesizer. Analyze the provided responses and " + "synthesize them into a single, high-quality output. Consider the " + "strengths of each response and combine them effectively. Pay special " + "attention to any highlighted best response, as it provides high-quality guidance." + ), + model_name=aggregator_model, + temperature=0.0, # Deterministic aggregation + max_loops=1, + verbose=self.verbose, + ) + + # Metrics tracking + self.metrics: Dict[str, Any] = { + "total_samples_generated": 0, + "total_aggregations": 0, + "total_tokens_used": 0, + "execution_time_seconds": 0, + } + + logger.info("Self-MoA-Seq initialization complete") + + @retry_with_instance_config + def _generate_samples( + self, task: str, num_samples: int + ) -> List[str]: + """ + Generate multiple samples from the proposer model. + + Args: + task: The task description + num_samples: Number of samples to generate + + Returns: + List of generated samples + """ + logger.info(f"Generating {num_samples} samples for task") + samples = [] + + try: + for i in range(num_samples): + logger.debug(f"Generating sample {i+1}/{num_samples}") + sample = self.proposer.run(task) + samples.append(sample) + self.metrics["total_samples_generated"] += 1 + + logger.success( + f"Successfully generated {len(samples)} samples" + ) + return samples + + except Exception as e: + logger.error(f"Error during sample generation: {str(e)}") + raise + + def _format_aggregation_prompt( + self, + task: str, + samples: List[str], + best_so_far: Optional[str] = None, + ) -> str: + """ + Format the aggregation prompt with sliding window. + + Args: + task: Original task + samples: List of samples to aggregate + best_so_far: Previously synthesized best output + + Returns: + Formatted aggregation prompt + """ + prompt = f"Original Task:\n{task}\n\n" + + if best_so_far: + prompt += f"Current Best Response (synthesized from previous iterations):\n{best_so_far}\n\n" + + prompt += "Candidate Responses to Synthesize:\n" + for i, sample in enumerate(samples, 1): + prompt += f"\n[Response {i}]:\n{sample}\n" + + prompt += ( + "\nProvide a comprehensive synthesis that combines the strengths of " + "all responses while maintaining coherence and quality." + ) + + return prompt + + @retry_with_instance_config + def _aggregate_window( + self, + task: str, + window_samples: List[str], + best_so_far: Optional[str] = None, + ) -> str: + """ + Aggregate a window of samples. + + Args: + task: Original task + window_samples: Samples in current window + best_so_far: Best aggregation so far + + Returns: + Synthesized output + """ + logger.debug( + f"Aggregating window of {len(window_samples)} samples" + ) + + try: + prompt = self._format_aggregation_prompt( + task, + window_samples, + best_so_far, + ) + + aggregated = self.aggregator.run(prompt) + self.metrics["total_aggregations"] += 1 + + logger.debug("Window aggregation complete") + return aggregated + + except Exception as e: + logger.error(f"Error during window aggregation: {str(e)}") + raise + + @retry_with_instance_config + def run( + self, + task: str, + ) -> Dict[str, Any]: + """ + Execute Self-MoA-Seq on the given task. + + This method implements the sequential aggregation algorithm: + 1. Generate num_samples from the proposer model + 2. Use sliding window to aggregate in chunks + 3. Progressively synthesize outputs, biasing aggregator toward best + 4. Return final synthesized output + + Args: + task: The task to process + + Returns: + Dictionary containing: + - final_output: The best synthesized response + - all_samples: List of generated samples + - aggregation_steps: Number of aggregation iterations + - metrics: Performance metrics + """ + logger.info( + f"Starting Self-MoA-Seq run with {self.num_samples} samples" + ) + start_time = datetime.now() + + try: + # Validate input + if not task or not isinstance(task, str): + raise ValueError("task must be a non-empty string") + + # Phase 1: Generate samples + logger.info("Phase 1: Generating initial samples") + samples = self._generate_samples(task, self.num_samples) + + # Phase 2: Sequential aggregation with sliding window + logger.info("Phase 2: Sequential aggregation") + best_output = samples[0] + aggregation_step = 0 + + # Process samples in windows + remaining_samples = samples[1:] + + while remaining_samples: + aggregation_step += 1 + logger.info( + f"Aggregation iteration {aggregation_step}, " + f"remaining samples: {len(remaining_samples)}" + ) + + # Create window: reserved slots + new samples + window_size = min( + self.window_size - self.reserved_slots, + len(remaining_samples), + ) + current_window = remaining_samples[:window_size] + remaining_samples = remaining_samples[window_size:] + + # Aggregate with bias toward best output + window_with_best = [best_output] + current_window + best_output = self._aggregate_window( + task, + window_with_best, + best_output, + ) + + if aggregation_step >= self.max_iterations: + logger.warning( + f"Reached max aggregation iterations ({self.max_iterations})" + ) + break + + # Calculate metrics + elapsed = (datetime.now() - start_time).total_seconds() + self.metrics["execution_time_seconds"] = elapsed + + result = { + "final_output": best_output, + "all_samples": samples, + "aggregation_steps": aggregation_step, + "metrics": self.metrics.copy(), + "task": task, + "timestamp": datetime.now().isoformat(), + } + + logger.success( + f"Self-MoA-Seq completed in {elapsed:.2f}s " + f"with {aggregation_step} aggregation iterations" + ) + + if self.verbose: + self._log_summary(result) + + return result + + except Exception as e: + logger.error(f"Fatal error in Self-MoA-Seq.run: {str(e)}") + raise + + def _log_summary(self, result: Dict[str, Any]) -> None: + """Log execution summary.""" + logger.info("=" * 60) + logger.info("Self-MoA-Seq Execution Summary") + logger.info("=" * 60) + logger.info( + f"Total samples generated: {self.metrics['total_samples_generated']}" + ) + logger.info( + f"Aggregation iterations: {result['aggregation_steps']}" + ) + logger.info( + f"Execution time: {self.metrics['execution_time_seconds']:.2f}s" + ) + logger.info( + f"Final output length: {len(result['final_output'])} chars" + ) + logger.info("=" * 60) + + def get_metrics(self) -> Dict[str, Any]: + """Get current performance metrics.""" + return self.metrics.copy() diff --git a/tests/structs/test_self_moa_seq.py b/tests/structs/test_self_moa_seq.py new file mode 100644 index 00000000..f3c50e97 --- /dev/null +++ b/tests/structs/test_self_moa_seq.py @@ -0,0 +1,709 @@ +from unittest.mock import Mock, patch + +import pytest + +from swarms.structs.self_moa_seq import SelfMoASeq + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def basic_seq(): + """Create a basic SelfMoASeq instance for testing.""" + return SelfMoASeq( + num_samples=3, + window_size=4, + reserved_slots=2, + max_iterations=5, + verbose=False, + enable_logging=False, + ) + + +@pytest.fixture +def custom_retry_seq(): + """Create a SelfMoASeq instance with custom retry parameters.""" + return SelfMoASeq( + num_samples=2, + max_retries=5, + retry_delay=0.5, + retry_backoff_multiplier=1.5, + retry_max_delay=10.0, + verbose=False, + enable_logging=False, + ) + + +@pytest.fixture +def mock_agents(): + """Create mock agents for testing.""" + proposer = Mock() + aggregator = Mock() + return proposer, aggregator + + +# ============================================================================ +# Initialization and Parameter Validation Tests +# ============================================================================ + + +def test_default_initialization(): + """Test that SelfMoASeq initializes with default parameters.""" + seq = SelfMoASeq() + + assert seq.model_name == "gpt-4o-mini" + assert seq.temperature == 0.7 + assert seq.window_size == 6 + assert seq.reserved_slots == 3 + assert seq.max_iterations == 10 + assert seq.max_tokens == 2000 + assert seq.num_samples == 30 + assert seq.enable_logging is True + assert seq.log_level == "INFO" + assert seq.verbose is True + assert seq.max_retries == 3 + assert seq.retry_delay == 1.0 + assert seq.retry_backoff_multiplier == 2.0 + assert seq.retry_max_delay == 60.0 + + +def test_custom_initialization(): + """Test initialization with custom parameters.""" + seq = SelfMoASeq( + model_name="custom-model", + temperature=0.5, + window_size=8, + reserved_slots=2, + max_iterations=15, + max_tokens=3000, + num_samples=20, + enable_logging=False, + log_level="DEBUG", + verbose=False, + proposer_model_name="proposer-model", + aggregator_model_name="aggregator-model", + max_retries=5, + retry_delay=2.0, + retry_backoff_multiplier=3.0, + retry_max_delay=120.0, + ) + + assert seq.model_name == "custom-model" + assert seq.temperature == 0.5 + assert seq.window_size == 8 + assert seq.reserved_slots == 2 + assert seq.max_iterations == 15 + assert seq.max_tokens == 3000 + assert seq.num_samples == 20 + assert seq.enable_logging is False + assert seq.log_level == "DEBUG" + assert seq.verbose is False + assert seq.max_retries == 5 + assert seq.retry_delay == 2.0 + assert seq.retry_backoff_multiplier == 3.0 + assert seq.retry_max_delay == 120.0 + + +def test_window_size_validation(): + """Test window_size parameter validation.""" + # Valid window_size + seq = SelfMoASeq(window_size=2) + assert seq.window_size == 2 + + # Invalid window_size + with pytest.raises( + ValueError, match="window_size must be at least 2" + ): + SelfMoASeq(window_size=1) + + +def test_reserved_slots_validation(): + """Test reserved_slots parameter validation.""" + # Valid reserved_slots + seq = SelfMoASeq(window_size=6, reserved_slots=3) + assert seq.reserved_slots == 3 + + # Invalid reserved_slots (>= window_size) + with pytest.raises( + ValueError, + match="reserved_slots must be less than window_size", + ): + SelfMoASeq(window_size=6, reserved_slots=6) + + +def test_temperature_validation(): + """Test temperature parameter validation.""" + # Valid temperature + seq = SelfMoASeq(temperature=1.5) + assert seq.temperature == 1.5 + + # Invalid temperature (too high) + with pytest.raises( + ValueError, match="temperature must be between 0 and 2" + ): + SelfMoASeq(temperature=2.5) + + # Invalid temperature (negative) + with pytest.raises( + ValueError, match="temperature must be between 0 and 2" + ): + SelfMoASeq(temperature=-0.1) + + +def test_max_iterations_validation(): + """Test max_iterations parameter validation.""" + # Valid max_iterations + seq = SelfMoASeq(max_iterations=5) + assert seq.max_iterations == 5 + + # Invalid max_iterations + with pytest.raises( + ValueError, match="max_iterations must be at least 1" + ): + SelfMoASeq(max_iterations=0) + + +def test_num_samples_validation(): + """Test num_samples parameter validation.""" + # Valid num_samples + seq = SelfMoASeq(num_samples=5) + assert seq.num_samples == 5 + + # Invalid num_samples + with pytest.raises( + ValueError, match="num_samples must be at least 2" + ): + SelfMoASeq(num_samples=1) + + +def test_retry_parameters_validation(): + """Test retry parameters validation.""" + # Valid retry parameters + seq = SelfMoASeq( + max_retries=5, + retry_delay=2.0, + retry_backoff_multiplier=1.5, + retry_max_delay=10.0, + ) + assert seq.max_retries == 5 + assert seq.retry_delay == 2.0 + assert seq.retry_backoff_multiplier == 1.5 + assert seq.retry_max_delay == 10.0 + + # Invalid max_retries + with pytest.raises( + ValueError, match="max_retries must be non-negative" + ): + SelfMoASeq(max_retries=-1) + + # Invalid retry_delay + with pytest.raises( + ValueError, match="retry_delay must be non-negative" + ): + SelfMoASeq(retry_delay=-1.0) + + # Invalid retry_backoff_multiplier + with pytest.raises( + ValueError, match="retry_backoff_multiplier must be >= 1" + ): + SelfMoASeq(retry_backoff_multiplier=0.5) + + # Invalid retry_max_delay + with pytest.raises( + ValueError, match="retry_max_delay must be >= retry_delay" + ): + SelfMoASeq(retry_delay=10.0, retry_max_delay=5.0) + + +# ============================================================================ +# Retry Functionality Tests +# ============================================================================ + + +def test_retry_decorator_property(basic_seq): + """Test that retry decorator property works correctly.""" + decorator = basic_seq.retry_decorator + assert callable(decorator) + + +def test_get_retry_decorator(basic_seq): + """Test _get_retry_decorator method.""" + decorator = basic_seq._get_retry_decorator() + assert callable(decorator) + + +def test_retry_configuration_inheritance(custom_retry_seq): + """Test that retry configuration is properly inherited.""" + assert custom_retry_seq.max_retries == 5 + assert custom_retry_seq.retry_delay == 0.5 + assert custom_retry_seq.retry_backoff_multiplier == 1.5 + assert custom_retry_seq.retry_max_delay == 10.0 + + +def test_retry_functionality_with_mock(custom_retry_seq, mock_agents): + """Test retry functionality with mocked agents.""" + proposer, aggregator = mock_agents + + # Configure mock to fail first time, succeed second time + proposer.run.side_effect = [ + Exception("Simulated failure"), + "Sample 1", + "Sample 2", + ] + + aggregator.run.side_effect = [ + Exception("Simulated aggregation failure"), + "Aggregated result", + ] + + # Patch the agents + with patch.object( + custom_retry_seq, "proposer", proposer + ), patch.object(custom_retry_seq, "aggregator", aggregator): + + # This should succeed after retries + result = custom_retry_seq.run("Test task") + + assert result is not None + assert "final_output" in result + assert "all_samples" in result + assert "metrics" in result + + +# ============================================================================ +# Core Methods Tests +# ============================================================================ + + +def test_generate_samples_success(basic_seq, mock_agents): + """Test successful sample generation.""" + proposer, _ = mock_agents + proposer.run.return_value = "Generated sample" + + with patch.object(basic_seq, "proposer", proposer): + samples = basic_seq._generate_samples("Test task", 3) + + assert len(samples) == 3 + assert all(sample == "Generated sample" for sample in samples) + assert basic_seq.metrics["total_samples_generated"] == 3 + + +def test_generate_samples_with_retry(basic_seq, mock_agents): + """Test sample generation with retry on failure.""" + proposer, _ = mock_agents + proposer.run.side_effect = [ + Exception("First failure"), + Exception("Second failure"), + "Sample 1", + "Sample 2", + "Sample 3", + ] + + with patch.object(basic_seq, "proposer", proposer): + samples = basic_seq._generate_samples("Test task", 3) + + assert len(samples) == 3 + assert basic_seq.metrics["total_samples_generated"] == 3 + + +def test_format_aggregation_prompt(basic_seq): + """Test aggregation prompt formatting.""" + task = "Test task" + samples = ["Sample 1", "Sample 2", "Sample 3"] + best_so_far = "Best response" + + prompt = basic_seq._format_aggregation_prompt( + task, samples, best_so_far + ) + + assert "Original Task:" in prompt + assert task in prompt + assert "Current Best Response" in prompt + assert best_so_far in prompt + assert "Candidate Responses to Synthesize" in prompt + assert "Sample 1" in prompt + assert "Sample 2" in prompt + assert "Sample 3" in prompt + + +def test_format_aggregation_prompt_no_best(basic_seq): + """Test aggregation prompt formatting without best_so_far.""" + task = "Test task" + samples = ["Sample 1", "Sample 2"] + + prompt = basic_seq._format_aggregation_prompt(task, samples) + + assert "Original Task:" in prompt + assert task in prompt + assert "Current Best Response" not in prompt + assert "Candidate Responses to Synthesize" in prompt + assert "Sample 1" in prompt + assert "Sample 2" in prompt + + +def test_aggregate_window_success(basic_seq, mock_agents): + """Test successful window aggregation.""" + _, aggregator = mock_agents + aggregator.run.return_value = "Aggregated result" + + with patch.object(basic_seq, "aggregator", aggregator): + result = basic_seq._aggregate_window( + "Test task", ["Sample 1", "Sample 2"], "Best so far" + ) + + assert result == "Aggregated result" + assert basic_seq.metrics["total_aggregations"] == 1 + + +def test_aggregate_window_with_retry(basic_seq, mock_agents): + """Test window aggregation with retry on failure.""" + _, aggregator = mock_agents + aggregator.run.side_effect = [ + Exception("First failure"), + Exception("Second failure"), + "Aggregated result", + ] + + with patch.object(basic_seq, "aggregator", aggregator): + result = basic_seq._aggregate_window( + "Test task", ["Sample 1", "Sample 2"] + ) + + assert result == "Aggregated result" + assert basic_seq.metrics["total_aggregations"] == 1 + + +def test_run_method_success(basic_seq, mock_agents): + """Test successful run method execution.""" + proposer, aggregator = mock_agents + + # Configure mocks + proposer.run.return_value = "Generated sample" + aggregator.run.return_value = "Aggregated result" + + with patch.object(basic_seq, "proposer", proposer), patch.object( + basic_seq, "aggregator", aggregator + ): + + result = basic_seq.run("Test task") + + assert isinstance(result, dict) + assert "final_output" in result + assert "all_samples" in result + assert "aggregation_steps" in result + assert "metrics" in result + assert "task" in result + assert "timestamp" in result + + assert result["task"] == "Test task" + assert len(result["all_samples"]) == 3 + assert result["final_output"] == "Aggregated result" + + +def test_run_method_with_retry(basic_seq, mock_agents): + """Test run method with retry on failure.""" + proposer, aggregator = mock_agents + + # Configure mocks to fail first time, succeed second time + proposer.run.side_effect = [ + Exception("First failure"), + "Sample 1", + "Sample 2", + "Sample 3", + ] + + aggregator.run.side_effect = [ + Exception("Aggregation failure"), + "Final result", + ] + + with patch.object(basic_seq, "proposer", proposer), patch.object( + basic_seq, "aggregator", aggregator + ): + + result = basic_seq.run("Test task") + + assert result is not None + assert result["final_output"] == "Final result" + + +# ============================================================================ +# Error Handling and Edge Cases Tests +# ============================================================================ + + +def test_run_invalid_task(basic_seq): + """Test run method with invalid task input.""" + with pytest.raises( + ValueError, match="task must be a non-empty string" + ): + basic_seq.run("") + + with pytest.raises( + ValueError, match="task must be a non-empty string" + ): + basic_seq.run(None) + + +def test_run_max_iterations_reached(basic_seq, mock_agents): + """Test run method when max iterations are reached.""" + proposer, aggregator = mock_agents + + # Configure mocks + proposer.run.return_value = "Generated sample" + aggregator.run.return_value = "Aggregated result" + + # Set max_iterations to 1 to trigger the warning + basic_seq.max_iterations = 1 + + with patch.object(basic_seq, "proposer", proposer), patch.object( + basic_seq, "aggregator", aggregator + ): + + result = basic_seq.run("Test task") + + assert result is not None + assert result["aggregation_steps"] <= 1 + + +def test_generate_samples_exception_propagation( + basic_seq, mock_agents +): + """Test that exceptions in sample generation are properly propagated.""" + proposer, _ = mock_agents + proposer.run.side_effect = Exception("Persistent failure") + + with patch.object(basic_seq, "proposer", proposer): + with pytest.raises(Exception, match="Persistent failure"): + basic_seq._generate_samples("Test task", 3) + + +def test_aggregate_window_exception_propagation( + basic_seq, mock_agents +): + """Test that exceptions in window aggregation are properly propagated.""" + _, aggregator = mock_agents + aggregator.run.side_effect = Exception( + "Persistent aggregation failure" + ) + + with patch.object(basic_seq, "aggregator", aggregator): + with pytest.raises( + Exception, match="Persistent aggregation failure" + ): + basic_seq._aggregate_window( + "Test task", ["Sample 1", "Sample 2"] + ) + + +def test_run_exception_propagation(basic_seq, mock_agents): + """Test that exceptions in run method are properly propagated.""" + proposer, _ = mock_agents + proposer.run.side_effect = Exception("Persistent run failure") + + with patch.object(basic_seq, "proposer", proposer): + with pytest.raises(Exception, match="Persistent run failure"): + basic_seq.run("Test task") + + +# ============================================================================ +# Metrics and Logging Tests +# ============================================================================ + + +def test_metrics_initialization(basic_seq): + """Test that metrics are properly initialized.""" + metrics = basic_seq.get_metrics() + + assert isinstance(metrics, dict) + assert "total_samples_generated" in metrics + assert "total_aggregations" in metrics + assert "total_tokens_used" in metrics + assert "execution_time_seconds" in metrics + + assert metrics["total_samples_generated"] == 0 + assert metrics["total_aggregations"] == 0 + assert metrics["total_tokens_used"] == 0 + assert metrics["execution_time_seconds"] == 0 + + +def test_metrics_tracking(basic_seq, mock_agents): + """Test that metrics are properly tracked during execution.""" + proposer, aggregator = mock_agents + + proposer.run.return_value = "Generated sample" + aggregator.run.return_value = "Aggregated result" + + with patch.object(basic_seq, "proposer", proposer), patch.object( + basic_seq, "aggregator", aggregator + ): + + result = basic_seq.run("Test task") + + metrics = result["metrics"] + assert metrics["total_samples_generated"] == 3 + assert metrics["total_aggregations"] >= 1 + assert metrics["execution_time_seconds"] > 0 + + +def test_log_summary(basic_seq): + """Test _log_summary method.""" + result = { + "final_output": "Test output", + "aggregation_steps": 2, + "metrics": { + "total_samples_generated": 3, + "execution_time_seconds": 1.5, + }, + } + + # This should not raise an exception + basic_seq._log_summary(result) + + +def test_get_metrics_returns_copy(basic_seq): + """Test that get_metrics returns a copy of metrics.""" + metrics1 = basic_seq.get_metrics() + metrics2 = basic_seq.get_metrics() + + # Should be different objects + assert metrics1 is not metrics2 + + # But should have same content + assert metrics1 == metrics2 + + +# ============================================================================ +# Integration Tests +# ============================================================================ + + +def test_full_integration_small_samples(): + """Test full integration with small number of samples.""" + seq = SelfMoASeq( + num_samples=2, + window_size=3, + reserved_slots=1, + max_iterations=2, + verbose=False, + enable_logging=False, + ) + + proposer, aggregator = Mock(), Mock() + proposer.run.return_value = "Generated sample" + aggregator.run.return_value = "Aggregated result" + + with patch.object(seq, "proposer", proposer), patch.object( + seq, "aggregator", aggregator + ): + + result = seq.run("Integration test task") + + assert result is not None + assert result["task"] == "Integration test task" + assert len(result["all_samples"]) == 2 + assert result["final_output"] == "Aggregated result" + assert result["aggregation_steps"] >= 0 + + +def test_model_name_overrides(): + """Test that model name overrides work correctly.""" + seq = SelfMoASeq( + model_name="base-model", + proposer_model_name="proposer-model", + aggregator_model_name="aggregator-model", + verbose=False, + enable_logging=False, + ) + + # The agents should be initialized with the override names + assert seq.proposer.model_name == "proposer-model" + assert seq.aggregator.model_name == "aggregator-model" + + +def test_temperature_settings(): + """Test that temperature settings are applied correctly.""" + seq = SelfMoASeq( + temperature=0.5, verbose=False, enable_logging=False + ) + + assert seq.proposer.temperature == 0.5 + assert ( + seq.aggregator.temperature == 0.0 + ) # Deterministic aggregation + + +# ============================================================================ +# Performance and Edge Case Tests +# ============================================================================ + + +def test_minimum_valid_configuration(): + """Test with minimum valid configuration.""" + seq = SelfMoASeq( + window_size=2, + reserved_slots=1, + max_iterations=1, + num_samples=2, + verbose=False, + enable_logging=False, + ) + + assert seq.window_size == 2 + assert seq.reserved_slots == 1 + assert seq.max_iterations == 1 + assert seq.num_samples == 2 + + +def test_zero_retries(): + """Test with zero retries (should still work but not retry).""" + seq = SelfMoASeq( + max_retries=0, + num_samples=2, + verbose=False, + enable_logging=False, + ) + + assert seq.max_retries == 0 + + proposer, aggregator = Mock(), Mock() + proposer.run.return_value = "Generated sample" + aggregator.run.return_value = "Aggregated result" + + with patch.object(seq, "proposer", proposer), patch.object( + seq, "aggregator", aggregator + ): + + result = seq.run("Test task") + assert result is not None + + +def test_large_configuration(): + """Test with large configuration values.""" + seq = SelfMoASeq( + window_size=20, + reserved_slots=5, + max_iterations=50, + num_samples=100, + max_retries=10, + retry_delay=5.0, + retry_backoff_multiplier=3.0, + retry_max_delay=300.0, + verbose=False, + enable_logging=False, + ) + + assert seq.window_size == 20 + assert seq.reserved_slots == 5 + assert seq.max_iterations == 50 + assert seq.num_samples == 100 + assert seq.max_retries == 10 + assert seq.retry_delay == 5.0 + assert seq.retry_backoff_multiplier == 3.0 + assert seq.retry_max_delay == 300.0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])