parent
c623216b70
commit
ebe5f52988
@ -0,0 +1,21 @@
|
||||
from swarms.structs.self_moa_seq import SelfMoASeq
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Initialize
|
||||
moa_seq = SelfMoASeq(
|
||||
model_name="gpt-4o-mini",
|
||||
temperature=0.7,
|
||||
window_size=6,
|
||||
verbose=True,
|
||||
num_samples=4,
|
||||
)
|
||||
|
||||
# Run
|
||||
task = (
|
||||
"Describe an effective treatment plan for a patient with a broken rib. "
|
||||
"Include immediate care, pain management, expected recovery timeline, and potential complications to watch for."
|
||||
)
|
||||
|
||||
result = moa_seq.run(task)
|
@ -0,0 +1,38 @@
|
||||
from pydantic import BaseModel
|
||||
from swarms.tools.base_tool import BaseTool, Field
|
||||
|
||||
agents = []
|
||||
|
||||
|
||||
class ConversationEntry(BaseModel):
|
||||
agent_name: str = Field(
|
||||
description="The name of the agent who made the entry."
|
||||
)
|
||||
message: str = Field(description="The message sent by the agent.")
|
||||
|
||||
|
||||
class LeaveConversation(BaseModel):
|
||||
agent_name: str = Field(
|
||||
description="The name of the agent who left the conversation."
|
||||
)
|
||||
|
||||
|
||||
class JoinGroupChat(BaseModel):
|
||||
agent_name: str = Field(
|
||||
description="The name of the agent who joined the conversation."
|
||||
)
|
||||
group_chat_name: str = Field(
|
||||
description="The name of the group chat."
|
||||
)
|
||||
initial_message: str = Field(
|
||||
description="The initial message sent by the agent."
|
||||
)
|
||||
|
||||
|
||||
conversation_entry = BaseTool().base_model_to_dict(ConversationEntry)
|
||||
leave_conversation = BaseTool().base_model_to_dict(LeaveConversation)
|
||||
join_group_chat = BaseTool().base_model_to_dict(JoinGroupChat)
|
||||
|
||||
print(conversation_entry)
|
||||
print(leave_conversation)
|
||||
print(join_group_chat)
|
@ -0,0 +1,77 @@
|
||||
import traceback
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from swarms.structs.deep_discussion import one_on_one_debate
|
||||
from swarms.structs.omni_agent_types import AgentListType, AgentType
|
||||
|
||||
|
||||
def talk_to_agent(
|
||||
current_agent: AgentType,
|
||||
agents: AgentListType,
|
||||
task: str,
|
||||
agent_name: str,
|
||||
max_loops: int = 1,
|
||||
output_type: str = "str-all-except-first",
|
||||
):
|
||||
"""
|
||||
Initiate a one-on-one debate between the current agent and a named target agent
|
||||
from a provided list, using a specified task or message as the debate topic.
|
||||
|
||||
This function searches through the provided list of agents for an agent whose
|
||||
'agent_name' matches the specified `agent_name`. If found, it runs a turn-based
|
||||
debate between `current_agent` and the target agent, using the `one_on_one_debate`
|
||||
utility for a specified number of conversational loops.
|
||||
|
||||
Args:
|
||||
current_agent (AgentType): The agent initiating the debate.
|
||||
agents (AgentListType): The list of agent objects (must have 'agent_name' attributes).
|
||||
task (str): The task, question, or message that serves as the debate topic.
|
||||
agent_name (str): The name of the agent to engage in the debate with (must match 'agent_name').
|
||||
max_loops (int, optional): Number of debate turns per agent. Defaults to 1.
|
||||
output_type (str, optional): The format for the debate's output as returned by
|
||||
`one_on_one_debate`. Defaults to "str-all-except-first".
|
||||
|
||||
Returns:
|
||||
list: The formatted conversation history generated by `one_on_one_debate`.
|
||||
|
||||
Raises:
|
||||
ValueError: If no agent with the specified name exists in the agent list.
|
||||
Exception: Propagates any error encountered during debate setup or execution.
|
||||
|
||||
Example:
|
||||
>>> talk_to_agent(
|
||||
... current_agent=alice,
|
||||
... agents=[alice, bob],
|
||||
... task="Summarize and critique the given proposal.",
|
||||
... agent_name="bob",
|
||||
... max_loops=2
|
||||
... )
|
||||
"""
|
||||
try:
|
||||
target_agent = None
|
||||
for agent in agents:
|
||||
if (
|
||||
hasattr(agent, "agent_name")
|
||||
and agent.agent_name == agent_name
|
||||
):
|
||||
target_agent = agent
|
||||
break
|
||||
|
||||
if target_agent is None:
|
||||
raise ValueError(
|
||||
f"Agent '{agent_name}' not found in agent list."
|
||||
)
|
||||
|
||||
# Initiate a one-on-one debate between the current agent and the target agent.
|
||||
return one_on_one_debate(
|
||||
max_loops=max_loops,
|
||||
agents=[current_agent, target_agent],
|
||||
task=task,
|
||||
output_type=output_type,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error talking to agent: {e} Traceback: {traceback.format_exc()}"
|
||||
)
|
||||
raise e
|
@ -0,0 +1,397 @@
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from swarms.structs.agent import Agent
|
||||
|
||||
|
||||
def retry_with_instance_config(func):
|
||||
"""
|
||||
Decorator that applies retry configuration using instance variables.
|
||||
This allows the retry decorator to access instance configuration.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
# Create retry decorator with instance configuration
|
||||
retry_decorator = retry(
|
||||
stop=stop_after_attempt(self.max_retries + 1),
|
||||
wait=wait_exponential(
|
||||
multiplier=self.retry_backoff_multiplier,
|
||||
min=self.retry_delay,
|
||||
max=self.retry_max_delay,
|
||||
),
|
||||
retry=retry_if_exception_type((Exception,)),
|
||||
before_sleep=before_sleep_log(logger, "WARNING"),
|
||||
)
|
||||
|
||||
# Apply the retry decorator to the function
|
||||
retried_func = retry_decorator(func)
|
||||
return retried_func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class SelfMoASeq:
|
||||
"""
|
||||
Self-MoA-Seq: Sequential Self-Mixture of Agents
|
||||
|
||||
An ensemble method that generates multiple outputs from a single
|
||||
high-performing model and aggregates them sequentially using a
|
||||
sliding window approach. This addresses context length constraints
|
||||
while maintaining the effectiveness of in-model diversity.
|
||||
|
||||
Architecture:
|
||||
- Phase 1: Generate initial samples from the proposer model
|
||||
- Phase 2: Aggregate outputs using sliding window with synthesized bias
|
||||
- Phase 3: Iterate until all samples are processed
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "SelfMoASeq",
|
||||
description: str = "Self-MoA-Seq: Sequential Self-Mixture of Agents",
|
||||
model_name: str = "gpt-4o-mini",
|
||||
temperature: float = 0.7,
|
||||
window_size: int = 6,
|
||||
reserved_slots: int = 3,
|
||||
max_iterations: int = 10,
|
||||
max_tokens: int = 2000,
|
||||
num_samples: int = 30,
|
||||
enable_logging: bool = True,
|
||||
log_level: str = "INFO",
|
||||
verbose: bool = True,
|
||||
proposer_model_name: Optional[str] = None,
|
||||
aggregator_model_name: Optional[str] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff_multiplier: float = 2.0,
|
||||
retry_max_delay: float = 60.0,
|
||||
):
|
||||
# Validate parameters
|
||||
if window_size < 2:
|
||||
raise ValueError("window_size must be at least 2")
|
||||
if reserved_slots >= window_size:
|
||||
raise ValueError(
|
||||
"reserved_slots must be less than window_size"
|
||||
)
|
||||
if not 0 <= temperature <= 2:
|
||||
raise ValueError("temperature must be between 0 and 2")
|
||||
if max_iterations < 1:
|
||||
raise ValueError("max_iterations must be at least 1")
|
||||
if num_samples < 2:
|
||||
raise ValueError("num_samples must be at least 2")
|
||||
if max_retries < 0:
|
||||
raise ValueError("max_retries must be non-negative")
|
||||
if retry_delay < 0:
|
||||
raise ValueError("retry_delay must be non-negative")
|
||||
if retry_backoff_multiplier < 1:
|
||||
raise ValueError("retry_backoff_multiplier must be >= 1")
|
||||
if retry_max_delay < retry_delay:
|
||||
raise ValueError("retry_max_delay must be >= retry_delay")
|
||||
|
||||
# Store parameters
|
||||
self.model_name = model_name
|
||||
self.temperature = temperature
|
||||
self.window_size = window_size
|
||||
self.reserved_slots = reserved_slots
|
||||
self.max_iterations = max_iterations
|
||||
self.max_tokens = max_tokens
|
||||
self.num_samples = num_samples
|
||||
self.enable_logging = enable_logging
|
||||
self.log_level = log_level
|
||||
self.verbose = verbose
|
||||
|
||||
# Retry configuration
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.retry_backoff_multiplier = retry_backoff_multiplier
|
||||
self.retry_max_delay = retry_max_delay
|
||||
|
||||
# Allow model overrides
|
||||
proposer_model = proposer_model_name or self.model_name
|
||||
aggregator_model = aggregator_model_name or self.model_name
|
||||
|
||||
# Setup logging
|
||||
logger.info(
|
||||
f"Initializing Self-MoA-Seq with model: {self.model_name}"
|
||||
)
|
||||
|
||||
# Initialize proposer agent (generates multiple samples)
|
||||
self.proposer = Agent(
|
||||
agent_name="SelfMoASeq-Proposer",
|
||||
system_prompt=(
|
||||
"You are a sample generator. Generate diverse, high-quality responses "
|
||||
"to the given task. Vary your approach while maintaining quality."
|
||||
),
|
||||
model_name=proposer_model,
|
||||
temperature=self.temperature,
|
||||
max_loops=1,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
# Initialize aggregator agent (synthesizes outputs)
|
||||
self.aggregator = Agent(
|
||||
agent_name="SelfMoASeq-Aggregator",
|
||||
system_prompt=(
|
||||
"You are an expert synthesizer. Analyze the provided responses and "
|
||||
"synthesize them into a single, high-quality output. Consider the "
|
||||
"strengths of each response and combine them effectively. Pay special "
|
||||
"attention to any highlighted best response, as it provides high-quality guidance."
|
||||
),
|
||||
model_name=aggregator_model,
|
||||
temperature=0.0, # Deterministic aggregation
|
||||
max_loops=1,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
# Metrics tracking
|
||||
self.metrics: Dict[str, Any] = {
|
||||
"total_samples_generated": 0,
|
||||
"total_aggregations": 0,
|
||||
"total_tokens_used": 0,
|
||||
"execution_time_seconds": 0,
|
||||
}
|
||||
|
||||
logger.info("Self-MoA-Seq initialization complete")
|
||||
|
||||
@retry_with_instance_config
|
||||
def _generate_samples(
|
||||
self, task: str, num_samples: int
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate multiple samples from the proposer model.
|
||||
|
||||
Args:
|
||||
task: The task description
|
||||
num_samples: Number of samples to generate
|
||||
|
||||
Returns:
|
||||
List of generated samples
|
||||
"""
|
||||
logger.info(f"Generating {num_samples} samples for task")
|
||||
samples = []
|
||||
|
||||
try:
|
||||
for i in range(num_samples):
|
||||
logger.debug(f"Generating sample {i+1}/{num_samples}")
|
||||
sample = self.proposer.run(task)
|
||||
samples.append(sample)
|
||||
self.metrics["total_samples_generated"] += 1
|
||||
|
||||
logger.success(
|
||||
f"Successfully generated {len(samples)} samples"
|
||||
)
|
||||
return samples
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during sample generation: {str(e)}")
|
||||
raise
|
||||
|
||||
def _format_aggregation_prompt(
|
||||
self,
|
||||
task: str,
|
||||
samples: List[str],
|
||||
best_so_far: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format the aggregation prompt with sliding window.
|
||||
|
||||
Args:
|
||||
task: Original task
|
||||
samples: List of samples to aggregate
|
||||
best_so_far: Previously synthesized best output
|
||||
|
||||
Returns:
|
||||
Formatted aggregation prompt
|
||||
"""
|
||||
prompt = f"Original Task:\n{task}\n\n"
|
||||
|
||||
if best_so_far:
|
||||
prompt += f"Current Best Response (synthesized from previous iterations):\n{best_so_far}\n\n"
|
||||
|
||||
prompt += "Candidate Responses to Synthesize:\n"
|
||||
for i, sample in enumerate(samples, 1):
|
||||
prompt += f"\n[Response {i}]:\n{sample}\n"
|
||||
|
||||
prompt += (
|
||||
"\nProvide a comprehensive synthesis that combines the strengths of "
|
||||
"all responses while maintaining coherence and quality."
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
@retry_with_instance_config
|
||||
def _aggregate_window(
|
||||
self,
|
||||
task: str,
|
||||
window_samples: List[str],
|
||||
best_so_far: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Aggregate a window of samples.
|
||||
|
||||
Args:
|
||||
task: Original task
|
||||
window_samples: Samples in current window
|
||||
best_so_far: Best aggregation so far
|
||||
|
||||
Returns:
|
||||
Synthesized output
|
||||
"""
|
||||
logger.debug(
|
||||
f"Aggregating window of {len(window_samples)} samples"
|
||||
)
|
||||
|
||||
try:
|
||||
prompt = self._format_aggregation_prompt(
|
||||
task,
|
||||
window_samples,
|
||||
best_so_far,
|
||||
)
|
||||
|
||||
aggregated = self.aggregator.run(prompt)
|
||||
self.metrics["total_aggregations"] += 1
|
||||
|
||||
logger.debug("Window aggregation complete")
|
||||
return aggregated
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during window aggregation: {str(e)}")
|
||||
raise
|
||||
|
||||
@retry_with_instance_config
|
||||
def run(
|
||||
self,
|
||||
task: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute Self-MoA-Seq on the given task.
|
||||
|
||||
This method implements the sequential aggregation algorithm:
|
||||
1. Generate num_samples from the proposer model
|
||||
2. Use sliding window to aggregate in chunks
|
||||
3. Progressively synthesize outputs, biasing aggregator toward best
|
||||
4. Return final synthesized output
|
||||
|
||||
Args:
|
||||
task: The task to process
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- final_output: The best synthesized response
|
||||
- all_samples: List of generated samples
|
||||
- aggregation_steps: Number of aggregation iterations
|
||||
- metrics: Performance metrics
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting Self-MoA-Seq run with {self.num_samples} samples"
|
||||
)
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
# Validate input
|
||||
if not task or not isinstance(task, str):
|
||||
raise ValueError("task must be a non-empty string")
|
||||
|
||||
# Phase 1: Generate samples
|
||||
logger.info("Phase 1: Generating initial samples")
|
||||
samples = self._generate_samples(task, self.num_samples)
|
||||
|
||||
# Phase 2: Sequential aggregation with sliding window
|
||||
logger.info("Phase 2: Sequential aggregation")
|
||||
best_output = samples[0]
|
||||
aggregation_step = 0
|
||||
|
||||
# Process samples in windows
|
||||
remaining_samples = samples[1:]
|
||||
|
||||
while remaining_samples:
|
||||
aggregation_step += 1
|
||||
logger.info(
|
||||
f"Aggregation iteration {aggregation_step}, "
|
||||
f"remaining samples: {len(remaining_samples)}"
|
||||
)
|
||||
|
||||
# Create window: reserved slots + new samples
|
||||
window_size = min(
|
||||
self.window_size - self.reserved_slots,
|
||||
len(remaining_samples),
|
||||
)
|
||||
current_window = remaining_samples[:window_size]
|
||||
remaining_samples = remaining_samples[window_size:]
|
||||
|
||||
# Aggregate with bias toward best output
|
||||
window_with_best = [best_output] + current_window
|
||||
best_output = self._aggregate_window(
|
||||
task,
|
||||
window_with_best,
|
||||
best_output,
|
||||
)
|
||||
|
||||
if aggregation_step >= self.max_iterations:
|
||||
logger.warning(
|
||||
f"Reached max aggregation iterations ({self.max_iterations})"
|
||||
)
|
||||
break
|
||||
|
||||
# Calculate metrics
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
self.metrics["execution_time_seconds"] = elapsed
|
||||
|
||||
result = {
|
||||
"final_output": best_output,
|
||||
"all_samples": samples,
|
||||
"aggregation_steps": aggregation_step,
|
||||
"metrics": self.metrics.copy(),
|
||||
"task": task,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
logger.success(
|
||||
f"Self-MoA-Seq completed in {elapsed:.2f}s "
|
||||
f"with {aggregation_step} aggregation iterations"
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
self._log_summary(result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in Self-MoA-Seq.run: {str(e)}")
|
||||
raise
|
||||
|
||||
def _log_summary(self, result: Dict[str, Any]) -> None:
|
||||
"""Log execution summary."""
|
||||
logger.info("=" * 60)
|
||||
logger.info("Self-MoA-Seq Execution Summary")
|
||||
logger.info("=" * 60)
|
||||
logger.info(
|
||||
f"Total samples generated: {self.metrics['total_samples_generated']}"
|
||||
)
|
||||
logger.info(
|
||||
f"Aggregation iterations: {result['aggregation_steps']}"
|
||||
)
|
||||
logger.info(
|
||||
f"Execution time: {self.metrics['execution_time_seconds']:.2f}s"
|
||||
)
|
||||
logger.info(
|
||||
f"Final output length: {len(result['final_output'])} chars"
|
||||
)
|
||||
logger.info("=" * 60)
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Get current performance metrics."""
|
||||
return self.metrics.copy()
|
@ -0,0 +1,709 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from swarms.structs.self_moa_seq import SelfMoASeq
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_seq():
|
||||
"""Create a basic SelfMoASeq instance for testing."""
|
||||
return SelfMoASeq(
|
||||
num_samples=3,
|
||||
window_size=4,
|
||||
reserved_slots=2,
|
||||
max_iterations=5,
|
||||
verbose=False,
|
||||
enable_logging=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_retry_seq():
|
||||
"""Create a SelfMoASeq instance with custom retry parameters."""
|
||||
return SelfMoASeq(
|
||||
num_samples=2,
|
||||
max_retries=5,
|
||||
retry_delay=0.5,
|
||||
retry_backoff_multiplier=1.5,
|
||||
retry_max_delay=10.0,
|
||||
verbose=False,
|
||||
enable_logging=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agents():
|
||||
"""Create mock agents for testing."""
|
||||
proposer = Mock()
|
||||
aggregator = Mock()
|
||||
return proposer, aggregator
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Initialization and Parameter Validation Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_default_initialization():
|
||||
"""Test that SelfMoASeq initializes with default parameters."""
|
||||
seq = SelfMoASeq()
|
||||
|
||||
assert seq.model_name == "gpt-4o-mini"
|
||||
assert seq.temperature == 0.7
|
||||
assert seq.window_size == 6
|
||||
assert seq.reserved_slots == 3
|
||||
assert seq.max_iterations == 10
|
||||
assert seq.max_tokens == 2000
|
||||
assert seq.num_samples == 30
|
||||
assert seq.enable_logging is True
|
||||
assert seq.log_level == "INFO"
|
||||
assert seq.verbose is True
|
||||
assert seq.max_retries == 3
|
||||
assert seq.retry_delay == 1.0
|
||||
assert seq.retry_backoff_multiplier == 2.0
|
||||
assert seq.retry_max_delay == 60.0
|
||||
|
||||
|
||||
def test_custom_initialization():
|
||||
"""Test initialization with custom parameters."""
|
||||
seq = SelfMoASeq(
|
||||
model_name="custom-model",
|
||||
temperature=0.5,
|
||||
window_size=8,
|
||||
reserved_slots=2,
|
||||
max_iterations=15,
|
||||
max_tokens=3000,
|
||||
num_samples=20,
|
||||
enable_logging=False,
|
||||
log_level="DEBUG",
|
||||
verbose=False,
|
||||
proposer_model_name="proposer-model",
|
||||
aggregator_model_name="aggregator-model",
|
||||
max_retries=5,
|
||||
retry_delay=2.0,
|
||||
retry_backoff_multiplier=3.0,
|
||||
retry_max_delay=120.0,
|
||||
)
|
||||
|
||||
assert seq.model_name == "custom-model"
|
||||
assert seq.temperature == 0.5
|
||||
assert seq.window_size == 8
|
||||
assert seq.reserved_slots == 2
|
||||
assert seq.max_iterations == 15
|
||||
assert seq.max_tokens == 3000
|
||||
assert seq.num_samples == 20
|
||||
assert seq.enable_logging is False
|
||||
assert seq.log_level == "DEBUG"
|
||||
assert seq.verbose is False
|
||||
assert seq.max_retries == 5
|
||||
assert seq.retry_delay == 2.0
|
||||
assert seq.retry_backoff_multiplier == 3.0
|
||||
assert seq.retry_max_delay == 120.0
|
||||
|
||||
|
||||
def test_window_size_validation():
|
||||
"""Test window_size parameter validation."""
|
||||
# Valid window_size
|
||||
seq = SelfMoASeq(window_size=2)
|
||||
assert seq.window_size == 2
|
||||
|
||||
# Invalid window_size
|
||||
with pytest.raises(
|
||||
ValueError, match="window_size must be at least 2"
|
||||
):
|
||||
SelfMoASeq(window_size=1)
|
||||
|
||||
|
||||
def test_reserved_slots_validation():
|
||||
"""Test reserved_slots parameter validation."""
|
||||
# Valid reserved_slots
|
||||
seq = SelfMoASeq(window_size=6, reserved_slots=3)
|
||||
assert seq.reserved_slots == 3
|
||||
|
||||
# Invalid reserved_slots (>= window_size)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="reserved_slots must be less than window_size",
|
||||
):
|
||||
SelfMoASeq(window_size=6, reserved_slots=6)
|
||||
|
||||
|
||||
def test_temperature_validation():
|
||||
"""Test temperature parameter validation."""
|
||||
# Valid temperature
|
||||
seq = SelfMoASeq(temperature=1.5)
|
||||
assert seq.temperature == 1.5
|
||||
|
||||
# Invalid temperature (too high)
|
||||
with pytest.raises(
|
||||
ValueError, match="temperature must be between 0 and 2"
|
||||
):
|
||||
SelfMoASeq(temperature=2.5)
|
||||
|
||||
# Invalid temperature (negative)
|
||||
with pytest.raises(
|
||||
ValueError, match="temperature must be between 0 and 2"
|
||||
):
|
||||
SelfMoASeq(temperature=-0.1)
|
||||
|
||||
|
||||
def test_max_iterations_validation():
|
||||
"""Test max_iterations parameter validation."""
|
||||
# Valid max_iterations
|
||||
seq = SelfMoASeq(max_iterations=5)
|
||||
assert seq.max_iterations == 5
|
||||
|
||||
# Invalid max_iterations
|
||||
with pytest.raises(
|
||||
ValueError, match="max_iterations must be at least 1"
|
||||
):
|
||||
SelfMoASeq(max_iterations=0)
|
||||
|
||||
|
||||
def test_num_samples_validation():
|
||||
"""Test num_samples parameter validation."""
|
||||
# Valid num_samples
|
||||
seq = SelfMoASeq(num_samples=5)
|
||||
assert seq.num_samples == 5
|
||||
|
||||
# Invalid num_samples
|
||||
with pytest.raises(
|
||||
ValueError, match="num_samples must be at least 2"
|
||||
):
|
||||
SelfMoASeq(num_samples=1)
|
||||
|
||||
|
||||
def test_retry_parameters_validation():
|
||||
"""Test retry parameters validation."""
|
||||
# Valid retry parameters
|
||||
seq = SelfMoASeq(
|
||||
max_retries=5,
|
||||
retry_delay=2.0,
|
||||
retry_backoff_multiplier=1.5,
|
||||
retry_max_delay=10.0,
|
||||
)
|
||||
assert seq.max_retries == 5
|
||||
assert seq.retry_delay == 2.0
|
||||
assert seq.retry_backoff_multiplier == 1.5
|
||||
assert seq.retry_max_delay == 10.0
|
||||
|
||||
# Invalid max_retries
|
||||
with pytest.raises(
|
||||
ValueError, match="max_retries must be non-negative"
|
||||
):
|
||||
SelfMoASeq(max_retries=-1)
|
||||
|
||||
# Invalid retry_delay
|
||||
with pytest.raises(
|
||||
ValueError, match="retry_delay must be non-negative"
|
||||
):
|
||||
SelfMoASeq(retry_delay=-1.0)
|
||||
|
||||
# Invalid retry_backoff_multiplier
|
||||
with pytest.raises(
|
||||
ValueError, match="retry_backoff_multiplier must be >= 1"
|
||||
):
|
||||
SelfMoASeq(retry_backoff_multiplier=0.5)
|
||||
|
||||
# Invalid retry_max_delay
|
||||
with pytest.raises(
|
||||
ValueError, match="retry_max_delay must be >= retry_delay"
|
||||
):
|
||||
SelfMoASeq(retry_delay=10.0, retry_max_delay=5.0)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Retry Functionality Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_retry_decorator_property(basic_seq):
|
||||
"""Test that retry decorator property works correctly."""
|
||||
decorator = basic_seq.retry_decorator
|
||||
assert callable(decorator)
|
||||
|
||||
|
||||
def test_get_retry_decorator(basic_seq):
|
||||
"""Test _get_retry_decorator method."""
|
||||
decorator = basic_seq._get_retry_decorator()
|
||||
assert callable(decorator)
|
||||
|
||||
|
||||
def test_retry_configuration_inheritance(custom_retry_seq):
|
||||
"""Test that retry configuration is properly inherited."""
|
||||
assert custom_retry_seq.max_retries == 5
|
||||
assert custom_retry_seq.retry_delay == 0.5
|
||||
assert custom_retry_seq.retry_backoff_multiplier == 1.5
|
||||
assert custom_retry_seq.retry_max_delay == 10.0
|
||||
|
||||
|
||||
def test_retry_functionality_with_mock(custom_retry_seq, mock_agents):
|
||||
"""Test retry functionality with mocked agents."""
|
||||
proposer, aggregator = mock_agents
|
||||
|
||||
# Configure mock to fail first time, succeed second time
|
||||
proposer.run.side_effect = [
|
||||
Exception("Simulated failure"),
|
||||
"Sample 1",
|
||||
"Sample 2",
|
||||
]
|
||||
|
||||
aggregator.run.side_effect = [
|
||||
Exception("Simulated aggregation failure"),
|
||||
"Aggregated result",
|
||||
]
|
||||
|
||||
# Patch the agents
|
||||
with patch.object(
|
||||
custom_retry_seq, "proposer", proposer
|
||||
), patch.object(custom_retry_seq, "aggregator", aggregator):
|
||||
|
||||
# This should succeed after retries
|
||||
result = custom_retry_seq.run("Test task")
|
||||
|
||||
assert result is not None
|
||||
assert "final_output" in result
|
||||
assert "all_samples" in result
|
||||
assert "metrics" in result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Core Methods Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_generate_samples_success(basic_seq, mock_agents):
|
||||
"""Test successful sample generation."""
|
||||
proposer, _ = mock_agents
|
||||
proposer.run.return_value = "Generated sample"
|
||||
|
||||
with patch.object(basic_seq, "proposer", proposer):
|
||||
samples = basic_seq._generate_samples("Test task", 3)
|
||||
|
||||
assert len(samples) == 3
|
||||
assert all(sample == "Generated sample" for sample in samples)
|
||||
assert basic_seq.metrics["total_samples_generated"] == 3
|
||||
|
||||
|
||||
def test_generate_samples_with_retry(basic_seq, mock_agents):
|
||||
"""Test sample generation with retry on failure."""
|
||||
proposer, _ = mock_agents
|
||||
proposer.run.side_effect = [
|
||||
Exception("First failure"),
|
||||
Exception("Second failure"),
|
||||
"Sample 1",
|
||||
"Sample 2",
|
||||
"Sample 3",
|
||||
]
|
||||
|
||||
with patch.object(basic_seq, "proposer", proposer):
|
||||
samples = basic_seq._generate_samples("Test task", 3)
|
||||
|
||||
assert len(samples) == 3
|
||||
assert basic_seq.metrics["total_samples_generated"] == 3
|
||||
|
||||
|
||||
def test_format_aggregation_prompt(basic_seq):
|
||||
"""Test aggregation prompt formatting."""
|
||||
task = "Test task"
|
||||
samples = ["Sample 1", "Sample 2", "Sample 3"]
|
||||
best_so_far = "Best response"
|
||||
|
||||
prompt = basic_seq._format_aggregation_prompt(
|
||||
task, samples, best_so_far
|
||||
)
|
||||
|
||||
assert "Original Task:" in prompt
|
||||
assert task in prompt
|
||||
assert "Current Best Response" in prompt
|
||||
assert best_so_far in prompt
|
||||
assert "Candidate Responses to Synthesize" in prompt
|
||||
assert "Sample 1" in prompt
|
||||
assert "Sample 2" in prompt
|
||||
assert "Sample 3" in prompt
|
||||
|
||||
|
||||
def test_format_aggregation_prompt_no_best(basic_seq):
|
||||
"""Test aggregation prompt formatting without best_so_far."""
|
||||
task = "Test task"
|
||||
samples = ["Sample 1", "Sample 2"]
|
||||
|
||||
prompt = basic_seq._format_aggregation_prompt(task, samples)
|
||||
|
||||
assert "Original Task:" in prompt
|
||||
assert task in prompt
|
||||
assert "Current Best Response" not in prompt
|
||||
assert "Candidate Responses to Synthesize" in prompt
|
||||
assert "Sample 1" in prompt
|
||||
assert "Sample 2" in prompt
|
||||
|
||||
|
||||
def test_aggregate_window_success(basic_seq, mock_agents):
|
||||
"""Test successful window aggregation."""
|
||||
_, aggregator = mock_agents
|
||||
aggregator.run.return_value = "Aggregated result"
|
||||
|
||||
with patch.object(basic_seq, "aggregator", aggregator):
|
||||
result = basic_seq._aggregate_window(
|
||||
"Test task", ["Sample 1", "Sample 2"], "Best so far"
|
||||
)
|
||||
|
||||
assert result == "Aggregated result"
|
||||
assert basic_seq.metrics["total_aggregations"] == 1
|
||||
|
||||
|
||||
def test_aggregate_window_with_retry(basic_seq, mock_agents):
|
||||
"""Test window aggregation with retry on failure."""
|
||||
_, aggregator = mock_agents
|
||||
aggregator.run.side_effect = [
|
||||
Exception("First failure"),
|
||||
Exception("Second failure"),
|
||||
"Aggregated result",
|
||||
]
|
||||
|
||||
with patch.object(basic_seq, "aggregator", aggregator):
|
||||
result = basic_seq._aggregate_window(
|
||||
"Test task", ["Sample 1", "Sample 2"]
|
||||
)
|
||||
|
||||
assert result == "Aggregated result"
|
||||
assert basic_seq.metrics["total_aggregations"] == 1
|
||||
|
||||
|
||||
def test_run_method_success(basic_seq, mock_agents):
|
||||
"""Test successful run method execution."""
|
||||
proposer, aggregator = mock_agents
|
||||
|
||||
# Configure mocks
|
||||
proposer.run.return_value = "Generated sample"
|
||||
aggregator.run.return_value = "Aggregated result"
|
||||
|
||||
with patch.object(basic_seq, "proposer", proposer), patch.object(
|
||||
basic_seq, "aggregator", aggregator
|
||||
):
|
||||
|
||||
result = basic_seq.run("Test task")
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "final_output" in result
|
||||
assert "all_samples" in result
|
||||
assert "aggregation_steps" in result
|
||||
assert "metrics" in result
|
||||
assert "task" in result
|
||||
assert "timestamp" in result
|
||||
|
||||
assert result["task"] == "Test task"
|
||||
assert len(result["all_samples"]) == 3
|
||||
assert result["final_output"] == "Aggregated result"
|
||||
|
||||
|
||||
def test_run_method_with_retry(basic_seq, mock_agents):
|
||||
"""Test run method with retry on failure."""
|
||||
proposer, aggregator = mock_agents
|
||||
|
||||
# Configure mocks to fail first time, succeed second time
|
||||
proposer.run.side_effect = [
|
||||
Exception("First failure"),
|
||||
"Sample 1",
|
||||
"Sample 2",
|
||||
"Sample 3",
|
||||
]
|
||||
|
||||
aggregator.run.side_effect = [
|
||||
Exception("Aggregation failure"),
|
||||
"Final result",
|
||||
]
|
||||
|
||||
with patch.object(basic_seq, "proposer", proposer), patch.object(
|
||||
basic_seq, "aggregator", aggregator
|
||||
):
|
||||
|
||||
result = basic_seq.run("Test task")
|
||||
|
||||
assert result is not None
|
||||
assert result["final_output"] == "Final result"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Error Handling and Edge Cases Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_run_invalid_task(basic_seq):
|
||||
"""Test run method with invalid task input."""
|
||||
with pytest.raises(
|
||||
ValueError, match="task must be a non-empty string"
|
||||
):
|
||||
basic_seq.run("")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="task must be a non-empty string"
|
||||
):
|
||||
basic_seq.run(None)
|
||||
|
||||
|
||||
def test_run_max_iterations_reached(basic_seq, mock_agents):
|
||||
"""Test run method when max iterations are reached."""
|
||||
proposer, aggregator = mock_agents
|
||||
|
||||
# Configure mocks
|
||||
proposer.run.return_value = "Generated sample"
|
||||
aggregator.run.return_value = "Aggregated result"
|
||||
|
||||
# Set max_iterations to 1 to trigger the warning
|
||||
basic_seq.max_iterations = 1
|
||||
|
||||
with patch.object(basic_seq, "proposer", proposer), patch.object(
|
||||
basic_seq, "aggregator", aggregator
|
||||
):
|
||||
|
||||
result = basic_seq.run("Test task")
|
||||
|
||||
assert result is not None
|
||||
assert result["aggregation_steps"] <= 1
|
||||
|
||||
|
||||
def test_generate_samples_exception_propagation(
|
||||
basic_seq, mock_agents
|
||||
):
|
||||
"""Test that exceptions in sample generation are properly propagated."""
|
||||
proposer, _ = mock_agents
|
||||
proposer.run.side_effect = Exception("Persistent failure")
|
||||
|
||||
with patch.object(basic_seq, "proposer", proposer):
|
||||
with pytest.raises(Exception, match="Persistent failure"):
|
||||
basic_seq._generate_samples("Test task", 3)
|
||||
|
||||
|
||||
def test_aggregate_window_exception_propagation(
|
||||
basic_seq, mock_agents
|
||||
):
|
||||
"""Test that exceptions in window aggregation are properly propagated."""
|
||||
_, aggregator = mock_agents
|
||||
aggregator.run.side_effect = Exception(
|
||||
"Persistent aggregation failure"
|
||||
)
|
||||
|
||||
with patch.object(basic_seq, "aggregator", aggregator):
|
||||
with pytest.raises(
|
||||
Exception, match="Persistent aggregation failure"
|
||||
):
|
||||
basic_seq._aggregate_window(
|
||||
"Test task", ["Sample 1", "Sample 2"]
|
||||
)
|
||||
|
||||
|
||||
def test_run_exception_propagation(basic_seq, mock_agents):
|
||||
"""Test that exceptions in run method are properly propagated."""
|
||||
proposer, _ = mock_agents
|
||||
proposer.run.side_effect = Exception("Persistent run failure")
|
||||
|
||||
with patch.object(basic_seq, "proposer", proposer):
|
||||
with pytest.raises(Exception, match="Persistent run failure"):
|
||||
basic_seq.run("Test task")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Metrics and Logging Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_metrics_initialization(basic_seq):
|
||||
"""Test that metrics are properly initialized."""
|
||||
metrics = basic_seq.get_metrics()
|
||||
|
||||
assert isinstance(metrics, dict)
|
||||
assert "total_samples_generated" in metrics
|
||||
assert "total_aggregations" in metrics
|
||||
assert "total_tokens_used" in metrics
|
||||
assert "execution_time_seconds" in metrics
|
||||
|
||||
assert metrics["total_samples_generated"] == 0
|
||||
assert metrics["total_aggregations"] == 0
|
||||
assert metrics["total_tokens_used"] == 0
|
||||
assert metrics["execution_time_seconds"] == 0
|
||||
|
||||
|
||||
def test_metrics_tracking(basic_seq, mock_agents):
|
||||
"""Test that metrics are properly tracked during execution."""
|
||||
proposer, aggregator = mock_agents
|
||||
|
||||
proposer.run.return_value = "Generated sample"
|
||||
aggregator.run.return_value = "Aggregated result"
|
||||
|
||||
with patch.object(basic_seq, "proposer", proposer), patch.object(
|
||||
basic_seq, "aggregator", aggregator
|
||||
):
|
||||
|
||||
result = basic_seq.run("Test task")
|
||||
|
||||
metrics = result["metrics"]
|
||||
assert metrics["total_samples_generated"] == 3
|
||||
assert metrics["total_aggregations"] >= 1
|
||||
assert metrics["execution_time_seconds"] > 0
|
||||
|
||||
|
||||
def test_log_summary(basic_seq):
|
||||
"""Test _log_summary method."""
|
||||
result = {
|
||||
"final_output": "Test output",
|
||||
"aggregation_steps": 2,
|
||||
"metrics": {
|
||||
"total_samples_generated": 3,
|
||||
"execution_time_seconds": 1.5,
|
||||
},
|
||||
}
|
||||
|
||||
# This should not raise an exception
|
||||
basic_seq._log_summary(result)
|
||||
|
||||
|
||||
def test_get_metrics_returns_copy(basic_seq):
|
||||
"""Test that get_metrics returns a copy of metrics."""
|
||||
metrics1 = basic_seq.get_metrics()
|
||||
metrics2 = basic_seq.get_metrics()
|
||||
|
||||
# Should be different objects
|
||||
assert metrics1 is not metrics2
|
||||
|
||||
# But should have same content
|
||||
assert metrics1 == metrics2
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_full_integration_small_samples():
|
||||
"""Test full integration with small number of samples."""
|
||||
seq = SelfMoASeq(
|
||||
num_samples=2,
|
||||
window_size=3,
|
||||
reserved_slots=1,
|
||||
max_iterations=2,
|
||||
verbose=False,
|
||||
enable_logging=False,
|
||||
)
|
||||
|
||||
proposer, aggregator = Mock(), Mock()
|
||||
proposer.run.return_value = "Generated sample"
|
||||
aggregator.run.return_value = "Aggregated result"
|
||||
|
||||
with patch.object(seq, "proposer", proposer), patch.object(
|
||||
seq, "aggregator", aggregator
|
||||
):
|
||||
|
||||
result = seq.run("Integration test task")
|
||||
|
||||
assert result is not None
|
||||
assert result["task"] == "Integration test task"
|
||||
assert len(result["all_samples"]) == 2
|
||||
assert result["final_output"] == "Aggregated result"
|
||||
assert result["aggregation_steps"] >= 0
|
||||
|
||||
|
||||
def test_model_name_overrides():
|
||||
"""Test that model name overrides work correctly."""
|
||||
seq = SelfMoASeq(
|
||||
model_name="base-model",
|
||||
proposer_model_name="proposer-model",
|
||||
aggregator_model_name="aggregator-model",
|
||||
verbose=False,
|
||||
enable_logging=False,
|
||||
)
|
||||
|
||||
# The agents should be initialized with the override names
|
||||
assert seq.proposer.model_name == "proposer-model"
|
||||
assert seq.aggregator.model_name == "aggregator-model"
|
||||
|
||||
|
||||
def test_temperature_settings():
|
||||
"""Test that temperature settings are applied correctly."""
|
||||
seq = SelfMoASeq(
|
||||
temperature=0.5, verbose=False, enable_logging=False
|
||||
)
|
||||
|
||||
assert seq.proposer.temperature == 0.5
|
||||
assert (
|
||||
seq.aggregator.temperature == 0.0
|
||||
) # Deterministic aggregation
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Performance and Edge Case Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_minimum_valid_configuration():
|
||||
"""Test with minimum valid configuration."""
|
||||
seq = SelfMoASeq(
|
||||
window_size=2,
|
||||
reserved_slots=1,
|
||||
max_iterations=1,
|
||||
num_samples=2,
|
||||
verbose=False,
|
||||
enable_logging=False,
|
||||
)
|
||||
|
||||
assert seq.window_size == 2
|
||||
assert seq.reserved_slots == 1
|
||||
assert seq.max_iterations == 1
|
||||
assert seq.num_samples == 2
|
||||
|
||||
|
||||
def test_zero_retries():
|
||||
"""Test with zero retries (should still work but not retry)."""
|
||||
seq = SelfMoASeq(
|
||||
max_retries=0,
|
||||
num_samples=2,
|
||||
verbose=False,
|
||||
enable_logging=False,
|
||||
)
|
||||
|
||||
assert seq.max_retries == 0
|
||||
|
||||
proposer, aggregator = Mock(), Mock()
|
||||
proposer.run.return_value = "Generated sample"
|
||||
aggregator.run.return_value = "Aggregated result"
|
||||
|
||||
with patch.object(seq, "proposer", proposer), patch.object(
|
||||
seq, "aggregator", aggregator
|
||||
):
|
||||
|
||||
result = seq.run("Test task")
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_large_configuration():
|
||||
"""Test with large configuration values."""
|
||||
seq = SelfMoASeq(
|
||||
window_size=20,
|
||||
reserved_slots=5,
|
||||
max_iterations=50,
|
||||
num_samples=100,
|
||||
max_retries=10,
|
||||
retry_delay=5.0,
|
||||
retry_backoff_multiplier=3.0,
|
||||
retry_max_delay=300.0,
|
||||
verbose=False,
|
||||
enable_logging=False,
|
||||
)
|
||||
|
||||
assert seq.window_size == 20
|
||||
assert seq.reserved_slots == 5
|
||||
assert seq.max_iterations == 50
|
||||
assert seq.num_samples == 100
|
||||
assert seq.max_retries == 10
|
||||
assert seq.retry_delay == 5.0
|
||||
assert seq.retry_backoff_multiplier == 3.0
|
||||
assert seq.retry_max_delay == 300.0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
Loading…
Reference in new issue