pull/652/head
Your Name 1 month ago
parent 8b9e424dd5
commit 7a66cbd705

@ -113,15 +113,37 @@ Here are some example scripts to get you started. For more comprehensive documen
| Swarms Examples | A collection of simple examples to demonstrate Swarms capabilities. | Basic Usage | [https://github.com/The-Swarm-Corporation/swarms-examples?tab=readme-ov-file](https://github.com/The-Swarm-Corporation/swarms-examples?tab=readme-ov-file) |
| Cookbook | A comprehensive guide with recipes for various use cases and scenarios. | Advanced Usage | [https://github.com/The-Swarm-Corporation/Cookbook](https://github.com/The-Swarm-Corporation/Cookbook) |
---
## `Agent` Class
The `Agent` class is a fundamental component of the Swarms framework, designed to execute tasks autonomously. It fuses llms, tools and long-term memory capabilities to create a full stack agent. The `Agent` class is highly customizable, allowing for fine-grained control over its behavior and interactions.
### `run` Method
The `run` method is the primary entry point for executing tasks with an `Agent` instance. It accepts a task string as the main input task and processes it according to the agent's configuration. And, it can also accept an `img` parameter such as `img="image_filepath.png` to process images if you have a VLM
The `run` method is the primary entry point for executing tasks with an `Agent` instance. It accepts a task string as the main input task and processes it according to the agent's configuration. And, it can also accept an `img` parameter such as `img="image_filepath.png` to process images if you have a VLM attached such as `GPT4VisionAPI`
## Simple Example
```python
from swarms import Agent
agent = Agent(
agent_name="Stock-Analysis-Agent",
model_name="gpt-4o-mini",
max_loops="auto",
interactive=True,
streaming_on=True,
)
agent.run("What is the current market trend for tech stocks?")
```
### Settings and Customization
The `Agent` class offers a range of settings to tailor its behavior to specific needs. Some key settings include:

@ -1,4 +0,0 @@
timestamp,transaction_hash,from_address,to_address,value_eth,gas_used,gas_price_gwei,block_number,analysis
2024-11-27T13:50:35,ddbb665bc75fe848e7ce3d3ce1729243e92466c38ca407deccce8bf629987652,0x267be1C1D684F78cb4F6a176C4911b741E4Ffdc0,0xa40dFEE99E1C85DC97Fdc594b16A460717838703,3200.0,21000,19.968163737,21281878,"Transaction Analysis: This transaction represents a significant transfer of value in the Ethereum network with 3200 ETH (~$6.72 million USD at the current rate) moved from one address to another. It is essential to note that this transaction did not involve smart contract interaction, suggesting it could be a straightforward transfer of funds rather than part of a more complex operation. Looking at the broader market context, large transactions like this can potentially indicate major investment activities or redistribution of assets, which can have ripple effects in the market. If this transaction is part of a larger pattern of significant transfers, it could suggest substantial liquidity moving in the Ethereum ecosystem, possibly affecting the ETH prices. From a DeFi point of view, since there's no contract interaction, it's difficult to infer any direct implications. However, given the substantial value involved, it could be a step in preparation for involvement in DeFi protocols or a move from one DeFi platform to another by a large investor. The transaction fee paid, calculated from the given Gas Used and Gas Price, appears to be within reasonable range. This suggests that the transaction was not rushed and that the sender was willing to wait for this transaction to be confirmed, which might hint towards the non-urgent nature of the transaction. As for potential risk factors or security concerns, the transaction itself appears to be standard and doesn't raise any immediate red flags. However, the parties involved should always be cautious about the address security, maintaining privacy, and avoiding social engineering attacks. For traders and investors, this transaction can be interpreted as a potential bullish sign if it signifies increased liquidity and investment in the Ethereum market, especially if it's followed by similar large transfers. However, due to the anonymous nature of the transaction, it's critical to combine this with other market indicators and not to rely solely on transaction analysis for investment decisions."
2024-11-27T13:52:23,b98bcbf6d57a158b67a126d8f023766e03fb15c3e74becc1189d4244fda61a13,0xEae7380dD4CeF6fbD1144F49E4D1e6964258A4F4,0x28C6c06298d514Db089934071355E5743bf21d60,401.99463589018103,21000,14.978063737,21281887,"Ethereum-Analysis-Agent: Transaction Analysis: This transaction marks a significant transfer of 401.99 ETH, approximately $845,000 at the current rate. The transaction did not involve any smart contract interaction, suggesting a simple fund transfer rather than a complicated operation or interaction with a DeFi protocol. From a broader market perspective, this transaction is meaningful but not as potentially impactful as larger transactions. It can nonetheless be part of a larger pattern of asset movement within the Ethereum ecosystem. If this transaction is part of larger investment activities, it could suggest an increase in demand for ETH and potentially impact its price. Without contract interaction, it's challenging to assess direct implications for DeFi protocols. However, the substantial ETH transfer could suggest a step towards participation in DeFi activities, or a movement of funds between different DeFi platforms. The transaction fee appears reasonable, given the Gas Used and Gas Price. This implies that the transaction wasn't urgent, and the sender was willing to wait for the transaction to be confirmed, indicating a non-critical movement of funds. In terms of security and risk factors, there are no immediate concerns from the transaction itself. Nevertheless, as with any crypto transaction, the parties involved should ensure secure storage of their keys, maintain privacy, and be wary of potential phishing or social engineering attacks. For traders and investors, this transaction could be seen as a bullish sign if it forms part of a trend of increased investment activities in the Ethereum market. However, it's important to remember that transaction analysis should be combined with other market indicators due to the anonymous nature of blockchain transactions."
2024-11-27T13:59:47,a985b74fd3dfee09cbe4a2e6890509e583a3f0ce13f68c98e82996e0f66428be,0xf7858Da8a6617f7C6d0fF2bcAFDb6D2eeDF64840,0xA294cCa691e4C83B1fc0c8d63D9a3eeF0A196DE1,136.0668,494665.408728,3635.46,21000,18.866443971,21281923,"1. MARKET CONTEXT The transaction of 136.07 ETH, equivalent to $494,665.41, is a significant movement in the Ethereum market. However, compared to the daily trading volume of Ethereum, which often exceeds billions of dollars, this transaction is not large enough to significantly impact the ETH price on its own. 2. BEHAVIORAL ANALYSIS The transaction does not appear to be a protocol movement as there is no contract interaction involved. It could be a whale movement, given the substantial amount of ETH transferred. However, without additional information about the wallets involved, it's difficult to definitively determine the nature of the transaction. The gas price of 18.87 Gwei is relatively standard, suggesting that the transaction was not urgent or time-sensitive. 3. RISK & IMPLICATIONS The transaction does not show signs of market manipulation or unusual activity. The absence of contract interaction suggests that this transaction does not directly involve DeFi protocols, reducing the risk of smart contract vulnerabilities or DeFi-related risks. However, the large amount of ETH transferred could potentially influence market sentiment if it is part of a larger trend of similar transactions. 4. STRATEGIC INSIGHTS Traders should note this transaction as part of the broader market activity. While a single transaction of this size is unlikely to significantly impact the market, a series of similar transactions could indicate a larger trend. If this is part of a larger movement of ETH out of exchanges, it could suggest a decrease in selling pressure, which could be bullish for ETH. Conversely, if this is part of a larger movement into exchanges, it could indicate an increase in selling pressure, which could be bearish for ETH. Traders should monitor the market for further similar transactions to gain a better understanding of the potential market trends."
Can't render this file because it has a wrong number of fields in line 4.

@ -0,0 +1,666 @@
"""
GraphSwarm: A production-grade framework for orchestrating swarms of agents
Author: Claude
License: MIT
Version: 2.0.0
"""
import asyncio
import json
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union
import chromadb
import networkx as nx
from loguru import logger
from pydantic import BaseModel, Field
from swarms import Agent
# Configure logging
logger.add(
"graphswarm.log",
rotation="500 MB",
retention="10 days",
level="INFO",
format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
)
class AgentOutput(BaseModel):
"""Structured output from an agent."""
agent_name: str
timestamp: float = Field(default_factory=time.time)
output: Any
execution_time: float
error: Optional[str] = None
metadata: Dict = Field(default_factory=dict)
class SwarmOutput(BaseModel):
"""Structured output from the entire swarm."""
timestamp: float = Field(default_factory=time.time)
outputs: Dict[str, AgentOutput]
execution_time: float
success: bool
error: Optional[str] = None
metadata: Dict = Field(default_factory=dict)
class SwarmMemory:
"""Vector-based memory system for GraphSwarm using ChromaDB."""
def __init__(self, collection_name: str = "swarm_memories"):
"""Initialize SwarmMemory with ChromaDB."""
self.client = chromadb.Client()
# Get or create collection
self.collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"description": "GraphSwarm execution memories"},
)
def store_execution(self, task: str, result: SwarmOutput):
"""Store execution results in vector memory."""
try:
# Create metadata
metadata = {
"timestamp": datetime.now().isoformat(),
"success": result.success,
"execution_time": result.execution_time,
"agent_sequence": json.dumps(
[name for name in result.outputs.keys()]
),
"error": result.error if result.error else "",
}
# Create document from outputs
document = {
"task": task,
"outputs": json.dumps(
{
name: {
"output": str(output.output),
"execution_time": output.execution_time,
"error": output.error,
}
for name, output in result.outputs.items()
}
),
}
# Store in ChromaDB
self.collection.add(
documents=[json.dumps(document)],
metadatas=[metadata],
ids=[f"exec_{datetime.now().timestamp()}"],
)
print("added to database")
logger.info(f"Stored execution in memory: {task}")
except Exception as e:
logger.error(
f"Failed to store execution in memory: {str(e)}"
)
def get_similar_executions(self, task: str, limit: int = 5):
"""Retrieve similar past executions."""
try:
# Query ChromaDB for similar executions
results = self.collection.query(
query_texts=[task],
n_results=limit,
include=["documents", "metadatas"],
)
print(results)
if not results["documents"]:
return []
# Process results
executions = []
for doc, metadata in zip(
results["documents"][0], results["metadatas"][0]
):
doc_dict = json.loads(doc)
executions.append(
{
"task": doc_dict["task"],
"outputs": json.loads(doc_dict["outputs"]),
"success": metadata["success"],
"execution_time": metadata["execution_time"],
"agent_sequence": json.loads(
metadata["agent_sequence"]
),
"timestamp": metadata["timestamp"],
}
)
return executions
except Exception as e:
logger.error(
f"Failed to retrieve similar executions: {str(e)}"
)
return []
def get_optimal_sequence(self, task: str) -> Optional[List[str]]:
"""Get the most successful agent sequence for similar tasks."""
similar_executions = self.get_similar_executions(task)
print(f"similar_executions {similar_executions}")
if not similar_executions:
return None
# Sort by success and execution time
successful_execs = [
ex for ex in similar_executions if ex["success"]
]
if not successful_execs:
return None
# Return sequence from most successful execution
return successful_execs[0]["agent_sequence"]
def clear_memory(self):
"""Clear all memories."""
self.client.delete_collection(self.collection.name)
self.collection = self.client.get_or_create_collection(
name=self.collection.name
)
class GraphSwarm:
"""
Enhanced framework for creating and managing swarms of collaborative agents.
"""
def __init__(
self,
agents: Union[
List[Agent], List[Tuple[Agent, List[str]]], None
] = None,
max_workers: Optional[int] = None,
swarm_name: str = "Collaborative Agent Swarm",
memory_collection: str = "swarm_memory",
):
"""Initialize GraphSwarm."""
self.graph = nx.DiGraph()
self.agents: Dict[str, Agent] = {}
self.dependencies: Dict[str, List[str]] = {}
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.swarm_name = swarm_name
self.memory_collection = memory_collection
self.memory = SwarmMemory(collection_name=memory_collection)
if agents:
self.initialize_agents(agents)
logger.info(f"Initialized GraphSwarm: {swarm_name}")
def initialize_agents(
self,
agents: Union[List[Agent], List[Tuple[Agent, List[str]]]],
):
"""Initialize agents and their dependencies."""
try:
# Handle list of Agents or (Agent, dependencies) tuples
for item in agents:
if isinstance(item, tuple):
agent, dependencies = item
else:
agent, dependencies = item, []
if not isinstance(agent, Agent):
raise ValueError(
f"Expected Agent object, got {type(agent)}"
)
self.agents[agent.agent_name] = agent
self.dependencies[agent.agent_name] = dependencies
self.graph.add_node(agent.agent_name, agent=agent)
# Add dependencies
for dep in dependencies:
if dep not in self.agents:
raise ValueError(
f"Dependency {dep} not found for agent {agent.agent_name}"
)
self.graph.add_edge(dep, agent.agent_name)
self._validate_graph()
except Exception as e:
logger.error(f"Failed to initialize agents: {str(e)}")
raise
def _validate_graph(self):
"""Validate the agent dependency graph."""
if not self.graph.nodes():
raise ValueError("No agents added to swarm")
if not nx.is_directed_acyclic_graph(self.graph):
cycles = list(nx.simple_cycles(self.graph))
raise ValueError(
f"Agent dependency graph contains cycles: {cycles}"
)
def _get_agent_role_description(self, agent_name: str) -> str:
"""Generate a description of the agent's role in the swarm."""
predecessors = list(self.graph.predecessors(agent_name))
successors = list(self.graph.successors(agent_name))
position = (
"initial"
if not predecessors
else ("final" if not successors else "intermediate")
)
role = f"""You are {agent_name}, a specialized agent in the {self.swarm_name}.
Position: {position} agent in the workflow
Your relationships:"""
if predecessors:
role += (
f"\nYou receive input from: {', '.join(predecessors)}"
)
if successors:
role += f"\nYour output will be used by: {', '.join(successors)}"
return role
def _generate_workflow_context(self) -> str:
"""Generate a description of the entire workflow."""
execution_order = list(nx.topological_sort(self.graph))
workflow = f"""Workflow Overview of {self.swarm_name}:
Processing Order:
{' -> '.join(execution_order)}
Agent Roles:
"""
for agent_name in execution_order:
predecessors = list(self.graph.predecessors(agent_name))
successors = list(self.graph.successors(agent_name))
workflow += f"\n\n{agent_name}:"
if predecessors:
workflow += (
f"\n- Receives from: {', '.join(predecessors)}"
)
if successors:
workflow += f"\n- Sends to: {', '.join(successors)}"
if not predecessors and not successors:
workflow += "\n- Independent agent"
return workflow
def _build_agent_prompt(
self, agent_name: str, task: str, context: Dict = None
) -> str:
"""Build a comprehensive prompt for the agent including role and context."""
prompt_parts = [
self._get_agent_role_description(agent_name),
"\nWorkflow Context:",
self._generate_workflow_context(),
"\nYour Task:",
task,
]
if context:
prompt_parts.extend(
["\nContext from Previous Agents:", str(context)]
)
prompt_parts.extend(
[
"\nInstructions:",
"1. Process the task according to your role",
"2. Consider the input from previous agents when available",
"3. Provide clear, structured output",
"4. Remember that your output will be used by subsequent agents",
"\nResponse Guidelines:",
"- Provide clear, well-organized output",
"- Include relevant details and insights",
"- Highlight key findings",
"- Flag any uncertainties or issues",
]
)
return "\n".join(prompt_parts)
async def _execute_agent(
self, agent_name: str, task: str, context: Dict = None
) -> AgentOutput:
"""Execute a single agent."""
start_time = time.time()
agent = self.agents[agent_name]
try:
# Build comprehensive prompt
full_prompt = self._build_agent_prompt(
agent_name, task, context
)
logger.debug(f"Prompt for {agent_name}:\n{full_prompt}")
# Execute agent
output = await asyncio.to_thread(agent.run, full_prompt)
return AgentOutput(
agent_name=agent_name,
output=output,
execution_time=time.time() - start_time,
metadata={
"task": task,
"context": context,
"position_in_workflow": list(
nx.topological_sort(self.graph)
).index(agent_name),
},
)
except Exception as e:
logger.error(
f"Error executing agent {agent_name}: {str(e)}"
)
return AgentOutput(
agent_name=agent_name,
output=None,
execution_time=time.time() - start_time,
error=str(e),
metadata={"task": task},
)
async def execute(self, task: str) -> SwarmOutput:
"""
Execute the entire swarm of agents with memory integration.
Args:
task: Initial task to execute
Returns:
SwarmOutput: Structured output from all agents
"""
start_time = time.time()
outputs = {}
success = True
error = None
try:
# Get similar past executions
similar_executions = self.memory.get_similar_executions(
task, limit=3
)
optimal_sequence = self.memory.get_optimal_sequence(task)
# Get base execution order
base_execution_order = list(
nx.topological_sort(self.graph)
)
# Determine final execution order
if optimal_sequence and all(
agent in base_execution_order
for agent in optimal_sequence
):
logger.info(
f"Using optimal sequence from memory: {optimal_sequence}"
)
execution_order = optimal_sequence
else:
execution_order = base_execution_order
# Get historical context if available
historical_context = {}
if similar_executions:
best_execution = similar_executions[0]
if best_execution["success"]:
historical_context = {
"similar_task": best_execution["task"],
"previous_outputs": best_execution["outputs"],
"execution_time": best_execution[
"execution_time"
],
"success_patterns": self._extract_success_patterns(
similar_executions
),
}
# Execute agents in order
for agent_name in execution_order:
try:
# Get context from dependencies and history
agent_context = {
"dependencies": {
dep: outputs[dep].output
for dep in self.graph.predecessors(
agent_name
)
if dep in outputs
},
"historical": historical_context,
"position": execution_order.index(agent_name),
"total_agents": len(execution_order),
}
# Execute agent with enhanced context
output = await self._execute_agent(
agent_name, task, agent_context
)
outputs[agent_name] = output
# Update historical context with current execution
if output.output:
historical_context.update(
{
f"current_{agent_name}_output": output.output
}
)
# Check for errors
if output.error:
success = False
error = f"Agent {agent_name} failed: {output.error}"
# Try to recover using memory
if similar_executions:
recovery_output = self._attempt_recovery(
agent_name, task, similar_executions
)
if recovery_output:
outputs[agent_name] = recovery_output
success = True
error = None
continue
break
except Exception as agent_error:
logger.error(
f"Error executing agent {agent_name}: {str(agent_error)}"
)
success = False
error = f"Agent {agent_name} failed: {str(agent_error)}"
break
# Create result
result = SwarmOutput(
outputs=outputs,
execution_time=time.time() - start_time,
success=success,
error=error,
metadata={
"task": task,
"used_optimal_sequence": optimal_sequence
is not None,
"similar_executions_found": len(
similar_executions
),
"execution_order": execution_order,
"historical_context_used": bool(
historical_context
),
},
)
# Store execution in memory
await self._store_execution_async(task, result)
return result
except Exception as e:
logger.error(f"Swarm execution failed: {str(e)}")
return SwarmOutput(
outputs=outputs,
execution_time=time.time() - start_time,
success=False,
error=str(e),
metadata={"task": task},
)
def run(self, task: str) -> SwarmOutput:
"""Synchronous interface to execute the swarm."""
return asyncio.run(self.execute(task))
def _extract_success_patterns(
self, similar_executions: List[Dict]
) -> Dict:
"""Extract success patterns from similar executions."""
patterns = {}
successful_execs = [
ex for ex in similar_executions if ex["success"]
]
if successful_execs:
patterns = {
"common_sequences": self._find_common_sequences(
successful_execs
),
"avg_execution_time": sum(
ex["execution_time"] for ex in successful_execs
)
/ len(successful_execs),
"successful_strategies": self._extract_strategies(
successful_execs
),
}
return patterns
def _attempt_recovery(
self,
failed_agent: str,
task: str,
similar_executions: List[Dict],
) -> Optional[AgentOutput]:
"""Attempt to recover from failure using memory."""
for execution in similar_executions:
if (
execution["success"]
and failed_agent in execution["outputs"]
):
historical_output = execution["outputs"][failed_agent]
return AgentOutput(
agent_name=failed_agent,
output=historical_output["output"],
execution_time=historical_output[
"execution_time"
],
metadata={
"recovered_from_memory": True,
"original_task": execution["task"],
},
)
return None
async def _store_execution_async(
self, task: str, result: SwarmOutput
):
"""Asynchronously store execution in memory."""
try:
await asyncio.to_thread(
self.memory.store_execution, task, result
)
except Exception as e:
logger.error(
f"Failed to store execution in memory: {str(e)}"
)
def add_agent(self, agent: Agent, dependencies: List[str] = None):
"""Add a new agent to the swarm."""
dependencies = dependencies or []
self.agents[agent.agent_name] = agent
self.dependencies[agent.agent_name] = dependencies
self.graph.add_node(agent.agent_name, agent=agent)
for dep in dependencies:
if dep not in self.agents:
raise ValueError(f"Dependency {dep} not found")
self.graph.add_edge(dep, agent.agent_name)
self._validate_graph()
if __name__ == "__main__":
try:
# Create agents
data_collector = Agent(
agent_name="Market-Data-Collector",
model_name="gpt-4o-mini",
max_loops=1,
streaming_on=True,
)
trend_analyzer = Agent(
agent_name="Market-Trend-Analyzer",
model_name="gpt-4o-mini",
max_loops=1,
streaming_on=True,
)
report_generator = Agent(
agent_name="Investment-Report-Generator",
model_name="gpt-4o-mini",
max_loops=1,
streaming_on=True,
)
# Create swarm
swarm = GraphSwarm(
agents=[
(data_collector, []),
(trend_analyzer, ["Market-Data-Collector"]),
(report_generator, ["Market-Trend-Analyzer"]),
],
swarm_name="Market Analysis Intelligence Network",
)
# Run the swarm
result = swarm.run(
"Analyze current market trends for tech stocks and provide investment recommendations"
)
# Print results
print(f"Execution success: {result.success}")
print(f"Total time: {result.execution_time:.2f} seconds")
for agent_name, output in result.outputs.items():
print(f"\nAgent: {agent_name}")
print(f"Output: {output.output}")
if output.error:
print(f"Error: {output.error}")
except Exception as error:
logger.error(error)
raise error

@ -7,7 +7,7 @@ from swarms import Agent
from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT,
)
from async_executor import HighSpeedExecutor
from new_features_examples.async_executor import HighSpeedExecutor
load_dotenv()

File diff suppressed because it is too large Load Diff

@ -0,0 +1,7 @@
from swarms import Agent
Agent(
agent_name="Stock-Analysis-Agent",
model_name="gpt-4o-mini",
max_loops=1,
).run("What are 5 hft algorithms")

@ -0,0 +1,276 @@
import asyncio
import pulsar
from pulsar import ConsumerType
from loguru import logger
from swarms import Agent
from typing import List, Dict, Any
import json
class ScalableAsyncAgentSwarm:
"""
A scalable, asynchronous swarm of agents leveraging Apache Pulsar for inter-agent communication.
Provides load balancing, health monitoring, dead letter queues, and centralized logging.
"""
def __init__(
self,
pulsar_url: str,
topic: str,
dlq_topic: str,
agents_config: List[Dict[str, Any]],
):
"""
Initializes the async swarm with agents.
Args:
pulsar_url (str): The URL of the Apache Pulsar broker.
topic (str): The main topic for task distribution.
dlq_topic (str): The Dead Letter Queue topic for failed messages.
agents_config (List[Dict[str, Any]]): List of agent configurations with `name`, `description`, and `model_name`.
"""
self.pulsar_url = pulsar_url
self.topic = topic
self.dlq_topic = dlq_topic
self.agents_config = agents_config
self.client = pulsar.Client(pulsar_url)
self.consumer = self.client.subscribe(
topic,
subscription_name="swarm-task-sub",
consumer_type=ConsumerType.Shared,
)
self.dlq_producer = self.client.create_producer(dlq_topic)
self.response_logger = []
self.agents = [
self.create_agent(config) for config in agents_config
]
self.agent_index = 0
logger.info(
"Swarm initialized with agents: {}",
[agent["name"] for agent in agents_config],
)
def create_agent(
self, agent_config: Dict[str, Any]
) -> Dict[str, Any]:
"""
Creates a new agent configuration with asynchronous capabilities.
Args:
agent_config (Dict[str, Any]): Configuration dictionary with agent details.
Returns:
Dict[str, Any]: A dictionary containing agent metadata and functionality.
"""
agent_name = agent_config["name"]
description = agent_config["description"]
model_name = agent_config.get("model_name", "gpt-4o-mini")
class AsyncAgent:
"""
An asynchronous agent that processes tasks and communicates via Apache Pulsar.
"""
def __init__(
self, name: str, description: str, model_name: str
):
self.name = name
self.description = description
self.agent = Agent(
agent_name=name,
model_name=model_name,
max_loops="auto",
interactive=True,
streaming_on=True,
)
logger.info(
f"Initialized agent '{name}' - {description}"
)
async def process_task(
self, message: str
) -> Dict[str, Any]:
"""
Processes a single task using the agent.
Args:
message (str): The task message.
Returns:
Dict[str, Any]: JSON-formatted response.
"""
try:
logger.info(
f"Agent {self.name} processing task: {message}"
)
response = await asyncio.to_thread(
self.agent.run, message
)
logger.info(f"Agent {self.name} completed task.")
return {
"agent_name": self.name,
"response": response,
}
except Exception as e:
logger.error(
f"Agent {self.name} encountered an error: {e}"
)
return {"agent_name": self.name, "error": str(e)}
return {
"name": agent_name,
"instance": AsyncAgent(
agent_name, description, model_name
),
}
async def distribute_task(self, message: str):
"""
Distributes a task to the next available agent using round-robin.
Args:
message (str): The task message.
"""
agent = self.agents[self.agent_index]
self.agent_index = (self.agent_index + 1) % len(self.agents)
try:
response = await agent["instance"].process_task(message)
self.log_response(response)
except Exception as e:
logger.error(
f"Error processing task by agent {agent['name']}: {e}"
)
self.send_to_dlq(message)
async def monitor_health(self):
"""
Periodically monitors the health of agents.
"""
while True:
logger.info("Performing health check for all agents.")
for agent in self.agents:
logger.info(f"Agent {agent['name']} is online.")
await asyncio.sleep(10)
def send_to_dlq(self, message: str):
"""
Sends a failed message to the Dead Letter Queue (DLQ).
Args:
message (str): The message to send to the DLQ.
"""
try:
self.dlq_producer.send(message.encode("utf-8"))
logger.info("Message sent to Dead Letter Queue.")
except Exception as e:
logger.error(f"Failed to send message to DLQ: {e}")
def log_response(self, response: Dict[str, Any]):
"""
Logs the response to a centralized list for later analysis.
Args:
response (Dict[str, Any]): The agent's response.
"""
self.response_logger.append(response)
logger.info(f"Response logged: {response}")
async def listen_and_distribute(self):
"""
Listens to the main Pulsar topic and distributes tasks to agents.
"""
while True:
msg = self.consumer.receive()
try:
message = msg.data().decode("utf-8")
logger.info(f"Received task: {message}")
await self.distribute_task(message)
self.consumer.acknowledge(msg)
except Exception as e:
logger.error(f"Error processing message: {e}")
self.send_to_dlq(msg.data().decode("utf-8"))
self.consumer.negative_acknowledge(msg)
async def run(self):
"""
Runs the swarm asynchronously with health monitoring and task distribution.
"""
logger.info("Starting the async swarm...")
task_listener = asyncio.create_task(
self.listen_and_distribute()
)
health_monitor = asyncio.create_task(self.monitor_health())
await asyncio.gather(task_listener, health_monitor)
def shutdown(self):
"""
Safely shuts down the swarm and logs all responses.
"""
logger.info("Shutting down the swarm...")
self.client.close()
with open("responses.json", "w") as f:
json.dump(self.response_logger, f, indent=4)
logger.info("Responses saved to 'responses.json'.")
# from scalable_agent_swarm import ScalableAsyncAgentSwarm # Assuming your swarm class is saved here
if __name__ == "__main__":
# Example Configuration
PULSAR_URL = "pulsar://localhost:6650"
TOPIC = "stock-analysis"
DLQ_TOPIC = "stock-analysis-dlq"
# Agents configuration
AGENTS_CONFIG = [
{
"name": "Stock-Analysis-Agent-1",
"description": "Analyzes stock trends.",
"model_name": "gpt-4o-mini",
},
{
"name": "Stock-News-Agent",
"description": "Summarizes stock news.",
"model_name": "gpt-4o-mini",
},
{
"name": "Tech-Trends-Agent",
"description": "Tracks tech sector trends.",
"model_name": "gpt-4o-mini",
},
]
# Tasks to send
TASKS = [
"Analyze the trend for tech stocks in Q4 2024",
"Summarize the latest news on the S&P 500",
"Identify the top-performing sectors in the stock market",
"Provide a forecast for AI-related stocks for 2025",
]
# Initialize and run the swarm
swarm = ScalableAsyncAgentSwarm(
PULSAR_URL, TOPIC, DLQ_TOPIC, AGENTS_CONFIG
)
try:
# Run the swarm in the background
swarm_task = asyncio.create_task(swarm.run())
# Send tasks to the topic
client = pulsar.Client(PULSAR_URL)
producer = client.create_producer(TOPIC)
for task in TASKS:
producer.send(task.encode("utf-8"))
print(f"Sent task: {task}")
producer.close()
client.close()
# Keep the swarm running
asyncio.run(swarm_task)
except KeyboardInterrupt:
swarm.shutdown()

@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "swarms"
version = "6.3.7"
version = "6.4.7"
description = "Swarms - Pytorch"
license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"]

@ -0,0 +1,253 @@
import re
from dotenv import load_dotenv
from tenacity import retry, stop_after_attempt, wait_exponential
from swarms import Agent
from swarms.agents.create_agents_from_yaml import (
create_agents_from_yaml,
)
from swarms.utils.formatter import formatter
from swarms.utils.litellm import LiteLLM
load_dotenv()
def prepare_yaml_for_parsing(raw_yaml: str) -> str:
"""
Prepares raw YAML content by fixing spacing and formatting issues.
Args:
raw_yaml (str): The raw YAML content extracted from Markdown.
Returns:
str: The cleaned YAML content ready for parsing.
"""
# Fix sequence items that are improperly placed on the same line as their key
fixed_yaml = re.sub(
r"(\b\w+\b):\s*-\s*", r"\1:\n - ", raw_yaml
) # Fix "key: - value" to "key:\n - value"
# Ensure proper spacing after colons
fixed_yaml = re.sub(
r"(\S):(\S)", r"\1: \2", fixed_yaml
) # Ensure space after colons
# Remove trailing spaces before newlines
fixed_yaml = re.sub(r"\s+\n", "\n", fixed_yaml)
# Replace non-breaking spaces (if any) with regular spaces
fixed_yaml = fixed_yaml.replace("\xa0", " ")
return fixed_yaml.strip()
def parse_yaml_from_swarm_markdown(markdown_text: str) -> dict:
"""
Extracts and prepares YAML content from a Markdown-style 'Auto-Swarm-Builder' block and parses it.
Args:
markdown_text (str): The Markdown text containing the YAML inside 'Auto-Swarm-Builder' block.
Returns:
dict: A parsed Python dictionary of the YAML content.
"""
# Match the 'Auto-Swarm-Builder' block with YAML inside triple backticks
pattern = r"```yaml\s*\n(.*?)```"
match = re.search(pattern, markdown_text, re.DOTALL)
if not match:
raise ValueError(
"No YAML content found in the 'Auto-Swarm-Builder' block."
)
raw_yaml = match.group(1).strip()
# Preprocess and normalize the YAML content
normalized_yaml = prepare_yaml_for_parsing(raw_yaml)
return normalized_yaml
AUTO_GEN_PROMPT = """
You are a specialized agent responsible for creating YAML configuration files for multi-agent swarms. Your role is to generate well-structured YAML that defines both individual agents and swarm architectures based on user requirements.
Output only the yaml nothing else. You will be penalized for making mistakes
GUIDELINES:
1. Each YAML file must contain an `agents` section with at least one agent configuration
2. Each agent configuration requires the following mandatory fields:
- agent_name (string)
- system_prompt (string)
3. Optional agent fields include:
- max_loops (integer)
- autosave (boolean)
- dashboard (boolean)
- verbose (boolean)
- dynamic_temperature_enabled (boolean)
- saved_state_path (string)
- user_name (string)
- retry_attempts (integer)
- context_length (integer)
- return_step_meta (boolean)
- output_type (string)
- task (string)
4. When a swarm is needed, include a `swarm_architecture` section with:
Mandatory fields:
- name (string)
- swarm_type (string: "ConcurrentWorkflow" or "SequentialWorkflow") [AgentRearrange, MixtureOfAgents, SpreadSheetSwarm, SequentialWorkflow, ConcurrentWorkflow]
Optional fields:
- description (string)
- max_loops (integer)
- task (string)
TEMPLATE STRUCTURE:
```yaml
agents:
- agent_name: "Agent-1-Name"
system_prompt: "Detailed system prompt here"
max_loops: 1
# [additional optional fields]
- agent_name: "Agent-2-Name"
system_prompt: "Detailed system prompt here"
# [additional optional fields]
swarm_architecture:
name: "Swarm-Name"
description: "Swarm purpose and goals"
swarm_type: "ConcurrentWorkflow"
max_loops: 5
task: "Main swarm task description"
```
VALIDATION RULES:
1. All agent names must be unique
2. System prompts must be clear and specific to the agent's role
3. Integer values must be positive
4. Boolean values must be true or false (lowercase)
5. File paths should use forward slashes
6. Tasks should be specific and aligned with the agent/swarm purpose
When generating a YAML configuration:
1. Ask for specific requirements about the agents and swarm needed
2. Determine if a swarm architecture is necessary based on the task complexity
3. Generate appropriate system prompts for each agent based on their roles
4. Include relevant optional fields based on the use case
5. Validate the configuration against all rules before returning
Example valid YAML configurations are provided below. Use these as references for structure and formatting:
```yaml
agents:
- agent_name: "Data-Analysis-Agent"
system_prompt: "You are a specialized data analysis agent focused on processing and interpreting financial data. Provide clear, actionable insights based on the data provided."
max_loops: 3
autosave: true
verbose: true
context_length: 100000
output_type: "json"
task: "Analyze quarterly financial reports and identify trends"
# Multi-Agent Swarm Example
agents:
- agent_name: "Research-Agent"
system_prompt: "You are a research agent specialized in gathering and summarizing scientific publications. Focus on peer-reviewed sources and provide comprehensive summaries."
max_loops: 2
context_length: 150000
output_type: "str"
- agent_name: "Analysis-Agent"
system_prompt: "You are an analysis agent that processes research summaries and identifies key patterns and insights. Provide detailed analytical reports."
max_loops: 3
context_length: 200000
output_type: "json"
swarm_architecture:
name: "Research-Analysis-Swarm"
description: "A swarm for comprehensive research analysis and insight generation"
swarm_type: "SequentialWorkflow"
max_loops: 5
task: "Research and analyze recent developments in quantum computing"
```
"""
def generate_swarm_config(
task: str,
file_name: str = "swarm_config_output.yaml",
model_name: str = "gpt-4o",
*args,
**kwargs,
):
"""
Generates a swarm configuration based on the provided task and model name.
This function attempts to generate a swarm configuration by running an agent with the specified task and model name.
It then parses the output into YAML format and creates agents based on the parsed YAML content.
Args:
task (str): The task to be performed by the swarm.
file_name (str, optional): The file name for the output YAML configuration. Defaults to "swarm_config_output.yaml".
model_name (str, optional): The name of the model to use for the agent. Defaults to "gpt-4o".
*args: Additional positional arguments to be passed to the agent's run method.
**kwargs: Additional keyword arguments to be passed to the agent's run method.
Returns:
Any: The output of the swarm configuration generation process. This can be a SwarmRouter instance or an error message.
"""
formatter.print_panel(
"Auto Generating Swarm...", "Auto Swarm Builder"
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(min=4, max=10),
)
def attempt_generate_swarm_config():
try:
model = LiteLLM(model_name=model_name)
# Initialize the agent
agent = Agent(
agent_name="Auto-Swarm-Builder",
system_prompt=AUTO_GEN_PROMPT,
llm=model,
max_loops=1,
dynamic_temperature_enabled=True,
saved_state_path="swarm_builder.json",
user_name="swarms_corp",
output_type="str",
)
# Generate output from the agent
raw_output = agent.run(task, *args, **kwargs)
yaml_content = parse_yaml_from_swarm_markdown(raw_output)
print(yaml_content)
# Create agents from the YAML file
output = create_agents_from_yaml(
yaml_string=yaml_content,
return_type="run_swarm",
)
formatter.print_panel(
"Swarm configuration generated successfully.",
"Success",
)
return output
except Exception as e:
formatter.print_panel(
f"Error generating swarm configuration: {str(e)}",
"Error",
)
raise
return attempt_generate_swarm_config()

@ -1,22 +1,168 @@
import os
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import yaml
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from pydantic import (
BaseModel,
Field,
field_validator,
)
from swarms.utils.loguru_logger import initialize_logger
from swarms.structs.agent import Agent
from swarms.structs.swarm_router import SwarmRouter
from swarms.utils.litellm import LiteLLM
logger = initialize_logger(log_folder="create_agents_from_yaml")
class AgentConfig(BaseModel):
agent_name: str
system_prompt: str
model_name: Optional[str] = None
max_loops: int = Field(default=1, ge=1)
autosave: bool = True
dashboard: bool = False
verbose: bool = False
dynamic_temperature_enabled: bool = False
saved_state_path: Optional[str] = None
user_name: str = "default_user"
retry_attempts: int = Field(default=3, ge=1)
context_length: int = Field(default=100000, ge=1000)
return_step_meta: bool = False
output_type: str = "str"
auto_generate_prompt: bool = False
artifacts_on: bool = False
artifacts_file_extension: str = ".md"
artifacts_output_path: str = ""
@field_validator("system_prompt")
@classmethod
def validate_system_prompt(cls, v):
if not v or not isinstance(v, str) or len(v.strip()) == 0:
raise ValueError(
"System prompt must be a non-empty string"
)
return v
class SwarmConfig(BaseModel):
name: str
description: str
max_loops: int = Field(default=1, ge=1)
swarm_type: str
task: Optional[str] = None
flow: Optional[Dict] = None
autosave: bool = True
return_json: bool = False
rules: str = ""
@field_validator("swarm_type")
@classmethod
def validate_swarm_type(cls, v):
valid_types = {
"SequentialWorkflow",
"ConcurrentWorkflow",
"AgentRearrange",
"MixtureOfAgents",
"auto",
}
if v not in valid_types:
raise ValueError(
f"Swarm type must be one of: {valid_types}"
)
return v
class YAMLConfig(BaseModel):
agents: List[AgentConfig] = Field(..., min_length=1)
swarm_architecture: Optional[SwarmConfig] = None
model_config = {
"extra": "forbid" # Prevent additional fields not in the model
}
def load_yaml_safely(
yaml_file: str = None, yaml_string: str = None
) -> Dict:
"""Safely load and validate YAML configuration using Pydantic."""
try:
if yaml_string:
config_dict = yaml.safe_load(yaml_string)
elif yaml_file:
if not os.path.exists(yaml_file):
raise FileNotFoundError(
f"YAML file {yaml_file} not found."
)
with open(yaml_file, "r") as file:
config_dict = yaml.safe_load(file)
else:
raise ValueError(
"Either yaml_file or yaml_string must be provided"
)
# Validate using Pydantic
YAMLConfig(**config_dict)
return config_dict
except yaml.YAMLError as e:
raise ValueError(f"Error parsing YAML: {str(e)}")
except Exception as e:
raise ValueError(f"Error validating configuration: {str(e)}")
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((ConnectionError, TimeoutError)),
before_sleep=lambda retry_state: logger.info(
f"Retrying after error: {retry_state.outcome.exception()}"
),
)
def create_agent_with_retry(
agent_config: Dict, model: LiteLLM
) -> Agent:
"""Create an agent with retry logic for handling transient failures."""
try:
validated_config = AgentConfig(**agent_config)
agent = Agent(
agent_name=validated_config.agent_name,
system_prompt=validated_config.system_prompt,
llm=model,
max_loops=validated_config.max_loops,
autosave=validated_config.autosave,
dashboard=validated_config.dashboard,
verbose=validated_config.verbose,
dynamic_temperature_enabled=validated_config.dynamic_temperature_enabled,
saved_state_path=validated_config.saved_state_path,
user_name=validated_config.user_name,
retry_attempts=validated_config.retry_attempts,
context_length=validated_config.context_length,
return_step_meta=validated_config.return_step_meta,
output_type=validated_config.output_type,
auto_generate_prompt=validated_config.auto_generate_prompt,
artifacts_on=validated_config.artifacts_on,
artifacts_file_extension=validated_config.artifacts_file_extension,
artifacts_output_path=validated_config.artifacts_output_path,
)
return agent
except Exception as e:
logger.error(
f"Error creating agent {agent_config.get('agent_name', 'unknown')}: {str(e)}"
)
raise
def create_agents_from_yaml(
model: Callable = None,
yaml_file: str = "agents.yaml",
yaml_string: str = None,
return_type: str = "auto",
*args,
**kwargs,
) -> Union[
SwarmRouter,
Agent,
@ -25,171 +171,99 @@ def create_agents_from_yaml(
List[Dict[str, Any]],
]:
"""
Create agents and/or SwarmRouter based on configurations defined in a YAML file.
This function dynamically creates agents and a SwarmRouter (if specified) based on the
configuration in the YAML file. It adapts its behavior based on the presence of a
swarm architecture and the number of agents defined.
Args:
model (Callable): The language model to be used by the agents.
yaml_file (str): Path to the YAML file containing agent and swarm configurations.
return_type (str): Determines the return value. Options are:
"auto" (default): Automatically determine the most appropriate return type.
"swarm": Return SwarmRouter if present, otherwise a single agent or list of agents.
"agents": Return a list of agents (or a single agent if only one is defined).
"both": Return both SwarmRouter (or single agent) and list of agents.
"tasks": Return task results if any tasks were executed.
"run_swarm": Run the swarm and return its output.
*args: Additional positional arguments for agent or SwarmRouter customization.
**kwargs: Additional keyword arguments for agent or SwarmRouter customization.
Returns:
Union[SwarmRouter, Agent, List[Agent], Tuple[Union[SwarmRouter, Agent], List[Agent]], List[Dict[str, Any]]]:
The return type depends on the 'return_type' argument and the configuration in the YAML file.
Raises:
FileNotFoundError: If the specified YAML file is not found.
ValueError: If the YAML configuration is invalid or if an invalid return_type is specified.
Create agents and/or SwarmRouter based on configurations defined in a YAML file or string.
"""
try:
logger.info(
f"Checking if the YAML file {yaml_file} exists..."
)
if not os.path.exists(yaml_file):
logger.error(f"YAML file {yaml_file} not found.")
raise FileNotFoundError(
f"YAML file {yaml_file} not found."
)
logger.info(f"Loading YAML file {yaml_file}")
with open(yaml_file, "r") as file:
config = yaml.safe_load(file)
if "agents" not in config:
logger.error(
"The YAML configuration does not contain 'agents'."
)
raise ValueError(
"The YAML configuration does not contain 'agents'."
)
agents = []
task_results = []
swarm_router = None
agents = []
task_results = []
try:
# Load and validate configuration
config = load_yaml_safely(yaml_file, yaml_string)
# Create agents
# Create agents with retry logic
for agent_config in config["agents"]:
logger.info(
f"Creating agent: {agent_config['agent_name']}"
)
if "system_prompt" not in agent_config:
logger.error(
f"System prompt is missing for agent: {agent_config['agent_name']}"
)
raise ValueError(
f"System prompt is missing for agent: {agent_config['agent_name']}"
if "model_name" in agent_config:
model_instance = LiteLLM(
model_name=agent_config["model_name"]
)
else:
model_name = "gpt-4o"
model_instance = LiteLLM(model_name=model_name)
agent = Agent(
agent_name=agent_config["agent_name"],
system_prompt=agent_config["system_prompt"],
llm=model,
max_loops=agent_config.get("max_loops", 1),
autosave=agent_config.get("autosave", True),
dashboard=agent_config.get("dashboard", False),
verbose=agent_config.get("verbose", False),
dynamic_temperature_enabled=agent_config.get(
"dynamic_temperature_enabled", False
),
saved_state_path=agent_config.get("saved_state_path"),
user_name=agent_config.get(
"user_name", "default_user"
),
retry_attempts=agent_config.get("retry_attempts", 1),
context_length=agent_config.get(
"context_length", 100000
),
return_step_meta=agent_config.get(
"return_step_meta", False
),
output_type=agent_config.get("output_type", "str"),
auto_generate_prompt=agent_config.get(
"auto_generate_prompt", "False"
),
artifacts_on=agent_config.get(
"artifacts_on", "False"
),
artifacts_file_extension=agent_config.get(
"artifacts_file_extension", ".md"
),
artifacts_output_path=agent_config.get(
"artifacts_output_path", ""
),
*args,
**kwargs,
agent = create_agent_with_retry(
agent_config, model_instance
)
logger.info(
f"Agent {agent_config['agent_name']} created successfully."
)
agents.append(agent)
# Create SwarmRouter if swarm_architecture is present
swarm_router = None
# Create SwarmRouter if specified
if "swarm_architecture" in config:
swarm_config = config["swarm_architecture"]
swarm_router = SwarmRouter(
name=swarm_config["name"],
description=swarm_config["description"],
max_loops=swarm_config["max_loops"],
agents=agents,
swarm_type=swarm_config["swarm_type"],
task=swarm_config.get("task"),
flow=swarm_config.get("flow"),
autosave=swarm_config.get("autosave"),
return_json=swarm_config.get("return_json"),
rules=swarm_config.get("rules", "") * args,
**kwargs,
)
logger.info(
f"SwarmRouter '{swarm_config['name']}' created successfully."
try:
swarm_config = SwarmConfig(
**config["swarm_architecture"]
)
swarm_router = SwarmRouter(
name=swarm_config.name,
description=swarm_config.description,
max_loops=swarm_config.max_loops,
agents=agents,
swarm_type=swarm_config.swarm_type,
task=swarm_config.task,
flow=swarm_config.flow,
autosave=swarm_config.autosave,
return_json=swarm_config.return_json,
rules=swarm_config.rules,
)
logger.info(
f"SwarmRouter '{swarm_config.name}' created successfully."
)
except Exception as e:
logger.error(f"Error creating SwarmRouter: {str(e)}")
raise ValueError(
f"Failed to create SwarmRouter: {str(e)}"
)
# Handle return types with improved error checking
valid_return_types = {
"auto",
"swarm",
"agents",
"both",
"tasks",
"run_swarm",
}
if return_type not in valid_return_types:
raise ValueError(
f"Invalid return_type. Must be one of: {valid_return_types}"
)
# Define function to run SwarmRouter
def run_swarm_router(
task: str = (
swarm_config.get("task")
if "swarm_architecture" in config
else None
),
):
if swarm_router:
try:
output = swarm_router.run(task)
print(output)
logger.info(
f"Output for SwarmRouter '{swarm_config['name']}': {output}"
)
return output
except Exception as e:
logger.error(
f"Error running task for SwarmRouter '{swarm_config['name']}': {e}"
)
raise e
else:
logger.error("SwarmRouter not created.")
raise ValueError("SwarmRouter not created.")
if return_type == "run_swarm" or "swarm":
if not swarm_router:
raise ValueError(
"Cannot run swarm: SwarmRouter not created."
)
try:
return swarm_router.run(
config["swarm_architecture"]["task"]
)
except Exception as e:
logger.error(f"Error running SwarmRouter: {str(e)}")
raise
# Handle return types
# Return appropriate type based on configuration
if return_type == "auto":
if swarm_router:
return swarm_router
elif len(agents) == 1:
return agents[0]
else:
return agents
return (
swarm_router
if swarm_router
else (agents[0] if len(agents) == 1 else agents)
)
elif return_type == "swarm":
return (
swarm_router
@ -205,24 +279,10 @@ def create_agents_from_yaml(
else agents[0] if len(agents) == 1 else agents
), agents
elif return_type == "tasks":
if not task_results:
logger.warning(
"No tasks were executed. Returning empty list."
)
return task_results
elif return_type == "run_swarm":
if swarm_router:
return run_swarm_router()
else:
logger.error(
"Cannot run swarm: SwarmRouter not created."
)
raise ValueError(
"Cannot run swarm: SwarmRouter not created."
)
else:
logger.error(f"Invalid return_type: {return_type}")
raise ValueError(f"Invalid return_type: {return_type}")
except Exception as e:
logger.error(f"An error occurred: {e}")
raise e
logger.error(
f"Critical error in create_agents_from_yaml: {str(e)}"
)
raise

@ -1,244 +1,348 @@
import argparse
import os
import subprocess
import time
import webbrowser
from rich.console import Console
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.table import Table
from rich.text import Text
from swarms.cli.onboarding_process import OnboardingProcess
from swarms.agents.auto_generate_swarm_config import (
generate_swarm_config,
)
from swarms.agents.create_agents_from_yaml import (
create_agents_from_yaml,
)
import subprocess
from swarms.cli.onboarding_process import OnboardingProcess
from swarms.utils.formatter import formatter
# Initialize console with custom styling
console = Console()
ASCII_ART = """
_________
/ _____/_ _ _______ _______ _____ ______
\_____ \\ \/ \/ /\__ \\_ __ \/ \ / ___/
/ \\ / / __ \| | \/ Y Y \\___ \
/_______ / \/\_/ (____ /__| |__|_| /____ >
\/ \/ \/ \/
class SwarmCLIError(Exception):
"""Custom exception for Swarm CLI errors"""
pass
# Color scheme
COLORS = {
"primary": "red",
"secondary": "#FF6B6B",
"accent": "#4A90E2",
"success": "#2ECC71",
"warning": "#F1C40F",
"error": "#E74C3C",
"text": "#FFFFFF",
}
ASCII_ART = """
"""
# Function to display the ASCII art in red
def create_spinner(text: str) -> Progress:
"""Create a custom spinner with the given text."""
return Progress(
SpinnerColumn(style=COLORS["primary"]),
TextColumn("[{task.description}]", style=COLORS["text"]),
console=console,
)
def show_ascii_art():
text = Text(ASCII_ART, style="bold cyan")
console.print(text)
"""Display the ASCII art with a glowing effect."""
panel = Panel(
Text(ASCII_ART, style=f"bold {COLORS['primary']}"),
border_style=COLORS["secondary"],
title="[bold]Welcome to Swarms[/bold]",
subtitle="[dim]Power to the Swarms[/dim]",
)
console.print(panel)
# Help command
def show_help():
console.print(
"""
[bold cyan]Swarms CLI - Help[/bold cyan]
[bold magenta]Commands:[/bold magenta]
[bold white]onboarding[/bold white] : Starts the onboarding process
[bold white]help[/bold white] : Shows this help message
[bold white]get-api-key[/bold white] : Retrieves your API key from the platform
[bold white]check-login[/bold white] : Checks if you're logged in and starts the cache
[bold white]read-docs[/bold white] : Redirects you to swarms cloud documentation!
[bold white]run-agents[/bold white] : Run your Agents from your specified yaml file. Specify the yaml file with path the `--yaml-file` arg. Example: `--yaml-file agents.yaml`
[bold white]generate-prompt[/bold white] : Generate a prompt through automated prompt engineering. Requires an OPENAI Key in your `.env` Example: --prompt "Generate a prompt for an agent to analyze legal docs"
[bold white]auto-upgrade[/bold white] : Automatically upgrades Swarms to the latest version
[bold white]book-call[/bold white] : Book a strategy session with our team to discuss your use case and get personalized guidance
For more details, visit: https://docs.swarms.world
"""
def create_command_table() -> Table:
"""Create a beautifully formatted table of commands."""
table = Table(
show_header=True,
header_style=f"bold {COLORS['primary']}",
border_style=COLORS["secondary"],
title="Available Commands",
padding=(0, 2),
)
# [bold white]add-agent[/bold white] : Add an agent to the marketplace under your name. Must have a Dockerfile + your agent.yaml to publish. Learn more Here: https://docs.swarms.world/en/latest/swarms_cloud/vision/
table.add_column("Command", style="bold white")
table.add_column("Description", style="dim white")
commands = [
("onboarding", "Start the interactive onboarding process"),
("help", "Display this help message"),
("get-api-key", "Retrieve your API key from the platform"),
("check-login", "Verify login status and initialize cache"),
("run-agents", "Execute agents from your YAML configuration"),
("auto-upgrade", "Update Swarms to the latest version"),
("book-call", "Schedule a strategy session with our team"),
("autoswarm", "Generate and execute an autonomous swarm"),
]
# Fetch API key from platform
def get_api_key():
for cmd, desc in commands:
table.add_row(cmd, desc)
return table
def show_help():
"""Display a beautifully formatted help message."""
console.print(
"[bold yellow]Opening the API key retrieval page...[/bold yellow]"
"\n[bold]Swarms CLI - Command Reference[/bold]\n",
style=COLORS["primary"],
)
# Simulating API key retrieval process by opening the website
import webbrowser
webbrowser.open("https://swarms.world/platform/api-keys")
time.sleep(2)
console.print(create_command_table())
console.print(
"[bold green]Your API key is available on the dashboard.[/bold green]"
"\n[dim]For detailed documentation, visit: https://docs.swarms.world[/dim]"
)
# Redirect to docs
def redirect_to_docs():
console.print(
"[bold yellow]Opening the Docs page...[/bold yellow]"
def show_error(message: str, help_text: str = None):
"""Display error message in a formatted panel"""
error_panel = Panel(
f"[bold red]{message}[/bold red]",
title="Error",
border_style="red",
)
# Simulating API key retrieval process by opening the website
import webbrowser
console.print(error_panel)
webbrowser.open("https://docs.swarms.world")
time.sleep(2)
if help_text:
console.print(f"\n[yellow] {help_text}[/yellow]")
# Redirect to docs
def redirect_to_call():
def execute_with_spinner(action: callable, text: str) -> None:
"""Execute an action with a spinner animation."""
with create_spinner(text) as progress:
task = progress.add_task(text, total=None)
result = action()
progress.remove_task(task)
return result
def get_api_key():
"""Retrieve API key with visual feedback."""
with create_spinner("Opening API key portal...") as progress:
task = progress.add_task("Opening browser...")
webbrowser.open("https://swarms.world/platform/api-keys")
time.sleep(1)
progress.remove_task(task)
console.print(
"[bold yellow]Opening the Call page...[/bold yellow]"
f"\n[{COLORS['success']}]✓ API key page opened in your browser[/{COLORS['success']}]"
)
# Simulating API key retrieval process by opening the website
import webbrowser
webbrowser.open("https://cal.com/swarms/swarms-strategy-session")
time.sleep(2)
# Check and start cache (login system simulation)
def check_login():
"""Verify login status with enhanced visual feedback."""
cache_file = "cache.txt"
if os.path.exists(cache_file):
with open(cache_file, "r") as f:
cache_content = f.read()
if cache_content == "logged_in":
if f.read() == "logged_in":
console.print(
f"[{COLORS['success']}]✓ Authentication verified[/{COLORS['success']}]"
)
return True
with create_spinner("Authenticating...") as progress:
task = progress.add_task("Initializing session...")
time.sleep(1)
with open(cache_file, "w") as f:
f.write("logged_in")
progress.remove_task(task)
console.print(
f"[{COLORS['success']}]✓ Login successful![/{COLORS['success']}]"
)
return True
def run_autoswarm(task: str, model: str):
"""Run autoswarm with enhanced error handling"""
try:
console.print(
"[yellow]Initializing autoswarm configuration...[/yellow]"
)
# Set LiteLLM verbose mode for debugging
import litellm
litellm.set_verbose = True
# Validate inputs
if not task or task.strip() == "":
raise SwarmCLIError("Task cannot be empty")
if not model or model.strip() == "":
raise SwarmCLIError("Model name cannot be empty")
# Attempt to generate swarm configuration
console.print(
f"[yellow]Generating swarm for task: {task}[/yellow]"
)
result = generate_swarm_config(task=task, model=model)
if result:
console.print(
"[bold green]You are already logged in.[/bold green]"
"[green]✓ Swarm configuration generated successfully![/green]"
)
else:
console.print(
"[bold red]You are not logged in.[/bold red]"
raise SwarmCLIError(
"Failed to generate swarm configuration"
)
except Exception as e:
if "No YAML content found" in str(e):
show_error(
"Failed to generate YAML configuration",
"This might be due to an API key issue or invalid model configuration.\n"
+ "1. Check if your OpenAI API key is set correctly\n"
+ "2. Verify the model name is valid\n"
+ "3. Try running with --model gpt-4",
)
else:
show_error(
f"Error during autoswarm execution: {str(e)}",
"For debugging, try:\n"
+ "1. Check your API keys are set correctly\n"
+ "2. Verify your network connection\n"
+ "3. Try a different model",
)
else:
console.print("[bold yellow]Logging in...[/bold yellow]")
time.sleep(2)
with open(cache_file, "w") as f:
f.write("logged_in")
console.print("[bold green]Login successful![/bold green]")
def check_and_upgrade_version():
console.print(
"[bold yellow]Checking for Swarms updates...[/bold yellow]"
)
try:
# Check for updates using pip
"""Check for updates with visual progress."""
def check_update():
result = subprocess.run(
["pip", "list", "--outdated", "--format=freeze"],
capture_output=True,
text=True,
)
outdated_packages = result.stdout.splitlines()
return result.stdout.splitlines()
# Check if Swarms is outdated
for package in outdated_packages:
if package.startswith("swarms=="):
console.print(
"[bold magenta]New version available! Upgrading...[/bold magenta]"
outdated = execute_with_spinner(
check_update, "Checking for updates..."
)
for package in outdated:
if package.startswith("swarms=="):
console.print(
f"[{COLORS['warning']}]↑ Update available![/{COLORS['warning']}]"
)
with create_spinner("Upgrading Swarms...") as progress:
task = progress.add_task(
"Installing latest version..."
)
subprocess.run(
["pip", "install", "--upgrade", "swarms"],
check=True,
)
console.print(
"[bold green]Swarms upgraded successfully![/bold green]"
)
return
progress.remove_task(task)
console.print(
f"[{COLORS['success']}]✓ Swarms upgraded successfully![/{COLORS['success']}]"
)
return
console.print(
"[bold green]Swarms is up-to-date.[/bold green]"
)
except Exception as e:
console.print(
f"[bold red]Error checking for updates: {e}[/bold red]"
)
console.print(
f"[{COLORS['success']}]✓ Swarms is up to date![/{COLORS['success']}]"
)
# Main CLI handler
def main():
parser = argparse.ArgumentParser(description="Swarms Cloud CLI")
# Adding arguments for different commands
parser.add_argument(
"command",
choices=[
"onboarding",
"help",
"get-api-key",
"check-login",
"run-agents",
"generate-prompt", # Added new command for generating prompts
"auto-upgrade", # Added new command for auto-upgrade,
"book-call",
],
help="Command to run",
)
parser.add_argument(
"--yaml-file",
type=str,
default="agents.yaml",
help="Specify the YAML file for running agents",
)
parser.add_argument(
"--prompt",
type=str,
help="Specify the task for generating a prompt",
)
parser.add_argument(
"--num-loops",
type=int,
default=1,
help="Specify the number of loops for generating a prompt",
)
parser.add_argument(
"--autosave",
action="store_true",
help="Enable autosave for the prompt generator",
)
parser.add_argument(
"--save-to-yaml",
action="store_true",
help="Save the generated prompt to a YAML file",
)
try:
args = parser.parse_args()
show_ascii_art()
# Determine which command to run
if args.command == "onboarding":
OnboardingProcess().run()
elif args.command == "help":
show_help()
elif args.command == "get-api-key":
get_api_key()
elif args.command == "check-login":
check_login()
elif args.command == "run-agents":
create_agents_from_yaml(
yaml_file=args.yaml_file, return_type="tasks"
show_ascii_art()
parser = argparse.ArgumentParser(
description="Swarms Cloud CLI"
)
# elif args.command == "generate-prompt":
# if (
# args.prompt
# ): # Corrected from args.prompt_task to args.prompt
# generate_prompt(
# num_loops=args.num_loops,
# autosave=args.autosave,
# save_to_yaml=args.save_to_yaml,
# prompt=args.prompt, # Corrected from args.prompt_task to args.prompt
# )
# else:
# console.print(
# "[bold red]Please specify a task for generating a prompt using '--prompt'.[/bold red]"
# )
elif args.command == "auto-upgrade":
check_and_upgrade_version()
elif args.command == "book-call":
redirect_to_call()
else:
console.print(
"[bold red]Unknown command! Type 'help' for usage.[/bold red]"
parser.add_argument(
"command",
choices=[
"onboarding",
"help",
"get-api-key",
"check-login",
"run-agents",
"auto-upgrade",
"book-call",
"autoswarm",
],
help="Command to execute",
)
parser.add_argument(
"--yaml-file",
type=str,
default="agents.yaml",
help="YAML configuration file path",
)
parser.add_argument(
"--task", type=str, help="Task for autoswarm"
)
parser.add_argument(
"--model",
type=str,
default="gpt-4",
help="Model for autoswarm",
)
args = parser.parse_args()
try:
if args.command == "onboarding":
OnboardingProcess().run()
elif args.command == "help":
show_help()
elif args.command == "get-api-key":
get_api_key()
elif args.command == "check-login":
check_login()
elif args.command == "run-agents":
create_agents_from_yaml(
yaml_file=args.yaml_file, return_type="tasks"
)
elif args.command == "auto-upgrade":
check_and_upgrade_version()
elif args.command == "book-call":
webbrowser.open(
"https://cal.com/swarms/swarms-strategy-session"
)
elif args.command == "autoswarm":
if not args.task:
show_error(
"Missing required argument: --task",
"Example usage: python cli.py autoswarm --task 'analyze this data' --model gpt-4",
)
exit(1)
run_autoswarm(args.task, args.model)
except Exception as e:
console.print(
f"[{COLORS['error']}]Error: {str(e)}[/{COLORS['error']}]"
)
return
except Exception as error:
formatter.print_panel(
f"Error detected: {error} check your args"
)
raise error
if __name__ == "__main__":

@ -87,19 +87,6 @@ class OnboardingProcess:
try:
combined_data = {**self.user_data, **self.system_data}
log_agent_data(combined_data)
# threading.Thread(target=log_agent_data(combined_data)).start()
# with open(self.auto_save_path, "w") as f:
# json.dump(combined_data, f, indent=4)
# # logger.info(
# # "User and system data successfully saved to {}",
# # self.auto_save_path,
# # )
# with open(self.cache_save_path, "w") as f:
# json.dump(combined_data, f, indent=4)
# logger.info(
# "User and system data successfully cached in {}",
# self.cache_save_path,
# )
return # Exit the function if saving was successful
except Exception as e:
logger.error(

@ -338,6 +338,8 @@ class Agent:
scheduled_run_date: Optional[datetime] = None,
do_not_use_cluster_ops: bool = True,
all_gpus: bool = False,
model_name: str = None,
llm_args: dict = None,
*args,
**kwargs,
):
@ -453,6 +455,8 @@ class Agent:
self.scheduled_run_date = scheduled_run_date
self.do_not_use_cluster_ops = do_not_use_cluster_ops
self.all_gpus = all_gpus
self.model_name = model_name
self.llm_args = llm_args
# Initialize the short term memory
self.short_memory = Conversation(
@ -589,6 +593,21 @@ class Agent:
# Telemetry Processor to log agent data
threading.Thread(target=self.log_agent_data).start()
threading.Thread(target=self.llm_handling())
def llm_handling(self):
if self.llm is None:
from swarms.utils.litellm import LiteLLM
if self.llm_args is not None:
self.llm = LiteLLM(
model_name=self.model_name, **self.llm_args
)
else:
self.llm = LiteLLM(model_name=self.model_name)
def check_if_no_prompt_then_autogenerate(self, task: str = None):
"""
Checks if auto_generate_prompt is enabled and generates a prompt by combining agent name, description and system prompt if available.
@ -951,7 +970,7 @@ class Agent:
if self.interactive:
logger.info("Interactive mode enabled.")
user_input = formatter.print_panel(input("You: "))
user_input = input("You: ")
# User-defined exit command
if (
@ -1015,6 +1034,11 @@ class Agent:
self.artifacts_file_extension,
)
try:
self.log_agent_data()
except Exception:
pass
# More flexible output types
if (
self.output_type == "string"
@ -1050,8 +1074,16 @@ class Agent:
)
except Exception as error:
self.log_agent_data()
logger.info(
f"Error running agent: {error} optimize your input parameters"
)
raise error
except KeyboardInterrupt as error:
self.log_agent_data()
logger.info(
f"Error running agent: {error} optimize your input parameter"
f"Error running agent: {error} optimize your input parameters"
)
raise error

@ -0,0 +1,105 @@
try:
from litellm import completion
except ImportError:
import subprocess
subprocess.check_call(["pip", "install", "litellm"])
import litellm
from litellm import completion
litellm.set_verbose = True
class LiteLLM:
"""
This class represents a LiteLLM.
It is used to interact with the LLM model for various tasks.
"""
def __init__(
self,
model_name: str = "gpt-4o",
system_prompt: str = None,
stream: bool = False,
temperature: float = 0.5,
max_tokens: int = 4000,
):
"""
Initialize the LiteLLM with the given parameters.
Args:
model_name (str, optional): The name of the model to use. Defaults to "gpt-4o".
system_prompt (str, optional): The system prompt to use. Defaults to None.
stream (bool, optional): Whether to stream the output. Defaults to False.
temperature (float, optional): The temperature for the model. Defaults to 0.5.
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 4000.
"""
self.model_name = model_name
self.system_prompt = system_prompt
self.stream = stream
self.temperature = temperature
self.max_tokens = max_tokens
def _prepare_messages(self, task: str) -> list:
"""
Prepare the messages for the given task.
Args:
task (str): The task to prepare messages for.
Returns:
list: A list of messages prepared for the task.
"""
messages = []
if self.system_prompt: # Check if system_prompt is not None
messages.append(
{"role": "system", "content": self.system_prompt}
)
messages.append({"role": "user", "content": task})
return messages
def run(self, task: str, *args, **kwargs):
"""
Run the LLM model for the given task.
Args:
task (str): The task to run the model for.
*args: Additional positional arguments to pass to the model.
**kwargs: Additional keyword arguments to pass to the model.
Returns:
str: The content of the response from the model.
"""
messages = self._prepare_messages(task)
response = completion(
model=self.model_name,
messages=messages,
stream=self.stream,
temperature=self.temperature,
# max_completion_tokens=self.max_tokens,
max_tokens=self.max_tokens,
*args,
**kwargs,
)
content = response.choices[
0
].message.content # Accessing the content
return content
def __call__(self, task: str, *args, **kwargs):
"""
Call the LLM model for the given task.
Args:
task (str): The task to run the model for.
*args: Additional positional arguments to pass to the model.
**kwargs: Additional keyword arguments to pass to the model.
Returns:
str: The content of the response from the model.
"""
return self.run(task, *args, **kwargs)

@ -1,50 +1,64 @@
import re
def extract_code_from_markdown(markdown_content: str) -> str:
def extract_code_blocks_with_language(markdown_text: str):
"""
Extracts code blocks from a Markdown string and returns them as a single string.
Extracts all code blocks from Markdown text along with their languages.
Args:
- markdown_content (str): The Markdown content as a string.
markdown_text (str): The input Markdown text.
Returns:
- str: A single string containing all the code blocks separated by newlines.
list[dict]: A list of dictionaries, each containing:
- 'language': The detected language (or 'plaintext' if none specified).
- 'content': The content of the code block.
"""
# Regular expression for fenced code blocks with optional language specifier
pattern = r"```(?:\w+\n)?(.*?)```"
# Regex pattern to match code blocks and optional language specifiers
pattern = r"```(\w+)?\n(.*?)```"
# Check if markdown_content is a string
if not isinstance(markdown_content, str):
raise TypeError("markdown_content must be a string")
# Find all matches (language and content)
matches = re.findall(pattern, markdown_text, re.DOTALL)
# Find all matches of the pattern
matches = re.finditer(pattern, markdown_content, re.DOTALL)
# Extract the content inside the backticks
# Parse results
code_blocks = []
for match in matches:
code_block = match.group(1).strip()
# Remove any leading or trailing whitespace from the code block
code_block = code_block.strip()
# Remove any empty lines from the code block
code_block = "\n".join(
[line for line in code_block.split("\n") if line.strip()]
for language, content in matches:
language = (
language.strip() if language else "plaintext"
) # Default to 'plaintext'
code_blocks.append(
{"language": language, "content": content.strip()}
)
code_blocks.append(code_block)
# Concatenate all code blocks separated by newlines
if code_blocks:
return "\n\n".join(code_blocks)
else:
return ""
return code_blocks
def extract_code_from_markdown(
markdown_text: str, language: str = None
):
"""
Extracts content of code blocks for a specific language or all blocks if no language specified.
Args:
markdown_text (str): The input Markdown text.
language (str, optional): The language to filter by (e.g., 'yaml', 'python').
# example = """
# hello im an agent
# ```bash
# pip install swarms
# ```
# """
Returns:
str: The concatenated content of matched code blocks or an empty string if none found.
"""
# Get all code blocks with detected languages
code_blocks = extract_code_blocks_with_language(markdown_text)
# Filter by language if specified
if language:
code_blocks = [
block["content"]
for block in code_blocks
if block["language"] == language
]
else:
code_blocks = [
block["content"] for block in code_blocks
] # Include all blocks
# print(extract_code_from_markdown(example)) # Output: { "type": "function", "function": { "name": "fetch_financial_news", "parameters": { "query": "Nvidia news", "num_articles": 5 } } }
# Return concatenated content
return "\n\n".join(code_blocks) if code_blocks else ""

@ -1,292 +0,0 @@
import torch
import torch.nn as nn
import torch.distributed as dist
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from loguru import logger
import math
@dataclass
class StarAttentionConfig:
"""Configuration for StarAttention module.
Attributes:
hidden_size: Dimension of the model's hidden states
num_attention_heads: Number of attention heads
num_hosts: Number of hosts in the distributed system
block_size: Size of each context block
anchor_size: Size of the anchor block
dropout_prob: Dropout probability (default: 0.1)
layer_norm_eps: Layer normalization epsilon (default: 1e-12)
"""
hidden_size: int
num_attention_heads: int
num_hosts: int
block_size: int
anchor_size: int
dropout_prob: float = 0.1
layer_norm_eps: float = 1e-12
class StarAttention(nn.Module):
"""
Implementation of Star Attention mechanism for distributed inference.
The module implements a two-phase attention mechanism:
1. Local Context Encoding with Anchor Blocks
2. Query Encoding and Output Generation with Global Attention
"""
def __init__(self, config: StarAttentionConfig):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"Hidden size {config.hidden_size} not divisible by number of attention "
f"heads {config.num_attention_heads}"
)
self.config = config
self.head_dim = (
config.hidden_size // config.num_attention_heads
)
# Initialize components
self.query = nn.Linear(config.hidden_size, config.hidden_size)
self.key = nn.Linear(config.hidden_size, config.hidden_size)
self.value = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.dropout_prob)
self.layer_norm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
# KV cache for storing computed key/value pairs
self.kv_cache = {}
logger.info(
f"Initialized StarAttention with config: {config}"
)
def _split_heads(
self, tensor: torch.Tensor, num_heads: int
) -> torch.Tensor:
"""Split the last dimension into (num_heads, head_dim)."""
batch_size, seq_len, _ = tensor.size()
tensor = tensor.view(
batch_size, seq_len, num_heads, self.head_dim
)
# Transpose to (batch_size, num_heads, seq_len, head_dim)
return tensor.transpose(1, 2)
def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
"""Merge the head dimension back into hidden_size."""
batch_size, _, seq_len, _ = tensor.size()
tensor = tensor.transpose(1, 2)
return tensor.reshape(
batch_size, seq_len, self.config.hidden_size
)
def _compute_attention_scores(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute attention scores and weighted values."""
# Scale dot-product attention
scores = torch.matmul(
query, key.transpose(-2, -1)
) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
# Online softmax computation
attention_probs = torch.nn.functional.softmax(scores, dim=-1)
attention_probs = self.dropout(attention_probs)
context = torch.matmul(attention_probs, value)
return context, attention_probs
def phase1_local_context_encoding(
self,
input_ids: torch.Tensor,
host_id: int,
device: Union[str, torch.device] = "cuda",
) -> None:
"""
Phase 1: Local Context Encoding with Anchor Blocks
Args:
input_ids: Input tensor of shape (batch_size, seq_len)
host_id: ID of the current host
device: Device to run computations on
"""
logger.debug(f"Starting Phase 1 on host {host_id}")
# Calculate block assignments
block_start = host_id * self.config.block_size
block_end = block_start + self.config.block_size
# Get local block
local_block = input_ids[:, block_start:block_end].to(device)
# Get anchor block (first block)
anchor_block = input_ids[:, : self.config.anchor_size].to(
device
)
# Compute KV pairs for local block
local_hidden = self.layer_norm(local_block)
local_key = self._split_heads(
self.key(local_hidden), self.config.num_attention_heads
)
local_value = self._split_heads(
self.value(local_hidden), self.config.num_attention_heads
)
# Store in KV cache
self.kv_cache[host_id] = {
"key": local_key,
"value": local_value,
"anchor_key": (
None
if host_id == 0
else self._split_heads(
self.key(self.layer_norm(anchor_block)),
self.config.num_attention_heads,
)
),
}
logger.debug(
f"Phase 1 complete on host {host_id}. KV cache shapes - "
f"key: {local_key.shape}, value: {local_value.shape}"
)
def phase2_query_encoding(
self,
query_input: torch.Tensor,
host_id: int,
is_query_host: bool,
device: Union[str, torch.device] = "cuda",
) -> Optional[torch.Tensor]:
"""
Phase 2: Query Encoding and Output Generation
Args:
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
host_id: ID of the current host
is_query_host: Whether this host is the query host
device: Device to run computations on
Returns:
Output tensor if this is the query host, None otherwise
"""
logger.debug(f"Starting Phase 2 on host {host_id}")
# Transform query
query_hidden = self.layer_norm(query_input)
query = self._split_heads(
self.query(query_hidden), self.config.num_attention_heads
)
# Compute local attention scores
local_context, local_probs = self._compute_attention_scores(
query,
self.kv_cache[host_id]["key"],
self.kv_cache[host_id]["value"],
)
if not is_query_host:
# Non-query hosts send their local attention statistics
dist.send(local_probs, dst=self.config.num_hosts - 1)
return None
# Query host aggregates attention from all hosts
all_attention_probs = [local_probs]
for src_rank in range(self.config.num_hosts - 1):
probs = torch.empty_like(local_probs)
dist.recv(probs, src=src_rank)
all_attention_probs.append(probs)
# Compute global attention
torch.mean(torch.stack(all_attention_probs), dim=0)
# Final output computation
output = self._merge_heads(local_context)
output = self.dropout(output)
logger.debug(
f"Phase 2 complete on host {host_id}. Output shape: {output.shape}"
)
return output
def forward(
self,
input_ids: torch.Tensor,
query_input: torch.Tensor,
host_id: int,
is_query_host: bool,
device: Union[str, torch.device] = "cuda",
) -> Optional[torch.Tensor]:
"""
Forward pass of the StarAttention module.
Args:
input_ids: Input tensor of shape (batch_size, seq_len)
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
host_id: ID of the current host
is_query_host: Whether this host is the query host
device: Device to run computations on
Returns:
Output tensor if this is the query host, None otherwise
"""
# Phase 1: Local Context Encoding
self.phase1_local_context_encoding(input_ids, host_id, device)
# Phase 2: Query Encoding and Output Generation
return self.phase2_query_encoding(
query_input, host_id, is_query_host, device
)
# Example forward pass
config = StarAttentionConfig(
hidden_size=768,
num_attention_heads=12,
num_hosts=3,
block_size=512,
anchor_size=128,
)
# Initialize model
model = StarAttention(config)
# Example input tensors
batch_size = 4
seq_len = 512
input_ids = torch.randint(
0, 1000, (batch_size, seq_len)
) # Random input IDs
query_input = torch.randn(
batch_size, seq_len, config.hidden_size
) # Random query input
# Example forward pass for query host (host_id = 2)
output = model(
input_ids=input_ids,
query_input=query_input,
host_id=2,
is_query_host=True,
device="cpu",
)
print(output)
Loading…
Cancel
Save