parent
8b9e424dd5
commit
7a66cbd705
Can't render this file because it has a wrong number of fields in line 4.
|
@ -0,0 +1,666 @@
|
||||
"""
|
||||
GraphSwarm: A production-grade framework for orchestrating swarms of agents
|
||||
Author: Claude
|
||||
License: MIT
|
||||
Version: 2.0.0
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import chromadb
|
||||
import networkx as nx
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from swarms import Agent
|
||||
|
||||
|
||||
# Configure logging
|
||||
logger.add(
|
||||
"graphswarm.log",
|
||||
rotation="500 MB",
|
||||
retention="10 days",
|
||||
level="INFO",
|
||||
format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
|
||||
)
|
||||
|
||||
|
||||
class AgentOutput(BaseModel):
|
||||
"""Structured output from an agent."""
|
||||
|
||||
agent_name: str
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
output: Any
|
||||
execution_time: float
|
||||
error: Optional[str] = None
|
||||
metadata: Dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SwarmOutput(BaseModel):
|
||||
"""Structured output from the entire swarm."""
|
||||
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
outputs: Dict[str, AgentOutput]
|
||||
execution_time: float
|
||||
success: bool
|
||||
error: Optional[str] = None
|
||||
metadata: Dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SwarmMemory:
|
||||
"""Vector-based memory system for GraphSwarm using ChromaDB."""
|
||||
|
||||
def __init__(self, collection_name: str = "swarm_memories"):
|
||||
"""Initialize SwarmMemory with ChromaDB."""
|
||||
self.client = chromadb.Client()
|
||||
|
||||
# Get or create collection
|
||||
self.collection = self.client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
metadata={"description": "GraphSwarm execution memories"},
|
||||
)
|
||||
|
||||
def store_execution(self, task: str, result: SwarmOutput):
|
||||
"""Store execution results in vector memory."""
|
||||
try:
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"success": result.success,
|
||||
"execution_time": result.execution_time,
|
||||
"agent_sequence": json.dumps(
|
||||
[name for name in result.outputs.keys()]
|
||||
),
|
||||
"error": result.error if result.error else "",
|
||||
}
|
||||
|
||||
# Create document from outputs
|
||||
document = {
|
||||
"task": task,
|
||||
"outputs": json.dumps(
|
||||
{
|
||||
name: {
|
||||
"output": str(output.output),
|
||||
"execution_time": output.execution_time,
|
||||
"error": output.error,
|
||||
}
|
||||
for name, output in result.outputs.items()
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
# Store in ChromaDB
|
||||
self.collection.add(
|
||||
documents=[json.dumps(document)],
|
||||
metadatas=[metadata],
|
||||
ids=[f"exec_{datetime.now().timestamp()}"],
|
||||
)
|
||||
|
||||
print("added to database")
|
||||
|
||||
logger.info(f"Stored execution in memory: {task}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to store execution in memory: {str(e)}"
|
||||
)
|
||||
|
||||
def get_similar_executions(self, task: str, limit: int = 5):
|
||||
"""Retrieve similar past executions."""
|
||||
try:
|
||||
# Query ChromaDB for similar executions
|
||||
results = self.collection.query(
|
||||
query_texts=[task],
|
||||
n_results=limit,
|
||||
include=["documents", "metadatas"],
|
||||
)
|
||||
|
||||
print(results)
|
||||
|
||||
if not results["documents"]:
|
||||
return []
|
||||
|
||||
# Process results
|
||||
executions = []
|
||||
for doc, metadata in zip(
|
||||
results["documents"][0], results["metadatas"][0]
|
||||
):
|
||||
doc_dict = json.loads(doc)
|
||||
executions.append(
|
||||
{
|
||||
"task": doc_dict["task"],
|
||||
"outputs": json.loads(doc_dict["outputs"]),
|
||||
"success": metadata["success"],
|
||||
"execution_time": metadata["execution_time"],
|
||||
"agent_sequence": json.loads(
|
||||
metadata["agent_sequence"]
|
||||
),
|
||||
"timestamp": metadata["timestamp"],
|
||||
}
|
||||
)
|
||||
|
||||
return executions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to retrieve similar executions: {str(e)}"
|
||||
)
|
||||
return []
|
||||
|
||||
def get_optimal_sequence(self, task: str) -> Optional[List[str]]:
|
||||
"""Get the most successful agent sequence for similar tasks."""
|
||||
similar_executions = self.get_similar_executions(task)
|
||||
print(f"similar_executions {similar_executions}")
|
||||
|
||||
if not similar_executions:
|
||||
return None
|
||||
|
||||
# Sort by success and execution time
|
||||
successful_execs = [
|
||||
ex for ex in similar_executions if ex["success"]
|
||||
]
|
||||
|
||||
if not successful_execs:
|
||||
return None
|
||||
|
||||
# Return sequence from most successful execution
|
||||
return successful_execs[0]["agent_sequence"]
|
||||
|
||||
def clear_memory(self):
|
||||
"""Clear all memories."""
|
||||
self.client.delete_collection(self.collection.name)
|
||||
self.collection = self.client.get_or_create_collection(
|
||||
name=self.collection.name
|
||||
)
|
||||
|
||||
|
||||
class GraphSwarm:
|
||||
"""
|
||||
Enhanced framework for creating and managing swarms of collaborative agents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agents: Union[
|
||||
List[Agent], List[Tuple[Agent, List[str]]], None
|
||||
] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
swarm_name: str = "Collaborative Agent Swarm",
|
||||
memory_collection: str = "swarm_memory",
|
||||
):
|
||||
"""Initialize GraphSwarm."""
|
||||
self.graph = nx.DiGraph()
|
||||
self.agents: Dict[str, Agent] = {}
|
||||
self.dependencies: Dict[str, List[str]] = {}
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.swarm_name = swarm_name
|
||||
self.memory_collection = memory_collection
|
||||
self.memory = SwarmMemory(collection_name=memory_collection)
|
||||
|
||||
|
||||
if agents:
|
||||
self.initialize_agents(agents)
|
||||
|
||||
logger.info(f"Initialized GraphSwarm: {swarm_name}")
|
||||
|
||||
def initialize_agents(
|
||||
self,
|
||||
agents: Union[List[Agent], List[Tuple[Agent, List[str]]]],
|
||||
):
|
||||
"""Initialize agents and their dependencies."""
|
||||
try:
|
||||
# Handle list of Agents or (Agent, dependencies) tuples
|
||||
for item in agents:
|
||||
if isinstance(item, tuple):
|
||||
agent, dependencies = item
|
||||
else:
|
||||
agent, dependencies = item, []
|
||||
|
||||
if not isinstance(agent, Agent):
|
||||
raise ValueError(
|
||||
f"Expected Agent object, got {type(agent)}"
|
||||
)
|
||||
|
||||
self.agents[agent.agent_name] = agent
|
||||
self.dependencies[agent.agent_name] = dependencies
|
||||
self.graph.add_node(agent.agent_name, agent=agent)
|
||||
|
||||
# Add dependencies
|
||||
for dep in dependencies:
|
||||
if dep not in self.agents:
|
||||
raise ValueError(
|
||||
f"Dependency {dep} not found for agent {agent.agent_name}"
|
||||
)
|
||||
self.graph.add_edge(dep, agent.agent_name)
|
||||
|
||||
self._validate_graph()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize agents: {str(e)}")
|
||||
raise
|
||||
|
||||
def _validate_graph(self):
|
||||
"""Validate the agent dependency graph."""
|
||||
if not self.graph.nodes():
|
||||
raise ValueError("No agents added to swarm")
|
||||
|
||||
if not nx.is_directed_acyclic_graph(self.graph):
|
||||
cycles = list(nx.simple_cycles(self.graph))
|
||||
raise ValueError(
|
||||
f"Agent dependency graph contains cycles: {cycles}"
|
||||
)
|
||||
|
||||
def _get_agent_role_description(self, agent_name: str) -> str:
|
||||
"""Generate a description of the agent's role in the swarm."""
|
||||
predecessors = list(self.graph.predecessors(agent_name))
|
||||
successors = list(self.graph.successors(agent_name))
|
||||
position = (
|
||||
"initial"
|
||||
if not predecessors
|
||||
else ("final" if not successors else "intermediate")
|
||||
)
|
||||
|
||||
role = f"""You are {agent_name}, a specialized agent in the {self.swarm_name}.
|
||||
Position: {position} agent in the workflow
|
||||
|
||||
Your relationships:"""
|
||||
|
||||
if predecessors:
|
||||
role += (
|
||||
f"\nYou receive input from: {', '.join(predecessors)}"
|
||||
)
|
||||
if successors:
|
||||
role += f"\nYour output will be used by: {', '.join(successors)}"
|
||||
|
||||
return role
|
||||
|
||||
def _generate_workflow_context(self) -> str:
|
||||
"""Generate a description of the entire workflow."""
|
||||
execution_order = list(nx.topological_sort(self.graph))
|
||||
|
||||
workflow = f"""Workflow Overview of {self.swarm_name}:
|
||||
|
||||
Processing Order:
|
||||
{' -> '.join(execution_order)}
|
||||
|
||||
Agent Roles:
|
||||
"""
|
||||
|
||||
for agent_name in execution_order:
|
||||
predecessors = list(self.graph.predecessors(agent_name))
|
||||
successors = list(self.graph.successors(agent_name))
|
||||
|
||||
workflow += f"\n\n{agent_name}:"
|
||||
if predecessors:
|
||||
workflow += (
|
||||
f"\n- Receives from: {', '.join(predecessors)}"
|
||||
)
|
||||
if successors:
|
||||
workflow += f"\n- Sends to: {', '.join(successors)}"
|
||||
if not predecessors and not successors:
|
||||
workflow += "\n- Independent agent"
|
||||
|
||||
return workflow
|
||||
|
||||
def _build_agent_prompt(
|
||||
self, agent_name: str, task: str, context: Dict = None
|
||||
) -> str:
|
||||
"""Build a comprehensive prompt for the agent including role and context."""
|
||||
prompt_parts = [
|
||||
self._get_agent_role_description(agent_name),
|
||||
"\nWorkflow Context:",
|
||||
self._generate_workflow_context(),
|
||||
"\nYour Task:",
|
||||
task,
|
||||
]
|
||||
|
||||
if context:
|
||||
prompt_parts.extend(
|
||||
["\nContext from Previous Agents:", str(context)]
|
||||
)
|
||||
|
||||
prompt_parts.extend(
|
||||
[
|
||||
"\nInstructions:",
|
||||
"1. Process the task according to your role",
|
||||
"2. Consider the input from previous agents when available",
|
||||
"3. Provide clear, structured output",
|
||||
"4. Remember that your output will be used by subsequent agents",
|
||||
"\nResponse Guidelines:",
|
||||
"- Provide clear, well-organized output",
|
||||
"- Include relevant details and insights",
|
||||
"- Highlight key findings",
|
||||
"- Flag any uncertainties or issues",
|
||||
]
|
||||
)
|
||||
|
||||
return "\n".join(prompt_parts)
|
||||
|
||||
async def _execute_agent(
|
||||
self, agent_name: str, task: str, context: Dict = None
|
||||
) -> AgentOutput:
|
||||
"""Execute a single agent."""
|
||||
start_time = time.time()
|
||||
agent = self.agents[agent_name]
|
||||
|
||||
try:
|
||||
# Build comprehensive prompt
|
||||
full_prompt = self._build_agent_prompt(
|
||||
agent_name, task, context
|
||||
)
|
||||
logger.debug(f"Prompt for {agent_name}:\n{full_prompt}")
|
||||
|
||||
# Execute agent
|
||||
output = await asyncio.to_thread(agent.run, full_prompt)
|
||||
|
||||
return AgentOutput(
|
||||
agent_name=agent_name,
|
||||
output=output,
|
||||
execution_time=time.time() - start_time,
|
||||
metadata={
|
||||
"task": task,
|
||||
"context": context,
|
||||
"position_in_workflow": list(
|
||||
nx.topological_sort(self.graph)
|
||||
).index(agent_name),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error executing agent {agent_name}: {str(e)}"
|
||||
)
|
||||
return AgentOutput(
|
||||
agent_name=agent_name,
|
||||
output=None,
|
||||
execution_time=time.time() - start_time,
|
||||
error=str(e),
|
||||
metadata={"task": task},
|
||||
)
|
||||
|
||||
async def execute(self, task: str) -> SwarmOutput:
|
||||
"""
|
||||
Execute the entire swarm of agents with memory integration.
|
||||
|
||||
Args:
|
||||
task: Initial task to execute
|
||||
|
||||
Returns:
|
||||
SwarmOutput: Structured output from all agents
|
||||
"""
|
||||
start_time = time.time()
|
||||
outputs = {}
|
||||
success = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
# Get similar past executions
|
||||
similar_executions = self.memory.get_similar_executions(
|
||||
task, limit=3
|
||||
)
|
||||
optimal_sequence = self.memory.get_optimal_sequence(task)
|
||||
|
||||
# Get base execution order
|
||||
base_execution_order = list(
|
||||
nx.topological_sort(self.graph)
|
||||
)
|
||||
|
||||
# Determine final execution order
|
||||
if optimal_sequence and all(
|
||||
agent in base_execution_order
|
||||
for agent in optimal_sequence
|
||||
):
|
||||
logger.info(
|
||||
f"Using optimal sequence from memory: {optimal_sequence}"
|
||||
)
|
||||
execution_order = optimal_sequence
|
||||
else:
|
||||
execution_order = base_execution_order
|
||||
|
||||
# Get historical context if available
|
||||
historical_context = {}
|
||||
if similar_executions:
|
||||
best_execution = similar_executions[0]
|
||||
if best_execution["success"]:
|
||||
historical_context = {
|
||||
"similar_task": best_execution["task"],
|
||||
"previous_outputs": best_execution["outputs"],
|
||||
"execution_time": best_execution[
|
||||
"execution_time"
|
||||
],
|
||||
"success_patterns": self._extract_success_patterns(
|
||||
similar_executions
|
||||
),
|
||||
}
|
||||
|
||||
# Execute agents in order
|
||||
for agent_name in execution_order:
|
||||
try:
|
||||
# Get context from dependencies and history
|
||||
agent_context = {
|
||||
"dependencies": {
|
||||
dep: outputs[dep].output
|
||||
for dep in self.graph.predecessors(
|
||||
agent_name
|
||||
)
|
||||
if dep in outputs
|
||||
},
|
||||
"historical": historical_context,
|
||||
"position": execution_order.index(agent_name),
|
||||
"total_agents": len(execution_order),
|
||||
}
|
||||
|
||||
# Execute agent with enhanced context
|
||||
output = await self._execute_agent(
|
||||
agent_name, task, agent_context
|
||||
)
|
||||
outputs[agent_name] = output
|
||||
|
||||
# Update historical context with current execution
|
||||
if output.output:
|
||||
historical_context.update(
|
||||
{
|
||||
f"current_{agent_name}_output": output.output
|
||||
}
|
||||
)
|
||||
|
||||
# Check for errors
|
||||
if output.error:
|
||||
success = False
|
||||
error = f"Agent {agent_name} failed: {output.error}"
|
||||
|
||||
# Try to recover using memory
|
||||
if similar_executions:
|
||||
recovery_output = self._attempt_recovery(
|
||||
agent_name, task, similar_executions
|
||||
)
|
||||
if recovery_output:
|
||||
outputs[agent_name] = recovery_output
|
||||
success = True
|
||||
error = None
|
||||
continue
|
||||
break
|
||||
|
||||
except Exception as agent_error:
|
||||
logger.error(
|
||||
f"Error executing agent {agent_name}: {str(agent_error)}"
|
||||
)
|
||||
success = False
|
||||
error = f"Agent {agent_name} failed: {str(agent_error)}"
|
||||
break
|
||||
|
||||
# Create result
|
||||
result = SwarmOutput(
|
||||
outputs=outputs,
|
||||
execution_time=time.time() - start_time,
|
||||
success=success,
|
||||
error=error,
|
||||
metadata={
|
||||
"task": task,
|
||||
"used_optimal_sequence": optimal_sequence
|
||||
is not None,
|
||||
"similar_executions_found": len(
|
||||
similar_executions
|
||||
),
|
||||
"execution_order": execution_order,
|
||||
"historical_context_used": bool(
|
||||
historical_context
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# Store execution in memory
|
||||
await self._store_execution_async(task, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Swarm execution failed: {str(e)}")
|
||||
return SwarmOutput(
|
||||
outputs=outputs,
|
||||
execution_time=time.time() - start_time,
|
||||
success=False,
|
||||
error=str(e),
|
||||
metadata={"task": task},
|
||||
)
|
||||
|
||||
def run(self, task: str) -> SwarmOutput:
|
||||
"""Synchronous interface to execute the swarm."""
|
||||
return asyncio.run(self.execute(task))
|
||||
|
||||
def _extract_success_patterns(
|
||||
self, similar_executions: List[Dict]
|
||||
) -> Dict:
|
||||
"""Extract success patterns from similar executions."""
|
||||
patterns = {}
|
||||
successful_execs = [
|
||||
ex for ex in similar_executions if ex["success"]
|
||||
]
|
||||
|
||||
if successful_execs:
|
||||
patterns = {
|
||||
"common_sequences": self._find_common_sequences(
|
||||
successful_execs
|
||||
),
|
||||
"avg_execution_time": sum(
|
||||
ex["execution_time"] for ex in successful_execs
|
||||
)
|
||||
/ len(successful_execs),
|
||||
"successful_strategies": self._extract_strategies(
|
||||
successful_execs
|
||||
),
|
||||
}
|
||||
|
||||
return patterns
|
||||
|
||||
def _attempt_recovery(
|
||||
self,
|
||||
failed_agent: str,
|
||||
task: str,
|
||||
similar_executions: List[Dict],
|
||||
) -> Optional[AgentOutput]:
|
||||
"""Attempt to recover from failure using memory."""
|
||||
for execution in similar_executions:
|
||||
if (
|
||||
execution["success"]
|
||||
and failed_agent in execution["outputs"]
|
||||
):
|
||||
historical_output = execution["outputs"][failed_agent]
|
||||
|
||||
return AgentOutput(
|
||||
agent_name=failed_agent,
|
||||
output=historical_output["output"],
|
||||
execution_time=historical_output[
|
||||
"execution_time"
|
||||
],
|
||||
metadata={
|
||||
"recovered_from_memory": True,
|
||||
"original_task": execution["task"],
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
async def _store_execution_async(
|
||||
self, task: str, result: SwarmOutput
|
||||
):
|
||||
"""Asynchronously store execution in memory."""
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self.memory.store_execution, task, result
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to store execution in memory: {str(e)}"
|
||||
)
|
||||
|
||||
def add_agent(self, agent: Agent, dependencies: List[str] = None):
|
||||
"""Add a new agent to the swarm."""
|
||||
dependencies = dependencies or []
|
||||
self.agents[agent.agent_name] = agent
|
||||
self.dependencies[agent.agent_name] = dependencies
|
||||
self.graph.add_node(agent.agent_name, agent=agent)
|
||||
|
||||
for dep in dependencies:
|
||||
if dep not in self.agents:
|
||||
raise ValueError(f"Dependency {dep} not found")
|
||||
self.graph.add_edge(dep, agent.agent_name)
|
||||
|
||||
self._validate_graph()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# Create agents
|
||||
data_collector = Agent(
|
||||
agent_name="Market-Data-Collector",
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
streaming_on=True,
|
||||
)
|
||||
|
||||
trend_analyzer = Agent(
|
||||
agent_name="Market-Trend-Analyzer",
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
streaming_on=True,
|
||||
)
|
||||
|
||||
report_generator = Agent(
|
||||
agent_name="Investment-Report-Generator",
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
streaming_on=True,
|
||||
)
|
||||
|
||||
# Create swarm
|
||||
swarm = GraphSwarm(
|
||||
agents=[
|
||||
(data_collector, []),
|
||||
(trend_analyzer, ["Market-Data-Collector"]),
|
||||
(report_generator, ["Market-Trend-Analyzer"]),
|
||||
],
|
||||
swarm_name="Market Analysis Intelligence Network",
|
||||
)
|
||||
|
||||
# Run the swarm
|
||||
result = swarm.run(
|
||||
"Analyze current market trends for tech stocks and provide investment recommendations"
|
||||
)
|
||||
|
||||
# Print results
|
||||
print(f"Execution success: {result.success}")
|
||||
print(f"Total time: {result.execution_time:.2f} seconds")
|
||||
|
||||
for agent_name, output in result.outputs.items():
|
||||
print(f"\nAgent: {agent_name}")
|
||||
print(f"Output: {output.output}")
|
||||
if output.error:
|
||||
print(f"Error: {output.error}")
|
||||
except Exception as error:
|
||||
logger.error(error)
|
||||
raise error
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,7 @@
|
||||
from swarms import Agent
|
||||
|
||||
Agent(
|
||||
agent_name="Stock-Analysis-Agent",
|
||||
model_name="gpt-4o-mini",
|
||||
max_loops=1,
|
||||
).run("What are 5 hft algorithms")
|
@ -0,0 +1,276 @@
|
||||
import asyncio
|
||||
import pulsar
|
||||
|
||||
from pulsar import ConsumerType
|
||||
from loguru import logger
|
||||
from swarms import Agent
|
||||
from typing import List, Dict, Any
|
||||
import json
|
||||
|
||||
|
||||
class ScalableAsyncAgentSwarm:
|
||||
"""
|
||||
A scalable, asynchronous swarm of agents leveraging Apache Pulsar for inter-agent communication.
|
||||
Provides load balancing, health monitoring, dead letter queues, and centralized logging.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pulsar_url: str,
|
||||
topic: str,
|
||||
dlq_topic: str,
|
||||
agents_config: List[Dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Initializes the async swarm with agents.
|
||||
|
||||
Args:
|
||||
pulsar_url (str): The URL of the Apache Pulsar broker.
|
||||
topic (str): The main topic for task distribution.
|
||||
dlq_topic (str): The Dead Letter Queue topic for failed messages.
|
||||
agents_config (List[Dict[str, Any]]): List of agent configurations with `name`, `description`, and `model_name`.
|
||||
"""
|
||||
self.pulsar_url = pulsar_url
|
||||
self.topic = topic
|
||||
self.dlq_topic = dlq_topic
|
||||
self.agents_config = agents_config
|
||||
self.client = pulsar.Client(pulsar_url)
|
||||
self.consumer = self.client.subscribe(
|
||||
topic,
|
||||
subscription_name="swarm-task-sub",
|
||||
consumer_type=ConsumerType.Shared,
|
||||
)
|
||||
self.dlq_producer = self.client.create_producer(dlq_topic)
|
||||
self.response_logger = []
|
||||
self.agents = [
|
||||
self.create_agent(config) for config in agents_config
|
||||
]
|
||||
self.agent_index = 0
|
||||
|
||||
logger.info(
|
||||
"Swarm initialized with agents: {}",
|
||||
[agent["name"] for agent in agents_config],
|
||||
)
|
||||
|
||||
def create_agent(
|
||||
self, agent_config: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Creates a new agent configuration with asynchronous capabilities.
|
||||
|
||||
Args:
|
||||
agent_config (Dict[str, Any]): Configuration dictionary with agent details.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing agent metadata and functionality.
|
||||
"""
|
||||
agent_name = agent_config["name"]
|
||||
description = agent_config["description"]
|
||||
model_name = agent_config.get("model_name", "gpt-4o-mini")
|
||||
|
||||
class AsyncAgent:
|
||||
"""
|
||||
An asynchronous agent that processes tasks and communicates via Apache Pulsar.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, name: str, description: str, model_name: str
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.agent = Agent(
|
||||
agent_name=name,
|
||||
model_name=model_name,
|
||||
max_loops="auto",
|
||||
interactive=True,
|
||||
streaming_on=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Initialized agent '{name}' - {description}"
|
||||
)
|
||||
|
||||
async def process_task(
|
||||
self, message: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Processes a single task using the agent.
|
||||
|
||||
Args:
|
||||
message (str): The task message.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: JSON-formatted response.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
f"Agent {self.name} processing task: {message}"
|
||||
)
|
||||
response = await asyncio.to_thread(
|
||||
self.agent.run, message
|
||||
)
|
||||
logger.info(f"Agent {self.name} completed task.")
|
||||
return {
|
||||
"agent_name": self.name,
|
||||
"response": response,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Agent {self.name} encountered an error: {e}"
|
||||
)
|
||||
return {"agent_name": self.name, "error": str(e)}
|
||||
|
||||
return {
|
||||
"name": agent_name,
|
||||
"instance": AsyncAgent(
|
||||
agent_name, description, model_name
|
||||
),
|
||||
}
|
||||
|
||||
async def distribute_task(self, message: str):
|
||||
"""
|
||||
Distributes a task to the next available agent using round-robin.
|
||||
|
||||
Args:
|
||||
message (str): The task message.
|
||||
"""
|
||||
agent = self.agents[self.agent_index]
|
||||
self.agent_index = (self.agent_index + 1) % len(self.agents)
|
||||
|
||||
try:
|
||||
response = await agent["instance"].process_task(message)
|
||||
self.log_response(response)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing task by agent {agent['name']}: {e}"
|
||||
)
|
||||
self.send_to_dlq(message)
|
||||
|
||||
async def monitor_health(self):
|
||||
"""
|
||||
Periodically monitors the health of agents.
|
||||
"""
|
||||
while True:
|
||||
logger.info("Performing health check for all agents.")
|
||||
for agent in self.agents:
|
||||
logger.info(f"Agent {agent['name']} is online.")
|
||||
await asyncio.sleep(10)
|
||||
|
||||
def send_to_dlq(self, message: str):
|
||||
"""
|
||||
Sends a failed message to the Dead Letter Queue (DLQ).
|
||||
|
||||
Args:
|
||||
message (str): The message to send to the DLQ.
|
||||
"""
|
||||
try:
|
||||
self.dlq_producer.send(message.encode("utf-8"))
|
||||
logger.info("Message sent to Dead Letter Queue.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send message to DLQ: {e}")
|
||||
|
||||
def log_response(self, response: Dict[str, Any]):
|
||||
"""
|
||||
Logs the response to a centralized list for later analysis.
|
||||
|
||||
Args:
|
||||
response (Dict[str, Any]): The agent's response.
|
||||
"""
|
||||
self.response_logger.append(response)
|
||||
logger.info(f"Response logged: {response}")
|
||||
|
||||
async def listen_and_distribute(self):
|
||||
"""
|
||||
Listens to the main Pulsar topic and distributes tasks to agents.
|
||||
"""
|
||||
while True:
|
||||
msg = self.consumer.receive()
|
||||
try:
|
||||
message = msg.data().decode("utf-8")
|
||||
logger.info(f"Received task: {message}")
|
||||
await self.distribute_task(message)
|
||||
self.consumer.acknowledge(msg)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
self.send_to_dlq(msg.data().decode("utf-8"))
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
|
||||
async def run(self):
|
||||
"""
|
||||
Runs the swarm asynchronously with health monitoring and task distribution.
|
||||
"""
|
||||
logger.info("Starting the async swarm...")
|
||||
task_listener = asyncio.create_task(
|
||||
self.listen_and_distribute()
|
||||
)
|
||||
health_monitor = asyncio.create_task(self.monitor_health())
|
||||
await asyncio.gather(task_listener, health_monitor)
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Safely shuts down the swarm and logs all responses.
|
||||
"""
|
||||
logger.info("Shutting down the swarm...")
|
||||
self.client.close()
|
||||
with open("responses.json", "w") as f:
|
||||
json.dump(self.response_logger, f, indent=4)
|
||||
logger.info("Responses saved to 'responses.json'.")
|
||||
|
||||
|
||||
# from scalable_agent_swarm import ScalableAsyncAgentSwarm # Assuming your swarm class is saved here
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example Configuration
|
||||
PULSAR_URL = "pulsar://localhost:6650"
|
||||
TOPIC = "stock-analysis"
|
||||
DLQ_TOPIC = "stock-analysis-dlq"
|
||||
|
||||
# Agents configuration
|
||||
AGENTS_CONFIG = [
|
||||
{
|
||||
"name": "Stock-Analysis-Agent-1",
|
||||
"description": "Analyzes stock trends.",
|
||||
"model_name": "gpt-4o-mini",
|
||||
},
|
||||
{
|
||||
"name": "Stock-News-Agent",
|
||||
"description": "Summarizes stock news.",
|
||||
"model_name": "gpt-4o-mini",
|
||||
},
|
||||
{
|
||||
"name": "Tech-Trends-Agent",
|
||||
"description": "Tracks tech sector trends.",
|
||||
"model_name": "gpt-4o-mini",
|
||||
},
|
||||
]
|
||||
|
||||
# Tasks to send
|
||||
TASKS = [
|
||||
"Analyze the trend for tech stocks in Q4 2024",
|
||||
"Summarize the latest news on the S&P 500",
|
||||
"Identify the top-performing sectors in the stock market",
|
||||
"Provide a forecast for AI-related stocks for 2025",
|
||||
]
|
||||
|
||||
# Initialize and run the swarm
|
||||
swarm = ScalableAsyncAgentSwarm(
|
||||
PULSAR_URL, TOPIC, DLQ_TOPIC, AGENTS_CONFIG
|
||||
)
|
||||
try:
|
||||
# Run the swarm in the background
|
||||
swarm_task = asyncio.create_task(swarm.run())
|
||||
|
||||
# Send tasks to the topic
|
||||
client = pulsar.Client(PULSAR_URL)
|
||||
producer = client.create_producer(TOPIC)
|
||||
|
||||
for task in TASKS:
|
||||
producer.send(task.encode("utf-8"))
|
||||
print(f"Sent task: {task}")
|
||||
|
||||
producer.close()
|
||||
client.close()
|
||||
|
||||
# Keep the swarm running
|
||||
asyncio.run(swarm_task)
|
||||
except KeyboardInterrupt:
|
||||
swarm.shutdown()
|
@ -0,0 +1,253 @@
|
||||
import re
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from swarms import Agent
|
||||
from swarms.agents.create_agents_from_yaml import (
|
||||
create_agents_from_yaml,
|
||||
)
|
||||
from swarms.utils.formatter import formatter
|
||||
from swarms.utils.litellm import LiteLLM
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def prepare_yaml_for_parsing(raw_yaml: str) -> str:
|
||||
"""
|
||||
Prepares raw YAML content by fixing spacing and formatting issues.
|
||||
|
||||
Args:
|
||||
raw_yaml (str): The raw YAML content extracted from Markdown.
|
||||
|
||||
Returns:
|
||||
str: The cleaned YAML content ready for parsing.
|
||||
"""
|
||||
# Fix sequence items that are improperly placed on the same line as their key
|
||||
fixed_yaml = re.sub(
|
||||
r"(\b\w+\b):\s*-\s*", r"\1:\n - ", raw_yaml
|
||||
) # Fix "key: - value" to "key:\n - value"
|
||||
|
||||
# Ensure proper spacing after colons
|
||||
fixed_yaml = re.sub(
|
||||
r"(\S):(\S)", r"\1: \2", fixed_yaml
|
||||
) # Ensure space after colons
|
||||
|
||||
# Remove trailing spaces before newlines
|
||||
fixed_yaml = re.sub(r"\s+\n", "\n", fixed_yaml)
|
||||
|
||||
# Replace non-breaking spaces (if any) with regular spaces
|
||||
fixed_yaml = fixed_yaml.replace("\xa0", " ")
|
||||
|
||||
return fixed_yaml.strip()
|
||||
|
||||
|
||||
def parse_yaml_from_swarm_markdown(markdown_text: str) -> dict:
|
||||
"""
|
||||
Extracts and prepares YAML content from a Markdown-style 'Auto-Swarm-Builder' block and parses it.
|
||||
|
||||
Args:
|
||||
markdown_text (str): The Markdown text containing the YAML inside 'Auto-Swarm-Builder' block.
|
||||
|
||||
Returns:
|
||||
dict: A parsed Python dictionary of the YAML content.
|
||||
"""
|
||||
# Match the 'Auto-Swarm-Builder' block with YAML inside triple backticks
|
||||
pattern = r"```yaml\s*\n(.*?)```"
|
||||
match = re.search(pattern, markdown_text, re.DOTALL)
|
||||
|
||||
if not match:
|
||||
raise ValueError(
|
||||
"No YAML content found in the 'Auto-Swarm-Builder' block."
|
||||
)
|
||||
|
||||
raw_yaml = match.group(1).strip()
|
||||
|
||||
# Preprocess and normalize the YAML content
|
||||
normalized_yaml = prepare_yaml_for_parsing(raw_yaml)
|
||||
|
||||
return normalized_yaml
|
||||
|
||||
|
||||
AUTO_GEN_PROMPT = """
|
||||
You are a specialized agent responsible for creating YAML configuration files for multi-agent swarms. Your role is to generate well-structured YAML that defines both individual agents and swarm architectures based on user requirements.
|
||||
Output only the yaml nothing else. You will be penalized for making mistakes
|
||||
|
||||
GUIDELINES:
|
||||
1. Each YAML file must contain an `agents` section with at least one agent configuration
|
||||
2. Each agent configuration requires the following mandatory fields:
|
||||
- agent_name (string)
|
||||
- system_prompt (string)
|
||||
|
||||
3. Optional agent fields include:
|
||||
- max_loops (integer)
|
||||
- autosave (boolean)
|
||||
- dashboard (boolean)
|
||||
- verbose (boolean)
|
||||
- dynamic_temperature_enabled (boolean)
|
||||
- saved_state_path (string)
|
||||
- user_name (string)
|
||||
- retry_attempts (integer)
|
||||
- context_length (integer)
|
||||
- return_step_meta (boolean)
|
||||
- output_type (string)
|
||||
- task (string)
|
||||
|
||||
4. When a swarm is needed, include a `swarm_architecture` section with:
|
||||
Mandatory fields:
|
||||
- name (string)
|
||||
- swarm_type (string: "ConcurrentWorkflow" or "SequentialWorkflow") [AgentRearrange, MixtureOfAgents, SpreadSheetSwarm, SequentialWorkflow, ConcurrentWorkflow]
|
||||
|
||||
Optional fields:
|
||||
- description (string)
|
||||
- max_loops (integer)
|
||||
- task (string)
|
||||
|
||||
TEMPLATE STRUCTURE:
|
||||
```yaml
|
||||
agents:
|
||||
- agent_name: "Agent-1-Name"
|
||||
system_prompt: "Detailed system prompt here"
|
||||
max_loops: 1
|
||||
# [additional optional fields]
|
||||
|
||||
- agent_name: "Agent-2-Name"
|
||||
system_prompt: "Detailed system prompt here"
|
||||
# [additional optional fields]
|
||||
|
||||
swarm_architecture:
|
||||
name: "Swarm-Name"
|
||||
description: "Swarm purpose and goals"
|
||||
swarm_type: "ConcurrentWorkflow"
|
||||
max_loops: 5
|
||||
task: "Main swarm task description"
|
||||
```
|
||||
|
||||
VALIDATION RULES:
|
||||
1. All agent names must be unique
|
||||
2. System prompts must be clear and specific to the agent's role
|
||||
3. Integer values must be positive
|
||||
4. Boolean values must be true or false (lowercase)
|
||||
5. File paths should use forward slashes
|
||||
6. Tasks should be specific and aligned with the agent/swarm purpose
|
||||
|
||||
When generating a YAML configuration:
|
||||
1. Ask for specific requirements about the agents and swarm needed
|
||||
2. Determine if a swarm architecture is necessary based on the task complexity
|
||||
3. Generate appropriate system prompts for each agent based on their roles
|
||||
4. Include relevant optional fields based on the use case
|
||||
5. Validate the configuration against all rules before returning
|
||||
|
||||
Example valid YAML configurations are provided below. Use these as references for structure and formatting:
|
||||
|
||||
```yaml
|
||||
|
||||
|
||||
agents:
|
||||
- agent_name: "Data-Analysis-Agent"
|
||||
system_prompt: "You are a specialized data analysis agent focused on processing and interpreting financial data. Provide clear, actionable insights based on the data provided."
|
||||
max_loops: 3
|
||||
autosave: true
|
||||
verbose: true
|
||||
context_length: 100000
|
||||
output_type: "json"
|
||||
task: "Analyze quarterly financial reports and identify trends"
|
||||
|
||||
# Multi-Agent Swarm Example
|
||||
agents:
|
||||
- agent_name: "Research-Agent"
|
||||
system_prompt: "You are a research agent specialized in gathering and summarizing scientific publications. Focus on peer-reviewed sources and provide comprehensive summaries."
|
||||
max_loops: 2
|
||||
context_length: 150000
|
||||
output_type: "str"
|
||||
|
||||
- agent_name: "Analysis-Agent"
|
||||
system_prompt: "You are an analysis agent that processes research summaries and identifies key patterns and insights. Provide detailed analytical reports."
|
||||
max_loops: 3
|
||||
context_length: 200000
|
||||
output_type: "json"
|
||||
|
||||
swarm_architecture:
|
||||
name: "Research-Analysis-Swarm"
|
||||
description: "A swarm for comprehensive research analysis and insight generation"
|
||||
swarm_type: "SequentialWorkflow"
|
||||
max_loops: 5
|
||||
task: "Research and analyze recent developments in quantum computing"
|
||||
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def generate_swarm_config(
|
||||
task: str,
|
||||
file_name: str = "swarm_config_output.yaml",
|
||||
model_name: str = "gpt-4o",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Generates a swarm configuration based on the provided task and model name.
|
||||
|
||||
This function attempts to generate a swarm configuration by running an agent with the specified task and model name.
|
||||
It then parses the output into YAML format and creates agents based on the parsed YAML content.
|
||||
|
||||
Args:
|
||||
task (str): The task to be performed by the swarm.
|
||||
file_name (str, optional): The file name for the output YAML configuration. Defaults to "swarm_config_output.yaml".
|
||||
model_name (str, optional): The name of the model to use for the agent. Defaults to "gpt-4o".
|
||||
*args: Additional positional arguments to be passed to the agent's run method.
|
||||
**kwargs: Additional keyword arguments to be passed to the agent's run method.
|
||||
|
||||
Returns:
|
||||
Any: The output of the swarm configuration generation process. This can be a SwarmRouter instance or an error message.
|
||||
"""
|
||||
formatter.print_panel(
|
||||
"Auto Generating Swarm...", "Auto Swarm Builder"
|
||||
)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(min=4, max=10),
|
||||
)
|
||||
def attempt_generate_swarm_config():
|
||||
try:
|
||||
model = LiteLLM(model_name=model_name)
|
||||
|
||||
# Initialize the agent
|
||||
agent = Agent(
|
||||
agent_name="Auto-Swarm-Builder",
|
||||
system_prompt=AUTO_GEN_PROMPT,
|
||||
llm=model,
|
||||
max_loops=1,
|
||||
dynamic_temperature_enabled=True,
|
||||
saved_state_path="swarm_builder.json",
|
||||
user_name="swarms_corp",
|
||||
output_type="str",
|
||||
)
|
||||
|
||||
# Generate output from the agent
|
||||
raw_output = agent.run(task, *args, **kwargs)
|
||||
yaml_content = parse_yaml_from_swarm_markdown(raw_output)
|
||||
print(yaml_content)
|
||||
|
||||
# Create agents from the YAML file
|
||||
output = create_agents_from_yaml(
|
||||
yaml_string=yaml_content,
|
||||
return_type="run_swarm",
|
||||
)
|
||||
|
||||
formatter.print_panel(
|
||||
"Swarm configuration generated successfully.",
|
||||
"Success",
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
formatter.print_panel(
|
||||
f"Error generating swarm configuration: {str(e)}",
|
||||
"Error",
|
||||
)
|
||||
raise
|
||||
|
||||
return attempt_generate_swarm_config()
|
@ -0,0 +1,105 @@
|
||||
try:
|
||||
from litellm import completion
|
||||
except ImportError:
|
||||
import subprocess
|
||||
|
||||
subprocess.check_call(["pip", "install", "litellm"])
|
||||
import litellm
|
||||
from litellm import completion
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
|
||||
class LiteLLM:
|
||||
"""
|
||||
This class represents a LiteLLM.
|
||||
It is used to interact with the LLM model for various tasks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "gpt-4o",
|
||||
system_prompt: str = None,
|
||||
stream: bool = False,
|
||||
temperature: float = 0.5,
|
||||
max_tokens: int = 4000,
|
||||
):
|
||||
"""
|
||||
Initialize the LiteLLM with the given parameters.
|
||||
|
||||
Args:
|
||||
model_name (str, optional): The name of the model to use. Defaults to "gpt-4o".
|
||||
system_prompt (str, optional): The system prompt to use. Defaults to None.
|
||||
stream (bool, optional): Whether to stream the output. Defaults to False.
|
||||
temperature (float, optional): The temperature for the model. Defaults to 0.5.
|
||||
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 4000.
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.system_prompt = system_prompt
|
||||
self.stream = stream
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def _prepare_messages(self, task: str) -> list:
|
||||
"""
|
||||
Prepare the messages for the given task.
|
||||
|
||||
Args:
|
||||
task (str): The task to prepare messages for.
|
||||
|
||||
Returns:
|
||||
list: A list of messages prepared for the task.
|
||||
"""
|
||||
messages = []
|
||||
|
||||
if self.system_prompt: # Check if system_prompt is not None
|
||||
messages.append(
|
||||
{"role": "system", "content": self.system_prompt}
|
||||
)
|
||||
|
||||
messages.append({"role": "user", "content": task})
|
||||
|
||||
return messages
|
||||
|
||||
def run(self, task: str, *args, **kwargs):
|
||||
"""
|
||||
Run the LLM model for the given task.
|
||||
|
||||
Args:
|
||||
task (str): The task to run the model for.
|
||||
*args: Additional positional arguments to pass to the model.
|
||||
**kwargs: Additional keyword arguments to pass to the model.
|
||||
|
||||
Returns:
|
||||
str: The content of the response from the model.
|
||||
"""
|
||||
messages = self._prepare_messages(task)
|
||||
|
||||
response = completion(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
stream=self.stream,
|
||||
temperature=self.temperature,
|
||||
# max_completion_tokens=self.max_tokens,
|
||||
max_tokens=self.max_tokens,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
content = response.choices[
|
||||
0
|
||||
].message.content # Accessing the content
|
||||
return content
|
||||
|
||||
def __call__(self, task: str, *args, **kwargs):
|
||||
"""
|
||||
Call the LLM model for the given task.
|
||||
|
||||
Args:
|
||||
task (str): The task to run the model for.
|
||||
*args: Additional positional arguments to pass to the model.
|
||||
**kwargs: Additional keyword arguments to pass to the model.
|
||||
|
||||
Returns:
|
||||
str: The content of the response from the model.
|
||||
"""
|
||||
return self.run(task, *args, **kwargs)
|
@ -1,50 +1,64 @@
|
||||
import re
|
||||
|
||||
|
||||
def extract_code_from_markdown(markdown_content: str) -> str:
|
||||
def extract_code_blocks_with_language(markdown_text: str):
|
||||
"""
|
||||
Extracts code blocks from a Markdown string and returns them as a single string.
|
||||
Extracts all code blocks from Markdown text along with their languages.
|
||||
|
||||
Args:
|
||||
- markdown_content (str): The Markdown content as a string.
|
||||
markdown_text (str): The input Markdown text.
|
||||
|
||||
Returns:
|
||||
- str: A single string containing all the code blocks separated by newlines.
|
||||
list[dict]: A list of dictionaries, each containing:
|
||||
- 'language': The detected language (or 'plaintext' if none specified).
|
||||
- 'content': The content of the code block.
|
||||
"""
|
||||
# Regular expression for fenced code blocks with optional language specifier
|
||||
pattern = r"```(?:\w+\n)?(.*?)```"
|
||||
# Regex pattern to match code blocks and optional language specifiers
|
||||
pattern = r"```(\w+)?\n(.*?)```"
|
||||
|
||||
# Check if markdown_content is a string
|
||||
if not isinstance(markdown_content, str):
|
||||
raise TypeError("markdown_content must be a string")
|
||||
# Find all matches (language and content)
|
||||
matches = re.findall(pattern, markdown_text, re.DOTALL)
|
||||
|
||||
# Find all matches of the pattern
|
||||
matches = re.finditer(pattern, markdown_content, re.DOTALL)
|
||||
|
||||
# Extract the content inside the backticks
|
||||
# Parse results
|
||||
code_blocks = []
|
||||
for match in matches:
|
||||
code_block = match.group(1).strip()
|
||||
# Remove any leading or trailing whitespace from the code block
|
||||
code_block = code_block.strip()
|
||||
# Remove any empty lines from the code block
|
||||
code_block = "\n".join(
|
||||
[line for line in code_block.split("\n") if line.strip()]
|
||||
for language, content in matches:
|
||||
language = (
|
||||
language.strip() if language else "plaintext"
|
||||
) # Default to 'plaintext'
|
||||
code_blocks.append(
|
||||
{"language": language, "content": content.strip()}
|
||||
)
|
||||
code_blocks.append(code_block)
|
||||
|
||||
# Concatenate all code blocks separated by newlines
|
||||
if code_blocks:
|
||||
return "\n\n".join(code_blocks)
|
||||
else:
|
||||
return ""
|
||||
return code_blocks
|
||||
|
||||
|
||||
def extract_code_from_markdown(
|
||||
markdown_text: str, language: str = None
|
||||
):
|
||||
"""
|
||||
Extracts content of code blocks for a specific language or all blocks if no language specified.
|
||||
|
||||
Args:
|
||||
markdown_text (str): The input Markdown text.
|
||||
language (str, optional): The language to filter by (e.g., 'yaml', 'python').
|
||||
|
||||
# example = """
|
||||
# hello im an agent
|
||||
# ```bash
|
||||
# pip install swarms
|
||||
# ```
|
||||
# """
|
||||
Returns:
|
||||
str: The concatenated content of matched code blocks or an empty string if none found.
|
||||
"""
|
||||
# Get all code blocks with detected languages
|
||||
code_blocks = extract_code_blocks_with_language(markdown_text)
|
||||
|
||||
# Filter by language if specified
|
||||
if language:
|
||||
code_blocks = [
|
||||
block["content"]
|
||||
for block in code_blocks
|
||||
if block["language"] == language
|
||||
]
|
||||
else:
|
||||
code_blocks = [
|
||||
block["content"] for block in code_blocks
|
||||
] # Include all blocks
|
||||
|
||||
# print(extract_code_from_markdown(example)) # Output: { "type": "function", "function": { "name": "fetch_financial_news", "parameters": { "query": "Nvidia news", "num_articles": 5 } } }
|
||||
# Return concatenated content
|
||||
return "\n\n".join(code_blocks) if code_blocks else ""
|
||||
|
@ -1,292 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from loguru import logger
|
||||
import math
|
||||
|
||||
|
||||
@dataclass
|
||||
class StarAttentionConfig:
|
||||
"""Configuration for StarAttention module.
|
||||
|
||||
Attributes:
|
||||
hidden_size: Dimension of the model's hidden states
|
||||
num_attention_heads: Number of attention heads
|
||||
num_hosts: Number of hosts in the distributed system
|
||||
block_size: Size of each context block
|
||||
anchor_size: Size of the anchor block
|
||||
dropout_prob: Dropout probability (default: 0.1)
|
||||
layer_norm_eps: Layer normalization epsilon (default: 1e-12)
|
||||
"""
|
||||
|
||||
hidden_size: int
|
||||
num_attention_heads: int
|
||||
num_hosts: int
|
||||
block_size: int
|
||||
anchor_size: int
|
||||
dropout_prob: float = 0.1
|
||||
layer_norm_eps: float = 1e-12
|
||||
|
||||
|
||||
class StarAttention(nn.Module):
|
||||
"""
|
||||
Implementation of Star Attention mechanism for distributed inference.
|
||||
|
||||
The module implements a two-phase attention mechanism:
|
||||
1. Local Context Encoding with Anchor Blocks
|
||||
2. Query Encoding and Output Generation with Global Attention
|
||||
"""
|
||||
|
||||
def __init__(self, config: StarAttentionConfig):
|
||||
super().__init__()
|
||||
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {config.hidden_size} not divisible by number of attention "
|
||||
f"heads {config.num_attention_heads}"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.head_dim = (
|
||||
config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
|
||||
# Initialize components
|
||||
self.query = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.key = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.value = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.dropout_prob)
|
||||
self.layer_norm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
# KV cache for storing computed key/value pairs
|
||||
self.kv_cache = {}
|
||||
|
||||
logger.info(
|
||||
f"Initialized StarAttention with config: {config}"
|
||||
)
|
||||
|
||||
def _split_heads(
|
||||
self, tensor: torch.Tensor, num_heads: int
|
||||
) -> torch.Tensor:
|
||||
"""Split the last dimension into (num_heads, head_dim)."""
|
||||
batch_size, seq_len, _ = tensor.size()
|
||||
tensor = tensor.view(
|
||||
batch_size, seq_len, num_heads, self.head_dim
|
||||
)
|
||||
# Transpose to (batch_size, num_heads, seq_len, head_dim)
|
||||
return tensor.transpose(1, 2)
|
||||
|
||||
def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Merge the head dimension back into hidden_size."""
|
||||
batch_size, _, seq_len, _ = tensor.size()
|
||||
tensor = tensor.transpose(1, 2)
|
||||
return tensor.reshape(
|
||||
batch_size, seq_len, self.config.hidden_size
|
||||
)
|
||||
|
||||
def _compute_attention_scores(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute attention scores and weighted values."""
|
||||
# Scale dot-product attention
|
||||
scores = torch.matmul(
|
||||
query, key.transpose(-2, -1)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, float("-inf"))
|
||||
|
||||
# Online softmax computation
|
||||
attention_probs = torch.nn.functional.softmax(scores, dim=-1)
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
context = torch.matmul(attention_probs, value)
|
||||
|
||||
return context, attention_probs
|
||||
|
||||
def phase1_local_context_encoding(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
host_id: int,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
) -> None:
|
||||
"""
|
||||
Phase 1: Local Context Encoding with Anchor Blocks
|
||||
|
||||
Args:
|
||||
input_ids: Input tensor of shape (batch_size, seq_len)
|
||||
host_id: ID of the current host
|
||||
device: Device to run computations on
|
||||
"""
|
||||
logger.debug(f"Starting Phase 1 on host {host_id}")
|
||||
|
||||
# Calculate block assignments
|
||||
block_start = host_id * self.config.block_size
|
||||
block_end = block_start + self.config.block_size
|
||||
|
||||
# Get local block
|
||||
local_block = input_ids[:, block_start:block_end].to(device)
|
||||
|
||||
# Get anchor block (first block)
|
||||
anchor_block = input_ids[:, : self.config.anchor_size].to(
|
||||
device
|
||||
)
|
||||
|
||||
# Compute KV pairs for local block
|
||||
local_hidden = self.layer_norm(local_block)
|
||||
local_key = self._split_heads(
|
||||
self.key(local_hidden), self.config.num_attention_heads
|
||||
)
|
||||
local_value = self._split_heads(
|
||||
self.value(local_hidden), self.config.num_attention_heads
|
||||
)
|
||||
|
||||
# Store in KV cache
|
||||
self.kv_cache[host_id] = {
|
||||
"key": local_key,
|
||||
"value": local_value,
|
||||
"anchor_key": (
|
||||
None
|
||||
if host_id == 0
|
||||
else self._split_heads(
|
||||
self.key(self.layer_norm(anchor_block)),
|
||||
self.config.num_attention_heads,
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"Phase 1 complete on host {host_id}. KV cache shapes - "
|
||||
f"key: {local_key.shape}, value: {local_value.shape}"
|
||||
)
|
||||
|
||||
def phase2_query_encoding(
|
||||
self,
|
||||
query_input: torch.Tensor,
|
||||
host_id: int,
|
||||
is_query_host: bool,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Phase 2: Query Encoding and Output Generation
|
||||
|
||||
Args:
|
||||
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
|
||||
host_id: ID of the current host
|
||||
is_query_host: Whether this host is the query host
|
||||
device: Device to run computations on
|
||||
|
||||
Returns:
|
||||
Output tensor if this is the query host, None otherwise
|
||||
"""
|
||||
logger.debug(f"Starting Phase 2 on host {host_id}")
|
||||
|
||||
# Transform query
|
||||
query_hidden = self.layer_norm(query_input)
|
||||
query = self._split_heads(
|
||||
self.query(query_hidden), self.config.num_attention_heads
|
||||
)
|
||||
|
||||
# Compute local attention scores
|
||||
local_context, local_probs = self._compute_attention_scores(
|
||||
query,
|
||||
self.kv_cache[host_id]["key"],
|
||||
self.kv_cache[host_id]["value"],
|
||||
)
|
||||
|
||||
if not is_query_host:
|
||||
# Non-query hosts send their local attention statistics
|
||||
dist.send(local_probs, dst=self.config.num_hosts - 1)
|
||||
return None
|
||||
|
||||
# Query host aggregates attention from all hosts
|
||||
all_attention_probs = [local_probs]
|
||||
for src_rank in range(self.config.num_hosts - 1):
|
||||
probs = torch.empty_like(local_probs)
|
||||
dist.recv(probs, src=src_rank)
|
||||
all_attention_probs.append(probs)
|
||||
|
||||
# Compute global attention
|
||||
torch.mean(torch.stack(all_attention_probs), dim=0)
|
||||
|
||||
# Final output computation
|
||||
output = self._merge_heads(local_context)
|
||||
output = self.dropout(output)
|
||||
|
||||
logger.debug(
|
||||
f"Phase 2 complete on host {host_id}. Output shape: {output.shape}"
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
query_input: torch.Tensor,
|
||||
host_id: int,
|
||||
is_query_host: bool,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Forward pass of the StarAttention module.
|
||||
|
||||
Args:
|
||||
input_ids: Input tensor of shape (batch_size, seq_len)
|
||||
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
|
||||
host_id: ID of the current host
|
||||
is_query_host: Whether this host is the query host
|
||||
device: Device to run computations on
|
||||
|
||||
Returns:
|
||||
Output tensor if this is the query host, None otherwise
|
||||
"""
|
||||
# Phase 1: Local Context Encoding
|
||||
self.phase1_local_context_encoding(input_ids, host_id, device)
|
||||
|
||||
# Phase 2: Query Encoding and Output Generation
|
||||
return self.phase2_query_encoding(
|
||||
query_input, host_id, is_query_host, device
|
||||
)
|
||||
|
||||
|
||||
# Example forward pass
|
||||
config = StarAttentionConfig(
|
||||
hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
num_hosts=3,
|
||||
block_size=512,
|
||||
anchor_size=128,
|
||||
)
|
||||
|
||||
# Initialize model
|
||||
model = StarAttention(config)
|
||||
|
||||
# Example input tensors
|
||||
batch_size = 4
|
||||
seq_len = 512
|
||||
input_ids = torch.randint(
|
||||
0, 1000, (batch_size, seq_len)
|
||||
) # Random input IDs
|
||||
query_input = torch.randn(
|
||||
batch_size, seq_len, config.hidden_size
|
||||
) # Random query input
|
||||
|
||||
# Example forward pass for query host (host_id = 2)
|
||||
output = model(
|
||||
input_ids=input_ids,
|
||||
query_input=query_input,
|
||||
host_id=2,
|
||||
is_query_host=True,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
print(output)
|
Loading…
Reference in new issue