Merge pull request #1 from kyegomez/master

Catchup 20241128
pull/681/head
evelynmitchell 1 month ago committed by GitHub
commit 95e025454e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,14 +1,7 @@
---
name: Lint
on: [push, pull_request] # yamllint disable-line rule:truthy
on: [push, pull_request]
jobs:
yaml-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- run: pip install yamllint
- run: yamllint .
flake8-lint:
runs-on: ubuntu-latest
steps:

@ -0,0 +1,56 @@
import os
from dotenv import load_dotenv
from swarm_models import OpenAIChat
from swarms import Agent
from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT,
)
from async_executor import HighSpeedExecutor
load_dotenv()
# Get the OpenAI API key from the environment variable
api_key = os.getenv("OPENAI_API_KEY")
# Create an instance of the OpenAIChat class
model = OpenAIChat(
openai_api_key=api_key, model_name="gpt-4o-mini", temperature=0.1
)
# Initialize the agent
agent = Agent(
agent_name="Financial-Analysis-Agent",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
llm=model,
max_loops=1,
# autosave=True,
# dashboard=False,
# verbose=True,
# dynamic_temperature_enabled=True,
# saved_state_path="finance_agent.json",
# user_name="swarms_corp",
# retry_attempts=1,
# context_length=200000,
# return_step_meta=True,
# output_type="json", # "json", "dict", "csv" OR "string" soon "yaml" and
# auto_generate_prompt=False, # Auto generate prompt for the agent based on name, description, and system prompt, task
# # artifacts_on=True,
# artifacts_output_path="roth_ira_report",
# artifacts_file_extension=".txt",
# max_tokens=8000,
# return_history=True,
)
def execute_agent(
task: str = "How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria. Create a report on this question.",
):
return agent.run(task)
executor = HighSpeedExecutor()
results = executor.run(execute_agent, 2)
print(results)

@ -0,0 +1,131 @@
import asyncio
import multiprocessing as mp
import time
from functools import partial
from typing import Any, Dict, Union
class HighSpeedExecutor:
def __init__(self, num_processes: int = None):
"""
Initialize the executor with configurable number of processes.
If num_processes is None, it uses CPU count.
"""
self.num_processes = num_processes or mp.cpu_count()
async def _worker(
self,
queue: asyncio.Queue,
func: Any,
*args: Any,
**kwargs: Any,
):
"""Async worker that processes tasks from the queue"""
while True:
try:
# Non-blocking get from queue
await queue.get()
await asyncio.get_event_loop().run_in_executor(
None, partial(func, *args, **kwargs)
)
queue.task_done()
except asyncio.CancelledError:
break
async def _distribute_tasks(
self, num_tasks: int, queue: asyncio.Queue
):
"""Distribute tasks across the queue"""
for i in range(num_tasks):
await queue.put(i)
async def execute_batch(
self,
func: Any,
num_executions: int,
*args: Any,
**kwargs: Any,
) -> Dict[str, Union[int, float]]:
"""
Execute the given function multiple times concurrently.
Args:
func: The function to execute
num_executions: Number of times to execute the function
*args, **kwargs: Arguments to pass to the function
Returns:
A dictionary containing the number of executions, duration, and executions per second.
"""
queue = asyncio.Queue()
# Create worker tasks
workers = [
asyncio.create_task(
self._worker(queue, func, *args, **kwargs)
)
for _ in range(self.num_processes)
]
# Start timing
start_time = time.perf_counter()
# Distribute tasks
await self._distribute_tasks(num_executions, queue)
# Wait for all tasks to complete
await queue.join()
# Cancel workers
for worker in workers:
worker.cancel()
# Wait for all workers to finish
await asyncio.gather(*workers, return_exceptions=True)
end_time = time.perf_counter()
duration = end_time - start_time
return {
"executions": num_executions,
"duration": duration,
"executions_per_second": num_executions / duration,
}
def run(
self,
func: Any,
num_executions: int,
*args: Any,
**kwargs: Any,
):
return asyncio.run(
self.execute_batch(func, num_executions, *args, **kwargs)
)
# def example_function(x: int = 0) -> int:
# """Example function to execute"""
# return x * x
# async def main():
# # Create executor with number of CPU cores
# executor = HighSpeedExecutor()
# # Execute the function 1000 times
# result = await executor.execute_batch(
# example_function, num_executions=1000, x=42
# )
# print(
# f"Completed {result['executions']} executions in {result['duration']:.2f} seconds"
# )
# print(
# f"Rate: {result['executions_per_second']:.2f} executions/second"
# )
# if __name__ == "__main__":
# # Run the async main function
# asyncio.run(main())

@ -1,18 +1,27 @@
/* Further customization as needed */
/* * Further customization as needed */ */
.md-typeset__table {
min-width: 100%;
min-width: 100%;
}
.md-typeset table:not([class]) {
display: table;
}
/*
:root {
--md-primary-fg-color: #EE0F0F;
--md-primary-fg-color--light: #ECB7B7;
--md-primary-fg-color--dark: #90030C;
} */
/* Dark mode */
[data-md-color-scheme="slate"] {
--md-default-bg-color: black;
}
.header__ellipsis {
color: black;
}
.md-copyright__highlight {
color: black;
}
.md-header.md-header--shadow {
color: black;
}

@ -57,29 +57,35 @@ extra:
property: G-MPE9C65596
theme:
name: material
custom_dir: overrides
logo: assets/img/swarms-logo.png
palette:
name: material
custom_dir: overrides
logo: assets/img/swarms-logo.png
palette:
- scheme: default
primary: black
primary: white # White background
accent: white # Black accents for interactive elements
toggle:
icon: material/brightness-7
icon: material/brightness-7
name: Switch to dark mode
# Palette toggle for dark mode
- scheme: slate
- scheme: slate # Optional: lighter shades for accessibility
primary: black
accent: black
toggle:
icon: material/brightness-4
name: Switch to light mode
features:
- content.code.copy
- content.code.annotate
- navigation.tabs
- navigation.sections
- navigation.expand
- navigation.top
- announce.dismiss
features:
- content.code.copy
- content.code.annotate
- navigation.tabs
- navigation.sections
- navigation.expand
- navigation.top
- announce.dismiss
font:
text: "Fira Sans" # Clean and readable text
code: "Fira Code" # Modern look for code snippets
# Extensions
markdown_extensions:
- abbr

@ -0,0 +1,308 @@
import os
from swarms import Agent
from swarm_models import OpenAIChat
from web3 import Web3
from typing import Dict, Optional, Any
from datetime import datetime
import asyncio
from loguru import logger
from dotenv import load_dotenv
import csv
import requests
import time
BLOCKCHAIN_AGENT_PROMPT = """
You are an expert blockchain and cryptocurrency analyst with deep knowledge of Ethereum markets and DeFi ecosystems.
You have access to real-time ETH price data and transaction information.
For each transaction, analyze:
1. MARKET CONTEXT
- Current ETH price and what this transaction means in USD terms
- How this movement compares to typical market volumes
- Whether this could impact ETH price
2. BEHAVIORAL ANALYSIS
- Whether this appears to be institutional, whale, or protocol movement
- If this fits any known wallet patterns or behaviors
- Signs of smart contract interaction or DeFi activity
3. RISK & IMPLICATIONS
- Potential market impact or price influence
- Signs of potential market manipulation or unusual activity
- Protocol or DeFi risks if applicable
4. STRATEGIC INSIGHTS
- What traders should know about this movement
- Potential chain reactions or follow-up effects
- Market opportunities or risks created
Write naturally but precisely. Focus on actionable insights and important patterns.
Your analysis helps traders and researchers understand significant market movements in real-time."""
class EthereumAnalyzer:
def __init__(self, min_value_eth: float = 100.0):
load_dotenv()
logger.add(
"eth_analysis.log",
rotation="500 MB",
retention="10 days",
level="INFO",
format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
)
self.w3 = Web3(
Web3.HTTPProvider(
"https://mainnet.infura.io/v3/9aa3d95b3bc440fa88ea12eaa4456161"
)
)
if not self.w3.is_connected():
raise ConnectionError(
"Failed to connect to Ethereum network"
)
self.min_value_eth = min_value_eth
self.last_processed_block = self.w3.eth.block_number
self.eth_price = self.get_eth_price()
self.last_price_update = time.time()
# Initialize AI agent
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError(
"OpenAI API key not found in environment variables"
)
model = OpenAIChat(
openai_api_key=api_key,
model_name="gpt-4",
temperature=0.1,
)
self.agent = Agent(
agent_name="Ethereum-Analysis-Agent",
system_prompt=BLOCKCHAIN_AGENT_PROMPT,
llm=model,
max_loops=1,
autosave=True,
dashboard=False,
verbose=True,
dynamic_temperature_enabled=True,
saved_state_path="eth_agent.json",
user_name="eth_analyzer",
retry_attempts=1,
context_length=200000,
output_type="string",
streaming_on=False,
)
self.csv_filename = "ethereum_analysis.csv"
self.initialize_csv()
def get_eth_price(self) -> float:
"""Get current ETH price from CoinGecko API."""
try:
response = requests.get(
"https://api.coingecko.com/api/v3/simple/price",
params={"ids": "ethereum", "vs_currencies": "usd"},
)
return float(response.json()["ethereum"]["usd"])
except Exception as e:
logger.error(f"Error fetching ETH price: {str(e)}")
return 0.0
def update_eth_price(self):
"""Update ETH price if more than 5 minutes have passed."""
if time.time() - self.last_price_update > 300: # 5 minutes
self.eth_price = self.get_eth_price()
self.last_price_update = time.time()
logger.info(f"Updated ETH price: ${self.eth_price:,.2f}")
def initialize_csv(self):
"""Initialize CSV file with headers."""
headers = [
"timestamp",
"transaction_hash",
"from_address",
"to_address",
"value_eth",
"value_usd",
"eth_price",
"gas_used",
"gas_price_gwei",
"block_number",
"analysis",
]
if not os.path.exists(self.csv_filename):
with open(self.csv_filename, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(headers)
async def analyze_transaction(
self, tx_hash: str
) -> Optional[Dict[str, Any]]:
"""Analyze a single transaction."""
try:
tx = self.w3.eth.get_transaction(tx_hash)
receipt = self.w3.eth.get_transaction_receipt(tx_hash)
value_eth = float(self.w3.from_wei(tx.value, "ether"))
if value_eth < self.min_value_eth:
return None
block = self.w3.eth.get_block(tx.blockNumber)
# Update ETH price if needed
self.update_eth_price()
value_usd = value_eth * self.eth_price
analysis = {
"timestamp": datetime.fromtimestamp(
block.timestamp
).isoformat(),
"transaction_hash": tx_hash.hex(),
"from_address": tx["from"],
"to_address": tx.to if tx.to else "Contract Creation",
"value_eth": value_eth,
"value_usd": value_usd,
"eth_price": self.eth_price,
"gas_used": receipt.gasUsed,
"gas_price_gwei": float(
self.w3.from_wei(tx.gasPrice, "gwei")
),
"block_number": tx.blockNumber,
}
# Check if it's a contract
if tx.to:
code = self.w3.eth.get_code(tx.to)
analysis["is_contract"] = len(code) > 0
# Get contract events
if analysis["is_contract"]:
analysis["events"] = receipt.logs
return analysis
except Exception as e:
logger.error(
f"Error analyzing transaction {tx_hash}: {str(e)}"
)
return None
def prepare_analysis_prompt(self, tx_data: Dict[str, Any]) -> str:
"""Prepare detailed analysis prompt including price context."""
value_usd = tx_data["value_usd"]
eth_price = tx_data["eth_price"]
prompt = f"""Analyze this Ethereum transaction in current market context:
Transaction Details:
- Value: {tx_data['value_eth']:.2f} ETH (${value_usd:,.2f} at current price)
- Current ETH Price: ${eth_price:,.2f}
- From: {tx_data['from_address']}
- To: {tx_data['to_address']}
- Contract Interaction: {tx_data.get('is_contract', False)}
- Gas Used: {tx_data['gas_used']:,} units
- Gas Price: {tx_data['gas_price_gwei']:.2f} Gwei
- Block: {tx_data['block_number']}
- Timestamp: {tx_data['timestamp']}
{f"Event Count: {len(tx_data['events'])} events" if tx_data.get('events') else "No contract events"}
Consider the transaction's significance given the current ETH price of ${eth_price:,.2f} and total USD value of ${value_usd:,.2f}.
Analyze market impact, patterns, risks, and strategic implications."""
return prompt
def save_to_csv(self, tx_data: Dict[str, Any], ai_analysis: str):
"""Save transaction data and analysis to CSV."""
row = [
tx_data["timestamp"],
tx_data["transaction_hash"],
tx_data["from_address"],
tx_data["to_address"],
tx_data["value_eth"],
tx_data["value_usd"],
tx_data["eth_price"],
tx_data["gas_used"],
tx_data["gas_price_gwei"],
tx_data["block_number"],
ai_analysis.replace("\n", " "),
]
with open(self.csv_filename, "a", newline="") as f:
writer = csv.writer(f)
writer.writerow(row)
async def monitor_transactions(self):
"""Monitor and analyze transactions one at a time."""
logger.info(
f"Starting transaction monitor (minimum value: {self.min_value_eth} ETH)"
)
while True:
try:
current_block = self.w3.eth.block_number
block = self.w3.eth.get_block(
current_block, full_transactions=True
)
for tx in block.transactions:
tx_analysis = await self.analyze_transaction(
tx.hash
)
if tx_analysis:
# Get AI analysis
analysis_prompt = (
self.prepare_analysis_prompt(tx_analysis)
)
ai_analysis = self.agent.run(analysis_prompt)
print(ai_analysis)
# Save to CSV
self.save_to_csv(tx_analysis, ai_analysis)
# Print analysis
print("\n" + "=" * 50)
print("New Transaction Analysis")
print(
f"Hash: {tx_analysis['transaction_hash']}"
)
print(
f"Value: {tx_analysis['value_eth']:.2f} ETH (${tx_analysis['value_usd']:,.2f})"
)
print(
f"Current ETH Price: ${self.eth_price:,.2f}"
)
print("=" * 50)
print(ai_analysis)
print("=" * 50 + "\n")
await asyncio.sleep(1) # Wait for next block
except Exception as e:
logger.error(f"Error in monitoring loop: {str(e)}")
await asyncio.sleep(1)
async def main():
"""Entry point for the analysis system."""
analyzer = EthereumAnalyzer(min_value_eth=100.0)
await analyzer.monitor_transactions()
if __name__ == "__main__":
print("Starting Ethereum Transaction Analyzer...")
print("Saving results to ethereum_analysis.csv")
print("Press Ctrl+C to stop")
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\nStopping analyzer...")

@ -0,0 +1,4 @@
timestamp,transaction_hash,from_address,to_address,value_eth,gas_used,gas_price_gwei,block_number,analysis
2024-11-27T13:50:35,ddbb665bc75fe848e7ce3d3ce1729243e92466c38ca407deccce8bf629987652,0x267be1C1D684F78cb4F6a176C4911b741E4Ffdc0,0xa40dFEE99E1C85DC97Fdc594b16A460717838703,3200.0,21000,19.968163737,21281878,"Transaction Analysis: This transaction represents a significant transfer of value in the Ethereum network with 3200 ETH (~$6.72 million USD at the current rate) moved from one address to another. It is essential to note that this transaction did not involve smart contract interaction, suggesting it could be a straightforward transfer of funds rather than part of a more complex operation. Looking at the broader market context, large transactions like this can potentially indicate major investment activities or redistribution of assets, which can have ripple effects in the market. If this transaction is part of a larger pattern of significant transfers, it could suggest substantial liquidity moving in the Ethereum ecosystem, possibly affecting the ETH prices. From a DeFi point of view, since there's no contract interaction, it's difficult to infer any direct implications. However, given the substantial value involved, it could be a step in preparation for involvement in DeFi protocols or a move from one DeFi platform to another by a large investor. The transaction fee paid, calculated from the given Gas Used and Gas Price, appears to be within reasonable range. This suggests that the transaction was not rushed and that the sender was willing to wait for this transaction to be confirmed, which might hint towards the non-urgent nature of the transaction. As for potential risk factors or security concerns, the transaction itself appears to be standard and doesn't raise any immediate red flags. However, the parties involved should always be cautious about the address security, maintaining privacy, and avoiding social engineering attacks. For traders and investors, this transaction can be interpreted as a potential bullish sign if it signifies increased liquidity and investment in the Ethereum market, especially if it's followed by similar large transfers. However, due to the anonymous nature of the transaction, it's critical to combine this with other market indicators and not to rely solely on transaction analysis for investment decisions."
2024-11-27T13:52:23,b98bcbf6d57a158b67a126d8f023766e03fb15c3e74becc1189d4244fda61a13,0xEae7380dD4CeF6fbD1144F49E4D1e6964258A4F4,0x28C6c06298d514Db089934071355E5743bf21d60,401.99463589018103,21000,14.978063737,21281887,"Ethereum-Analysis-Agent: Transaction Analysis: This transaction marks a significant transfer of 401.99 ETH, approximately $845,000 at the current rate. The transaction did not involve any smart contract interaction, suggesting a simple fund transfer rather than a complicated operation or interaction with a DeFi protocol. From a broader market perspective, this transaction is meaningful but not as potentially impactful as larger transactions. It can nonetheless be part of a larger pattern of asset movement within the Ethereum ecosystem. If this transaction is part of larger investment activities, it could suggest an increase in demand for ETH and potentially impact its price. Without contract interaction, it's challenging to assess direct implications for DeFi protocols. However, the substantial ETH transfer could suggest a step towards participation in DeFi activities, or a movement of funds between different DeFi platforms. The transaction fee appears reasonable, given the Gas Used and Gas Price. This implies that the transaction wasn't urgent, and the sender was willing to wait for the transaction to be confirmed, indicating a non-critical movement of funds. In terms of security and risk factors, there are no immediate concerns from the transaction itself. Nevertheless, as with any crypto transaction, the parties involved should ensure secure storage of their keys, maintain privacy, and be wary of potential phishing or social engineering attacks. For traders and investors, this transaction could be seen as a bullish sign if it forms part of a trend of increased investment activities in the Ethereum market. However, it's important to remember that transaction analysis should be combined with other market indicators due to the anonymous nature of blockchain transactions."
2024-11-27T13:59:47,a985b74fd3dfee09cbe4a2e6890509e583a3f0ce13f68c98e82996e0f66428be,0xf7858Da8a6617f7C6d0fF2bcAFDb6D2eeDF64840,0xA294cCa691e4C83B1fc0c8d63D9a3eeF0A196DE1,136.0668,494665.408728,3635.46,21000,18.866443971,21281923,"1. MARKET CONTEXT The transaction of 136.07 ETH, equivalent to $494,665.41, is a significant movement in the Ethereum market. However, compared to the daily trading volume of Ethereum, which often exceeds billions of dollars, this transaction is not large enough to significantly impact the ETH price on its own. 2. BEHAVIORAL ANALYSIS The transaction does not appear to be a protocol movement as there is no contract interaction involved. It could be a whale movement, given the substantial amount of ETH transferred. However, without additional information about the wallets involved, it's difficult to definitively determine the nature of the transaction. The gas price of 18.87 Gwei is relatively standard, suggesting that the transaction was not urgent or time-sensitive. 3. RISK & IMPLICATIONS The transaction does not show signs of market manipulation or unusual activity. The absence of contract interaction suggests that this transaction does not directly involve DeFi protocols, reducing the risk of smart contract vulnerabilities or DeFi-related risks. However, the large amount of ETH transferred could potentially influence market sentiment if it is part of a larger trend of similar transactions. 4. STRATEGIC INSIGHTS Traders should note this transaction as part of the broader market activity. While a single transaction of this size is unlikely to significantly impact the market, a series of similar transactions could indicate a larger trend. If this is part of a larger movement of ETH out of exchanges, it could suggest a decrease in selling pressure, which could be bullish for ETH. Conversely, if this is part of a larger movement into exchanges, it could indicate an increase in selling pressure, which could be bearish for ETH. Traders should monitor the market for further similar transactions to gain a better understanding of the potential market trends."
Can't render this file because it has a wrong number of fields in line 4.

@ -0,0 +1,244 @@
import os
import asyncio
from pydantic import BaseModel, Field
from typing import List, Dict, Any
from swarms import Agent
from swarm_models import OpenAIChat
from dotenv import load_dotenv
from swarms.utils.formatter import formatter
# Load environment variables
load_dotenv()
# Get OpenAI API key
api_key = os.getenv("OPENAI_API_KEY")
# Define Pydantic schema for agent outputs
class AgentOutput(BaseModel):
"""Schema for capturing the output of each agent."""
agent_name: str = Field(..., description="The name of the agent")
message: str = Field(
...,
description="The agent's response or contribution to the group chat",
)
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Additional metadata about the agent's response",
)
class GroupChat:
"""
GroupChat class to enable multiple agents to communicate in an asynchronous group chat.
Each agent is aware of all other agents, every message exchanged, and the social context.
"""
def __init__(
self,
name: str,
description: str,
agents: List[Agent],
max_loops: int = 1,
):
"""
Initialize the GroupChat.
Args:
name (str): Name of the group chat.
description (str): Description of the purpose of the group chat.
agents (List[Agent]): A list of agents participating in the chat.
max_loops (int): Maximum number of loops to run through all agents.
"""
self.name = name
self.description = description
self.agents = agents
self.max_loops = max_loops
self.chat_history = (
[]
) # Stores all messages exchanged in the chat
formatter.print_panel(
f"Initialized GroupChat '{self.name}' with {len(self.agents)} agents. Max loops: {self.max_loops}",
title="Groupchat Swarm",
)
async def _agent_conversation(
self, agent: Agent, input_message: str
) -> AgentOutput:
"""
Facilitate a single agent's response to the chat.
Args:
agent (Agent): The agent responding.
input_message (str): The message triggering the response.
Returns:
AgentOutput: The agent's response captured in a structured format.
"""
formatter.print_panel(
f"Agent '{agent.agent_name}' is responding to the message: {input_message}",
title="Groupchat Swarm",
)
response = await asyncio.to_thread(agent.run, input_message)
output = AgentOutput(
agent_name=agent.agent_name,
message=response,
metadata={"context_length": agent.context_length},
)
# logger.debug(f"Agent '{agent.agent_name}' response: {response}")
return output
async def _run(self, initial_message: str) -> List[AgentOutput]:
"""
Execute the group chat asynchronously, looping through all agents up to max_loops.
Args:
initial_message (str): The initial message to start the chat.
Returns:
List[AgentOutput]: The responses of all agents across all loops.
"""
formatter.print_panel(
f"Starting group chat '{self.name}' with initial message: {initial_message}",
title="Groupchat Swarm",
)
self.chat_history.append(
{"sender": "System", "message": initial_message}
)
outputs = []
for loop in range(self.max_loops):
formatter.print_panel(
f"Group chat loop {loop + 1}/{self.max_loops}",
title="Groupchat Swarm",
)
for agent in self.agents:
# Create a custom input message for each agent, sharing the chat history and social context
input_message = (
f"Chat History:\n{self._format_chat_history()}\n\n"
f"Participants:\n"
+ "\n".join(
[
f"- {a.agent_name}: {a.system_prompt}"
for a in self.agents
]
)
+ f"\n\nNew Message: {initial_message}\n\n"
f"You are '{agent.agent_name}'. Remember to keep track of the social context, who is speaking, "
f"and respond accordingly based on your role: {agent.system_prompt}."
)
# Collect agent's response
output = await self._agent_conversation(
agent, input_message
)
outputs.append(output)
# Update chat history with the agent's response
self.chat_history.append(
{
"sender": agent.agent_name,
"message": output.message,
}
)
formatter.print_panel(
"Group chat completed. All agent responses captured.",
title="Groupchat Swarm",
)
return outputs
def run(self, task: str, *args, **kwargs):
return asyncio.run(self.run(task, *args, **kwargs))
def _format_chat_history(self) -> str:
"""
Format the chat history for agents to understand the context.
Returns:
str: The formatted chat history as a string.
"""
return "\n".join(
[
f"{entry['sender']}: {entry['message']}"
for entry in self.chat_history
]
)
def __str__(self) -> str:
"""String representation of the group chat's outputs."""
return self._format_chat_history()
def to_json(self) -> str:
"""JSON representation of the group chat's outputs."""
return [
{"sender": entry["sender"], "message": entry["message"]}
for entry in self.chat_history
]
# Example Usage
if __name__ == "__main__":
load_dotenv()
# Get the OpenAI API key from the environment variable
api_key = os.getenv("OPENAI_API_KEY")
# Create an instance of the OpenAIChat class
model = OpenAIChat(
openai_api_key=api_key,
model_name="gpt-4o-mini",
temperature=0.1,
)
# Example agents
agent1 = Agent(
agent_name="Financial-Analysis-Agent",
system_prompt="You are a financial analyst specializing in investment strategies.",
llm=model,
max_loops=1,
autosave=False,
dashboard=False,
verbose=True,
dynamic_temperature_enabled=True,
user_name="swarms_corp",
retry_attempts=1,
context_length=200000,
output_type="string",
streaming_on=False,
)
agent2 = Agent(
agent_name="Tax-Adviser-Agent",
system_prompt="You are a tax adviser who provides clear and concise guidance on tax-related queries.",
llm=model,
max_loops=1,
autosave=False,
dashboard=False,
verbose=True,
dynamic_temperature_enabled=True,
user_name="swarms_corp",
retry_attempts=1,
context_length=200000,
output_type="string",
streaming_on=False,
)
# Create group chat
group_chat = GroupChat(
name="Financial Discussion",
description="A group chat for financial analysis and tax advice.",
agents=[agent1, agent2],
)
# Run the group chat
asyncio.run(
group_chat.run(
"How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria? What do you guys think?"
)
)

@ -0,0 +1,417 @@
import os
from typing import List, Dict, Any, Optional, Callable
from dataclasses import dataclass, field
import json
from datetime import datetime
import inspect
import typing
from typing import Union
from swarms import Agent
from swarm_models import OpenAIChat
@dataclass
class ToolDefinition:
name: str
description: str
parameters: Dict[str, Any]
required_params: List[str]
callable: Optional[Callable] = None
def extract_type_hints(func: Callable) -> Dict[str, Any]:
"""Extract parameter types from function type hints."""
return typing.get_type_hints(func)
def extract_tool_info(func: Callable) -> ToolDefinition:
"""Extract tool information from a callable function."""
# Get function name
name = func.__name__
# Get docstring
description = inspect.getdoc(func) or "No description available"
# Get parameters and their types
signature = inspect.signature(func)
type_hints = extract_type_hints(func)
parameters = {}
required_params = []
for param_name, param in signature.parameters.items():
# Skip self parameter for methods
if param_name == "self":
continue
param_type = type_hints.get(param_name, Any)
# Handle optional parameters
is_optional = (
param.default != inspect.Parameter.empty
or getattr(param_type, "__origin__", None) is Union
and type(None) in param_type.__args__
)
if not is_optional:
required_params.append(param_name)
parameters[param_name] = {
"type": str(param_type),
"default": (
None
if param.default is inspect.Parameter.empty
else param.default
),
"required": not is_optional,
}
return ToolDefinition(
name=name,
description=description,
parameters=parameters,
required_params=required_params,
callable=func,
)
@dataclass
class FunctionSpec:
"""Specification for a callable tool function."""
name: str
description: str
parameters: Dict[
str, dict
] # Contains type and description for each parameter
return_type: str
return_description: str
@dataclass
class ExecutionStep:
"""Represents a single step in the execution plan."""
step_id: int
function_name: str
parameters: Dict[str, Any]
expected_output: str
completed: bool = False
result: Any = None
@dataclass
class ExecutionContext:
"""Maintains state during execution."""
task: str
steps: List[ExecutionStep] = field(default_factory=list)
results: Dict[int, Any] = field(default_factory=dict)
current_step: int = 0
history: List[Dict[str, Any]] = field(default_factory=list)
class ToolAgent:
def __init__(
self,
functions: List[Callable],
openai_api_key: str,
model_name: str = "gpt-4",
temperature: float = 0.1,
):
self.functions = {func.__name__: func for func in functions}
self.function_specs = self._analyze_functions(functions)
self.model = OpenAIChat(
openai_api_key=openai_api_key,
model_name=model_name,
temperature=temperature,
)
self.system_prompt = self._create_system_prompt()
self.agent = Agent(
agent_name="Tool-Agent",
system_prompt=self.system_prompt,
llm=self.model,
max_loops=1,
verbose=True,
)
def _analyze_functions(
self, functions: List[Callable]
) -> Dict[str, FunctionSpec]:
"""Analyze functions to create detailed specifications."""
specs = {}
for func in functions:
hints = get_type_hints(func)
sig = inspect.signature(func)
doc = inspect.getdoc(func) or ""
# Parse docstring for parameter descriptions
param_descriptions = {}
current_param = None
for line in doc.split("\n"):
if ":param" in line:
param_name = (
line.split(":param")[1].split(":")[0].strip()
)
desc = line.split(":", 2)[-1].strip()
param_descriptions[param_name] = desc
elif ":return:" in line:
return_desc = line.split(":return:")[1].strip()
# Build parameter specifications
parameters = {}
for name, param in sig.parameters.items():
param_type = hints.get(name, Any)
parameters[name] = {
"type": str(param_type),
"type_class": param_type,
"description": param_descriptions.get(name, ""),
"required": param.default == param.empty,
}
specs[func.__name__] = FunctionSpec(
name=func.__name__,
description=doc.split("\n")[0],
parameters=parameters,
return_type=str(hints.get("return", Any)),
return_description=(
return_desc if "return_desc" in locals() else ""
),
)
return specs
def _create_system_prompt(self) -> str:
"""Create system prompt with detailed function specifications."""
functions_desc = []
for spec in self.function_specs.values():
params_desc = []
for name, details in spec.parameters.items():
params_desc.append(
f" - {name}: {details['type']} - {details['description']}"
)
functions_desc.append(
f"""
Function: {spec.name}
Description: {spec.description}
Parameters:
{chr(10).join(params_desc)}
Returns: {spec.return_type} - {spec.return_description}
"""
)
return f"""You are an AI agent that creates and executes plans using available functions.
Available Functions:
{chr(10).join(functions_desc)}
You must respond in two formats depending on the phase:
1. Planning Phase:
{{
"phase": "planning",
"plan": {{
"description": "Overall plan description",
"steps": [
{{
"step_id": 1,
"function": "function_name",
"parameters": {{
"param1": "value1",
"param2": "value2"
}},
"purpose": "Why this step is needed"
}}
]
}}
}}
2. Execution Phase:
{{
"phase": "execution",
"analysis": "Analysis of current result",
"next_action": {{
"type": "continue|request_input|complete",
"reason": "Why this action was chosen",
"needed_input": {{}} # If requesting input
}}
}}
Always:
- Use exact function names
- Ensure parameter types match specifications
- Provide clear reasoning for each decision
"""
def _execute_function(
self, spec: FunctionSpec, parameters: Dict[str, Any]
) -> Any:
"""Execute a function with type checking."""
converted_params = {}
for name, value in parameters.items():
param_spec = spec.parameters[name]
try:
# Convert value to required type
param_type = param_spec["type_class"]
if param_type in (int, float, str, bool):
converted_params[name] = param_type(value)
else:
converted_params[name] = value
except (ValueError, TypeError) as e:
raise ValueError(
f"Parameter '{name}' conversion failed: {str(e)}"
)
return self.functions[spec.name](**converted_params)
def run(self, task: str) -> Dict[str, Any]:
"""Execute task with planning and step-by-step execution."""
context = ExecutionContext(task=task)
execution_log = {
"task": task,
"start_time": datetime.utcnow().isoformat(),
"steps": [],
"final_result": None,
}
try:
# Planning phase
plan_prompt = f"Create a plan to: {task}"
plan_response = self.agent.run(plan_prompt)
plan_data = json.loads(
plan_response.replace("System:", "").strip()
)
# Convert plan to execution steps
for step in plan_data["plan"]["steps"]:
context.steps.append(
ExecutionStep(
step_id=step["step_id"],
function_name=step["function"],
parameters=step["parameters"],
expected_output=step["purpose"],
)
)
# Execution phase
while context.current_step < len(context.steps):
step = context.steps[context.current_step]
print(
f"\nExecuting step {step.step_id}: {step.function_name}"
)
try:
# Execute function
spec = self.function_specs[step.function_name]
result = self._execute_function(
spec, step.parameters
)
context.results[step.step_id] = result
step.completed = True
step.result = result
# Get agent's analysis
analysis_prompt = f"""
Step {step.step_id} completed:
Function: {step.function_name}
Result: {json.dumps(result)}
Remaining steps: {len(context.steps) - context.current_step - 1}
Analyze the result and decide next action.
"""
analysis_response = self.agent.run(
analysis_prompt
)
analysis_data = json.loads(
analysis_response.replace(
"System:", ""
).strip()
)
execution_log["steps"].append(
{
"step_id": step.step_id,
"function": step.function_name,
"parameters": step.parameters,
"result": result,
"analysis": analysis_data,
}
)
if (
analysis_data["next_action"]["type"]
== "complete"
):
if (
context.current_step
< len(context.steps) - 1
):
continue
break
context.current_step += 1
except Exception as e:
print(f"Error in step {step.step_id}: {str(e)}")
execution_log["steps"].append(
{
"step_id": step.step_id,
"function": step.function_name,
"parameters": step.parameters,
"error": str(e),
}
)
raise
# Final analysis
final_prompt = f"""
Task completed. Results:
{json.dumps(context.results, indent=2)}
Provide final analysis and recommendations.
"""
final_analysis = self.agent.run(final_prompt)
execution_log["final_result"] = {
"success": True,
"results": context.results,
"analysis": json.loads(
final_analysis.replace("System:", "").strip()
),
}
except Exception as e:
execution_log["final_result"] = {
"success": False,
"error": str(e),
}
execution_log["end_time"] = datetime.utcnow().isoformat()
return execution_log
def calculate_investment_return(
principal: float, rate: float, years: int
) -> float:
"""Calculate investment return with compound interest.
:param principal: Initial investment amount in dollars
:param rate: Annual interest rate as decimal (e.g., 0.07 for 7%)
:param years: Number of years to invest
:return: Final investment value
"""
return principal * (1 + rate) ** years
agent = ToolAgent(
functions=[calculate_investment_return],
openai_api_key=os.getenv("OPENAI_API_KEY"),
)
result = agent.run(
"Calculate returns for $10000 invested at 7% for 10 years"
)

@ -1,6 +1,5 @@
import os
import asyncio
import threading
from swarms import Agent
from swarm_models import OpenAIChat
import time
@ -40,18 +39,21 @@ agent = Agent(
streaming_on=False,
)
# Function to measure time and memory usage
def measure_time_and_memory(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
memory_usage = psutil.Process().memory_info().rss / 1024 ** 2
memory_usage = psutil.Process().memory_info().rss / 1024**2
print(f"Time taken: {end_time - start_time} seconds")
print(f"Memory used: {memory_usage} MB")
return result
return wrapper
# Function to run the agent asynchronously
@measure_time_and_memory
async def run_agent_async():
@ -61,11 +63,13 @@ async def run_agent_async():
)
)
# Function to run the agent on another thread
@measure_time_and_memory
def run_agent_thread():
asyncio.run(run_agent_async())
# Run the agent asynchronously and on another thread to test the speed
asyncio.run(run_agent_async())
run_agent_thread()

@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "swarms"
version = "6.2.9"
version = "6.3.7"
description = "Swarms - Pytorch"
license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"]
@ -37,6 +37,14 @@ keywords = [
"Generative AI",
"Agent Marketplace",
"Agent Store",
"quant",
"finance",
"algorithmic trading",
"portfolio optimization",
"risk management",
"financial modeling",
"machine learning for finance",
"natural language processing for finance",
]
classifiers = [
"Development Status :: 4 - Beta",
@ -52,27 +60,18 @@ python = ">=3.10,<4.0"
torch = ">=2.1.1,<3.0"
transformers = ">= 4.39.0, <5.0.0"
asyncio = ">=3.4.3,<4.0"
langchain-community = "0.0.29"
langchain-experimental = "0.0.55"
backoff = "2.2.1"
toml = "*"
pypdf = "4.3.1"
loguru = "0.7.2"
loguru = "*"
pydantic = "2.8.2"
tenacity = "8.5.0"
Pillow = "10.4.0"
tenacity = "*"
psutil = "*"
sentry-sdk = {version = "*", extras = ["http"]} # Updated here
python-dotenv = "*"
PyYAML = "*"
docstring_parser = "0.16"
fastapi = "*"
openai = ">=1.30.1,<2.0"
termcolor = "*"
tiktoken = "*"
networkx = "*"
swarms-memory = "*"
black = "*"
aiofiles = "*"
swarm-models = "*"
clusterops = "*"
@ -96,9 +95,7 @@ mypy-protobuf = "^3.0.0"
[tool.poetry.group.test.dependencies]
pytest = "^8.1.1"
termcolor = "^2.4.0"
pandas = "^2.2.2"
fastapi = ">=0.110.1,<0.116.0"
[tool.ruff]
line-length = 70

@ -95,12 +95,14 @@ flow = "BossAgent -> ExpenseAnalyzer -> SummaryGenerator"
# Using AgentRearrange class to manage the swarm
agent_system = AgentRearrange(
name="pe-swarm",
description="ss",
agents=agents,
flow=flow,
return_json=False,
output_type="final",
max_loops=1,
docs=["SECURITY.md"],
# docs=["SECURITY.md"],
)
# Input task for the swarm

@ -2,21 +2,16 @@
torch>=2.1.1,<3.0
transformers>=4.39.0,<5.0.0
asyncio>=3.4.3,<4.0
langchain-community==0.0.28
langchain-experimental==0.0.55
backoff==2.2.1
toml
pypdf==4.3.1
ratelimit==2.2.1
loguru==0.7.2
pydantic==2.8.2
tenacity==8.5.0
Pillow==10.4.0
tenacity
rich
psutil
sentry-sdk
python-dotenv
opencv-python-headless
PyYAML
docstring_parser==0.16
black>=23.1,<25.0
@ -26,12 +21,8 @@ types-pytz>=2023.3,<2025.0
types-chardet>=5.0.4.6
mypy-protobuf>=3.0.0
pytest>=8.1.1
termcolor>=2.4.0
pandas>=2.2.2
fastapi>=0.110.1
networkx
swarms-memory
pre-commit
aiofiles
swarm-models
clusterops

@ -1,22 +1,38 @@
import os
import concurrent.futures
from dotenv import load_dotenv
# from swarms.structs.workspace_manager import WorkspaceManager
# workspace_manager = WorkspaceManager()
# workspace_manager.run()
from loguru import logger
load_dotenv()
# Disable logging by default
if os.getenv("SWARMS_VERBOSE_GLOBAL", "False").lower() == "false":
logger.disable("")
# Import telemetry functions with error handling
from swarms.telemetry.bootup import bootup # noqa: E402, F403
from swarms.telemetry.sentry_active import (
from swarms.telemetry.sentry_active import ( # noqa: E402
activate_sentry,
) # noqa: E402
# Use ThreadPoolExecutor to run bootup and activate_sentry concurrently
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
executor.submit(bootup)
executor.submit(activate_sentry)
# Run telemetry functions concurrently with error handling
def run_telemetry():
try:
with concurrent.futures.ThreadPoolExecutor(
max_workers=2
) as executor:
future_bootup = executor.submit(bootup)
future_sentry = executor.submit(activate_sentry)
# Wait for completion and check for exceptions
future_bootup.result()
future_sentry.result()
except Exception as e:
logger.error(f"Error running telemetry functions: {e}")
run_telemetry()
from swarms.agents import * # noqa: E402, F403
from swarms.artifacts import * # noqa: E402, F403

@ -24,8 +24,6 @@ import toml
import yaml
from pydantic import BaseModel
from swarm_models.tiktoken_wrapper import TikTokenizer
from termcolor import colored
from swarms.agents.ape_agent import auto_generate_prompt
from swarms.prompts.agent_system_prompts import AGENT_SYSTEM_PROMPT_3
from swarms.prompts.multi_modal_autonomous_instruction_prompt import (
@ -671,11 +669,8 @@ class Agent:
return self.stopping_condition(response)
return False
except Exception as error:
print(
colored(
f"Error checking stopping condition: {error}",
"red",
)
logger.error(
f"Error checking stopping condition: {error}"
)
def dynamic_temperature(self):
@ -688,21 +683,20 @@ class Agent:
try:
if hasattr(self.llm, "temperature"):
# Randomly change the temperature attribute of self.llm object
logger.info("Enabling Random Dyamic Temperature")
self.llm.temperature = random.uniform(0.0, 1.0)
else:
# Use a default temperature
self.llm.temperature = 0.5
except Exception as error:
print(
colored(
f"Error dynamically changing temperature: {error}"
)
logger.error(
f"Error dynamically changing temperature: {error}"
)
def print_dashboard(self):
"""Print dashboard"""
print(colored("Initializing Agent Dashboard...", "yellow"))
formatter.print_panel(
f"Initializing Agent: {self.agent_name}"
)
data = self.to_dict()
@ -710,22 +704,19 @@ class Agent:
# data = json.dumps(data, indent=4)
# json_data = json.dumps(data, indent=4)
print(
colored(
f"""
Agent Dashboard
--------------------------------------------
formatter.print_panel(
f"""
Agent Dashboard
--------------------------------------------
Agent {self.agent_name} is initializing for {self.max_loops} with the following configuration:
----------------------------------------
Agent {self.agent_name} is initializing for {self.max_loops} with the following configuration:
----------------------------------------
Agent Configuration:
Configuration: {data}
Agent Configuration:
Configuration: {data}
----------------------------------------
""",
"green",
)
----------------------------------------
""",
)
def loop_count_print(
@ -737,7 +728,7 @@ class Agent:
loop_count (_type_): _description_
max_loops (_type_): _description_
"""
print(colored(f"\nLoop {loop_count} of {max_loops}", "cyan"))
logger.info(f"\nLoop {loop_count} of {max_loops}")
print("\n")
# Check parameters
@ -761,8 +752,8 @@ class Agent:
self,
task: Optional[str] = None,
img: Optional[str] = None,
is_last: bool = False,
print_task: bool = False,
is_last: Optional[bool] = False,
print_task: Optional[bool] = False,
*args,
**kwargs,
) -> Any:
@ -960,7 +951,7 @@ class Agent:
if self.interactive:
logger.info("Interactive mode enabled.")
user_input = colored(input("You: "), "red")
user_input = formatter.print_panel(input("You: "))
# User-defined exit command
if (
@ -1060,7 +1051,7 @@ class Agent:
except Exception as error:
logger.info(
f"Error running agent: {error} optimize your input parameters"
f"Error running agent: {error} optimize your input parameter"
)
raise error
@ -1261,7 +1252,7 @@ class Agent:
logger.info(f"Running bulk tasks: {inputs}")
return [self.run(**input_data) for input_data in inputs]
except Exception as error:
print(colored(f"Error running bulk run: {error}", "red"))
logger.info(f"Error running bulk run: {error}", "red")
def save(self) -> None:
"""Save the agent history to a file.
@ -1438,9 +1429,7 @@ class Agent:
with open(file_path, "w") as f:
yaml.dump(self.to_dict(), f)
except Exception as error:
logger.error(
colored(f"Error saving agent to YAML: {error}", "red")
)
logger.error(f"Error saving agent to YAML: {error}")
raise error
def get_llm_parameters(self):
@ -1505,7 +1494,7 @@ class Agent:
role=self.user_name, content=data
)
except Exception as error:
print(colored(f"Error ingesting docs: {error}", "red"))
logger.info(f"Error ingesting docs: {error}", "red")
def ingest_pdf(self, pdf: str):
"""Ingest the pdf into the memory
@ -1520,7 +1509,7 @@ class Agent:
role=self.user_name, content=text
)
except Exception as error:
print(colored(f"Error ingesting pdf: {error}", "red"))
logger.info(f"Error ingesting pdf: {error}", "red")
def receieve_message(self, name: str, message: str):
"""Receieve a message"""
@ -1604,12 +1593,10 @@ class Agent:
role=self.user_name, content=text
)
except Exception as error:
print(
colored(
f"Error getting docs from doc folders: {error}",
"red",
)
logger.error(
f"Error getting docs from doc folders: {error}"
)
raise error
def check_end_session_agentops(self):
if self.agent_ops_on is True:
@ -1629,7 +1616,8 @@ class Agent:
try:
# Query the long term memory
if self.long_term_memory is not None:
logger.info(f"Querying long term memory for: {task}")
formatter.print_panel(f"Querying RAG for: {task}")
memory_retrieval = self.long_term_memory.query(
task, *args, **kwargs
)
@ -1638,15 +1626,15 @@ class Agent:
f"Documents Available: {str(memory_retrieval)}"
)
# Count the tokens
memory_token_count = self.tokenizer.count_tokens(
memory_retrieval
)
if memory_token_count > self.memory_chunk_size:
# Truncate the memory by the memory chunk size
memory_retrieval = self.truncate_string_by_tokens(
memory_retrieval, self.memory_chunk_size
)
# # Count the tokens
# memory_token_count = self.tokenizer.count_tokens(
# memory_retrieval
# )
# if memory_token_count > self.memory_chunk_size:
# # Truncate the memory by the memory chunk size
# memory_retrieval = self.truncate_string_by_tokens(
# memory_retrieval, self.memory_chunk_size
# )
self.short_memory.add(
role="Database",

@ -1,109 +1,87 @@
import os
from typing import List, Any
from swarms.structs.agent import Agent
from loguru import logger
import uuid
WORKSPACE_DIR = os.getenv("WORKSPACE_DIR")
uuid_for_log = str(uuid.uuid4())
logger.add(
os.path.join(
WORKSPACE_DIR,
"agents_available",
f"agents-available-{uuid_for_log}.log",
),
level="INFO",
colorize=True,
backtrace=True,
diagnose=True,
)
def get_agent_name(agent: Any) -> str:
"""Helper function to safely get agent name
Args:
agent (Any): The agent object to get name from
Returns:
str: The agent's name if found, 'Unknown' otherwise
"""
if isinstance(agent, Agent) and hasattr(agent, "agent_name"):
return agent.agent_name
return "Unknown"
def get_agent_description(agent: Any) -> str:
"""Helper function to get agent description or system prompt preview
Args:
agent (Any): The agent object
Returns:
str: Description or first 100 chars of system prompt
"""
if not isinstance(agent, Agent):
return "N/A"
if hasattr(agent, "description") and agent.description:
return agent.description
if hasattr(agent, "system_prompt") and agent.system_prompt:
return f"{agent.system_prompt[:150]}..."
return "N/A"
from typing import List
def showcase_available_agents(
agents: List[Agent],
name: str = None,
description: str = None,
agents: List[Agent] = [],
update_agents_on: bool = False,
format: str = "XML",
) -> str:
"""
Generate a formatted string showcasing all available agents and their descriptions.
Format the available agents in either XML or Table format.
Args:
agents (List[Agent]): List of Agent objects to showcase.
update_agents_on (bool, optional): If True, updates each agent's system prompt with
the showcase information. Defaults to False.
agents (List[Agent]): A list of agents to represent
name (str, optional): Name of the swarm
description (str, optional): Description of the swarm
format (str, optional): Output format ("XML" or "Table"). Defaults to "XML"
Returns:
str: Formatted string containing agent information, including names, descriptions
and IDs for all available agents.
str: Formatted string containing agent information
"""
logger.info(f"Showcasing {len(agents)} available agents")
formatted_agents = []
header = f"\n####### Agents available in the swarm: {name} ############\n"
header += f"{description}\n"
row_format = "{:<5} | {:<20} | {:<50}"
header_row = row_format.format("ID", "Agent Name", "Description")
separator = "-" * 80
formatted_agents.append(header)
formatted_agents.append(separator)
formatted_agents.append(header_row)
formatted_agents.append(separator)
for idx, agent in enumerate(agents):
if not isinstance(agent, Agent):
logger.warning(
f"Skipping non-Agent object: {type(agent)}"
)
continue
agent_name = get_agent_name(agent)
description = (
get_agent_description(agent)[:100] + "..."
if len(get_agent_description(agent)) > 100
else get_agent_description(agent)
def truncate(text: str, max_length: int = 130) -> str:
return (
f"{text[:max_length]}..."
if len(text) > max_length
else text
)
formatted_agents.append(
row_format.format(idx + 1, agent_name, description)
output = []
if format.upper() == "TABLE":
output.append("\n| ID | Agent Name | Description |")
output.append("|-----|------------|-------------|")
for idx, agent in enumerate(agents):
if isinstance(agent, Agent):
agent_name = getattr(agent, "agent_name", str(agent))
description = getattr(
agent,
"description",
getattr(
agent, "system_prompt", "Unknown description"
),
)
desc = truncate(description, 50)
output.append(
f"| {idx + 1} | {agent_name} | {desc} |"
)
else:
output.append(
f"| {idx + 1} | {agent} | Unknown description |"
)
return "\n".join(output)
# Default XML format
output.append("<agents>")
if name:
output.append(f" <name>{name}</name>")
if description:
output.append(
f" <description>{truncate(description)}</description>"
)
for idx, agent in enumerate(agents):
output.append(f" <agent id='{idx + 1}'>")
if isinstance(agent, Agent):
agent_name = getattr(agent, "agent_name", str(agent))
description = getattr(
agent,
"description",
getattr(
agent, "system_prompt", "Unknown description"
),
)
output.append(f" <name>{agent_name}</name>")
output.append(
f" <description>{truncate(description)}</description>"
)
else:
output.append(f" <name>{agent}</name>")
output.append(
" <description>Unknown description</description>"
)
output.append(" </agent>")
output.append("</agents>")
showcase = "\n".join(formatted_agents)
return showcase
return "\n".join(output)

@ -1,8 +1,7 @@
import json
from typing import Any, Dict, List, Optional
from termcolor import colored
from swarms.utils.formatter import formatter
from swarms.structs.agent import Agent
from swarms.structs.base_structure import BaseStructure
from swarms.structs.task import Task
@ -132,9 +131,10 @@ class BaseWorkflow(BaseStructure):
for task in self.tasks:
task.result = None
except Exception as error:
print(
colored(f"Error resetting workflow: {error}", "red"),
formatter.print_panel(
f"Error resetting workflow: {error}"
)
raise error
def get_task_results(self) -> Dict[str, Any]:
"""
@ -148,10 +148,8 @@ class BaseWorkflow(BaseStructure):
task.description: task.result for task in self.tasks
}
except Exception as error:
print(
colored(
f"Error getting task results: {error}", "red"
),
formatter.print_panel(
f"Error getting task results: {error}"
)
def remove_task(self, task: str) -> None:
@ -163,12 +161,10 @@ class BaseWorkflow(BaseStructure):
if task.description != task
]
except Exception as error:
print(
colored(
f"Error removing task from workflow: {error}",
"red",
),
formatter.print_panel(
f"Error removing task from workflow: {error}",
)
raise error
def update_task(self, task: str, **updates) -> None:
"""
@ -203,11 +199,9 @@ class BaseWorkflow(BaseStructure):
f"Task {task} not found in workflow."
)
except Exception as error:
print(
colored(
f"Error updating task in workflow: {error}", "red"
),
)
formatter.print_panel(
f"Error updating task in workflow: {error}"
),
def delete_task(self, task: str) -> None:
"""
@ -240,12 +234,10 @@ class BaseWorkflow(BaseStructure):
f"Task {task} not found in workflow."
)
except Exception as error:
print(
colored(
f"Error deleting task from workflow: {error}",
"red",
),
formatter.print_panel(
f"Error deleting task from workflow: {error}",
)
raise error
def save_workflow_state(
self,
@ -287,23 +279,18 @@ class BaseWorkflow(BaseStructure):
}
json.dump(state, f, indent=4)
except Exception as error:
print(
colored(
f"Error saving workflow state: {error}",
"red",
)
formatter.print_panel(
f"Error saving workflow state: {error}",
)
raise error
def add_objective_to_workflow(self, task: str, **kwargs) -> None:
"""Adds an objective to the workflow."""
try:
print(
colored(
"""
Adding Objective to Workflow...""",
"green",
attrs=["bold", "underline"],
)
formatter.print_panel(
"""
Adding Objective to Workflow...""",
"green",
)
task = Task(
@ -314,12 +301,10 @@ class BaseWorkflow(BaseStructure):
)
self.tasks.append(task)
except Exception as error:
print(
colored(
f"Error adding objective to workflow: {error}",
"red",
)
formatter.print_panel(
f"Error adding objective to workflow: {error}",
)
raise error
def load_workflow_state(
self, filepath: str = None, **kwargs
@ -359,11 +344,8 @@ class BaseWorkflow(BaseStructure):
)
self.tasks.append(task)
except Exception as error:
print(
colored(
f"Error loading workflow state: {error}",
"red",
)
formatter.print_panel(
f"Error loading workflow state: {error}",
)
def workflow_dashboard(self, **kwargs) -> None:
@ -383,25 +365,21 @@ class BaseWorkflow(BaseStructure):
>>> workflow.workflow_dashboard()
"""
print(
colored(
f"""
Sequential Workflow Dashboard
--------------------------------
Name: {self.name}
Description: {self.description}
task_pool: {len(self.task_pool)}
Max Loops: {self.max_loops}
Autosave: {self.autosave}
Autosave Filepath: {self.saved_state_filepath}
Restore Filepath: {self.restore_state_filepath}
--------------------------------
Metadata:
kwargs: {kwargs}
""",
"cyan",
attrs=["bold", "underline"],
)
formatter.print_panel(
f"""
Sequential Workflow Dashboard
--------------------------------
Name: {self.name}
Description: {self.description}
task_pool: {len(self.task_pool)}
Max Loops: {self.max_loops}
Autosave: {self.autosave}
Autosave Filepath: {self.saved_state_filepath}
Restore Filepath: {self.restore_state_filepath}
--------------------------------
Metadata:
kwargs: {kwargs}
"""
)
def workflow_bootup(self, **kwargs) -> None:
@ -409,11 +387,6 @@ class BaseWorkflow(BaseStructure):
Workflow bootup.
"""
print(
colored(
"""
Sequential Workflow Initializing...""",
"green",
attrs=["bold", "underline"],
)
formatter.print_panel(
"""Sequential Workflow Initializing...""",
)

@ -3,10 +3,9 @@ import json
from typing import Any, Optional
import yaml
from termcolor import colored
from swarms.structs.base_structure import BaseStructure
from typing import TYPE_CHECKING
from swarms.utils.formatter import formatter
if TYPE_CHECKING:
from swarms.structs.agent import (
@ -191,18 +190,9 @@ class Conversation(BaseStructure):
Args:
detailed (bool, optional): detailed. Defaults to False.
"""
role_to_color = {
"system": "red",
"user": "green",
"assistant": "blue",
"function": "magenta",
}
for message in self.conversation_history:
print(
colored(
f"{message['role']}: {message['content']}\n\n",
role_to_color[message["role"]],
)
formatter.print_panel(
f"{message['role']}: {message['content']}\n\n"
)
def export_conversation(self, filename: str, *args, **kwargs):
@ -307,46 +297,36 @@ class Conversation(BaseStructure):
for message in messages:
if message["role"] == "system":
print(
colored(
f"system: {message['content']}\n",
role_to_color[message["role"]],
)
formatter.print_panel(
f"system: {message['content']}\n",
role_to_color[message["role"]],
)
elif message["role"] == "user":
print(
colored(
f"user: {message['content']}\n",
role_to_color[message["role"]],
)
formatter.print_panel(
f"user: {message['content']}\n",
role_to_color[message["role"]],
)
elif message["role"] == "assistant" and message.get(
"function_call"
):
print(
colored(
f"assistant: {message['function_call']}\n",
role_to_color[message["role"]],
)
formatter.print_panel(
f"assistant: {message['function_call']}\n",
role_to_color[message["role"]],
)
elif message["role"] == "assistant" and not message.get(
"function_call"
):
print(
colored(
f"assistant: {message['content']}\n",
role_to_color[message["role"]],
)
formatter.print_panel(
f"assistant: {message['content']}\n",
role_to_color[message["role"]],
)
elif message["role"] == "tool":
print(
colored(
(
f"function ({message['name']}):"
f" {message['content']}\n"
),
role_to_color[message["role"]],
)
formatter.print_panel(
(
f"function ({message['name']}):"
f" {message['content']}\n"
),
role_to_color[message["role"]],
)
def truncate_memory_with_tokenizer(self):

@ -86,9 +86,7 @@ class MixtureOfAgents:
self.input_schema = MixtureOfAgentsInput(
name=name,
description=description,
agents=[
agent.to_dict() for agent in self.agents
],
agents=[agent.to_dict() for agent in self.agents],
aggregator_agent=aggregator_agent.to_dict(),
aggregator_system_prompt=self.aggregator_system_prompt,
layers=self.layers,

@ -414,7 +414,7 @@ def run_agents_with_tasks_concurrently(
List[Any]: A list of outputs from each agent execution.
"""
# Make the first agent not use the ifrs
if no_clusterops:
return _run_agents_with_tasks_concurrently(
agents, tasks, batch_size, max_workers

@ -1,5 +1,5 @@
import traceback
import asyncio
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
@ -13,10 +13,10 @@ from swarms.structs.agent import Agent
from swarms.structs.agents_available import showcase_available_agents
from swarms.structs.base_swarm import BaseSwarm
from swarms.utils.add_docs_to_agents import handle_input_docs
from swarms.utils.loguru_logger import initialize_logger
from swarms.utils.wrapper_clusterop import (
exec_callable_with_clusterops,
)
from swarms.utils.loguru_logger import initialize_logger
logger = initialize_logger(log_folder="rearrange")
@ -121,16 +121,14 @@ class AgentRearrange(BaseSwarm):
output_type: OutputType = "final",
docs: List[str] = None,
doc_folder: str = None,
device: str = "cpu",
device_id: int = 0,
all_cores: bool = False,
all_gpus: bool = True,
no_use_clusterops: bool = True,
*args,
**kwargs,
):
# reliability_check(
# agents=agents,
# name=name,
# description=description,
# flow=flow,
# max_loops=max_loops,
# )
super(AgentRearrange, self).__init__(
name=name,
description=description,
@ -150,33 +148,11 @@ class AgentRearrange(BaseSwarm):
self.output_type = output_type
self.docs = docs
self.doc_folder = doc_folder
self.swarm_history = {
agent.agent_name: [] for agent in agents
}
self.id = uuid.uuid4().hex if id is None else id
# Output schema
self.input_config = AgentRearrangeInput(
swarm_id=self.id,
name=self.name,
description=self.description,
flow=self.flow,
max_loops=self.max_loops,
output_type=self.output_type,
)
# Output schema
self.output_schema = AgentRearrangeOutput(
Input=self.input_config,
outputs=[],
)
# Run the reliability checks to validate the swarm
# self.handle_input_docs()
# Show the agents whose in the swarm
# self.showcase_agents()
self.device = device
self.device_id = device_id
self.all_cores = all_cores
self.all_gpus = all_gpus
self.no_use_clusterops = no_use_clusterops
def showcase_agents(self):
# Get formatted agent info once
@ -184,12 +160,34 @@ class AgentRearrange(BaseSwarm):
name=self.name,
description=self.description,
agents=self.agents,
format="Table",
)
# Update all agents in one pass using values()
for agent in self.agents.values():
if isinstance(agent, Agent):
agent.system_prompt += agents_available
return agents_available
def rearrange_prompt_prep(self) -> str:
"""Prepares a formatted prompt describing the swarm configuration.
Returns:
str: A formatted string containing the swarm's name, description,
flow pattern, and participating agents.
"""
agents_available = self.showcase_agents()
prompt = f"""
===== Swarm Configuration =====
Name: {self.name}
Description: {self.description}
===== Execution Flow =====
{self.flow}
===== Participating Agents =====
{agents_available}
===========================
"""
return prompt
def set_custom_flow(self, flow: str):
self.flow = flow
@ -322,6 +320,7 @@ class AgentRearrange(BaseSwarm):
current_task = task
all_responses = []
response_dict = {}
previous_agent = None
logger.info(
f"Starting task execution with {len(tasks)} steps"
@ -346,12 +345,19 @@ class AgentRearrange(BaseSwarm):
f"Starting loop {loop_count + 1}/{self.max_loops}"
)
for task in tasks:
for task_idx, task in enumerate(tasks):
is_last = task == tasks[-1]
agent_names = [
name.strip() for name in task.split(",")
]
# Prepare prompt with previous agent info
prompt_prefix = ""
if previous_agent and task_idx > 0:
prompt_prefix = f"Previous agent {previous_agent} output: {current_task}\n"
elif task_idx == 0:
prompt_prefix = "Initial task: "
if len(agent_names) > 1:
# Parallel processing
logger.info(
@ -367,12 +373,14 @@ class AgentRearrange(BaseSwarm):
):
current_task = (
self.custom_human_in_the_loop(
current_task
prompt_prefix
+ str(current_task)
)
)
else:
current_task = input(
"Enter your response:"
prompt_prefix
+ "Enter your response: "
)
results.append(current_task)
response_dict[agent_name] = (
@ -380,13 +388,13 @@ class AgentRearrange(BaseSwarm):
)
else:
agent = self.agents[agent_name]
current_task = (
str(current_task)
task_with_context = (
prompt_prefix + str(current_task)
if current_task
else ""
else prompt_prefix
)
result = agent.run(
task=current_task,
task=task_with_context,
img=img,
is_last=is_last,
*args,
@ -404,6 +412,7 @@ class AgentRearrange(BaseSwarm):
current_task = "; ".join(results)
all_responses.extend(results)
previous_agent = ",".join(agent_names)
else:
# Sequential processing
@ -419,23 +428,25 @@ class AgentRearrange(BaseSwarm):
):
current_task = (
self.custom_human_in_the_loop(
current_task
prompt_prefix
+ str(current_task)
)
)
else:
current_task = input(
"Enter the next task: "
prompt_prefix
+ "Enter the next task: "
)
response_dict[agent_name] = current_task
else:
agent = self.agents[agent_name]
current_task = (
str(current_task)
task_with_context = (
prompt_prefix + str(current_task)
if current_task
else ""
else prompt_prefix
)
current_task = agent.run(
task=current_task,
task=task_with_context,
img=img,
is_last=is_last,
*args,
@ -451,6 +462,7 @@ class AgentRearrange(BaseSwarm):
)
all_responses.append(current_task)
previous_agent = agent_name
loop_count += 1
@ -506,7 +518,11 @@ class AgentRearrange(BaseSwarm):
Returns:
The result from executing the task through the cluster operations wrapper.
"""
if no_use_clusterops:
no_use_clusterops = (
no_use_clusterops or self.no_use_clusterops
)
if no_use_clusterops is True:
return self._run(
task=task,
img=img,

@ -107,7 +107,7 @@ class SequentialWorkflow:
all_cores: bool = False,
all_gpus: bool = False,
device_id: int = 0,
no_use_clusterops: bool = False,
no_use_clusterops: bool = True,
*args,
**kwargs,
) -> str:

@ -12,7 +12,7 @@ def auto_update():
try:
# Check if auto-update is disabled
auto_update_enabled = os.getenv(
"SWARMS_AUTOUPDATE_ON", "true"
"SWARMS_AUTOUPDATE_ON", "false"
).lower()
if auto_update_enabled == "false":
logger.info(

@ -1,7 +1,6 @@
import json
from typing import Any, Dict, List, Union
from termcolor import cprint
from transformers import PreTrainedModel, PreTrainedTokenizer
from pydantic import BaseModel
from swarms.tools.logits_processor import (
@ -68,15 +67,6 @@ class Jsonformer:
self.temperature = temperature
self.max_string_token_length = max_string_token_length
def debug(self, caller: str, value: str, is_prompt: bool = False):
if self.debug_on:
if is_prompt:
cprint(caller, "green", end=" ")
cprint(value, "yellow")
else:
cprint(caller, "green", end=" ")
cprint(value, "blue")
def generate_number(
self, temperature: Union[float, None] = None, iterations=0
):

@ -3,8 +3,7 @@ from typing import Any, List
import inspect
from typing import Callable
from termcolor import colored
from swarms.utils.formatter import formatter
def scrape_tool_func_docs(fn: Callable) -> str:
@ -37,17 +36,16 @@ def scrape_tool_func_docs(fn: Callable) -> str:
f" {inspect.getdoc(fn)}\nParameters:\n{parameters_str}"
)
except Exception as error:
print(
colored(
(
f"Error scraping tool function docs {error} try"
" optimizing your inputs with different"
" variables and attempt once more."
),
"red",
)
(
formatter.print_panel(
f"Error scraping tool function docs {error} try"
" optimizing your inputs with different"
" variables and attempt once more."
),
)
raise error
def tool_find_by_name(tool_name: str, tools: List[Any]):
"""Find the tool by name"""

@ -0,0 +1,203 @@
from typing import Any
import inspect
from functools import partial
import logging
class NameResolver:
"""Utility class for resolving names of various objects"""
@staticmethod
def get_name(obj: Any, default: str = "unnamed_callable") -> str:
"""
Get the name of any object with multiple fallback strategies.
Args:
obj: The object to get the name from
default: Default name if all strategies fail
Returns:
str: The resolved name
"""
strategies = [
# Try getting __name__ attribute
lambda x: getattr(x, "__name__", None),
# Try getting class name
lambda x: (
x.__class__.__name__
if hasattr(x, "__class__")
else None
),
# Try getting function name if it's a partial
lambda x: (
x.func.__name__ if isinstance(x, partial) else None
),
# Try getting the name from the class's type
lambda x: type(x).__name__,
# Try getting qualname
lambda x: getattr(x, "__qualname__", None),
# Try getting the module and class name
lambda x: (
f"{x.__module__}.{x.__class__.__name__}"
if hasattr(x, "__module__")
else None
),
# For async functions
lambda x: (
x.__name__ if inspect.iscoroutinefunction(x) else None
),
# For classes with custom __str__
lambda x: (
str(x)
if hasattr(x, "__str__")
and x.__str__ != object.__str__
else None
),
# For wrapped functions
lambda x: (
getattr(x, "__wrapped__", None).__name__
if hasattr(x, "__wrapped__")
else None
),
]
# Try each strategy
for strategy in strategies:
try:
name = strategy(obj)
if name and isinstance(name, str):
return name.replace(" ", "_").replace("-", "_")
except Exception:
continue
# Return default if all strategies fail
return default
@staticmethod
def get_callable_details(obj: Any) -> dict:
"""
Get detailed information about a callable object.
Returns:
dict: Dictionary containing:
- name: The resolved name
- type: The type of callable
- signature: The signature if available
- module: The module name if available
- doc: The docstring if available
"""
details = {
"name": NameResolver.get_name(obj),
"type": "unknown",
"signature": None,
"module": getattr(obj, "__module__", "unknown"),
"doc": inspect.getdoc(obj)
or "No documentation available",
}
# Determine the type
if inspect.isclass(obj):
details["type"] = "class"
elif inspect.iscoroutinefunction(obj):
details["type"] = "async_function"
elif inspect.isfunction(obj):
details["type"] = "function"
elif isinstance(obj, partial):
details["type"] = "partial"
elif callable(obj):
details["type"] = "callable"
# Try to get signature
try:
details["signature"] = str(inspect.signature(obj))
except (ValueError, TypeError):
details["signature"] = "Unknown signature"
return details
@classmethod
def get_safe_name(cls, obj: Any, max_retries: int = 3) -> str:
"""
Safely get a name with retries and validation.
Args:
obj: Object to get name from
max_retries: Maximum number of retry attempts
Returns:
str: A valid name string
"""
retries = 0
last_error = None
while retries < max_retries:
try:
name = cls.get_name(obj)
# Validate and clean the name
if name:
# Remove invalid characters
clean_name = "".join(
c
for c in name
if c.isalnum() or c in ["_", "."]
)
# Ensure it starts with a letter or underscore
if (
not clean_name[0].isalpha()
and clean_name[0] != "_"
):
clean_name = f"_{clean_name}"
return clean_name
except Exception as e:
last_error = e
retries += 1
# If all retries failed, generate a unique fallback name
import uuid
fallback = f"callable_{uuid.uuid4().hex[:8]}"
logging.warning(
f"Failed to get name after {max_retries} retries. Using fallback: {fallback}. "
f"Last error: {str(last_error)}"
)
return fallback
# # Example usage
# if __name__ == "__main__":
# def test_resolver():
# # Test cases
# class TestClass:
# def method(self):
# pass
# async def async_func():
# pass
# test_cases = [
# TestClass, # Class
# TestClass(), # Instance
# async_func, # Async function
# lambda x: x, # Lambda
# partial(print, end=""), # Partial
# TestClass.method, # Method
# print, # Built-in function
# str, # Built-in class
# ]
# resolver = NameResolver()
# print("\nName Resolution Results:")
# print("-" * 50)
# for obj in test_cases:
# details = resolver.get_callable_details(obj)
# safe_name = resolver.get_safe_name(obj)
# print(f"\nObject: {obj}")
# print(f"Safe Name: {safe_name}")
# print(f"Details: {details}")
# test_resolver()

@ -1,4 +1,4 @@
from termcolor import colored
from swarms.utils.formatter import formatter
def display_markdown_message(message: str, color: str = "cyan"):
@ -12,13 +12,10 @@ def display_markdown_message(message: str, color: str = "cyan"):
if line == "":
print()
elif line == "---":
print(colored("-" * 50, color))
formatter.print_panel("-" * 50)
else:
print(colored(line, color))
formatter.print_panel(line)
if "\n" not in message and message.startswith(">"):
# Aesthetic choice. For these tags, they need a space below them
print()
# display_markdown_message("I love you and you are beautiful.", "cyan")

@ -0,0 +1,53 @@
import concurrent.futures
from typing import List, Union
from swarms.structs.agent import Agent
def update_system_prompts(
agents: List[Union[Agent, str]],
prompt: str,
) -> List[Agent]:
"""
Update system prompts for a list of agents concurrently.
Args:
agents: List of Agent objects or strings to update
prompt: The prompt text to append to each agent's system prompt
Returns:
List of updated Agent objects
"""
if not agents:
return agents
def update_agent_prompt(agent: Union[Agent, str]) -> Agent:
# Convert string to Agent if needed
if isinstance(agent, str):
agent = Agent(
agent_name=agent,
system_prompt=prompt, # Initialize with the provided prompt
)
else:
# Preserve existing prompt and append new one
existing_prompt = (
agent.system_prompt if agent.system_prompt else ""
)
agent.system_prompt = existing_prompt + "\n" + prompt
return agent
# Use ThreadPoolExecutor for concurrent execution
max_workers = min(len(agents), 4) # Reasonable thread count
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
) as executor:
futures = []
for agent in agents:
future = executor.submit(update_agent_prompt, agent)
futures.append(future)
# Collect results as they complete
updated_agents = []
for future in concurrent.futures.as_completed(futures):
updated_agents.append(future.result())
return updated_agents

@ -0,0 +1,292 @@
import torch
import torch.nn as nn
import torch.distributed as dist
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from loguru import logger
import math
@dataclass
class StarAttentionConfig:
"""Configuration for StarAttention module.
Attributes:
hidden_size: Dimension of the model's hidden states
num_attention_heads: Number of attention heads
num_hosts: Number of hosts in the distributed system
block_size: Size of each context block
anchor_size: Size of the anchor block
dropout_prob: Dropout probability (default: 0.1)
layer_norm_eps: Layer normalization epsilon (default: 1e-12)
"""
hidden_size: int
num_attention_heads: int
num_hosts: int
block_size: int
anchor_size: int
dropout_prob: float = 0.1
layer_norm_eps: float = 1e-12
class StarAttention(nn.Module):
"""
Implementation of Star Attention mechanism for distributed inference.
The module implements a two-phase attention mechanism:
1. Local Context Encoding with Anchor Blocks
2. Query Encoding and Output Generation with Global Attention
"""
def __init__(self, config: StarAttentionConfig):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"Hidden size {config.hidden_size} not divisible by number of attention "
f"heads {config.num_attention_heads}"
)
self.config = config
self.head_dim = (
config.hidden_size // config.num_attention_heads
)
# Initialize components
self.query = nn.Linear(config.hidden_size, config.hidden_size)
self.key = nn.Linear(config.hidden_size, config.hidden_size)
self.value = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.dropout_prob)
self.layer_norm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
# KV cache for storing computed key/value pairs
self.kv_cache = {}
logger.info(
f"Initialized StarAttention with config: {config}"
)
def _split_heads(
self, tensor: torch.Tensor, num_heads: int
) -> torch.Tensor:
"""Split the last dimension into (num_heads, head_dim)."""
batch_size, seq_len, _ = tensor.size()
tensor = tensor.view(
batch_size, seq_len, num_heads, self.head_dim
)
# Transpose to (batch_size, num_heads, seq_len, head_dim)
return tensor.transpose(1, 2)
def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
"""Merge the head dimension back into hidden_size."""
batch_size, _, seq_len, _ = tensor.size()
tensor = tensor.transpose(1, 2)
return tensor.reshape(
batch_size, seq_len, self.config.hidden_size
)
def _compute_attention_scores(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute attention scores and weighted values."""
# Scale dot-product attention
scores = torch.matmul(
query, key.transpose(-2, -1)
) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
# Online softmax computation
attention_probs = torch.nn.functional.softmax(scores, dim=-1)
attention_probs = self.dropout(attention_probs)
context = torch.matmul(attention_probs, value)
return context, attention_probs
def phase1_local_context_encoding(
self,
input_ids: torch.Tensor,
host_id: int,
device: Union[str, torch.device] = "cuda",
) -> None:
"""
Phase 1: Local Context Encoding with Anchor Blocks
Args:
input_ids: Input tensor of shape (batch_size, seq_len)
host_id: ID of the current host
device: Device to run computations on
"""
logger.debug(f"Starting Phase 1 on host {host_id}")
# Calculate block assignments
block_start = host_id * self.config.block_size
block_end = block_start + self.config.block_size
# Get local block
local_block = input_ids[:, block_start:block_end].to(device)
# Get anchor block (first block)
anchor_block = input_ids[:, : self.config.anchor_size].to(
device
)
# Compute KV pairs for local block
local_hidden = self.layer_norm(local_block)
local_key = self._split_heads(
self.key(local_hidden), self.config.num_attention_heads
)
local_value = self._split_heads(
self.value(local_hidden), self.config.num_attention_heads
)
# Store in KV cache
self.kv_cache[host_id] = {
"key": local_key,
"value": local_value,
"anchor_key": (
None
if host_id == 0
else self._split_heads(
self.key(self.layer_norm(anchor_block)),
self.config.num_attention_heads,
)
),
}
logger.debug(
f"Phase 1 complete on host {host_id}. KV cache shapes - "
f"key: {local_key.shape}, value: {local_value.shape}"
)
def phase2_query_encoding(
self,
query_input: torch.Tensor,
host_id: int,
is_query_host: bool,
device: Union[str, torch.device] = "cuda",
) -> Optional[torch.Tensor]:
"""
Phase 2: Query Encoding and Output Generation
Args:
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
host_id: ID of the current host
is_query_host: Whether this host is the query host
device: Device to run computations on
Returns:
Output tensor if this is the query host, None otherwise
"""
logger.debug(f"Starting Phase 2 on host {host_id}")
# Transform query
query_hidden = self.layer_norm(query_input)
query = self._split_heads(
self.query(query_hidden), self.config.num_attention_heads
)
# Compute local attention scores
local_context, local_probs = self._compute_attention_scores(
query,
self.kv_cache[host_id]["key"],
self.kv_cache[host_id]["value"],
)
if not is_query_host:
# Non-query hosts send their local attention statistics
dist.send(local_probs, dst=self.config.num_hosts - 1)
return None
# Query host aggregates attention from all hosts
all_attention_probs = [local_probs]
for src_rank in range(self.config.num_hosts - 1):
probs = torch.empty_like(local_probs)
dist.recv(probs, src=src_rank)
all_attention_probs.append(probs)
# Compute global attention
torch.mean(torch.stack(all_attention_probs), dim=0)
# Final output computation
output = self._merge_heads(local_context)
output = self.dropout(output)
logger.debug(
f"Phase 2 complete on host {host_id}. Output shape: {output.shape}"
)
return output
def forward(
self,
input_ids: torch.Tensor,
query_input: torch.Tensor,
host_id: int,
is_query_host: bool,
device: Union[str, torch.device] = "cuda",
) -> Optional[torch.Tensor]:
"""
Forward pass of the StarAttention module.
Args:
input_ids: Input tensor of shape (batch_size, seq_len)
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
host_id: ID of the current host
is_query_host: Whether this host is the query host
device: Device to run computations on
Returns:
Output tensor if this is the query host, None otherwise
"""
# Phase 1: Local Context Encoding
self.phase1_local_context_encoding(input_ids, host_id, device)
# Phase 2: Query Encoding and Output Generation
return self.phase2_query_encoding(
query_input, host_id, is_query_host, device
)
# Example forward pass
config = StarAttentionConfig(
hidden_size=768,
num_attention_heads=12,
num_hosts=3,
block_size=512,
anchor_size=128,
)
# Initialize model
model = StarAttention(config)
# Example input tensors
batch_size = 4
seq_len = 512
input_ids = torch.randint(
0, 1000, (batch_size, seq_len)
) # Random input IDs
query_input = torch.randn(
batch_size, seq_len, config.hidden_size
) # Random query input
# Example forward pass for query host (host_id = 2)
output = model(
input_ids=input_ids,
query_input=query_input,
host_id=2,
is_query_host=True,
device="cpu",
)
print(output)
Loading…
Cancel
Save