commit
95e025454e
@ -0,0 +1,56 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from swarm_models import OpenAIChat
|
||||||
|
|
||||||
|
from swarms import Agent
|
||||||
|
from swarms.prompts.finance_agent_sys_prompt import (
|
||||||
|
FINANCIAL_AGENT_SYS_PROMPT,
|
||||||
|
)
|
||||||
|
from async_executor import HighSpeedExecutor
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Get the OpenAI API key from the environment variable
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
# Create an instance of the OpenAIChat class
|
||||||
|
model = OpenAIChat(
|
||||||
|
openai_api_key=api_key, model_name="gpt-4o-mini", temperature=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the agent
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Financial-Analysis-Agent",
|
||||||
|
system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
# autosave=True,
|
||||||
|
# dashboard=False,
|
||||||
|
# verbose=True,
|
||||||
|
# dynamic_temperature_enabled=True,
|
||||||
|
# saved_state_path="finance_agent.json",
|
||||||
|
# user_name="swarms_corp",
|
||||||
|
# retry_attempts=1,
|
||||||
|
# context_length=200000,
|
||||||
|
# return_step_meta=True,
|
||||||
|
# output_type="json", # "json", "dict", "csv" OR "string" soon "yaml" and
|
||||||
|
# auto_generate_prompt=False, # Auto generate prompt for the agent based on name, description, and system prompt, task
|
||||||
|
# # artifacts_on=True,
|
||||||
|
# artifacts_output_path="roth_ira_report",
|
||||||
|
# artifacts_file_extension=".txt",
|
||||||
|
# max_tokens=8000,
|
||||||
|
# return_history=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def execute_agent(
|
||||||
|
task: str = "How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria. Create a report on this question.",
|
||||||
|
):
|
||||||
|
return agent.run(task)
|
||||||
|
|
||||||
|
|
||||||
|
executor = HighSpeedExecutor()
|
||||||
|
results = executor.run(execute_agent, 2)
|
||||||
|
|
||||||
|
print(results)
|
@ -0,0 +1,131 @@
|
|||||||
|
import asyncio
|
||||||
|
import multiprocessing as mp
|
||||||
|
import time
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
|
|
||||||
|
class HighSpeedExecutor:
|
||||||
|
def __init__(self, num_processes: int = None):
|
||||||
|
"""
|
||||||
|
Initialize the executor with configurable number of processes.
|
||||||
|
If num_processes is None, it uses CPU count.
|
||||||
|
"""
|
||||||
|
self.num_processes = num_processes or mp.cpu_count()
|
||||||
|
|
||||||
|
async def _worker(
|
||||||
|
self,
|
||||||
|
queue: asyncio.Queue,
|
||||||
|
func: Any,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""Async worker that processes tasks from the queue"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Non-blocking get from queue
|
||||||
|
await queue.get()
|
||||||
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, partial(func, *args, **kwargs)
|
||||||
|
)
|
||||||
|
queue.task_done()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
|
||||||
|
async def _distribute_tasks(
|
||||||
|
self, num_tasks: int, queue: asyncio.Queue
|
||||||
|
):
|
||||||
|
"""Distribute tasks across the queue"""
|
||||||
|
for i in range(num_tasks):
|
||||||
|
await queue.put(i)
|
||||||
|
|
||||||
|
async def execute_batch(
|
||||||
|
self,
|
||||||
|
func: Any,
|
||||||
|
num_executions: int,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Dict[str, Union[int, float]]:
|
||||||
|
"""
|
||||||
|
Execute the given function multiple times concurrently.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to execute
|
||||||
|
num_executions: Number of times to execute the function
|
||||||
|
*args, **kwargs: Arguments to pass to the function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing the number of executions, duration, and executions per second.
|
||||||
|
"""
|
||||||
|
queue = asyncio.Queue()
|
||||||
|
|
||||||
|
# Create worker tasks
|
||||||
|
workers = [
|
||||||
|
asyncio.create_task(
|
||||||
|
self._worker(queue, func, *args, **kwargs)
|
||||||
|
)
|
||||||
|
for _ in range(self.num_processes)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Start timing
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Distribute tasks
|
||||||
|
await self._distribute_tasks(num_executions, queue)
|
||||||
|
|
||||||
|
# Wait for all tasks to complete
|
||||||
|
await queue.join()
|
||||||
|
|
||||||
|
# Cancel workers
|
||||||
|
for worker in workers:
|
||||||
|
worker.cancel()
|
||||||
|
|
||||||
|
# Wait for all workers to finish
|
||||||
|
await asyncio.gather(*workers, return_exceptions=True)
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
duration = end_time - start_time
|
||||||
|
|
||||||
|
return {
|
||||||
|
"executions": num_executions,
|
||||||
|
"duration": duration,
|
||||||
|
"executions_per_second": num_executions / duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
func: Any,
|
||||||
|
num_executions: int,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
return asyncio.run(
|
||||||
|
self.execute_batch(func, num_executions, *args, **kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# def example_function(x: int = 0) -> int:
|
||||||
|
# """Example function to execute"""
|
||||||
|
# return x * x
|
||||||
|
|
||||||
|
|
||||||
|
# async def main():
|
||||||
|
# # Create executor with number of CPU cores
|
||||||
|
# executor = HighSpeedExecutor()
|
||||||
|
|
||||||
|
# # Execute the function 1000 times
|
||||||
|
# result = await executor.execute_batch(
|
||||||
|
# example_function, num_executions=1000, x=42
|
||||||
|
# )
|
||||||
|
|
||||||
|
# print(
|
||||||
|
# f"Completed {result['executions']} executions in {result['duration']:.2f} seconds"
|
||||||
|
# )
|
||||||
|
# print(
|
||||||
|
# f"Rate: {result['executions_per_second']:.2f} executions/second"
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# # Run the async main function
|
||||||
|
# asyncio.run(main())
|
@ -0,0 +1,308 @@
|
|||||||
|
import os
|
||||||
|
from swarms import Agent
|
||||||
|
from swarm_models import OpenAIChat
|
||||||
|
from web3 import Web3
|
||||||
|
from typing import Dict, Optional, Any
|
||||||
|
from datetime import datetime
|
||||||
|
import asyncio
|
||||||
|
from loguru import logger
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import csv
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
|
||||||
|
BLOCKCHAIN_AGENT_PROMPT = """
|
||||||
|
You are an expert blockchain and cryptocurrency analyst with deep knowledge of Ethereum markets and DeFi ecosystems.
|
||||||
|
You have access to real-time ETH price data and transaction information.
|
||||||
|
|
||||||
|
For each transaction, analyze:
|
||||||
|
|
||||||
|
1. MARKET CONTEXT
|
||||||
|
- Current ETH price and what this transaction means in USD terms
|
||||||
|
- How this movement compares to typical market volumes
|
||||||
|
- Whether this could impact ETH price
|
||||||
|
|
||||||
|
2. BEHAVIORAL ANALYSIS
|
||||||
|
- Whether this appears to be institutional, whale, or protocol movement
|
||||||
|
- If this fits any known wallet patterns or behaviors
|
||||||
|
- Signs of smart contract interaction or DeFi activity
|
||||||
|
|
||||||
|
3. RISK & IMPLICATIONS
|
||||||
|
- Potential market impact or price influence
|
||||||
|
- Signs of potential market manipulation or unusual activity
|
||||||
|
- Protocol or DeFi risks if applicable
|
||||||
|
|
||||||
|
4. STRATEGIC INSIGHTS
|
||||||
|
- What traders should know about this movement
|
||||||
|
- Potential chain reactions or follow-up effects
|
||||||
|
- Market opportunities or risks created
|
||||||
|
|
||||||
|
Write naturally but precisely. Focus on actionable insights and important patterns.
|
||||||
|
Your analysis helps traders and researchers understand significant market movements in real-time."""
|
||||||
|
|
||||||
|
|
||||||
|
class EthereumAnalyzer:
|
||||||
|
def __init__(self, min_value_eth: float = 100.0):
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
logger.add(
|
||||||
|
"eth_analysis.log",
|
||||||
|
rotation="500 MB",
|
||||||
|
retention="10 days",
|
||||||
|
level="INFO",
|
||||||
|
format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.w3 = Web3(
|
||||||
|
Web3.HTTPProvider(
|
||||||
|
"https://mainnet.infura.io/v3/9aa3d95b3bc440fa88ea12eaa4456161"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not self.w3.is_connected():
|
||||||
|
raise ConnectionError(
|
||||||
|
"Failed to connect to Ethereum network"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.min_value_eth = min_value_eth
|
||||||
|
self.last_processed_block = self.w3.eth.block_number
|
||||||
|
self.eth_price = self.get_eth_price()
|
||||||
|
self.last_price_update = time.time()
|
||||||
|
|
||||||
|
# Initialize AI agent
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"OpenAI API key not found in environment variables"
|
||||||
|
)
|
||||||
|
|
||||||
|
model = OpenAIChat(
|
||||||
|
openai_api_key=api_key,
|
||||||
|
model_name="gpt-4",
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.agent = Agent(
|
||||||
|
agent_name="Ethereum-Analysis-Agent",
|
||||||
|
system_prompt=BLOCKCHAIN_AGENT_PROMPT,
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
autosave=True,
|
||||||
|
dashboard=False,
|
||||||
|
verbose=True,
|
||||||
|
dynamic_temperature_enabled=True,
|
||||||
|
saved_state_path="eth_agent.json",
|
||||||
|
user_name="eth_analyzer",
|
||||||
|
retry_attempts=1,
|
||||||
|
context_length=200000,
|
||||||
|
output_type="string",
|
||||||
|
streaming_on=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.csv_filename = "ethereum_analysis.csv"
|
||||||
|
self.initialize_csv()
|
||||||
|
|
||||||
|
def get_eth_price(self) -> float:
|
||||||
|
"""Get current ETH price from CoinGecko API."""
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
"https://api.coingecko.com/api/v3/simple/price",
|
||||||
|
params={"ids": "ethereum", "vs_currencies": "usd"},
|
||||||
|
)
|
||||||
|
return float(response.json()["ethereum"]["usd"])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching ETH price: {str(e)}")
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def update_eth_price(self):
|
||||||
|
"""Update ETH price if more than 5 minutes have passed."""
|
||||||
|
if time.time() - self.last_price_update > 300: # 5 minutes
|
||||||
|
self.eth_price = self.get_eth_price()
|
||||||
|
self.last_price_update = time.time()
|
||||||
|
logger.info(f"Updated ETH price: ${self.eth_price:,.2f}")
|
||||||
|
|
||||||
|
def initialize_csv(self):
|
||||||
|
"""Initialize CSV file with headers."""
|
||||||
|
headers = [
|
||||||
|
"timestamp",
|
||||||
|
"transaction_hash",
|
||||||
|
"from_address",
|
||||||
|
"to_address",
|
||||||
|
"value_eth",
|
||||||
|
"value_usd",
|
||||||
|
"eth_price",
|
||||||
|
"gas_used",
|
||||||
|
"gas_price_gwei",
|
||||||
|
"block_number",
|
||||||
|
"analysis",
|
||||||
|
]
|
||||||
|
|
||||||
|
if not os.path.exists(self.csv_filename):
|
||||||
|
with open(self.csv_filename, "w", newline="") as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
writer.writerow(headers)
|
||||||
|
|
||||||
|
async def analyze_transaction(
|
||||||
|
self, tx_hash: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Analyze a single transaction."""
|
||||||
|
try:
|
||||||
|
tx = self.w3.eth.get_transaction(tx_hash)
|
||||||
|
receipt = self.w3.eth.get_transaction_receipt(tx_hash)
|
||||||
|
|
||||||
|
value_eth = float(self.w3.from_wei(tx.value, "ether"))
|
||||||
|
|
||||||
|
if value_eth < self.min_value_eth:
|
||||||
|
return None
|
||||||
|
|
||||||
|
block = self.w3.eth.get_block(tx.blockNumber)
|
||||||
|
|
||||||
|
# Update ETH price if needed
|
||||||
|
self.update_eth_price()
|
||||||
|
|
||||||
|
value_usd = value_eth * self.eth_price
|
||||||
|
|
||||||
|
analysis = {
|
||||||
|
"timestamp": datetime.fromtimestamp(
|
||||||
|
block.timestamp
|
||||||
|
).isoformat(),
|
||||||
|
"transaction_hash": tx_hash.hex(),
|
||||||
|
"from_address": tx["from"],
|
||||||
|
"to_address": tx.to if tx.to else "Contract Creation",
|
||||||
|
"value_eth": value_eth,
|
||||||
|
"value_usd": value_usd,
|
||||||
|
"eth_price": self.eth_price,
|
||||||
|
"gas_used": receipt.gasUsed,
|
||||||
|
"gas_price_gwei": float(
|
||||||
|
self.w3.from_wei(tx.gasPrice, "gwei")
|
||||||
|
),
|
||||||
|
"block_number": tx.blockNumber,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if it's a contract
|
||||||
|
if tx.to:
|
||||||
|
code = self.w3.eth.get_code(tx.to)
|
||||||
|
analysis["is_contract"] = len(code) > 0
|
||||||
|
|
||||||
|
# Get contract events
|
||||||
|
if analysis["is_contract"]:
|
||||||
|
analysis["events"] = receipt.logs
|
||||||
|
|
||||||
|
return analysis
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error analyzing transaction {tx_hash}: {str(e)}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def prepare_analysis_prompt(self, tx_data: Dict[str, Any]) -> str:
|
||||||
|
"""Prepare detailed analysis prompt including price context."""
|
||||||
|
value_usd = tx_data["value_usd"]
|
||||||
|
eth_price = tx_data["eth_price"]
|
||||||
|
|
||||||
|
prompt = f"""Analyze this Ethereum transaction in current market context:
|
||||||
|
|
||||||
|
Transaction Details:
|
||||||
|
- Value: {tx_data['value_eth']:.2f} ETH (${value_usd:,.2f} at current price)
|
||||||
|
- Current ETH Price: ${eth_price:,.2f}
|
||||||
|
- From: {tx_data['from_address']}
|
||||||
|
- To: {tx_data['to_address']}
|
||||||
|
- Contract Interaction: {tx_data.get('is_contract', False)}
|
||||||
|
- Gas Used: {tx_data['gas_used']:,} units
|
||||||
|
- Gas Price: {tx_data['gas_price_gwei']:.2f} Gwei
|
||||||
|
- Block: {tx_data['block_number']}
|
||||||
|
- Timestamp: {tx_data['timestamp']}
|
||||||
|
|
||||||
|
{f"Event Count: {len(tx_data['events'])} events" if tx_data.get('events') else "No contract events"}
|
||||||
|
|
||||||
|
Consider the transaction's significance given the current ETH price of ${eth_price:,.2f} and total USD value of ${value_usd:,.2f}.
|
||||||
|
Analyze market impact, patterns, risks, and strategic implications."""
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def save_to_csv(self, tx_data: Dict[str, Any], ai_analysis: str):
|
||||||
|
"""Save transaction data and analysis to CSV."""
|
||||||
|
row = [
|
||||||
|
tx_data["timestamp"],
|
||||||
|
tx_data["transaction_hash"],
|
||||||
|
tx_data["from_address"],
|
||||||
|
tx_data["to_address"],
|
||||||
|
tx_data["value_eth"],
|
||||||
|
tx_data["value_usd"],
|
||||||
|
tx_data["eth_price"],
|
||||||
|
tx_data["gas_used"],
|
||||||
|
tx_data["gas_price_gwei"],
|
||||||
|
tx_data["block_number"],
|
||||||
|
ai_analysis.replace("\n", " "),
|
||||||
|
]
|
||||||
|
|
||||||
|
with open(self.csv_filename, "a", newline="") as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
async def monitor_transactions(self):
|
||||||
|
"""Monitor and analyze transactions one at a time."""
|
||||||
|
logger.info(
|
||||||
|
f"Starting transaction monitor (minimum value: {self.min_value_eth} ETH)"
|
||||||
|
)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
current_block = self.w3.eth.block_number
|
||||||
|
block = self.w3.eth.get_block(
|
||||||
|
current_block, full_transactions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for tx in block.transactions:
|
||||||
|
tx_analysis = await self.analyze_transaction(
|
||||||
|
tx.hash
|
||||||
|
)
|
||||||
|
|
||||||
|
if tx_analysis:
|
||||||
|
# Get AI analysis
|
||||||
|
analysis_prompt = (
|
||||||
|
self.prepare_analysis_prompt(tx_analysis)
|
||||||
|
)
|
||||||
|
ai_analysis = self.agent.run(analysis_prompt)
|
||||||
|
print(ai_analysis)
|
||||||
|
|
||||||
|
# Save to CSV
|
||||||
|
self.save_to_csv(tx_analysis, ai_analysis)
|
||||||
|
|
||||||
|
# Print analysis
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("New Transaction Analysis")
|
||||||
|
print(
|
||||||
|
f"Hash: {tx_analysis['transaction_hash']}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Value: {tx_analysis['value_eth']:.2f} ETH (${tx_analysis['value_usd']:,.2f})"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Current ETH Price: ${self.eth_price:,.2f}"
|
||||||
|
)
|
||||||
|
print("=" * 50)
|
||||||
|
print(ai_analysis)
|
||||||
|
print("=" * 50 + "\n")
|
||||||
|
|
||||||
|
await asyncio.sleep(1) # Wait for next block
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in monitoring loop: {str(e)}")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Entry point for the analysis system."""
|
||||||
|
analyzer = EthereumAnalyzer(min_value_eth=100.0)
|
||||||
|
await analyzer.monitor_transactions()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Starting Ethereum Transaction Analyzer...")
|
||||||
|
print("Saving results to ethereum_analysis.csv")
|
||||||
|
print("Press Ctrl+C to stop")
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nStopping analyzer...")
|
Can't render this file because it has a wrong number of fields in line 4.
|
@ -0,0 +1,244 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from swarms import Agent
|
||||||
|
from swarm_models import OpenAIChat
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from swarms.utils.formatter import formatter
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Get OpenAI API key
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
# Define Pydantic schema for agent outputs
|
||||||
|
class AgentOutput(BaseModel):
|
||||||
|
"""Schema for capturing the output of each agent."""
|
||||||
|
|
||||||
|
agent_name: str = Field(..., description="The name of the agent")
|
||||||
|
message: str = Field(
|
||||||
|
...,
|
||||||
|
description="The agent's response or contribution to the group chat",
|
||||||
|
)
|
||||||
|
metadata: Dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Additional metadata about the agent's response",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupChat:
|
||||||
|
"""
|
||||||
|
GroupChat class to enable multiple agents to communicate in an asynchronous group chat.
|
||||||
|
Each agent is aware of all other agents, every message exchanged, and the social context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
agents: List[Agent],
|
||||||
|
max_loops: int = 1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the GroupChat.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Name of the group chat.
|
||||||
|
description (str): Description of the purpose of the group chat.
|
||||||
|
agents (List[Agent]): A list of agents participating in the chat.
|
||||||
|
max_loops (int): Maximum number of loops to run through all agents.
|
||||||
|
"""
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.agents = agents
|
||||||
|
self.max_loops = max_loops
|
||||||
|
self.chat_history = (
|
||||||
|
[]
|
||||||
|
) # Stores all messages exchanged in the chat
|
||||||
|
|
||||||
|
formatter.print_panel(
|
||||||
|
f"Initialized GroupChat '{self.name}' with {len(self.agents)} agents. Max loops: {self.max_loops}",
|
||||||
|
title="Groupchat Swarm",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _agent_conversation(
|
||||||
|
self, agent: Agent, input_message: str
|
||||||
|
) -> AgentOutput:
|
||||||
|
"""
|
||||||
|
Facilitate a single agent's response to the chat.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent (Agent): The agent responding.
|
||||||
|
input_message (str): The message triggering the response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentOutput: The agent's response captured in a structured format.
|
||||||
|
"""
|
||||||
|
formatter.print_panel(
|
||||||
|
f"Agent '{agent.agent_name}' is responding to the message: {input_message}",
|
||||||
|
title="Groupchat Swarm",
|
||||||
|
)
|
||||||
|
response = await asyncio.to_thread(agent.run, input_message)
|
||||||
|
|
||||||
|
output = AgentOutput(
|
||||||
|
agent_name=agent.agent_name,
|
||||||
|
message=response,
|
||||||
|
metadata={"context_length": agent.context_length},
|
||||||
|
)
|
||||||
|
# logger.debug(f"Agent '{agent.agent_name}' response: {response}")
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def _run(self, initial_message: str) -> List[AgentOutput]:
|
||||||
|
"""
|
||||||
|
Execute the group chat asynchronously, looping through all agents up to max_loops.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_message (str): The initial message to start the chat.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[AgentOutput]: The responses of all agents across all loops.
|
||||||
|
"""
|
||||||
|
formatter.print_panel(
|
||||||
|
f"Starting group chat '{self.name}' with initial message: {initial_message}",
|
||||||
|
title="Groupchat Swarm",
|
||||||
|
)
|
||||||
|
self.chat_history.append(
|
||||||
|
{"sender": "System", "message": initial_message}
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for loop in range(self.max_loops):
|
||||||
|
formatter.print_panel(
|
||||||
|
f"Group chat loop {loop + 1}/{self.max_loops}",
|
||||||
|
title="Groupchat Swarm",
|
||||||
|
)
|
||||||
|
|
||||||
|
for agent in self.agents:
|
||||||
|
# Create a custom input message for each agent, sharing the chat history and social context
|
||||||
|
input_message = (
|
||||||
|
f"Chat History:\n{self._format_chat_history()}\n\n"
|
||||||
|
f"Participants:\n"
|
||||||
|
+ "\n".join(
|
||||||
|
[
|
||||||
|
f"- {a.agent_name}: {a.system_prompt}"
|
||||||
|
for a in self.agents
|
||||||
|
]
|
||||||
|
)
|
||||||
|
+ f"\n\nNew Message: {initial_message}\n\n"
|
||||||
|
f"You are '{agent.agent_name}'. Remember to keep track of the social context, who is speaking, "
|
||||||
|
f"and respond accordingly based on your role: {agent.system_prompt}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect agent's response
|
||||||
|
output = await self._agent_conversation(
|
||||||
|
agent, input_message
|
||||||
|
)
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
# Update chat history with the agent's response
|
||||||
|
self.chat_history.append(
|
||||||
|
{
|
||||||
|
"sender": agent.agent_name,
|
||||||
|
"message": output.message,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
formatter.print_panel(
|
||||||
|
"Group chat completed. All agent responses captured.",
|
||||||
|
title="Groupchat Swarm",
|
||||||
|
)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def run(self, task: str, *args, **kwargs):
|
||||||
|
return asyncio.run(self.run(task, *args, **kwargs))
|
||||||
|
|
||||||
|
def _format_chat_history(self) -> str:
|
||||||
|
"""
|
||||||
|
Format the chat history for agents to understand the context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The formatted chat history as a string.
|
||||||
|
"""
|
||||||
|
return "\n".join(
|
||||||
|
[
|
||||||
|
f"{entry['sender']}: {entry['message']}"
|
||||||
|
for entry in self.chat_history
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""String representation of the group chat's outputs."""
|
||||||
|
return self._format_chat_history()
|
||||||
|
|
||||||
|
def to_json(self) -> str:
|
||||||
|
"""JSON representation of the group chat's outputs."""
|
||||||
|
return [
|
||||||
|
{"sender": entry["sender"], "message": entry["message"]}
|
||||||
|
for entry in self.chat_history
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Example Usage
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Get the OpenAI API key from the environment variable
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
# Create an instance of the OpenAIChat class
|
||||||
|
model = OpenAIChat(
|
||||||
|
openai_api_key=api_key,
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Example agents
|
||||||
|
agent1 = Agent(
|
||||||
|
agent_name="Financial-Analysis-Agent",
|
||||||
|
system_prompt="You are a financial analyst specializing in investment strategies.",
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
autosave=False,
|
||||||
|
dashboard=False,
|
||||||
|
verbose=True,
|
||||||
|
dynamic_temperature_enabled=True,
|
||||||
|
user_name="swarms_corp",
|
||||||
|
retry_attempts=1,
|
||||||
|
context_length=200000,
|
||||||
|
output_type="string",
|
||||||
|
streaming_on=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent2 = Agent(
|
||||||
|
agent_name="Tax-Adviser-Agent",
|
||||||
|
system_prompt="You are a tax adviser who provides clear and concise guidance on tax-related queries.",
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
autosave=False,
|
||||||
|
dashboard=False,
|
||||||
|
verbose=True,
|
||||||
|
dynamic_temperature_enabled=True,
|
||||||
|
user_name="swarms_corp",
|
||||||
|
retry_attempts=1,
|
||||||
|
context_length=200000,
|
||||||
|
output_type="string",
|
||||||
|
streaming_on=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create group chat
|
||||||
|
group_chat = GroupChat(
|
||||||
|
name="Financial Discussion",
|
||||||
|
description="A group chat for financial analysis and tax advice.",
|
||||||
|
agents=[agent1, agent2],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the group chat
|
||||||
|
asyncio.run(
|
||||||
|
group_chat.run(
|
||||||
|
"How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria? What do you guys think?"
|
||||||
|
)
|
||||||
|
)
|
@ -0,0 +1,417 @@
|
|||||||
|
import os
|
||||||
|
from typing import List, Dict, Any, Optional, Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
import inspect
|
||||||
|
import typing
|
||||||
|
from typing import Union
|
||||||
|
from swarms import Agent
|
||||||
|
from swarm_models import OpenAIChat
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolDefinition:
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
parameters: Dict[str, Any]
|
||||||
|
required_params: List[str]
|
||||||
|
callable: Optional[Callable] = None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_type_hints(func: Callable) -> Dict[str, Any]:
|
||||||
|
"""Extract parameter types from function type hints."""
|
||||||
|
return typing.get_type_hints(func)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tool_info(func: Callable) -> ToolDefinition:
|
||||||
|
"""Extract tool information from a callable function."""
|
||||||
|
# Get function name
|
||||||
|
name = func.__name__
|
||||||
|
|
||||||
|
# Get docstring
|
||||||
|
description = inspect.getdoc(func) or "No description available"
|
||||||
|
|
||||||
|
# Get parameters and their types
|
||||||
|
signature = inspect.signature(func)
|
||||||
|
type_hints = extract_type_hints(func)
|
||||||
|
|
||||||
|
parameters = {}
|
||||||
|
required_params = []
|
||||||
|
|
||||||
|
for param_name, param in signature.parameters.items():
|
||||||
|
# Skip self parameter for methods
|
||||||
|
if param_name == "self":
|
||||||
|
continue
|
||||||
|
|
||||||
|
param_type = type_hints.get(param_name, Any)
|
||||||
|
|
||||||
|
# Handle optional parameters
|
||||||
|
is_optional = (
|
||||||
|
param.default != inspect.Parameter.empty
|
||||||
|
or getattr(param_type, "__origin__", None) is Union
|
||||||
|
and type(None) in param_type.__args__
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_optional:
|
||||||
|
required_params.append(param_name)
|
||||||
|
|
||||||
|
parameters[param_name] = {
|
||||||
|
"type": str(param_type),
|
||||||
|
"default": (
|
||||||
|
None
|
||||||
|
if param.default is inspect.Parameter.empty
|
||||||
|
else param.default
|
||||||
|
),
|
||||||
|
"required": not is_optional,
|
||||||
|
}
|
||||||
|
|
||||||
|
return ToolDefinition(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
parameters=parameters,
|
||||||
|
required_params=required_params,
|
||||||
|
callable=func,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FunctionSpec:
|
||||||
|
"""Specification for a callable tool function."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
parameters: Dict[
|
||||||
|
str, dict
|
||||||
|
] # Contains type and description for each parameter
|
||||||
|
return_type: str
|
||||||
|
return_description: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExecutionStep:
|
||||||
|
"""Represents a single step in the execution plan."""
|
||||||
|
|
||||||
|
step_id: int
|
||||||
|
function_name: str
|
||||||
|
parameters: Dict[str, Any]
|
||||||
|
expected_output: str
|
||||||
|
completed: bool = False
|
||||||
|
result: Any = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExecutionContext:
|
||||||
|
"""Maintains state during execution."""
|
||||||
|
|
||||||
|
task: str
|
||||||
|
steps: List[ExecutionStep] = field(default_factory=list)
|
||||||
|
results: Dict[int, Any] = field(default_factory=dict)
|
||||||
|
current_step: int = 0
|
||||||
|
history: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolAgent:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
functions: List[Callable],
|
||||||
|
openai_api_key: str,
|
||||||
|
model_name: str = "gpt-4",
|
||||||
|
temperature: float = 0.1,
|
||||||
|
):
|
||||||
|
self.functions = {func.__name__: func for func in functions}
|
||||||
|
self.function_specs = self._analyze_functions(functions)
|
||||||
|
|
||||||
|
self.model = OpenAIChat(
|
||||||
|
openai_api_key=openai_api_key,
|
||||||
|
model_name=model_name,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.system_prompt = self._create_system_prompt()
|
||||||
|
self.agent = Agent(
|
||||||
|
agent_name="Tool-Agent",
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
llm=self.model,
|
||||||
|
max_loops=1,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _analyze_functions(
|
||||||
|
self, functions: List[Callable]
|
||||||
|
) -> Dict[str, FunctionSpec]:
|
||||||
|
"""Analyze functions to create detailed specifications."""
|
||||||
|
specs = {}
|
||||||
|
for func in functions:
|
||||||
|
hints = get_type_hints(func)
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
doc = inspect.getdoc(func) or ""
|
||||||
|
|
||||||
|
# Parse docstring for parameter descriptions
|
||||||
|
param_descriptions = {}
|
||||||
|
current_param = None
|
||||||
|
for line in doc.split("\n"):
|
||||||
|
if ":param" in line:
|
||||||
|
param_name = (
|
||||||
|
line.split(":param")[1].split(":")[0].strip()
|
||||||
|
)
|
||||||
|
desc = line.split(":", 2)[-1].strip()
|
||||||
|
param_descriptions[param_name] = desc
|
||||||
|
elif ":return:" in line:
|
||||||
|
return_desc = line.split(":return:")[1].strip()
|
||||||
|
|
||||||
|
# Build parameter specifications
|
||||||
|
parameters = {}
|
||||||
|
for name, param in sig.parameters.items():
|
||||||
|
param_type = hints.get(name, Any)
|
||||||
|
parameters[name] = {
|
||||||
|
"type": str(param_type),
|
||||||
|
"type_class": param_type,
|
||||||
|
"description": param_descriptions.get(name, ""),
|
||||||
|
"required": param.default == param.empty,
|
||||||
|
}
|
||||||
|
|
||||||
|
specs[func.__name__] = FunctionSpec(
|
||||||
|
name=func.__name__,
|
||||||
|
description=doc.split("\n")[0],
|
||||||
|
parameters=parameters,
|
||||||
|
return_type=str(hints.get("return", Any)),
|
||||||
|
return_description=(
|
||||||
|
return_desc if "return_desc" in locals() else ""
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return specs
|
||||||
|
|
||||||
|
def _create_system_prompt(self) -> str:
|
||||||
|
"""Create system prompt with detailed function specifications."""
|
||||||
|
functions_desc = []
|
||||||
|
for spec in self.function_specs.values():
|
||||||
|
params_desc = []
|
||||||
|
for name, details in spec.parameters.items():
|
||||||
|
params_desc.append(
|
||||||
|
f" - {name}: {details['type']} - {details['description']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
functions_desc.append(
|
||||||
|
f"""
|
||||||
|
Function: {spec.name}
|
||||||
|
Description: {spec.description}
|
||||||
|
Parameters:
|
||||||
|
{chr(10).join(params_desc)}
|
||||||
|
Returns: {spec.return_type} - {spec.return_description}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"""You are an AI agent that creates and executes plans using available functions.
|
||||||
|
|
||||||
|
Available Functions:
|
||||||
|
{chr(10).join(functions_desc)}
|
||||||
|
|
||||||
|
You must respond in two formats depending on the phase:
|
||||||
|
|
||||||
|
1. Planning Phase:
|
||||||
|
{{
|
||||||
|
"phase": "planning",
|
||||||
|
"plan": {{
|
||||||
|
"description": "Overall plan description",
|
||||||
|
"steps": [
|
||||||
|
{{
|
||||||
|
"step_id": 1,
|
||||||
|
"function": "function_name",
|
||||||
|
"parameters": {{
|
||||||
|
"param1": "value1",
|
||||||
|
"param2": "value2"
|
||||||
|
}},
|
||||||
|
"purpose": "Why this step is needed"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
|
||||||
|
2. Execution Phase:
|
||||||
|
{{
|
||||||
|
"phase": "execution",
|
||||||
|
"analysis": "Analysis of current result",
|
||||||
|
"next_action": {{
|
||||||
|
"type": "continue|request_input|complete",
|
||||||
|
"reason": "Why this action was chosen",
|
||||||
|
"needed_input": {{}} # If requesting input
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
|
||||||
|
Always:
|
||||||
|
- Use exact function names
|
||||||
|
- Ensure parameter types match specifications
|
||||||
|
- Provide clear reasoning for each decision
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _execute_function(
|
||||||
|
self, spec: FunctionSpec, parameters: Dict[str, Any]
|
||||||
|
) -> Any:
|
||||||
|
"""Execute a function with type checking."""
|
||||||
|
converted_params = {}
|
||||||
|
for name, value in parameters.items():
|
||||||
|
param_spec = spec.parameters[name]
|
||||||
|
try:
|
||||||
|
# Convert value to required type
|
||||||
|
param_type = param_spec["type_class"]
|
||||||
|
if param_type in (int, float, str, bool):
|
||||||
|
converted_params[name] = param_type(value)
|
||||||
|
else:
|
||||||
|
converted_params[name] = value
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameter '{name}' conversion failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.functions[spec.name](**converted_params)
|
||||||
|
|
||||||
|
def run(self, task: str) -> Dict[str, Any]:
|
||||||
|
"""Execute task with planning and step-by-step execution."""
|
||||||
|
context = ExecutionContext(task=task)
|
||||||
|
execution_log = {
|
||||||
|
"task": task,
|
||||||
|
"start_time": datetime.utcnow().isoformat(),
|
||||||
|
"steps": [],
|
||||||
|
"final_result": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Planning phase
|
||||||
|
plan_prompt = f"Create a plan to: {task}"
|
||||||
|
plan_response = self.agent.run(plan_prompt)
|
||||||
|
plan_data = json.loads(
|
||||||
|
plan_response.replace("System:", "").strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert plan to execution steps
|
||||||
|
for step in plan_data["plan"]["steps"]:
|
||||||
|
context.steps.append(
|
||||||
|
ExecutionStep(
|
||||||
|
step_id=step["step_id"],
|
||||||
|
function_name=step["function"],
|
||||||
|
parameters=step["parameters"],
|
||||||
|
expected_output=step["purpose"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execution phase
|
||||||
|
while context.current_step < len(context.steps):
|
||||||
|
step = context.steps[context.current_step]
|
||||||
|
print(
|
||||||
|
f"\nExecuting step {step.step_id}: {step.function_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute function
|
||||||
|
spec = self.function_specs[step.function_name]
|
||||||
|
result = self._execute_function(
|
||||||
|
spec, step.parameters
|
||||||
|
)
|
||||||
|
context.results[step.step_id] = result
|
||||||
|
step.completed = True
|
||||||
|
step.result = result
|
||||||
|
|
||||||
|
# Get agent's analysis
|
||||||
|
analysis_prompt = f"""
|
||||||
|
Step {step.step_id} completed:
|
||||||
|
Function: {step.function_name}
|
||||||
|
Result: {json.dumps(result)}
|
||||||
|
Remaining steps: {len(context.steps) - context.current_step - 1}
|
||||||
|
|
||||||
|
Analyze the result and decide next action.
|
||||||
|
"""
|
||||||
|
|
||||||
|
analysis_response = self.agent.run(
|
||||||
|
analysis_prompt
|
||||||
|
)
|
||||||
|
analysis_data = json.loads(
|
||||||
|
analysis_response.replace(
|
||||||
|
"System:", ""
|
||||||
|
).strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
execution_log["steps"].append(
|
||||||
|
{
|
||||||
|
"step_id": step.step_id,
|
||||||
|
"function": step.function_name,
|
||||||
|
"parameters": step.parameters,
|
||||||
|
"result": result,
|
||||||
|
"analysis": analysis_data,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
analysis_data["next_action"]["type"]
|
||||||
|
== "complete"
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
context.current_step
|
||||||
|
< len(context.steps) - 1
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
context.current_step += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in step {step.step_id}: {str(e)}")
|
||||||
|
execution_log["steps"].append(
|
||||||
|
{
|
||||||
|
"step_id": step.step_id,
|
||||||
|
"function": step.function_name,
|
||||||
|
"parameters": step.parameters,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Final analysis
|
||||||
|
final_prompt = f"""
|
||||||
|
Task completed. Results:
|
||||||
|
{json.dumps(context.results, indent=2)}
|
||||||
|
|
||||||
|
Provide final analysis and recommendations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
final_analysis = self.agent.run(final_prompt)
|
||||||
|
execution_log["final_result"] = {
|
||||||
|
"success": True,
|
||||||
|
"results": context.results,
|
||||||
|
"analysis": json.loads(
|
||||||
|
final_analysis.replace("System:", "").strip()
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
execution_log["final_result"] = {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
execution_log["end_time"] = datetime.utcnow().isoformat()
|
||||||
|
return execution_log
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_investment_return(
|
||||||
|
principal: float, rate: float, years: int
|
||||||
|
) -> float:
|
||||||
|
"""Calculate investment return with compound interest.
|
||||||
|
|
||||||
|
:param principal: Initial investment amount in dollars
|
||||||
|
:param rate: Annual interest rate as decimal (e.g., 0.07 for 7%)
|
||||||
|
:param years: Number of years to invest
|
||||||
|
:return: Final investment value
|
||||||
|
"""
|
||||||
|
return principal * (1 + rate) ** years
|
||||||
|
|
||||||
|
|
||||||
|
agent = ToolAgent(
|
||||||
|
functions=[calculate_investment_return],
|
||||||
|
openai_api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = agent.run(
|
||||||
|
"Calculate returns for $10000 invested at 7% for 10 years"
|
||||||
|
)
|
@ -1,109 +1,87 @@
|
|||||||
import os
|
|
||||||
from typing import List, Any
|
|
||||||
from swarms.structs.agent import Agent
|
from swarms.structs.agent import Agent
|
||||||
from loguru import logger
|
from typing import List
|
||||||
import uuid
|
|
||||||
|
|
||||||
WORKSPACE_DIR = os.getenv("WORKSPACE_DIR")
|
|
||||||
uuid_for_log = str(uuid.uuid4())
|
|
||||||
logger.add(
|
|
||||||
os.path.join(
|
|
||||||
WORKSPACE_DIR,
|
|
||||||
"agents_available",
|
|
||||||
f"agents-available-{uuid_for_log}.log",
|
|
||||||
),
|
|
||||||
level="INFO",
|
|
||||||
colorize=True,
|
|
||||||
backtrace=True,
|
|
||||||
diagnose=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_agent_name(agent: Any) -> str:
|
|
||||||
"""Helper function to safely get agent name
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent (Any): The agent object to get name from
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The agent's name if found, 'Unknown' otherwise
|
|
||||||
"""
|
|
||||||
if isinstance(agent, Agent) and hasattr(agent, "agent_name"):
|
|
||||||
return agent.agent_name
|
|
||||||
return "Unknown"
|
|
||||||
|
|
||||||
|
|
||||||
def get_agent_description(agent: Any) -> str:
|
|
||||||
"""Helper function to get agent description or system prompt preview
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent (Any): The agent object
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Description or first 100 chars of system prompt
|
|
||||||
"""
|
|
||||||
if not isinstance(agent, Agent):
|
|
||||||
return "N/A"
|
|
||||||
|
|
||||||
if hasattr(agent, "description") and agent.description:
|
|
||||||
return agent.description
|
|
||||||
|
|
||||||
if hasattr(agent, "system_prompt") and agent.system_prompt:
|
|
||||||
return f"{agent.system_prompt[:150]}..."
|
|
||||||
|
|
||||||
return "N/A"
|
|
||||||
|
|
||||||
|
|
||||||
def showcase_available_agents(
|
def showcase_available_agents(
|
||||||
|
agents: List[Agent],
|
||||||
name: str = None,
|
name: str = None,
|
||||||
description: str = None,
|
description: str = None,
|
||||||
agents: List[Agent] = [],
|
format: str = "XML",
|
||||||
update_agents_on: bool = False,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a formatted string showcasing all available agents and their descriptions.
|
Format the available agents in either XML or Table format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agents (List[Agent]): List of Agent objects to showcase.
|
agents (List[Agent]): A list of agents to represent
|
||||||
update_agents_on (bool, optional): If True, updates each agent's system prompt with
|
name (str, optional): Name of the swarm
|
||||||
the showcase information. Defaults to False.
|
description (str, optional): Description of the swarm
|
||||||
|
format (str, optional): Output format ("XML" or "Table"). Defaults to "XML"
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Formatted string containing agent information, including names, descriptions
|
str: Formatted string containing agent information
|
||||||
and IDs for all available agents.
|
|
||||||
"""
|
"""
|
||||||
logger.info(f"Showcasing {len(agents)} available agents")
|
|
||||||
|
|
||||||
formatted_agents = []
|
def truncate(text: str, max_length: int = 130) -> str:
|
||||||
header = f"\n####### Agents available in the swarm: {name} ############\n"
|
return (
|
||||||
header += f"{description}\n"
|
f"{text[:max_length]}..."
|
||||||
row_format = "{:<5} | {:<20} | {:<50}"
|
if len(text) > max_length
|
||||||
header_row = row_format.format("ID", "Agent Name", "Description")
|
else text
|
||||||
separator = "-" * 80
|
)
|
||||||
|
|
||||||
formatted_agents.append(header)
|
output = []
|
||||||
formatted_agents.append(separator)
|
|
||||||
formatted_agents.append(header_row)
|
|
||||||
formatted_agents.append(separator)
|
|
||||||
|
|
||||||
|
if format.upper() == "TABLE":
|
||||||
|
output.append("\n| ID | Agent Name | Description |")
|
||||||
|
output.append("|-----|------------|-------------|")
|
||||||
for idx, agent in enumerate(agents):
|
for idx, agent in enumerate(agents):
|
||||||
if not isinstance(agent, Agent):
|
if isinstance(agent, Agent):
|
||||||
logger.warning(
|
agent_name = getattr(agent, "agent_name", str(agent))
|
||||||
f"Skipping non-Agent object: {type(agent)}"
|
description = getattr(
|
||||||
|
agent,
|
||||||
|
"description",
|
||||||
|
getattr(
|
||||||
|
agent, "system_prompt", "Unknown description"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
continue
|
desc = truncate(description, 50)
|
||||||
|
output.append(
|
||||||
agent_name = get_agent_name(agent)
|
f"| {idx + 1} | {agent_name} | {desc} |"
|
||||||
description = (
|
|
||||||
get_agent_description(agent)[:100] + "..."
|
|
||||||
if len(get_agent_description(agent)) > 100
|
|
||||||
else get_agent_description(agent)
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
formatted_agents.append(
|
output.append(
|
||||||
row_format.format(idx + 1, agent_name, description)
|
f"| {idx + 1} | {agent} | Unknown description |"
|
||||||
)
|
)
|
||||||
|
return "\n".join(output)
|
||||||
|
|
||||||
|
# Default XML format
|
||||||
|
output.append("<agents>")
|
||||||
|
if name:
|
||||||
|
output.append(f" <name>{name}</name>")
|
||||||
|
if description:
|
||||||
|
output.append(
|
||||||
|
f" <description>{truncate(description)}</description>"
|
||||||
|
)
|
||||||
|
for idx, agent in enumerate(agents):
|
||||||
|
output.append(f" <agent id='{idx + 1}'>")
|
||||||
|
if isinstance(agent, Agent):
|
||||||
|
agent_name = getattr(agent, "agent_name", str(agent))
|
||||||
|
description = getattr(
|
||||||
|
agent,
|
||||||
|
"description",
|
||||||
|
getattr(
|
||||||
|
agent, "system_prompt", "Unknown description"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
output.append(f" <name>{agent_name}</name>")
|
||||||
|
output.append(
|
||||||
|
f" <description>{truncate(description)}</description>"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output.append(f" <name>{agent}</name>")
|
||||||
|
output.append(
|
||||||
|
" <description>Unknown description</description>"
|
||||||
|
)
|
||||||
|
output.append(" </agent>")
|
||||||
|
output.append("</agents>")
|
||||||
|
|
||||||
showcase = "\n".join(formatted_agents)
|
return "\n".join(output)
|
||||||
|
|
||||||
return showcase
|
|
||||||
|
@ -0,0 +1,203 @@
|
|||||||
|
from typing import Any
|
||||||
|
import inspect
|
||||||
|
from functools import partial
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class NameResolver:
|
||||||
|
"""Utility class for resolving names of various objects"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name(obj: Any, default: str = "unnamed_callable") -> str:
|
||||||
|
"""
|
||||||
|
Get the name of any object with multiple fallback strategies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The object to get the name from
|
||||||
|
default: Default name if all strategies fail
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The resolved name
|
||||||
|
"""
|
||||||
|
strategies = [
|
||||||
|
# Try getting __name__ attribute
|
||||||
|
lambda x: getattr(x, "__name__", None),
|
||||||
|
# Try getting class name
|
||||||
|
lambda x: (
|
||||||
|
x.__class__.__name__
|
||||||
|
if hasattr(x, "__class__")
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
# Try getting function name if it's a partial
|
||||||
|
lambda x: (
|
||||||
|
x.func.__name__ if isinstance(x, partial) else None
|
||||||
|
),
|
||||||
|
# Try getting the name from the class's type
|
||||||
|
lambda x: type(x).__name__,
|
||||||
|
# Try getting qualname
|
||||||
|
lambda x: getattr(x, "__qualname__", None),
|
||||||
|
# Try getting the module and class name
|
||||||
|
lambda x: (
|
||||||
|
f"{x.__module__}.{x.__class__.__name__}"
|
||||||
|
if hasattr(x, "__module__")
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
# For async functions
|
||||||
|
lambda x: (
|
||||||
|
x.__name__ if inspect.iscoroutinefunction(x) else None
|
||||||
|
),
|
||||||
|
# For classes with custom __str__
|
||||||
|
lambda x: (
|
||||||
|
str(x)
|
||||||
|
if hasattr(x, "__str__")
|
||||||
|
and x.__str__ != object.__str__
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
# For wrapped functions
|
||||||
|
lambda x: (
|
||||||
|
getattr(x, "__wrapped__", None).__name__
|
||||||
|
if hasattr(x, "__wrapped__")
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Try each strategy
|
||||||
|
for strategy in strategies:
|
||||||
|
try:
|
||||||
|
name = strategy(obj)
|
||||||
|
if name and isinstance(name, str):
|
||||||
|
return name.replace(" ", "_").replace("-", "_")
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Return default if all strategies fail
|
||||||
|
return default
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_callable_details(obj: Any) -> dict:
|
||||||
|
"""
|
||||||
|
Get detailed information about a callable object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary containing:
|
||||||
|
- name: The resolved name
|
||||||
|
- type: The type of callable
|
||||||
|
- signature: The signature if available
|
||||||
|
- module: The module name if available
|
||||||
|
- doc: The docstring if available
|
||||||
|
"""
|
||||||
|
details = {
|
||||||
|
"name": NameResolver.get_name(obj),
|
||||||
|
"type": "unknown",
|
||||||
|
"signature": None,
|
||||||
|
"module": getattr(obj, "__module__", "unknown"),
|
||||||
|
"doc": inspect.getdoc(obj)
|
||||||
|
or "No documentation available",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Determine the type
|
||||||
|
if inspect.isclass(obj):
|
||||||
|
details["type"] = "class"
|
||||||
|
elif inspect.iscoroutinefunction(obj):
|
||||||
|
details["type"] = "async_function"
|
||||||
|
elif inspect.isfunction(obj):
|
||||||
|
details["type"] = "function"
|
||||||
|
elif isinstance(obj, partial):
|
||||||
|
details["type"] = "partial"
|
||||||
|
elif callable(obj):
|
||||||
|
details["type"] = "callable"
|
||||||
|
|
||||||
|
# Try to get signature
|
||||||
|
try:
|
||||||
|
details["signature"] = str(inspect.signature(obj))
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
details["signature"] = "Unknown signature"
|
||||||
|
|
||||||
|
return details
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_safe_name(cls, obj: Any, max_retries: int = 3) -> str:
|
||||||
|
"""
|
||||||
|
Safely get a name with retries and validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: Object to get name from
|
||||||
|
max_retries: Maximum number of retry attempts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A valid name string
|
||||||
|
"""
|
||||||
|
retries = 0
|
||||||
|
last_error = None
|
||||||
|
|
||||||
|
while retries < max_retries:
|
||||||
|
try:
|
||||||
|
name = cls.get_name(obj)
|
||||||
|
|
||||||
|
# Validate and clean the name
|
||||||
|
if name:
|
||||||
|
# Remove invalid characters
|
||||||
|
clean_name = "".join(
|
||||||
|
c
|
||||||
|
for c in name
|
||||||
|
if c.isalnum() or c in ["_", "."]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure it starts with a letter or underscore
|
||||||
|
if (
|
||||||
|
not clean_name[0].isalpha()
|
||||||
|
and clean_name[0] != "_"
|
||||||
|
):
|
||||||
|
clean_name = f"_{clean_name}"
|
||||||
|
|
||||||
|
return clean_name
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
retries += 1
|
||||||
|
|
||||||
|
# If all retries failed, generate a unique fallback name
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
fallback = f"callable_{uuid.uuid4().hex[:8]}"
|
||||||
|
logging.warning(
|
||||||
|
f"Failed to get name after {max_retries} retries. Using fallback: {fallback}. "
|
||||||
|
f"Last error: {str(last_error)}"
|
||||||
|
)
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
|
||||||
|
# # Example usage
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# def test_resolver():
|
||||||
|
# # Test cases
|
||||||
|
# class TestClass:
|
||||||
|
# def method(self):
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# async def async_func():
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# test_cases = [
|
||||||
|
# TestClass, # Class
|
||||||
|
# TestClass(), # Instance
|
||||||
|
# async_func, # Async function
|
||||||
|
# lambda x: x, # Lambda
|
||||||
|
# partial(print, end=""), # Partial
|
||||||
|
# TestClass.method, # Method
|
||||||
|
# print, # Built-in function
|
||||||
|
# str, # Built-in class
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# resolver = NameResolver()
|
||||||
|
|
||||||
|
# print("\nName Resolution Results:")
|
||||||
|
# print("-" * 50)
|
||||||
|
# for obj in test_cases:
|
||||||
|
# details = resolver.get_callable_details(obj)
|
||||||
|
# safe_name = resolver.get_safe_name(obj)
|
||||||
|
# print(f"\nObject: {obj}")
|
||||||
|
# print(f"Safe Name: {safe_name}")
|
||||||
|
# print(f"Details: {details}")
|
||||||
|
|
||||||
|
# test_resolver()
|
@ -0,0 +1,53 @@
|
|||||||
|
import concurrent.futures
|
||||||
|
from typing import List, Union
|
||||||
|
from swarms.structs.agent import Agent
|
||||||
|
|
||||||
|
|
||||||
|
def update_system_prompts(
|
||||||
|
agents: List[Union[Agent, str]],
|
||||||
|
prompt: str,
|
||||||
|
) -> List[Agent]:
|
||||||
|
"""
|
||||||
|
Update system prompts for a list of agents concurrently.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agents: List of Agent objects or strings to update
|
||||||
|
prompt: The prompt text to append to each agent's system prompt
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of updated Agent objects
|
||||||
|
"""
|
||||||
|
if not agents:
|
||||||
|
return agents
|
||||||
|
|
||||||
|
def update_agent_prompt(agent: Union[Agent, str]) -> Agent:
|
||||||
|
# Convert string to Agent if needed
|
||||||
|
if isinstance(agent, str):
|
||||||
|
agent = Agent(
|
||||||
|
agent_name=agent,
|
||||||
|
system_prompt=prompt, # Initialize with the provided prompt
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Preserve existing prompt and append new one
|
||||||
|
existing_prompt = (
|
||||||
|
agent.system_prompt if agent.system_prompt else ""
|
||||||
|
)
|
||||||
|
agent.system_prompt = existing_prompt + "\n" + prompt
|
||||||
|
return agent
|
||||||
|
|
||||||
|
# Use ThreadPoolExecutor for concurrent execution
|
||||||
|
max_workers = min(len(agents), 4) # Reasonable thread count
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(
|
||||||
|
max_workers=max_workers
|
||||||
|
) as executor:
|
||||||
|
futures = []
|
||||||
|
for agent in agents:
|
||||||
|
future = executor.submit(update_agent_prompt, agent)
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
# Collect results as they complete
|
||||||
|
updated_agents = []
|
||||||
|
for future in concurrent.futures.as_completed(futures):
|
||||||
|
updated_agents.append(future.result())
|
||||||
|
|
||||||
|
return updated_agents
|
@ -0,0 +1,292 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.distributed as dist
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
from loguru import logger
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StarAttentionConfig:
|
||||||
|
"""Configuration for StarAttention module.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
hidden_size: Dimension of the model's hidden states
|
||||||
|
num_attention_heads: Number of attention heads
|
||||||
|
num_hosts: Number of hosts in the distributed system
|
||||||
|
block_size: Size of each context block
|
||||||
|
anchor_size: Size of the anchor block
|
||||||
|
dropout_prob: Dropout probability (default: 0.1)
|
||||||
|
layer_norm_eps: Layer normalization epsilon (default: 1e-12)
|
||||||
|
"""
|
||||||
|
|
||||||
|
hidden_size: int
|
||||||
|
num_attention_heads: int
|
||||||
|
num_hosts: int
|
||||||
|
block_size: int
|
||||||
|
anchor_size: int
|
||||||
|
dropout_prob: float = 0.1
|
||||||
|
layer_norm_eps: float = 1e-12
|
||||||
|
|
||||||
|
|
||||||
|
class StarAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
Implementation of Star Attention mechanism for distributed inference.
|
||||||
|
|
||||||
|
The module implements a two-phase attention mechanism:
|
||||||
|
1. Local Context Encoding with Anchor Blocks
|
||||||
|
2. Query Encoding and Output Generation with Global Attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: StarAttentionConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if config.hidden_size % config.num_attention_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Hidden size {config.hidden_size} not divisible by number of attention "
|
||||||
|
f"heads {config.num_attention_heads}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.head_dim = (
|
||||||
|
config.hidden_size // config.num_attention_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
self.query = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.key = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
self.value = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(config.dropout_prob)
|
||||||
|
self.layer_norm = nn.LayerNorm(
|
||||||
|
config.hidden_size, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
# KV cache for storing computed key/value pairs
|
||||||
|
self.kv_cache = {}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Initialized StarAttention with config: {config}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _split_heads(
|
||||||
|
self, tensor: torch.Tensor, num_heads: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Split the last dimension into (num_heads, head_dim)."""
|
||||||
|
batch_size, seq_len, _ = tensor.size()
|
||||||
|
tensor = tensor.view(
|
||||||
|
batch_size, seq_len, num_heads, self.head_dim
|
||||||
|
)
|
||||||
|
# Transpose to (batch_size, num_heads, seq_len, head_dim)
|
||||||
|
return tensor.transpose(1, 2)
|
||||||
|
|
||||||
|
def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Merge the head dimension back into hidden_size."""
|
||||||
|
batch_size, _, seq_len, _ = tensor.size()
|
||||||
|
tensor = tensor.transpose(1, 2)
|
||||||
|
return tensor.reshape(
|
||||||
|
batch_size, seq_len, self.config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compute_attention_scores(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute attention scores and weighted values."""
|
||||||
|
# Scale dot-product attention
|
||||||
|
scores = torch.matmul(
|
||||||
|
query, key.transpose(-2, -1)
|
||||||
|
) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores.masked_fill(mask == 0, float("-inf"))
|
||||||
|
|
||||||
|
# Online softmax computation
|
||||||
|
attention_probs = torch.nn.functional.softmax(scores, dim=-1)
|
||||||
|
attention_probs = self.dropout(attention_probs)
|
||||||
|
|
||||||
|
context = torch.matmul(attention_probs, value)
|
||||||
|
|
||||||
|
return context, attention_probs
|
||||||
|
|
||||||
|
def phase1_local_context_encoding(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
host_id: int,
|
||||||
|
device: Union[str, torch.device] = "cuda",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Phase 1: Local Context Encoding with Anchor Blocks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Input tensor of shape (batch_size, seq_len)
|
||||||
|
host_id: ID of the current host
|
||||||
|
device: Device to run computations on
|
||||||
|
"""
|
||||||
|
logger.debug(f"Starting Phase 1 on host {host_id}")
|
||||||
|
|
||||||
|
# Calculate block assignments
|
||||||
|
block_start = host_id * self.config.block_size
|
||||||
|
block_end = block_start + self.config.block_size
|
||||||
|
|
||||||
|
# Get local block
|
||||||
|
local_block = input_ids[:, block_start:block_end].to(device)
|
||||||
|
|
||||||
|
# Get anchor block (first block)
|
||||||
|
anchor_block = input_ids[:, : self.config.anchor_size].to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute KV pairs for local block
|
||||||
|
local_hidden = self.layer_norm(local_block)
|
||||||
|
local_key = self._split_heads(
|
||||||
|
self.key(local_hidden), self.config.num_attention_heads
|
||||||
|
)
|
||||||
|
local_value = self._split_heads(
|
||||||
|
self.value(local_hidden), self.config.num_attention_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store in KV cache
|
||||||
|
self.kv_cache[host_id] = {
|
||||||
|
"key": local_key,
|
||||||
|
"value": local_value,
|
||||||
|
"anchor_key": (
|
||||||
|
None
|
||||||
|
if host_id == 0
|
||||||
|
else self._split_heads(
|
||||||
|
self.key(self.layer_norm(anchor_block)),
|
||||||
|
self.config.num_attention_heads,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Phase 1 complete on host {host_id}. KV cache shapes - "
|
||||||
|
f"key: {local_key.shape}, value: {local_value.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def phase2_query_encoding(
|
||||||
|
self,
|
||||||
|
query_input: torch.Tensor,
|
||||||
|
host_id: int,
|
||||||
|
is_query_host: bool,
|
||||||
|
device: Union[str, torch.device] = "cuda",
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Phase 2: Query Encoding and Output Generation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
|
||||||
|
host_id: ID of the current host
|
||||||
|
is_query_host: Whether this host is the query host
|
||||||
|
device: Device to run computations on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output tensor if this is the query host, None otherwise
|
||||||
|
"""
|
||||||
|
logger.debug(f"Starting Phase 2 on host {host_id}")
|
||||||
|
|
||||||
|
# Transform query
|
||||||
|
query_hidden = self.layer_norm(query_input)
|
||||||
|
query = self._split_heads(
|
||||||
|
self.query(query_hidden), self.config.num_attention_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute local attention scores
|
||||||
|
local_context, local_probs = self._compute_attention_scores(
|
||||||
|
query,
|
||||||
|
self.kv_cache[host_id]["key"],
|
||||||
|
self.kv_cache[host_id]["value"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_query_host:
|
||||||
|
# Non-query hosts send their local attention statistics
|
||||||
|
dist.send(local_probs, dst=self.config.num_hosts - 1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Query host aggregates attention from all hosts
|
||||||
|
all_attention_probs = [local_probs]
|
||||||
|
for src_rank in range(self.config.num_hosts - 1):
|
||||||
|
probs = torch.empty_like(local_probs)
|
||||||
|
dist.recv(probs, src=src_rank)
|
||||||
|
all_attention_probs.append(probs)
|
||||||
|
|
||||||
|
# Compute global attention
|
||||||
|
torch.mean(torch.stack(all_attention_probs), dim=0)
|
||||||
|
|
||||||
|
# Final output computation
|
||||||
|
output = self._merge_heads(local_context)
|
||||||
|
output = self.dropout(output)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Phase 2 complete on host {host_id}. Output shape: {output.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
query_input: torch.Tensor,
|
||||||
|
host_id: int,
|
||||||
|
is_query_host: bool,
|
||||||
|
device: Union[str, torch.device] = "cuda",
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Forward pass of the StarAttention module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Input tensor of shape (batch_size, seq_len)
|
||||||
|
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
|
||||||
|
host_id: ID of the current host
|
||||||
|
is_query_host: Whether this host is the query host
|
||||||
|
device: Device to run computations on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output tensor if this is the query host, None otherwise
|
||||||
|
"""
|
||||||
|
# Phase 1: Local Context Encoding
|
||||||
|
self.phase1_local_context_encoding(input_ids, host_id, device)
|
||||||
|
|
||||||
|
# Phase 2: Query Encoding and Output Generation
|
||||||
|
return self.phase2_query_encoding(
|
||||||
|
query_input, host_id, is_query_host, device
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Example forward pass
|
||||||
|
config = StarAttentionConfig(
|
||||||
|
hidden_size=768,
|
||||||
|
num_attention_heads=12,
|
||||||
|
num_hosts=3,
|
||||||
|
block_size=512,
|
||||||
|
anchor_size=128,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
model = StarAttention(config)
|
||||||
|
|
||||||
|
# Example input tensors
|
||||||
|
batch_size = 4
|
||||||
|
seq_len = 512
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, 1000, (batch_size, seq_len)
|
||||||
|
) # Random input IDs
|
||||||
|
query_input = torch.randn(
|
||||||
|
batch_size, seq_len, config.hidden_size
|
||||||
|
) # Random query input
|
||||||
|
|
||||||
|
# Example forward pass for query host (host_id = 2)
|
||||||
|
output = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
query_input=query_input,
|
||||||
|
host_id=2,
|
||||||
|
is_query_host=True,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(output)
|
Loading…
Reference in new issue