parent
45342f896e
commit
aafa63c682
@ -0,0 +1,263 @@
|
|||||||
|
# MALT: Multi-Agent Learning Task Framework
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
MALT (Multi-Agent Learning Task) is a sophisticated orchestration framework that coordinates multiple specialized AI agents to tackle complex tasks through structured conversations. Inspired by the principles outlined in the [MALT research paper](https://arxiv.org/pdf/2412.01928), this implementation provides a reliable, extensible system for multi-agent collaboration.
|
||||||
|
|
||||||
|
The framework is designed around a three-agent architecture:
|
||||||
|
1. **Creator Agent**: Generates initial content or solutions
|
||||||
|
2. **Verifier Agent**: Critically evaluates the creator's output
|
||||||
|
3. **Refiner Agent**: Improves the solution based on verifier feedback
|
||||||
|
|
||||||
|
This collaborative approach enables high-quality outputs for complex tasks by combining the strengths of multiple specialized agents, each focused on a different aspect of the problem-solving process.
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
The MALT framework follows a structured workflow:
|
||||||
|
|
||||||
|
1. A task is submitted to the system
|
||||||
|
2. The Creator Agent generates an initial solution
|
||||||
|
3. Multiple instances of the Verifier Agent independently evaluate the solution
|
||||||
|
4. Multiple instances of the Refiner Agent improve the solution based on verification feedback
|
||||||
|
5. The final refined output is returned
|
||||||
|
|
||||||
|
This process can be configured to run for multiple iterations, with each cycle potentially improving the quality of the output. The system maintains a conversation history, tracking interactions between agents throughout the workflow.
|
||||||
|
|
||||||
|
### Key Components
|
||||||
|
|
||||||
|
- **Agent**: Represents an individual AI agent with specific capabilities and responsibilities
|
||||||
|
- **Conversation**: Manages the interaction history between agents
|
||||||
|
- **MALT Orchestrator**: Coordinates the workflow and manages agent interactions
|
||||||
|
- **Concurrency Support**: Enables parallel execution of multiple agent instances
|
||||||
|
|
||||||
|
## Architecture Diagram
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
flowchart TD
|
||||||
|
User[User/Client] -->|Submit Task| MALT[MALT Orchestrator]
|
||||||
|
|
||||||
|
subgraph MALT Framework
|
||||||
|
MALT -->|Task| Creator[Creator Agent]
|
||||||
|
Creator -->|Initial Solution| Conversation[Conversation Manager]
|
||||||
|
Conversation -->|Solution| VerifierPool[Verifier Agents Pool]
|
||||||
|
|
||||||
|
subgraph VerifierPool
|
||||||
|
Verifier1[Verifier Agent 1]
|
||||||
|
Verifier2[Verifier Agent 2]
|
||||||
|
Verifier3[Verifier Agent 3]
|
||||||
|
end
|
||||||
|
|
||||||
|
VerifierPool -->|Verification Feedback| Conversation
|
||||||
|
Conversation -->|Solution + Feedback| RefinerPool[Refiner Agents Pool]
|
||||||
|
|
||||||
|
subgraph RefinerPool
|
||||||
|
Refiner1[Refiner Agent 1]
|
||||||
|
Refiner2[Refiner Agent 2]
|
||||||
|
Refiner3[Refiner Agent 3]
|
||||||
|
end
|
||||||
|
|
||||||
|
RefinerPool -->|Refined Solutions| Conversation
|
||||||
|
end
|
||||||
|
|
||||||
|
Conversation -->|Final Output| User
|
||||||
|
```
|
||||||
|
|
||||||
|
## Execution Workflow
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
sequenceDiagram
|
||||||
|
participant User
|
||||||
|
participant MALT
|
||||||
|
participant Creator
|
||||||
|
participant Verifiers
|
||||||
|
participant Refiners
|
||||||
|
participant Conversation
|
||||||
|
|
||||||
|
User->>MALT: Submit task
|
||||||
|
MALT->>Creator: Process task
|
||||||
|
Creator->>Conversation: Add initial solution
|
||||||
|
|
||||||
|
par Verification Phase
|
||||||
|
Conversation->>Verifiers: Send solution for verification
|
||||||
|
Verifiers->>Conversation: Return verification feedback
|
||||||
|
end
|
||||||
|
|
||||||
|
par Refinement Phase
|
||||||
|
Conversation->>Refiners: Send solution + feedback
|
||||||
|
Refiners->>Conversation: Return refined solutions
|
||||||
|
end
|
||||||
|
|
||||||
|
MALT->>Conversation: Request final output
|
||||||
|
Conversation->>MALT: Return conversation history
|
||||||
|
MALT->>User: Return final result
|
||||||
|
```
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
### MALT Class
|
||||||
|
|
||||||
|
The core orchestrator that manages the multi-agent interaction process.
|
||||||
|
|
||||||
|
#### Constructor Parameters
|
||||||
|
|
||||||
|
| Parameter | Type | Default | Description |
|
||||||
|
|-----------|------|---------|-------------|
|
||||||
|
| `main_agent` | `Agent` | `None` | The primary agent (Creator) responsible for generating initial solutions |
|
||||||
|
| `refiner_agent` | `Agent` | `None` | The agent that refines solutions based on verification feedback |
|
||||||
|
| `verifier_agent` | `Agent` | `None` | The agent that verifies and evaluates solutions |
|
||||||
|
| `max_loops` | `int` | `1` | Maximum number of iterations for the task execution |
|
||||||
|
| `return_list` | `bool` | `False` | Flag to return output as a list |
|
||||||
|
| `return_dict` | `bool` | `False` | Flag to return output as a dictionary |
|
||||||
|
| `agents` | `list[Agent]` | `[]` | Alternative list of agents to use in the task |
|
||||||
|
| `preset_agents` | `bool` | `True` | Use default preset agents for mathematical proofs |
|
||||||
|
|
||||||
|
#### Methods
|
||||||
|
|
||||||
|
| Method | Parameters | Return Type | Description |
|
||||||
|
|--------|------------|-------------|-------------|
|
||||||
|
| `reliability_check` | None | None | Validates agent configuration and parameters |
|
||||||
|
| `step` | `task: str, img: str = None, *args, **kwargs` | `str` or `list` or `dict` | Executes a single iteration of the MALT workflow |
|
||||||
|
| `run` | `task: str, img: str = None, *args, **kwargs` | `str` or `list` or `dict` | Executes the complete MALT workflow for a task |
|
||||||
|
| `run_batched` | `tasks: List[str], *args, **kwargs` | `List[str]` or `List[list]` or `List[dict]` | Sequentially processes multiple tasks |
|
||||||
|
| `run_concurrently` | `tasks: List[str], *args, **kwargs` | `concurrent.futures.Future` | Processes multiple tasks in parallel using ThreadPoolExecutor |
|
||||||
|
| `__call__` | `task: str, *args, **kwargs` | Same as `run` | Allows the MALT instance to be called as a function |
|
||||||
|
| `__str__` | None | `str` | Returns the conversation history as a string |
|
||||||
|
| `__repr__` | None | `str` | Returns the conversation history as a string |
|
||||||
|
|
||||||
|
|
||||||
|
## Sample Implementations
|
||||||
|
|
||||||
|
### Default Mathematical Proof Agents
|
||||||
|
|
||||||
|
The MALT framework includes preset agents specialized for mathematical proof generation and refinement:
|
||||||
|
|
||||||
|
1. **Proof Creator Agent**: Generates original mathematical theorems and proofs
|
||||||
|
2. **Proof Verifier Agent**: Critically evaluates and identifies issues in mathematical proofs
|
||||||
|
3. **Proof Refiner Agent**: Improves proofs based on verification feedback
|
||||||
|
|
||||||
|
Each agent has a carefully designed system prompt that guides its behavior and specialization.
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from swarms.structs.agent import Agent
|
||||||
|
from swarms.structs.multi_agent_exec import MALT
|
||||||
|
|
||||||
|
# Initialize with preset mathematical proof agents
|
||||||
|
malt = MALT(preset_agents=True)
|
||||||
|
|
||||||
|
# Run a mathematical proof task
|
||||||
|
result = malt.run("Develop a theorem and proof related to prime numbers and their distribution.")
|
||||||
|
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Agents
|
||||||
|
|
||||||
|
```python
|
||||||
|
from swarms.structs.agent import Agent
|
||||||
|
from swarms.structs.multi_agent_exec import MALT
|
||||||
|
|
||||||
|
# Define custom agents
|
||||||
|
creator = Agent(
|
||||||
|
agent_name="Physics-Creator",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
system_prompt="You are a theoretical physicist specializing in quantum mechanics..."
|
||||||
|
)
|
||||||
|
|
||||||
|
verifier = Agent(
|
||||||
|
agent_name="Physics-Verifier",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
system_prompt="You are an experimental physicist who verifies theoretical claims..."
|
||||||
|
)
|
||||||
|
|
||||||
|
refiner = Agent(
|
||||||
|
agent_name="Physics-Communicator",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
system_prompt="You excel at explaining complex physics concepts to diverse audiences..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize MALT with custom agents
|
||||||
|
malt = MALT(
|
||||||
|
main_agent=creator,
|
||||||
|
verifier_agent=verifier,
|
||||||
|
refiner_agent=refiner,
|
||||||
|
preset_agents=False,
|
||||||
|
max_loops=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run a physics explanation task
|
||||||
|
result = malt.run("Explain the quantum entanglement phenomenon and its implications.")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Concurrent Processing
|
||||||
|
|
||||||
|
```python
|
||||||
|
from swarms.structs.multi_agent_exec import MALT
|
||||||
|
|
||||||
|
# Initialize MALT
|
||||||
|
malt = MALT()
|
||||||
|
|
||||||
|
# Define multiple tasks
|
||||||
|
tasks = [
|
||||||
|
"Prove a theorem related to continuous functions on compact sets.",
|
||||||
|
"Develop a theorem about convergence in infinite-dimensional Hilbert spaces.",
|
||||||
|
"Create a theorem relating to measure theory and Lebesgue integration."
|
||||||
|
]
|
||||||
|
|
||||||
|
# Process tasks concurrently
|
||||||
|
futures = malt.run_concurrently(tasks)
|
||||||
|
|
||||||
|
# Collect results as they complete
|
||||||
|
for future in futures:
|
||||||
|
result = future.result()
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example: Complex Mathematical Domain
|
||||||
|
|
||||||
|
Here's an example of how MALT can generate, verify, and refine a mathematical proof:
|
||||||
|
|
||||||
|
### Input
|
||||||
|
|
||||||
|
```python
|
||||||
|
malt = MALT(preset_agents=True)
|
||||||
|
task = "Develop a theorem and rigorous proof related to the convergence properties of infinite series."
|
||||||
|
result = malt.run(task)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Output Flow
|
||||||
|
|
||||||
|
1. **Creator Agent** generates a theorem and proof about conditions for absolute convergence
|
||||||
|
2. **Verifier Agents** identify issues:
|
||||||
|
- Logical gap in lemma 2
|
||||||
|
- Missing justification for uniform convergence claim
|
||||||
|
- Imprecise definition of certain terms
|
||||||
|
3. **Refiner Agents** produce improved versions addressing these concerns
|
||||||
|
4. The final output contains the refined, rigorous mathematical proof
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Task Specificity**: Provide clear, detailed task descriptions for optimal results
|
||||||
|
2. **Agent Specialization**: Design agent prompts to focus on specific aspects of the task
|
||||||
|
3. **Iteration Control**: Adjust `max_loops` based on task complexity
|
||||||
|
4. **Concurrent Verification**: Use multiple verifier instances for comprehensive evaluation
|
||||||
|
5. **Custom Agents**: Create domain-specific agents for specialized tasks
|
||||||
|
|
||||||
|
## Potential Improvements
|
||||||
|
|
||||||
|
- Autonomously create specialized agents based on task requirements
|
||||||
|
- Implement feedback loops between agents for iterative improvement
|
||||||
|
- Add support for agent-specific memory and knowledge bases
|
||||||
|
- Expand concurrency capabilities for improved performance
|
||||||
|
- Implement learning mechanisms for agent improvement over time
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- Original MALT paper: [arXiv:2412.01928](https://arxiv.org/pdf/2412.01928)
|
||||||
|
- Built on the swarms framework for multi-agent systems
|
@ -0,0 +1,12 @@
|
|||||||
|
from swarms.structs.malt import MALT
|
||||||
|
|
||||||
|
malt = MALT(
|
||||||
|
max_loops=1,
|
||||||
|
preset_agents=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
malt.run(
|
||||||
|
task="Prove that the sum of the first n natural numbers is n(n+1)/2."
|
||||||
|
)
|
||||||
|
|
||||||
|
print(malt.conversation.return_json())
|
@ -0,0 +1,406 @@
|
|||||||
|
"""
|
||||||
|
Implementation of the MALT (Multi-Agent Learning Task) orchestrator.
|
||||||
|
|
||||||
|
MALT is a multi-agent system that orchestrates the interaction between multiple agents
|
||||||
|
to perform complex tasks through a structured conversation. It ensures reliability and
|
||||||
|
provides options for output formatting.
|
||||||
|
|
||||||
|
|
||||||
|
- Paper: https://arxiv.org/pdf/2412.01928
|
||||||
|
|
||||||
|
|
||||||
|
Potential Improvements:
|
||||||
|
- Autonomously create the agents based on the task.
|
||||||
|
- Feed verifier responses back into the creator to improve the proof.
|
||||||
|
- Feed refiner responses back into the creator to improve the proof.
|
||||||
|
-
|
||||||
|
|
||||||
|
|
||||||
|
This is a simplified implementation of the MALT orchestrator. The original implementation trains the models with dpo and sft.
|
||||||
|
Whereas this implementation uses the models as is.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from ast import Mult
|
||||||
|
import concurrent.futures
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from swarms.structs.agent import Agent
|
||||||
|
from swarms.structs.conversation import Conversation
|
||||||
|
from swarms.structs.multi_agent_exec import run_agents_concurrently
|
||||||
|
from swarms.utils.any_to_str import any_to_str
|
||||||
|
|
||||||
|
# Agent 1: Proof Creator Agent
|
||||||
|
proof_creator_prompt = """
|
||||||
|
You are a world-renowned mathematician with an extensive background in multiple advanced fields including number theory, abstract algebra, topology, advanced calculus, and mathematical logic. You are tasked with generating an original, non-trivial theorem along with a fully rigorous and detailed proof. Your response must include the following elements:
|
||||||
|
|
||||||
|
1. **Theorem Statement and Definitions:**
|
||||||
|
- Clearly articulate a novel theorem in a specific branch of mathematics.
|
||||||
|
- Provide precise definitions of any non-standard terms or constructs that appear in your theorem.
|
||||||
|
- Contextualize the theorem within the framework of existing mathematical theory, explaining its significance.
|
||||||
|
|
||||||
|
2. **Structured Proof:**
|
||||||
|
- Develop the proof in a step-by-step format that is logically coherent and rigorous.
|
||||||
|
- Include intermediate results such as lemmas, propositions, and corollaries where applicable.
|
||||||
|
- Provide thorough justifications for every logical step and transition, referencing known theorems or axioms when relevant.
|
||||||
|
- If multiple cases or conditions are involved, handle each case separately with clear demarcations.
|
||||||
|
|
||||||
|
3. **Intuitive Explanations and Motivation:**
|
||||||
|
- Supplement the formal proof with intuitive explanations that help the reader understand the underlying ideas.
|
||||||
|
- Explain why the theorem is interesting, including possible applications or implications in other areas of mathematics.
|
||||||
|
- Address potential counterexamples or special conditions that could challenge the generality of your result.
|
||||||
|
|
||||||
|
4. **Formatting and Detail:**
|
||||||
|
- Your output should be verbose, ensuring that each part of the proof is elaborated in detail.
|
||||||
|
- Use formal mathematical language but also include lay explanations for complex arguments.
|
||||||
|
- Ensure that the final document is self-contained, clear, and can be understood by both experts and advanced students.
|
||||||
|
|
||||||
|
Your response should be as comprehensive as possible, leaving no room for ambiguity, and it should reflect your mastery in constructing original mathematical arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
proof_creator_agent = Agent(
|
||||||
|
agent_name="Proof-Creator-Agent",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
system_prompt=proof_creator_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Agent 2: Proof Verifier Agent
|
||||||
|
proof_verifier_prompt = """
|
||||||
|
You are an esteemed mathematician and veteran academic known for your precise and critical evaluations of complex mathematical arguments. Your role is to verify the proof produced by the Proof-Creator-Agent. Your detailed analysis should include the following components:
|
||||||
|
|
||||||
|
1. **Comprehensive Logical Analysis:**
|
||||||
|
- Examine every logical step of the proof, ensuring that all transitions are valid and that no step is assumed without proper justification.
|
||||||
|
- Identify any potential errors, missing justifications, or logical gaps in the argument.
|
||||||
|
- Provide a thorough commentary on each lemma, proposition, and conclusion presented in the proof.
|
||||||
|
|
||||||
|
2. **Mathematical Rigor and Consistency:**
|
||||||
|
- Cross-reference every argument with established mathematical theories, axioms, and known results.
|
||||||
|
- Check the consistency of definitions and ensure that they are used uniformly throughout the proof.
|
||||||
|
- Address any inconsistencies or ambiguities in notation, assumptions, or logical structure.
|
||||||
|
|
||||||
|
3. **Critical Feedback and Suggestions:**
|
||||||
|
- Provide detailed feedback on the strengths and weaknesses of the proof.
|
||||||
|
- Suggest specific modifications or additional explanations that could enhance the clarity, correctness, and overall rigor.
|
||||||
|
- If applicable, propose alternative approaches or insights that could further solidify the argument.
|
||||||
|
|
||||||
|
4. **Exhaustive Review:**
|
||||||
|
- Your analysis should be extensive and methodical, examining the proof from multiple angles.
|
||||||
|
- Ensure that each critique is accompanied by a clear rationale and reference to relevant mathematical principles.
|
||||||
|
- Summarize your findings in a structured format, highlighting both the successful aspects of the proof and areas that need improvement.
|
||||||
|
|
||||||
|
Your review must be exhaustive, ensuring that even the most subtle aspects of the proof are scrutinized in depth.
|
||||||
|
"""
|
||||||
|
|
||||||
|
proof_verifier_agent = Agent(
|
||||||
|
agent_name="Proof-Verifier-Agent",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
system_prompt=proof_verifier_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Agent 3: Proof Refiner Agent
|
||||||
|
proof_refiner_prompt = """
|
||||||
|
You are an expert in mathematical exposition and refinement with decades of experience in teaching, publishing, and peer-reviewing advanced mathematics. Your objective is to take the initial proof and the comprehensive feedback provided by the Proof-Verifier-Agent, and then produce a refined, polished version of the proof. Your refined output must address the following points:
|
||||||
|
|
||||||
|
1. **Incorporation of Verification Feedback:**
|
||||||
|
- Meticulously integrate all the detailed suggestions and critiques provided by the Proof-Verifier-Agent.
|
||||||
|
- Ensure that all logical gaps, ambiguities, and inconsistencies identified in the review are resolved in the revised proof.
|
||||||
|
- Revisit and revise definitions, lemmas, and intermediate steps where necessary to ensure complete logical consistency.
|
||||||
|
|
||||||
|
2. **Enhanced Clarity and Structure:**
|
||||||
|
- Reorganize the proof for optimal flow and readability, ensuring that each section leads naturally to the next.
|
||||||
|
- Add comprehensive explanations where needed, emphasizing intuitive reasoning alongside formal arguments.
|
||||||
|
- Break down complex sections into more manageable parts, and ensure that each is clearly labeled and explained.
|
||||||
|
|
||||||
|
3. **Rigorous Detailing and Presentation:**
|
||||||
|
- Enhance the overall presentation of the proof by ensuring that every assertion is supported by detailed justifications.
|
||||||
|
- Include additional commentary that not only defends the logical integrity of the argument but also explains its broader significance.
|
||||||
|
- Maintain a balance between rigorous formalism and accessible exposition so that the refined proof appeals to both experts and advanced learners.
|
||||||
|
|
||||||
|
4. **Comprehensive Feedback and Rationale:**
|
||||||
|
- For every modification made, provide an accompanying explanation that outlines the rationale behind the change.
|
||||||
|
- If any aspects of the original proof were retained, clarify why they were considered adequate and how they contribute to the overall argument.
|
||||||
|
- Ensure that your final output is a cohesive, self-contained document that stands up to critical academic scrutiny.
|
||||||
|
|
||||||
|
Your refined proof should be a masterpiece of mathematical writing, addressing all the feedback with detailed revisions and explanations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
proof_refiner_agent = Agent(
|
||||||
|
agent_name="Proof-Refiner-Agent",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
system_prompt=proof_refiner_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
majority_voting_prompt = """
|
||||||
|
Engage in a comprehensive and exhaustive majority voting analysis of the following conversation, ensuring a deep and thoughtful examination of the responses provided by each agent. This analysis should not only summarize the responses but also critically engage with the content, context, and implications of each agent's input.
|
||||||
|
|
||||||
|
Please adhere to the following detailed guidelines:
|
||||||
|
|
||||||
|
1. **Identification of Dominant Responses:**
|
||||||
|
- Identify the most prevalent answer or recommendation across all agents. Provide a thorough rationale for its dominance, including an exploration of the factors that may have contributed to its acceptance among the agents. Discuss the context in which this consensus emerged and any relevant historical or theoretical frameworks that support this conclusion.
|
||||||
|
|
||||||
|
2. **Exploration of Disparities:**
|
||||||
|
- Delve into any significant disparities or contrasting viewpoints between agents. Explore the underlying reasons for these differences, considering aspects such as differing methodologies, assumptions, or interpretations of the task at hand. Analyze how these contrasting perspectives may reflect broader debates within the field and what implications they hold for the overall understanding of the topic.
|
||||||
|
|
||||||
|
3. **Consensus and Disagreement Analysis:**
|
||||||
|
- Highlight key areas of consensus and disagreement among the agents. Discuss the implications of these findings on the overall argument, including how consensus can strengthen certain claims while disagreement may indicate areas of uncertainty or contention. Provide examples from the conversation to illustrate these points and consider how they might influence future discussions or research directions.
|
||||||
|
|
||||||
|
4. **Critical Evaluation of Majority Opinion:**
|
||||||
|
- Critically evaluate the strength of the majority opinion, considering factors such as the reasoning behind it and its mathematical validity if applicable. Assess whether the majority opinion is well-supported by evidence and logical reasoning, and discuss any potential weaknesses or oversights that may undermine its credibility.
|
||||||
|
|
||||||
|
5. **Insights from Minority Viewpoints:**
|
||||||
|
- Note any unique insights from minority viewpoints, assessing their potential contributions to a more nuanced understanding of the topic. Discuss how these minority perspectives can enrich the conversation and provide alternative angles that may have been overlooked by the majority. Consider the value of dissent in academic discourse and how it can lead to more robust conclusions.
|
||||||
|
|
||||||
|
6. **Synthesis of Recommendations:**
|
||||||
|
- Provide a final synthesized recommendation based on the majority consensus, ensuring that it reflects a thorough consideration of all perspectives and is grounded in sound reasoning. This recommendation should not only summarize the majority view but also integrate insights from minority opinions, creating a comprehensive and balanced conclusion that acknowledges the complexity of the discussion.
|
||||||
|
|
||||||
|
Throughout your analysis, focus on uncovering clear patterns while being attentive to the subtleties and complexities inherent in the responses. Pay particular attention to the nuances of mathematical contexts where algorithmic thinking may be required, ensuring that your examination is both rigorous and accessible to a diverse audience.
|
||||||
|
"""
|
||||||
|
|
||||||
|
majority_voting_agent = Agent(
|
||||||
|
agent_name="Majority-Voting-Agent",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
system_prompt=majority_voting_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
class MALT:
|
||||||
|
"""
|
||||||
|
MALT (Mult-Agent Learning Task) orchestrates the interaction between multiple agents
|
||||||
|
to perform complex tasks through a structured conversation. It ensures reliability and
|
||||||
|
provides options for output formatting.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
main_agent (Agent): The primary agent responsible for executing the main task.
|
||||||
|
refiner_agent (Agent): The agent that refines the output of the main agent.
|
||||||
|
verifier_agent (Agent): The agent that verifies the output of the main agent.
|
||||||
|
max_loops (int): The maximum number of iterations for the task execution.
|
||||||
|
return_list (bool): Flag to return output as a list.
|
||||||
|
return_dict (bool): Flag to return output as a dictionary.
|
||||||
|
agents (list[Agent]): A list of agents to be used in the task.
|
||||||
|
conversation (Conversation): Manages the conversation history between agents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
main_agent: Agent = None,
|
||||||
|
refiner_agent: Agent = None,
|
||||||
|
verifier_agent: Agent = None,
|
||||||
|
max_loops: int = 1,
|
||||||
|
return_list: bool = False,
|
||||||
|
return_dict: bool = False,
|
||||||
|
agents: list[Agent] = [],
|
||||||
|
preset_agents: bool = True,
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Initializing MALT with provided agents and parameters."
|
||||||
|
)
|
||||||
|
self.main_agent = main_agent
|
||||||
|
self.refiner_agent = refiner_agent
|
||||||
|
self.verifier_agent = verifier_agent
|
||||||
|
self.max_loops = max_loops
|
||||||
|
self.return_list = return_list
|
||||||
|
self.return_dict = return_dict
|
||||||
|
self.agents = agents
|
||||||
|
|
||||||
|
self.conversation = Conversation()
|
||||||
|
logger.debug("Conversation initialized.")
|
||||||
|
|
||||||
|
if preset_agents:
|
||||||
|
self.main_agent = proof_creator_agent
|
||||||
|
self.refiner_agent = proof_refiner_agent
|
||||||
|
self.verifier_agent = proof_verifier_agent
|
||||||
|
|
||||||
|
self.reliability_check()
|
||||||
|
|
||||||
|
def reliability_check(self):
|
||||||
|
"""Checks the reliability of the provided agents and parameters."""
|
||||||
|
logger.info("Performing reliability check.")
|
||||||
|
if self.max_loops == 0 or self.max_loops is None:
|
||||||
|
logger.error("max_loops must be greater than 0")
|
||||||
|
raise ValueError("max_loops must be greater than 0")
|
||||||
|
|
||||||
|
# Check if agents list is provided and not empty when needed
|
||||||
|
if not self.agents and (
|
||||||
|
self.main_agent is None
|
||||||
|
or self.refiner_agent is None
|
||||||
|
or self.verifier_agent is None
|
||||||
|
):
|
||||||
|
logger.error(
|
||||||
|
"Missing agents: Provide individual agents or a list of at least 3 agents."
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
"Either provide individual agents (main_agent, refiner_agent, verifier_agent) or a list of at least 3 agents"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If individual agents aren't specified but we have agents list, use the first three
|
||||||
|
if (
|
||||||
|
self.main_agent is None
|
||||||
|
or self.refiner_agent is None
|
||||||
|
or self.verifier_agent is None
|
||||||
|
) and len(self.agents) >= 3:
|
||||||
|
self.main_agent = self.main_agent or self.agents[0]
|
||||||
|
self.refiner_agent = self.refiner_agent or self.agents[1]
|
||||||
|
self.verifier_agent = (
|
||||||
|
self.verifier_agent or self.agents[2]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Final check to ensure we have all required agents
|
||||||
|
if (
|
||||||
|
self.main_agent is None
|
||||||
|
or self.refiner_agent is None
|
||||||
|
or self.verifier_agent is None
|
||||||
|
):
|
||||||
|
logger.error("Missing required agents.")
|
||||||
|
raise ValueError(
|
||||||
|
"Missing required agents: main_agent, refiner_agent, and verifier_agent must all be provided"
|
||||||
|
)
|
||||||
|
logger.info("Reliability check passed.")
|
||||||
|
|
||||||
|
def step(self, task: str, img: str = None, *args, **kwargs):
|
||||||
|
"""Executes the task using the main agent and processes the output through verifier and refiner agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to be executed by the main agent.
|
||||||
|
img (str, optional): An optional image input for the agents.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or list or dict: The output from the conversation based on the specified return format.
|
||||||
|
"""
|
||||||
|
self.conversation.add(
|
||||||
|
role="user",
|
||||||
|
content=task,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Running task with main agent.")
|
||||||
|
main_agent_output = self.main_agent.run(
|
||||||
|
task=task, img=img, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conversation.add(
|
||||||
|
role=self.main_agent.agent_name, content=main_agent_output
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Running task with verifier agents")
|
||||||
|
verified_outputs = run_agents_concurrently(
|
||||||
|
[
|
||||||
|
self.verifier_agent,
|
||||||
|
self.verifier_agent,
|
||||||
|
self.verifier_agent,
|
||||||
|
],
|
||||||
|
task=main_agent_output,
|
||||||
|
max_workers=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conversation.add(
|
||||||
|
role=self.verifier_agent.agent_name,
|
||||||
|
content=verified_outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
######################### MAJORITY VOTING #########################
|
||||||
|
|
||||||
|
|
||||||
|
# Majority Voting on the verified outputs
|
||||||
|
majority_voting_verified = majority_voting_agent.run(
|
||||||
|
task=any_to_str(verified_outputs),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conversation.add(
|
||||||
|
role=self.majority_voting_agent.agent_name,
|
||||||
|
content=majority_voting_verified,
|
||||||
|
)
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
|
||||||
|
|
||||||
|
# Refining the majority voting output
|
||||||
|
logger.info("Running task with refiner agents")
|
||||||
|
for output in verified_outputs:
|
||||||
|
refined_outputs = run_agents_concurrently(
|
||||||
|
[
|
||||||
|
self.refiner_agent,
|
||||||
|
self.refiner_agent,
|
||||||
|
self.refiner_agent,
|
||||||
|
],
|
||||||
|
task=output,
|
||||||
|
max_workers=3,
|
||||||
|
)
|
||||||
|
logger.debug(f"Refined outputs: {refined_outputs}")
|
||||||
|
|
||||||
|
self.conversation.add(
|
||||||
|
role=self.refiner_agent.agent_name,
|
||||||
|
content=refined_outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.conversation.get_str()
|
||||||
|
|
||||||
|
def run(self, task: str, img: str = None, *args, **kwargs):
|
||||||
|
"""Executes the task using the main agent and processes the output through verifier and refiner agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to be executed by the main agent.
|
||||||
|
img (str, optional): An optional image input for the agents.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or list or dict: The output from the conversation based on the specified return format.
|
||||||
|
"""
|
||||||
|
task = task
|
||||||
|
for i in range(self.max_loops):
|
||||||
|
logger.info(f"Starting iteration {i+1}/{self.max_loops}")
|
||||||
|
output = self.step(task, img, *args, **kwargs)
|
||||||
|
if output is not None:
|
||||||
|
return output
|
||||||
|
|
||||||
|
if self.return_list:
|
||||||
|
logger.info("Returning output as a list.")
|
||||||
|
return self.conversation.return_messages_as_list()
|
||||||
|
elif self.return_dict:
|
||||||
|
logger.info("Returning output as a dictionary.")
|
||||||
|
return self.conversation.return_messages_as_dictionary()
|
||||||
|
else:
|
||||||
|
logger.info("Returning output as a string.")
|
||||||
|
return self.conversation.get_str()
|
||||||
|
|
||||||
|
def run_batched(self, tasks: List[str], *args, **kwargs):
|
||||||
|
"""Executes a list of tasks using the main agent and processes the output through verifier and refiner agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks (list[str]): The list of tasks to be executed by the main agent.
|
||||||
|
"""
|
||||||
|
logger.info("Running batch of tasks.")
|
||||||
|
logger.info(f"Number of tasks: {len(tasks)}")
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for task in tasks:
|
||||||
|
outputs.append(self.run(task, *args, **kwargs))
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def __call__(self, task: str, *args, **kwargs):
|
||||||
|
return self.run(task, *args, **kwargs)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.conversation.get_str()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.conversation.get_str()
|
||||||
|
|
||||||
|
def run_concurrently(self, tasks: List[str], *args, **kwargs):
|
||||||
|
"""Executes a list of tasks using the main agent and processes the output through verifier and refiner agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks (list[str]): The list of tasks to be executed by the main agent.
|
||||||
|
"""
|
||||||
|
logger.info("Running batch of tasks concurrently.")
|
||||||
|
logger.info(f"Number of tasks: {len(tasks)}")
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(self.run, task, *args, **kwargs)
|
||||||
|
for task in tasks
|
||||||
|
]
|
||||||
|
return concurrent.futures.as_completed(futures)
|
@ -1,964 +0,0 @@
|
|||||||
"""
|
|
||||||
MultiModelOptimizer: A high-performance optimizer for training multiple transformer models simultaneously.
|
|
||||||
|
|
||||||
This optimizer implements several advanced techniques:
|
|
||||||
1. Gradient accumulation with dynamic batch sizing
|
|
||||||
2. Hierarchical parameter synchronization
|
|
||||||
3. Memory-efficient gradient sharing with shape compatibility
|
|
||||||
4. Adaptive learning rate scheduling per model
|
|
||||||
5. Convergence acceleration via momentum tuning
|
|
||||||
6. Robust error handling for production environments
|
|
||||||
|
|
||||||
Author: Claude 3.7 Sonnet
|
|
||||||
License: MIT
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Dict, List, Optional, Tuple, Callable
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
|
||||||
from loguru import logger
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class MultiModelOptimizer(Optimizer):
|
|
||||||
"""
|
|
||||||
An optimizer designed for training multiple models simultaneously with shared gradient information,
|
|
||||||
adaptive learning rates, and efficient memory usage.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
models (Dict[str, nn.Module]): Dictionary mapping model names to model instances
|
|
||||||
lr (float, optional): Initial learning rate. Default: 1e-3
|
|
||||||
betas (Tuple[float, float], optional): Coefficients for computing running averages of gradient and its square. Default: (0.9, 0.999)
|
|
||||||
eps (float, optional): Term added to denominator for numerical stability. Default: 1e-8
|
|
||||||
weight_decay (float, optional): Weight decay coefficient. Default: 0
|
|
||||||
amsgrad (bool, optional): Whether to use the AMSGrad variant. Default: False
|
|
||||||
grad_sync_frequency (int, optional): How often to synchronize gradients between models. Default: 1
|
|
||||||
warmup_steps (int, optional): Number of warmup steps for learning rate. Default: 1000
|
|
||||||
model_weights (Dict[str, float], optional): Relative importance weights for each model. Default: None
|
|
||||||
gradient_accumulation_steps (int, optional): Number of steps to accumulate gradients before update. Default: 1
|
|
||||||
clip_grad_norm (float, optional): Maximum norm for gradient clipping. Default: None
|
|
||||||
use_cosine_schedule (bool, optional): Whether to use cosine annealing schedule. Default: True
|
|
||||||
sync_every_step (bool, optional): Whether to synchronize parameters on every step. Default: False
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
models: Dict[str, nn.Module],
|
|
||||||
lr: float = 1e-3,
|
|
||||||
betas: Tuple[float, float] = (0.9, 0.999),
|
|
||||||
eps: float = 1e-8,
|
|
||||||
weight_decay: float = 0,
|
|
||||||
amsgrad: bool = False,
|
|
||||||
grad_sync_frequency: int = 1,
|
|
||||||
warmup_steps: int = 1000,
|
|
||||||
model_weights: Optional[Dict[str, float]] = None,
|
|
||||||
gradient_accumulation_steps: int = 1,
|
|
||||||
clip_grad_norm: Optional[float] = None,
|
|
||||||
use_cosine_schedule: bool = True,
|
|
||||||
sync_every_step: bool = False,
|
|
||||||
):
|
|
||||||
|
|
||||||
# Initialize model weights if not provided
|
|
||||||
if model_weights is None:
|
|
||||||
model_weights = {name: 1.0 for name in models.keys()}
|
|
||||||
|
|
||||||
# Normalize weights to sum to 1
|
|
||||||
total_weight = sum(model_weights.values())
|
|
||||||
self.model_weights = {
|
|
||||||
k: v / total_weight for k, v in model_weights.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Store models
|
|
||||||
self.models = models
|
|
||||||
|
|
||||||
# Collect all parameters from all models
|
|
||||||
parameters = []
|
|
||||||
self.model_param_groups: Dict[str, List[Dict]] = {}
|
|
||||||
|
|
||||||
for model_name, model in models.items():
|
|
||||||
model_params = []
|
|
||||||
for param in model.parameters():
|
|
||||||
if param.requires_grad:
|
|
||||||
param_dict = {
|
|
||||||
"params": [param],
|
|
||||||
"model_name": model_name,
|
|
||||||
}
|
|
||||||
parameters.append(param_dict)
|
|
||||||
model_params.append(param_dict)
|
|
||||||
self.model_param_groups[model_name] = model_params
|
|
||||||
|
|
||||||
# Initialize optimizer with all parameters
|
|
||||||
defaults = dict(
|
|
||||||
lr=lr,
|
|
||||||
betas=betas,
|
|
||||||
eps=eps,
|
|
||||||
weight_decay=weight_decay,
|
|
||||||
amsgrad=amsgrad,
|
|
||||||
)
|
|
||||||
super(MultiModelOptimizer, self).__init__(
|
|
||||||
parameters, defaults
|
|
||||||
)
|
|
||||||
|
|
||||||
# Additional settings
|
|
||||||
self.grad_sync_frequency = grad_sync_frequency
|
|
||||||
self.warmup_steps = warmup_steps
|
|
||||||
self.step_count = 0
|
|
||||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
|
||||||
self.current_accumulation_step = 0
|
|
||||||
self.clip_grad_norm = clip_grad_norm
|
|
||||||
self.use_cosine_schedule = use_cosine_schedule
|
|
||||||
self.sync_every_step = sync_every_step
|
|
||||||
|
|
||||||
# Metrics and tracking
|
|
||||||
self.model_losses: Dict[str, List[float]] = defaultdict(list)
|
|
||||||
self.model_gradients: Dict[str, torch.Tensor] = {}
|
|
||||||
self.shared_gradient_cache: Dict[str, torch.Tensor] = {}
|
|
||||||
|
|
||||||
# Set up gradient sharing structures
|
|
||||||
self.param_name_to_model = {}
|
|
||||||
for name, model in self.models.items():
|
|
||||||
for param_name, _ in model.named_parameters():
|
|
||||||
self.param_name_to_model[f"{name}.{param_name}"] = (
|
|
||||||
name
|
|
||||||
)
|
|
||||||
|
|
||||||
# Configure logger
|
|
||||||
logger.configure(
|
|
||||||
handlers=[
|
|
||||||
{
|
|
||||||
"sink": "logs/multi_model_optimizer.log",
|
|
||||||
"level": "INFO",
|
|
||||||
},
|
|
||||||
{"sink": lambda msg: print(msg), "level": "INFO"},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Initialized MultiModelOptimizer with {len(models)} models"
|
|
||||||
)
|
|
||||||
for name, weight in self.model_weights.items():
|
|
||||||
logger.info(f"Model {name} weight: {weight:.4f}")
|
|
||||||
|
|
||||||
def get_lr_multiplier(self) -> float:
|
|
||||||
"""Calculate the learning rate multiplier based on warmup and schedule."""
|
|
||||||
if self.step_count < self.warmup_steps:
|
|
||||||
# Linear warmup
|
|
||||||
return float(self.step_count) / float(
|
|
||||||
max(1, self.warmup_steps)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_cosine_schedule:
|
|
||||||
# Cosine decay after warmup
|
|
||||||
decay_steps = max(1, self.step_count - self.warmup_steps)
|
|
||||||
cosine_decay = 0.5 * (
|
|
||||||
1
|
|
||||||
+ math.cos(
|
|
||||||
math.pi
|
|
||||||
* decay_steps
|
|
||||||
/ (10000 * self.gradient_accumulation_steps)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return max(
|
|
||||||
0.1, cosine_decay
|
|
||||||
) # Don't let LR go below 10% of base value
|
|
||||||
|
|
||||||
return 1.0 # Constant LR after warmup if not using cosine
|
|
||||||
|
|
||||||
def share_gradients(self):
|
|
||||||
"""Share gradient information across models for similar parameters."""
|
|
||||||
# First, collect all gradients by parameter type and shape
|
|
||||||
param_type_shape_grads = defaultdict(list)
|
|
||||||
|
|
||||||
for model_name, model in self.models.items():
|
|
||||||
for param_name, param in model.named_parameters():
|
|
||||||
if param.grad is not None:
|
|
||||||
# Classify parameter by name pattern and include shape to ensure compatibility
|
|
||||||
param_type = self._classify_parameter(param_name)
|
|
||||||
param_shape = param.shape
|
|
||||||
key = (param_type, param_shape)
|
|
||||||
param_type_shape_grads[key].append(
|
|
||||||
(model_name, param_name, param.grad)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now compute shared gradients for each parameter type and shape combination
|
|
||||||
for (
|
|
||||||
param_type,
|
|
||||||
param_shape,
|
|
||||||
), grads in param_type_shape_grads.items():
|
|
||||||
if len(grads) <= 1:
|
|
||||||
continue # Skip if only one model has this parameter type+shape
|
|
||||||
|
|
||||||
cache_key = f"{param_type}_{param_shape}"
|
|
||||||
|
|
||||||
# Compute weighted average gradient for this parameter type+shape
|
|
||||||
for model_name, param_name, grad in grads:
|
|
||||||
weight = self.model_weights[model_name]
|
|
||||||
|
|
||||||
# Initialize shared gradient for this parameter if not exists
|
|
||||||
if cache_key not in self.shared_gradient_cache:
|
|
||||||
self.shared_gradient_cache[cache_key] = (
|
|
||||||
torch.zeros_like(grad)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add weighted contribution
|
|
||||||
self.shared_gradient_cache[cache_key].add_(
|
|
||||||
grad * weight
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now apply a fraction of the shared gradient back to each model's parameter
|
|
||||||
for model_name, param_name, _ in grads:
|
|
||||||
param = self.models[model_name].get_parameter(
|
|
||||||
param_name
|
|
||||||
)
|
|
||||||
if param.grad is not None:
|
|
||||||
# Mix original gradient with shared gradient
|
|
||||||
sharing_ratio = 0.2 # 20% shared, 80% original
|
|
||||||
param.grad.mul_(1 - sharing_ratio).add_(
|
|
||||||
self.shared_gradient_cache[cache_key]
|
|
||||||
* sharing_ratio
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clear the cache for next iteration
|
|
||||||
self.shared_gradient_cache.clear()
|
|
||||||
|
|
||||||
def _classify_parameter(self, param_name: str) -> str:
|
|
||||||
"""Classify parameter by name to determine which parameters should share gradients."""
|
|
||||||
# First, make sure we include the model architecture in the classification
|
|
||||||
# to prevent mixing parameters from different architectures
|
|
||||||
model_type = "unknown"
|
|
||||||
if "bert" in param_name:
|
|
||||||
model_type = "bert"
|
|
||||||
elif "gpt" in param_name:
|
|
||||||
model_type = "gpt"
|
|
||||||
elif "roberta" in param_name:
|
|
||||||
model_type = "roberta"
|
|
||||||
elif "transformer" in param_name:
|
|
||||||
model_type = "transformer"
|
|
||||||
|
|
||||||
# Then classify by parameter type
|
|
||||||
param_type = "other"
|
|
||||||
if (
|
|
||||||
"query" in param_name
|
|
||||||
or "key" in param_name
|
|
||||||
or "value" in param_name
|
|
||||||
):
|
|
||||||
param_type = "attention"
|
|
||||||
elif (
|
|
||||||
"dense" in param_name
|
|
||||||
or "fc" in param_name
|
|
||||||
or "ffn" in param_name
|
|
||||||
):
|
|
||||||
param_type = "ffn"
|
|
||||||
elif "embedding" in param_name:
|
|
||||||
param_type = "embedding"
|
|
||||||
elif "norm" in param_name or "layer_norm" in param_name:
|
|
||||||
param_type = "norm"
|
|
||||||
elif "bias" in param_name:
|
|
||||||
param_type = "bias"
|
|
||||||
else:
|
|
||||||
param_type = param_name.split(".")[
|
|
||||||
-1
|
|
||||||
] # Use the last component of the name
|
|
||||||
|
|
||||||
# Combine model type and parameter type for more specific classification
|
|
||||||
return f"{model_type}_{param_type}"
|
|
||||||
|
|
||||||
def step(
|
|
||||||
self, closure: Optional[Callable[[], float]] = None
|
|
||||||
) -> Optional[float]:
|
|
||||||
"""Perform a single optimization step, handling gradient accumulation and sync."""
|
|
||||||
loss = None
|
|
||||||
if closure is not None:
|
|
||||||
loss = closure()
|
|
||||||
|
|
||||||
self.current_accumulation_step += 1
|
|
||||||
|
|
||||||
# Only perform the update after accumulating enough gradients
|
|
||||||
if (
|
|
||||||
self.current_accumulation_step
|
|
||||||
< self.gradient_accumulation_steps
|
|
||||||
):
|
|
||||||
return loss
|
|
||||||
|
|
||||||
self.current_accumulation_step = 0
|
|
||||||
self.step_count += 1
|
|
||||||
|
|
||||||
# Apply gradient clipping if configured
|
|
||||||
if self.clip_grad_norm is not None:
|
|
||||||
for model_name, model in self.models.items():
|
|
||||||
torch.nn.utils.clip_grad_norm_(
|
|
||||||
model.parameters(), self.clip_grad_norm
|
|
||||||
)
|
|
||||||
|
|
||||||
# Share gradients between models if it's time
|
|
||||||
if self.step_count % self.grad_sync_frequency == 0:
|
|
||||||
self.share_gradients()
|
|
||||||
|
|
||||||
# Calculate the current learning rate multiplier
|
|
||||||
lr_multiplier = self.get_lr_multiplier()
|
|
||||||
|
|
||||||
# Apply optimizer update for each parameter group
|
|
||||||
for group in self.param_groups:
|
|
||||||
# Get model-specific learning rate adjustment
|
|
||||||
model_name = group["model_name"]
|
|
||||||
model_weight = self.model_weights[model_name]
|
|
||||||
|
|
||||||
# Adjust lr based on model weight and global multiplier
|
|
||||||
model_lr_multiplier = lr_multiplier * (
|
|
||||||
0.5 + 0.5 * model_weight
|
|
||||||
) # Scale between 50-150% based on weight
|
|
||||||
|
|
||||||
# Extract parameters for this group
|
|
||||||
p = group["params"][0]
|
|
||||||
if p.grad is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# State initialization
|
|
||||||
state = self.state[p]
|
|
||||||
if len(state) == 0:
|
|
||||||
state["step"] = 0
|
|
||||||
state["exp_avg"] = torch.zeros_like(
|
|
||||||
p, memory_format=torch.preserve_format
|
|
||||||
)
|
|
||||||
state["exp_avg_sq"] = torch.zeros_like(
|
|
||||||
p, memory_format=torch.preserve_format
|
|
||||||
)
|
|
||||||
if group["amsgrad"]:
|
|
||||||
state["max_exp_avg_sq"] = torch.zeros_like(
|
|
||||||
p, memory_format=torch.preserve_format
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract optimizer parameters
|
|
||||||
beta1, beta2 = group["betas"]
|
|
||||||
exp_avg, exp_avg_sq = (
|
|
||||||
state["exp_avg"],
|
|
||||||
state["exp_avg_sq"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update step count
|
|
||||||
state["step"] += 1
|
|
||||||
|
|
||||||
# Decay the first and second moment running average coefficient
|
|
||||||
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
|
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(
|
|
||||||
p.grad, p.grad, value=1 - beta2
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply AMSGrad if enabled
|
|
||||||
if group["amsgrad"]:
|
|
||||||
max_exp_avg_sq = state["max_exp_avg_sq"]
|
|
||||||
torch.maximum(
|
|
||||||
max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq
|
|
||||||
)
|
|
||||||
denom = max_exp_avg_sq.sqrt().add_(group["eps"])
|
|
||||||
else:
|
|
||||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
|
||||||
|
|
||||||
# Bias correction
|
|
||||||
bias_correction1 = 1 - beta1 ** state["step"]
|
|
||||||
bias_correction2 = 1 - beta2 ** state["step"]
|
|
||||||
step_size = (
|
|
||||||
group["lr"]
|
|
||||||
* model_lr_multiplier
|
|
||||||
* math.sqrt(bias_correction2)
|
|
||||||
/ bias_correction1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply weight decay if configured
|
|
||||||
if group["weight_decay"] > 0:
|
|
||||||
p.data.add_(
|
|
||||||
p.data,
|
|
||||||
alpha=-group["weight_decay"]
|
|
||||||
* group["lr"]
|
|
||||||
* model_lr_multiplier,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update parameter
|
|
||||||
p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
|
||||||
|
|
||||||
# Synchronize parameters if configured to do so every step
|
|
||||||
if self.sync_every_step:
|
|
||||||
self.synchronize_similar_parameters()
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def synchronize_similar_parameters(self):
|
|
||||||
"""Synchronize similar parameters across models to promote convergence."""
|
|
||||||
# Only sync occasionally
|
|
||||||
if self.step_count % 10 != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# First, identify similar parameters across models
|
|
||||||
param_groups = defaultdict(list)
|
|
||||||
|
|
||||||
for model_name, model in self.models.items():
|
|
||||||
for param_name, param in model.named_parameters():
|
|
||||||
# Only sync parameters of the same shape
|
|
||||||
param_type = self._classify_parameter(param_name)
|
|
||||||
param_shape = param.shape
|
|
||||||
param_groups[(param_type, param_shape)].append(
|
|
||||||
(model_name, param_name, param)
|
|
||||||
)
|
|
||||||
|
|
||||||
# For each group of similar parameters, synchronize values
|
|
||||||
for (
|
|
||||||
param_type,
|
|
||||||
param_shape,
|
|
||||||
), params in param_groups.items():
|
|
||||||
if len(params) <= 1:
|
|
||||||
continue # Skip if only one parameter in this group
|
|
||||||
|
|
||||||
# Calculate weighted average
|
|
||||||
avg_param = None
|
|
||||||
total_weight = 0.0
|
|
||||||
|
|
||||||
for model_name, _, param in params:
|
|
||||||
weight = self.model_weights[model_name]
|
|
||||||
total_weight += weight
|
|
||||||
|
|
||||||
if avg_param is None:
|
|
||||||
avg_param = param.data.clone() * weight
|
|
||||||
else:
|
|
||||||
avg_param.add_(param.data * weight)
|
|
||||||
|
|
||||||
if total_weight > 0:
|
|
||||||
avg_param.div_(total_weight)
|
|
||||||
|
|
||||||
# Mix original parameters with the average (soft sync)
|
|
||||||
sync_ratio = 0.1 # 10% shared, 90% original
|
|
||||||
for _, _, param in params:
|
|
||||||
param.data.mul_(1 - sync_ratio).add_(
|
|
||||||
avg_param * sync_ratio
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error during parameter synchronization: {e}"
|
|
||||||
)
|
|
||||||
logger.error("Skipping synchronization for this step")
|
|
||||||
|
|
||||||
def log_metrics(self, model_losses: Dict[str, float]):
|
|
||||||
"""Log training metrics and update loss tracking."""
|
|
||||||
for model_name, loss in model_losses.items():
|
|
||||||
self.model_losses[model_name].append(loss)
|
|
||||||
|
|
||||||
# Log metrics every 100 steps
|
|
||||||
if self.step_count % 100 == 0:
|
|
||||||
avg_losses = {
|
|
||||||
name: np.mean(losses[-100:])
|
|
||||||
for name, losses in self.model_losses.items()
|
|
||||||
if losses
|
|
||||||
}
|
|
||||||
current_lr = (
|
|
||||||
self.param_groups[0]["lr"] * self.get_lr_multiplier()
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Step {self.step_count}")
|
|
||||||
logger.info(f"Current learning rate: {current_lr:.6f}")
|
|
||||||
for model_name, avg_loss in avg_losses.items():
|
|
||||||
logger.info(
|
|
||||||
f"Model {model_name} - Avg loss: {avg_loss:.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def state_dict(self) -> Dict:
|
|
||||||
"""Return the optimizer state dict with additional MultiModelOptimizer specifics."""
|
|
||||||
state_dict = super(MultiModelOptimizer, self).state_dict()
|
|
||||||
state_dict["model_weights"] = self.model_weights
|
|
||||||
state_dict["step_count"] = self.step_count
|
|
||||||
state_dict["current_accumulation_step"] = (
|
|
||||||
self.current_accumulation_step
|
|
||||||
)
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Dict):
|
|
||||||
"""Load optimizer state with MultiModelOptimizer specifics."""
|
|
||||||
self.model_weights = state_dict.pop("model_weights")
|
|
||||||
self.step_count = state_dict.pop("step_count")
|
|
||||||
self.current_accumulation_step = state_dict.pop(
|
|
||||||
"current_accumulation_step"
|
|
||||||
)
|
|
||||||
super(MultiModelOptimizer, self).load_state_dict(state_dict)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiModelScheduler(_LRScheduler):
|
|
||||||
"""
|
|
||||||
A learning rate scheduler designed to work with MultiModelOptimizer,
|
|
||||||
supporting different schedules for different models based on their convergence rates.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
optimizer (MultiModelOptimizer): The optimizer to schedule
|
|
||||||
total_steps (int): Total number of training steps
|
|
||||||
warmup_steps (int, optional): Number of warmup steps. Default: 1000
|
|
||||||
min_lr_ratio (float, optional): Minimum learning rate as a fraction of max. Default: 0.1
|
|
||||||
model_schedule_weights (Dict[str, float], optional): Per-model schedule weights. Default: None
|
|
||||||
last_epoch (int, optional): The index of the last epoch. Default: -1
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
optimizer: MultiModelOptimizer,
|
|
||||||
total_steps: int,
|
|
||||||
warmup_steps: int = 1000,
|
|
||||||
min_lr_ratio: float = 0.1,
|
|
||||||
model_schedule_weights: Optional[Dict[str, float]] = None,
|
|
||||||
last_epoch: int = -1,
|
|
||||||
):
|
|
||||||
|
|
||||||
self.total_steps = total_steps
|
|
||||||
self.warmup_steps = warmup_steps
|
|
||||||
self.min_lr_ratio = min_lr_ratio
|
|
||||||
|
|
||||||
# Use optimizer's model weights if not provided
|
|
||||||
if model_schedule_weights is None:
|
|
||||||
self.model_schedule_weights = optimizer.model_weights
|
|
||||||
else:
|
|
||||||
self.model_schedule_weights = model_schedule_weights
|
|
||||||
|
|
||||||
self.model_convergence_rates: Dict[str, float] = {
|
|
||||||
name: 1.0 for name in self.model_schedule_weights.keys()
|
|
||||||
}
|
|
||||||
super(MultiModelScheduler, self).__init__(
|
|
||||||
optimizer, last_epoch
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_lr(self):
|
|
||||||
"""Calculate learning rates for all parameter groups."""
|
|
||||||
if not self._get_lr_called_within_step:
|
|
||||||
logger.warning(
|
|
||||||
"To get the last learning rate computed by the scheduler, please use `get_last_lr()`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply warmup
|
|
||||||
if self.last_epoch < self.warmup_steps:
|
|
||||||
lr_scale = float(self.last_epoch) / float(
|
|
||||||
max(1, self.warmup_steps)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Cosine decay after warmup
|
|
||||||
progress = float(
|
|
||||||
self.last_epoch - self.warmup_steps
|
|
||||||
) / float(max(1, self.total_steps - self.warmup_steps))
|
|
||||||
lr_scale = max(
|
|
||||||
self.min_lr_ratio,
|
|
||||||
0.5 * (1.0 + math.cos(math.pi * progress)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply model-specific adjustments based on convergence rates
|
|
||||||
lrs = []
|
|
||||||
for group in self.optimizer.param_groups:
|
|
||||||
model_name = group["model_name"]
|
|
||||||
# Adjust learning rate based on model convergence rate
|
|
||||||
model_lr = group["initial_lr"] * lr_scale
|
|
||||||
|
|
||||||
# Apply model-specific adjustment
|
|
||||||
if model_name in self.model_convergence_rates:
|
|
||||||
# Models with higher convergence rates get lower learning rates
|
|
||||||
conv_rate = self.model_convergence_rates[model_name]
|
|
||||||
model_lr *= max(0.5, min(1.5, 1.0 / conv_rate))
|
|
||||||
|
|
||||||
lrs.append(model_lr)
|
|
||||||
|
|
||||||
return lrs
|
|
||||||
|
|
||||||
def update_convergence_rates(
|
|
||||||
self, model_losses: Dict[str, List[float]], window: int = 100
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update convergence rate estimates based on recent loss trends.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_losses: Dictionary mapping model names to their loss histories
|
|
||||||
window: Number of steps to consider for convergence estimation
|
|
||||||
"""
|
|
||||||
for model_name, losses in model_losses.items():
|
|
||||||
if len(losses) < window:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Use recent loss values
|
|
||||||
recent_losses = losses[-window:]
|
|
||||||
|
|
||||||
# Calculate slope of loss curve
|
|
||||||
x = np.arange(len(recent_losses))
|
|
||||||
y = np.array(recent_losses)
|
|
||||||
|
|
||||||
# Simple linear regression to estimate convergence rate
|
|
||||||
slope, _ = np.polyfit(x, y, 1)
|
|
||||||
|
|
||||||
# Normalize slope to a convergence rate
|
|
||||||
# Negative slope is good (loss is decreasing)
|
|
||||||
norm_rate = 1.0 / (1.0 + abs(slope))
|
|
||||||
|
|
||||||
# Update with exponential moving average
|
|
||||||
old_rate = self.model_convergence_rates.get(
|
|
||||||
model_name, 1.0
|
|
||||||
)
|
|
||||||
self.model_convergence_rates[model_name] = (
|
|
||||||
0.9 * old_rate + 0.1 * norm_rate
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log updated convergence rates
|
|
||||||
logger.info("Updated model convergence rates:")
|
|
||||||
for model_name, rate in self.model_convergence_rates.items():
|
|
||||||
logger.info(f" {model_name}: {rate:.4f}")
|
|
||||||
|
|
||||||
|
|
||||||
# Usage example with real dataset
|
|
||||||
def example_usage_with_real_data():
|
|
||||||
"""Example demonstrating how to use MultiModelOptimizer with real data from GLUE."""
|
|
||||||
try:
|
|
||||||
# Import required libraries
|
|
||||||
from transformers import (
|
|
||||||
BertForSequenceClassification,
|
|
||||||
GPT2ForSequenceClassification,
|
|
||||||
RobertaForSequenceClassification,
|
|
||||||
BertTokenizer,
|
|
||||||
GPT2Tokenizer,
|
|
||||||
RobertaTokenizer,
|
|
||||||
DataCollatorWithPadding,
|
|
||||||
)
|
|
||||||
from datasets import load_dataset
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
# Set up logging
|
|
||||||
logger.info(
|
|
||||||
"=== Starting MultiModelOptimizer example with real GLUE data ==="
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load SST-2 dataset from GLUE (small sentiment classification dataset)
|
|
||||||
logger.info("Loading SST-2 dataset from GLUE...")
|
|
||||||
sst2_dataset = load_dataset("glue", "sst2")
|
|
||||||
train_dataset = sst2_dataset["train"].select(
|
|
||||||
range(1000)
|
|
||||||
) # Use only 1000 examples for quick training
|
|
||||||
|
|
||||||
# Load tokenizers
|
|
||||||
logger.info("Loading tokenizers...")
|
|
||||||
bert_tokenizer = BertTokenizer.from_pretrained(
|
|
||||||
"bert-base-uncased"
|
|
||||||
)
|
|
||||||
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
|
||||||
roberta_tokenizer = RobertaTokenizer.from_pretrained(
|
|
||||||
"roberta-base"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add padding token to GPT2 tokenizer (it doesn't have one by default)
|
|
||||||
if gpt2_tokenizer.pad_token is None:
|
|
||||||
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token
|
|
||||||
|
|
||||||
# Tokenization functions
|
|
||||||
def tokenize_bert(examples):
|
|
||||||
return bert_tokenizer(
|
|
||||||
examples["sentence"], truncation=True, max_length=128
|
|
||||||
)
|
|
||||||
|
|
||||||
def tokenize_gpt2(examples):
|
|
||||||
return gpt2_tokenizer(
|
|
||||||
examples["sentence"], truncation=True, max_length=128
|
|
||||||
)
|
|
||||||
|
|
||||||
def tokenize_roberta(examples):
|
|
||||||
return roberta_tokenizer(
|
|
||||||
examples["sentence"], truncation=True, max_length=128
|
|
||||||
)
|
|
||||||
|
|
||||||
# Tokenize datasets for each model
|
|
||||||
logger.info("Tokenizing datasets...")
|
|
||||||
bert_dataset = train_dataset.map(tokenize_bert, batched=True)
|
|
||||||
gpt2_dataset = train_dataset.map(tokenize_gpt2, batched=True)
|
|
||||||
roberta_dataset = train_dataset.map(
|
|
||||||
tokenize_roberta, batched=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set format for PyTorch
|
|
||||||
bert_dataset.set_format(
|
|
||||||
type="torch",
|
|
||||||
columns=["input_ids", "attention_mask", "label"],
|
|
||||||
)
|
|
||||||
gpt2_dataset.set_format(
|
|
||||||
type="torch",
|
|
||||||
columns=["input_ids", "attention_mask", "label"],
|
|
||||||
)
|
|
||||||
roberta_dataset.set_format(
|
|
||||||
type="torch",
|
|
||||||
columns=["input_ids", "attention_mask", "label"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create data collators
|
|
||||||
bert_data_collator = DataCollatorWithPadding(
|
|
||||||
tokenizer=bert_tokenizer
|
|
||||||
)
|
|
||||||
gpt2_data_collator = DataCollatorWithPadding(
|
|
||||||
tokenizer=gpt2_tokenizer
|
|
||||||
)
|
|
||||||
roberta_data_collator = DataCollatorWithPadding(
|
|
||||||
tokenizer=roberta_tokenizer
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create dataloaders
|
|
||||||
logger.info("Creating dataloaders...")
|
|
||||||
batch_size = 16
|
|
||||||
bert_dataloader = DataLoader(
|
|
||||||
bert_dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
collate_fn=bert_data_collator,
|
|
||||||
)
|
|
||||||
gpt2_dataloader = DataLoader(
|
|
||||||
gpt2_dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
collate_fn=gpt2_data_collator,
|
|
||||||
)
|
|
||||||
roberta_dataloader = DataLoader(
|
|
||||||
roberta_dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
collate_fn=roberta_data_collator,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load models for sequence classification
|
|
||||||
logger.info(
|
|
||||||
"Loading transformer models for sequence classification..."
|
|
||||||
)
|
|
||||||
models = {
|
|
||||||
"bert": BertForSequenceClassification.from_pretrained(
|
|
||||||
"bert-base-uncased", num_labels=2
|
|
||||||
),
|
|
||||||
"gpt2": GPT2ForSequenceClassification.from_pretrained(
|
|
||||||
"gpt2", num_labels=2
|
|
||||||
),
|
|
||||||
"roberta": RobertaForSequenceClassification.from_pretrained(
|
|
||||||
"roberta-base", num_labels=2
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Set up optimizer with different weights for each model
|
|
||||||
logger.info("Setting up MultiModelOptimizer...")
|
|
||||||
optimizer = MultiModelOptimizer(
|
|
||||||
models=models,
|
|
||||||
lr=3e-5,
|
|
||||||
betas=(0.9, 0.999),
|
|
||||||
weight_decay=0.01,
|
|
||||||
model_weights={"bert": 1.0, "gpt2": 0.7, "roberta": 1.3},
|
|
||||||
gradient_accumulation_steps=2,
|
|
||||||
clip_grad_norm=1.0,
|
|
||||||
warmup_steps=100,
|
|
||||||
grad_sync_frequency=50,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set up scheduler
|
|
||||||
scheduler = MultiModelScheduler(
|
|
||||||
optimizer=optimizer, total_steps=5000, warmup_steps=100
|
|
||||||
)
|
|
||||||
|
|
||||||
# Move models to GPU if available
|
|
||||||
device = torch.device(
|
|
||||||
"cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
)
|
|
||||||
logger.info(f"Using device: {device}")
|
|
||||||
|
|
||||||
for model_name, model in models.items():
|
|
||||||
models[model_name] = model.to(device)
|
|
||||||
|
|
||||||
# Create iterator function for each dataloader
|
|
||||||
def infinite_iterator(dataloader):
|
|
||||||
while True:
|
|
||||||
for batch in dataloader:
|
|
||||||
yield batch
|
|
||||||
|
|
||||||
bert_iter = infinite_iterator(bert_dataloader)
|
|
||||||
gpt2_iter = infinite_iterator(gpt2_dataloader)
|
|
||||||
roberta_iter = infinite_iterator(roberta_dataloader)
|
|
||||||
|
|
||||||
# Define metrics for tracking
|
|
||||||
from sklearn.metrics import accuracy_score
|
|
||||||
|
|
||||||
total_steps = 1000 # Total training steps
|
|
||||||
eval_every = 100 # Evaluate every 100 steps
|
|
||||||
best_accuracy = {"bert": 0.0, "gpt2": 0.0, "roberta": 0.0}
|
|
||||||
|
|
||||||
logger.info(f"Starting training for {total_steps} steps...")
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
for step in range(total_steps):
|
|
||||||
# Zero gradients
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
losses = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# For BERT
|
|
||||||
bert_batch = next(bert_iter)
|
|
||||||
bert_batch = {
|
|
||||||
k: v.to(device) for k, v in bert_batch.items()
|
|
||||||
}
|
|
||||||
bert_outputs = models["bert"](**bert_batch)
|
|
||||||
bert_loss = bert_outputs.loss
|
|
||||||
bert_loss.backward()
|
|
||||||
losses["bert"] = bert_loss.item()
|
|
||||||
|
|
||||||
# For GPT2
|
|
||||||
gpt2_batch = next(gpt2_iter)
|
|
||||||
gpt2_batch = {
|
|
||||||
k: v.to(device) for k, v in gpt2_batch.items()
|
|
||||||
}
|
|
||||||
gpt2_outputs = models["gpt2"](**gpt2_batch)
|
|
||||||
gpt2_loss = gpt2_outputs.loss
|
|
||||||
gpt2_loss.backward()
|
|
||||||
losses["gpt2"] = gpt2_loss.item()
|
|
||||||
|
|
||||||
# For RoBERTa
|
|
||||||
roberta_batch = next(roberta_iter)
|
|
||||||
roberta_batch = {
|
|
||||||
k: v.to(device) for k, v in roberta_batch.items()
|
|
||||||
}
|
|
||||||
roberta_outputs = models["roberta"](**roberta_batch)
|
|
||||||
roberta_loss = roberta_outputs.loss
|
|
||||||
roberta_loss.backward()
|
|
||||||
losses["roberta"] = roberta_loss.item()
|
|
||||||
|
|
||||||
# Log metrics
|
|
||||||
optimizer.log_metrics(losses)
|
|
||||||
|
|
||||||
# Step the optimizer and scheduler
|
|
||||||
optimizer.step()
|
|
||||||
scheduler.step()
|
|
||||||
|
|
||||||
# Update convergence rates periodically
|
|
||||||
if step % 100 == 0:
|
|
||||||
scheduler.update_convergence_rates(
|
|
||||||
optimizer.model_losses
|
|
||||||
)
|
|
||||||
|
|
||||||
# Evaluate periodically
|
|
||||||
if step > 0 and step % eval_every == 0:
|
|
||||||
logger.info(f"Evaluating at step {step}...")
|
|
||||||
|
|
||||||
# Create a small evaluation set
|
|
||||||
eval_dataset = sst2_dataset["validation"].select(
|
|
||||||
range(100)
|
|
||||||
)
|
|
||||||
|
|
||||||
for model_name, model in models.items():
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# Tokenize evaluation data based on model type
|
|
||||||
if model_name == "bert":
|
|
||||||
tokenizer = bert_tokenizer
|
|
||||||
tokenize_fn = tokenize_bert
|
|
||||||
elif model_name == "gpt2":
|
|
||||||
tokenizer = gpt2_tokenizer
|
|
||||||
tokenize_fn = tokenize_gpt2
|
|
||||||
else: # roberta
|
|
||||||
tokenizer = roberta_tokenizer
|
|
||||||
tokenize_fn = tokenize_roberta
|
|
||||||
|
|
||||||
eval_tokenized = eval_dataset.map(
|
|
||||||
tokenize_fn, batched=True
|
|
||||||
)
|
|
||||||
eval_tokenized.set_format(
|
|
||||||
type="torch",
|
|
||||||
columns=[
|
|
||||||
"input_ids",
|
|
||||||
"attention_mask",
|
|
||||||
"label",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create dataloader
|
|
||||||
eval_collator = DataCollatorWithPadding(
|
|
||||||
tokenizer=tokenizer
|
|
||||||
)
|
|
||||||
eval_dataloader = DataLoader(
|
|
||||||
eval_tokenized,
|
|
||||||
batch_size=16,
|
|
||||||
collate_fn=eval_collator,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Evaluate
|
|
||||||
all_preds = []
|
|
||||||
all_labels = []
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for batch in eval_dataloader:
|
|
||||||
batch = {
|
|
||||||
k: v.to(device)
|
|
||||||
for k, v in batch.items()
|
|
||||||
}
|
|
||||||
outputs = model(**batch)
|
|
||||||
logits = outputs.logits
|
|
||||||
preds = (
|
|
||||||
torch.argmax(logits, dim=-1)
|
|
||||||
.cpu()
|
|
||||||
.numpy()
|
|
||||||
)
|
|
||||||
labels = batch["label"].cpu().numpy()
|
|
||||||
|
|
||||||
all_preds.extend(preds)
|
|
||||||
all_labels.extend(labels)
|
|
||||||
|
|
||||||
# Calculate accuracy
|
|
||||||
accuracy = accuracy_score(
|
|
||||||
all_labels, all_preds
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Model {model_name} - Accuracy: {accuracy:.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save best model
|
|
||||||
if accuracy > best_accuracy[model_name]:
|
|
||||||
best_accuracy[model_name] = accuracy
|
|
||||||
torch.save(
|
|
||||||
model.state_dict(),
|
|
||||||
f"best_{model_name}_model.pt",
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Saved new best {model_name} model with accuracy {accuracy:.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
model.train()
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error during training step {step}: {e}"
|
|
||||||
)
|
|
||||||
logger.error("Skipping this step and continuing...")
|
|
||||||
optimizer.zero_grad()
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Save checkpoint every 500 steps
|
|
||||||
if step > 0 and step % 500 == 0:
|
|
||||||
logger.info(f"Saving checkpoint at step {step}...")
|
|
||||||
torch.save(
|
|
||||||
{
|
|
||||||
"step": step,
|
|
||||||
"model_states": {
|
|
||||||
name: model.state_dict()
|
|
||||||
for name, model in models.items()
|
|
||||||
},
|
|
||||||
"optimizer_state": optimizer.state_dict(),
|
|
||||||
"scheduler_state": scheduler.state_dict(),
|
|
||||||
"best_accuracy": best_accuracy,
|
|
||||||
},
|
|
||||||
f"checkpoint_step_{step}.pt",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Final evaluation and results
|
|
||||||
logger.info("=== Training complete! Final results ===")
|
|
||||||
for model_name, acc in best_accuracy.items():
|
|
||||||
logger.info(f"Best {model_name} accuracy: {acc:.4f}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Fatal error in example_usage_with_real_data: {e}"
|
|
||||||
)
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Use real data example by default
|
|
||||||
example_usage_with_real_data()
|
|
Loading…
Reference in new issue