diff --git a/consistency_example.py b/consistency_example.py new file mode 100644 index 00000000..062d8ee1 --- /dev/null +++ b/consistency_example.py @@ -0,0 +1,22 @@ +from swarms import SelfConsistencyAgent + +# Initialize the reasoning agent router with self-consistency +reasoning_agent_router = SelfConsistencyAgent( + name="reasoning-agent", + description="A reasoning agent that can answer questions and help with tasks.", + model_name="gpt-4o-mini", + system_prompt="You are a helpful assistant that can answer questions and help with tasks.", + max_loops=1, + num_samples=3, # Generate 3 independent responses + eval=False, # Disable evaluation mode + random_models_on=False, # Disable random model selection + majority_voting_prompt=None, # Use default majority voting prompt +) + +# Run the agent on a financial analysis task +result = reasoning_agent_router.run( + "What is the best possible financial strategy to maximize returns but minimize risk? Give a list of etfs to invest in and the percentage of the portfolio to allocate to each etf." +) + +print("Financial Strategy Result:") +print(result) diff --git a/docs/swarms/agents/consistency_agent.md b/docs/swarms/agents/consistency_agent.md index 631af2cb..2f990445 100644 --- a/docs/swarms/agents/consistency_agent.md +++ b/docs/swarms/agents/consistency_agent.md @@ -1,6 +1,5 @@ # Consistency Agent Documentation - The `SelfConsistencyAgent` is a specialized agent designed for generating multiple independent responses to a given task and aggregating them into a single, consistent final answer. It leverages concurrent processing to enhance efficiency and employs a majority voting mechanism to ensure the reliability of the aggregated response. ## Purpose @@ -17,24 +16,31 @@ The primary objective of the `SelfConsistencyAgent` is to provide a robust mecha | Argument | Type | Default | Description | |------------------------|---------|---------|-----------------------------------------------------------------------------| -| `num_samples` | `int` | `5` | Number of independent responses to sample. | -| `return_list` | `bool` | `False` | Whether to return the conversation as a list. | -| `max_loops` | `int` | `1` | Maximum number of loops for the agent to run. | -| `return_dict` | `bool` | `False` | Whether to return the conversation as a dictionary. | -| `return_json` | `bool` | `False` | Whether to return the conversation as JSON. | -| `majority_voting_prompt` | `str` | `None` | Custom prompt for majority voting. | +| `name` | `str` | `"Self-Consistency-Agent"` | Name of the agent. | +| `description` | `str` | `"An agent that uses self consistency to generate a final answer."` | Description of the agent's purpose. | +| `system_prompt` | `str` | `CONSISTENCY_SYSTEM_PROMPT` | System prompt for the reasoning agent. | +| `model_name` | `str` | Required | The underlying language model to use. | +| `num_samples` | `int` | `5` | Number of independent responses to generate. | +| `max_loops` | `int` | `1` | Maximum number of reasoning loops per sample. | +| `majority_voting_prompt` | `Optional[str]` | `majority_voting_prompt` | Custom prompt for majority voting aggregation. | +| `eval` | `bool` | `False` | Enable evaluation mode for answer validation. | +| `output_type` | `OutputType` | `"dict"` | Format of the output. | +| `random_models_on` | `bool` | `False` | Enable random model selection for diversity. | ### Methods - **`run`**: Generates multiple responses for the given task and aggregates them. - **Arguments**: - `task` (`str`): The input prompt. - - `answer` (`str`, optional): The expected answer to validate responses against. - - **Returns**: `str` - The aggregated final answer. + - `img` (`Optional[str]`, optional): Image input for vision tasks. + - `answer` (`Optional[str]`, optional): Expected answer for validation (if eval=True). + - **Returns**: `Union[str, Dict[str, Any]]` - The aggregated final answer. -- **`aggregate`**: Aggregates a list of responses into a single final answer using majority voting. +- **`aggregation_agent`**: Aggregates a list of responses into a single final answer using majority voting. - **Arguments**: - `responses` (`List[str]`): The list of responses. + - `prompt` (`str`, optional): Custom prompt for the aggregation agent. + - `model_name` (`str`, optional): Model to use for aggregation. - **Returns**: `str` - The aggregated answer. - **`check_responses_for_answer`**: Checks if a specified answer is present in any of the provided responses. @@ -43,6 +49,11 @@ The primary objective of the `SelfConsistencyAgent` is to provide a robust mecha - `answer` (`str`): The answer to look for in the responses. - **Returns**: `bool` - `True` if the answer is found, `False` otherwise. +- **`batched_run`**: Run the agent on multiple tasks in batch. + - **Arguments**: + - `tasks` (`List[str]`): List of tasks to be processed. + - **Returns**: `List[Union[str, Dict[str, Any]]]` - List of results for each task. + ### Examples #### Example 1: Basic Usage @@ -52,7 +63,7 @@ from swarms.agents.consistency_agent import SelfConsistencyAgent # Initialize the agent agent = SelfConsistencyAgent( - agent_name="Reasoning-Agent", + name="Math-Reasoning-Agent", model_name="gpt-4o-mini", max_loops=1, num_samples=5 @@ -75,7 +86,7 @@ from swarms.agents.consistency_agent import SelfConsistencyAgent # Initialize the agent with a custom majority voting prompt agent = SelfConsistencyAgent( - agent_name="Reasoning-Agent", + name="Reasoning-Agent", model_name="gpt-4o-mini", max_loops=1, num_samples=5, @@ -92,4 +103,128 @@ final_answer = agent.run(task) print("Final aggregated answer:", final_answer) ``` +#### Example 3: Evaluation Mode + +```python +from swarms.agents.consistency_agent import SelfConsistencyAgent + +# Initialize the agent with evaluation mode +agent = SelfConsistencyAgent( + name="Validation-Agent", + model_name="gpt-4o-mini", + num_samples=3, + eval=True +) + +# Run with expected answer for validation +result = agent.run("What is 2 + 2?", answer="4", eval=True) +if result is not None: + print("Validation passed:", result) +else: + print("Validation failed - expected answer not found") +``` + +#### Example 4: Random Models for Diversity + +```python +from swarms.agents.consistency_agent import SelfConsistencyAgent + +# Initialize the agent with random model selection +agent = SelfConsistencyAgent( + name="Diverse-Reasoning-Agent", + model_name="gpt-4o-mini", + num_samples=5, + random_models_on=True +) + +# Run the agent +result = agent.run("What are the benefits of renewable energy?") +print("Diverse reasoning result:", result) +``` + +#### Example 5: Batch Processing + +```python +from swarms.agents.consistency_agent import SelfConsistencyAgent + +# Initialize the agent +agent = SelfConsistencyAgent( + name="Batch-Processing-Agent", + model_name="gpt-4o-mini", + num_samples=3 +) + +# Define multiple tasks +tasks = [ + "What is the capital of France?", + "What is 15 * 23?", + "Explain photosynthesis in simple terms." +] + +# Process all tasks +results = agent.batched_run(tasks) + +# Print results +for i, result in enumerate(results): + print(f"Task {i+1} result: {result}") +``` + +## Key Features + +### Self-Consistency Technique +The agent implements the self-consistency approach based on the research paper "Self-Consistency Improves Chain of Thought Reasoning in Language Models" by Wang et al. (2022). This technique: + +1. **Generates Multiple Independent Responses**: Creates several reasoning paths for the same problem +2. **Analyzes Consistency**: Examines agreement among different reasoning approaches +3. **Aggregates Results**: Uses majority voting or consensus building +4. **Produces Reliable Output**: Delivers a final answer reflecting the most reliable consensus + +### Benefits +- **Mitigates Random Errors**: Multiple reasoning paths reduce individual path errors +- **Reduces Bias**: Diverse approaches minimize single-method biases +- **Improves Reliability**: Consensus-based results are more trustworthy +- **Handles Complexity**: Better performance on complex problem-solving tasks + +### Use Cases +- **Mathematical Problem Solving**: Where accuracy is critical +- **Decision Making**: When reliability is paramount +- **Validation Tasks**: When answers need verification +- **Complex Reasoning**: Multi-step problem solving +- **Research Questions**: Where multiple perspectives are valuable + +## Technical Details + +### Concurrent Execution +The agent uses `ThreadPoolExecutor` to generate multiple responses concurrently, improving performance while maintaining independence between reasoning paths. + +### Aggregation Process +The aggregation uses an AI-powered agent that: +- Identifies dominant responses +- Analyzes disparities and disagreements +- Evaluates consensus strength +- Synthesizes minority insights +- Provides comprehensive recommendations + +### Output Formats +The agent supports various output types: +- `"dict"`: Dictionary format with conversation history +- `"str"`: Simple string output +- `"list"`: List format +- `"json"`: JSON formatted output + +## Limitations + +1. **Computational Cost**: Higher `num_samples` increases processing time and cost +2. **Model Dependencies**: Performance depends on the underlying model capabilities +3. **Consensus Challenges**: May struggle with tasks where multiple valid approaches exist +4. **Memory Usage**: Concurrent execution requires more memory resources + +## Best Practices + +1. **Sample Size**: Use 3-7 samples for most tasks; increase for critical decisions +2. **Model Selection**: Choose models with strong reasoning capabilities +3. **Evaluation Mode**: Enable for tasks with known correct answers +4. **Custom Prompts**: Tailor majority voting prompts for specific domains +5. **Batch Processing**: Use `batched_run` for multiple related tasks + --- diff --git a/docs/swarms/agents/reasoning_agent_router.md b/docs/swarms/agents/reasoning_agent_router.md index 1415c078..969d323f 100644 --- a/docs/swarms/agents/reasoning_agent_router.md +++ b/docs/swarms/agents/reasoning_agent_router.md @@ -38,9 +38,12 @@ graph TD | `max_loops` | int | 1 | Maximum number of reasoning loops | | `swarm_type` | agent_types | "reasoning_duo" | Type of reasoning swarm to use | | `num_samples` | int | 1 | Number of samples for self-consistency | - | `output_type` | OutputType | "dict" | Format of the output | + | `output_type` | OutputType | "dict-all-except-first" | Format of the output | | `num_knowledge_items` | int | 6 | Number of knowledge items for GKP agent | | `memory_capacity` | int | 6 | Memory capacity for agents that support it | + | `eval` | bool | False | Enable evaluation mode for self-consistency | + | `random_models_on` | bool | False | Enable random model selection for diversity | + | `majority_voting_prompt` | Optional[str] | None | Custom prompt for majority voting | ### Available Agent Types @@ -84,12 +87,16 @@ graph TD - Multiple solution generation - Consensus building - Solution verification + - Concurrent execution + - AI-powered aggregation **Best Use Cases** - Tasks requiring high reliability - Problems with multiple approaches - Validation-heavy tasks + - Mathematical problem solving + - Decision making scenarios **Required Parameters** @@ -98,9 +105,12 @@ graph TD **Optional Parameters** - - num_samples - - max_loops - - output_type + - num_samples (default: 5) + - max_loops (default: 1) + - output_type (default: "dict") + - eval (default: False) - Enable answer validation + - random_models_on (default: False) - Enable model diversity + - majority_voting_prompt (default: None) - Custom aggregation prompt === "IRE" **Key Features** @@ -217,14 +227,43 @@ graph TD system_prompt="You are a helpful assistant that can answer questions and help with tasks.", max_loops=1, swarm_type="self-consistency", - num_samples=1, - output_type="list" + num_samples=3, + eval=False, + random_models_on=False, + majority_voting_prompt=None ) # Run a single task result = router.run("What is the best approach to solve this problem?") ``` +=== "Self-Consistency Examples" + ```python + # Basic self-consistency + router = ReasoningAgentRouter( + swarm_type="self-consistency", + num_samples=3, + model_name="gpt-4o-mini" + ) + + # Self-consistency with evaluation mode + router = ReasoningAgentRouter( + swarm_type="self-consistency", + num_samples=5, + model_name="gpt-4o-mini", + eval=True, + random_models_on=True + ) + + # Self-consistency with custom majority voting + router = ReasoningAgentRouter( + swarm_type="self-consistency", + num_samples=3, + model_name="gpt-4o-mini", + majority_voting_prompt="Analyze the responses and provide the most accurate answer." + ) + ``` + === "ReflexionAgent" ```python router = ReasoningAgentRouter( @@ -265,9 +304,13 @@ graph TD 2. **Performance Optimization** - Adjust max_loops based on task complexity - - Increase num_samples for higher reliability + - Increase num_samples for higher reliability (3-7 for most tasks) - Choose appropriate model_name based on task requirements + + - Enable random_models_on for diverse reasoning approaches + + - Use eval mode for validation tasks with known answers 3. **Output Handling** - Use appropriate output_type for your needs @@ -275,6 +318,15 @@ graph TD - Process batched results appropriately - Handle errors gracefully + + 4. **Self-Consistency Specific** + - Use 3-5 samples for most tasks, 7+ for critical decisions + + - Enable eval mode when you have expected answers for validation + + - Customize majority_voting_prompt for domain-specific aggregation + + - Consider random_models_on for diverse model perspectives ## Limitations diff --git a/examples/single_agent/reasoning_agent_examples/reasoning_agent_router.py b/examples/single_agent/reasoning_agent_examples/reasoning_agent_router.py deleted file mode 100644 index 96341179..00000000 --- a/examples/single_agent/reasoning_agent_examples/reasoning_agent_router.py +++ /dev/null @@ -1,46 +0,0 @@ -from swarms.agents.reasoning_agents import ReasoningAgentRouter - -reasoning_agent_router = ReasoningAgentRouter( - agent_name="reasoning-agent", - description="A reasoning agent that can answer questions and help with tasks.", - model_name="gpt-4o-mini", - system_prompt="You are a helpful assistant that can answer questions and help with tasks.", - max_loops=1, - swarm_type="self-consistency", - num_samples=1, - output_type="list", -) - -reasoning_agent_router.run( - "What is the best possible financial strategy to maximize returns but minimize risk? Give a list of etfs to invest in and the percentage of the portfolio to allocate to each etf." -) - - -# reasoning_agent_router.batched_run( -# [ -# "What is the best possible financial strategy to maximize returns but minimize risk? Give a list of etfs to invest in and the percentage of the portfolio to allocate to each etf.", -# "What is the best possible financial strategy to maximize returns but minimize risk? Give a list of etfs to invest in and the percentage of the portfolio to allocate to each etf.", -# ] -# ) - - -# from swarms import ReasoningAgentRouter - - -# calculus_router = ReasoningAgentRouter( -# agent_name="calculus-expert", -# description="A calculus problem solving agent", -# model_name="gpt-4o-mini", -# system_prompt="You are a calculus expert. Solve differentiation and integration problems methodically.", -# swarm_type="self-consistency", -# num_samples=3, # Generate 3 samples to ensure consistency -# output_type="list", -# ) - - -# # Example calculus problem -# calculus_problem = "Find the derivative of f(x) = x³ln(x) - 5x²" - -# # Get the solution -# solution = calculus_router.run(calculus_problem) -# print(solution) diff --git a/reasoning_agent_router.py b/reasoning_agent_router.py new file mode 100644 index 00000000..c1f537e7 --- /dev/null +++ b/reasoning_agent_router.py @@ -0,0 +1,23 @@ +from swarms.agents.reasoning_agents import ReasoningAgentRouter + +# Initialize the reasoning agent router with self-consistency +reasoning_agent_router = ReasoningAgentRouter( + agent_name="reasoning-agent", + description="A reasoning agent that can answer questions and help with tasks.", + model_name="gpt-4o-mini", + system_prompt="You are a helpful assistant that can answer questions and help with tasks.", + max_loops=1, + swarm_type="self-consistency", + num_samples=3, # Generate 3 independent responses + eval=False, # Disable evaluation mode + random_models_on=False, # Disable random model selection + majority_voting_prompt=None, # Use default majority voting prompt +) + +# Run the agent on a financial analysis task +result = reasoning_agent_router.run( + "What is the best possible financial strategy to maximize returns but minimize risk? Give a list of etfs to invest in and the percentage of the portfolio to allocate to each etf." +) + +print("Financial Strategy Result:") +print(result) diff --git a/swarms/agents/consistency_agent.py b/swarms/agents/consistency_agent.py index b06db583..855e60e4 100644 --- a/swarms/agents/consistency_agent.py +++ b/swarms/agents/consistency_agent.py @@ -1,22 +1,71 @@ -from collections import Counter +""" +Self-Consistency Agent Implementation + +This module implements the SelfConsistencyAgent, a specialized agent that leverages the +self-consistency technique to improve reasoning reliability and accuracy. The agent generates +multiple independent responses to a given task and aggregates them into a single, consistent +final answer using majority voting and sophisticated aggregation techniques. + +The self-consistency approach is based on the research paper: +"Self-Consistency Improves Chain of Thought Reasoning in Language Models" +by Wang et al. (2022) - https://arxiv.org/abs/2203.07870 + +Key Features: +- Concurrent generation of multiple independent responses +- Majority voting aggregation with detailed analysis +- Evaluation mode for answer validation +- Configurable output formats +- Thread-safe execution + +Author: Swarms Team +License: MIT +""" + from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import List +from typing import List, Optional, Union, Dict, Any from loguru import logger from swarms.structs.agent import Agent from swarms.structs.conversation import Conversation -from swarms.structs.malt import majority_voting_prompt from swarms.utils.output_types import OutputType from swarms.utils.any_to_str import any_to_str from swarms.utils.history_output_formatter import ( history_output_formatter, ) +# System prompt for the reasoning agent that generates individual responses CONSISTENCY_SYSTEM_PROMPT = """ You are a reasoning agent designed for complex problem-solving and decision-making. Your objective is to provide clear and reliable responses through structured reasoning. Begin by thoroughly understanding the problem, rephrasing it for clarity, and identifying key components. Develop a logical plan that breaks the problem into manageable steps, detailing your approach and any assumptions made. Validate your information with reliable sources and assess the accuracy of your calculations. Explore multiple solutions, weighing their pros and cons, and maintain transparency by documenting your reasoning process, uncertainties, and biases. Summarize your findings in a concise final answer that reflects your thorough analysis, ensuring it is well-organized and accessible. Adapt your reasoning to the context of the problem, integrating new information as needed, and implement error-handling strategies to address any issues that arise. Finally, reflect on your reasoning process to identify areas for improvement and ensure consistency across all reasoning paths. """ +# Detailed prompt for the majority voting aggregation agent +majority_voting_prompt = """ +Engage in a comprehensive and exhaustive majority voting analysis of the following conversation, ensuring a deep and thoughtful examination of the responses provided by each agent. This analysis should not only summarize the responses but also critically engage with the content, context, and implications of each agent's input. + +Please adhere to the following detailed guidelines: + +1. **Identification of Dominant Responses:** + - Identify the most prevalent answer or recommendation across all agents. Provide a thorough rationale for its dominance, including an exploration of the factors that may have contributed to its acceptance among the agents. Discuss the context in which this consensus emerged and any relevant historical or theoretical frameworks that support this conclusion. + +2. **Exploration of Disparities:** + - Delve into any significant disparities or contrasting viewpoints between agents. Explore the underlying reasons for these differences, considering aspects such as differing methodologies, assumptions, or interpretations of the task at hand. Analyze how these contrasting perspectives may reflect broader debates within the field and what implications they hold for the overall understanding of the topic. + +3. **Consensus and Disagreement Analysis:** + - Highlight key areas of consensus and disagreement among the agents. Discuss the implications of these findings on the overall argument, including how consensus can strengthen certain claims while disagreement may indicate areas of uncertainty or contention. Provide examples from the conversation to illustrate these points and consider how they might influence future discussions or research directions. + +4. **Critical Evaluation of Majority Opinion:** + - Critically evaluate the strength of the majority opinion, considering factors such as the reasoning behind it and its mathematical validity if applicable. Assess whether the majority opinion is well-supported by evidence and logical reasoning, and discuss any potential weaknesses or oversights that may undermine its credibility. + +5. **Insights from Minority Viewpoints:** + - Note any unique insights from minority viewpoints, assessing their potential contributions to a more nuanced understanding of the topic. Discuss how these minority perspectives can enrich the conversation and provide alternative angles that may have been overlooked by the majority. Consider the value of dissent in academic discourse and how it can lead to more robust conclusions. + +6. **Synthesis of Recommendations:** + - Provide a final synthesized recommendation based on the majority consensus, ensuring that it reflects a thorough consideration of all perspectives and is grounded in sound reasoning. This recommendation should not only summarize the majority view but also integrate insights from minority opinions, creating a comprehensive and balanced conclusion that acknowledges the complexity of the discussion. + +Throughout your analysis, focus on uncovering clear patterns while being attentive to the subtleties and complexities inherent in the responses. Pay particular attention to the nuances of mathematical contexts where algorithmic thinking may be required, ensuring that your examination is both rigorous and accessible to a diverse audience. +""" + def aggregation_agent( responses: List[str], @@ -24,7 +73,27 @@ def aggregation_agent( model_name: str = "gpt-4o-mini", ) -> str: """ - Aggregates a list of responses into a single final answer. + Aggregates a list of responses into a single final answer using an AI-powered aggregation agent. + + This function creates a specialized agent that analyzes multiple responses and synthesizes + them into a coherent final answer. The aggregation process considers consensus, disagreements, + and minority viewpoints to produce a well-reasoned conclusion. + + Args: + responses (List[str]): List of responses to be aggregated + prompt (str, optional): Custom prompt for the aggregation agent. + Defaults to the majority_voting_prompt. + model_name (str, optional): Model to use for aggregation. + Defaults to "gpt-4o-mini". + + Returns: + str: The aggregated final answer + + Example: + >>> responses = ["Answer A", "Answer B", "Answer A"] + >>> final_answer = aggregation_agent(responses) + >>> print(final_answer) + "Based on the majority consensus..." """ task = any_to_str(responses) @@ -41,69 +110,174 @@ def aggregation_agent( return final_answer -class SelfConsistencyAgent(Agent): +class SelfConsistencyAgent: + """ + A specialized agent that implements self-consistency for improved reasoning reliability. + + The SelfConsistencyAgent generates multiple independent responses to a given task and + aggregates them into a single, consistent final answer. This approach is based on the + research paper "Self-Consistency Improves Chain of Thought Reasoning in Language Models" + by Wang et al. (2022). + + Key Features: + - Concurrent generation of multiple independent responses + - Majority voting aggregation with detailed analysis + - Evaluation mode for answer validation + - Configurable output formats + - Thread-safe execution + + The self-consistency technique works by: + 1. Generating multiple independent reasoning paths for the same problem + 2. Analyzing the consistency and agreement among these paths + 3. Aggregating the results using majority voting or consensus building + 4. Producing a final answer that reflects the most reliable consensus + + This approach helps mitigate issues like: + - Random errors in individual reasoning paths + - Biases in single reasoning approaches + - Inconsistencies in complex problem-solving + + Reference: + Wang, Y., Dong, W., Han, J., & Wang, W. (2022). Self-Consistency Improves Chain of + Thought Reasoning in Language Models. arXiv preprint arXiv:2203.07870. + https://arxiv.org/abs/2203.07870 + + Example: + >>> agent = SelfConsistencyAgent( + ... name="Math-Reasoning-Agent", + ... model_name="gpt-4o-mini", + ... num_samples=5, + ... max_loops=1 + ... ) + >>> result = agent.run("What is the 40th prime number?") + >>> print(result) + """ + def __init__( self, name: str = "Self-Consistency-Agent", description: str = "An agent that uses self consistency to generate a final answer.", + model_name: str = "gpt-4o-mini", system_prompt: str = CONSISTENCY_SYSTEM_PROMPT, num_samples: int = 5, max_loops: int = 1, - majority_voting_prompt: str = None, + majority_voting_prompt: Optional[ + str + ] = majority_voting_prompt, eval: bool = False, output_type: OutputType = "dict", + random_models_on: bool = False, + *args, **kwargs, ): """ - Initializes the SelfConsistencyAgent. + Initialize the SelfConsistencyAgent. Args: - num_samples (int): Number of independent responses to sample. - **kwargs: Other keyword arguments passed to the base Agent. + name (str, optional): Name of the agent. Defaults to "Self-Consistency-Agent". + description (str, optional): Description of the agent's purpose. + Defaults to "An agent that uses self consistency to generate a final answer.". + model_name (str, optional): The underlying language model to use. + Defaults to "gpt-4o-mini". + system_prompt (str, optional): System prompt for the reasoning agent. + Defaults to CONSISTENCY_SYSTEM_PROMPT. + num_samples (int, optional): Number of independent responses to generate. + Defaults to 5. + max_loops (int, optional): Maximum number of reasoning loops per sample. + Defaults to 1. + majority_voting_prompt (Optional[str], optional): Custom prompt for majority voting. + Defaults to None. + eval (bool, optional): Enable evaluation mode for answer validation. + Defaults to False. + output_type (OutputType, optional): Format of the output. + Defaults to "dict". + random_models_on (bool, optional): Enable random model selection for diversity. + Defaults to False. + **kwargs: Additional keyword arguments passed to the base Agent class. + + Note: + The num_samples parameter determines how many independent reasoning paths + will be generated. Higher values generally lead to more reliable results + but increase computational cost and time. """ - super().__init__( - name=name, - description=description, - **kwargs, - ) + self.name = name + self.description = description + self.model_name = model_name self.num_samples = num_samples - self.conversation = Conversation() self.max_loops = max_loops self.majority_voting_prompt = majority_voting_prompt self.eval = eval self.output_type = output_type self.system_prompt = system_prompt + self.random_models_on = random_models_on + self.conversation = Conversation() + self.args = args + self.kwargs = kwargs def run( - self, task: str, answer: str = None, *args, **kwargs - ) -> str: + self, + task: str, + img: Optional[str] = None, + answer: Optional[str] = None, + *args, + **kwargs, + ) -> Union[str, Dict[str, Any]]: """ - Generates multiple responses for the given prompt and aggregates them concurrently. + Generate multiple responses for the given task and aggregate them concurrently. + + This method implements the core self-consistency algorithm: + 1. Generates multiple independent responses using concurrent execution + 2. Optionally validates responses against a known answer (if eval=True) + 3. Aggregates responses using an AI-powered aggregation agent + 4. Returns the final result in the specified output format Args: - task (str): The input prompt. + task (str): The input prompt or task to be solved + answer (Optional[str], optional): Expected answer for validation (if eval=True). + Defaults to None. + *args: Additional positional arguments passed to the base agent's run method + **kwargs: Additional keyword arguments passed to the base agent's run method Returns: - str: The aggregated final answer. + Union[str, Dict[str, Any]]: The aggregated final answer in the specified format + + Raises: + RuntimeError: If evaluation mode is enabled and the expected answer is not found + in any of the generated responses + + Example: + >>> agent = SelfConsistencyAgent(num_samples=3) + >>> result = agent.run("What is 2 + 2?") + >>> print(result) + + >>> # With evaluation mode + >>> result = agent.run("What is 2 + 2?", answer="4", eval=True) """ responses = [] - logger.info( - f"Generating {self.num_samples} responses concurrently..." - ) self.conversation.add(role="User", content=task) + # Generate multiple independent responses concurrently + reasoning_agent = self._create_reasoning_agent() + with ThreadPoolExecutor() as executor: futures = { - executor.submit(super().run, task, *args, **kwargs): i + executor.submit( + reasoning_agent.run, + task=task, + img=img, + *args, + **kwargs, + ): i for i in range(self.num_samples) } for future in as_completed(futures): response = future.result() responses.append(response) - self.conversation.add(role=self.agent_name, content=responses) + self.conversation.add(role=self.name, content=responses) + # Optional evaluation against known answer if self.eval: if answer is not None: correct = self.check_responses_for_answer( @@ -116,9 +290,7 @@ class SelfConsistencyAgent(Agent): ) return None - # Aggregation agent - # final_answer = self.aggregation_agent(responses) - + # Aggregate responses using AI-powered aggregation final_answer = aggregation_agent(responses) self.conversation.add( @@ -129,39 +301,46 @@ class SelfConsistencyAgent(Agent): self.conversation, self.output_type ) - def aggregate(self, responses: List[str]) -> str: + def _create_reasoning_agent(self) -> Agent: """ - Aggregates a list of responses into a single final answer. - - Here we use a simple majority vote (most common answer) as an example. Depending on - the task, you might need a more sophisticated aggregation (e.g., weighting, consensus reasoning, etc.). - - Args: - responses (list of str): The list of responses. + Create a reasoning agent instance for generating individual responses. Returns: - str: The aggregated answer. + Agent: A configured Agent instance for reasoning tasks """ - # Count the frequency of each response. - counts = Counter(responses) - most_common, freq = counts.most_common(1)[0] - logger.info( - f"Aggregation complete. Most common response (appeared {freq} times):" + return Agent( + agent_name=self.name, + description=self.description, + model_name=self.model_name, + system_prompt=self.system_prompt, + max_loops=self.max_loops, + random_models_on=self.random_models_on, + output_type="str-all-except-first", + **self.kwargs, ) - return most_common def check_responses_for_answer( self, responses: List[str], answer: str ) -> bool: """ - Checks if the specified answer is present in any of the provided responses. + Check if the specified answer is present in any of the provided responses. + + This method performs a simple string matching to determine if the expected + answer appears in any of the generated responses. It's useful for validation + and evaluation purposes. Args: - responses (List[str]): A list of responses to check. - answer (str): The answer to look for in the responses. + responses (List[str]): List of responses to check + answer (str): The answer to look for in the responses Returns: - bool: True if the answer is found in any response, False otherwise. + bool: True if the answer is found in any response, False otherwise + + Example: + >>> agent = SelfConsistencyAgent() + >>> responses = ["The answer is 42", "I think it's 42", "Not sure"] + >>> found = agent.check_responses_for_answer(responses, "42") + >>> print(found) # True """ for response in responses: if answer in response: @@ -181,27 +360,30 @@ class SelfConsistencyAgent(Agent): def batched_run( self, tasks: List[str], *args, **kwargs - ) -> List[str]: + ) -> List[Union[str, Dict[str, Any]]]: """ - Runs the agent in a batched manner. + Run the agent on multiple tasks in batch. + + This method processes multiple tasks sequentially, applying the self-consistency + approach to each task independently. It's useful for processing large datasets + or multiple related problems. + + Args: + tasks (List[str]): List of tasks to be processed + *args: Additional positional arguments passed to the run method + **kwargs: Additional keyword arguments passed to the run method + + Returns: + List[Union[str, Dict[str, Any]]]: List of results for each task + + Example: + >>> agent = SelfConsistencyAgent() + >>> tasks = ["What is 2+2?", "What is 3+3?", "What is 4+4?"] + >>> results = agent.batched_run(tasks) + >>> print(len(results)) # 3 """ responses = [] for task in tasks: response = self.run(task, *args, **kwargs) responses.append(response) return responses - - -# # Example usage: -# if __name__ == "__main__": -# agent = SelfConsistencyAgent( -# agent_name="Reasoning-Agent", -# model_name="gpt-4o-mini", -# max_loops=1, -# num_samples=5, # Number of samples for self consistency -# ) - -# prompt = "What is the 40th prime number?" -# final_answer = agent.run(prompt) -# print("\nFinal aggregated answer:") -# print(final_answer) diff --git a/swarms/agents/reasoning_agents.py b/swarms/agents/reasoning_agents.py index e8087dbc..da9760e6 100644 --- a/swarms/agents/reasoning_agents.py +++ b/swarms/agents/reasoning_agents.py @@ -1,3 +1,41 @@ +""" +ReasoningAgentRouter: A flexible router for advanced reasoning agent swarms. + +This module provides the ReasoningAgentRouter class, which enables dynamic selection and instantiation +of various advanced reasoning agent types (swarms) for complex problem-solving tasks. It supports +multiple reasoning strategies, including self-consistency, collaborative duo agents, iterative +reflection, knowledge prompting, and agent judging. + +Key Features: +- Unified interface for multiple agent types (see `agent_types`) +- Caching of agent instances for efficiency and memory management +- Extensible factory-based architecture for easy addition of new agent types +- Batch and single-task execution +- Customizable agent configuration (model, prompt, memory, etc.) + +Supported Agent Types: + - "reasoning-duo" / "reasoning-agent": Dual collaborative agent system + - "self-consistency" / "consistency-agent": Multiple independent solutions with consensus + - "ire" / "ire-agent": Iterative Reflective Expansion agent + - "ReflexionAgent": Reflexion agent with memory + - "GKPAgent": Generated Knowledge Prompting agent + - "AgentJudge": Agent judge for evaluation/critique + +Example usage: + >>> router = ReasoningAgentRouter(swarm_type="self-consistency", num_samples=3) + >>> result = router.run("What is the capital of France?") + >>> print(result) + + >>> # Batch mode + >>> results = router.batched_run(["2+2?", "3+3?"]) + >>> print(results) + +See also: + - docs/swarms/agents/reasoning_agent_router.md for detailed documentation and architecture diagrams. + - consistency_example.py for a usage example with SelfConsistencyAgent. + +""" + from typing import ( List, Literal, @@ -6,9 +44,9 @@ from typing import ( Any, Tuple, Hashable, + Optional, ) - from swarms.agents.consistency_agent import SelfConsistencyAgent from swarms.agents.flexion_agent import ReflexionAgent from swarms.agents.gkp_agent import GKPAgent @@ -19,7 +57,7 @@ from swarms.agents.reasoning_duo import ReasoningDuo from swarms.utils.output_types import OutputType from swarms.agents.agent_judge import AgentJudge - +#: Supported agent type literals for ReasoningAgentRouter agent_types = Literal[ "reasoning-duo", "self-consistency", @@ -35,18 +73,30 @@ agent_types = Literal[ class ReasoningAgentRouter: """ - A Reasoning Agent that can answer questions and assist with various tasks using different reasoning strategies. - - - Attributes: - agent_name (str): The name of the agent. - description (str): A brief description of the agent's capabilities. - model_name (str): The name of the model used for reasoning. - system_prompt (str): The prompt that guides the agent's reasoning process. - max_loops (int): The maximum number of loops for the reasoning process. - swarm_type (agent_types): The type of reasoning swarm to use (e.g., reasoning duo, self-consistency, IRE). - num_samples (int): The number of samples to generate for self-consistency agents. - output_type (OutputType): The format of the output (e.g., dict, list). + A router for advanced reasoning agent swarms. + + The ReasoningAgentRouter enables dynamic selection, instantiation, and caching of various + reasoning agent types ("swarms") for flexible, robust, and scalable problem-solving. + + Args: + agent_name (str): Name identifier for the agent instance. + description (str): Description of the agent's capabilities. + model_name (str): The underlying language model to use. + system_prompt (str): System prompt for the agent. + max_loops (int): Maximum number of reasoning loops. + swarm_type (agent_types): Type of reasoning swarm to use. + num_samples (int): Number of samples for self-consistency or iterations. + output_type (OutputType): Format of the output. + num_knowledge_items (int): Number of knowledge items for GKP agent. + memory_capacity (int): Memory capacity for agents that support it. + eval (bool): Enable evaluation mode for self-consistency. + random_models_on (bool): Enable random model selection for diversity. + majority_voting_prompt (Optional[str]): Custom prompt for majority voting. + + Example: + >>> router = ReasoningAgentRouter(swarm_type="reasoning-duo") + >>> result = router.run("Explain quantum entanglement.") + >>> print(result) """ # Class variable to store cached agent instances @@ -59,12 +109,20 @@ class ReasoningAgentRouter: model_name: str = "gpt-4o-mini", system_prompt: str = "You are a helpful assistant that can answer questions and help with tasks.", max_loops: int = 1, - swarm_type: agent_types = "reasoning_duo", + swarm_type: agent_types = "reasoning-duo", num_samples: int = 1, - output_type: OutputType = "dict", + output_type: OutputType = "dict-all-except-first", num_knowledge_items: int = 6, memory_capacity: int = 6, + eval: bool = False, + random_models_on: bool = False, + majority_voting_prompt: Optional[str] = None, ): + """ + Initialize the ReasoningAgentRouter with the specified configuration. + + See class docstring for parameter details. + """ self.agent_name = agent_name self.description = description self.model_name = model_name @@ -75,14 +133,17 @@ class ReasoningAgentRouter: self.output_type = output_type self.num_knowledge_items = num_knowledge_items self.memory_capacity = memory_capacity + self.eval = eval + self.random_models_on = random_models_on + self.majority_voting_prompt = majority_voting_prompt - # Added: Initialize the factory mapping dictionary - + # Initialize the factory mapping dictionary self._initialize_agent_factories() def _initialize_agent_factories(self) -> None: """ Initialize the agent factory mapping dictionary, mapping various agent types to their respective creation functions. + This method replaces the original if-elif chain, making the code more maintainable and extensible. """ self.agent_factories: Dict[str, Callable[[], Any]] = { @@ -104,11 +165,11 @@ class ReasoningAgentRouter: def _get_cache_key(self) -> Tuple[Hashable, ...]: """ Generate a unique key for cache lookup. - The key is based on all relevant configuration parameters of the agent. + The key is based on all relevant configuration parameters of the agent. Returns: - Tuple[Hashable, ...]: A hashable tuple to serve as the cache key + Tuple[Hashable, ...]: A hashable tuple to serve as the cache key. """ return ( self.swarm_type, @@ -121,10 +182,18 @@ class ReasoningAgentRouter: self.output_type, self.num_knowledge_items, self.memory_capacity, + self.eval, + self.random_models_on, + self.majority_voting_prompt, ) def _create_reasoning_duo(self): - """Create an agent instance for the ReasoningDuo type""" + """ + Create an agent instance for the ReasoningDuo type. + + Returns: + ReasoningDuo: An instance of the ReasoningDuo agent. + """ return ReasoningDuo( agent_name=self.agent_name, agent_description=self.description, @@ -134,19 +203,32 @@ class ReasoningAgentRouter: ) def _create_consistency_agent(self): - """Create an agent instance for the SelfConsistencyAgent type""" + """ + Create an agent instance for the SelfConsistencyAgent type. + + Returns: + SelfConsistencyAgent: An instance of the SelfConsistencyAgent. + """ return SelfConsistencyAgent( - agent_name=self.agent_name, + name=self.agent_name, description=self.description, model_name=self.model_name, system_prompt=self.system_prompt, max_loops=self.max_loops, num_samples=self.num_samples, output_type=self.output_type, + eval=self.eval, + random_models_on=self.random_models_on, + majority_voting_prompt=self.majority_voting_prompt, ) def _create_ire_agent(self): - """Create an agent instance for the IREAgent type""" + """ + Create an agent instance for the IREAgent type. + + Returns: + IREAgent: An instance of the IterativeReflectiveExpansion agent. + """ return IREAgent( agent_name=self.agent_name, description=self.description, @@ -158,7 +240,12 @@ class ReasoningAgentRouter: ) def _create_agent_judge(self): - """Create an agent instance for the AgentJudge type""" + """ + Create an agent instance for the AgentJudge type. + + Returns: + AgentJudge: An instance of the AgentJudge agent. + """ return AgentJudge( agent_name=self.agent_name, model_name=self.model_name, @@ -167,16 +254,27 @@ class ReasoningAgentRouter: ) def _create_reflexion_agent(self): - """Create an agent instance for the ReflexionAgent type""" + """ + Create an agent instance for the ReflexionAgent type. + + Returns: + ReflexionAgent: An instance of the ReflexionAgent. + """ return ReflexionAgent( agent_name=self.agent_name, system_prompt=self.system_prompt, model_name=self.model_name, max_loops=self.max_loops, + memory_capacity=self.memory_capacity, ) def _create_gkp_agent(self): - """Create an agent instance for the GKPAgent type""" + """ + Create an agent instance for the GKPAgent type. + + Returns: + GKPAgent: An instance of the GKPAgent. + """ return GKPAgent( agent_name=self.agent_name, model_name=self.model_name, @@ -186,13 +284,15 @@ class ReasoningAgentRouter: def select_swarm(self): """ Select and initialize the appropriate reasoning swarm based on the specified swarm type. - Uses a caching mechanism to return a cached instance if an agent with the same configuration already exists. + Uses a caching mechanism to return a cached instance if an agent with the same configuration already exists. Returns: The selected reasoning swarm instance. - """ + Raises: + ValueError: If the specified swarm type is invalid. + """ # Generate cache key cache_key = self._get_cache_key() @@ -216,25 +316,25 @@ class ReasoningAgentRouter: """ Execute the reasoning process of the selected swarm on a given task. - Args: task (str): The task or question to be processed by the reasoning agent. - + *args: Additional positional arguments for the agent's run method. + **kwargs: Additional keyword arguments for the agent's run method. Returns: - The result of the reasoning process. + The result of the reasoning process (format depends on agent and output_type). """ swarm = self.select_swarm() - return swarm.run(task=task) + return swarm.run(task=task, *args, **kwargs) def batched_run(self, tasks: List[str], *args, **kwargs): """ Execute the reasoning process on a batch of tasks. - Args: tasks (List[str]): The list of tasks to process. - + *args: Additional positional arguments for the agent's run method. + **kwargs: Additional keyword arguments for the agent's run method. Returns: A list of reasoning process results for each task. @@ -248,6 +348,7 @@ class ReasoningAgentRouter: def clear_cache(cls): """ Clear the agent instance cache. + Use this when you need to free memory or force the creation of new instances. """ cls._agent_cache.clear() diff --git a/swarms/structs/conversation.py b/swarms/structs/conversation.py index 82493f38..45371e71 100644 --- a/swarms/structs/conversation.py +++ b/swarms/structs/conversation.py @@ -6,7 +6,6 @@ import threading import uuid from typing import ( TYPE_CHECKING, - Callable, Dict, List, Optional, @@ -190,18 +189,16 @@ class Conversation(BaseStructure): save_enabled: bool = False, # New parameter to control if saving is enabled save_filepath: str = None, load_filepath: str = None, # New parameter to specify which file to load from - tokenizer: Callable = None, context_length: int = 8192, rules: str = None, custom_rules_prompt: str = None, - user: str = "User:", + user: str = "User", save_as_yaml: bool = False, save_as_json_bool: bool = False, - token_count: bool = True, + token_count: bool = False, message_id_on: bool = False, provider: providers = "in-memory", backend: Optional[str] = None, - # Backend-specific parameters supabase_url: Optional[str] = None, supabase_key: Optional[str] = None, redis_host: str = "localhost", @@ -210,7 +207,6 @@ class Conversation(BaseStructure): redis_password: Optional[str] = None, db_path: Optional[str] = None, table_name: str = "conversations", - # Additional backend parameters use_embedded_redis: bool = True, persist_redis: bool = True, auto_persist: bool = True, @@ -230,20 +226,7 @@ class Conversation(BaseStructure): self.save_enabled = save_enabled self.conversations_dir = conversations_dir self.message_id_on = message_id_on - - # Handle save filepath - if save_enabled and save_filepath: - self.save_filepath = save_filepath - elif save_enabled and conversations_dir: - self.save_filepath = os.path.join( - conversations_dir, f"{self.id}.json" - ) - else: - self.save_filepath = None - self.load_filepath = load_filepath - self.conversation_history = [] - self.tokenizer = tokenizer self.context_length = context_length self.rules = rules self.custom_rules_prompt = custom_rules_prompt @@ -253,9 +236,40 @@ class Conversation(BaseStructure): self.token_count = token_count self.provider = provider # Keep for backwards compatibility self.conversations_dir = conversations_dir + self.backend = backend + self.supabase_url = supabase_url + self.supabase_key = supabase_key + self.redis_host = redis_host + self.redis_port = redis_port + self.redis_db = redis_db + self.redis_password = redis_password + self.db_path = db_path + self.table_name = table_name + self.use_embedded_redis = use_embedded_redis + self.persist_redis = persist_redis + self.auto_persist = auto_persist + self.redis_data_dir = redis_data_dir + + self.conversation_history = [] + + # Handle save filepath + if save_enabled and save_filepath: + self.save_filepath = save_filepath + elif save_enabled and conversations_dir: + self.save_filepath = os.path.join( + conversations_dir, f"{self.id}.json" + ) + else: + self.save_filepath = None # Support both 'provider' and 'backend' parameters for backwards compatibility # 'backend' takes precedence if both are provided + + self.backend_setup(backend, provider) + + def backend_setup( + self, backend: str = None, provider: str = None + ): self.backend = backend or provider self.backend_instance = None @@ -285,19 +299,18 @@ class Conversation(BaseStructure): ]: try: self._initialize_backend( - supabase_url=supabase_url, - supabase_key=supabase_key, - redis_host=redis_host, - redis_port=redis_port, - redis_db=redis_db, - redis_password=redis_password, - db_path=db_path, - table_name=table_name, - use_embedded_redis=use_embedded_redis, - persist_redis=persist_redis, - auto_persist=auto_persist, - redis_data_dir=redis_data_dir, - **kwargs, + supabase_url=self.supabase_url, + supabase_key=self.supabase_key, + redis_host=self.redis_host, + redis_port=self.redis_port, + redis_db=self.redis_db, + redis_password=self.redis_password, + db_path=self.db_path, + table_name=self.table_name, + use_embedded_redis=self.use_embedded_redis, + persist_redis=self.persist_redis, + auto_persist=self.auto_persist, + redis_data_dir=self.redis_data_dir, ) except Exception as e: logger.warning( @@ -324,7 +337,6 @@ class Conversation(BaseStructure): "time_enabled": self.time_enabled, "autosave": self.autosave, "save_filepath": self.save_filepath, - "tokenizer": self.tokenizer, "context_length": self.context_length, "rules": self.rules, "custom_rules_prompt": self.custom_rules_prompt, @@ -449,8 +461,8 @@ class Conversation(BaseStructure): if self.custom_rules_prompt is not None: self.add(self.user or "User", self.custom_rules_prompt) - if self.tokenizer is not None: - self.truncate_memory_with_tokenizer() + # if self.tokenizer is not None: + # self.truncate_memory_with_tokenizer() def _autosave(self): """Automatically save the conversation if autosave is enabled.""" @@ -1051,9 +1063,7 @@ class Conversation(BaseStructure): for message in self.conversation_history: role = message.get("role") content = message.get("content") - tokens = self.tokenizer.count_tokens( - text=content - ) # Count the number of tokens + tokens = count_tokens(content) count = tokens # Assign the token count total_tokens += count diff --git a/test_llm.py b/test_llm.py new file mode 100644 index 00000000..3ebd8a9d --- /dev/null +++ b/test_llm.py @@ -0,0 +1,624 @@ +""" +Sparse Mixture-of-Experts (MoE) Transformer Implementation +Based on Gemini 2.5 architecture description + +This implementation provides a sparse MoE architecture that activates only a subset +of expert parameters per input token, allowing for decoupling of model capacity +from computation cost. +""" + +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from loguru import logger +from torch import Tensor + + +class Expert(nn.Module): + """ + Individual expert network in the MoE architecture. + + Each expert is a feed-forward network that specializes in processing + certain types of input patterns. + + Args: + hidden_dim: Hidden dimension size + intermediate_dim: Intermediate dimension in feed-forward network + dropout: Dropout probability + activation: Activation function to use + """ + + def __init__( + self, + hidden_dim: int, + intermediate_dim: int, + dropout: float = 0.1, + activation: str = "swish", + ): + super().__init__() + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + + # Feed-forward network + self.w1 = nn.Linear(hidden_dim, intermediate_dim, bias=False) + self.w2 = nn.Linear(intermediate_dim, hidden_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # Activation function + if activation == "swish": + self.activation = lambda x: x * torch.sigmoid(x) + elif activation == "gelu": + self.activation = F.gelu + elif activation == "relu": + self.activation = F.relu + else: + raise ValueError(f"Unsupported activation: {activation}") + + self._init_weights() + + def _init_weights(self) -> None: + """Initialize weights with proper scaling.""" + nn.init.xavier_uniform_(self.w1.weight) + nn.init.xavier_uniform_(self.w2.weight) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass through the expert network. + + Args: + x: Input tensor of shape [batch_size, seq_len, hidden_dim] + + Returns: + Output tensor of shape [batch_size, seq_len, hidden_dim] + """ + x = self.w1(x) + x = self.activation(x) + x = self.dropout(x) + x = self.w2(x) + return x + + +class Router(nn.Module): + """ + Gating network that routes tokens to appropriate experts. + + The router learns to assign input tokens to the most suitable experts + based on the token representations. + + Args: + hidden_dim: Hidden dimension size + num_experts: Number of experts in the MoE layer + top_k: Number of experts to activate per token + temperature: Temperature for softmax routing + """ + + def __init__( + self, + hidden_dim: int, + num_experts: int, + top_k: int = 2, + temperature: float = 1.0, + ): + super().__init__() + self.hidden_dim = hidden_dim + self.num_experts = num_experts + self.top_k = top_k + self.temperature = temperature + + # Linear layer for routing scores + self.gate = nn.Linear(hidden_dim, num_experts, bias=False) + self._init_weights() + + def _init_weights(self) -> None: + """Initialize routing weights.""" + nn.init.xavier_uniform_(self.gate.weight) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Route tokens to experts. + + Args: + x: Input tensor of shape [batch_size, seq_len, hidden_dim] + + Returns: + Tuple of (routing_weights, expert_indices, routing_probs) + - routing_weights: [batch_size, seq_len, top_k] + - expert_indices: [batch_size, seq_len, top_k] + - routing_probs: [batch_size, seq_len, num_experts] + """ + batch_size, seq_len, hidden_dim = x.shape + + # Compute routing scores + routing_logits = self.gate( + x + ) # [batch_size, seq_len, num_experts] + routing_logits = routing_logits / self.temperature + + # Apply softmax to get probabilities + routing_probs = F.softmax(routing_logits, dim=-1) + + # Select top-k experts + routing_weights, expert_indices = torch.topk( + routing_probs, self.top_k, dim=-1 + ) + + # Normalize routing weights + routing_weights = routing_weights / routing_weights.sum( + dim=-1, keepdim=True + ) + + return routing_weights, expert_indices, routing_probs + + +class MoELayer(nn.Module): + """ + Sparse Mixture-of-Experts layer. + + This layer contains multiple expert networks and a router that decides + which experts to activate for each input token. + + Args: + hidden_dim: Hidden dimension size + num_experts: Number of expert networks + top_k: Number of experts to activate per token + intermediate_dim: Intermediate dimension in expert networks + dropout: Dropout probability + activation: Activation function for experts + load_balance_weight: Weight for load balancing loss + """ + + def __init__( + self, + hidden_dim: int, + num_experts: int, + top_k: int = 2, + intermediate_dim: Optional[int] = None, + dropout: float = 0.1, + activation: str = "swish", + load_balance_weight: float = 0.01, + ): + super().__init__() + self.hidden_dim = hidden_dim + self.num_experts = num_experts + self.top_k = top_k + self.load_balance_weight = load_balance_weight + + if intermediate_dim is None: + intermediate_dim = hidden_dim * 4 + + # Create expert networks + self.experts = nn.ModuleList( + [ + Expert( + hidden_dim, intermediate_dim, dropout, activation + ) + for _ in range(num_experts) + ] + ) + + # Router for expert selection + self.router = Router(hidden_dim, num_experts, top_k) + + logger.info( + f"Created MoE layer with {num_experts} experts, top_k={top_k}" + ) + + def forward(self, x: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: + """ + Forward pass through MoE layer. + + Args: + x: Input tensor of shape [batch_size, seq_len, hidden_dim] + + Returns: + Tuple of (output, aux_losses) + - output: [batch_size, seq_len, hidden_dim] + - aux_losses: Dictionary containing auxiliary losses + """ + batch_size, seq_len, hidden_dim = x.shape + + # Get routing decisions + routing_weights, expert_indices, routing_probs = self.router( + x + ) + + # Initialize output + output = torch.zeros_like(x) + + # Process each expert + for i in range(self.num_experts): + # Create mask for tokens routed to this expert + expert_mask = (expert_indices == i).any( + dim=-1 + ) # [batch_size, seq_len] + + if not expert_mask.any(): + continue + + # Get tokens for this expert + expert_tokens = x[expert_mask] # [num_tokens, hidden_dim] + + if expert_tokens.numel() == 0: + continue + + # Process through expert + expert_output = self.experts[i](expert_tokens) + + # Compute weights for this expert + expert_weights = torch.zeros( + batch_size, seq_len, device=x.device + ) + for k in range(self.top_k): + mask = expert_indices[:, :, k] == i + expert_weights[mask] = routing_weights[:, :, k][mask] + + # Add weighted expert output + expert_contribution = torch.zeros_like(x) + expert_contribution[expert_mask] = expert_output + output += expert_contribution * expert_weights.unsqueeze( + -1 + ) + + # Compute auxiliary losses + aux_losses = self._compute_aux_losses( + routing_probs, expert_indices + ) + + return output, aux_losses + + def _compute_aux_losses( + self, routing_probs: Tensor, expert_indices: Tensor + ) -> Dict[str, Tensor]: + """ + Compute auxiliary losses for training stability. + + Args: + routing_probs: Routing probabilities [batch_size, seq_len, num_experts] + expert_indices: Selected expert indices [batch_size, seq_len, top_k] + + Returns: + Dictionary of auxiliary losses + """ + batch_size, seq_len, num_experts = routing_probs.shape + + # Load balancing loss + expert_usage = torch.zeros( + num_experts, device=routing_probs.device + ) + total_tokens = batch_size * seq_len * self.top_k + + for i in range(num_experts): + expert_usage[i] = ( + expert_indices == i + ).sum().float() / total_tokens + + target_usage = 1.0 / num_experts + load_balance_loss = F.mse_loss( + expert_usage, torch.full_like(expert_usage, target_usage) + ) + + # Entropy loss to encourage diversity + entropy_loss = ( + -(routing_probs * torch.log(routing_probs + 1e-8)) + .sum(dim=-1) + .mean() + ) + + return { + "load_balance_loss": load_balance_loss + * self.load_balance_weight, + "entropy_loss": entropy_loss * 0.01, + "expert_usage": expert_usage, + } + + +class MoETransformerBlock(nn.Module): + """ + Transformer block with MoE feed-forward layer. + + This block combines multi-head attention with a sparse MoE layer, + following the standard transformer architecture pattern. + + Args: + hidden_dim: Hidden dimension size + num_heads: Number of attention heads + num_experts: Number of experts in MoE layer + top_k: Number of experts to activate per token + dropout: Dropout probability + layer_norm_eps: Epsilon for layer normalization + """ + + def __init__( + self, + hidden_dim: int, + num_heads: int, + num_experts: int, + top_k: int = 2, + dropout: float = 0.1, + layer_norm_eps: float = 1e-6, + ): + super().__init__() + self.hidden_dim = hidden_dim + + # Multi-head attention + self.attention = nn.MultiheadAttention( + hidden_dim, num_heads, dropout=dropout, batch_first=True + ) + + # MoE layer + self.moe_layer = MoELayer( + hidden_dim=hidden_dim, + num_experts=num_experts, + top_k=top_k, + dropout=dropout, + ) + + # Layer normalization + self.norm1 = nn.LayerNorm(hidden_dim, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(hidden_dim, eps=layer_norm_eps) + + # Dropout + self.dropout = nn.Dropout(dropout) + + def forward( + self, x: Tensor, attention_mask: Optional[Tensor] = None + ) -> Tuple[Tensor, Dict[str, Tensor]]: + """ + Forward pass through transformer block. + + Args: + x: Input tensor [batch_size, seq_len, hidden_dim] + attention_mask: Optional attention mask + + Returns: + Tuple of (output, aux_losses) + """ + # Self-attention with residual connection + residual = x + x = self.norm1(x) + attn_output, _ = self.attention( + x, x, x, key_padding_mask=attention_mask + ) + x = residual + self.dropout(attn_output) + + # MoE layer with residual connection + residual = x + x = self.norm2(x) + moe_output, aux_losses = self.moe_layer(x) + x = residual + self.dropout(moe_output) + + return x, aux_losses + + +class MoETransformer(nn.Module): + """ + Complete sparse MoE Transformer model. + + This model implements the full transformer architecture with sparse + mixture-of-experts layers, similar to the Gemini 2.5 architecture. + + Args: + vocab_size: Vocabulary size + hidden_dim: Hidden dimension size + num_layers: Number of transformer layers + num_heads: Number of attention heads + num_experts: Number of experts per MoE layer + top_k: Number of experts to activate per token + max_seq_len: Maximum sequence length + dropout: Dropout probability + """ + + def __init__( + self, + vocab_size: int, + hidden_dim: int, + num_layers: int, + num_heads: int, + num_experts: int, + top_k: int = 2, + max_seq_len: int = 2048, + dropout: float = 0.1, + ): + super().__init__() + self.vocab_size = vocab_size + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.max_seq_len = max_seq_len + + # Token embedding + self.token_embedding = nn.Embedding(vocab_size, hidden_dim) + + # Positional encoding + self.pos_embedding = nn.Parameter( + torch.randn(1, max_seq_len, hidden_dim) * 0.02 + ) + + # Transformer layers + self.layers = nn.ModuleList( + [ + MoETransformerBlock( + hidden_dim=hidden_dim, + num_heads=num_heads, + num_experts=num_experts, + top_k=top_k, + dropout=dropout, + ) + for _ in range(num_layers) + ] + ) + + # Final layer norm + self.final_norm = nn.LayerNorm(hidden_dim) + + # Output projection + self.output_projection = nn.Linear( + hidden_dim, vocab_size, bias=False + ) + + # Tie input and output embeddings + self.output_projection.weight = self.token_embedding.weight + + self._init_weights() + + logger.info( + f"Created MoE Transformer with {num_layers} layers, " + f"{num_experts} experts per layer, hidden_dim={hidden_dim}" + ) + + def _init_weights(self) -> None: + """Initialize model weights.""" + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.pos_embedding, std=0.02) + + # Initialize output projection + nn.init.normal_(self.output_projection.weight, std=0.02) + + def forward( + self, + input_ids: Tensor, + attention_mask: Optional[Tensor] = None, + return_aux_losses: bool = True, + ) -> Union[Tensor, Tuple[Tensor, Dict[str, Tensor]]]: + """ + Forward pass through the model. + + Args: + input_ids: Input token IDs [batch_size, seq_len] + attention_mask: Optional attention mask [batch_size, seq_len] + return_aux_losses: Whether to return auxiliary losses + + Returns: + If return_aux_losses=False: logits [batch_size, seq_len, vocab_size] + If return_aux_losses=True: (logits, aux_losses) + """ + batch_size, seq_len = input_ids.shape + + # Token embeddings + x = self.token_embedding(input_ids) + + # Add positional encoding + x = x + self.pos_embedding[:, :seq_len, :] + + # Collect auxiliary losses + all_aux_losses = {} + + # Pass through transformer layers + for i, layer in enumerate(self.layers): + x, aux_losses = layer(x, attention_mask) + + if return_aux_losses: + for key, value in aux_losses.items(): + if key not in all_aux_losses: + all_aux_losses[key] = [] + all_aux_losses[key].append(value) + + # Final layer norm + x = self.final_norm(x) + + # Output projection + logits = self.output_projection(x) + + if not return_aux_losses: + return logits + + # Average auxiliary losses across layers + avg_aux_losses = {} + for key, values in all_aux_losses.items(): + if key == "expert_usage": + # For expert usage, we want to see all layers + avg_aux_losses[key] = torch.stack(values) + else: + avg_aux_losses[key] = torch.stack(values).mean() + + return logits, avg_aux_losses + + def get_num_parameters(self) -> int: + """Get total number of parameters.""" + return sum(p.numel() for p in self.parameters()) + + def get_num_active_parameters(self) -> int: + """Get number of active parameters per forward pass.""" + # This is approximate - actual active parameters depend on routing + total_params = self.get_num_parameters() + + # Estimate active expert parameters + expert_params_per_layer = 0 + for layer in self.layers: + expert_params = sum( + p.numel() + for p in layer.moe_layer.experts[0].parameters() + ) + expert_params_per_layer += ( + expert_params * layer.moe_layer.top_k + ) + + total_expert_params = sum( + sum( + p.numel() + for expert in layer.moe_layer.experts + for p in expert.parameters() + ) + for layer in self.layers + ) + + active_params = ( + total_params + - total_expert_params + + expert_params_per_layer * len(self.layers) + ) + return active_params + + +# Example usage and testing +if __name__ == "__main__": + # Configure logger + logger.add("moe_training.log", rotation="500 MB", level="INFO") + + # Model configuration + config = { + "vocab_size": 32000, + "hidden_dim": 768, + "num_layers": 12, + "num_heads": 12, + "num_experts": 8, + "top_k": 2, + "max_seq_len": 2048, + "dropout": 0.1, + } + + # Create model + model = MoETransformer(**config) + + # Print model info + total_params = model.get_num_parameters() + active_params = model.get_num_active_parameters() + + logger.info(f"Total parameters: {total_params:,}") + logger.info( + f"Active parameters per forward pass: {active_params:,}" + ) + logger.info( + f"Parameter efficiency: {active_params/total_params:.2%}" + ) + + # Test forward pass + batch_size, seq_len = 2, 512 + input_ids = torch.randint( + 0, config["vocab_size"], (batch_size, seq_len) + ) + + with torch.no_grad(): + logits, aux_losses = model(input_ids) + + logger.info(f"Input shape: {input_ids.shape}") + logger.info(f"Output shape: {logits.shape}") + logger.info(f"Auxiliary losses: {list(aux_losses.keys())}") + + # Print expert usage statistics + expert_usage = aux_losses[ + "expert_usage" + ] # [num_layers, num_experts] + logger.info(f"Expert usage shape: {expert_usage.shape}") + logger.info(f"Average expert usage: {expert_usage.mean(dim=0)}")