malt example

pull/792/head
Kye Gomez 1 month ago
parent 45342f896e
commit aafa63c682

@ -195,6 +195,7 @@ nav:
- MultiAgentRouter: "swarms/structs/multi_agent_router.md" - MultiAgentRouter: "swarms/structs/multi_agent_router.md"
- MatrixSwarm: "swarms/structs/matrix_swarm.md" - MatrixSwarm: "swarms/structs/matrix_swarm.md"
- ModelRouter: "swarms/structs/model_router.md" - ModelRouter: "swarms/structs/model_router.md"
- MALT: "swarms/structs/malt.md"
- Various Execution Methods: "swarms/structs/various_execution_methods.md" - Various Execution Methods: "swarms/structs/various_execution_methods.md"
- Workflows: - Workflows:
- ConcurrentWorkflow: "swarms/structs/concurrentworkflow.md" - ConcurrentWorkflow: "swarms/structs/concurrentworkflow.md"

@ -0,0 +1,263 @@
# MALT: Multi-Agent Learning Task Framework
## Overview
MALT (Multi-Agent Learning Task) is a sophisticated orchestration framework that coordinates multiple specialized AI agents to tackle complex tasks through structured conversations. Inspired by the principles outlined in the [MALT research paper](https://arxiv.org/pdf/2412.01928), this implementation provides a reliable, extensible system for multi-agent collaboration.
The framework is designed around a three-agent architecture:
1. **Creator Agent**: Generates initial content or solutions
2. **Verifier Agent**: Critically evaluates the creator's output
3. **Refiner Agent**: Improves the solution based on verifier feedback
This collaborative approach enables high-quality outputs for complex tasks by combining the strengths of multiple specialized agents, each focused on a different aspect of the problem-solving process.
## How It Works
The MALT framework follows a structured workflow:
1. A task is submitted to the system
2. The Creator Agent generates an initial solution
3. Multiple instances of the Verifier Agent independently evaluate the solution
4. Multiple instances of the Refiner Agent improve the solution based on verification feedback
5. The final refined output is returned
This process can be configured to run for multiple iterations, with each cycle potentially improving the quality of the output. The system maintains a conversation history, tracking interactions between agents throughout the workflow.
### Key Components
- **Agent**: Represents an individual AI agent with specific capabilities and responsibilities
- **Conversation**: Manages the interaction history between agents
- **MALT Orchestrator**: Coordinates the workflow and manages agent interactions
- **Concurrency Support**: Enables parallel execution of multiple agent instances
## Architecture Diagram
```mermaid
flowchart TD
User[User/Client] -->|Submit Task| MALT[MALT Orchestrator]
subgraph MALT Framework
MALT -->|Task| Creator[Creator Agent]
Creator -->|Initial Solution| Conversation[Conversation Manager]
Conversation -->|Solution| VerifierPool[Verifier Agents Pool]
subgraph VerifierPool
Verifier1[Verifier Agent 1]
Verifier2[Verifier Agent 2]
Verifier3[Verifier Agent 3]
end
VerifierPool -->|Verification Feedback| Conversation
Conversation -->|Solution + Feedback| RefinerPool[Refiner Agents Pool]
subgraph RefinerPool
Refiner1[Refiner Agent 1]
Refiner2[Refiner Agent 2]
Refiner3[Refiner Agent 3]
end
RefinerPool -->|Refined Solutions| Conversation
end
Conversation -->|Final Output| User
```
## Execution Workflow
```mermaid
sequenceDiagram
participant User
participant MALT
participant Creator
participant Verifiers
participant Refiners
participant Conversation
User->>MALT: Submit task
MALT->>Creator: Process task
Creator->>Conversation: Add initial solution
par Verification Phase
Conversation->>Verifiers: Send solution for verification
Verifiers->>Conversation: Return verification feedback
end
par Refinement Phase
Conversation->>Refiners: Send solution + feedback
Refiners->>Conversation: Return refined solutions
end
MALT->>Conversation: Request final output
Conversation->>MALT: Return conversation history
MALT->>User: Return final result
```
## API Reference
### MALT Class
The core orchestrator that manages the multi-agent interaction process.
#### Constructor Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `main_agent` | `Agent` | `None` | The primary agent (Creator) responsible for generating initial solutions |
| `refiner_agent` | `Agent` | `None` | The agent that refines solutions based on verification feedback |
| `verifier_agent` | `Agent` | `None` | The agent that verifies and evaluates solutions |
| `max_loops` | `int` | `1` | Maximum number of iterations for the task execution |
| `return_list` | `bool` | `False` | Flag to return output as a list |
| `return_dict` | `bool` | `False` | Flag to return output as a dictionary |
| `agents` | `list[Agent]` | `[]` | Alternative list of agents to use in the task |
| `preset_agents` | `bool` | `True` | Use default preset agents for mathematical proofs |
#### Methods
| Method | Parameters | Return Type | Description |
|--------|------------|-------------|-------------|
| `reliability_check` | None | None | Validates agent configuration and parameters |
| `step` | `task: str, img: str = None, *args, **kwargs` | `str` or `list` or `dict` | Executes a single iteration of the MALT workflow |
| `run` | `task: str, img: str = None, *args, **kwargs` | `str` or `list` or `dict` | Executes the complete MALT workflow for a task |
| `run_batched` | `tasks: List[str], *args, **kwargs` | `List[str]` or `List[list]` or `List[dict]` | Sequentially processes multiple tasks |
| `run_concurrently` | `tasks: List[str], *args, **kwargs` | `concurrent.futures.Future` | Processes multiple tasks in parallel using ThreadPoolExecutor |
| `__call__` | `task: str, *args, **kwargs` | Same as `run` | Allows the MALT instance to be called as a function |
| `__str__` | None | `str` | Returns the conversation history as a string |
| `__repr__` | None | `str` | Returns the conversation history as a string |
## Sample Implementations
### Default Mathematical Proof Agents
The MALT framework includes preset agents specialized for mathematical proof generation and refinement:
1. **Proof Creator Agent**: Generates original mathematical theorems and proofs
2. **Proof Verifier Agent**: Critically evaluates and identifies issues in mathematical proofs
3. **Proof Refiner Agent**: Improves proofs based on verification feedback
Each agent has a carefully designed system prompt that guides its behavior and specialization.
## Usage Examples
### Basic Usage
```python
from swarms.structs.agent import Agent
from swarms.structs.multi_agent_exec import MALT
# Initialize with preset mathematical proof agents
malt = MALT(preset_agents=True)
# Run a mathematical proof task
result = malt.run("Develop a theorem and proof related to prime numbers and their distribution.")
print(result)
```
### Custom Agents
```python
from swarms.structs.agent import Agent
from swarms.structs.multi_agent_exec import MALT
# Define custom agents
creator = Agent(
agent_name="Physics-Creator",
model_name="gpt-4o-mini",
max_loops=1,
system_prompt="You are a theoretical physicist specializing in quantum mechanics..."
)
verifier = Agent(
agent_name="Physics-Verifier",
model_name="gpt-4o-mini",
max_loops=1,
system_prompt="You are an experimental physicist who verifies theoretical claims..."
)
refiner = Agent(
agent_name="Physics-Communicator",
model_name="gpt-4o-mini",
max_loops=1,
system_prompt="You excel at explaining complex physics concepts to diverse audiences..."
)
# Initialize MALT with custom agents
malt = MALT(
main_agent=creator,
verifier_agent=verifier,
refiner_agent=refiner,
preset_agents=False,
max_loops=1
)
# Run a physics explanation task
result = malt.run("Explain the quantum entanglement phenomenon and its implications.")
```
### Concurrent Processing
```python
from swarms.structs.multi_agent_exec import MALT
# Initialize MALT
malt = MALT()
# Define multiple tasks
tasks = [
"Prove a theorem related to continuous functions on compact sets.",
"Develop a theorem about convergence in infinite-dimensional Hilbert spaces.",
"Create a theorem relating to measure theory and Lebesgue integration."
]
# Process tasks concurrently
futures = malt.run_concurrently(tasks)
# Collect results as they complete
for future in futures:
result = future.result()
print(result)
```
## Example: Complex Mathematical Domain
Here's an example of how MALT can generate, verify, and refine a mathematical proof:
### Input
```python
malt = MALT(preset_agents=True)
task = "Develop a theorem and rigorous proof related to the convergence properties of infinite series."
result = malt.run(task)
```
### Output Flow
1. **Creator Agent** generates a theorem and proof about conditions for absolute convergence
2. **Verifier Agents** identify issues:
- Logical gap in lemma 2
- Missing justification for uniform convergence claim
- Imprecise definition of certain terms
3. **Refiner Agents** produce improved versions addressing these concerns
4. The final output contains the refined, rigorous mathematical proof
## Best Practices
1. **Task Specificity**: Provide clear, detailed task descriptions for optimal results
2. **Agent Specialization**: Design agent prompts to focus on specific aspects of the task
3. **Iteration Control**: Adjust `max_loops` based on task complexity
4. **Concurrent Verification**: Use multiple verifier instances for comprehensive evaluation
5. **Custom Agents**: Create domain-specific agents for specialized tasks
## Potential Improvements
- Autonomously create specialized agents based on task requirements
- Implement feedback loops between agents for iterative improvement
- Add support for agent-specific memory and knowledge bases
- Expand concurrency capabilities for improved performance
- Implement learning mechanisms for agent improvement over time
## References
- Original MALT paper: [arXiv:2412.01928](https://arxiv.org/pdf/2412.01928)
- Built on the swarms framework for multi-agent systems

@ -0,0 +1,12 @@
from swarms.structs.malt import MALT
malt = MALT(
max_loops=1,
preset_agents=True,
)
malt.run(
task="Prove that the sum of the first n natural numbers is n(n+1)/2."
)
print(malt.conversation.return_json())

@ -81,6 +81,7 @@ from swarms.structs.swarms_api import (
SwarmValidationError, SwarmValidationError,
) )
from swarms.structs.agent_builder import AgentsBuilder from swarms.structs.agent_builder import AgentsBuilder
from swarms.structs.malt import MALT
__all__ = [ __all__ = [
"Agent", "Agent",
@ -154,4 +155,5 @@ __all__ = [
"SwarmValidationError", "SwarmValidationError",
"AgentInput", "AgentInput",
"AgentsBuilder", "AgentsBuilder",
"MALT",
] ]

@ -597,9 +597,9 @@ class Agent:
return llm return llm
def prepare_tools_list_dictionary(self): def prepare_tools_list_dictionary(self):
import json import json
return json.loads(self.tools_list_dictionary) return json.loads(self.tools_list_dictionary)
def check_if_no_prompt_then_autogenerate(self, task: str = None): def check_if_no_prompt_then_autogenerate(self, task: str = None):

@ -4,7 +4,7 @@ from typing import List, Optional, Tuple
from loguru import logger from loguru import logger
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from swarms import Agent from swarms.structs.agent import Agent
from swarms.utils.any_to_str import any_to_str from swarms.utils.any_to_str import any_to_str
from swarms.utils.function_caller_model import OpenAIFunctionCaller from swarms.utils.function_caller_model import OpenAIFunctionCaller
from swarms.utils.litellm_tokenizer import count_tokens from swarms.utils.litellm_tokenizer import count_tokens

@ -0,0 +1,406 @@
"""
Implementation of the MALT (Multi-Agent Learning Task) orchestrator.
MALT is a multi-agent system that orchestrates the interaction between multiple agents
to perform complex tasks through a structured conversation. It ensures reliability and
provides options for output formatting.
- Paper: https://arxiv.org/pdf/2412.01928
Potential Improvements:
- Autonomously create the agents based on the task.
- Feed verifier responses back into the creator to improve the proof.
- Feed refiner responses back into the creator to improve the proof.
-
This is a simplified implementation of the MALT orchestrator. The original implementation trains the models with dpo and sft.
Whereas this implementation uses the models as is.
"""
from ast import Mult
import concurrent.futures
from typing import List
from loguru import logger
from swarms.structs.agent import Agent
from swarms.structs.conversation import Conversation
from swarms.structs.multi_agent_exec import run_agents_concurrently
from swarms.utils.any_to_str import any_to_str
# Agent 1: Proof Creator Agent
proof_creator_prompt = """
You are a world-renowned mathematician with an extensive background in multiple advanced fields including number theory, abstract algebra, topology, advanced calculus, and mathematical logic. You are tasked with generating an original, non-trivial theorem along with a fully rigorous and detailed proof. Your response must include the following elements:
1. **Theorem Statement and Definitions:**
- Clearly articulate a novel theorem in a specific branch of mathematics.
- Provide precise definitions of any non-standard terms or constructs that appear in your theorem.
- Contextualize the theorem within the framework of existing mathematical theory, explaining its significance.
2. **Structured Proof:**
- Develop the proof in a step-by-step format that is logically coherent and rigorous.
- Include intermediate results such as lemmas, propositions, and corollaries where applicable.
- Provide thorough justifications for every logical step and transition, referencing known theorems or axioms when relevant.
- If multiple cases or conditions are involved, handle each case separately with clear demarcations.
3. **Intuitive Explanations and Motivation:**
- Supplement the formal proof with intuitive explanations that help the reader understand the underlying ideas.
- Explain why the theorem is interesting, including possible applications or implications in other areas of mathematics.
- Address potential counterexamples or special conditions that could challenge the generality of your result.
4. **Formatting and Detail:**
- Your output should be verbose, ensuring that each part of the proof is elaborated in detail.
- Use formal mathematical language but also include lay explanations for complex arguments.
- Ensure that the final document is self-contained, clear, and can be understood by both experts and advanced students.
Your response should be as comprehensive as possible, leaving no room for ambiguity, and it should reflect your mastery in constructing original mathematical arguments.
"""
proof_creator_agent = Agent(
agent_name="Proof-Creator-Agent",
model_name="gpt-4o-mini",
max_loops=1,
system_prompt=proof_creator_prompt,
)
# Agent 2: Proof Verifier Agent
proof_verifier_prompt = """
You are an esteemed mathematician and veteran academic known for your precise and critical evaluations of complex mathematical arguments. Your role is to verify the proof produced by the Proof-Creator-Agent. Your detailed analysis should include the following components:
1. **Comprehensive Logical Analysis:**
- Examine every logical step of the proof, ensuring that all transitions are valid and that no step is assumed without proper justification.
- Identify any potential errors, missing justifications, or logical gaps in the argument.
- Provide a thorough commentary on each lemma, proposition, and conclusion presented in the proof.
2. **Mathematical Rigor and Consistency:**
- Cross-reference every argument with established mathematical theories, axioms, and known results.
- Check the consistency of definitions and ensure that they are used uniformly throughout the proof.
- Address any inconsistencies or ambiguities in notation, assumptions, or logical structure.
3. **Critical Feedback and Suggestions:**
- Provide detailed feedback on the strengths and weaknesses of the proof.
- Suggest specific modifications or additional explanations that could enhance the clarity, correctness, and overall rigor.
- If applicable, propose alternative approaches or insights that could further solidify the argument.
4. **Exhaustive Review:**
- Your analysis should be extensive and methodical, examining the proof from multiple angles.
- Ensure that each critique is accompanied by a clear rationale and reference to relevant mathematical principles.
- Summarize your findings in a structured format, highlighting both the successful aspects of the proof and areas that need improvement.
Your review must be exhaustive, ensuring that even the most subtle aspects of the proof are scrutinized in depth.
"""
proof_verifier_agent = Agent(
agent_name="Proof-Verifier-Agent",
model_name="gpt-4o-mini",
max_loops=1,
system_prompt=proof_verifier_prompt,
)
# Agent 3: Proof Refiner Agent
proof_refiner_prompt = """
You are an expert in mathematical exposition and refinement with decades of experience in teaching, publishing, and peer-reviewing advanced mathematics. Your objective is to take the initial proof and the comprehensive feedback provided by the Proof-Verifier-Agent, and then produce a refined, polished version of the proof. Your refined output must address the following points:
1. **Incorporation of Verification Feedback:**
- Meticulously integrate all the detailed suggestions and critiques provided by the Proof-Verifier-Agent.
- Ensure that all logical gaps, ambiguities, and inconsistencies identified in the review are resolved in the revised proof.
- Revisit and revise definitions, lemmas, and intermediate steps where necessary to ensure complete logical consistency.
2. **Enhanced Clarity and Structure:**
- Reorganize the proof for optimal flow and readability, ensuring that each section leads naturally to the next.
- Add comprehensive explanations where needed, emphasizing intuitive reasoning alongside formal arguments.
- Break down complex sections into more manageable parts, and ensure that each is clearly labeled and explained.
3. **Rigorous Detailing and Presentation:**
- Enhance the overall presentation of the proof by ensuring that every assertion is supported by detailed justifications.
- Include additional commentary that not only defends the logical integrity of the argument but also explains its broader significance.
- Maintain a balance between rigorous formalism and accessible exposition so that the refined proof appeals to both experts and advanced learners.
4. **Comprehensive Feedback and Rationale:**
- For every modification made, provide an accompanying explanation that outlines the rationale behind the change.
- If any aspects of the original proof were retained, clarify why they were considered adequate and how they contribute to the overall argument.
- Ensure that your final output is a cohesive, self-contained document that stands up to critical academic scrutiny.
Your refined proof should be a masterpiece of mathematical writing, addressing all the feedback with detailed revisions and explanations.
"""
proof_refiner_agent = Agent(
agent_name="Proof-Refiner-Agent",
model_name="gpt-4o-mini",
max_loops=1,
system_prompt=proof_refiner_prompt,
)
majority_voting_prompt = """
Engage in a comprehensive and exhaustive majority voting analysis of the following conversation, ensuring a deep and thoughtful examination of the responses provided by each agent. This analysis should not only summarize the responses but also critically engage with the content, context, and implications of each agent's input.
Please adhere to the following detailed guidelines:
1. **Identification of Dominant Responses:**
- Identify the most prevalent answer or recommendation across all agents. Provide a thorough rationale for its dominance, including an exploration of the factors that may have contributed to its acceptance among the agents. Discuss the context in which this consensus emerged and any relevant historical or theoretical frameworks that support this conclusion.
2. **Exploration of Disparities:**
- Delve into any significant disparities or contrasting viewpoints between agents. Explore the underlying reasons for these differences, considering aspects such as differing methodologies, assumptions, or interpretations of the task at hand. Analyze how these contrasting perspectives may reflect broader debates within the field and what implications they hold for the overall understanding of the topic.
3. **Consensus and Disagreement Analysis:**
- Highlight key areas of consensus and disagreement among the agents. Discuss the implications of these findings on the overall argument, including how consensus can strengthen certain claims while disagreement may indicate areas of uncertainty or contention. Provide examples from the conversation to illustrate these points and consider how they might influence future discussions or research directions.
4. **Critical Evaluation of Majority Opinion:**
- Critically evaluate the strength of the majority opinion, considering factors such as the reasoning behind it and its mathematical validity if applicable. Assess whether the majority opinion is well-supported by evidence and logical reasoning, and discuss any potential weaknesses or oversights that may undermine its credibility.
5. **Insights from Minority Viewpoints:**
- Note any unique insights from minority viewpoints, assessing their potential contributions to a more nuanced understanding of the topic. Discuss how these minority perspectives can enrich the conversation and provide alternative angles that may have been overlooked by the majority. Consider the value of dissent in academic discourse and how it can lead to more robust conclusions.
6. **Synthesis of Recommendations:**
- Provide a final synthesized recommendation based on the majority consensus, ensuring that it reflects a thorough consideration of all perspectives and is grounded in sound reasoning. This recommendation should not only summarize the majority view but also integrate insights from minority opinions, creating a comprehensive and balanced conclusion that acknowledges the complexity of the discussion.
Throughout your analysis, focus on uncovering clear patterns while being attentive to the subtleties and complexities inherent in the responses. Pay particular attention to the nuances of mathematical contexts where algorithmic thinking may be required, ensuring that your examination is both rigorous and accessible to a diverse audience.
"""
majority_voting_agent = Agent(
agent_name="Majority-Voting-Agent",
model_name="gpt-4o-mini",
max_loops=1,
system_prompt=majority_voting_prompt,
)
class MALT:
"""
MALT (Mult-Agent Learning Task) orchestrates the interaction between multiple agents
to perform complex tasks through a structured conversation. It ensures reliability and
provides options for output formatting.
Attributes:
main_agent (Agent): The primary agent responsible for executing the main task.
refiner_agent (Agent): The agent that refines the output of the main agent.
verifier_agent (Agent): The agent that verifies the output of the main agent.
max_loops (int): The maximum number of iterations for the task execution.
return_list (bool): Flag to return output as a list.
return_dict (bool): Flag to return output as a dictionary.
agents (list[Agent]): A list of agents to be used in the task.
conversation (Conversation): Manages the conversation history between agents.
"""
def __init__(
self,
main_agent: Agent = None,
refiner_agent: Agent = None,
verifier_agent: Agent = None,
max_loops: int = 1,
return_list: bool = False,
return_dict: bool = False,
agents: list[Agent] = [],
preset_agents: bool = True,
):
logger.info(
"Initializing MALT with provided agents and parameters."
)
self.main_agent = main_agent
self.refiner_agent = refiner_agent
self.verifier_agent = verifier_agent
self.max_loops = max_loops
self.return_list = return_list
self.return_dict = return_dict
self.agents = agents
self.conversation = Conversation()
logger.debug("Conversation initialized.")
if preset_agents:
self.main_agent = proof_creator_agent
self.refiner_agent = proof_refiner_agent
self.verifier_agent = proof_verifier_agent
self.reliability_check()
def reliability_check(self):
"""Checks the reliability of the provided agents and parameters."""
logger.info("Performing reliability check.")
if self.max_loops == 0 or self.max_loops is None:
logger.error("max_loops must be greater than 0")
raise ValueError("max_loops must be greater than 0")
# Check if agents list is provided and not empty when needed
if not self.agents and (
self.main_agent is None
or self.refiner_agent is None
or self.verifier_agent is None
):
logger.error(
"Missing agents: Provide individual agents or a list of at least 3 agents."
)
raise ValueError(
"Either provide individual agents (main_agent, refiner_agent, verifier_agent) or a list of at least 3 agents"
)
# If individual agents aren't specified but we have agents list, use the first three
if (
self.main_agent is None
or self.refiner_agent is None
or self.verifier_agent is None
) and len(self.agents) >= 3:
self.main_agent = self.main_agent or self.agents[0]
self.refiner_agent = self.refiner_agent or self.agents[1]
self.verifier_agent = (
self.verifier_agent or self.agents[2]
)
# Final check to ensure we have all required agents
if (
self.main_agent is None
or self.refiner_agent is None
or self.verifier_agent is None
):
logger.error("Missing required agents.")
raise ValueError(
"Missing required agents: main_agent, refiner_agent, and verifier_agent must all be provided"
)
logger.info("Reliability check passed.")
def step(self, task: str, img: str = None, *args, **kwargs):
"""Executes the task using the main agent and processes the output through verifier and refiner agents.
Args:
task (str): The task to be executed by the main agent.
img (str, optional): An optional image input for the agents.
Returns:
str or list or dict: The output from the conversation based on the specified return format.
"""
self.conversation.add(
role="user",
content=task,
)
logger.info("Running task with main agent.")
main_agent_output = self.main_agent.run(
task=task, img=img, *args, **kwargs
)
self.conversation.add(
role=self.main_agent.agent_name, content=main_agent_output
)
logger.info("Running task with verifier agents")
verified_outputs = run_agents_concurrently(
[
self.verifier_agent,
self.verifier_agent,
self.verifier_agent,
],
task=main_agent_output,
max_workers=3,
)
self.conversation.add(
role=self.verifier_agent.agent_name,
content=verified_outputs,
)
######################### MAJORITY VOTING #########################
# Majority Voting on the verified outputs
majority_voting_verified = majority_voting_agent.run(
task=any_to_str(verified_outputs),
)
self.conversation.add(
role=self.majority_voting_agent.agent_name,
content=majority_voting_verified,
)
#########################################################
# Refining the majority voting output
logger.info("Running task with refiner agents")
for output in verified_outputs:
refined_outputs = run_agents_concurrently(
[
self.refiner_agent,
self.refiner_agent,
self.refiner_agent,
],
task=output,
max_workers=3,
)
logger.debug(f"Refined outputs: {refined_outputs}")
self.conversation.add(
role=self.refiner_agent.agent_name,
content=refined_outputs,
)
return self.conversation.get_str()
def run(self, task: str, img: str = None, *args, **kwargs):
"""Executes the task using the main agent and processes the output through verifier and refiner agents.
Args:
task (str): The task to be executed by the main agent.
img (str, optional): An optional image input for the agents.
Returns:
str or list or dict: The output from the conversation based on the specified return format.
"""
task = task
for i in range(self.max_loops):
logger.info(f"Starting iteration {i+1}/{self.max_loops}")
output = self.step(task, img, *args, **kwargs)
if output is not None:
return output
if self.return_list:
logger.info("Returning output as a list.")
return self.conversation.return_messages_as_list()
elif self.return_dict:
logger.info("Returning output as a dictionary.")
return self.conversation.return_messages_as_dictionary()
else:
logger.info("Returning output as a string.")
return self.conversation.get_str()
def run_batched(self, tasks: List[str], *args, **kwargs):
"""Executes a list of tasks using the main agent and processes the output through verifier and refiner agents.
Args:
tasks (list[str]): The list of tasks to be executed by the main agent.
"""
logger.info("Running batch of tasks.")
logger.info(f"Number of tasks: {len(tasks)}")
outputs = []
for task in tasks:
outputs.append(self.run(task, *args, **kwargs))
return outputs
def __call__(self, task: str, *args, **kwargs):
return self.run(task, *args, **kwargs)
def __str__(self):
return self.conversation.get_str()
def __repr__(self):
return self.conversation.get_str()
def run_concurrently(self, tasks: List[str], *args, **kwargs):
"""Executes a list of tasks using the main agent and processes the output through verifier and refiner agents.
Args:
tasks (list[str]): The list of tasks to be executed by the main agent.
"""
logger.info("Running batch of tasks concurrently.")
logger.info(f"Number of tasks: {len(tasks)}")
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.run, task, *args, **kwargs)
for task in tasks
]
return concurrent.futures.as_completed(futures)

@ -20,6 +20,7 @@ from swarms.structs.spreadsheet_swarm import SpreadSheetSwarm
from swarms.structs.swarm_matcher import swarm_matcher from swarms.structs.swarm_matcher import swarm_matcher
from swarms.structs.output_type import OutputType from swarms.structs.output_type import OutputType
from swarms.utils.loguru_logger import initialize_logger from swarms.utils.loguru_logger import initialize_logger
from swarms.structs.malt import MALT
logger = initialize_logger(log_folder="swarm_router") logger = initialize_logger(log_folder="swarm_router")
@ -35,6 +36,7 @@ SwarmType = Literal[
"HiearchicalSwarm", "HiearchicalSwarm",
"auto", "auto",
"MajorityVoting", "MajorityVoting",
"MALT",
] ]
@ -305,6 +307,15 @@ class SwarmRouter:
**kwargs, **kwargs,
) )
elif self.swarm_type == "MALT":
return MALT(
name=self.name,
description=self.description,
max_loops=self.max_loops,
return_dict=True,
preset_agents=True,
)
elif self.swarm_type == "HiearchicalSwarm": elif self.swarm_type == "HiearchicalSwarm":
return HierarchicalSwarm( return HierarchicalSwarm(
name=self.name, name=self.name,
@ -442,18 +453,10 @@ class SwarmRouter:
self.swarm = self._create_swarm(task, *args, **kwargs) self.swarm = self._create_swarm(task, *args, **kwargs)
try: try:
# self._log( logger.info(f"Running task on {self.swarm_type} swarm with task: {task}")
# "info",
# f"Running task on {self.swarm_type} swarm with task: {task}",
# )
result = self.swarm.run(task=task, *args, **kwargs) result = self.swarm.run(task=task, *args, **kwargs)
# self._log( logger.info("Swarm completed successfully")
# "success",
# f"Task completed successfully on {self.swarm_type} swarm",
# task=task,
# metadata={"result": str(result)},
# )
return result return result
except Exception as e: except Exception as e:
self._log( self._log(

@ -1,964 +0,0 @@
"""
MultiModelOptimizer: A high-performance optimizer for training multiple transformer models simultaneously.
This optimizer implements several advanced techniques:
1. Gradient accumulation with dynamic batch sizing
2. Hierarchical parameter synchronization
3. Memory-efficient gradient sharing with shape compatibility
4. Adaptive learning rate scheduling per model
5. Convergence acceleration via momentum tuning
6. Robust error handling for production environments
Author: Claude 3.7 Sonnet
License: MIT
"""
import math
from typing import Dict, List, Optional, Tuple, Callable
from collections import defaultdict
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from loguru import logger
import numpy as np
class MultiModelOptimizer(Optimizer):
"""
An optimizer designed for training multiple models simultaneously with shared gradient information,
adaptive learning rates, and efficient memory usage.
Args:
models (Dict[str, nn.Module]): Dictionary mapping model names to model instances
lr (float, optional): Initial learning rate. Default: 1e-3
betas (Tuple[float, float], optional): Coefficients for computing running averages of gradient and its square. Default: (0.9, 0.999)
eps (float, optional): Term added to denominator for numerical stability. Default: 1e-8
weight_decay (float, optional): Weight decay coefficient. Default: 0
amsgrad (bool, optional): Whether to use the AMSGrad variant. Default: False
grad_sync_frequency (int, optional): How often to synchronize gradients between models. Default: 1
warmup_steps (int, optional): Number of warmup steps for learning rate. Default: 1000
model_weights (Dict[str, float], optional): Relative importance weights for each model. Default: None
gradient_accumulation_steps (int, optional): Number of steps to accumulate gradients before update. Default: 1
clip_grad_norm (float, optional): Maximum norm for gradient clipping. Default: None
use_cosine_schedule (bool, optional): Whether to use cosine annealing schedule. Default: True
sync_every_step (bool, optional): Whether to synchronize parameters on every step. Default: False
"""
def __init__(
self,
models: Dict[str, nn.Module],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
amsgrad: bool = False,
grad_sync_frequency: int = 1,
warmup_steps: int = 1000,
model_weights: Optional[Dict[str, float]] = None,
gradient_accumulation_steps: int = 1,
clip_grad_norm: Optional[float] = None,
use_cosine_schedule: bool = True,
sync_every_step: bool = False,
):
# Initialize model weights if not provided
if model_weights is None:
model_weights = {name: 1.0 for name in models.keys()}
# Normalize weights to sum to 1
total_weight = sum(model_weights.values())
self.model_weights = {
k: v / total_weight for k, v in model_weights.items()
}
# Store models
self.models = models
# Collect all parameters from all models
parameters = []
self.model_param_groups: Dict[str, List[Dict]] = {}
for model_name, model in models.items():
model_params = []
for param in model.parameters():
if param.requires_grad:
param_dict = {
"params": [param],
"model_name": model_name,
}
parameters.append(param_dict)
model_params.append(param_dict)
self.model_param_groups[model_name] = model_params
# Initialize optimizer with all parameters
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
)
super(MultiModelOptimizer, self).__init__(
parameters, defaults
)
# Additional settings
self.grad_sync_frequency = grad_sync_frequency
self.warmup_steps = warmup_steps
self.step_count = 0
self.gradient_accumulation_steps = gradient_accumulation_steps
self.current_accumulation_step = 0
self.clip_grad_norm = clip_grad_norm
self.use_cosine_schedule = use_cosine_schedule
self.sync_every_step = sync_every_step
# Metrics and tracking
self.model_losses: Dict[str, List[float]] = defaultdict(list)
self.model_gradients: Dict[str, torch.Tensor] = {}
self.shared_gradient_cache: Dict[str, torch.Tensor] = {}
# Set up gradient sharing structures
self.param_name_to_model = {}
for name, model in self.models.items():
for param_name, _ in model.named_parameters():
self.param_name_to_model[f"{name}.{param_name}"] = (
name
)
# Configure logger
logger.configure(
handlers=[
{
"sink": "logs/multi_model_optimizer.log",
"level": "INFO",
},
{"sink": lambda msg: print(msg), "level": "INFO"},
]
)
logger.info(
f"Initialized MultiModelOptimizer with {len(models)} models"
)
for name, weight in self.model_weights.items():
logger.info(f"Model {name} weight: {weight:.4f}")
def get_lr_multiplier(self) -> float:
"""Calculate the learning rate multiplier based on warmup and schedule."""
if self.step_count < self.warmup_steps:
# Linear warmup
return float(self.step_count) / float(
max(1, self.warmup_steps)
)
if self.use_cosine_schedule:
# Cosine decay after warmup
decay_steps = max(1, self.step_count - self.warmup_steps)
cosine_decay = 0.5 * (
1
+ math.cos(
math.pi
* decay_steps
/ (10000 * self.gradient_accumulation_steps)
)
)
return max(
0.1, cosine_decay
) # Don't let LR go below 10% of base value
return 1.0 # Constant LR after warmup if not using cosine
def share_gradients(self):
"""Share gradient information across models for similar parameters."""
# First, collect all gradients by parameter type and shape
param_type_shape_grads = defaultdict(list)
for model_name, model in self.models.items():
for param_name, param in model.named_parameters():
if param.grad is not None:
# Classify parameter by name pattern and include shape to ensure compatibility
param_type = self._classify_parameter(param_name)
param_shape = param.shape
key = (param_type, param_shape)
param_type_shape_grads[key].append(
(model_name, param_name, param.grad)
)
# Now compute shared gradients for each parameter type and shape combination
for (
param_type,
param_shape,
), grads in param_type_shape_grads.items():
if len(grads) <= 1:
continue # Skip if only one model has this parameter type+shape
cache_key = f"{param_type}_{param_shape}"
# Compute weighted average gradient for this parameter type+shape
for model_name, param_name, grad in grads:
weight = self.model_weights[model_name]
# Initialize shared gradient for this parameter if not exists
if cache_key not in self.shared_gradient_cache:
self.shared_gradient_cache[cache_key] = (
torch.zeros_like(grad)
)
# Add weighted contribution
self.shared_gradient_cache[cache_key].add_(
grad * weight
)
# Now apply a fraction of the shared gradient back to each model's parameter
for model_name, param_name, _ in grads:
param = self.models[model_name].get_parameter(
param_name
)
if param.grad is not None:
# Mix original gradient with shared gradient
sharing_ratio = 0.2 # 20% shared, 80% original
param.grad.mul_(1 - sharing_ratio).add_(
self.shared_gradient_cache[cache_key]
* sharing_ratio
)
# Clear the cache for next iteration
self.shared_gradient_cache.clear()
def _classify_parameter(self, param_name: str) -> str:
"""Classify parameter by name to determine which parameters should share gradients."""
# First, make sure we include the model architecture in the classification
# to prevent mixing parameters from different architectures
model_type = "unknown"
if "bert" in param_name:
model_type = "bert"
elif "gpt" in param_name:
model_type = "gpt"
elif "roberta" in param_name:
model_type = "roberta"
elif "transformer" in param_name:
model_type = "transformer"
# Then classify by parameter type
param_type = "other"
if (
"query" in param_name
or "key" in param_name
or "value" in param_name
):
param_type = "attention"
elif (
"dense" in param_name
or "fc" in param_name
or "ffn" in param_name
):
param_type = "ffn"
elif "embedding" in param_name:
param_type = "embedding"
elif "norm" in param_name or "layer_norm" in param_name:
param_type = "norm"
elif "bias" in param_name:
param_type = "bias"
else:
param_type = param_name.split(".")[
-1
] # Use the last component of the name
# Combine model type and parameter type for more specific classification
return f"{model_type}_{param_type}"
def step(
self, closure: Optional[Callable[[], float]] = None
) -> Optional[float]:
"""Perform a single optimization step, handling gradient accumulation and sync."""
loss = None
if closure is not None:
loss = closure()
self.current_accumulation_step += 1
# Only perform the update after accumulating enough gradients
if (
self.current_accumulation_step
< self.gradient_accumulation_steps
):
return loss
self.current_accumulation_step = 0
self.step_count += 1
# Apply gradient clipping if configured
if self.clip_grad_norm is not None:
for model_name, model in self.models.items():
torch.nn.utils.clip_grad_norm_(
model.parameters(), self.clip_grad_norm
)
# Share gradients between models if it's time
if self.step_count % self.grad_sync_frequency == 0:
self.share_gradients()
# Calculate the current learning rate multiplier
lr_multiplier = self.get_lr_multiplier()
# Apply optimizer update for each parameter group
for group in self.param_groups:
# Get model-specific learning rate adjustment
model_name = group["model_name"]
model_weight = self.model_weights[model_name]
# Adjust lr based on model weight and global multiplier
model_lr_multiplier = lr_multiplier * (
0.5 + 0.5 * model_weight
) # Scale between 50-150% based on weight
# Extract parameters for this group
p = group["params"][0]
if p.grad is None:
continue
# State initialization
state = self.state[p]
if len(state) == 0:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if group["amsgrad"]:
state["max_exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Extract optimizer parameters
beta1, beta2 = group["betas"]
exp_avg, exp_avg_sq = (
state["exp_avg"],
state["exp_avg_sq"],
)
# Update step count
state["step"] += 1
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(
p.grad, p.grad, value=1 - beta2
)
# Apply AMSGrad if enabled
if group["amsgrad"]:
max_exp_avg_sq = state["max_exp_avg_sq"]
torch.maximum(
max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq
)
denom = max_exp_avg_sq.sqrt().add_(group["eps"])
else:
denom = exp_avg_sq.sqrt().add_(group["eps"])
# Bias correction
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
step_size = (
group["lr"]
* model_lr_multiplier
* math.sqrt(bias_correction2)
/ bias_correction1
)
# Apply weight decay if configured
if group["weight_decay"] > 0:
p.data.add_(
p.data,
alpha=-group["weight_decay"]
* group["lr"]
* model_lr_multiplier,
)
# Update parameter
p.data.addcdiv_(exp_avg, denom, value=-step_size)
# Synchronize parameters if configured to do so every step
if self.sync_every_step:
self.synchronize_similar_parameters()
return loss
def synchronize_similar_parameters(self):
"""Synchronize similar parameters across models to promote convergence."""
# Only sync occasionally
if self.step_count % 10 != 0:
return
try:
# First, identify similar parameters across models
param_groups = defaultdict(list)
for model_name, model in self.models.items():
for param_name, param in model.named_parameters():
# Only sync parameters of the same shape
param_type = self._classify_parameter(param_name)
param_shape = param.shape
param_groups[(param_type, param_shape)].append(
(model_name, param_name, param)
)
# For each group of similar parameters, synchronize values
for (
param_type,
param_shape,
), params in param_groups.items():
if len(params) <= 1:
continue # Skip if only one parameter in this group
# Calculate weighted average
avg_param = None
total_weight = 0.0
for model_name, _, param in params:
weight = self.model_weights[model_name]
total_weight += weight
if avg_param is None:
avg_param = param.data.clone() * weight
else:
avg_param.add_(param.data * weight)
if total_weight > 0:
avg_param.div_(total_weight)
# Mix original parameters with the average (soft sync)
sync_ratio = 0.1 # 10% shared, 90% original
for _, _, param in params:
param.data.mul_(1 - sync_ratio).add_(
avg_param * sync_ratio
)
except Exception as e:
logger.error(
f"Error during parameter synchronization: {e}"
)
logger.error("Skipping synchronization for this step")
def log_metrics(self, model_losses: Dict[str, float]):
"""Log training metrics and update loss tracking."""
for model_name, loss in model_losses.items():
self.model_losses[model_name].append(loss)
# Log metrics every 100 steps
if self.step_count % 100 == 0:
avg_losses = {
name: np.mean(losses[-100:])
for name, losses in self.model_losses.items()
if losses
}
current_lr = (
self.param_groups[0]["lr"] * self.get_lr_multiplier()
)
logger.info(f"Step {self.step_count}")
logger.info(f"Current learning rate: {current_lr:.6f}")
for model_name, avg_loss in avg_losses.items():
logger.info(
f"Model {model_name} - Avg loss: {avg_loss:.4f}"
)
def state_dict(self) -> Dict:
"""Return the optimizer state dict with additional MultiModelOptimizer specifics."""
state_dict = super(MultiModelOptimizer, self).state_dict()
state_dict["model_weights"] = self.model_weights
state_dict["step_count"] = self.step_count
state_dict["current_accumulation_step"] = (
self.current_accumulation_step
)
return state_dict
def load_state_dict(self, state_dict: Dict):
"""Load optimizer state with MultiModelOptimizer specifics."""
self.model_weights = state_dict.pop("model_weights")
self.step_count = state_dict.pop("step_count")
self.current_accumulation_step = state_dict.pop(
"current_accumulation_step"
)
super(MultiModelOptimizer, self).load_state_dict(state_dict)
class MultiModelScheduler(_LRScheduler):
"""
A learning rate scheduler designed to work with MultiModelOptimizer,
supporting different schedules for different models based on their convergence rates.
Args:
optimizer (MultiModelOptimizer): The optimizer to schedule
total_steps (int): Total number of training steps
warmup_steps (int, optional): Number of warmup steps. Default: 1000
min_lr_ratio (float, optional): Minimum learning rate as a fraction of max. Default: 0.1
model_schedule_weights (Dict[str, float], optional): Per-model schedule weights. Default: None
last_epoch (int, optional): The index of the last epoch. Default: -1
"""
def __init__(
self,
optimizer: MultiModelOptimizer,
total_steps: int,
warmup_steps: int = 1000,
min_lr_ratio: float = 0.1,
model_schedule_weights: Optional[Dict[str, float]] = None,
last_epoch: int = -1,
):
self.total_steps = total_steps
self.warmup_steps = warmup_steps
self.min_lr_ratio = min_lr_ratio
# Use optimizer's model weights if not provided
if model_schedule_weights is None:
self.model_schedule_weights = optimizer.model_weights
else:
self.model_schedule_weights = model_schedule_weights
self.model_convergence_rates: Dict[str, float] = {
name: 1.0 for name in self.model_schedule_weights.keys()
}
super(MultiModelScheduler, self).__init__(
optimizer, last_epoch
)
def get_lr(self):
"""Calculate learning rates for all parameter groups."""
if not self._get_lr_called_within_step:
logger.warning(
"To get the last learning rate computed by the scheduler, please use `get_last_lr()`."
)
# Apply warmup
if self.last_epoch < self.warmup_steps:
lr_scale = float(self.last_epoch) / float(
max(1, self.warmup_steps)
)
else:
# Cosine decay after warmup
progress = float(
self.last_epoch - self.warmup_steps
) / float(max(1, self.total_steps - self.warmup_steps))
lr_scale = max(
self.min_lr_ratio,
0.5 * (1.0 + math.cos(math.pi * progress)),
)
# Apply model-specific adjustments based on convergence rates
lrs = []
for group in self.optimizer.param_groups:
model_name = group["model_name"]
# Adjust learning rate based on model convergence rate
model_lr = group["initial_lr"] * lr_scale
# Apply model-specific adjustment
if model_name in self.model_convergence_rates:
# Models with higher convergence rates get lower learning rates
conv_rate = self.model_convergence_rates[model_name]
model_lr *= max(0.5, min(1.5, 1.0 / conv_rate))
lrs.append(model_lr)
return lrs
def update_convergence_rates(
self, model_losses: Dict[str, List[float]], window: int = 100
):
"""
Update convergence rate estimates based on recent loss trends.
Args:
model_losses: Dictionary mapping model names to their loss histories
window: Number of steps to consider for convergence estimation
"""
for model_name, losses in model_losses.items():
if len(losses) < window:
continue
# Use recent loss values
recent_losses = losses[-window:]
# Calculate slope of loss curve
x = np.arange(len(recent_losses))
y = np.array(recent_losses)
# Simple linear regression to estimate convergence rate
slope, _ = np.polyfit(x, y, 1)
# Normalize slope to a convergence rate
# Negative slope is good (loss is decreasing)
norm_rate = 1.0 / (1.0 + abs(slope))
# Update with exponential moving average
old_rate = self.model_convergence_rates.get(
model_name, 1.0
)
self.model_convergence_rates[model_name] = (
0.9 * old_rate + 0.1 * norm_rate
)
# Log updated convergence rates
logger.info("Updated model convergence rates:")
for model_name, rate in self.model_convergence_rates.items():
logger.info(f" {model_name}: {rate:.4f}")
# Usage example with real dataset
def example_usage_with_real_data():
"""Example demonstrating how to use MultiModelOptimizer with real data from GLUE."""
try:
# Import required libraries
from transformers import (
BertForSequenceClassification,
GPT2ForSequenceClassification,
RobertaForSequenceClassification,
BertTokenizer,
GPT2Tokenizer,
RobertaTokenizer,
DataCollatorWithPadding,
)
from datasets import load_dataset
from torch.utils.data import DataLoader
# Set up logging
logger.info(
"=== Starting MultiModelOptimizer example with real GLUE data ==="
)
# Load SST-2 dataset from GLUE (small sentiment classification dataset)
logger.info("Loading SST-2 dataset from GLUE...")
sst2_dataset = load_dataset("glue", "sst2")
train_dataset = sst2_dataset["train"].select(
range(1000)
) # Use only 1000 examples for quick training
# Load tokenizers
logger.info("Loading tokenizers...")
bert_tokenizer = BertTokenizer.from_pretrained(
"bert-base-uncased"
)
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
roberta_tokenizer = RobertaTokenizer.from_pretrained(
"roberta-base"
)
# Add padding token to GPT2 tokenizer (it doesn't have one by default)
if gpt2_tokenizer.pad_token is None:
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token
# Tokenization functions
def tokenize_bert(examples):
return bert_tokenizer(
examples["sentence"], truncation=True, max_length=128
)
def tokenize_gpt2(examples):
return gpt2_tokenizer(
examples["sentence"], truncation=True, max_length=128
)
def tokenize_roberta(examples):
return roberta_tokenizer(
examples["sentence"], truncation=True, max_length=128
)
# Tokenize datasets for each model
logger.info("Tokenizing datasets...")
bert_dataset = train_dataset.map(tokenize_bert, batched=True)
gpt2_dataset = train_dataset.map(tokenize_gpt2, batched=True)
roberta_dataset = train_dataset.map(
tokenize_roberta, batched=True
)
# Set format for PyTorch
bert_dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "label"],
)
gpt2_dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "label"],
)
roberta_dataset.set_format(
type="torch",
columns=["input_ids", "attention_mask", "label"],
)
# Create data collators
bert_data_collator = DataCollatorWithPadding(
tokenizer=bert_tokenizer
)
gpt2_data_collator = DataCollatorWithPadding(
tokenizer=gpt2_tokenizer
)
roberta_data_collator = DataCollatorWithPadding(
tokenizer=roberta_tokenizer
)
# Create dataloaders
logger.info("Creating dataloaders...")
batch_size = 16
bert_dataloader = DataLoader(
bert_dataset,
batch_size=batch_size,
collate_fn=bert_data_collator,
)
gpt2_dataloader = DataLoader(
gpt2_dataset,
batch_size=batch_size,
collate_fn=gpt2_data_collator,
)
roberta_dataloader = DataLoader(
roberta_dataset,
batch_size=batch_size,
collate_fn=roberta_data_collator,
)
# Load models for sequence classification
logger.info(
"Loading transformer models for sequence classification..."
)
models = {
"bert": BertForSequenceClassification.from_pretrained(
"bert-base-uncased", num_labels=2
),
"gpt2": GPT2ForSequenceClassification.from_pretrained(
"gpt2", num_labels=2
),
"roberta": RobertaForSequenceClassification.from_pretrained(
"roberta-base", num_labels=2
),
}
# Set up optimizer with different weights for each model
logger.info("Setting up MultiModelOptimizer...")
optimizer = MultiModelOptimizer(
models=models,
lr=3e-5,
betas=(0.9, 0.999),
weight_decay=0.01,
model_weights={"bert": 1.0, "gpt2": 0.7, "roberta": 1.3},
gradient_accumulation_steps=2,
clip_grad_norm=1.0,
warmup_steps=100,
grad_sync_frequency=50,
)
# Set up scheduler
scheduler = MultiModelScheduler(
optimizer=optimizer, total_steps=5000, warmup_steps=100
)
# Move models to GPU if available
device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
logger.info(f"Using device: {device}")
for model_name, model in models.items():
models[model_name] = model.to(device)
# Create iterator function for each dataloader
def infinite_iterator(dataloader):
while True:
for batch in dataloader:
yield batch
bert_iter = infinite_iterator(bert_dataloader)
gpt2_iter = infinite_iterator(gpt2_dataloader)
roberta_iter = infinite_iterator(roberta_dataloader)
# Define metrics for tracking
from sklearn.metrics import accuracy_score
total_steps = 1000 # Total training steps
eval_every = 100 # Evaluate every 100 steps
best_accuracy = {"bert": 0.0, "gpt2": 0.0, "roberta": 0.0}
logger.info(f"Starting training for {total_steps} steps...")
# Training loop
for step in range(total_steps):
# Zero gradients
optimizer.zero_grad()
losses = {}
try:
# For BERT
bert_batch = next(bert_iter)
bert_batch = {
k: v.to(device) for k, v in bert_batch.items()
}
bert_outputs = models["bert"](**bert_batch)
bert_loss = bert_outputs.loss
bert_loss.backward()
losses["bert"] = bert_loss.item()
# For GPT2
gpt2_batch = next(gpt2_iter)
gpt2_batch = {
k: v.to(device) for k, v in gpt2_batch.items()
}
gpt2_outputs = models["gpt2"](**gpt2_batch)
gpt2_loss = gpt2_outputs.loss
gpt2_loss.backward()
losses["gpt2"] = gpt2_loss.item()
# For RoBERTa
roberta_batch = next(roberta_iter)
roberta_batch = {
k: v.to(device) for k, v in roberta_batch.items()
}
roberta_outputs = models["roberta"](**roberta_batch)
roberta_loss = roberta_outputs.loss
roberta_loss.backward()
losses["roberta"] = roberta_loss.item()
# Log metrics
optimizer.log_metrics(losses)
# Step the optimizer and scheduler
optimizer.step()
scheduler.step()
# Update convergence rates periodically
if step % 100 == 0:
scheduler.update_convergence_rates(
optimizer.model_losses
)
# Evaluate periodically
if step > 0 and step % eval_every == 0:
logger.info(f"Evaluating at step {step}...")
# Create a small evaluation set
eval_dataset = sst2_dataset["validation"].select(
range(100)
)
for model_name, model in models.items():
model.eval()
# Tokenize evaluation data based on model type
if model_name == "bert":
tokenizer = bert_tokenizer
tokenize_fn = tokenize_bert
elif model_name == "gpt2":
tokenizer = gpt2_tokenizer
tokenize_fn = tokenize_gpt2
else: # roberta
tokenizer = roberta_tokenizer
tokenize_fn = tokenize_roberta
eval_tokenized = eval_dataset.map(
tokenize_fn, batched=True
)
eval_tokenized.set_format(
type="torch",
columns=[
"input_ids",
"attention_mask",
"label",
],
)
# Create dataloader
eval_collator = DataCollatorWithPadding(
tokenizer=tokenizer
)
eval_dataloader = DataLoader(
eval_tokenized,
batch_size=16,
collate_fn=eval_collator,
)
# Evaluate
all_preds = []
all_labels = []
with torch.no_grad():
for batch in eval_dataloader:
batch = {
k: v.to(device)
for k, v in batch.items()
}
outputs = model(**batch)
logits = outputs.logits
preds = (
torch.argmax(logits, dim=-1)
.cpu()
.numpy()
)
labels = batch["label"].cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels)
# Calculate accuracy
accuracy = accuracy_score(
all_labels, all_preds
)
logger.info(
f"Model {model_name} - Accuracy: {accuracy:.4f}"
)
# Save best model
if accuracy > best_accuracy[model_name]:
best_accuracy[model_name] = accuracy
torch.save(
model.state_dict(),
f"best_{model_name}_model.pt",
)
logger.info(
f"Saved new best {model_name} model with accuracy {accuracy:.4f}"
)
model.train()
except RuntimeError as e:
logger.error(
f"Error during training step {step}: {e}"
)
logger.error("Skipping this step and continuing...")
optimizer.zero_grad()
continue
# Save checkpoint every 500 steps
if step > 0 and step % 500 == 0:
logger.info(f"Saving checkpoint at step {step}...")
torch.save(
{
"step": step,
"model_states": {
name: model.state_dict()
for name, model in models.items()
},
"optimizer_state": optimizer.state_dict(),
"scheduler_state": scheduler.state_dict(),
"best_accuracy": best_accuracy,
},
f"checkpoint_step_{step}.pt",
)
# Final evaluation and results
logger.info("=== Training complete! Final results ===")
for model_name, acc in best_accuracy.items():
logger.info(f"Best {model_name} accuracy: {acc:.4f}")
except Exception as e:
logger.error(
f"Fatal error in example_usage_with_real_data: {e}"
)
import traceback
logger.error(traceback.format_exc())
if __name__ == "__main__":
# Use real data example by default
example_usage_with_real_data()
Loading…
Cancel
Save