[FIX][Conversation] [FIX][Self Consistency, make it less decoupled from agent, now it doesn't

inherit from the agent class, removed code and improved self consistency docs] [fixed reasoning agent router, i
mproved the docstrings, and docs for the self consistency
master
Kye Gomez 15 hours ago
parent 440173817d
commit b4410ce9a9

@ -0,0 +1,22 @@
from swarms import SelfConsistencyAgent
# Initialize the reasoning agent router with self-consistency
reasoning_agent_router = SelfConsistencyAgent(
name="reasoning-agent",
description="A reasoning agent that can answer questions and help with tasks.",
model_name="gpt-4o-mini",
system_prompt="You are a helpful assistant that can answer questions and help with tasks.",
max_loops=1,
num_samples=3, # Generate 3 independent responses
eval=False, # Disable evaluation mode
random_models_on=False, # Disable random model selection
majority_voting_prompt=None, # Use default majority voting prompt
)
# Run the agent on a financial analysis task
result = reasoning_agent_router.run(
"What is the best possible financial strategy to maximize returns but minimize risk? Give a list of etfs to invest in and the percentage of the portfolio to allocate to each etf."
)
print("Financial Strategy Result:")
print(result)

@ -1,6 +1,5 @@
# Consistency Agent Documentation # Consistency Agent Documentation
The `SelfConsistencyAgent` is a specialized agent designed for generating multiple independent responses to a given task and aggregating them into a single, consistent final answer. It leverages concurrent processing to enhance efficiency and employs a majority voting mechanism to ensure the reliability of the aggregated response. The `SelfConsistencyAgent` is a specialized agent designed for generating multiple independent responses to a given task and aggregating them into a single, consistent final answer. It leverages concurrent processing to enhance efficiency and employs a majority voting mechanism to ensure the reliability of the aggregated response.
## Purpose ## Purpose
@ -17,24 +16,31 @@ The primary objective of the `SelfConsistencyAgent` is to provide a robust mecha
| Argument | Type | Default | Description | | Argument | Type | Default | Description |
|------------------------|---------|---------|-----------------------------------------------------------------------------| |------------------------|---------|---------|-----------------------------------------------------------------------------|
| `num_samples` | `int` | `5` | Number of independent responses to sample. | | `name` | `str` | `"Self-Consistency-Agent"` | Name of the agent. |
| `return_list` | `bool` | `False` | Whether to return the conversation as a list. | | `description` | `str` | `"An agent that uses self consistency to generate a final answer."` | Description of the agent's purpose. |
| `max_loops` | `int` | `1` | Maximum number of loops for the agent to run. | | `system_prompt` | `str` | `CONSISTENCY_SYSTEM_PROMPT` | System prompt for the reasoning agent. |
| `return_dict` | `bool` | `False` | Whether to return the conversation as a dictionary. | | `model_name` | `str` | Required | The underlying language model to use. |
| `return_json` | `bool` | `False` | Whether to return the conversation as JSON. | | `num_samples` | `int` | `5` | Number of independent responses to generate. |
| `majority_voting_prompt` | `str` | `None` | Custom prompt for majority voting. | | `max_loops` | `int` | `1` | Maximum number of reasoning loops per sample. |
| `majority_voting_prompt` | `Optional[str]` | `majority_voting_prompt` | Custom prompt for majority voting aggregation. |
| `eval` | `bool` | `False` | Enable evaluation mode for answer validation. |
| `output_type` | `OutputType` | `"dict"` | Format of the output. |
| `random_models_on` | `bool` | `False` | Enable random model selection for diversity. |
### Methods ### Methods
- **`run`**: Generates multiple responses for the given task and aggregates them. - **`run`**: Generates multiple responses for the given task and aggregates them.
- **Arguments**: - **Arguments**:
- `task` (`str`): The input prompt. - `task` (`str`): The input prompt.
- `answer` (`str`, optional): The expected answer to validate responses against. - `img` (`Optional[str]`, optional): Image input for vision tasks.
- **Returns**: `str` - The aggregated final answer. - `answer` (`Optional[str]`, optional): Expected answer for validation (if eval=True).
- **Returns**: `Union[str, Dict[str, Any]]` - The aggregated final answer.
- **`aggregate`**: Aggregates a list of responses into a single final answer using majority voting. - **`aggregation_agent`**: Aggregates a list of responses into a single final answer using majority voting.
- **Arguments**: - **Arguments**:
- `responses` (`List[str]`): The list of responses. - `responses` (`List[str]`): The list of responses.
- `prompt` (`str`, optional): Custom prompt for the aggregation agent.
- `model_name` (`str`, optional): Model to use for aggregation.
- **Returns**: `str` - The aggregated answer. - **Returns**: `str` - The aggregated answer.
- **`check_responses_for_answer`**: Checks if a specified answer is present in any of the provided responses. - **`check_responses_for_answer`**: Checks if a specified answer is present in any of the provided responses.
@ -43,6 +49,11 @@ The primary objective of the `SelfConsistencyAgent` is to provide a robust mecha
- `answer` (`str`): The answer to look for in the responses. - `answer` (`str`): The answer to look for in the responses.
- **Returns**: `bool` - `True` if the answer is found, `False` otherwise. - **Returns**: `bool` - `True` if the answer is found, `False` otherwise.
- **`batched_run`**: Run the agent on multiple tasks in batch.
- **Arguments**:
- `tasks` (`List[str]`): List of tasks to be processed.
- **Returns**: `List[Union[str, Dict[str, Any]]]` - List of results for each task.
### Examples ### Examples
#### Example 1: Basic Usage #### Example 1: Basic Usage
@ -52,7 +63,7 @@ from swarms.agents.consistency_agent import SelfConsistencyAgent
# Initialize the agent # Initialize the agent
agent = SelfConsistencyAgent( agent = SelfConsistencyAgent(
agent_name="Reasoning-Agent", name="Math-Reasoning-Agent",
model_name="gpt-4o-mini", model_name="gpt-4o-mini",
max_loops=1, max_loops=1,
num_samples=5 num_samples=5
@ -75,7 +86,7 @@ from swarms.agents.consistency_agent import SelfConsistencyAgent
# Initialize the agent with a custom majority voting prompt # Initialize the agent with a custom majority voting prompt
agent = SelfConsistencyAgent( agent = SelfConsistencyAgent(
agent_name="Reasoning-Agent", name="Reasoning-Agent",
model_name="gpt-4o-mini", model_name="gpt-4o-mini",
max_loops=1, max_loops=1,
num_samples=5, num_samples=5,
@ -92,4 +103,128 @@ final_answer = agent.run(task)
print("Final aggregated answer:", final_answer) print("Final aggregated answer:", final_answer)
``` ```
#### Example 3: Evaluation Mode
```python
from swarms.agents.consistency_agent import SelfConsistencyAgent
# Initialize the agent with evaluation mode
agent = SelfConsistencyAgent(
name="Validation-Agent",
model_name="gpt-4o-mini",
num_samples=3,
eval=True
)
# Run with expected answer for validation
result = agent.run("What is 2 + 2?", answer="4", eval=True)
if result is not None:
print("Validation passed:", result)
else:
print("Validation failed - expected answer not found")
```
#### Example 4: Random Models for Diversity
```python
from swarms.agents.consistency_agent import SelfConsistencyAgent
# Initialize the agent with random model selection
agent = SelfConsistencyAgent(
name="Diverse-Reasoning-Agent",
model_name="gpt-4o-mini",
num_samples=5,
random_models_on=True
)
# Run the agent
result = agent.run("What are the benefits of renewable energy?")
print("Diverse reasoning result:", result)
```
#### Example 5: Batch Processing
```python
from swarms.agents.consistency_agent import SelfConsistencyAgent
# Initialize the agent
agent = SelfConsistencyAgent(
name="Batch-Processing-Agent",
model_name="gpt-4o-mini",
num_samples=3
)
# Define multiple tasks
tasks = [
"What is the capital of France?",
"What is 15 * 23?",
"Explain photosynthesis in simple terms."
]
# Process all tasks
results = agent.batched_run(tasks)
# Print results
for i, result in enumerate(results):
print(f"Task {i+1} result: {result}")
```
## Key Features
### Self-Consistency Technique
The agent implements the self-consistency approach based on the research paper "Self-Consistency Improves Chain of Thought Reasoning in Language Models" by Wang et al. (2022). This technique:
1. **Generates Multiple Independent Responses**: Creates several reasoning paths for the same problem
2. **Analyzes Consistency**: Examines agreement among different reasoning approaches
3. **Aggregates Results**: Uses majority voting or consensus building
4. **Produces Reliable Output**: Delivers a final answer reflecting the most reliable consensus
### Benefits
- **Mitigates Random Errors**: Multiple reasoning paths reduce individual path errors
- **Reduces Bias**: Diverse approaches minimize single-method biases
- **Improves Reliability**: Consensus-based results are more trustworthy
- **Handles Complexity**: Better performance on complex problem-solving tasks
### Use Cases
- **Mathematical Problem Solving**: Where accuracy is critical
- **Decision Making**: When reliability is paramount
- **Validation Tasks**: When answers need verification
- **Complex Reasoning**: Multi-step problem solving
- **Research Questions**: Where multiple perspectives are valuable
## Technical Details
### Concurrent Execution
The agent uses `ThreadPoolExecutor` to generate multiple responses concurrently, improving performance while maintaining independence between reasoning paths.
### Aggregation Process
The aggregation uses an AI-powered agent that:
- Identifies dominant responses
- Analyzes disparities and disagreements
- Evaluates consensus strength
- Synthesizes minority insights
- Provides comprehensive recommendations
### Output Formats
The agent supports various output types:
- `"dict"`: Dictionary format with conversation history
- `"str"`: Simple string output
- `"list"`: List format
- `"json"`: JSON formatted output
## Limitations
1. **Computational Cost**: Higher `num_samples` increases processing time and cost
2. **Model Dependencies**: Performance depends on the underlying model capabilities
3. **Consensus Challenges**: May struggle with tasks where multiple valid approaches exist
4. **Memory Usage**: Concurrent execution requires more memory resources
## Best Practices
1. **Sample Size**: Use 3-7 samples for most tasks; increase for critical decisions
2. **Model Selection**: Choose models with strong reasoning capabilities
3. **Evaluation Mode**: Enable for tasks with known correct answers
4. **Custom Prompts**: Tailor majority voting prompts for specific domains
5. **Batch Processing**: Use `batched_run` for multiple related tasks
--- ---

@ -38,9 +38,12 @@ graph TD
| `max_loops` | int | 1 | Maximum number of reasoning loops | | `max_loops` | int | 1 | Maximum number of reasoning loops |
| `swarm_type` | agent_types | "reasoning_duo" | Type of reasoning swarm to use | | `swarm_type` | agent_types | "reasoning_duo" | Type of reasoning swarm to use |
| `num_samples` | int | 1 | Number of samples for self-consistency | | `num_samples` | int | 1 | Number of samples for self-consistency |
| `output_type` | OutputType | "dict" | Format of the output | | `output_type` | OutputType | "dict-all-except-first" | Format of the output |
| `num_knowledge_items` | int | 6 | Number of knowledge items for GKP agent | | `num_knowledge_items` | int | 6 | Number of knowledge items for GKP agent |
| `memory_capacity` | int | 6 | Memory capacity for agents that support it | | `memory_capacity` | int | 6 | Memory capacity for agents that support it |
| `eval` | bool | False | Enable evaluation mode for self-consistency |
| `random_models_on` | bool | False | Enable random model selection for diversity |
| `majority_voting_prompt` | Optional[str] | None | Custom prompt for majority voting |
### Available Agent Types ### Available Agent Types
@ -84,12 +87,16 @@ graph TD
- Multiple solution generation - Multiple solution generation
- Consensus building - Consensus building
- Solution verification - Solution verification
- Concurrent execution
- AI-powered aggregation
**Best Use Cases** **Best Use Cases**
- Tasks requiring high reliability - Tasks requiring high reliability
- Problems with multiple approaches - Problems with multiple approaches
- Validation-heavy tasks - Validation-heavy tasks
- Mathematical problem solving
- Decision making scenarios
**Required Parameters** **Required Parameters**
@ -98,9 +105,12 @@ graph TD
**Optional Parameters** **Optional Parameters**
- num_samples - num_samples (default: 5)
- max_loops - max_loops (default: 1)
- output_type - output_type (default: "dict")
- eval (default: False) - Enable answer validation
- random_models_on (default: False) - Enable model diversity
- majority_voting_prompt (default: None) - Custom aggregation prompt
=== "IRE" === "IRE"
**Key Features** **Key Features**
@ -217,14 +227,43 @@ graph TD
system_prompt="You are a helpful assistant that can answer questions and help with tasks.", system_prompt="You are a helpful assistant that can answer questions and help with tasks.",
max_loops=1, max_loops=1,
swarm_type="self-consistency", swarm_type="self-consistency",
num_samples=1, num_samples=3,
output_type="list" eval=False,
random_models_on=False,
majority_voting_prompt=None
) )
# Run a single task # Run a single task
result = router.run("What is the best approach to solve this problem?") result = router.run("What is the best approach to solve this problem?")
``` ```
=== "Self-Consistency Examples"
```python
# Basic self-consistency
router = ReasoningAgentRouter(
swarm_type="self-consistency",
num_samples=3,
model_name="gpt-4o-mini"
)
# Self-consistency with evaluation mode
router = ReasoningAgentRouter(
swarm_type="self-consistency",
num_samples=5,
model_name="gpt-4o-mini",
eval=True,
random_models_on=True
)
# Self-consistency with custom majority voting
router = ReasoningAgentRouter(
swarm_type="self-consistency",
num_samples=3,
model_name="gpt-4o-mini",
majority_voting_prompt="Analyze the responses and provide the most accurate answer."
)
```
=== "ReflexionAgent" === "ReflexionAgent"
```python ```python
router = ReasoningAgentRouter( router = ReasoningAgentRouter(
@ -265,9 +304,13 @@ graph TD
2. **Performance Optimization** 2. **Performance Optimization**
- Adjust max_loops based on task complexity - Adjust max_loops based on task complexity
- Increase num_samples for higher reliability - Increase num_samples for higher reliability (3-7 for most tasks)
- Choose appropriate model_name based on task requirements - Choose appropriate model_name based on task requirements
- Enable random_models_on for diverse reasoning approaches
- Use eval mode for validation tasks with known answers
3. **Output Handling** 3. **Output Handling**
- Use appropriate output_type for your needs - Use appropriate output_type for your needs
@ -275,6 +318,15 @@ graph TD
- Process batched results appropriately - Process batched results appropriately
- Handle errors gracefully - Handle errors gracefully
4. **Self-Consistency Specific**
- Use 3-5 samples for most tasks, 7+ for critical decisions
- Enable eval mode when you have expected answers for validation
- Customize majority_voting_prompt for domain-specific aggregation
- Consider random_models_on for diverse model perspectives
## Limitations ## Limitations

@ -1,46 +0,0 @@
from swarms.agents.reasoning_agents import ReasoningAgentRouter
reasoning_agent_router = ReasoningAgentRouter(
agent_name="reasoning-agent",
description="A reasoning agent that can answer questions and help with tasks.",
model_name="gpt-4o-mini",
system_prompt="You are a helpful assistant that can answer questions and help with tasks.",
max_loops=1,
swarm_type="self-consistency",
num_samples=1,
output_type="list",
)
reasoning_agent_router.run(
"What is the best possible financial strategy to maximize returns but minimize risk? Give a list of etfs to invest in and the percentage of the portfolio to allocate to each etf."
)
# reasoning_agent_router.batched_run(
# [
# "What is the best possible financial strategy to maximize returns but minimize risk? Give a list of etfs to invest in and the percentage of the portfolio to allocate to each etf.",
# "What is the best possible financial strategy to maximize returns but minimize risk? Give a list of etfs to invest in and the percentage of the portfolio to allocate to each etf.",
# ]
# )
# from swarms import ReasoningAgentRouter
# calculus_router = ReasoningAgentRouter(
# agent_name="calculus-expert",
# description="A calculus problem solving agent",
# model_name="gpt-4o-mini",
# system_prompt="You are a calculus expert. Solve differentiation and integration problems methodically.",
# swarm_type="self-consistency",
# num_samples=3, # Generate 3 samples to ensure consistency
# output_type="list",
# )
# # Example calculus problem
# calculus_problem = "Find the derivative of f(x) = x³ln(x) - 5x²"
# # Get the solution
# solution = calculus_router.run(calculus_problem)
# print(solution)

@ -0,0 +1,23 @@
from swarms.agents.reasoning_agents import ReasoningAgentRouter
# Initialize the reasoning agent router with self-consistency
reasoning_agent_router = ReasoningAgentRouter(
agent_name="reasoning-agent",
description="A reasoning agent that can answer questions and help with tasks.",
model_name="gpt-4o-mini",
system_prompt="You are a helpful assistant that can answer questions and help with tasks.",
max_loops=1,
swarm_type="self-consistency",
num_samples=3, # Generate 3 independent responses
eval=False, # Disable evaluation mode
random_models_on=False, # Disable random model selection
majority_voting_prompt=None, # Use default majority voting prompt
)
# Run the agent on a financial analysis task
result = reasoning_agent_router.run(
"What is the best possible financial strategy to maximize returns but minimize risk? Give a list of etfs to invest in and the percentage of the portfolio to allocate to each etf."
)
print("Financial Strategy Result:")
print(result)

@ -1,22 +1,71 @@
from collections import Counter """
Self-Consistency Agent Implementation
This module implements the SelfConsistencyAgent, a specialized agent that leverages the
self-consistency technique to improve reasoning reliability and accuracy. The agent generates
multiple independent responses to a given task and aggregates them into a single, consistent
final answer using majority voting and sophisticated aggregation techniques.
The self-consistency approach is based on the research paper:
"Self-Consistency Improves Chain of Thought Reasoning in Language Models"
by Wang et al. (2022) - https://arxiv.org/abs/2203.07870
Key Features:
- Concurrent generation of multiple independent responses
- Majority voting aggregation with detailed analysis
- Evaluation mode for answer validation
- Configurable output formats
- Thread-safe execution
Author: Swarms Team
License: MIT
"""
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List from typing import List, Optional, Union, Dict, Any
from loguru import logger from loguru import logger
from swarms.structs.agent import Agent from swarms.structs.agent import Agent
from swarms.structs.conversation import Conversation from swarms.structs.conversation import Conversation
from swarms.structs.malt import majority_voting_prompt
from swarms.utils.output_types import OutputType from swarms.utils.output_types import OutputType
from swarms.utils.any_to_str import any_to_str from swarms.utils.any_to_str import any_to_str
from swarms.utils.history_output_formatter import ( from swarms.utils.history_output_formatter import (
history_output_formatter, history_output_formatter,
) )
# System prompt for the reasoning agent that generates individual responses
CONSISTENCY_SYSTEM_PROMPT = """ CONSISTENCY_SYSTEM_PROMPT = """
You are a reasoning agent designed for complex problem-solving and decision-making. Your objective is to provide clear and reliable responses through structured reasoning. Begin by thoroughly understanding the problem, rephrasing it for clarity, and identifying key components. Develop a logical plan that breaks the problem into manageable steps, detailing your approach and any assumptions made. Validate your information with reliable sources and assess the accuracy of your calculations. Explore multiple solutions, weighing their pros and cons, and maintain transparency by documenting your reasoning process, uncertainties, and biases. Summarize your findings in a concise final answer that reflects your thorough analysis, ensuring it is well-organized and accessible. Adapt your reasoning to the context of the problem, integrating new information as needed, and implement error-handling strategies to address any issues that arise. Finally, reflect on your reasoning process to identify areas for improvement and ensure consistency across all reasoning paths. You are a reasoning agent designed for complex problem-solving and decision-making. Your objective is to provide clear and reliable responses through structured reasoning. Begin by thoroughly understanding the problem, rephrasing it for clarity, and identifying key components. Develop a logical plan that breaks the problem into manageable steps, detailing your approach and any assumptions made. Validate your information with reliable sources and assess the accuracy of your calculations. Explore multiple solutions, weighing their pros and cons, and maintain transparency by documenting your reasoning process, uncertainties, and biases. Summarize your findings in a concise final answer that reflects your thorough analysis, ensuring it is well-organized and accessible. Adapt your reasoning to the context of the problem, integrating new information as needed, and implement error-handling strategies to address any issues that arise. Finally, reflect on your reasoning process to identify areas for improvement and ensure consistency across all reasoning paths.
""" """
# Detailed prompt for the majority voting aggregation agent
majority_voting_prompt = """
Engage in a comprehensive and exhaustive majority voting analysis of the following conversation, ensuring a deep and thoughtful examination of the responses provided by each agent. This analysis should not only summarize the responses but also critically engage with the content, context, and implications of each agent's input.
Please adhere to the following detailed guidelines:
1. **Identification of Dominant Responses:**
- Identify the most prevalent answer or recommendation across all agents. Provide a thorough rationale for its dominance, including an exploration of the factors that may have contributed to its acceptance among the agents. Discuss the context in which this consensus emerged and any relevant historical or theoretical frameworks that support this conclusion.
2. **Exploration of Disparities:**
- Delve into any significant disparities or contrasting viewpoints between agents. Explore the underlying reasons for these differences, considering aspects such as differing methodologies, assumptions, or interpretations of the task at hand. Analyze how these contrasting perspectives may reflect broader debates within the field and what implications they hold for the overall understanding of the topic.
3. **Consensus and Disagreement Analysis:**
- Highlight key areas of consensus and disagreement among the agents. Discuss the implications of these findings on the overall argument, including how consensus can strengthen certain claims while disagreement may indicate areas of uncertainty or contention. Provide examples from the conversation to illustrate these points and consider how they might influence future discussions or research directions.
4. **Critical Evaluation of Majority Opinion:**
- Critically evaluate the strength of the majority opinion, considering factors such as the reasoning behind it and its mathematical validity if applicable. Assess whether the majority opinion is well-supported by evidence and logical reasoning, and discuss any potential weaknesses or oversights that may undermine its credibility.
5. **Insights from Minority Viewpoints:**
- Note any unique insights from minority viewpoints, assessing their potential contributions to a more nuanced understanding of the topic. Discuss how these minority perspectives can enrich the conversation and provide alternative angles that may have been overlooked by the majority. Consider the value of dissent in academic discourse and how it can lead to more robust conclusions.
6. **Synthesis of Recommendations:**
- Provide a final synthesized recommendation based on the majority consensus, ensuring that it reflects a thorough consideration of all perspectives and is grounded in sound reasoning. This recommendation should not only summarize the majority view but also integrate insights from minority opinions, creating a comprehensive and balanced conclusion that acknowledges the complexity of the discussion.
Throughout your analysis, focus on uncovering clear patterns while being attentive to the subtleties and complexities inherent in the responses. Pay particular attention to the nuances of mathematical contexts where algorithmic thinking may be required, ensuring that your examination is both rigorous and accessible to a diverse audience.
"""
def aggregation_agent( def aggregation_agent(
responses: List[str], responses: List[str],
@ -24,7 +73,27 @@ def aggregation_agent(
model_name: str = "gpt-4o-mini", model_name: str = "gpt-4o-mini",
) -> str: ) -> str:
""" """
Aggregates a list of responses into a single final answer. Aggregates a list of responses into a single final answer using an AI-powered aggregation agent.
This function creates a specialized agent that analyzes multiple responses and synthesizes
them into a coherent final answer. The aggregation process considers consensus, disagreements,
and minority viewpoints to produce a well-reasoned conclusion.
Args:
responses (List[str]): List of responses to be aggregated
prompt (str, optional): Custom prompt for the aggregation agent.
Defaults to the majority_voting_prompt.
model_name (str, optional): Model to use for aggregation.
Defaults to "gpt-4o-mini".
Returns:
str: The aggregated final answer
Example:
>>> responses = ["Answer A", "Answer B", "Answer A"]
>>> final_answer = aggregation_agent(responses)
>>> print(final_answer)
"Based on the majority consensus..."
""" """
task = any_to_str(responses) task = any_to_str(responses)
@ -41,69 +110,174 @@ def aggregation_agent(
return final_answer return final_answer
class SelfConsistencyAgent(Agent): class SelfConsistencyAgent:
"""
A specialized agent that implements self-consistency for improved reasoning reliability.
The SelfConsistencyAgent generates multiple independent responses to a given task and
aggregates them into a single, consistent final answer. This approach is based on the
research paper "Self-Consistency Improves Chain of Thought Reasoning in Language Models"
by Wang et al. (2022).
Key Features:
- Concurrent generation of multiple independent responses
- Majority voting aggregation with detailed analysis
- Evaluation mode for answer validation
- Configurable output formats
- Thread-safe execution
The self-consistency technique works by:
1. Generating multiple independent reasoning paths for the same problem
2. Analyzing the consistency and agreement among these paths
3. Aggregating the results using majority voting or consensus building
4. Producing a final answer that reflects the most reliable consensus
This approach helps mitigate issues like:
- Random errors in individual reasoning paths
- Biases in single reasoning approaches
- Inconsistencies in complex problem-solving
Reference:
Wang, Y., Dong, W., Han, J., & Wang, W. (2022). Self-Consistency Improves Chain of
Thought Reasoning in Language Models. arXiv preprint arXiv:2203.07870.
https://arxiv.org/abs/2203.07870
Example:
>>> agent = SelfConsistencyAgent(
... name="Math-Reasoning-Agent",
... model_name="gpt-4o-mini",
... num_samples=5,
... max_loops=1
... )
>>> result = agent.run("What is the 40th prime number?")
>>> print(result)
"""
def __init__( def __init__(
self, self,
name: str = "Self-Consistency-Agent", name: str = "Self-Consistency-Agent",
description: str = "An agent that uses self consistency to generate a final answer.", description: str = "An agent that uses self consistency to generate a final answer.",
model_name: str = "gpt-4o-mini",
system_prompt: str = CONSISTENCY_SYSTEM_PROMPT, system_prompt: str = CONSISTENCY_SYSTEM_PROMPT,
num_samples: int = 5, num_samples: int = 5,
max_loops: int = 1, max_loops: int = 1,
majority_voting_prompt: str = None, majority_voting_prompt: Optional[
str
] = majority_voting_prompt,
eval: bool = False, eval: bool = False,
output_type: OutputType = "dict", output_type: OutputType = "dict",
random_models_on: bool = False,
*args,
**kwargs, **kwargs,
): ):
""" """
Initializes the SelfConsistencyAgent. Initialize the SelfConsistencyAgent.
Args: Args:
num_samples (int): Number of independent responses to sample. name (str, optional): Name of the agent. Defaults to "Self-Consistency-Agent".
**kwargs: Other keyword arguments passed to the base Agent. description (str, optional): Description of the agent's purpose.
Defaults to "An agent that uses self consistency to generate a final answer.".
model_name (str, optional): The underlying language model to use.
Defaults to "gpt-4o-mini".
system_prompt (str, optional): System prompt for the reasoning agent.
Defaults to CONSISTENCY_SYSTEM_PROMPT.
num_samples (int, optional): Number of independent responses to generate.
Defaults to 5.
max_loops (int, optional): Maximum number of reasoning loops per sample.
Defaults to 1.
majority_voting_prompt (Optional[str], optional): Custom prompt for majority voting.
Defaults to None.
eval (bool, optional): Enable evaluation mode for answer validation.
Defaults to False.
output_type (OutputType, optional): Format of the output.
Defaults to "dict".
random_models_on (bool, optional): Enable random model selection for diversity.
Defaults to False.
**kwargs: Additional keyword arguments passed to the base Agent class.
Note:
The num_samples parameter determines how many independent reasoning paths
will be generated. Higher values generally lead to more reliable results
but increase computational cost and time.
""" """
super().__init__( self.name = name
name=name, self.description = description
description=description, self.model_name = model_name
**kwargs,
)
self.num_samples = num_samples self.num_samples = num_samples
self.conversation = Conversation()
self.max_loops = max_loops self.max_loops = max_loops
self.majority_voting_prompt = majority_voting_prompt self.majority_voting_prompt = majority_voting_prompt
self.eval = eval self.eval = eval
self.output_type = output_type self.output_type = output_type
self.system_prompt = system_prompt self.system_prompt = system_prompt
self.random_models_on = random_models_on
self.conversation = Conversation()
self.args = args
self.kwargs = kwargs
def run( def run(
self, task: str, answer: str = None, *args, **kwargs self,
) -> str: task: str,
img: Optional[str] = None,
answer: Optional[str] = None,
*args,
**kwargs,
) -> Union[str, Dict[str, Any]]:
""" """
Generates multiple responses for the given prompt and aggregates them concurrently. Generate multiple responses for the given task and aggregate them concurrently.
This method implements the core self-consistency algorithm:
1. Generates multiple independent responses using concurrent execution
2. Optionally validates responses against a known answer (if eval=True)
3. Aggregates responses using an AI-powered aggregation agent
4. Returns the final result in the specified output format
Args: Args:
task (str): The input prompt. task (str): The input prompt or task to be solved
answer (Optional[str], optional): Expected answer for validation (if eval=True).
Defaults to None.
*args: Additional positional arguments passed to the base agent's run method
**kwargs: Additional keyword arguments passed to the base agent's run method
Returns: Returns:
str: The aggregated final answer. Union[str, Dict[str, Any]]: The aggregated final answer in the specified format
Raises:
RuntimeError: If evaluation mode is enabled and the expected answer is not found
in any of the generated responses
Example:
>>> agent = SelfConsistencyAgent(num_samples=3)
>>> result = agent.run("What is 2 + 2?")
>>> print(result)
>>> # With evaluation mode
>>> result = agent.run("What is 2 + 2?", answer="4", eval=True)
""" """
responses = [] responses = []
logger.info(
f"Generating {self.num_samples} responses concurrently..."
)
self.conversation.add(role="User", content=task) self.conversation.add(role="User", content=task)
# Generate multiple independent responses concurrently
reasoning_agent = self._create_reasoning_agent()
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
futures = { futures = {
executor.submit(super().run, task, *args, **kwargs): i executor.submit(
reasoning_agent.run,
task=task,
img=img,
*args,
**kwargs,
): i
for i in range(self.num_samples) for i in range(self.num_samples)
} }
for future in as_completed(futures): for future in as_completed(futures):
response = future.result() response = future.result()
responses.append(response) responses.append(response)
self.conversation.add(role=self.agent_name, content=responses) self.conversation.add(role=self.name, content=responses)
# Optional evaluation against known answer
if self.eval: if self.eval:
if answer is not None: if answer is not None:
correct = self.check_responses_for_answer( correct = self.check_responses_for_answer(
@ -116,9 +290,7 @@ class SelfConsistencyAgent(Agent):
) )
return None return None
# Aggregation agent # Aggregate responses using AI-powered aggregation
# final_answer = self.aggregation_agent(responses)
final_answer = aggregation_agent(responses) final_answer = aggregation_agent(responses)
self.conversation.add( self.conversation.add(
@ -129,39 +301,46 @@ class SelfConsistencyAgent(Agent):
self.conversation, self.output_type self.conversation, self.output_type
) )
def aggregate(self, responses: List[str]) -> str: def _create_reasoning_agent(self) -> Agent:
""" """
Aggregates a list of responses into a single final answer. Create a reasoning agent instance for generating individual responses.
Here we use a simple majority vote (most common answer) as an example. Depending on
the task, you might need a more sophisticated aggregation (e.g., weighting, consensus reasoning, etc.).
Args:
responses (list of str): The list of responses.
Returns: Returns:
str: The aggregated answer. Agent: A configured Agent instance for reasoning tasks
""" """
# Count the frequency of each response. return Agent(
counts = Counter(responses) agent_name=self.name,
most_common, freq = counts.most_common(1)[0] description=self.description,
logger.info( model_name=self.model_name,
f"Aggregation complete. Most common response (appeared {freq} times):" system_prompt=self.system_prompt,
max_loops=self.max_loops,
random_models_on=self.random_models_on,
output_type="str-all-except-first",
**self.kwargs,
) )
return most_common
def check_responses_for_answer( def check_responses_for_answer(
self, responses: List[str], answer: str self, responses: List[str], answer: str
) -> bool: ) -> bool:
""" """
Checks if the specified answer is present in any of the provided responses. Check if the specified answer is present in any of the provided responses.
This method performs a simple string matching to determine if the expected
answer appears in any of the generated responses. It's useful for validation
and evaluation purposes.
Args: Args:
responses (List[str]): A list of responses to check. responses (List[str]): List of responses to check
answer (str): The answer to look for in the responses. answer (str): The answer to look for in the responses
Returns: Returns:
bool: True if the answer is found in any response, False otherwise. bool: True if the answer is found in any response, False otherwise
Example:
>>> agent = SelfConsistencyAgent()
>>> responses = ["The answer is 42", "I think it's 42", "Not sure"]
>>> found = agent.check_responses_for_answer(responses, "42")
>>> print(found) # True
""" """
for response in responses: for response in responses:
if answer in response: if answer in response:
@ -181,27 +360,30 @@ class SelfConsistencyAgent(Agent):
def batched_run( def batched_run(
self, tasks: List[str], *args, **kwargs self, tasks: List[str], *args, **kwargs
) -> List[str]: ) -> List[Union[str, Dict[str, Any]]]:
""" """
Runs the agent in a batched manner. Run the agent on multiple tasks in batch.
This method processes multiple tasks sequentially, applying the self-consistency
approach to each task independently. It's useful for processing large datasets
or multiple related problems.
Args:
tasks (List[str]): List of tasks to be processed
*args: Additional positional arguments passed to the run method
**kwargs: Additional keyword arguments passed to the run method
Returns:
List[Union[str, Dict[str, Any]]]: List of results for each task
Example:
>>> agent = SelfConsistencyAgent()
>>> tasks = ["What is 2+2?", "What is 3+3?", "What is 4+4?"]
>>> results = agent.batched_run(tasks)
>>> print(len(results)) # 3
""" """
responses = [] responses = []
for task in tasks: for task in tasks:
response = self.run(task, *args, **kwargs) response = self.run(task, *args, **kwargs)
responses.append(response) responses.append(response)
return responses return responses
# # Example usage:
# if __name__ == "__main__":
# agent = SelfConsistencyAgent(
# agent_name="Reasoning-Agent",
# model_name="gpt-4o-mini",
# max_loops=1,
# num_samples=5, # Number of samples for self consistency
# )
# prompt = "What is the 40th prime number?"
# final_answer = agent.run(prompt)
# print("\nFinal aggregated answer:")
# print(final_answer)

@ -1,3 +1,41 @@
"""
ReasoningAgentRouter: A flexible router for advanced reasoning agent swarms.
This module provides the ReasoningAgentRouter class, which enables dynamic selection and instantiation
of various advanced reasoning agent types (swarms) for complex problem-solving tasks. It supports
multiple reasoning strategies, including self-consistency, collaborative duo agents, iterative
reflection, knowledge prompting, and agent judging.
Key Features:
- Unified interface for multiple agent types (see `agent_types`)
- Caching of agent instances for efficiency and memory management
- Extensible factory-based architecture for easy addition of new agent types
- Batch and single-task execution
- Customizable agent configuration (model, prompt, memory, etc.)
Supported Agent Types:
- "reasoning-duo" / "reasoning-agent": Dual collaborative agent system
- "self-consistency" / "consistency-agent": Multiple independent solutions with consensus
- "ire" / "ire-agent": Iterative Reflective Expansion agent
- "ReflexionAgent": Reflexion agent with memory
- "GKPAgent": Generated Knowledge Prompting agent
- "AgentJudge": Agent judge for evaluation/critique
Example usage:
>>> router = ReasoningAgentRouter(swarm_type="self-consistency", num_samples=3)
>>> result = router.run("What is the capital of France?")
>>> print(result)
>>> # Batch mode
>>> results = router.batched_run(["2+2?", "3+3?"])
>>> print(results)
See also:
- docs/swarms/agents/reasoning_agent_router.md for detailed documentation and architecture diagrams.
- consistency_example.py for a usage example with SelfConsistencyAgent.
"""
from typing import ( from typing import (
List, List,
Literal, Literal,
@ -6,9 +44,9 @@ from typing import (
Any, Any,
Tuple, Tuple,
Hashable, Hashable,
Optional,
) )
from swarms.agents.consistency_agent import SelfConsistencyAgent from swarms.agents.consistency_agent import SelfConsistencyAgent
from swarms.agents.flexion_agent import ReflexionAgent from swarms.agents.flexion_agent import ReflexionAgent
from swarms.agents.gkp_agent import GKPAgent from swarms.agents.gkp_agent import GKPAgent
@ -19,7 +57,7 @@ from swarms.agents.reasoning_duo import ReasoningDuo
from swarms.utils.output_types import OutputType from swarms.utils.output_types import OutputType
from swarms.agents.agent_judge import AgentJudge from swarms.agents.agent_judge import AgentJudge
#: Supported agent type literals for ReasoningAgentRouter
agent_types = Literal[ agent_types = Literal[
"reasoning-duo", "reasoning-duo",
"self-consistency", "self-consistency",
@ -35,18 +73,30 @@ agent_types = Literal[
class ReasoningAgentRouter: class ReasoningAgentRouter:
""" """
A Reasoning Agent that can answer questions and assist with various tasks using different reasoning strategies. A router for advanced reasoning agent swarms.
The ReasoningAgentRouter enables dynamic selection, instantiation, and caching of various
Attributes: reasoning agent types ("swarms") for flexible, robust, and scalable problem-solving.
agent_name (str): The name of the agent.
description (str): A brief description of the agent's capabilities. Args:
model_name (str): The name of the model used for reasoning. agent_name (str): Name identifier for the agent instance.
system_prompt (str): The prompt that guides the agent's reasoning process. description (str): Description of the agent's capabilities.
max_loops (int): The maximum number of loops for the reasoning process. model_name (str): The underlying language model to use.
swarm_type (agent_types): The type of reasoning swarm to use (e.g., reasoning duo, self-consistency, IRE). system_prompt (str): System prompt for the agent.
num_samples (int): The number of samples to generate for self-consistency agents. max_loops (int): Maximum number of reasoning loops.
output_type (OutputType): The format of the output (e.g., dict, list). swarm_type (agent_types): Type of reasoning swarm to use.
num_samples (int): Number of samples for self-consistency or iterations.
output_type (OutputType): Format of the output.
num_knowledge_items (int): Number of knowledge items for GKP agent.
memory_capacity (int): Memory capacity for agents that support it.
eval (bool): Enable evaluation mode for self-consistency.
random_models_on (bool): Enable random model selection for diversity.
majority_voting_prompt (Optional[str]): Custom prompt for majority voting.
Example:
>>> router = ReasoningAgentRouter(swarm_type="reasoning-duo")
>>> result = router.run("Explain quantum entanglement.")
>>> print(result)
""" """
# Class variable to store cached agent instances # Class variable to store cached agent instances
@ -59,12 +109,20 @@ class ReasoningAgentRouter:
model_name: str = "gpt-4o-mini", model_name: str = "gpt-4o-mini",
system_prompt: str = "You are a helpful assistant that can answer questions and help with tasks.", system_prompt: str = "You are a helpful assistant that can answer questions and help with tasks.",
max_loops: int = 1, max_loops: int = 1,
swarm_type: agent_types = "reasoning_duo", swarm_type: agent_types = "reasoning-duo",
num_samples: int = 1, num_samples: int = 1,
output_type: OutputType = "dict", output_type: OutputType = "dict-all-except-first",
num_knowledge_items: int = 6, num_knowledge_items: int = 6,
memory_capacity: int = 6, memory_capacity: int = 6,
eval: bool = False,
random_models_on: bool = False,
majority_voting_prompt: Optional[str] = None,
): ):
"""
Initialize the ReasoningAgentRouter with the specified configuration.
See class docstring for parameter details.
"""
self.agent_name = agent_name self.agent_name = agent_name
self.description = description self.description = description
self.model_name = model_name self.model_name = model_name
@ -75,14 +133,17 @@ class ReasoningAgentRouter:
self.output_type = output_type self.output_type = output_type
self.num_knowledge_items = num_knowledge_items self.num_knowledge_items = num_knowledge_items
self.memory_capacity = memory_capacity self.memory_capacity = memory_capacity
self.eval = eval
self.random_models_on = random_models_on
self.majority_voting_prompt = majority_voting_prompt
# Added: Initialize the factory mapping dictionary # Initialize the factory mapping dictionary
self._initialize_agent_factories() self._initialize_agent_factories()
def _initialize_agent_factories(self) -> None: def _initialize_agent_factories(self) -> None:
""" """
Initialize the agent factory mapping dictionary, mapping various agent types to their respective creation functions. Initialize the agent factory mapping dictionary, mapping various agent types to their respective creation functions.
This method replaces the original if-elif chain, making the code more maintainable and extensible. This method replaces the original if-elif chain, making the code more maintainable and extensible.
""" """
self.agent_factories: Dict[str, Callable[[], Any]] = { self.agent_factories: Dict[str, Callable[[], Any]] = {
@ -104,11 +165,11 @@ class ReasoningAgentRouter:
def _get_cache_key(self) -> Tuple[Hashable, ...]: def _get_cache_key(self) -> Tuple[Hashable, ...]:
""" """
Generate a unique key for cache lookup. Generate a unique key for cache lookup.
The key is based on all relevant configuration parameters of the agent.
The key is based on all relevant configuration parameters of the agent.
Returns: Returns:
Tuple[Hashable, ...]: A hashable tuple to serve as the cache key Tuple[Hashable, ...]: A hashable tuple to serve as the cache key.
""" """
return ( return (
self.swarm_type, self.swarm_type,
@ -121,10 +182,18 @@ class ReasoningAgentRouter:
self.output_type, self.output_type,
self.num_knowledge_items, self.num_knowledge_items,
self.memory_capacity, self.memory_capacity,
self.eval,
self.random_models_on,
self.majority_voting_prompt,
) )
def _create_reasoning_duo(self): def _create_reasoning_duo(self):
"""Create an agent instance for the ReasoningDuo type""" """
Create an agent instance for the ReasoningDuo type.
Returns:
ReasoningDuo: An instance of the ReasoningDuo agent.
"""
return ReasoningDuo( return ReasoningDuo(
agent_name=self.agent_name, agent_name=self.agent_name,
agent_description=self.description, agent_description=self.description,
@ -134,19 +203,32 @@ class ReasoningAgentRouter:
) )
def _create_consistency_agent(self): def _create_consistency_agent(self):
"""Create an agent instance for the SelfConsistencyAgent type""" """
Create an agent instance for the SelfConsistencyAgent type.
Returns:
SelfConsistencyAgent: An instance of the SelfConsistencyAgent.
"""
return SelfConsistencyAgent( return SelfConsistencyAgent(
agent_name=self.agent_name, name=self.agent_name,
description=self.description, description=self.description,
model_name=self.model_name, model_name=self.model_name,
system_prompt=self.system_prompt, system_prompt=self.system_prompt,
max_loops=self.max_loops, max_loops=self.max_loops,
num_samples=self.num_samples, num_samples=self.num_samples,
output_type=self.output_type, output_type=self.output_type,
eval=self.eval,
random_models_on=self.random_models_on,
majority_voting_prompt=self.majority_voting_prompt,
) )
def _create_ire_agent(self): def _create_ire_agent(self):
"""Create an agent instance for the IREAgent type""" """
Create an agent instance for the IREAgent type.
Returns:
IREAgent: An instance of the IterativeReflectiveExpansion agent.
"""
return IREAgent( return IREAgent(
agent_name=self.agent_name, agent_name=self.agent_name,
description=self.description, description=self.description,
@ -158,7 +240,12 @@ class ReasoningAgentRouter:
) )
def _create_agent_judge(self): def _create_agent_judge(self):
"""Create an agent instance for the AgentJudge type""" """
Create an agent instance for the AgentJudge type.
Returns:
AgentJudge: An instance of the AgentJudge agent.
"""
return AgentJudge( return AgentJudge(
agent_name=self.agent_name, agent_name=self.agent_name,
model_name=self.model_name, model_name=self.model_name,
@ -167,16 +254,27 @@ class ReasoningAgentRouter:
) )
def _create_reflexion_agent(self): def _create_reflexion_agent(self):
"""Create an agent instance for the ReflexionAgent type""" """
Create an agent instance for the ReflexionAgent type.
Returns:
ReflexionAgent: An instance of the ReflexionAgent.
"""
return ReflexionAgent( return ReflexionAgent(
agent_name=self.agent_name, agent_name=self.agent_name,
system_prompt=self.system_prompt, system_prompt=self.system_prompt,
model_name=self.model_name, model_name=self.model_name,
max_loops=self.max_loops, max_loops=self.max_loops,
memory_capacity=self.memory_capacity,
) )
def _create_gkp_agent(self): def _create_gkp_agent(self):
"""Create an agent instance for the GKPAgent type""" """
Create an agent instance for the GKPAgent type.
Returns:
GKPAgent: An instance of the GKPAgent.
"""
return GKPAgent( return GKPAgent(
agent_name=self.agent_name, agent_name=self.agent_name,
model_name=self.model_name, model_name=self.model_name,
@ -186,13 +284,15 @@ class ReasoningAgentRouter:
def select_swarm(self): def select_swarm(self):
""" """
Select and initialize the appropriate reasoning swarm based on the specified swarm type. Select and initialize the appropriate reasoning swarm based on the specified swarm type.
Uses a caching mechanism to return a cached instance if an agent with the same configuration already exists.
Uses a caching mechanism to return a cached instance if an agent with the same configuration already exists.
Returns: Returns:
The selected reasoning swarm instance. The selected reasoning swarm instance.
"""
Raises:
ValueError: If the specified swarm type is invalid.
"""
# Generate cache key # Generate cache key
cache_key = self._get_cache_key() cache_key = self._get_cache_key()
@ -216,25 +316,25 @@ class ReasoningAgentRouter:
""" """
Execute the reasoning process of the selected swarm on a given task. Execute the reasoning process of the selected swarm on a given task.
Args: Args:
task (str): The task or question to be processed by the reasoning agent. task (str): The task or question to be processed by the reasoning agent.
*args: Additional positional arguments for the agent's run method.
**kwargs: Additional keyword arguments for the agent's run method.
Returns: Returns:
The result of the reasoning process. The result of the reasoning process (format depends on agent and output_type).
""" """
swarm = self.select_swarm() swarm = self.select_swarm()
return swarm.run(task=task) return swarm.run(task=task, *args, **kwargs)
def batched_run(self, tasks: List[str], *args, **kwargs): def batched_run(self, tasks: List[str], *args, **kwargs):
""" """
Execute the reasoning process on a batch of tasks. Execute the reasoning process on a batch of tasks.
Args: Args:
tasks (List[str]): The list of tasks to process. tasks (List[str]): The list of tasks to process.
*args: Additional positional arguments for the agent's run method.
**kwargs: Additional keyword arguments for the agent's run method.
Returns: Returns:
A list of reasoning process results for each task. A list of reasoning process results for each task.
@ -248,6 +348,7 @@ class ReasoningAgentRouter:
def clear_cache(cls): def clear_cache(cls):
""" """
Clear the agent instance cache. Clear the agent instance cache.
Use this when you need to free memory or force the creation of new instances. Use this when you need to free memory or force the creation of new instances.
""" """
cls._agent_cache.clear() cls._agent_cache.clear()

@ -6,7 +6,6 @@ import threading
import uuid import uuid
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Callable,
Dict, Dict,
List, List,
Optional, Optional,
@ -190,18 +189,16 @@ class Conversation(BaseStructure):
save_enabled: bool = False, # New parameter to control if saving is enabled save_enabled: bool = False, # New parameter to control if saving is enabled
save_filepath: str = None, save_filepath: str = None,
load_filepath: str = None, # New parameter to specify which file to load from load_filepath: str = None, # New parameter to specify which file to load from
tokenizer: Callable = None,
context_length: int = 8192, context_length: int = 8192,
rules: str = None, rules: str = None,
custom_rules_prompt: str = None, custom_rules_prompt: str = None,
user: str = "User:", user: str = "User",
save_as_yaml: bool = False, save_as_yaml: bool = False,
save_as_json_bool: bool = False, save_as_json_bool: bool = False,
token_count: bool = True, token_count: bool = False,
message_id_on: bool = False, message_id_on: bool = False,
provider: providers = "in-memory", provider: providers = "in-memory",
backend: Optional[str] = None, backend: Optional[str] = None,
# Backend-specific parameters
supabase_url: Optional[str] = None, supabase_url: Optional[str] = None,
supabase_key: Optional[str] = None, supabase_key: Optional[str] = None,
redis_host: str = "localhost", redis_host: str = "localhost",
@ -210,7 +207,6 @@ class Conversation(BaseStructure):
redis_password: Optional[str] = None, redis_password: Optional[str] = None,
db_path: Optional[str] = None, db_path: Optional[str] = None,
table_name: str = "conversations", table_name: str = "conversations",
# Additional backend parameters
use_embedded_redis: bool = True, use_embedded_redis: bool = True,
persist_redis: bool = True, persist_redis: bool = True,
auto_persist: bool = True, auto_persist: bool = True,
@ -230,20 +226,7 @@ class Conversation(BaseStructure):
self.save_enabled = save_enabled self.save_enabled = save_enabled
self.conversations_dir = conversations_dir self.conversations_dir = conversations_dir
self.message_id_on = message_id_on self.message_id_on = message_id_on
# Handle save filepath
if save_enabled and save_filepath:
self.save_filepath = save_filepath
elif save_enabled and conversations_dir:
self.save_filepath = os.path.join(
conversations_dir, f"{self.id}.json"
)
else:
self.save_filepath = None
self.load_filepath = load_filepath self.load_filepath = load_filepath
self.conversation_history = []
self.tokenizer = tokenizer
self.context_length = context_length self.context_length = context_length
self.rules = rules self.rules = rules
self.custom_rules_prompt = custom_rules_prompt self.custom_rules_prompt = custom_rules_prompt
@ -253,9 +236,40 @@ class Conversation(BaseStructure):
self.token_count = token_count self.token_count = token_count
self.provider = provider # Keep for backwards compatibility self.provider = provider # Keep for backwards compatibility
self.conversations_dir = conversations_dir self.conversations_dir = conversations_dir
self.backend = backend
self.supabase_url = supabase_url
self.supabase_key = supabase_key
self.redis_host = redis_host
self.redis_port = redis_port
self.redis_db = redis_db
self.redis_password = redis_password
self.db_path = db_path
self.table_name = table_name
self.use_embedded_redis = use_embedded_redis
self.persist_redis = persist_redis
self.auto_persist = auto_persist
self.redis_data_dir = redis_data_dir
self.conversation_history = []
# Handle save filepath
if save_enabled and save_filepath:
self.save_filepath = save_filepath
elif save_enabled and conversations_dir:
self.save_filepath = os.path.join(
conversations_dir, f"{self.id}.json"
)
else:
self.save_filepath = None
# Support both 'provider' and 'backend' parameters for backwards compatibility # Support both 'provider' and 'backend' parameters for backwards compatibility
# 'backend' takes precedence if both are provided # 'backend' takes precedence if both are provided
self.backend_setup(backend, provider)
def backend_setup(
self, backend: str = None, provider: str = None
):
self.backend = backend or provider self.backend = backend or provider
self.backend_instance = None self.backend_instance = None
@ -285,19 +299,18 @@ class Conversation(BaseStructure):
]: ]:
try: try:
self._initialize_backend( self._initialize_backend(
supabase_url=supabase_url, supabase_url=self.supabase_url,
supabase_key=supabase_key, supabase_key=self.supabase_key,
redis_host=redis_host, redis_host=self.redis_host,
redis_port=redis_port, redis_port=self.redis_port,
redis_db=redis_db, redis_db=self.redis_db,
redis_password=redis_password, redis_password=self.redis_password,
db_path=db_path, db_path=self.db_path,
table_name=table_name, table_name=self.table_name,
use_embedded_redis=use_embedded_redis, use_embedded_redis=self.use_embedded_redis,
persist_redis=persist_redis, persist_redis=self.persist_redis,
auto_persist=auto_persist, auto_persist=self.auto_persist,
redis_data_dir=redis_data_dir, redis_data_dir=self.redis_data_dir,
**kwargs,
) )
except Exception as e: except Exception as e:
logger.warning( logger.warning(
@ -324,7 +337,6 @@ class Conversation(BaseStructure):
"time_enabled": self.time_enabled, "time_enabled": self.time_enabled,
"autosave": self.autosave, "autosave": self.autosave,
"save_filepath": self.save_filepath, "save_filepath": self.save_filepath,
"tokenizer": self.tokenizer,
"context_length": self.context_length, "context_length": self.context_length,
"rules": self.rules, "rules": self.rules,
"custom_rules_prompt": self.custom_rules_prompt, "custom_rules_prompt": self.custom_rules_prompt,
@ -449,8 +461,8 @@ class Conversation(BaseStructure):
if self.custom_rules_prompt is not None: if self.custom_rules_prompt is not None:
self.add(self.user or "User", self.custom_rules_prompt) self.add(self.user or "User", self.custom_rules_prompt)
if self.tokenizer is not None: # if self.tokenizer is not None:
self.truncate_memory_with_tokenizer() # self.truncate_memory_with_tokenizer()
def _autosave(self): def _autosave(self):
"""Automatically save the conversation if autosave is enabled.""" """Automatically save the conversation if autosave is enabled."""
@ -1051,9 +1063,7 @@ class Conversation(BaseStructure):
for message in self.conversation_history: for message in self.conversation_history:
role = message.get("role") role = message.get("role")
content = message.get("content") content = message.get("content")
tokens = self.tokenizer.count_tokens( tokens = count_tokens(content)
text=content
) # Count the number of tokens
count = tokens # Assign the token count count = tokens # Assign the token count
total_tokens += count total_tokens += count

@ -0,0 +1,624 @@
"""
Sparse Mixture-of-Experts (MoE) Transformer Implementation
Based on Gemini 2.5 architecture description
This implementation provides a sparse MoE architecture that activates only a subset
of expert parameters per input token, allowing for decoupling of model capacity
from computation cost.
"""
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
from torch import Tensor
class Expert(nn.Module):
"""
Individual expert network in the MoE architecture.
Each expert is a feed-forward network that specializes in processing
certain types of input patterns.
Args:
hidden_dim: Hidden dimension size
intermediate_dim: Intermediate dimension in feed-forward network
dropout: Dropout probability
activation: Activation function to use
"""
def __init__(
self,
hidden_dim: int,
intermediate_dim: int,
dropout: float = 0.1,
activation: str = "swish",
):
super().__init__()
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
# Feed-forward network
self.w1 = nn.Linear(hidden_dim, intermediate_dim, bias=False)
self.w2 = nn.Linear(intermediate_dim, hidden_dim, bias=False)
self.dropout = nn.Dropout(dropout)
# Activation function
if activation == "swish":
self.activation = lambda x: x * torch.sigmoid(x)
elif activation == "gelu":
self.activation = F.gelu
elif activation == "relu":
self.activation = F.relu
else:
raise ValueError(f"Unsupported activation: {activation}")
self._init_weights()
def _init_weights(self) -> None:
"""Initialize weights with proper scaling."""
nn.init.xavier_uniform_(self.w1.weight)
nn.init.xavier_uniform_(self.w2.weight)
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass through the expert network.
Args:
x: Input tensor of shape [batch_size, seq_len, hidden_dim]
Returns:
Output tensor of shape [batch_size, seq_len, hidden_dim]
"""
x = self.w1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.w2(x)
return x
class Router(nn.Module):
"""
Gating network that routes tokens to appropriate experts.
The router learns to assign input tokens to the most suitable experts
based on the token representations.
Args:
hidden_dim: Hidden dimension size
num_experts: Number of experts in the MoE layer
top_k: Number of experts to activate per token
temperature: Temperature for softmax routing
"""
def __init__(
self,
hidden_dim: int,
num_experts: int,
top_k: int = 2,
temperature: float = 1.0,
):
super().__init__()
self.hidden_dim = hidden_dim
self.num_experts = num_experts
self.top_k = top_k
self.temperature = temperature
# Linear layer for routing scores
self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
self._init_weights()
def _init_weights(self) -> None:
"""Initialize routing weights."""
nn.init.xavier_uniform_(self.gate.weight)
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""
Route tokens to experts.
Args:
x: Input tensor of shape [batch_size, seq_len, hidden_dim]
Returns:
Tuple of (routing_weights, expert_indices, routing_probs)
- routing_weights: [batch_size, seq_len, top_k]
- expert_indices: [batch_size, seq_len, top_k]
- routing_probs: [batch_size, seq_len, num_experts]
"""
batch_size, seq_len, hidden_dim = x.shape
# Compute routing scores
routing_logits = self.gate(
x
) # [batch_size, seq_len, num_experts]
routing_logits = routing_logits / self.temperature
# Apply softmax to get probabilities
routing_probs = F.softmax(routing_logits, dim=-1)
# Select top-k experts
routing_weights, expert_indices = torch.topk(
routing_probs, self.top_k, dim=-1
)
# Normalize routing weights
routing_weights = routing_weights / routing_weights.sum(
dim=-1, keepdim=True
)
return routing_weights, expert_indices, routing_probs
class MoELayer(nn.Module):
"""
Sparse Mixture-of-Experts layer.
This layer contains multiple expert networks and a router that decides
which experts to activate for each input token.
Args:
hidden_dim: Hidden dimension size
num_experts: Number of expert networks
top_k: Number of experts to activate per token
intermediate_dim: Intermediate dimension in expert networks
dropout: Dropout probability
activation: Activation function for experts
load_balance_weight: Weight for load balancing loss
"""
def __init__(
self,
hidden_dim: int,
num_experts: int,
top_k: int = 2,
intermediate_dim: Optional[int] = None,
dropout: float = 0.1,
activation: str = "swish",
load_balance_weight: float = 0.01,
):
super().__init__()
self.hidden_dim = hidden_dim
self.num_experts = num_experts
self.top_k = top_k
self.load_balance_weight = load_balance_weight
if intermediate_dim is None:
intermediate_dim = hidden_dim * 4
# Create expert networks
self.experts = nn.ModuleList(
[
Expert(
hidden_dim, intermediate_dim, dropout, activation
)
for _ in range(num_experts)
]
)
# Router for expert selection
self.router = Router(hidden_dim, num_experts, top_k)
logger.info(
f"Created MoE layer with {num_experts} experts, top_k={top_k}"
)
def forward(self, x: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
"""
Forward pass through MoE layer.
Args:
x: Input tensor of shape [batch_size, seq_len, hidden_dim]
Returns:
Tuple of (output, aux_losses)
- output: [batch_size, seq_len, hidden_dim]
- aux_losses: Dictionary containing auxiliary losses
"""
batch_size, seq_len, hidden_dim = x.shape
# Get routing decisions
routing_weights, expert_indices, routing_probs = self.router(
x
)
# Initialize output
output = torch.zeros_like(x)
# Process each expert
for i in range(self.num_experts):
# Create mask for tokens routed to this expert
expert_mask = (expert_indices == i).any(
dim=-1
) # [batch_size, seq_len]
if not expert_mask.any():
continue
# Get tokens for this expert
expert_tokens = x[expert_mask] # [num_tokens, hidden_dim]
if expert_tokens.numel() == 0:
continue
# Process through expert
expert_output = self.experts[i](expert_tokens)
# Compute weights for this expert
expert_weights = torch.zeros(
batch_size, seq_len, device=x.device
)
for k in range(self.top_k):
mask = expert_indices[:, :, k] == i
expert_weights[mask] = routing_weights[:, :, k][mask]
# Add weighted expert output
expert_contribution = torch.zeros_like(x)
expert_contribution[expert_mask] = expert_output
output += expert_contribution * expert_weights.unsqueeze(
-1
)
# Compute auxiliary losses
aux_losses = self._compute_aux_losses(
routing_probs, expert_indices
)
return output, aux_losses
def _compute_aux_losses(
self, routing_probs: Tensor, expert_indices: Tensor
) -> Dict[str, Tensor]:
"""
Compute auxiliary losses for training stability.
Args:
routing_probs: Routing probabilities [batch_size, seq_len, num_experts]
expert_indices: Selected expert indices [batch_size, seq_len, top_k]
Returns:
Dictionary of auxiliary losses
"""
batch_size, seq_len, num_experts = routing_probs.shape
# Load balancing loss
expert_usage = torch.zeros(
num_experts, device=routing_probs.device
)
total_tokens = batch_size * seq_len * self.top_k
for i in range(num_experts):
expert_usage[i] = (
expert_indices == i
).sum().float() / total_tokens
target_usage = 1.0 / num_experts
load_balance_loss = F.mse_loss(
expert_usage, torch.full_like(expert_usage, target_usage)
)
# Entropy loss to encourage diversity
entropy_loss = (
-(routing_probs * torch.log(routing_probs + 1e-8))
.sum(dim=-1)
.mean()
)
return {
"load_balance_loss": load_balance_loss
* self.load_balance_weight,
"entropy_loss": entropy_loss * 0.01,
"expert_usage": expert_usage,
}
class MoETransformerBlock(nn.Module):
"""
Transformer block with MoE feed-forward layer.
This block combines multi-head attention with a sparse MoE layer,
following the standard transformer architecture pattern.
Args:
hidden_dim: Hidden dimension size
num_heads: Number of attention heads
num_experts: Number of experts in MoE layer
top_k: Number of experts to activate per token
dropout: Dropout probability
layer_norm_eps: Epsilon for layer normalization
"""
def __init__(
self,
hidden_dim: int,
num_heads: int,
num_experts: int,
top_k: int = 2,
dropout: float = 0.1,
layer_norm_eps: float = 1e-6,
):
super().__init__()
self.hidden_dim = hidden_dim
# Multi-head attention
self.attention = nn.MultiheadAttention(
hidden_dim, num_heads, dropout=dropout, batch_first=True
)
# MoE layer
self.moe_layer = MoELayer(
hidden_dim=hidden_dim,
num_experts=num_experts,
top_k=top_k,
dropout=dropout,
)
# Layer normalization
self.norm1 = nn.LayerNorm(hidden_dim, eps=layer_norm_eps)
self.norm2 = nn.LayerNorm(hidden_dim, eps=layer_norm_eps)
# Dropout
self.dropout = nn.Dropout(dropout)
def forward(
self, x: Tensor, attention_mask: Optional[Tensor] = None
) -> Tuple[Tensor, Dict[str, Tensor]]:
"""
Forward pass through transformer block.
Args:
x: Input tensor [batch_size, seq_len, hidden_dim]
attention_mask: Optional attention mask
Returns:
Tuple of (output, aux_losses)
"""
# Self-attention with residual connection
residual = x
x = self.norm1(x)
attn_output, _ = self.attention(
x, x, x, key_padding_mask=attention_mask
)
x = residual + self.dropout(attn_output)
# MoE layer with residual connection
residual = x
x = self.norm2(x)
moe_output, aux_losses = self.moe_layer(x)
x = residual + self.dropout(moe_output)
return x, aux_losses
class MoETransformer(nn.Module):
"""
Complete sparse MoE Transformer model.
This model implements the full transformer architecture with sparse
mixture-of-experts layers, similar to the Gemini 2.5 architecture.
Args:
vocab_size: Vocabulary size
hidden_dim: Hidden dimension size
num_layers: Number of transformer layers
num_heads: Number of attention heads
num_experts: Number of experts per MoE layer
top_k: Number of experts to activate per token
max_seq_len: Maximum sequence length
dropout: Dropout probability
"""
def __init__(
self,
vocab_size: int,
hidden_dim: int,
num_layers: int,
num_heads: int,
num_experts: int,
top_k: int = 2,
max_seq_len: int = 2048,
dropout: float = 0.1,
):
super().__init__()
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.max_seq_len = max_seq_len
# Token embedding
self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
# Positional encoding
self.pos_embedding = nn.Parameter(
torch.randn(1, max_seq_len, hidden_dim) * 0.02
)
# Transformer layers
self.layers = nn.ModuleList(
[
MoETransformerBlock(
hidden_dim=hidden_dim,
num_heads=num_heads,
num_experts=num_experts,
top_k=top_k,
dropout=dropout,
)
for _ in range(num_layers)
]
)
# Final layer norm
self.final_norm = nn.LayerNorm(hidden_dim)
# Output projection
self.output_projection = nn.Linear(
hidden_dim, vocab_size, bias=False
)
# Tie input and output embeddings
self.output_projection.weight = self.token_embedding.weight
self._init_weights()
logger.info(
f"Created MoE Transformer with {num_layers} layers, "
f"{num_experts} experts per layer, hidden_dim={hidden_dim}"
)
def _init_weights(self) -> None:
"""Initialize model weights."""
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.pos_embedding, std=0.02)
# Initialize output projection
nn.init.normal_(self.output_projection.weight, std=0.02)
def forward(
self,
input_ids: Tensor,
attention_mask: Optional[Tensor] = None,
return_aux_losses: bool = True,
) -> Union[Tensor, Tuple[Tensor, Dict[str, Tensor]]]:
"""
Forward pass through the model.
Args:
input_ids: Input token IDs [batch_size, seq_len]
attention_mask: Optional attention mask [batch_size, seq_len]
return_aux_losses: Whether to return auxiliary losses
Returns:
If return_aux_losses=False: logits [batch_size, seq_len, vocab_size]
If return_aux_losses=True: (logits, aux_losses)
"""
batch_size, seq_len = input_ids.shape
# Token embeddings
x = self.token_embedding(input_ids)
# Add positional encoding
x = x + self.pos_embedding[:, :seq_len, :]
# Collect auxiliary losses
all_aux_losses = {}
# Pass through transformer layers
for i, layer in enumerate(self.layers):
x, aux_losses = layer(x, attention_mask)
if return_aux_losses:
for key, value in aux_losses.items():
if key not in all_aux_losses:
all_aux_losses[key] = []
all_aux_losses[key].append(value)
# Final layer norm
x = self.final_norm(x)
# Output projection
logits = self.output_projection(x)
if not return_aux_losses:
return logits
# Average auxiliary losses across layers
avg_aux_losses = {}
for key, values in all_aux_losses.items():
if key == "expert_usage":
# For expert usage, we want to see all layers
avg_aux_losses[key] = torch.stack(values)
else:
avg_aux_losses[key] = torch.stack(values).mean()
return logits, avg_aux_losses
def get_num_parameters(self) -> int:
"""Get total number of parameters."""
return sum(p.numel() for p in self.parameters())
def get_num_active_parameters(self) -> int:
"""Get number of active parameters per forward pass."""
# This is approximate - actual active parameters depend on routing
total_params = self.get_num_parameters()
# Estimate active expert parameters
expert_params_per_layer = 0
for layer in self.layers:
expert_params = sum(
p.numel()
for p in layer.moe_layer.experts[0].parameters()
)
expert_params_per_layer += (
expert_params * layer.moe_layer.top_k
)
total_expert_params = sum(
sum(
p.numel()
for expert in layer.moe_layer.experts
for p in expert.parameters()
)
for layer in self.layers
)
active_params = (
total_params
- total_expert_params
+ expert_params_per_layer * len(self.layers)
)
return active_params
# Example usage and testing
if __name__ == "__main__":
# Configure logger
logger.add("moe_training.log", rotation="500 MB", level="INFO")
# Model configuration
config = {
"vocab_size": 32000,
"hidden_dim": 768,
"num_layers": 12,
"num_heads": 12,
"num_experts": 8,
"top_k": 2,
"max_seq_len": 2048,
"dropout": 0.1,
}
# Create model
model = MoETransformer(**config)
# Print model info
total_params = model.get_num_parameters()
active_params = model.get_num_active_parameters()
logger.info(f"Total parameters: {total_params:,}")
logger.info(
f"Active parameters per forward pass: {active_params:,}"
)
logger.info(
f"Parameter efficiency: {active_params/total_params:.2%}"
)
# Test forward pass
batch_size, seq_len = 2, 512
input_ids = torch.randint(
0, config["vocab_size"], (batch_size, seq_len)
)
with torch.no_grad():
logits, aux_losses = model(input_ids)
logger.info(f"Input shape: {input_ids.shape}")
logger.info(f"Output shape: {logits.shape}")
logger.info(f"Auxiliary losses: {list(aux_losses.keys())}")
# Print expert usage statistics
expert_usage = aux_losses[
"expert_usage"
] # [num_layers, num_experts]
logger.info(f"Expert usage shape: {expert_usage.shape}")
logger.info(f"Average expert usage: {expert_usage.mean(dim=0)}")
Loading…
Cancel
Save