commit
95e025454e
@ -0,0 +1,56 @@
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from swarm_models import OpenAIChat
|
||||
|
||||
from swarms import Agent
|
||||
from swarms.prompts.finance_agent_sys_prompt import (
|
||||
FINANCIAL_AGENT_SYS_PROMPT,
|
||||
)
|
||||
from async_executor import HighSpeedExecutor
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Get the OpenAI API key from the environment variable
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
# Create an instance of the OpenAIChat class
|
||||
model = OpenAIChat(
|
||||
openai_api_key=api_key, model_name="gpt-4o-mini", temperature=0.1
|
||||
)
|
||||
|
||||
# Initialize the agent
|
||||
agent = Agent(
|
||||
agent_name="Financial-Analysis-Agent",
|
||||
system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
|
||||
llm=model,
|
||||
max_loops=1,
|
||||
# autosave=True,
|
||||
# dashboard=False,
|
||||
# verbose=True,
|
||||
# dynamic_temperature_enabled=True,
|
||||
# saved_state_path="finance_agent.json",
|
||||
# user_name="swarms_corp",
|
||||
# retry_attempts=1,
|
||||
# context_length=200000,
|
||||
# return_step_meta=True,
|
||||
# output_type="json", # "json", "dict", "csv" OR "string" soon "yaml" and
|
||||
# auto_generate_prompt=False, # Auto generate prompt for the agent based on name, description, and system prompt, task
|
||||
# # artifacts_on=True,
|
||||
# artifacts_output_path="roth_ira_report",
|
||||
# artifacts_file_extension=".txt",
|
||||
# max_tokens=8000,
|
||||
# return_history=True,
|
||||
)
|
||||
|
||||
|
||||
def execute_agent(
|
||||
task: str = "How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria. Create a report on this question.",
|
||||
):
|
||||
return agent.run(task)
|
||||
|
||||
|
||||
executor = HighSpeedExecutor()
|
||||
results = executor.run(execute_agent, 2)
|
||||
|
||||
print(results)
|
@ -0,0 +1,131 @@
|
||||
import asyncio
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
|
||||
class HighSpeedExecutor:
|
||||
def __init__(self, num_processes: int = None):
|
||||
"""
|
||||
Initialize the executor with configurable number of processes.
|
||||
If num_processes is None, it uses CPU count.
|
||||
"""
|
||||
self.num_processes = num_processes or mp.cpu_count()
|
||||
|
||||
async def _worker(
|
||||
self,
|
||||
queue: asyncio.Queue,
|
||||
func: Any,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Async worker that processes tasks from the queue"""
|
||||
while True:
|
||||
try:
|
||||
# Non-blocking get from queue
|
||||
await queue.get()
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, partial(func, *args, **kwargs)
|
||||
)
|
||||
queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
async def _distribute_tasks(
|
||||
self, num_tasks: int, queue: asyncio.Queue
|
||||
):
|
||||
"""Distribute tasks across the queue"""
|
||||
for i in range(num_tasks):
|
||||
await queue.put(i)
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
func: Any,
|
||||
num_executions: int,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Union[int, float]]:
|
||||
"""
|
||||
Execute the given function multiple times concurrently.
|
||||
|
||||
Args:
|
||||
func: The function to execute
|
||||
num_executions: Number of times to execute the function
|
||||
*args, **kwargs: Arguments to pass to the function
|
||||
|
||||
Returns:
|
||||
A dictionary containing the number of executions, duration, and executions per second.
|
||||
"""
|
||||
queue = asyncio.Queue()
|
||||
|
||||
# Create worker tasks
|
||||
workers = [
|
||||
asyncio.create_task(
|
||||
self._worker(queue, func, *args, **kwargs)
|
||||
)
|
||||
for _ in range(self.num_processes)
|
||||
]
|
||||
|
||||
# Start timing
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Distribute tasks
|
||||
await self._distribute_tasks(num_executions, queue)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
await queue.join()
|
||||
|
||||
# Cancel workers
|
||||
for worker in workers:
|
||||
worker.cancel()
|
||||
|
||||
# Wait for all workers to finish
|
||||
await asyncio.gather(*workers, return_exceptions=True)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
return {
|
||||
"executions": num_executions,
|
||||
"duration": duration,
|
||||
"executions_per_second": num_executions / duration,
|
||||
}
|
||||
|
||||
def run(
|
||||
self,
|
||||
func: Any,
|
||||
num_executions: int,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
return asyncio.run(
|
||||
self.execute_batch(func, num_executions, *args, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
# def example_function(x: int = 0) -> int:
|
||||
# """Example function to execute"""
|
||||
# return x * x
|
||||
|
||||
|
||||
# async def main():
|
||||
# # Create executor with number of CPU cores
|
||||
# executor = HighSpeedExecutor()
|
||||
|
||||
# # Execute the function 1000 times
|
||||
# result = await executor.execute_batch(
|
||||
# example_function, num_executions=1000, x=42
|
||||
# )
|
||||
|
||||
# print(
|
||||
# f"Completed {result['executions']} executions in {result['duration']:.2f} seconds"
|
||||
# )
|
||||
# print(
|
||||
# f"Rate: {result['executions_per_second']:.2f} executions/second"
|
||||
# )
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# # Run the async main function
|
||||
# asyncio.run(main())
|
@ -1,18 +1,27 @@
|
||||
|
||||
/* Further customization as needed */
|
||||
|
||||
/* * Further customization as needed */ */
|
||||
|
||||
.md-typeset__table {
|
||||
min-width: 100%;
|
||||
min-width: 100%;
|
||||
}
|
||||
|
||||
.md-typeset table:not([class]) {
|
||||
display: table;
|
||||
}
|
||||
|
||||
/*
|
||||
:root {
|
||||
--md-primary-fg-color: #EE0F0F;
|
||||
--md-primary-fg-color--light: #ECB7B7;
|
||||
--md-primary-fg-color--dark: #90030C;
|
||||
} */
|
||||
/* Dark mode */
|
||||
[data-md-color-scheme="slate"] {
|
||||
--md-default-bg-color: black;
|
||||
}
|
||||
|
||||
.header__ellipsis {
|
||||
color: black;
|
||||
}
|
||||
|
||||
.md-copyright__highlight {
|
||||
color: black;
|
||||
}
|
||||
|
||||
|
||||
.md-header.md-header--shadow {
|
||||
color: black;
|
||||
}
|
@ -0,0 +1,308 @@
|
||||
import os
|
||||
from swarms import Agent
|
||||
from swarm_models import OpenAIChat
|
||||
from web3 import Web3
|
||||
from typing import Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
from loguru import logger
|
||||
from dotenv import load_dotenv
|
||||
import csv
|
||||
import requests
|
||||
import time
|
||||
|
||||
BLOCKCHAIN_AGENT_PROMPT = """
|
||||
You are an expert blockchain and cryptocurrency analyst with deep knowledge of Ethereum markets and DeFi ecosystems.
|
||||
You have access to real-time ETH price data and transaction information.
|
||||
|
||||
For each transaction, analyze:
|
||||
|
||||
1. MARKET CONTEXT
|
||||
- Current ETH price and what this transaction means in USD terms
|
||||
- How this movement compares to typical market volumes
|
||||
- Whether this could impact ETH price
|
||||
|
||||
2. BEHAVIORAL ANALYSIS
|
||||
- Whether this appears to be institutional, whale, or protocol movement
|
||||
- If this fits any known wallet patterns or behaviors
|
||||
- Signs of smart contract interaction or DeFi activity
|
||||
|
||||
3. RISK & IMPLICATIONS
|
||||
- Potential market impact or price influence
|
||||
- Signs of potential market manipulation or unusual activity
|
||||
- Protocol or DeFi risks if applicable
|
||||
|
||||
4. STRATEGIC INSIGHTS
|
||||
- What traders should know about this movement
|
||||
- Potential chain reactions or follow-up effects
|
||||
- Market opportunities or risks created
|
||||
|
||||
Write naturally but precisely. Focus on actionable insights and important patterns.
|
||||
Your analysis helps traders and researchers understand significant market movements in real-time."""
|
||||
|
||||
|
||||
class EthereumAnalyzer:
|
||||
def __init__(self, min_value_eth: float = 100.0):
|
||||
load_dotenv()
|
||||
|
||||
logger.add(
|
||||
"eth_analysis.log",
|
||||
rotation="500 MB",
|
||||
retention="10 days",
|
||||
level="INFO",
|
||||
format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
|
||||
)
|
||||
|
||||
self.w3 = Web3(
|
||||
Web3.HTTPProvider(
|
||||
"https://mainnet.infura.io/v3/9aa3d95b3bc440fa88ea12eaa4456161"
|
||||
)
|
||||
)
|
||||
if not self.w3.is_connected():
|
||||
raise ConnectionError(
|
||||
"Failed to connect to Ethereum network"
|
||||
)
|
||||
|
||||
self.min_value_eth = min_value_eth
|
||||
self.last_processed_block = self.w3.eth.block_number
|
||||
self.eth_price = self.get_eth_price()
|
||||
self.last_price_update = time.time()
|
||||
|
||||
# Initialize AI agent
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"OpenAI API key not found in environment variables"
|
||||
)
|
||||
|
||||
model = OpenAIChat(
|
||||
openai_api_key=api_key,
|
||||
model_name="gpt-4",
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
self.agent = Agent(
|
||||
agent_name="Ethereum-Analysis-Agent",
|
||||
system_prompt=BLOCKCHAIN_AGENT_PROMPT,
|
||||
llm=model,
|
||||
max_loops=1,
|
||||
autosave=True,
|
||||
dashboard=False,
|
||||
verbose=True,
|
||||
dynamic_temperature_enabled=True,
|
||||
saved_state_path="eth_agent.json",
|
||||
user_name="eth_analyzer",
|
||||
retry_attempts=1,
|
||||
context_length=200000,
|
||||
output_type="string",
|
||||
streaming_on=False,
|
||||
)
|
||||
|
||||
self.csv_filename = "ethereum_analysis.csv"
|
||||
self.initialize_csv()
|
||||
|
||||
def get_eth_price(self) -> float:
|
||||
"""Get current ETH price from CoinGecko API."""
|
||||
try:
|
||||
response = requests.get(
|
||||
"https://api.coingecko.com/api/v3/simple/price",
|
||||
params={"ids": "ethereum", "vs_currencies": "usd"},
|
||||
)
|
||||
return float(response.json()["ethereum"]["usd"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching ETH price: {str(e)}")
|
||||
return 0.0
|
||||
|
||||
def update_eth_price(self):
|
||||
"""Update ETH price if more than 5 minutes have passed."""
|
||||
if time.time() - self.last_price_update > 300: # 5 minutes
|
||||
self.eth_price = self.get_eth_price()
|
||||
self.last_price_update = time.time()
|
||||
logger.info(f"Updated ETH price: ${self.eth_price:,.2f}")
|
||||
|
||||
def initialize_csv(self):
|
||||
"""Initialize CSV file with headers."""
|
||||
headers = [
|
||||
"timestamp",
|
||||
"transaction_hash",
|
||||
"from_address",
|
||||
"to_address",
|
||||
"value_eth",
|
||||
"value_usd",
|
||||
"eth_price",
|
||||
"gas_used",
|
||||
"gas_price_gwei",
|
||||
"block_number",
|
||||
"analysis",
|
||||
]
|
||||
|
||||
if not os.path.exists(self.csv_filename):
|
||||
with open(self.csv_filename, "w", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(headers)
|
||||
|
||||
async def analyze_transaction(
|
||||
self, tx_hash: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Analyze a single transaction."""
|
||||
try:
|
||||
tx = self.w3.eth.get_transaction(tx_hash)
|
||||
receipt = self.w3.eth.get_transaction_receipt(tx_hash)
|
||||
|
||||
value_eth = float(self.w3.from_wei(tx.value, "ether"))
|
||||
|
||||
if value_eth < self.min_value_eth:
|
||||
return None
|
||||
|
||||
block = self.w3.eth.get_block(tx.blockNumber)
|
||||
|
||||
# Update ETH price if needed
|
||||
self.update_eth_price()
|
||||
|
||||
value_usd = value_eth * self.eth_price
|
||||
|
||||
analysis = {
|
||||
"timestamp": datetime.fromtimestamp(
|
||||
block.timestamp
|
||||
).isoformat(),
|
||||
"transaction_hash": tx_hash.hex(),
|
||||
"from_address": tx["from"],
|
||||
"to_address": tx.to if tx.to else "Contract Creation",
|
||||
"value_eth": value_eth,
|
||||
"value_usd": value_usd,
|
||||
"eth_price": self.eth_price,
|
||||
"gas_used": receipt.gasUsed,
|
||||
"gas_price_gwei": float(
|
||||
self.w3.from_wei(tx.gasPrice, "gwei")
|
||||
),
|
||||
"block_number": tx.blockNumber,
|
||||
}
|
||||
|
||||
# Check if it's a contract
|
||||
if tx.to:
|
||||
code = self.w3.eth.get_code(tx.to)
|
||||
analysis["is_contract"] = len(code) > 0
|
||||
|
||||
# Get contract events
|
||||
if analysis["is_contract"]:
|
||||
analysis["events"] = receipt.logs
|
||||
|
||||
return analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error analyzing transaction {tx_hash}: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
def prepare_analysis_prompt(self, tx_data: Dict[str, Any]) -> str:
|
||||
"""Prepare detailed analysis prompt including price context."""
|
||||
value_usd = tx_data["value_usd"]
|
||||
eth_price = tx_data["eth_price"]
|
||||
|
||||
prompt = f"""Analyze this Ethereum transaction in current market context:
|
||||
|
||||
Transaction Details:
|
||||
- Value: {tx_data['value_eth']:.2f} ETH (${value_usd:,.2f} at current price)
|
||||
- Current ETH Price: ${eth_price:,.2f}
|
||||
- From: {tx_data['from_address']}
|
||||
- To: {tx_data['to_address']}
|
||||
- Contract Interaction: {tx_data.get('is_contract', False)}
|
||||
- Gas Used: {tx_data['gas_used']:,} units
|
||||
- Gas Price: {tx_data['gas_price_gwei']:.2f} Gwei
|
||||
- Block: {tx_data['block_number']}
|
||||
- Timestamp: {tx_data['timestamp']}
|
||||
|
||||
{f"Event Count: {len(tx_data['events'])} events" if tx_data.get('events') else "No contract events"}
|
||||
|
||||
Consider the transaction's significance given the current ETH price of ${eth_price:,.2f} and total USD value of ${value_usd:,.2f}.
|
||||
Analyze market impact, patterns, risks, and strategic implications."""
|
||||
|
||||
return prompt
|
||||
|
||||
def save_to_csv(self, tx_data: Dict[str, Any], ai_analysis: str):
|
||||
"""Save transaction data and analysis to CSV."""
|
||||
row = [
|
||||
tx_data["timestamp"],
|
||||
tx_data["transaction_hash"],
|
||||
tx_data["from_address"],
|
||||
tx_data["to_address"],
|
||||
tx_data["value_eth"],
|
||||
tx_data["value_usd"],
|
||||
tx_data["eth_price"],
|
||||
tx_data["gas_used"],
|
||||
tx_data["gas_price_gwei"],
|
||||
tx_data["block_number"],
|
||||
ai_analysis.replace("\n", " "),
|
||||
]
|
||||
|
||||
with open(self.csv_filename, "a", newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(row)
|
||||
|
||||
async def monitor_transactions(self):
|
||||
"""Monitor and analyze transactions one at a time."""
|
||||
logger.info(
|
||||
f"Starting transaction monitor (minimum value: {self.min_value_eth} ETH)"
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
current_block = self.w3.eth.block_number
|
||||
block = self.w3.eth.get_block(
|
||||
current_block, full_transactions=True
|
||||
)
|
||||
|
||||
for tx in block.transactions:
|
||||
tx_analysis = await self.analyze_transaction(
|
||||
tx.hash
|
||||
)
|
||||
|
||||
if tx_analysis:
|
||||
# Get AI analysis
|
||||
analysis_prompt = (
|
||||
self.prepare_analysis_prompt(tx_analysis)
|
||||
)
|
||||
ai_analysis = self.agent.run(analysis_prompt)
|
||||
print(ai_analysis)
|
||||
|
||||
# Save to CSV
|
||||
self.save_to_csv(tx_analysis, ai_analysis)
|
||||
|
||||
# Print analysis
|
||||
print("\n" + "=" * 50)
|
||||
print("New Transaction Analysis")
|
||||
print(
|
||||
f"Hash: {tx_analysis['transaction_hash']}"
|
||||
)
|
||||
print(
|
||||
f"Value: {tx_analysis['value_eth']:.2f} ETH (${tx_analysis['value_usd']:,.2f})"
|
||||
)
|
||||
print(
|
||||
f"Current ETH Price: ${self.eth_price:,.2f}"
|
||||
)
|
||||
print("=" * 50)
|
||||
print(ai_analysis)
|
||||
print("=" * 50 + "\n")
|
||||
|
||||
await asyncio.sleep(1) # Wait for next block
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in monitoring loop: {str(e)}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Entry point for the analysis system."""
|
||||
analyzer = EthereumAnalyzer(min_value_eth=100.0)
|
||||
await analyzer.monitor_transactions()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Starting Ethereum Transaction Analyzer...")
|
||||
print("Saving results to ethereum_analysis.csv")
|
||||
print("Press Ctrl+C to stop")
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping analyzer...")
|
Can't render this file because it has a wrong number of fields in line 4.
|
@ -0,0 +1,244 @@
|
||||
import os
|
||||
import asyncio
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Any
|
||||
from swarms import Agent
|
||||
from swarm_models import OpenAIChat
|
||||
from dotenv import load_dotenv
|
||||
from swarms.utils.formatter import formatter
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Get OpenAI API key
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
# Define Pydantic schema for agent outputs
|
||||
class AgentOutput(BaseModel):
|
||||
"""Schema for capturing the output of each agent."""
|
||||
|
||||
agent_name: str = Field(..., description="The name of the agent")
|
||||
message: str = Field(
|
||||
...,
|
||||
description="The agent's response or contribution to the group chat",
|
||||
)
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional metadata about the agent's response",
|
||||
)
|
||||
|
||||
|
||||
class GroupChat:
|
||||
"""
|
||||
GroupChat class to enable multiple agents to communicate in an asynchronous group chat.
|
||||
Each agent is aware of all other agents, every message exchanged, and the social context.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
agents: List[Agent],
|
||||
max_loops: int = 1,
|
||||
):
|
||||
"""
|
||||
Initialize the GroupChat.
|
||||
|
||||
Args:
|
||||
name (str): Name of the group chat.
|
||||
description (str): Description of the purpose of the group chat.
|
||||
agents (List[Agent]): A list of agents participating in the chat.
|
||||
max_loops (int): Maximum number of loops to run through all agents.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.agents = agents
|
||||
self.max_loops = max_loops
|
||||
self.chat_history = (
|
||||
[]
|
||||
) # Stores all messages exchanged in the chat
|
||||
|
||||
formatter.print_panel(
|
||||
f"Initialized GroupChat '{self.name}' with {len(self.agents)} agents. Max loops: {self.max_loops}",
|
||||
title="Groupchat Swarm",
|
||||
)
|
||||
|
||||
async def _agent_conversation(
|
||||
self, agent: Agent, input_message: str
|
||||
) -> AgentOutput:
|
||||
"""
|
||||
Facilitate a single agent's response to the chat.
|
||||
|
||||
Args:
|
||||
agent (Agent): The agent responding.
|
||||
input_message (str): The message triggering the response.
|
||||
|
||||
Returns:
|
||||
AgentOutput: The agent's response captured in a structured format.
|
||||
"""
|
||||
formatter.print_panel(
|
||||
f"Agent '{agent.agent_name}' is responding to the message: {input_message}",
|
||||
title="Groupchat Swarm",
|
||||
)
|
||||
response = await asyncio.to_thread(agent.run, input_message)
|
||||
|
||||
output = AgentOutput(
|
||||
agent_name=agent.agent_name,
|
||||
message=response,
|
||||
metadata={"context_length": agent.context_length},
|
||||
)
|
||||
# logger.debug(f"Agent '{agent.agent_name}' response: {response}")
|
||||
return output
|
||||
|
||||
async def _run(self, initial_message: str) -> List[AgentOutput]:
|
||||
"""
|
||||
Execute the group chat asynchronously, looping through all agents up to max_loops.
|
||||
|
||||
Args:
|
||||
initial_message (str): The initial message to start the chat.
|
||||
|
||||
Returns:
|
||||
List[AgentOutput]: The responses of all agents across all loops.
|
||||
"""
|
||||
formatter.print_panel(
|
||||
f"Starting group chat '{self.name}' with initial message: {initial_message}",
|
||||
title="Groupchat Swarm",
|
||||
)
|
||||
self.chat_history.append(
|
||||
{"sender": "System", "message": initial_message}
|
||||
)
|
||||
|
||||
outputs = []
|
||||
for loop in range(self.max_loops):
|
||||
formatter.print_panel(
|
||||
f"Group chat loop {loop + 1}/{self.max_loops}",
|
||||
title="Groupchat Swarm",
|
||||
)
|
||||
|
||||
for agent in self.agents:
|
||||
# Create a custom input message for each agent, sharing the chat history and social context
|
||||
input_message = (
|
||||
f"Chat History:\n{self._format_chat_history()}\n\n"
|
||||
f"Participants:\n"
|
||||
+ "\n".join(
|
||||
[
|
||||
f"- {a.agent_name}: {a.system_prompt}"
|
||||
for a in self.agents
|
||||
]
|
||||
)
|
||||
+ f"\n\nNew Message: {initial_message}\n\n"
|
||||
f"You are '{agent.agent_name}'. Remember to keep track of the social context, who is speaking, "
|
||||
f"and respond accordingly based on your role: {agent.system_prompt}."
|
||||
)
|
||||
|
||||
# Collect agent's response
|
||||
output = await self._agent_conversation(
|
||||
agent, input_message
|
||||
)
|
||||
outputs.append(output)
|
||||
|
||||
# Update chat history with the agent's response
|
||||
self.chat_history.append(
|
||||
{
|
||||
"sender": agent.agent_name,
|
||||
"message": output.message,
|
||||
}
|
||||
)
|
||||
|
||||
formatter.print_panel(
|
||||
"Group chat completed. All agent responses captured.",
|
||||
title="Groupchat Swarm",
|
||||
)
|
||||
return outputs
|
||||
|
||||
def run(self, task: str, *args, **kwargs):
|
||||
return asyncio.run(self.run(task, *args, **kwargs))
|
||||
|
||||
def _format_chat_history(self) -> str:
|
||||
"""
|
||||
Format the chat history for agents to understand the context.
|
||||
|
||||
Returns:
|
||||
str: The formatted chat history as a string.
|
||||
"""
|
||||
return "\n".join(
|
||||
[
|
||||
f"{entry['sender']}: {entry['message']}"
|
||||
for entry in self.chat_history
|
||||
]
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the group chat's outputs."""
|
||||
return self._format_chat_history()
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""JSON representation of the group chat's outputs."""
|
||||
return [
|
||||
{"sender": entry["sender"], "message": entry["message"]}
|
||||
for entry in self.chat_history
|
||||
]
|
||||
|
||||
|
||||
# Example Usage
|
||||
if __name__ == "__main__":
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Get the OpenAI API key from the environment variable
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
# Create an instance of the OpenAIChat class
|
||||
model = OpenAIChat(
|
||||
openai_api_key=api_key,
|
||||
model_name="gpt-4o-mini",
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
# Example agents
|
||||
agent1 = Agent(
|
||||
agent_name="Financial-Analysis-Agent",
|
||||
system_prompt="You are a financial analyst specializing in investment strategies.",
|
||||
llm=model,
|
||||
max_loops=1,
|
||||
autosave=False,
|
||||
dashboard=False,
|
||||
verbose=True,
|
||||
dynamic_temperature_enabled=True,
|
||||
user_name="swarms_corp",
|
||||
retry_attempts=1,
|
||||
context_length=200000,
|
||||
output_type="string",
|
||||
streaming_on=False,
|
||||
)
|
||||
|
||||
agent2 = Agent(
|
||||
agent_name="Tax-Adviser-Agent",
|
||||
system_prompt="You are a tax adviser who provides clear and concise guidance on tax-related queries.",
|
||||
llm=model,
|
||||
max_loops=1,
|
||||
autosave=False,
|
||||
dashboard=False,
|
||||
verbose=True,
|
||||
dynamic_temperature_enabled=True,
|
||||
user_name="swarms_corp",
|
||||
retry_attempts=1,
|
||||
context_length=200000,
|
||||
output_type="string",
|
||||
streaming_on=False,
|
||||
)
|
||||
|
||||
# Create group chat
|
||||
group_chat = GroupChat(
|
||||
name="Financial Discussion",
|
||||
description="A group chat for financial analysis and tax advice.",
|
||||
agents=[agent1, agent2],
|
||||
)
|
||||
|
||||
# Run the group chat
|
||||
asyncio.run(
|
||||
group_chat.run(
|
||||
"How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria? What do you guys think?"
|
||||
)
|
||||
)
|
@ -0,0 +1,417 @@
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
from datetime import datetime
|
||||
import inspect
|
||||
import typing
|
||||
from typing import Union
|
||||
from swarms import Agent
|
||||
from swarm_models import OpenAIChat
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any]
|
||||
required_params: List[str]
|
||||
callable: Optional[Callable] = None
|
||||
|
||||
|
||||
def extract_type_hints(func: Callable) -> Dict[str, Any]:
|
||||
"""Extract parameter types from function type hints."""
|
||||
return typing.get_type_hints(func)
|
||||
|
||||
|
||||
def extract_tool_info(func: Callable) -> ToolDefinition:
|
||||
"""Extract tool information from a callable function."""
|
||||
# Get function name
|
||||
name = func.__name__
|
||||
|
||||
# Get docstring
|
||||
description = inspect.getdoc(func) or "No description available"
|
||||
|
||||
# Get parameters and their types
|
||||
signature = inspect.signature(func)
|
||||
type_hints = extract_type_hints(func)
|
||||
|
||||
parameters = {}
|
||||
required_params = []
|
||||
|
||||
for param_name, param in signature.parameters.items():
|
||||
# Skip self parameter for methods
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
param_type = type_hints.get(param_name, Any)
|
||||
|
||||
# Handle optional parameters
|
||||
is_optional = (
|
||||
param.default != inspect.Parameter.empty
|
||||
or getattr(param_type, "__origin__", None) is Union
|
||||
and type(None) in param_type.__args__
|
||||
)
|
||||
|
||||
if not is_optional:
|
||||
required_params.append(param_name)
|
||||
|
||||
parameters[param_name] = {
|
||||
"type": str(param_type),
|
||||
"default": (
|
||||
None
|
||||
if param.default is inspect.Parameter.empty
|
||||
else param.default
|
||||
),
|
||||
"required": not is_optional,
|
||||
}
|
||||
|
||||
return ToolDefinition(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=parameters,
|
||||
required_params=required_params,
|
||||
callable=func,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionSpec:
|
||||
"""Specification for a callable tool function."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[
|
||||
str, dict
|
||||
] # Contains type and description for each parameter
|
||||
return_type: str
|
||||
return_description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionStep:
|
||||
"""Represents a single step in the execution plan."""
|
||||
|
||||
step_id: int
|
||||
function_name: str
|
||||
parameters: Dict[str, Any]
|
||||
expected_output: str
|
||||
completed: bool = False
|
||||
result: Any = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionContext:
|
||||
"""Maintains state during execution."""
|
||||
|
||||
task: str
|
||||
steps: List[ExecutionStep] = field(default_factory=list)
|
||||
results: Dict[int, Any] = field(default_factory=dict)
|
||||
current_step: int = 0
|
||||
history: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
class ToolAgent:
|
||||
def __init__(
|
||||
self,
|
||||
functions: List[Callable],
|
||||
openai_api_key: str,
|
||||
model_name: str = "gpt-4",
|
||||
temperature: float = 0.1,
|
||||
):
|
||||
self.functions = {func.__name__: func for func in functions}
|
||||
self.function_specs = self._analyze_functions(functions)
|
||||
|
||||
self.model = OpenAIChat(
|
||||
openai_api_key=openai_api_key,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
self.system_prompt = self._create_system_prompt()
|
||||
self.agent = Agent(
|
||||
agent_name="Tool-Agent",
|
||||
system_prompt=self.system_prompt,
|
||||
llm=self.model,
|
||||
max_loops=1,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
def _analyze_functions(
|
||||
self, functions: List[Callable]
|
||||
) -> Dict[str, FunctionSpec]:
|
||||
"""Analyze functions to create detailed specifications."""
|
||||
specs = {}
|
||||
for func in functions:
|
||||
hints = get_type_hints(func)
|
||||
sig = inspect.signature(func)
|
||||
doc = inspect.getdoc(func) or ""
|
||||
|
||||
# Parse docstring for parameter descriptions
|
||||
param_descriptions = {}
|
||||
current_param = None
|
||||
for line in doc.split("\n"):
|
||||
if ":param" in line:
|
||||
param_name = (
|
||||
line.split(":param")[1].split(":")[0].strip()
|
||||
)
|
||||
desc = line.split(":", 2)[-1].strip()
|
||||
param_descriptions[param_name] = desc
|
||||
elif ":return:" in line:
|
||||
return_desc = line.split(":return:")[1].strip()
|
||||
|
||||
# Build parameter specifications
|
||||
parameters = {}
|
||||
for name, param in sig.parameters.items():
|
||||
param_type = hints.get(name, Any)
|
||||
parameters[name] = {
|
||||
"type": str(param_type),
|
||||
"type_class": param_type,
|
||||
"description": param_descriptions.get(name, ""),
|
||||
"required": param.default == param.empty,
|
||||
}
|
||||
|
||||
specs[func.__name__] = FunctionSpec(
|
||||
name=func.__name__,
|
||||
description=doc.split("\n")[0],
|
||||
parameters=parameters,
|
||||
return_type=str(hints.get("return", Any)),
|
||||
return_description=(
|
||||
return_desc if "return_desc" in locals() else ""
|
||||
),
|
||||
)
|
||||
|
||||
return specs
|
||||
|
||||
def _create_system_prompt(self) -> str:
|
||||
"""Create system prompt with detailed function specifications."""
|
||||
functions_desc = []
|
||||
for spec in self.function_specs.values():
|
||||
params_desc = []
|
||||
for name, details in spec.parameters.items():
|
||||
params_desc.append(
|
||||
f" - {name}: {details['type']} - {details['description']}"
|
||||
)
|
||||
|
||||
functions_desc.append(
|
||||
f"""
|
||||
Function: {spec.name}
|
||||
Description: {spec.description}
|
||||
Parameters:
|
||||
{chr(10).join(params_desc)}
|
||||
Returns: {spec.return_type} - {spec.return_description}
|
||||
"""
|
||||
)
|
||||
|
||||
return f"""You are an AI agent that creates and executes plans using available functions.
|
||||
|
||||
Available Functions:
|
||||
{chr(10).join(functions_desc)}
|
||||
|
||||
You must respond in two formats depending on the phase:
|
||||
|
||||
1. Planning Phase:
|
||||
{{
|
||||
"phase": "planning",
|
||||
"plan": {{
|
||||
"description": "Overall plan description",
|
||||
"steps": [
|
||||
{{
|
||||
"step_id": 1,
|
||||
"function": "function_name",
|
||||
"parameters": {{
|
||||
"param1": "value1",
|
||||
"param2": "value2"
|
||||
}},
|
||||
"purpose": "Why this step is needed"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
}}
|
||||
|
||||
2. Execution Phase:
|
||||
{{
|
||||
"phase": "execution",
|
||||
"analysis": "Analysis of current result",
|
||||
"next_action": {{
|
||||
"type": "continue|request_input|complete",
|
||||
"reason": "Why this action was chosen",
|
||||
"needed_input": {{}} # If requesting input
|
||||
}}
|
||||
}}
|
||||
|
||||
Always:
|
||||
- Use exact function names
|
||||
- Ensure parameter types match specifications
|
||||
- Provide clear reasoning for each decision
|
||||
"""
|
||||
|
||||
def _execute_function(
|
||||
self, spec: FunctionSpec, parameters: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""Execute a function with type checking."""
|
||||
converted_params = {}
|
||||
for name, value in parameters.items():
|
||||
param_spec = spec.parameters[name]
|
||||
try:
|
||||
# Convert value to required type
|
||||
param_type = param_spec["type_class"]
|
||||
if param_type in (int, float, str, bool):
|
||||
converted_params[name] = param_type(value)
|
||||
else:
|
||||
converted_params[name] = value
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(
|
||||
f"Parameter '{name}' conversion failed: {str(e)}"
|
||||
)
|
||||
|
||||
return self.functions[spec.name](**converted_params)
|
||||
|
||||
def run(self, task: str) -> Dict[str, Any]:
|
||||
"""Execute task with planning and step-by-step execution."""
|
||||
context = ExecutionContext(task=task)
|
||||
execution_log = {
|
||||
"task": task,
|
||||
"start_time": datetime.utcnow().isoformat(),
|
||||
"steps": [],
|
||||
"final_result": None,
|
||||
}
|
||||
|
||||
try:
|
||||
# Planning phase
|
||||
plan_prompt = f"Create a plan to: {task}"
|
||||
plan_response = self.agent.run(plan_prompt)
|
||||
plan_data = json.loads(
|
||||
plan_response.replace("System:", "").strip()
|
||||
)
|
||||
|
||||
# Convert plan to execution steps
|
||||
for step in plan_data["plan"]["steps"]:
|
||||
context.steps.append(
|
||||
ExecutionStep(
|
||||
step_id=step["step_id"],
|
||||
function_name=step["function"],
|
||||
parameters=step["parameters"],
|
||||
expected_output=step["purpose"],
|
||||
)
|
||||
)
|
||||
|
||||
# Execution phase
|
||||
while context.current_step < len(context.steps):
|
||||
step = context.steps[context.current_step]
|
||||
print(
|
||||
f"\nExecuting step {step.step_id}: {step.function_name}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Execute function
|
||||
spec = self.function_specs[step.function_name]
|
||||
result = self._execute_function(
|
||||
spec, step.parameters
|
||||
)
|
||||
context.results[step.step_id] = result
|
||||
step.completed = True
|
||||
step.result = result
|
||||
|
||||
# Get agent's analysis
|
||||
analysis_prompt = f"""
|
||||
Step {step.step_id} completed:
|
||||
Function: {step.function_name}
|
||||
Result: {json.dumps(result)}
|
||||
Remaining steps: {len(context.steps) - context.current_step - 1}
|
||||
|
||||
Analyze the result and decide next action.
|
||||
"""
|
||||
|
||||
analysis_response = self.agent.run(
|
||||
analysis_prompt
|
||||
)
|
||||
analysis_data = json.loads(
|
||||
analysis_response.replace(
|
||||
"System:", ""
|
||||
).strip()
|
||||
)
|
||||
|
||||
execution_log["steps"].append(
|
||||
{
|
||||
"step_id": step.step_id,
|
||||
"function": step.function_name,
|
||||
"parameters": step.parameters,
|
||||
"result": result,
|
||||
"analysis": analysis_data,
|
||||
}
|
||||
)
|
||||
|
||||
if (
|
||||
analysis_data["next_action"]["type"]
|
||||
== "complete"
|
||||
):
|
||||
if (
|
||||
context.current_step
|
||||
< len(context.steps) - 1
|
||||
):
|
||||
continue
|
||||
break
|
||||
|
||||
context.current_step += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in step {step.step_id}: {str(e)}")
|
||||
execution_log["steps"].append(
|
||||
{
|
||||
"step_id": step.step_id,
|
||||
"function": step.function_name,
|
||||
"parameters": step.parameters,
|
||||
"error": str(e),
|
||||
}
|
||||
)
|
||||
raise
|
||||
|
||||
# Final analysis
|
||||
final_prompt = f"""
|
||||
Task completed. Results:
|
||||
{json.dumps(context.results, indent=2)}
|
||||
|
||||
Provide final analysis and recommendations.
|
||||
"""
|
||||
|
||||
final_analysis = self.agent.run(final_prompt)
|
||||
execution_log["final_result"] = {
|
||||
"success": True,
|
||||
"results": context.results,
|
||||
"analysis": json.loads(
|
||||
final_analysis.replace("System:", "").strip()
|
||||
),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
execution_log["final_result"] = {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
execution_log["end_time"] = datetime.utcnow().isoformat()
|
||||
return execution_log
|
||||
|
||||
|
||||
def calculate_investment_return(
|
||||
principal: float, rate: float, years: int
|
||||
) -> float:
|
||||
"""Calculate investment return with compound interest.
|
||||
|
||||
:param principal: Initial investment amount in dollars
|
||||
:param rate: Annual interest rate as decimal (e.g., 0.07 for 7%)
|
||||
:param years: Number of years to invest
|
||||
:return: Final investment value
|
||||
"""
|
||||
return principal * (1 + rate) ** years
|
||||
|
||||
|
||||
agent = ToolAgent(
|
||||
functions=[calculate_investment_return],
|
||||
openai_api_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
|
||||
result = agent.run(
|
||||
"Calculate returns for $10000 invested at 7% for 10 years"
|
||||
)
|
@ -1,109 +1,87 @@
|
||||
import os
|
||||
from typing import List, Any
|
||||
from swarms.structs.agent import Agent
|
||||
from loguru import logger
|
||||
import uuid
|
||||
|
||||
WORKSPACE_DIR = os.getenv("WORKSPACE_DIR")
|
||||
uuid_for_log = str(uuid.uuid4())
|
||||
logger.add(
|
||||
os.path.join(
|
||||
WORKSPACE_DIR,
|
||||
"agents_available",
|
||||
f"agents-available-{uuid_for_log}.log",
|
||||
),
|
||||
level="INFO",
|
||||
colorize=True,
|
||||
backtrace=True,
|
||||
diagnose=True,
|
||||
)
|
||||
|
||||
|
||||
def get_agent_name(agent: Any) -> str:
|
||||
"""Helper function to safely get agent name
|
||||
|
||||
Args:
|
||||
agent (Any): The agent object to get name from
|
||||
|
||||
Returns:
|
||||
str: The agent's name if found, 'Unknown' otherwise
|
||||
"""
|
||||
if isinstance(agent, Agent) and hasattr(agent, "agent_name"):
|
||||
return agent.agent_name
|
||||
return "Unknown"
|
||||
|
||||
|
||||
def get_agent_description(agent: Any) -> str:
|
||||
"""Helper function to get agent description or system prompt preview
|
||||
|
||||
Args:
|
||||
agent (Any): The agent object
|
||||
|
||||
Returns:
|
||||
str: Description or first 100 chars of system prompt
|
||||
"""
|
||||
if not isinstance(agent, Agent):
|
||||
return "N/A"
|
||||
|
||||
if hasattr(agent, "description") and agent.description:
|
||||
return agent.description
|
||||
|
||||
if hasattr(agent, "system_prompt") and agent.system_prompt:
|
||||
return f"{agent.system_prompt[:150]}..."
|
||||
|
||||
return "N/A"
|
||||
from typing import List
|
||||
|
||||
|
||||
def showcase_available_agents(
|
||||
agents: List[Agent],
|
||||
name: str = None,
|
||||
description: str = None,
|
||||
agents: List[Agent] = [],
|
||||
update_agents_on: bool = False,
|
||||
format: str = "XML",
|
||||
) -> str:
|
||||
"""
|
||||
Generate a formatted string showcasing all available agents and their descriptions.
|
||||
Format the available agents in either XML or Table format.
|
||||
|
||||
Args:
|
||||
agents (List[Agent]): List of Agent objects to showcase.
|
||||
update_agents_on (bool, optional): If True, updates each agent's system prompt with
|
||||
the showcase information. Defaults to False.
|
||||
agents (List[Agent]): A list of agents to represent
|
||||
name (str, optional): Name of the swarm
|
||||
description (str, optional): Description of the swarm
|
||||
format (str, optional): Output format ("XML" or "Table"). Defaults to "XML"
|
||||
|
||||
Returns:
|
||||
str: Formatted string containing agent information, including names, descriptions
|
||||
and IDs for all available agents.
|
||||
str: Formatted string containing agent information
|
||||
"""
|
||||
logger.info(f"Showcasing {len(agents)} available agents")
|
||||
|
||||
formatted_agents = []
|
||||
header = f"\n####### Agents available in the swarm: {name} ############\n"
|
||||
header += f"{description}\n"
|
||||
row_format = "{:<5} | {:<20} | {:<50}"
|
||||
header_row = row_format.format("ID", "Agent Name", "Description")
|
||||
separator = "-" * 80
|
||||
|
||||
formatted_agents.append(header)
|
||||
formatted_agents.append(separator)
|
||||
formatted_agents.append(header_row)
|
||||
formatted_agents.append(separator)
|
||||
|
||||
for idx, agent in enumerate(agents):
|
||||
if not isinstance(agent, Agent):
|
||||
logger.warning(
|
||||
f"Skipping non-Agent object: {type(agent)}"
|
||||
)
|
||||
continue
|
||||
|
||||
agent_name = get_agent_name(agent)
|
||||
description = (
|
||||
get_agent_description(agent)[:100] + "..."
|
||||
if len(get_agent_description(agent)) > 100
|
||||
else get_agent_description(agent)
|
||||
def truncate(text: str, max_length: int = 130) -> str:
|
||||
return (
|
||||
f"{text[:max_length]}..."
|
||||
if len(text) > max_length
|
||||
else text
|
||||
)
|
||||
|
||||
formatted_agents.append(
|
||||
row_format.format(idx + 1, agent_name, description)
|
||||
output = []
|
||||
|
||||
if format.upper() == "TABLE":
|
||||
output.append("\n| ID | Agent Name | Description |")
|
||||
output.append("|-----|------------|-------------|")
|
||||
for idx, agent in enumerate(agents):
|
||||
if isinstance(agent, Agent):
|
||||
agent_name = getattr(agent, "agent_name", str(agent))
|
||||
description = getattr(
|
||||
agent,
|
||||
"description",
|
||||
getattr(
|
||||
agent, "system_prompt", "Unknown description"
|
||||
),
|
||||
)
|
||||
desc = truncate(description, 50)
|
||||
output.append(
|
||||
f"| {idx + 1} | {agent_name} | {desc} |"
|
||||
)
|
||||
else:
|
||||
output.append(
|
||||
f"| {idx + 1} | {agent} | Unknown description |"
|
||||
)
|
||||
return "\n".join(output)
|
||||
|
||||
# Default XML format
|
||||
output.append("<agents>")
|
||||
if name:
|
||||
output.append(f" <name>{name}</name>")
|
||||
if description:
|
||||
output.append(
|
||||
f" <description>{truncate(description)}</description>"
|
||||
)
|
||||
for idx, agent in enumerate(agents):
|
||||
output.append(f" <agent id='{idx + 1}'>")
|
||||
if isinstance(agent, Agent):
|
||||
agent_name = getattr(agent, "agent_name", str(agent))
|
||||
description = getattr(
|
||||
agent,
|
||||
"description",
|
||||
getattr(
|
||||
agent, "system_prompt", "Unknown description"
|
||||
),
|
||||
)
|
||||
output.append(f" <name>{agent_name}</name>")
|
||||
output.append(
|
||||
f" <description>{truncate(description)}</description>"
|
||||
)
|
||||
else:
|
||||
output.append(f" <name>{agent}</name>")
|
||||
output.append(
|
||||
" <description>Unknown description</description>"
|
||||
)
|
||||
output.append(" </agent>")
|
||||
output.append("</agents>")
|
||||
|
||||
showcase = "\n".join(formatted_agents)
|
||||
|
||||
return showcase
|
||||
return "\n".join(output)
|
||||
|
@ -0,0 +1,203 @@
|
||||
from typing import Any
|
||||
import inspect
|
||||
from functools import partial
|
||||
import logging
|
||||
|
||||
|
||||
class NameResolver:
|
||||
"""Utility class for resolving names of various objects"""
|
||||
|
||||
@staticmethod
|
||||
def get_name(obj: Any, default: str = "unnamed_callable") -> str:
|
||||
"""
|
||||
Get the name of any object with multiple fallback strategies.
|
||||
|
||||
Args:
|
||||
obj: The object to get the name from
|
||||
default: Default name if all strategies fail
|
||||
|
||||
Returns:
|
||||
str: The resolved name
|
||||
"""
|
||||
strategies = [
|
||||
# Try getting __name__ attribute
|
||||
lambda x: getattr(x, "__name__", None),
|
||||
# Try getting class name
|
||||
lambda x: (
|
||||
x.__class__.__name__
|
||||
if hasattr(x, "__class__")
|
||||
else None
|
||||
),
|
||||
# Try getting function name if it's a partial
|
||||
lambda x: (
|
||||
x.func.__name__ if isinstance(x, partial) else None
|
||||
),
|
||||
# Try getting the name from the class's type
|
||||
lambda x: type(x).__name__,
|
||||
# Try getting qualname
|
||||
lambda x: getattr(x, "__qualname__", None),
|
||||
# Try getting the module and class name
|
||||
lambda x: (
|
||||
f"{x.__module__}.{x.__class__.__name__}"
|
||||
if hasattr(x, "__module__")
|
||||
else None
|
||||
),
|
||||
# For async functions
|
||||
lambda x: (
|
||||
x.__name__ if inspect.iscoroutinefunction(x) else None
|
||||
),
|
||||
# For classes with custom __str__
|
||||
lambda x: (
|
||||
str(x)
|
||||
if hasattr(x, "__str__")
|
||||
and x.__str__ != object.__str__
|
||||
else None
|
||||
),
|
||||
# For wrapped functions
|
||||
lambda x: (
|
||||
getattr(x, "__wrapped__", None).__name__
|
||||
if hasattr(x, "__wrapped__")
|
||||
else None
|
||||
),
|
||||
]
|
||||
|
||||
# Try each strategy
|
||||
for strategy in strategies:
|
||||
try:
|
||||
name = strategy(obj)
|
||||
if name and isinstance(name, str):
|
||||
return name.replace(" ", "_").replace("-", "_")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Return default if all strategies fail
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def get_callable_details(obj: Any) -> dict:
|
||||
"""
|
||||
Get detailed information about a callable object.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing:
|
||||
- name: The resolved name
|
||||
- type: The type of callable
|
||||
- signature: The signature if available
|
||||
- module: The module name if available
|
||||
- doc: The docstring if available
|
||||
"""
|
||||
details = {
|
||||
"name": NameResolver.get_name(obj),
|
||||
"type": "unknown",
|
||||
"signature": None,
|
||||
"module": getattr(obj, "__module__", "unknown"),
|
||||
"doc": inspect.getdoc(obj)
|
||||
or "No documentation available",
|
||||
}
|
||||
|
||||
# Determine the type
|
||||
if inspect.isclass(obj):
|
||||
details["type"] = "class"
|
||||
elif inspect.iscoroutinefunction(obj):
|
||||
details["type"] = "async_function"
|
||||
elif inspect.isfunction(obj):
|
||||
details["type"] = "function"
|
||||
elif isinstance(obj, partial):
|
||||
details["type"] = "partial"
|
||||
elif callable(obj):
|
||||
details["type"] = "callable"
|
||||
|
||||
# Try to get signature
|
||||
try:
|
||||
details["signature"] = str(inspect.signature(obj))
|
||||
except (ValueError, TypeError):
|
||||
details["signature"] = "Unknown signature"
|
||||
|
||||
return details
|
||||
|
||||
@classmethod
|
||||
def get_safe_name(cls, obj: Any, max_retries: int = 3) -> str:
|
||||
"""
|
||||
Safely get a name with retries and validation.
|
||||
|
||||
Args:
|
||||
obj: Object to get name from
|
||||
max_retries: Maximum number of retry attempts
|
||||
|
||||
Returns:
|
||||
str: A valid name string
|
||||
"""
|
||||
retries = 0
|
||||
last_error = None
|
||||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
name = cls.get_name(obj)
|
||||
|
||||
# Validate and clean the name
|
||||
if name:
|
||||
# Remove invalid characters
|
||||
clean_name = "".join(
|
||||
c
|
||||
for c in name
|
||||
if c.isalnum() or c in ["_", "."]
|
||||
)
|
||||
|
||||
# Ensure it starts with a letter or underscore
|
||||
if (
|
||||
not clean_name[0].isalpha()
|
||||
and clean_name[0] != "_"
|
||||
):
|
||||
clean_name = f"_{clean_name}"
|
||||
|
||||
return clean_name
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
retries += 1
|
||||
|
||||
# If all retries failed, generate a unique fallback name
|
||||
import uuid
|
||||
|
||||
fallback = f"callable_{uuid.uuid4().hex[:8]}"
|
||||
logging.warning(
|
||||
f"Failed to get name after {max_retries} retries. Using fallback: {fallback}. "
|
||||
f"Last error: {str(last_error)}"
|
||||
)
|
||||
return fallback
|
||||
|
||||
|
||||
# # Example usage
|
||||
# if __name__ == "__main__":
|
||||
# def test_resolver():
|
||||
# # Test cases
|
||||
# class TestClass:
|
||||
# def method(self):
|
||||
# pass
|
||||
|
||||
# async def async_func():
|
||||
# pass
|
||||
|
||||
# test_cases = [
|
||||
# TestClass, # Class
|
||||
# TestClass(), # Instance
|
||||
# async_func, # Async function
|
||||
# lambda x: x, # Lambda
|
||||
# partial(print, end=""), # Partial
|
||||
# TestClass.method, # Method
|
||||
# print, # Built-in function
|
||||
# str, # Built-in class
|
||||
# ]
|
||||
|
||||
# resolver = NameResolver()
|
||||
|
||||
# print("\nName Resolution Results:")
|
||||
# print("-" * 50)
|
||||
# for obj in test_cases:
|
||||
# details = resolver.get_callable_details(obj)
|
||||
# safe_name = resolver.get_safe_name(obj)
|
||||
# print(f"\nObject: {obj}")
|
||||
# print(f"Safe Name: {safe_name}")
|
||||
# print(f"Details: {details}")
|
||||
|
||||
# test_resolver()
|
@ -0,0 +1,53 @@
|
||||
import concurrent.futures
|
||||
from typing import List, Union
|
||||
from swarms.structs.agent import Agent
|
||||
|
||||
|
||||
def update_system_prompts(
|
||||
agents: List[Union[Agent, str]],
|
||||
prompt: str,
|
||||
) -> List[Agent]:
|
||||
"""
|
||||
Update system prompts for a list of agents concurrently.
|
||||
|
||||
Args:
|
||||
agents: List of Agent objects or strings to update
|
||||
prompt: The prompt text to append to each agent's system prompt
|
||||
|
||||
Returns:
|
||||
List of updated Agent objects
|
||||
"""
|
||||
if not agents:
|
||||
return agents
|
||||
|
||||
def update_agent_prompt(agent: Union[Agent, str]) -> Agent:
|
||||
# Convert string to Agent if needed
|
||||
if isinstance(agent, str):
|
||||
agent = Agent(
|
||||
agent_name=agent,
|
||||
system_prompt=prompt, # Initialize with the provided prompt
|
||||
)
|
||||
else:
|
||||
# Preserve existing prompt and append new one
|
||||
existing_prompt = (
|
||||
agent.system_prompt if agent.system_prompt else ""
|
||||
)
|
||||
agent.system_prompt = existing_prompt + "\n" + prompt
|
||||
return agent
|
||||
|
||||
# Use ThreadPoolExecutor for concurrent execution
|
||||
max_workers = min(len(agents), 4) # Reasonable thread count
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=max_workers
|
||||
) as executor:
|
||||
futures = []
|
||||
for agent in agents:
|
||||
future = executor.submit(update_agent_prompt, agent)
|
||||
futures.append(future)
|
||||
|
||||
# Collect results as they complete
|
||||
updated_agents = []
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
updated_agents.append(future.result())
|
||||
|
||||
return updated_agents
|
@ -0,0 +1,292 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from loguru import logger
|
||||
import math
|
||||
|
||||
|
||||
@dataclass
|
||||
class StarAttentionConfig:
|
||||
"""Configuration for StarAttention module.
|
||||
|
||||
Attributes:
|
||||
hidden_size: Dimension of the model's hidden states
|
||||
num_attention_heads: Number of attention heads
|
||||
num_hosts: Number of hosts in the distributed system
|
||||
block_size: Size of each context block
|
||||
anchor_size: Size of the anchor block
|
||||
dropout_prob: Dropout probability (default: 0.1)
|
||||
layer_norm_eps: Layer normalization epsilon (default: 1e-12)
|
||||
"""
|
||||
|
||||
hidden_size: int
|
||||
num_attention_heads: int
|
||||
num_hosts: int
|
||||
block_size: int
|
||||
anchor_size: int
|
||||
dropout_prob: float = 0.1
|
||||
layer_norm_eps: float = 1e-12
|
||||
|
||||
|
||||
class StarAttention(nn.Module):
|
||||
"""
|
||||
Implementation of Star Attention mechanism for distributed inference.
|
||||
|
||||
The module implements a two-phase attention mechanism:
|
||||
1. Local Context Encoding with Anchor Blocks
|
||||
2. Query Encoding and Output Generation with Global Attention
|
||||
"""
|
||||
|
||||
def __init__(self, config: StarAttentionConfig):
|
||||
super().__init__()
|
||||
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {config.hidden_size} not divisible by number of attention "
|
||||
f"heads {config.num_attention_heads}"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.head_dim = (
|
||||
config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
|
||||
# Initialize components
|
||||
self.query = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.key = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.value = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.dropout_prob)
|
||||
self.layer_norm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
# KV cache for storing computed key/value pairs
|
||||
self.kv_cache = {}
|
||||
|
||||
logger.info(
|
||||
f"Initialized StarAttention with config: {config}"
|
||||
)
|
||||
|
||||
def _split_heads(
|
||||
self, tensor: torch.Tensor, num_heads: int
|
||||
) -> torch.Tensor:
|
||||
"""Split the last dimension into (num_heads, head_dim)."""
|
||||
batch_size, seq_len, _ = tensor.size()
|
||||
tensor = tensor.view(
|
||||
batch_size, seq_len, num_heads, self.head_dim
|
||||
)
|
||||
# Transpose to (batch_size, num_heads, seq_len, head_dim)
|
||||
return tensor.transpose(1, 2)
|
||||
|
||||
def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Merge the head dimension back into hidden_size."""
|
||||
batch_size, _, seq_len, _ = tensor.size()
|
||||
tensor = tensor.transpose(1, 2)
|
||||
return tensor.reshape(
|
||||
batch_size, seq_len, self.config.hidden_size
|
||||
)
|
||||
|
||||
def _compute_attention_scores(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute attention scores and weighted values."""
|
||||
# Scale dot-product attention
|
||||
scores = torch.matmul(
|
||||
query, key.transpose(-2, -1)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, float("-inf"))
|
||||
|
||||
# Online softmax computation
|
||||
attention_probs = torch.nn.functional.softmax(scores, dim=-1)
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
context = torch.matmul(attention_probs, value)
|
||||
|
||||
return context, attention_probs
|
||||
|
||||
def phase1_local_context_encoding(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
host_id: int,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
) -> None:
|
||||
"""
|
||||
Phase 1: Local Context Encoding with Anchor Blocks
|
||||
|
||||
Args:
|
||||
input_ids: Input tensor of shape (batch_size, seq_len)
|
||||
host_id: ID of the current host
|
||||
device: Device to run computations on
|
||||
"""
|
||||
logger.debug(f"Starting Phase 1 on host {host_id}")
|
||||
|
||||
# Calculate block assignments
|
||||
block_start = host_id * self.config.block_size
|
||||
block_end = block_start + self.config.block_size
|
||||
|
||||
# Get local block
|
||||
local_block = input_ids[:, block_start:block_end].to(device)
|
||||
|
||||
# Get anchor block (first block)
|
||||
anchor_block = input_ids[:, : self.config.anchor_size].to(
|
||||
device
|
||||
)
|
||||
|
||||
# Compute KV pairs for local block
|
||||
local_hidden = self.layer_norm(local_block)
|
||||
local_key = self._split_heads(
|
||||
self.key(local_hidden), self.config.num_attention_heads
|
||||
)
|
||||
local_value = self._split_heads(
|
||||
self.value(local_hidden), self.config.num_attention_heads
|
||||
)
|
||||
|
||||
# Store in KV cache
|
||||
self.kv_cache[host_id] = {
|
||||
"key": local_key,
|
||||
"value": local_value,
|
||||
"anchor_key": (
|
||||
None
|
||||
if host_id == 0
|
||||
else self._split_heads(
|
||||
self.key(self.layer_norm(anchor_block)),
|
||||
self.config.num_attention_heads,
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"Phase 1 complete on host {host_id}. KV cache shapes - "
|
||||
f"key: {local_key.shape}, value: {local_value.shape}"
|
||||
)
|
||||
|
||||
def phase2_query_encoding(
|
||||
self,
|
||||
query_input: torch.Tensor,
|
||||
host_id: int,
|
||||
is_query_host: bool,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Phase 2: Query Encoding and Output Generation
|
||||
|
||||
Args:
|
||||
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
|
||||
host_id: ID of the current host
|
||||
is_query_host: Whether this host is the query host
|
||||
device: Device to run computations on
|
||||
|
||||
Returns:
|
||||
Output tensor if this is the query host, None otherwise
|
||||
"""
|
||||
logger.debug(f"Starting Phase 2 on host {host_id}")
|
||||
|
||||
# Transform query
|
||||
query_hidden = self.layer_norm(query_input)
|
||||
query = self._split_heads(
|
||||
self.query(query_hidden), self.config.num_attention_heads
|
||||
)
|
||||
|
||||
# Compute local attention scores
|
||||
local_context, local_probs = self._compute_attention_scores(
|
||||
query,
|
||||
self.kv_cache[host_id]["key"],
|
||||
self.kv_cache[host_id]["value"],
|
||||
)
|
||||
|
||||
if not is_query_host:
|
||||
# Non-query hosts send their local attention statistics
|
||||
dist.send(local_probs, dst=self.config.num_hosts - 1)
|
||||
return None
|
||||
|
||||
# Query host aggregates attention from all hosts
|
||||
all_attention_probs = [local_probs]
|
||||
for src_rank in range(self.config.num_hosts - 1):
|
||||
probs = torch.empty_like(local_probs)
|
||||
dist.recv(probs, src=src_rank)
|
||||
all_attention_probs.append(probs)
|
||||
|
||||
# Compute global attention
|
||||
torch.mean(torch.stack(all_attention_probs), dim=0)
|
||||
|
||||
# Final output computation
|
||||
output = self._merge_heads(local_context)
|
||||
output = self.dropout(output)
|
||||
|
||||
logger.debug(
|
||||
f"Phase 2 complete on host {host_id}. Output shape: {output.shape}"
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
query_input: torch.Tensor,
|
||||
host_id: int,
|
||||
is_query_host: bool,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Forward pass of the StarAttention module.
|
||||
|
||||
Args:
|
||||
input_ids: Input tensor of shape (batch_size, seq_len)
|
||||
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
|
||||
host_id: ID of the current host
|
||||
is_query_host: Whether this host is the query host
|
||||
device: Device to run computations on
|
||||
|
||||
Returns:
|
||||
Output tensor if this is the query host, None otherwise
|
||||
"""
|
||||
# Phase 1: Local Context Encoding
|
||||
self.phase1_local_context_encoding(input_ids, host_id, device)
|
||||
|
||||
# Phase 2: Query Encoding and Output Generation
|
||||
return self.phase2_query_encoding(
|
||||
query_input, host_id, is_query_host, device
|
||||
)
|
||||
|
||||
|
||||
# Example forward pass
|
||||
config = StarAttentionConfig(
|
||||
hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
num_hosts=3,
|
||||
block_size=512,
|
||||
anchor_size=128,
|
||||
)
|
||||
|
||||
# Initialize model
|
||||
model = StarAttention(config)
|
||||
|
||||
# Example input tensors
|
||||
batch_size = 4
|
||||
seq_len = 512
|
||||
input_ids = torch.randint(
|
||||
0, 1000, (batch_size, seq_len)
|
||||
) # Random input IDs
|
||||
query_input = torch.randn(
|
||||
batch_size, seq_len, config.hidden_size
|
||||
) # Random query input
|
||||
|
||||
# Example forward pass for query host (host_id = 2)
|
||||
output = model(
|
||||
input_ids=input_ids,
|
||||
query_input=query_input,
|
||||
host_id=2,
|
||||
is_query_host=True,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
print(output)
|
Loading…
Reference in new issue