From 73d23f1995fae920b32eb6de4c42d30c9eb55068 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 25 Nov 2024 23:38:33 -0800 Subject: [PATCH 1/6] [DELETE OLD PACKAGES] --- async_agents.py | 56 ++++ async_executor.py | 131 ++++++++++ groupchat_new.py | 244 ++++++++++++++++++ .../example_async_vs_multithread.py | 8 +- .../sequential_worflow_test.py | 0 .../sequential_workflow.py | 0 pyproject.toml | 25 +- requirements.txt | 11 +- swarms/structs/agent.py | 92 +++---- swarms/structs/base_workflow.py | 115 ++++----- swarms/structs/conversation.py | 62 ++--- swarms/structs/mixture_of_agents.py | 4 +- swarms/structs/multi_agent_exec.py | 2 +- swarms/structs/rearrange.py | 23 +- swarms/structs/sequential_workflow.py | 2 +- swarms/tools/json_former.py | 10 - swarms/tools/tool_utils.py | 20 +- swarms/utils/callable_name.py | 203 +++++++++++++++ swarms/utils/markdown_message.py | 9 +- 19 files changed, 787 insertions(+), 230 deletions(-) create mode 100644 async_agents.py create mode 100644 async_executor.py create mode 100644 groupchat_new.py rename example_async_vs_multithread.py => new_features_examples/example_async_vs_multithread.py (98%) rename sequential_worflow_test.py => new_features_examples/sequential_worflow_test.py (100%) rename sequential_workflow.py => new_features_examples/sequential_workflow.py (100%) create mode 100644 swarms/utils/callable_name.py diff --git a/async_agents.py b/async_agents.py new file mode 100644 index 00000000..0d9353db --- /dev/null +++ b/async_agents.py @@ -0,0 +1,56 @@ +import os + +from dotenv import load_dotenv +from swarm_models import OpenAIChat + +from swarms import Agent +from swarms.prompts.finance_agent_sys_prompt import ( + FINANCIAL_AGENT_SYS_PROMPT, +) +from async_executor import HighSpeedExecutor + +load_dotenv() + +# Get the OpenAI API key from the environment variable +api_key = os.getenv("OPENAI_API_KEY") + +# Create an instance of the OpenAIChat class +model = OpenAIChat( + openai_api_key=api_key, model_name="gpt-4o-mini", temperature=0.1 +) + +# Initialize the agent +agent = Agent( + agent_name="Financial-Analysis-Agent", + system_prompt=FINANCIAL_AGENT_SYS_PROMPT, + llm=model, + max_loops=1, + # autosave=True, + # dashboard=False, + # verbose=True, + # dynamic_temperature_enabled=True, + # saved_state_path="finance_agent.json", + # user_name="swarms_corp", + # retry_attempts=1, + # context_length=200000, + # return_step_meta=True, + # output_type="json", # "json", "dict", "csv" OR "string" soon "yaml" and + # auto_generate_prompt=False, # Auto generate prompt for the agent based on name, description, and system prompt, task + # # artifacts_on=True, + # artifacts_output_path="roth_ira_report", + # artifacts_file_extension=".txt", + # max_tokens=8000, + # return_history=True, +) + + +def execute_agent( + task: str = "How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria. Create a report on this question.", +): + return agent.run(task) + + +executor = HighSpeedExecutor() +results = executor.run(execute_agent, 2) + +print(results) diff --git a/async_executor.py b/async_executor.py new file mode 100644 index 00000000..e9fcfa4e --- /dev/null +++ b/async_executor.py @@ -0,0 +1,131 @@ +import asyncio +import multiprocessing as mp +import time +from functools import partial +from typing import Any, Dict, Union + + +class HighSpeedExecutor: + def __init__(self, num_processes: int = None): + """ + Initialize the executor with configurable number of processes. + If num_processes is None, it uses CPU count. + """ + self.num_processes = num_processes or mp.cpu_count() + + async def _worker( + self, + queue: asyncio.Queue, + func: Any, + *args: Any, + **kwargs: Any, + ): + """Async worker that processes tasks from the queue""" + while True: + try: + # Non-blocking get from queue + await queue.get() + await asyncio.get_event_loop().run_in_executor( + None, partial(func, *args, **kwargs) + ) + queue.task_done() + except asyncio.CancelledError: + break + + async def _distribute_tasks( + self, num_tasks: int, queue: asyncio.Queue + ): + """Distribute tasks across the queue""" + for i in range(num_tasks): + await queue.put(i) + + async def execute_batch( + self, + func: Any, + num_executions: int, + *args: Any, + **kwargs: Any, + ) -> Dict[str, Union[int, float]]: + """ + Execute the given function multiple times concurrently. + + Args: + func: The function to execute + num_executions: Number of times to execute the function + *args, **kwargs: Arguments to pass to the function + + Returns: + A dictionary containing the number of executions, duration, and executions per second. + """ + queue = asyncio.Queue() + + # Create worker tasks + workers = [ + asyncio.create_task( + self._worker(queue, func, *args, **kwargs) + ) + for _ in range(self.num_processes) + ] + + # Start timing + start_time = time.perf_counter() + + # Distribute tasks + await self._distribute_tasks(num_executions, queue) + + # Wait for all tasks to complete + await queue.join() + + # Cancel workers + for worker in workers: + worker.cancel() + + # Wait for all workers to finish + await asyncio.gather(*workers, return_exceptions=True) + + end_time = time.perf_counter() + duration = end_time - start_time + + return { + "executions": num_executions, + "duration": duration, + "executions_per_second": num_executions / duration, + } + + def run( + self, + func: Any, + num_executions: int, + *args: Any, + **kwargs: Any, + ): + return asyncio.run( + self.execute_batch(func, num_executions, *args, **kwargs) + ) + + +# def example_function(x: int = 0) -> int: +# """Example function to execute""" +# return x * x + + +# async def main(): +# # Create executor with number of CPU cores +# executor = HighSpeedExecutor() + +# # Execute the function 1000 times +# result = await executor.execute_batch( +# example_function, num_executions=1000, x=42 +# ) + +# print( +# f"Completed {result['executions']} executions in {result['duration']:.2f} seconds" +# ) +# print( +# f"Rate: {result['executions_per_second']:.2f} executions/second" +# ) + + +# if __name__ == "__main__": +# # Run the async main function +# asyncio.run(main()) diff --git a/groupchat_new.py b/groupchat_new.py new file mode 100644 index 00000000..69c424d4 --- /dev/null +++ b/groupchat_new.py @@ -0,0 +1,244 @@ +import os +import asyncio +from pydantic import BaseModel, Field +from typing import List, Dict, Any +from swarms import Agent +from swarm_models import OpenAIChat +from dotenv import load_dotenv +from swarms.utils.formatter import formatter + +# Load environment variables +load_dotenv() + +# Get OpenAI API key +api_key = os.getenv("OPENAI_API_KEY") + + +# Define Pydantic schema for agent outputs +class AgentOutput(BaseModel): + """Schema for capturing the output of each agent.""" + + agent_name: str = Field(..., description="The name of the agent") + message: str = Field( + ..., + description="The agent's response or contribution to the group chat", + ) + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="Additional metadata about the agent's response", + ) + + +class GroupChat: + """ + GroupChat class to enable multiple agents to communicate in an asynchronous group chat. + Each agent is aware of all other agents, every message exchanged, and the social context. + """ + + def __init__( + self, + name: str, + description: str, + agents: List[Agent], + max_loops: int = 1, + ): + """ + Initialize the GroupChat. + + Args: + name (str): Name of the group chat. + description (str): Description of the purpose of the group chat. + agents (List[Agent]): A list of agents participating in the chat. + max_loops (int): Maximum number of loops to run through all agents. + """ + self.name = name + self.description = description + self.agents = agents + self.max_loops = max_loops + self.chat_history = ( + [] + ) # Stores all messages exchanged in the chat + + formatter.print_panel( + f"Initialized GroupChat '{self.name}' with {len(self.agents)} agents. Max loops: {self.max_loops}", + title="Groupchat Swarm", + ) + + async def _agent_conversation( + self, agent: Agent, input_message: str + ) -> AgentOutput: + """ + Facilitate a single agent's response to the chat. + + Args: + agent (Agent): The agent responding. + input_message (str): The message triggering the response. + + Returns: + AgentOutput: The agent's response captured in a structured format. + """ + formatter.print_panel( + f"Agent '{agent.agent_name}' is responding to the message: {input_message}", + title="Groupchat Swarm", + ) + response = await asyncio.to_thread(agent.run, input_message) + + output = AgentOutput( + agent_name=agent.agent_name, + message=response, + metadata={"context_length": agent.context_length}, + ) + # logger.debug(f"Agent '{agent.agent_name}' response: {response}") + return output + + async def _run(self, initial_message: str) -> List[AgentOutput]: + """ + Execute the group chat asynchronously, looping through all agents up to max_loops. + + Args: + initial_message (str): The initial message to start the chat. + + Returns: + List[AgentOutput]: The responses of all agents across all loops. + """ + formatter.print_panel( + f"Starting group chat '{self.name}' with initial message: {initial_message}", + title="Groupchat Swarm", + ) + self.chat_history.append( + {"sender": "System", "message": initial_message} + ) + + outputs = [] + for loop in range(self.max_loops): + formatter.print_panel( + f"Group chat loop {loop + 1}/{self.max_loops}", + title="Groupchat Swarm", + ) + + for agent in self.agents: + # Create a custom input message for each agent, sharing the chat history and social context + input_message = ( + f"Chat History:\n{self._format_chat_history()}\n\n" + f"Participants:\n" + + "\n".join( + [ + f"- {a.agent_name}: {a.system_prompt}" + for a in self.agents + ] + ) + + f"\n\nNew Message: {initial_message}\n\n" + f"You are '{agent.agent_name}'. Remember to keep track of the social context, who is speaking, " + f"and respond accordingly based on your role: {agent.system_prompt}." + ) + + # Collect agent's response + output = await self._agent_conversation( + agent, input_message + ) + outputs.append(output) + + # Update chat history with the agent's response + self.chat_history.append( + { + "sender": agent.agent_name, + "message": output.message, + } + ) + + formatter.print_panel( + "Group chat completed. All agent responses captured.", + title="Groupchat Swarm", + ) + return outputs + + def run(self, task: str, *args, **kwargs): + return asyncio.run(self.run(task, *args, **kwargs)) + + def _format_chat_history(self) -> str: + """ + Format the chat history for agents to understand the context. + + Returns: + str: The formatted chat history as a string. + """ + return "\n".join( + [ + f"{entry['sender']}: {entry['message']}" + for entry in self.chat_history + ] + ) + + def __str__(self) -> str: + """String representation of the group chat's outputs.""" + return self._format_chat_history() + + def to_json(self) -> str: + """JSON representation of the group chat's outputs.""" + return [ + {"sender": entry["sender"], "message": entry["message"]} + for entry in self.chat_history + ] + + +# Example Usage +if __name__ == "__main__": + + load_dotenv() + + # Get the OpenAI API key from the environment variable + api_key = os.getenv("OPENAI_API_KEY") + + # Create an instance of the OpenAIChat class + model = OpenAIChat( + openai_api_key=api_key, + model_name="gpt-4o-mini", + temperature=0.1, + ) + + # Example agents + agent1 = Agent( + agent_name="Financial-Analysis-Agent", + system_prompt="You are a financial analyst specializing in investment strategies.", + llm=model, + max_loops=1, + autosave=False, + dashboard=False, + verbose=True, + dynamic_temperature_enabled=True, + user_name="swarms_corp", + retry_attempts=1, + context_length=200000, + output_type="string", + streaming_on=False, + ) + + agent2 = Agent( + agent_name="Tax-Adviser-Agent", + system_prompt="You are a tax adviser who provides clear and concise guidance on tax-related queries.", + llm=model, + max_loops=1, + autosave=False, + dashboard=False, + verbose=True, + dynamic_temperature_enabled=True, + user_name="swarms_corp", + retry_attempts=1, + context_length=200000, + output_type="string", + streaming_on=False, + ) + + # Create group chat + group_chat = GroupChat( + name="Financial Discussion", + description="A group chat for financial analysis and tax advice.", + agents=[agent1, agent2], + ) + + # Run the group chat + asyncio.run( + group_chat.run( + "How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria? What do you guys think?" + ) + ) diff --git a/example_async_vs_multithread.py b/new_features_examples/example_async_vs_multithread.py similarity index 98% rename from example_async_vs_multithread.py rename to new_features_examples/example_async_vs_multithread.py index f547abc8..25d514aa 100644 --- a/example_async_vs_multithread.py +++ b/new_features_examples/example_async_vs_multithread.py @@ -1,6 +1,5 @@ import os import asyncio -import threading from swarms import Agent from swarm_models import OpenAIChat import time @@ -40,18 +39,21 @@ agent = Agent( streaming_on=False, ) + # Function to measure time and memory usage def measure_time_and_memory(func): def wrapper(*args, **kwargs): start_time = time.time() result = func(*args, **kwargs) end_time = time.time() - memory_usage = psutil.Process().memory_info().rss / 1024 ** 2 + memory_usage = psutil.Process().memory_info().rss / 1024**2 print(f"Time taken: {end_time - start_time} seconds") print(f"Memory used: {memory_usage} MB") return result + return wrapper + # Function to run the agent asynchronously @measure_time_and_memory async def run_agent_async(): @@ -61,11 +63,13 @@ async def run_agent_async(): ) ) + # Function to run the agent on another thread @measure_time_and_memory def run_agent_thread(): asyncio.run(run_agent_async()) + # Run the agent asynchronously and on another thread to test the speed asyncio.run(run_agent_async()) run_agent_thread() diff --git a/sequential_worflow_test.py b/new_features_examples/sequential_worflow_test.py similarity index 100% rename from sequential_worflow_test.py rename to new_features_examples/sequential_worflow_test.py diff --git a/sequential_workflow.py b/new_features_examples/sequential_workflow.py similarity index 100% rename from sequential_workflow.py rename to new_features_examples/sequential_workflow.py diff --git a/pyproject.toml b/pyproject.toml index d8d06c61..bb3d5043 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "6.2.9" +version = "6.3.6" description = "Swarms - Pytorch" license = "MIT" authors = ["Kye Gomez "] @@ -37,6 +37,14 @@ keywords = [ "Generative AI", "Agent Marketplace", "Agent Store", + "quant", + "finance", + "algorithmic trading", + "portfolio optimization", + "risk management", + "financial modeling", + "machine learning for finance", + "natural language processing for finance", ] classifiers = [ "Development Status :: 4 - Beta", @@ -52,27 +60,18 @@ python = ">=3.10,<4.0" torch = ">=2.1.1,<3.0" transformers = ">= 4.39.0, <5.0.0" asyncio = ">=3.4.3,<4.0" -langchain-community = "0.0.29" -langchain-experimental = "0.0.55" -backoff = "2.2.1" toml = "*" pypdf = "4.3.1" -loguru = "0.7.2" +loguru = "*" pydantic = "2.8.2" -tenacity = "8.5.0" -Pillow = "10.4.0" +tenacity = "*" psutil = "*" sentry-sdk = {version = "*", extras = ["http"]} # Updated here python-dotenv = "*" PyYAML = "*" docstring_parser = "0.16" -fastapi = "*" -openai = ">=1.30.1,<2.0" -termcolor = "*" tiktoken = "*" networkx = "*" -swarms-memory = "*" -black = "*" aiofiles = "*" swarm-models = "*" clusterops = "*" @@ -96,9 +95,7 @@ mypy-protobuf = "^3.0.0" [tool.poetry.group.test.dependencies] pytest = "^8.1.1" -termcolor = "^2.4.0" pandas = "^2.2.2" -fastapi = ">=0.110.1,<0.116.0" [tool.ruff] line-length = 70 diff --git a/requirements.txt b/requirements.txt index d422222b..ca9fdcdd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,21 +2,16 @@ torch>=2.1.1,<3.0 transformers>=4.39.0,<5.0.0 asyncio>=3.4.3,<4.0 -langchain-community==0.0.28 -langchain-experimental==0.0.55 -backoff==2.2.1 toml pypdf==4.3.1 ratelimit==2.2.1 loguru==0.7.2 pydantic==2.8.2 -tenacity==8.5.0 -Pillow==10.4.0 +tenacity rich psutil sentry-sdk python-dotenv -opencv-python-headless PyYAML docstring_parser==0.16 black>=23.1,<25.0 @@ -26,12 +21,8 @@ types-pytz>=2023.3,<2025.0 types-chardet>=5.0.4.6 mypy-protobuf>=3.0.0 pytest>=8.1.1 -termcolor>=2.4.0 pandas>=2.2.2 -fastapi>=0.110.1 networkx -swarms-memory -pre-commit aiofiles swarm-models clusterops diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index a4c04a16..d1f3a745 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -24,8 +24,6 @@ import toml import yaml from pydantic import BaseModel from swarm_models.tiktoken_wrapper import TikTokenizer -from termcolor import colored - from swarms.agents.ape_agent import auto_generate_prompt from swarms.prompts.agent_system_prompts import AGENT_SYSTEM_PROMPT_3 from swarms.prompts.multi_modal_autonomous_instruction_prompt import ( @@ -671,11 +669,8 @@ class Agent: return self.stopping_condition(response) return False except Exception as error: - print( - colored( - f"Error checking stopping condition: {error}", - "red", - ) + logger.error( + f"Error checking stopping condition: {error}" ) def dynamic_temperature(self): @@ -688,21 +683,20 @@ class Agent: try: if hasattr(self.llm, "temperature"): # Randomly change the temperature attribute of self.llm object - logger.info("Enabling Random Dyamic Temperature") self.llm.temperature = random.uniform(0.0, 1.0) else: # Use a default temperature self.llm.temperature = 0.5 except Exception as error: - print( - colored( - f"Error dynamically changing temperature: {error}" - ) + logger.error( + f"Error dynamically changing temperature: {error}" ) def print_dashboard(self): """Print dashboard""" - print(colored("Initializing Agent Dashboard...", "yellow")) + formatter.print_panel( + f"Initializing Agent: {self.agent_name}" + ) data = self.to_dict() @@ -710,22 +704,19 @@ class Agent: # data = json.dumps(data, indent=4) # json_data = json.dumps(data, indent=4) - print( - colored( - f""" - Agent Dashboard - -------------------------------------------- + formatter.print_panel( + f""" + Agent Dashboard + -------------------------------------------- - Agent {self.agent_name} is initializing for {self.max_loops} with the following configuration: - ---------------------------------------- + Agent {self.agent_name} is initializing for {self.max_loops} with the following configuration: + ---------------------------------------- - Agent Configuration: - Configuration: {data} + Agent Configuration: + Configuration: {data} - ---------------------------------------- - """, - "green", - ) + ---------------------------------------- + """, ) def loop_count_print( @@ -737,7 +728,7 @@ class Agent: loop_count (_type_): _description_ max_loops (_type_): _description_ """ - print(colored(f"\nLoop {loop_count} of {max_loops}", "cyan")) + logger.info(f"\nLoop {loop_count} of {max_loops}") print("\n") # Check parameters @@ -761,8 +752,8 @@ class Agent: self, task: Optional[str] = None, img: Optional[str] = None, - is_last: bool = False, - print_task: bool = False, + is_last: Optional[bool] = False, + print_task: Optional[bool] = False, *args, **kwargs, ) -> Any: @@ -960,7 +951,7 @@ class Agent: if self.interactive: logger.info("Interactive mode enabled.") - user_input = colored(input("You: "), "red") + user_input = formatter.print_panel(input("You: ")) # User-defined exit command if ( @@ -1060,7 +1051,7 @@ class Agent: except Exception as error: logger.info( - f"Error running agent: {error} optimize your input parameters" + f"Error running agent: {error} optimize your input parameter" ) raise error @@ -1261,7 +1252,7 @@ class Agent: logger.info(f"Running bulk tasks: {inputs}") return [self.run(**input_data) for input_data in inputs] except Exception as error: - print(colored(f"Error running bulk run: {error}", "red")) + logger.info(f"Error running bulk run: {error}", "red") def save(self) -> None: """Save the agent history to a file. @@ -1438,9 +1429,7 @@ class Agent: with open(file_path, "w") as f: yaml.dump(self.to_dict(), f) except Exception as error: - logger.error( - colored(f"Error saving agent to YAML: {error}", "red") - ) + logger.error(f"Error saving agent to YAML: {error}") raise error def get_llm_parameters(self): @@ -1505,7 +1494,7 @@ class Agent: role=self.user_name, content=data ) except Exception as error: - print(colored(f"Error ingesting docs: {error}", "red")) + logger.info(f"Error ingesting docs: {error}", "red") def ingest_pdf(self, pdf: str): """Ingest the pdf into the memory @@ -1520,7 +1509,7 @@ class Agent: role=self.user_name, content=text ) except Exception as error: - print(colored(f"Error ingesting pdf: {error}", "red")) + logger.info(f"Error ingesting pdf: {error}", "red") def receieve_message(self, name: str, message: str): """Receieve a message""" @@ -1604,12 +1593,10 @@ class Agent: role=self.user_name, content=text ) except Exception as error: - print( - colored( - f"Error getting docs from doc folders: {error}", - "red", - ) + logger.error( + f"Error getting docs from doc folders: {error}" ) + raise error def check_end_session_agentops(self): if self.agent_ops_on is True: @@ -1629,7 +1616,8 @@ class Agent: try: # Query the long term memory if self.long_term_memory is not None: - logger.info(f"Querying long term memory for: {task}") + formatter.print_panel(f"Querying RAG for: {task}") + memory_retrieval = self.long_term_memory.query( task, *args, **kwargs ) @@ -1638,15 +1626,15 @@ class Agent: f"Documents Available: {str(memory_retrieval)}" ) - # Count the tokens - memory_token_count = self.tokenizer.count_tokens( - memory_retrieval - ) - if memory_token_count > self.memory_chunk_size: - # Truncate the memory by the memory chunk size - memory_retrieval = self.truncate_string_by_tokens( - memory_retrieval, self.memory_chunk_size - ) + # # Count the tokens + # memory_token_count = self.tokenizer.count_tokens( + # memory_retrieval + # ) + # if memory_token_count > self.memory_chunk_size: + # # Truncate the memory by the memory chunk size + # memory_retrieval = self.truncate_string_by_tokens( + # memory_retrieval, self.memory_chunk_size + # ) self.short_memory.add( role="Database", diff --git a/swarms/structs/base_workflow.py b/swarms/structs/base_workflow.py index b75bfe2c..4107042a 100644 --- a/swarms/structs/base_workflow.py +++ b/swarms/structs/base_workflow.py @@ -1,8 +1,7 @@ import json from typing import Any, Dict, List, Optional -from termcolor import colored - +from swarms.utils.formatter import formatter from swarms.structs.agent import Agent from swarms.structs.base_structure import BaseStructure from swarms.structs.task import Task @@ -132,9 +131,10 @@ class BaseWorkflow(BaseStructure): for task in self.tasks: task.result = None except Exception as error: - print( - colored(f"Error resetting workflow: {error}", "red"), + formatter.print_panel( + f"Error resetting workflow: {error}" ) + raise error def get_task_results(self) -> Dict[str, Any]: """ @@ -148,10 +148,8 @@ class BaseWorkflow(BaseStructure): task.description: task.result for task in self.tasks } except Exception as error: - print( - colored( - f"Error getting task results: {error}", "red" - ), + formatter.print_panel( + f"Error getting task results: {error}" ) def remove_task(self, task: str) -> None: @@ -163,12 +161,10 @@ class BaseWorkflow(BaseStructure): if task.description != task ] except Exception as error: - print( - colored( - f"Error removing task from workflow: {error}", - "red", - ), + formatter.print_panel( + f"Error removing task from workflow: {error}", ) + raise error def update_task(self, task: str, **updates) -> None: """ @@ -203,11 +199,9 @@ class BaseWorkflow(BaseStructure): f"Task {task} not found in workflow." ) except Exception as error: - print( - colored( - f"Error updating task in workflow: {error}", "red" - ), - ) + formatter.print_panel( + f"Error updating task in workflow: {error}" + ), def delete_task(self, task: str) -> None: """ @@ -240,12 +234,10 @@ class BaseWorkflow(BaseStructure): f"Task {task} not found in workflow." ) except Exception as error: - print( - colored( - f"Error deleting task from workflow: {error}", - "red", - ), + formatter.print_panel( + f"Error deleting task from workflow: {error}", ) + raise error def save_workflow_state( self, @@ -287,23 +279,18 @@ class BaseWorkflow(BaseStructure): } json.dump(state, f, indent=4) except Exception as error: - print( - colored( - f"Error saving workflow state: {error}", - "red", - ) + formatter.print_panel( + f"Error saving workflow state: {error}", ) + raise error def add_objective_to_workflow(self, task: str, **kwargs) -> None: """Adds an objective to the workflow.""" try: - print( - colored( - """ - Adding Objective to Workflow...""", - "green", - attrs=["bold", "underline"], - ) + formatter.print_panel( + """ + Adding Objective to Workflow...""", + "green", ) task = Task( @@ -314,12 +301,10 @@ class BaseWorkflow(BaseStructure): ) self.tasks.append(task) except Exception as error: - print( - colored( - f"Error adding objective to workflow: {error}", - "red", - ) + formatter.print_panel( + f"Error adding objective to workflow: {error}", ) + raise error def load_workflow_state( self, filepath: str = None, **kwargs @@ -359,11 +344,8 @@ class BaseWorkflow(BaseStructure): ) self.tasks.append(task) except Exception as error: - print( - colored( - f"Error loading workflow state: {error}", - "red", - ) + formatter.print_panel( + f"Error loading workflow state: {error}", ) def workflow_dashboard(self, **kwargs) -> None: @@ -383,25 +365,21 @@ class BaseWorkflow(BaseStructure): >>> workflow.workflow_dashboard() """ - print( - colored( - f""" - Sequential Workflow Dashboard - -------------------------------- - Name: {self.name} - Description: {self.description} - task_pool: {len(self.task_pool)} - Max Loops: {self.max_loops} - Autosave: {self.autosave} - Autosave Filepath: {self.saved_state_filepath} - Restore Filepath: {self.restore_state_filepath} - -------------------------------- - Metadata: - kwargs: {kwargs} - """, - "cyan", - attrs=["bold", "underline"], - ) + formatter.print_panel( + f""" + Sequential Workflow Dashboard + -------------------------------- + Name: {self.name} + Description: {self.description} + task_pool: {len(self.task_pool)} + Max Loops: {self.max_loops} + Autosave: {self.autosave} + Autosave Filepath: {self.saved_state_filepath} + Restore Filepath: {self.restore_state_filepath} + -------------------------------- + Metadata: + kwargs: {kwargs} + """ ) def workflow_bootup(self, **kwargs) -> None: @@ -409,11 +387,6 @@ class BaseWorkflow(BaseStructure): Workflow bootup. """ - print( - colored( - """ - Sequential Workflow Initializing...""", - "green", - attrs=["bold", "underline"], - ) + formatter.print_panel( + """Sequential Workflow Initializing...""", ) diff --git a/swarms/structs/conversation.py b/swarms/structs/conversation.py index f808382d..a86a6d3b 100644 --- a/swarms/structs/conversation.py +++ b/swarms/structs/conversation.py @@ -3,10 +3,9 @@ import json from typing import Any, Optional import yaml -from termcolor import colored - from swarms.structs.base_structure import BaseStructure from typing import TYPE_CHECKING +from swarms.utils.formatter import formatter if TYPE_CHECKING: from swarms.structs.agent import ( @@ -191,18 +190,9 @@ class Conversation(BaseStructure): Args: detailed (bool, optional): detailed. Defaults to False. """ - role_to_color = { - "system": "red", - "user": "green", - "assistant": "blue", - "function": "magenta", - } for message in self.conversation_history: - print( - colored( - f"{message['role']}: {message['content']}\n\n", - role_to_color[message["role"]], - ) + formatter.print_panel( + f"{message['role']}: {message['content']}\n\n" ) def export_conversation(self, filename: str, *args, **kwargs): @@ -307,46 +297,36 @@ class Conversation(BaseStructure): for message in messages: if message["role"] == "system": - print( - colored( - f"system: {message['content']}\n", - role_to_color[message["role"]], - ) + formatter.print_panel( + f"system: {message['content']}\n", + role_to_color[message["role"]], ) elif message["role"] == "user": - print( - colored( - f"user: {message['content']}\n", - role_to_color[message["role"]], - ) + formatter.print_panel( + f"user: {message['content']}\n", + role_to_color[message["role"]], ) elif message["role"] == "assistant" and message.get( "function_call" ): - print( - colored( - f"assistant: {message['function_call']}\n", - role_to_color[message["role"]], - ) + formatter.print_panel( + f"assistant: {message['function_call']}\n", + role_to_color[message["role"]], ) elif message["role"] == "assistant" and not message.get( "function_call" ): - print( - colored( - f"assistant: {message['content']}\n", - role_to_color[message["role"]], - ) + formatter.print_panel( + f"assistant: {message['content']}\n", + role_to_color[message["role"]], ) elif message["role"] == "tool": - print( - colored( - ( - f"function ({message['name']}):" - f" {message['content']}\n" - ), - role_to_color[message["role"]], - ) + formatter.print_panel( + ( + f"function ({message['name']}):" + f" {message['content']}\n" + ), + role_to_color[message["role"]], ) def truncate_memory_with_tokenizer(self): diff --git a/swarms/structs/mixture_of_agents.py b/swarms/structs/mixture_of_agents.py index f80701ef..e91d565f 100644 --- a/swarms/structs/mixture_of_agents.py +++ b/swarms/structs/mixture_of_agents.py @@ -86,9 +86,7 @@ class MixtureOfAgents: self.input_schema = MixtureOfAgentsInput( name=name, description=description, - agents=[ - agent.to_dict() for agent in self.agents - ], + agents=[agent.to_dict() for agent in self.agents], aggregator_agent=aggregator_agent.to_dict(), aggregator_system_prompt=self.aggregator_system_prompt, layers=self.layers, diff --git a/swarms/structs/multi_agent_exec.py b/swarms/structs/multi_agent_exec.py index d733f49f..b66af8a5 100644 --- a/swarms/structs/multi_agent_exec.py +++ b/swarms/structs/multi_agent_exec.py @@ -414,7 +414,7 @@ def run_agents_with_tasks_concurrently( List[Any]: A list of outputs from each agent execution. """ # Make the first agent not use the ifrs - + if no_clusterops: return _run_agents_with_tasks_concurrently( agents, tasks, batch_size, max_workers diff --git a/swarms/structs/rearrange.py b/swarms/structs/rearrange.py index 7c71bd04..a61efbee 100644 --- a/swarms/structs/rearrange.py +++ b/swarms/structs/rearrange.py @@ -121,16 +121,14 @@ class AgentRearrange(BaseSwarm): output_type: OutputType = "final", docs: List[str] = None, doc_folder: str = None, + device: str = "cpu", + device_id: int = 0, + all_cores: bool = False, + all_gpus: bool = True, + no_use_clusterops: bool = True, *args, **kwargs, ): - # reliability_check( - # agents=agents, - # name=name, - # description=description, - # flow=flow, - # max_loops=max_loops, - # ) super(AgentRearrange, self).__init__( name=name, description=description, @@ -150,6 +148,11 @@ class AgentRearrange(BaseSwarm): self.output_type = output_type self.docs = docs self.doc_folder = doc_folder + self.device = device + self.device_id = device_id + self.all_cores = all_cores + self.all_gpus = all_gpus + self.no_use_clusterops = no_use_clusterops self.swarm_history = { agent.agent_name: [] for agent in agents } @@ -506,7 +509,11 @@ class AgentRearrange(BaseSwarm): Returns: The result from executing the task through the cluster operations wrapper. """ - if no_use_clusterops: + no_use_clusterops = ( + no_use_clusterops or self.no_use_clusterops + ) + + if no_use_clusterops is True: return self._run( task=task, img=img, diff --git a/swarms/structs/sequential_workflow.py b/swarms/structs/sequential_workflow.py index cebcd7f0..ed55102d 100644 --- a/swarms/structs/sequential_workflow.py +++ b/swarms/structs/sequential_workflow.py @@ -107,7 +107,7 @@ class SequentialWorkflow: all_cores: bool = False, all_gpus: bool = False, device_id: int = 0, - no_use_clusterops: bool = False, + no_use_clusterops: bool = True, *args, **kwargs, ) -> str: diff --git a/swarms/tools/json_former.py b/swarms/tools/json_former.py index 01d608a5..dcca9932 100644 --- a/swarms/tools/json_former.py +++ b/swarms/tools/json_former.py @@ -1,7 +1,6 @@ import json from typing import Any, Dict, List, Union -from termcolor import cprint from transformers import PreTrainedModel, PreTrainedTokenizer from pydantic import BaseModel from swarms.tools.logits_processor import ( @@ -68,15 +67,6 @@ class Jsonformer: self.temperature = temperature self.max_string_token_length = max_string_token_length - def debug(self, caller: str, value: str, is_prompt: bool = False): - if self.debug_on: - if is_prompt: - cprint(caller, "green", end=" ") - cprint(value, "yellow") - else: - cprint(caller, "green", end=" ") - cprint(value, "blue") - def generate_number( self, temperature: Union[float, None] = None, iterations=0 ): diff --git a/swarms/tools/tool_utils.py b/swarms/tools/tool_utils.py index 9076e2d1..972f98ec 100644 --- a/swarms/tools/tool_utils.py +++ b/swarms/tools/tool_utils.py @@ -3,8 +3,7 @@ from typing import Any, List import inspect from typing import Callable - -from termcolor import colored +from swarms.utils.formatter import formatter def scrape_tool_func_docs(fn: Callable) -> str: @@ -37,17 +36,16 @@ def scrape_tool_func_docs(fn: Callable) -> str: f" {inspect.getdoc(fn)}\nParameters:\n{parameters_str}" ) except Exception as error: - print( - colored( - ( - f"Error scraping tool function docs {error} try" - " optimizing your inputs with different" - " variables and attempt once more." - ), - "red", - ) + error_msg = ( + formatter.print_panel( + f"Error scraping tool function docs {error} try" + " optimizing your inputs with different" + " variables and attempt once more." + ), ) + raise error + def tool_find_by_name(tool_name: str, tools: List[Any]): """Find the tool by name""" diff --git a/swarms/utils/callable_name.py b/swarms/utils/callable_name.py new file mode 100644 index 00000000..9a0b037f --- /dev/null +++ b/swarms/utils/callable_name.py @@ -0,0 +1,203 @@ +from typing import Any +import inspect +from functools import partial +import logging + + +class NameResolver: + """Utility class for resolving names of various objects""" + + @staticmethod + def get_name(obj: Any, default: str = "unnamed_callable") -> str: + """ + Get the name of any object with multiple fallback strategies. + + Args: + obj: The object to get the name from + default: Default name if all strategies fail + + Returns: + str: The resolved name + """ + strategies = [ + # Try getting __name__ attribute + lambda x: getattr(x, "__name__", None), + # Try getting class name + lambda x: ( + x.__class__.__name__ + if hasattr(x, "__class__") + else None + ), + # Try getting function name if it's a partial + lambda x: ( + x.func.__name__ if isinstance(x, partial) else None + ), + # Try getting the name from the class's type + lambda x: type(x).__name__, + # Try getting qualname + lambda x: getattr(x, "__qualname__", None), + # Try getting the module and class name + lambda x: ( + f"{x.__module__}.{x.__class__.__name__}" + if hasattr(x, "__module__") + else None + ), + # For async functions + lambda x: ( + x.__name__ if inspect.iscoroutinefunction(x) else None + ), + # For classes with custom __str__ + lambda x: ( + str(x) + if hasattr(x, "__str__") + and x.__str__ != object.__str__ + else None + ), + # For wrapped functions + lambda x: ( + getattr(x, "__wrapped__", None).__name__ + if hasattr(x, "__wrapped__") + else None + ), + ] + + # Try each strategy + for strategy in strategies: + try: + name = strategy(obj) + if name and isinstance(name, str): + return name.replace(" ", "_").replace("-", "_") + except Exception: + continue + + # Return default if all strategies fail + return default + + @staticmethod + def get_callable_details(obj: Any) -> dict: + """ + Get detailed information about a callable object. + + Returns: + dict: Dictionary containing: + - name: The resolved name + - type: The type of callable + - signature: The signature if available + - module: The module name if available + - doc: The docstring if available + """ + details = { + "name": NameResolver.get_name(obj), + "type": "unknown", + "signature": None, + "module": getattr(obj, "__module__", "unknown"), + "doc": inspect.getdoc(obj) + or "No documentation available", + } + + # Determine the type + if inspect.isclass(obj): + details["type"] = "class" + elif inspect.iscoroutinefunction(obj): + details["type"] = "async_function" + elif inspect.isfunction(obj): + details["type"] = "function" + elif isinstance(obj, partial): + details["type"] = "partial" + elif callable(obj): + details["type"] = "callable" + + # Try to get signature + try: + details["signature"] = str(inspect.signature(obj)) + except (ValueError, TypeError): + details["signature"] = "Unknown signature" + + return details + + @classmethod + def get_safe_name(cls, obj: Any, max_retries: int = 3) -> str: + """ + Safely get a name with retries and validation. + + Args: + obj: Object to get name from + max_retries: Maximum number of retry attempts + + Returns: + str: A valid name string + """ + retries = 0 + last_error = None + + while retries < max_retries: + try: + name = cls.get_name(obj) + + # Validate and clean the name + if name: + # Remove invalid characters + clean_name = "".join( + c + for c in name + if c.isalnum() or c in ["_", "."] + ) + + # Ensure it starts with a letter or underscore + if ( + not clean_name[0].isalpha() + and clean_name[0] != "_" + ): + clean_name = f"_{clean_name}" + + return clean_name + + except Exception as e: + last_error = e + retries += 1 + + # If all retries failed, generate a unique fallback name + import uuid + + fallback = f"callable_{uuid.uuid4().hex[:8]}" + logging.warning( + f"Failed to get name after {max_retries} retries. Using fallback: {fallback}. " + f"Last error: {str(last_error)}" + ) + return fallback + + +# # Example usage +# if __name__ == "__main__": +# def test_resolver(): +# # Test cases +# class TestClass: +# def method(self): +# pass + +# async def async_func(): +# pass + +# test_cases = [ +# TestClass, # Class +# TestClass(), # Instance +# async_func, # Async function +# lambda x: x, # Lambda +# partial(print, end=""), # Partial +# TestClass.method, # Method +# print, # Built-in function +# str, # Built-in class +# ] + +# resolver = NameResolver() + +# print("\nName Resolution Results:") +# print("-" * 50) +# for obj in test_cases: +# details = resolver.get_callable_details(obj) +# safe_name = resolver.get_safe_name(obj) +# print(f"\nObject: {obj}") +# print(f"Safe Name: {safe_name}") +# print(f"Details: {details}") + +# test_resolver() diff --git a/swarms/utils/markdown_message.py b/swarms/utils/markdown_message.py index a85cb4a1..03a35092 100644 --- a/swarms/utils/markdown_message.py +++ b/swarms/utils/markdown_message.py @@ -1,4 +1,4 @@ -from termcolor import colored +from swarms.utils.formatter import formatter def display_markdown_message(message: str, color: str = "cyan"): @@ -12,13 +12,10 @@ def display_markdown_message(message: str, color: str = "cyan"): if line == "": print() elif line == "---": - print(colored("-" * 50, color)) + formatter.print_panel("-" * 50) else: - print(colored(line, color)) + formatter.print_panel(line) if "\n" not in message and message.startswith(">"): # Aesthetic choice. For these tags, they need a space below them print() - - -# display_markdown_message("I love you and you are beautiful.", "cyan") From 96c50bb20c00f25e15fd702735661a1473d70e32 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 27 Nov 2024 12:20:25 -0800 Subject: [PATCH 2/6] [CLEANUP] --- .../rearrange_test.py => rearrange_test.py | 4 +- swarms/structs/agents_available.py | 160 ++++++++---------- swarms/structs/rearrange.py | 101 ++++++----- swarms/utils/update_agent_system_prompts.py | 53 ++++++ 4 files changed, 180 insertions(+), 138 deletions(-) rename new_features_examples/rearrange_test.py => rearrange_test.py (98%) create mode 100644 swarms/utils/update_agent_system_prompts.py diff --git a/new_features_examples/rearrange_test.py b/rearrange_test.py similarity index 98% rename from new_features_examples/rearrange_test.py rename to rearrange_test.py index ddfd7670..d85e435a 100644 --- a/new_features_examples/rearrange_test.py +++ b/rearrange_test.py @@ -95,12 +95,14 @@ flow = "BossAgent -> ExpenseAnalyzer -> SummaryGenerator" # Using AgentRearrange class to manage the swarm agent_system = AgentRearrange( + name="pe-swarm", + description="ss", agents=agents, flow=flow, return_json=False, output_type="final", max_loops=1, - docs=["SECURITY.md"], + # docs=["SECURITY.md"], ) # Input task for the swarm diff --git a/swarms/structs/agents_available.py b/swarms/structs/agents_available.py index f676877d..5651f9b0 100644 --- a/swarms/structs/agents_available.py +++ b/swarms/structs/agents_available.py @@ -1,109 +1,87 @@ -import os -from typing import List, Any from swarms.structs.agent import Agent -from loguru import logger -import uuid - -WORKSPACE_DIR = os.getenv("WORKSPACE_DIR") -uuid_for_log = str(uuid.uuid4()) -logger.add( - os.path.join( - WORKSPACE_DIR, - "agents_available", - f"agents-available-{uuid_for_log}.log", - ), - level="INFO", - colorize=True, - backtrace=True, - diagnose=True, -) - - -def get_agent_name(agent: Any) -> str: - """Helper function to safely get agent name - - Args: - agent (Any): The agent object to get name from - - Returns: - str: The agent's name if found, 'Unknown' otherwise - """ - if isinstance(agent, Agent) and hasattr(agent, "agent_name"): - return agent.agent_name - return "Unknown" - - -def get_agent_description(agent: Any) -> str: - """Helper function to get agent description or system prompt preview - - Args: - agent (Any): The agent object - - Returns: - str: Description or first 100 chars of system prompt - """ - if not isinstance(agent, Agent): - return "N/A" - - if hasattr(agent, "description") and agent.description: - return agent.description - - if hasattr(agent, "system_prompt") and agent.system_prompt: - return f"{agent.system_prompt[:150]}..." - - return "N/A" +from typing import List def showcase_available_agents( + agents: List[Agent], name: str = None, description: str = None, - agents: List[Agent] = [], - update_agents_on: bool = False, + format: str = "XML", ) -> str: """ - Generate a formatted string showcasing all available agents and their descriptions. + Format the available agents in either XML or Table format. Args: - agents (List[Agent]): List of Agent objects to showcase. - update_agents_on (bool, optional): If True, updates each agent's system prompt with - the showcase information. Defaults to False. + agents (List[Agent]): A list of agents to represent + name (str, optional): Name of the swarm + description (str, optional): Description of the swarm + format (str, optional): Output format ("XML" or "Table"). Defaults to "XML" Returns: - str: Formatted string containing agent information, including names, descriptions - and IDs for all available agents. + str: Formatted string containing agent information """ - logger.info(f"Showcasing {len(agents)} available agents") - - formatted_agents = [] - header = f"\n####### Agents available in the swarm: {name} ############\n" - header += f"{description}\n" - row_format = "{:<5} | {:<20} | {:<50}" - header_row = row_format.format("ID", "Agent Name", "Description") - separator = "-" * 80 - - formatted_agents.append(header) - formatted_agents.append(separator) - formatted_agents.append(header_row) - formatted_agents.append(separator) - - for idx, agent in enumerate(agents): - if not isinstance(agent, Agent): - logger.warning( - f"Skipping non-Agent object: {type(agent)}" - ) - continue - agent_name = get_agent_name(agent) - description = ( - get_agent_description(agent)[:100] + "..." - if len(get_agent_description(agent)) > 100 - else get_agent_description(agent) + def truncate(text: str, max_length: int = 130) -> str: + return ( + f"{text[:max_length]}..." + if len(text) > max_length + else text ) - formatted_agents.append( - row_format.format(idx + 1, agent_name, description) + output = [] + + if format.upper() == "TABLE": + output.append("\n| ID | Agent Name | Description |") + output.append("|-----|------------|-------------|") + for idx, agent in enumerate(agents): + if isinstance(agent, Agent): + agent_name = getattr(agent, "agent_name", str(agent)) + description = getattr( + agent, + "description", + getattr( + agent, "system_prompt", "Unknown description" + ), + ) + desc = truncate(description, 50) + output.append( + f"| {idx + 1} | {agent_name} | {desc} |" + ) + else: + output.append( + f"| {idx + 1} | {agent} | Unknown description |" + ) + return "\n".join(output) + + # Default XML format + output.append("") + if name: + output.append(f" {name}") + if description: + output.append( + f" {truncate(description)}" ) + for idx, agent in enumerate(agents): + output.append(f" ") + if isinstance(agent, Agent): + agent_name = getattr(agent, "agent_name", str(agent)) + description = getattr( + agent, + "description", + getattr( + agent, "system_prompt", "Unknown description" + ), + ) + output.append(f" {agent_name}") + output.append( + f" {truncate(description)}" + ) + else: + output.append(f" {agent}") + output.append( + " Unknown description" + ) + output.append(" ") + output.append("") - showcase = "\n".join(formatted_agents) - - return showcase + return "\n".join(output) diff --git a/swarms/structs/rearrange.py b/swarms/structs/rearrange.py index a61efbee..801861b0 100644 --- a/swarms/structs/rearrange.py +++ b/swarms/structs/rearrange.py @@ -1,5 +1,5 @@ -import traceback import asyncio +import traceback import uuid from concurrent.futures import ThreadPoolExecutor from datetime import datetime @@ -13,10 +13,10 @@ from swarms.structs.agent import Agent from swarms.structs.agents_available import showcase_available_agents from swarms.structs.base_swarm import BaseSwarm from swarms.utils.add_docs_to_agents import handle_input_docs +from swarms.utils.loguru_logger import initialize_logger from swarms.utils.wrapper_clusterop import ( exec_callable_with_clusterops, ) -from swarms.utils.loguru_logger import initialize_logger logger = initialize_logger(log_folder="rearrange") @@ -153,33 +153,6 @@ class AgentRearrange(BaseSwarm): self.all_cores = all_cores self.all_gpus = all_gpus self.no_use_clusterops = no_use_clusterops - self.swarm_history = { - agent.agent_name: [] for agent in agents - } - - self.id = uuid.uuid4().hex if id is None else id - - # Output schema - self.input_config = AgentRearrangeInput( - swarm_id=self.id, - name=self.name, - description=self.description, - flow=self.flow, - max_loops=self.max_loops, - output_type=self.output_type, - ) - - # Output schema - self.output_schema = AgentRearrangeOutput( - Input=self.input_config, - outputs=[], - ) - - # Run the reliability checks to validate the swarm - # self.handle_input_docs() - - # Show the agents whose in the swarm - # self.showcase_agents() def showcase_agents(self): # Get formatted agent info once @@ -187,12 +160,34 @@ class AgentRearrange(BaseSwarm): name=self.name, description=self.description, agents=self.agents, + format="Table", ) - # Update all agents in one pass using values() - for agent in self.agents.values(): - if isinstance(agent, Agent): - agent.system_prompt += agents_available + return agents_available + + def rearrange_prompt_prep(self) -> str: + """Prepares a formatted prompt describing the swarm configuration. + + Returns: + str: A formatted string containing the swarm's name, description, + flow pattern, and participating agents. + """ + agents_available = self.showcase_agents() + prompt = f""" + ===== Swarm Configuration ===== + + Name: {self.name} + Description: {self.description} + + ===== Execution Flow ===== + {self.flow} + + ===== Participating Agents ===== + {agents_available} + + =========================== + """ + return prompt def set_custom_flow(self, flow: str): self.flow = flow @@ -325,6 +320,7 @@ class AgentRearrange(BaseSwarm): current_task = task all_responses = [] response_dict = {} + previous_agent = None logger.info( f"Starting task execution with {len(tasks)} steps" @@ -349,12 +345,19 @@ class AgentRearrange(BaseSwarm): f"Starting loop {loop_count + 1}/{self.max_loops}" ) - for task in tasks: + for task_idx, task in enumerate(tasks): is_last = task == tasks[-1] agent_names = [ name.strip() for name in task.split(",") ] + # Prepare prompt with previous agent info + prompt_prefix = "" + if previous_agent and task_idx > 0: + prompt_prefix = f"Previous agent {previous_agent} output: {current_task}\n" + elif task_idx == 0: + prompt_prefix = "Initial task: " + if len(agent_names) > 1: # Parallel processing logger.info( @@ -370,12 +373,14 @@ class AgentRearrange(BaseSwarm): ): current_task = ( self.custom_human_in_the_loop( - current_task + prompt_prefix + + str(current_task) ) ) else: current_task = input( - "Enter your response:" + prompt_prefix + + "Enter your response: " ) results.append(current_task) response_dict[agent_name] = ( @@ -383,13 +388,13 @@ class AgentRearrange(BaseSwarm): ) else: agent = self.agents[agent_name] - current_task = ( - str(current_task) + task_with_context = ( + prompt_prefix + str(current_task) if current_task - else "" + else prompt_prefix ) result = agent.run( - task=current_task, + task=task_with_context, img=img, is_last=is_last, *args, @@ -407,6 +412,7 @@ class AgentRearrange(BaseSwarm): current_task = "; ".join(results) all_responses.extend(results) + previous_agent = ",".join(agent_names) else: # Sequential processing @@ -422,23 +428,25 @@ class AgentRearrange(BaseSwarm): ): current_task = ( self.custom_human_in_the_loop( - current_task + prompt_prefix + + str(current_task) ) ) else: current_task = input( - "Enter the next task: " + prompt_prefix + + "Enter the next task: " ) response_dict[agent_name] = current_task else: agent = self.agents[agent_name] - current_task = ( - str(current_task) + task_with_context = ( + prompt_prefix + str(current_task) if current_task - else "" + else prompt_prefix ) current_task = agent.run( - task=current_task, + task=task_with_context, img=img, is_last=is_last, *args, @@ -454,6 +462,7 @@ class AgentRearrange(BaseSwarm): ) all_responses.append(current_task) + previous_agent = agent_name loop_count += 1 diff --git a/swarms/utils/update_agent_system_prompts.py b/swarms/utils/update_agent_system_prompts.py new file mode 100644 index 00000000..e6f82426 --- /dev/null +++ b/swarms/utils/update_agent_system_prompts.py @@ -0,0 +1,53 @@ +import concurrent.futures +from typing import List, Union +from swarms.structs.agent import Agent + + +def update_system_prompts( + agents: List[Union[Agent, str]], + prompt: str, +) -> List[Agent]: + """ + Update system prompts for a list of agents concurrently. + + Args: + agents: List of Agent objects or strings to update + prompt: The prompt text to append to each agent's system prompt + + Returns: + List of updated Agent objects + """ + if not agents: + return agents + + def update_agent_prompt(agent: Union[Agent, str]) -> Agent: + # Convert string to Agent if needed + if isinstance(agent, str): + agent = Agent( + agent_name=agent, + system_prompt=prompt, # Initialize with the provided prompt + ) + else: + # Preserve existing prompt and append new one + existing_prompt = ( + agent.system_prompt if agent.system_prompt else "" + ) + agent.system_prompt = existing_prompt + "\n" + prompt + return agent + + # Use ThreadPoolExecutor for concurrent execution + max_workers = min(len(agents), 4) # Reasonable thread count + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers + ) as executor: + futures = [] + for agent in agents: + future = executor.submit(update_agent_prompt, agent) + futures.append(future) + + # Collect results as they complete + updated_agents = [] + for future in concurrent.futures.as_completed(futures): + updated_agents.append(future.result()) + + return updated_agents From e6d616b2c3730bd9a95f30d81243812abbe12ae7 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 27 Nov 2024 12:22:01 -0800 Subject: [PATCH 3/6] [CLEANUP] --- docs/mkdocs.yml | 38 ++-- multi_tool_usage_agent.py | 352 ++++++++++++++++++++++++++++++++++++++ swarms/__init__.py | 34 +++- 3 files changed, 399 insertions(+), 25 deletions(-) create mode 100644 multi_tool_usage_agent.py diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 23958829..53b4d273 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -57,29 +57,35 @@ extra: property: G-MPE9C65596 theme: - name: material - custom_dir: overrides - logo: assets/img/swarms-logo.png - palette: + name: material + custom_dir: overrides + logo: assets/img/swarms-logo.png + palette: - scheme: default - primary: black + primary: white # White background + accent: white # Black accents for interactive elements toggle: - icon: material/brightness-7 + icon: material/brightness-7 name: Switch to dark mode - # Palette toggle for dark mode - - scheme: slate + - scheme: slate # Optional: lighter shades for accessibility primary: black + accent: black toggle: icon: material/brightness-4 name: Switch to light mode - features: - - content.code.copy - - content.code.annotate - - navigation.tabs - - navigation.sections - - navigation.expand - - navigation.top - - announce.dismiss + features: + - content.code.copy + - content.code.annotate + - navigation.tabs + - navigation.sections + - navigation.expand + - navigation.top + - announce.dismiss + font: + text: "Fira Sans" # Clean and readable text + code: "Fira Code" # Modern look for code snippets + + # Extensions markdown_extensions: - abbr diff --git a/multi_tool_usage_agent.py b/multi_tool_usage_agent.py new file mode 100644 index 00000000..16f07cae --- /dev/null +++ b/multi_tool_usage_agent.py @@ -0,0 +1,352 @@ +from typing import List, Dict, Any, Optional, Callable +from dataclasses import dataclass +import json +from datetime import datetime +import inspect +import typing +from typing import Union +from swarms import Agent +from swarm_models import OpenAIChat +from dotenv import load_dotenv + + +@dataclass +class ToolDefinition: + name: str + description: str + parameters: Dict[str, Any] + required_params: List[str] + callable: Optional[Callable] = None + + +@dataclass +class ExecutionStep: + step_id: str + tool_name: str + parameters: Dict[str, Any] + purpose: str + depends_on: List[str] + completed: bool = False + result: Optional[Any] = None + + +def extract_type_hints(func: Callable) -> Dict[str, Any]: + """Extract parameter types from function type hints.""" + return typing.get_type_hints(func) + + +def extract_tool_info(func: Callable) -> ToolDefinition: + """Extract tool information from a callable function.""" + # Get function name + name = func.__name__ + + # Get docstring + description = inspect.getdoc(func) or "No description available" + + # Get parameters and their types + signature = inspect.signature(func) + type_hints = extract_type_hints(func) + + parameters = {} + required_params = [] + + for param_name, param in signature.parameters.items(): + # Skip self parameter for methods + if param_name == "self": + continue + + param_type = type_hints.get(param_name, Any) + + # Handle optional parameters + is_optional = ( + param.default != inspect.Parameter.empty + or getattr(param_type, "__origin__", None) is Union + and type(None) in param_type.__args__ + ) + + if not is_optional: + required_params.append(param_name) + + parameters[param_name] = { + "type": str(param_type), + "default": ( + None + if param.default is inspect.Parameter.empty + else param.default + ), + "required": not is_optional, + } + + return ToolDefinition( + name=name, + description=description, + parameters=parameters, + required_params=required_params, + callable=func, + ) + + +class ToolUsingAgent: + def __init__( + self, + tools: List[Callable], + openai_api_key: str, + model_name: str = "gpt-4", + temperature: float = 0.1, + max_loops: int = 10, + ): + # Convert callable tools to ToolDefinitions + self.available_tools = { + tool.__name__: extract_tool_info(tool) for tool in tools + } + + self.execution_plan: List[ExecutionStep] = [] + self.current_step_index = 0 + self.max_loops = max_loops + + # Initialize the OpenAI model + self.model = OpenAIChat( + openai_api_key=openai_api_key, + model_name=model_name, + temperature=temperature, + ) + + # Create system prompt with tool descriptions + self.system_prompt = self._create_system_prompt() + + self.agent = Agent( + agent_name="Tool-Using-Agent", + system_prompt=self.system_prompt, + llm=self.model, + max_loops=1, + autosave=True, + verbose=True, + saved_state_path="tool_agent_state.json", + context_length=200000, + ) + + def _create_system_prompt(self) -> str: + """Create system prompt with available tools information.""" + tools_description = [] + for tool_name, tool in self.available_tools.items(): + tools_description.append( + f""" + Tool: {tool_name} + Description: {tool.description} + Parameters: {json.dumps(tool.parameters, indent=2)} + Required Parameters: {tool.required_params} + """ + ) + + output = f"""You are an autonomous agent capable of executing complex tasks using available tools. + + Available Tools: + {chr(10).join(tools_description)} + + Follow these protocols: + 1. Create a detailed plan using available tools + 2. Execute each step in order + 3. Handle errors appropriately + 4. Maintain execution state + 5. Return results in structured format + + You must ALWAYS respond in the following JSON format: + {{ + "plan": {{ + "description": "Brief description of the overall plan", + "steps": [ + {{ + "step_number": 1, + "tool_name": "name_of_tool", + "description": "What this step accomplishes", + "parameters": {{ + "param1": "value1", + "param2": "value2" + }}, + "expected_output": "Description of expected output" + }} + ] + }}, + "reasoning": "Explanation of why this plan was chosen" + }} + + Before executing any tool: + 1. Validate all required parameters are present + 2. Verify parameter types match specifications + 3. Check parameter values are within valid ranges/formats + 4. Ensure logical dependencies between steps are met + + If any validation fails: + 1. Return error in JSON format with specific details + 2. Suggest corrections if possible + 3. Do not proceed with execution + + After each step execution: + 1. Verify output matches expected format + 2. Log results and any warnings/errors + 3. Update execution state + 4. Determine if plan adjustment needed + + Error Handling: + 1. Catch and classify all errors + 2. Provide detailed error messages + 3. Suggest recovery steps + 4. Maintain system stability + + The final output must be valid JSON that can be parsed. Always check your response can be parsed as JSON before returning. + """ + return output + + def execute_tool( + self, tool_name: str, parameters: Dict[str, Any] + ) -> Any: + """Execute a tool with given parameters.""" + tool = self.available_tools[tool_name] + if not tool.callable: + raise ValueError( + f"Tool {tool_name} has no associated callable" + ) + + # Convert parameters to appropriate types + converted_params = {} + for param_name, param_value in parameters.items(): + param_info = tool.parameters[param_name] + param_type = eval( + param_info["type"] + ) # Note: Be careful with eval + converted_params[param_name] = param_type(param_value) + + return tool.callable(**converted_params) + + def run(self, task: str) -> Dict[str, Any]: + """Execute the complete task with proper logging and error handling.""" + execution_log = { + "task": task, + "start_time": datetime.utcnow().isoformat(), + "steps": [], + "final_result": None + } + + try: + # Create and execute plan + plan_response = self.agent.run(f"Create a plan for: {task}") + plan_data = json.loads(plan_response) + + # Extract steps from the correct path in JSON + steps = plan_data["plan"]["steps"] # Changed from plan_data["steps"] + + for step in steps: + try: + # Check if parameters need default values + for param_name, param_value in step["parameters"].items(): + if isinstance(param_value, str) and not param_value.replace(".", "").isdigit(): + # If parameter is a description rather than a value, set default + if "income" in param_name.lower(): + step["parameters"][param_name] = 75000.0 + elif "year" in param_name.lower(): + step["parameters"][param_name] = 2024 + elif "investment" in param_name.lower(): + step["parameters"][param_name] = 1000.0 + + # Execute the tool + result = self.execute_tool( + step["tool_name"], + step["parameters"] + ) + + execution_log["steps"].append({ + "step_number": step["step_number"], + "tool": step["tool_name"], + "parameters": step["parameters"], + "success": True, + "result": result, + "description": step["description"] + }) + + except Exception as e: + execution_log["steps"].append({ + "step_number": step["step_number"], + "tool": step["tool_name"], + "parameters": step["parameters"], + "success": False, + "error": str(e), + "description": step["description"] + }) + print(f"Error executing step {step['step_number']}: {str(e)}") + # Continue with next step instead of raising + continue + + # Only mark as success if at least some steps succeeded + successful_steps = [s for s in execution_log["steps"] if s["success"]] + if successful_steps: + execution_log["final_result"] = { + "success": True, + "results": successful_steps, + "reasoning": plan_data.get("reasoning", "No reasoning provided") + } + else: + execution_log["final_result"] = { + "success": False, + "error": "No steps completed successfully", + "plan": plan_data + } + + except Exception as e: + execution_log["final_result"] = { + "success": False, + "error": str(e), + "plan": plan_data if 'plan_data' in locals() else None + } + + execution_log["end_time"] = datetime.utcnow().isoformat() + return execution_log + + +# Example usage +if __name__ == "__main__": + load_dotenv() + + # Example tool functions + def research_ira_requirements() -> Dict[str, Any]: + """Research and return ROTH IRA eligibility requirements.""" + return { + "age_requirement": "Must have earned income", + "income_limits": {"single": 144000, "married": 214000}, + } + + def calculate_contribution_limit( + income: float, tax_year: int + ) -> Dict[str, float]: + """Calculate maximum ROTH IRA contribution based on income and tax year.""" + base_limit = 6000 if tax_year <= 2022 else 6500 + if income > 144000: + return {"limit": 0} + return {"limit": base_limit} + + def find_brokers(min_investment: float) -> List[Dict[str, Any]]: + """Find suitable brokers for ROTH IRA based on minimum investment.""" + return [ + {"name": "Broker A", "min_investment": min_investment}, + { + "name": "Broker B", + "min_investment": min_investment * 1.5, + }, + ] + + # Initialize agent with tools + agent = ToolUsingAgent( + tools=[ + research_ira_requirements, + calculate_contribution_limit, + find_brokers, + ], + openai_api_key="", + ) + + # Run a task + result = agent.run( + "How can I establish a ROTH IRA to buy stocks and get a tax break? " + "What are the criteria?" + ) + + print(json.dumps(result, indent=2)) diff --git a/swarms/__init__.py b/swarms/__init__.py index 59fee672..d73b8439 100644 --- a/swarms/__init__.py +++ b/swarms/__init__.py @@ -1,22 +1,38 @@ +import os import concurrent.futures from dotenv import load_dotenv - -# from swarms.structs.workspace_manager import WorkspaceManager -# workspace_manager = WorkspaceManager() -# workspace_manager.run() +from loguru import logger load_dotenv() +# More reliable string comparison +if os.getenv("SWARMS_VERBOSE_GLOBAL", "True").lower() == "false": + logger.disable("") +# Import telemetry functions with error handling from swarms.telemetry.bootup import bootup # noqa: E402, F403 -from swarms.telemetry.sentry_active import ( +from swarms.telemetry.sentry_active import ( # noqa: E402 activate_sentry, ) # noqa: E402 -# Use ThreadPoolExecutor to run bootup and activate_sentry concurrently -with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: - executor.submit(bootup) - executor.submit(activate_sentry) + +# Run telemetry functions concurrently with error handling +def run_telemetry(): + try: + with concurrent.futures.ThreadPoolExecutor( + max_workers=2 + ) as executor: + future_bootup = executor.submit(bootup) + future_sentry = executor.submit(activate_sentry) + + # Wait for completion and check for exceptions + future_bootup.result() + future_sentry.result() + except Exception as e: + logger.error(f"Error running telemetry functions: {e}") + + +run_telemetry() from swarms.agents import * # noqa: E402, F403 from swarms.artifacts import * # noqa: E402, F403 From c2cd6775751d1d412b55d131d8bb60e274b19b6a Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 27 Nov 2024 16:51:09 -0800 Subject: [PATCH 4/6] [DOCS][CSS] --- docs/assets/css/extra.css | 38 +- ethchain_agent.py | 308 ++++++++++++++ ethereum_analysis.csv | 4 + multi_tool_usage_agent.py | 513 +++++++++++++----------- pyproject.toml | 2 +- swarms/__init__.py | 4 +- swarms/telemetry/auto_upgrade_swarms.py | 2 +- swarms/tools/tool_utils.py | 2 +- test.py | 292 ++++++++++++++ 9 files changed, 921 insertions(+), 244 deletions(-) create mode 100644 ethchain_agent.py create mode 100644 ethereum_analysis.csv create mode 100644 test.py diff --git a/docs/assets/css/extra.css b/docs/assets/css/extra.css index b639e2f7..d11f77f5 100644 --- a/docs/assets/css/extra.css +++ b/docs/assets/css/extra.css @@ -1,18 +1,26 @@ - -/* Further customization as needed */ +/* * Further customization as needed */ */ .md-typeset__table { - min-width: 100%; -} - -.md-typeset table:not([class]) { - display: table; -} - -/* -:root { - --md-primary-fg-color: #EE0F0F; - --md-primary-fg-color--light: #ECB7B7; - --md-primary-fg-color--dark: #90030C; - } */ \ No newline at end of file + min-width: 100%; + } + + .md-typeset table:not([class]) { + display: table; + } + + /* Dark mode */ + [data-md-color-scheme="slate"] { + --md-default-bg-color: black; + } + + .header__ellipsis { + color: black; + } + + /* + :root { + --md-primary-fg-color: #EE0F0F; + --md-primary-fg-color--light: #ECB7B7; + --md-primary-fg-color--dark: #90030C; + } */ \ No newline at end of file diff --git a/ethchain_agent.py b/ethchain_agent.py new file mode 100644 index 00000000..cc06aeb5 --- /dev/null +++ b/ethchain_agent.py @@ -0,0 +1,308 @@ +import os +from swarms import Agent +from swarm_models import OpenAIChat +from web3 import Web3 +from typing import Dict, Optional, Any +from datetime import datetime +import asyncio +from loguru import logger +from dotenv import load_dotenv +import csv +import requests +import time + +BLOCKCHAIN_AGENT_PROMPT = """ +You are an expert blockchain and cryptocurrency analyst with deep knowledge of Ethereum markets and DeFi ecosystems. +You have access to real-time ETH price data and transaction information. + +For each transaction, analyze: + +1. MARKET CONTEXT +- Current ETH price and what this transaction means in USD terms +- How this movement compares to typical market volumes +- Whether this could impact ETH price + +2. BEHAVIORAL ANALYSIS +- Whether this appears to be institutional, whale, or protocol movement +- If this fits any known wallet patterns or behaviors +- Signs of smart contract interaction or DeFi activity + +3. RISK & IMPLICATIONS +- Potential market impact or price influence +- Signs of potential market manipulation or unusual activity +- Protocol or DeFi risks if applicable + +4. STRATEGIC INSIGHTS +- What traders should know about this movement +- Potential chain reactions or follow-up effects +- Market opportunities or risks created + +Write naturally but precisely. Focus on actionable insights and important patterns. +Your analysis helps traders and researchers understand significant market movements in real-time.""" + + +class EthereumAnalyzer: + def __init__(self, min_value_eth: float = 100.0): + load_dotenv() + + logger.add( + "eth_analysis.log", + rotation="500 MB", + retention="10 days", + level="INFO", + format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", + ) + + self.w3 = Web3( + Web3.HTTPProvider( + "https://mainnet.infura.io/v3/9aa3d95b3bc440fa88ea12eaa4456161" + ) + ) + if not self.w3.is_connected(): + raise ConnectionError( + "Failed to connect to Ethereum network" + ) + + self.min_value_eth = min_value_eth + self.last_processed_block = self.w3.eth.block_number + self.eth_price = self.get_eth_price() + self.last_price_update = time.time() + + # Initialize AI agent + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError( + "OpenAI API key not found in environment variables" + ) + + model = OpenAIChat( + openai_api_key=api_key, + model_name="gpt-4", + temperature=0.1, + ) + + self.agent = Agent( + agent_name="Ethereum-Analysis-Agent", + system_prompt=BLOCKCHAIN_AGENT_PROMPT, + llm=model, + max_loops=1, + autosave=True, + dashboard=False, + verbose=True, + dynamic_temperature_enabled=True, + saved_state_path="eth_agent.json", + user_name="eth_analyzer", + retry_attempts=1, + context_length=200000, + output_type="string", + streaming_on=False, + ) + + self.csv_filename = "ethereum_analysis.csv" + self.initialize_csv() + + def get_eth_price(self) -> float: + """Get current ETH price from CoinGecko API.""" + try: + response = requests.get( + "https://api.coingecko.com/api/v3/simple/price", + params={"ids": "ethereum", "vs_currencies": "usd"}, + ) + return float(response.json()["ethereum"]["usd"]) + except Exception as e: + logger.error(f"Error fetching ETH price: {str(e)}") + return 0.0 + + def update_eth_price(self): + """Update ETH price if more than 5 minutes have passed.""" + if time.time() - self.last_price_update > 300: # 5 minutes + self.eth_price = self.get_eth_price() + self.last_price_update = time.time() + logger.info(f"Updated ETH price: ${self.eth_price:,.2f}") + + def initialize_csv(self): + """Initialize CSV file with headers.""" + headers = [ + "timestamp", + "transaction_hash", + "from_address", + "to_address", + "value_eth", + "value_usd", + "eth_price", + "gas_used", + "gas_price_gwei", + "block_number", + "analysis", + ] + + if not os.path.exists(self.csv_filename): + with open(self.csv_filename, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(headers) + + async def analyze_transaction( + self, tx_hash: str + ) -> Optional[Dict[str, Any]]: + """Analyze a single transaction.""" + try: + tx = self.w3.eth.get_transaction(tx_hash) + receipt = self.w3.eth.get_transaction_receipt(tx_hash) + + value_eth = float(self.w3.from_wei(tx.value, "ether")) + + if value_eth < self.min_value_eth: + return None + + block = self.w3.eth.get_block(tx.blockNumber) + + # Update ETH price if needed + self.update_eth_price() + + value_usd = value_eth * self.eth_price + + analysis = { + "timestamp": datetime.fromtimestamp( + block.timestamp + ).isoformat(), + "transaction_hash": tx_hash.hex(), + "from_address": tx["from"], + "to_address": tx.to if tx.to else "Contract Creation", + "value_eth": value_eth, + "value_usd": value_usd, + "eth_price": self.eth_price, + "gas_used": receipt.gasUsed, + "gas_price_gwei": float( + self.w3.from_wei(tx.gasPrice, "gwei") + ), + "block_number": tx.blockNumber, + } + + # Check if it's a contract + if tx.to: + code = self.w3.eth.get_code(tx.to) + analysis["is_contract"] = len(code) > 0 + + # Get contract events + if analysis["is_contract"]: + analysis["events"] = receipt.logs + + return analysis + + except Exception as e: + logger.error( + f"Error analyzing transaction {tx_hash}: {str(e)}" + ) + return None + + def prepare_analysis_prompt(self, tx_data: Dict[str, Any]) -> str: + """Prepare detailed analysis prompt including price context.""" + value_usd = tx_data["value_usd"] + eth_price = tx_data["eth_price"] + + prompt = f"""Analyze this Ethereum transaction in current market context: + +Transaction Details: +- Value: {tx_data['value_eth']:.2f} ETH (${value_usd:,.2f} at current price) +- Current ETH Price: ${eth_price:,.2f} +- From: {tx_data['from_address']} +- To: {tx_data['to_address']} +- Contract Interaction: {tx_data.get('is_contract', False)} +- Gas Used: {tx_data['gas_used']:,} units +- Gas Price: {tx_data['gas_price_gwei']:.2f} Gwei +- Block: {tx_data['block_number']} +- Timestamp: {tx_data['timestamp']} + +{f"Event Count: {len(tx_data['events'])} events" if tx_data.get('events') else "No contract events"} + +Consider the transaction's significance given the current ETH price of ${eth_price:,.2f} and total USD value of ${value_usd:,.2f}. +Analyze market impact, patterns, risks, and strategic implications.""" + + return prompt + + def save_to_csv(self, tx_data: Dict[str, Any], ai_analysis: str): + """Save transaction data and analysis to CSV.""" + row = [ + tx_data["timestamp"], + tx_data["transaction_hash"], + tx_data["from_address"], + tx_data["to_address"], + tx_data["value_eth"], + tx_data["value_usd"], + tx_data["eth_price"], + tx_data["gas_used"], + tx_data["gas_price_gwei"], + tx_data["block_number"], + ai_analysis.replace("\n", " "), + ] + + with open(self.csv_filename, "a", newline="") as f: + writer = csv.writer(f) + writer.writerow(row) + + async def monitor_transactions(self): + """Monitor and analyze transactions one at a time.""" + logger.info( + f"Starting transaction monitor (minimum value: {self.min_value_eth} ETH)" + ) + + while True: + try: + current_block = self.w3.eth.block_number + block = self.w3.eth.get_block( + current_block, full_transactions=True + ) + + for tx in block.transactions: + tx_analysis = await self.analyze_transaction( + tx.hash + ) + + if tx_analysis: + # Get AI analysis + analysis_prompt = ( + self.prepare_analysis_prompt(tx_analysis) + ) + ai_analysis = self.agent.run(analysis_prompt) + print(ai_analysis) + + # Save to CSV + self.save_to_csv(tx_analysis, ai_analysis) + + # Print analysis + print("\n" + "=" * 50) + print("New Transaction Analysis") + print( + f"Hash: {tx_analysis['transaction_hash']}" + ) + print( + f"Value: {tx_analysis['value_eth']:.2f} ETH (${tx_analysis['value_usd']:,.2f})" + ) + print( + f"Current ETH Price: ${self.eth_price:,.2f}" + ) + print("=" * 50) + print(ai_analysis) + print("=" * 50 + "\n") + + await asyncio.sleep(1) # Wait for next block + + except Exception as e: + logger.error(f"Error in monitoring loop: {str(e)}") + await asyncio.sleep(1) + + +async def main(): + """Entry point for the analysis system.""" + analyzer = EthereumAnalyzer(min_value_eth=100.0) + await analyzer.monitor_transactions() + + +if __name__ == "__main__": + print("Starting Ethereum Transaction Analyzer...") + print("Saving results to ethereum_analysis.csv") + print("Press Ctrl+C to stop") + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nStopping analyzer...") diff --git a/ethereum_analysis.csv b/ethereum_analysis.csv new file mode 100644 index 00000000..703c4fbe --- /dev/null +++ b/ethereum_analysis.csv @@ -0,0 +1,4 @@ +timestamp,transaction_hash,from_address,to_address,value_eth,gas_used,gas_price_gwei,block_number,analysis +2024-11-27T13:50:35,ddbb665bc75fe848e7ce3d3ce1729243e92466c38ca407deccce8bf629987652,0x267be1C1D684F78cb4F6a176C4911b741E4Ffdc0,0xa40dFEE99E1C85DC97Fdc594b16A460717838703,3200.0,21000,19.968163737,21281878,"Transaction Analysis: This transaction represents a significant transfer of value in the Ethereum network with 3200 ETH (~$6.72 million USD at the current rate) moved from one address to another. It is essential to note that this transaction did not involve smart contract interaction, suggesting it could be a straightforward transfer of funds rather than part of a more complex operation. Looking at the broader market context, large transactions like this can potentially indicate major investment activities or redistribution of assets, which can have ripple effects in the market. If this transaction is part of a larger pattern of significant transfers, it could suggest substantial liquidity moving in the Ethereum ecosystem, possibly affecting the ETH prices. From a DeFi point of view, since there's no contract interaction, it's difficult to infer any direct implications. However, given the substantial value involved, it could be a step in preparation for involvement in DeFi protocols or a move from one DeFi platform to another by a large investor. The transaction fee paid, calculated from the given Gas Used and Gas Price, appears to be within reasonable range. This suggests that the transaction was not rushed and that the sender was willing to wait for this transaction to be confirmed, which might hint towards the non-urgent nature of the transaction. As for potential risk factors or security concerns, the transaction itself appears to be standard and doesn't raise any immediate red flags. However, the parties involved should always be cautious about the address security, maintaining privacy, and avoiding social engineering attacks. For traders and investors, this transaction can be interpreted as a potential bullish sign if it signifies increased liquidity and investment in the Ethereum market, especially if it's followed by similar large transfers. However, due to the anonymous nature of the transaction, it's critical to combine this with other market indicators and not to rely solely on transaction analysis for investment decisions." +2024-11-27T13:52:23,b98bcbf6d57a158b67a126d8f023766e03fb15c3e74becc1189d4244fda61a13,0xEae7380dD4CeF6fbD1144F49E4D1e6964258A4F4,0x28C6c06298d514Db089934071355E5743bf21d60,401.99463589018103,21000,14.978063737,21281887,"Ethereum-Analysis-Agent: Transaction Analysis: This transaction marks a significant transfer of 401.99 ETH, approximately $845,000 at the current rate. The transaction did not involve any smart contract interaction, suggesting a simple fund transfer rather than a complicated operation or interaction with a DeFi protocol. From a broader market perspective, this transaction is meaningful but not as potentially impactful as larger transactions. It can nonetheless be part of a larger pattern of asset movement within the Ethereum ecosystem. If this transaction is part of larger investment activities, it could suggest an increase in demand for ETH and potentially impact its price. Without contract interaction, it's challenging to assess direct implications for DeFi protocols. However, the substantial ETH transfer could suggest a step towards participation in DeFi activities, or a movement of funds between different DeFi platforms. The transaction fee appears reasonable, given the Gas Used and Gas Price. This implies that the transaction wasn't urgent, and the sender was willing to wait for the transaction to be confirmed, indicating a non-critical movement of funds. In terms of security and risk factors, there are no immediate concerns from the transaction itself. Nevertheless, as with any crypto transaction, the parties involved should ensure secure storage of their keys, maintain privacy, and be wary of potential phishing or social engineering attacks. For traders and investors, this transaction could be seen as a bullish sign if it forms part of a trend of increased investment activities in the Ethereum market. However, it's important to remember that transaction analysis should be combined with other market indicators due to the anonymous nature of blockchain transactions." +2024-11-27T13:59:47,a985b74fd3dfee09cbe4a2e6890509e583a3f0ce13f68c98e82996e0f66428be,0xf7858Da8a6617f7C6d0fF2bcAFDb6D2eeDF64840,0xA294cCa691e4C83B1fc0c8d63D9a3eeF0A196DE1,136.0668,494665.408728,3635.46,21000,18.866443971,21281923,"1. MARKET CONTEXT The transaction of 136.07 ETH, equivalent to $494,665.41, is a significant movement in the Ethereum market. However, compared to the daily trading volume of Ethereum, which often exceeds billions of dollars, this transaction is not large enough to significantly impact the ETH price on its own. 2. BEHAVIORAL ANALYSIS The transaction does not appear to be a protocol movement as there is no contract interaction involved. It could be a whale movement, given the substantial amount of ETH transferred. However, without additional information about the wallets involved, it's difficult to definitively determine the nature of the transaction. The gas price of 18.87 Gwei is relatively standard, suggesting that the transaction was not urgent or time-sensitive. 3. RISK & IMPLICATIONS The transaction does not show signs of market manipulation or unusual activity. The absence of contract interaction suggests that this transaction does not directly involve DeFi protocols, reducing the risk of smart contract vulnerabilities or DeFi-related risks. However, the large amount of ETH transferred could potentially influence market sentiment if it is part of a larger trend of similar transactions. 4. STRATEGIC INSIGHTS Traders should note this transaction as part of the broader market activity. While a single transaction of this size is unlikely to significantly impact the market, a series of similar transactions could indicate a larger trend. If this is part of a larger movement of ETH out of exchanges, it could suggest a decrease in selling pressure, which could be bullish for ETH. Conversely, if this is part of a larger movement into exchanges, it could indicate an increase in selling pressure, which could be bearish for ETH. Traders should monitor the market for further similar transactions to gain a better understanding of the potential market trends." diff --git a/multi_tool_usage_agent.py b/multi_tool_usage_agent.py index 16f07cae..44577528 100644 --- a/multi_tool_usage_agent.py +++ b/multi_tool_usage_agent.py @@ -1,5 +1,6 @@ +import os from typing import List, Dict, Any, Optional, Callable -from dataclasses import dataclass +from dataclasses import dataclass, field import json from datetime import datetime import inspect @@ -7,7 +8,6 @@ import typing from typing import Union from swarms import Agent from swarm_models import OpenAIChat -from dotenv import load_dotenv @dataclass @@ -19,17 +19,6 @@ class ToolDefinition: callable: Optional[Callable] = None -@dataclass -class ExecutionStep: - step_id: str - tool_name: str - parameters: Dict[str, Any] - purpose: str - depends_on: List[str] - completed: bool = False - result: Optional[Any] = None - - def extract_type_hints(func: Callable) -> Dict[str, Any]: """Extract parameter types from function type hints.""" return typing.get_type_hints(func) @@ -86,267 +75,343 @@ def extract_tool_info(func: Callable) -> ToolDefinition: ) -class ToolUsingAgent: +@dataclass +class FunctionSpec: + """Specification for a callable tool function.""" + + name: str + description: str + parameters: Dict[ + str, dict + ] # Contains type and description for each parameter + return_type: str + return_description: str + + +@dataclass +class ExecutionStep: + """Represents a single step in the execution plan.""" + + step_id: int + function_name: str + parameters: Dict[str, Any] + expected_output: str + completed: bool = False + result: Any = None + + +@dataclass +class ExecutionContext: + """Maintains state during execution.""" + + task: str + steps: List[ExecutionStep] = field(default_factory=list) + results: Dict[int, Any] = field(default_factory=dict) + current_step: int = 0 + history: List[Dict[str, Any]] = field(default_factory=list) + + +class ToolAgent: def __init__( self, - tools: List[Callable], + functions: List[Callable], openai_api_key: str, model_name: str = "gpt-4", temperature: float = 0.1, - max_loops: int = 10, ): - # Convert callable tools to ToolDefinitions - self.available_tools = { - tool.__name__: extract_tool_info(tool) for tool in tools - } + self.functions = {func.__name__: func for func in functions} + self.function_specs = self._analyze_functions(functions) - self.execution_plan: List[ExecutionStep] = [] - self.current_step_index = 0 - self.max_loops = max_loops - - # Initialize the OpenAI model self.model = OpenAIChat( openai_api_key=openai_api_key, model_name=model_name, temperature=temperature, ) - # Create system prompt with tool descriptions self.system_prompt = self._create_system_prompt() - self.agent = Agent( - agent_name="Tool-Using-Agent", + agent_name="Tool-Agent", system_prompt=self.system_prompt, llm=self.model, max_loops=1, - autosave=True, verbose=True, - saved_state_path="tool_agent_state.json", - context_length=200000, ) + def _analyze_functions( + self, functions: List[Callable] + ) -> Dict[str, FunctionSpec]: + """Analyze functions to create detailed specifications.""" + specs = {} + for func in functions: + hints = get_type_hints(func) + sig = inspect.signature(func) + doc = inspect.getdoc(func) or "" + + # Parse docstring for parameter descriptions + param_descriptions = {} + current_param = None + for line in doc.split("\n"): + if ":param" in line: + param_name = ( + line.split(":param")[1].split(":")[0].strip() + ) + desc = line.split(":", 2)[-1].strip() + param_descriptions[param_name] = desc + elif ":return:" in line: + return_desc = line.split(":return:")[1].strip() + + # Build parameter specifications + parameters = {} + for name, param in sig.parameters.items(): + param_type = hints.get(name, Any) + parameters[name] = { + "type": str(param_type), + "type_class": param_type, + "description": param_descriptions.get(name, ""), + "required": param.default == param.empty, + } + + specs[func.__name__] = FunctionSpec( + name=func.__name__, + description=doc.split("\n")[0], + parameters=parameters, + return_type=str(hints.get("return", Any)), + return_description=( + return_desc if "return_desc" in locals() else "" + ), + ) + + return specs + def _create_system_prompt(self) -> str: - """Create system prompt with available tools information.""" - tools_description = [] - for tool_name, tool in self.available_tools.items(): - tools_description.append( + """Create system prompt with detailed function specifications.""" + functions_desc = [] + for spec in self.function_specs.values(): + params_desc = [] + for name, details in spec.parameters.items(): + params_desc.append( + f" - {name}: {details['type']} - {details['description']}" + ) + + functions_desc.append( f""" - Tool: {tool_name} - Description: {tool.description} - Parameters: {json.dumps(tool.parameters, indent=2)} - Required Parameters: {tool.required_params} - """ +Function: {spec.name} +Description: {spec.description} +Parameters: +{chr(10).join(params_desc)} +Returns: {spec.return_type} - {spec.return_description} + """ ) - output = f"""You are an autonomous agent capable of executing complex tasks using available tools. - - Available Tools: - {chr(10).join(tools_description)} - - Follow these protocols: - 1. Create a detailed plan using available tools - 2. Execute each step in order - 3. Handle errors appropriately - 4. Maintain execution state - 5. Return results in structured format - - You must ALWAYS respond in the following JSON format: - {{ - "plan": {{ - "description": "Brief description of the overall plan", - "steps": [ - {{ - "step_number": 1, - "tool_name": "name_of_tool", - "description": "What this step accomplishes", - "parameters": {{ - "param1": "value1", - "param2": "value2" - }}, - "expected_output": "Description of expected output" - }} - ] - }}, - "reasoning": "Explanation of why this plan was chosen" - }} - - Before executing any tool: - 1. Validate all required parameters are present - 2. Verify parameter types match specifications - 3. Check parameter values are within valid ranges/formats - 4. Ensure logical dependencies between steps are met - - If any validation fails: - 1. Return error in JSON format with specific details - 2. Suggest corrections if possible - 3. Do not proceed with execution - - After each step execution: - 1. Verify output matches expected format - 2. Log results and any warnings/errors - 3. Update execution state - 4. Determine if plan adjustment needed - - Error Handling: - 1. Catch and classify all errors - 2. Provide detailed error messages - 3. Suggest recovery steps - 4. Maintain system stability - - The final output must be valid JSON that can be parsed. Always check your response can be parsed as JSON before returning. - """ - return output - - def execute_tool( - self, tool_name: str, parameters: Dict[str, Any] + return f"""You are an AI agent that creates and executes plans using available functions. + +Available Functions: +{chr(10).join(functions_desc)} + +You must respond in two formats depending on the phase: + +1. Planning Phase: +{{ + "phase": "planning", + "plan": {{ + "description": "Overall plan description", + "steps": [ + {{ + "step_id": 1, + "function": "function_name", + "parameters": {{ + "param1": "value1", + "param2": "value2" + }}, + "purpose": "Why this step is needed" + }} + ] + }} +}} + +2. Execution Phase: +{{ + "phase": "execution", + "analysis": "Analysis of current result", + "next_action": {{ + "type": "continue|request_input|complete", + "reason": "Why this action was chosen", + "needed_input": {{}} # If requesting input + }} +}} + +Always: +- Use exact function names +- Ensure parameter types match specifications +- Provide clear reasoning for each decision +""" + + def _execute_function( + self, spec: FunctionSpec, parameters: Dict[str, Any] ) -> Any: - """Execute a tool with given parameters.""" - tool = self.available_tools[tool_name] - if not tool.callable: - raise ValueError( - f"Tool {tool_name} has no associated callable" - ) - - # Convert parameters to appropriate types + """Execute a function with type checking.""" converted_params = {} - for param_name, param_value in parameters.items(): - param_info = tool.parameters[param_name] - param_type = eval( - param_info["type"] - ) # Note: Be careful with eval - converted_params[param_name] = param_type(param_value) - - return tool.callable(**converted_params) + for name, value in parameters.items(): + param_spec = spec.parameters[name] + try: + # Convert value to required type + param_type = param_spec["type_class"] + if param_type in (int, float, str, bool): + converted_params[name] = param_type(value) + else: + converted_params[name] = value + except (ValueError, TypeError) as e: + raise ValueError( + f"Parameter '{name}' conversion failed: {str(e)}" + ) + + return self.functions[spec.name](**converted_params) def run(self, task: str) -> Dict[str, Any]: - """Execute the complete task with proper logging and error handling.""" + """Execute task with planning and step-by-step execution.""" + context = ExecutionContext(task=task) execution_log = { "task": task, "start_time": datetime.utcnow().isoformat(), "steps": [], - "final_result": None + "final_result": None, } - + try: - # Create and execute plan - plan_response = self.agent.run(f"Create a plan for: {task}") - plan_data = json.loads(plan_response) - - # Extract steps from the correct path in JSON - steps = plan_data["plan"]["steps"] # Changed from plan_data["steps"] - - for step in steps: + # Planning phase + plan_prompt = f"Create a plan to: {task}" + plan_response = self.agent.run(plan_prompt) + plan_data = json.loads( + plan_response.replace("System:", "").strip() + ) + + # Convert plan to execution steps + for step in plan_data["plan"]["steps"]: + context.steps.append( + ExecutionStep( + step_id=step["step_id"], + function_name=step["function"], + parameters=step["parameters"], + expected_output=step["purpose"], + ) + ) + + # Execution phase + while context.current_step < len(context.steps): + step = context.steps[context.current_step] + print( + f"\nExecuting step {step.step_id}: {step.function_name}" + ) + try: - # Check if parameters need default values - for param_name, param_value in step["parameters"].items(): - if isinstance(param_value, str) and not param_value.replace(".", "").isdigit(): - # If parameter is a description rather than a value, set default - if "income" in param_name.lower(): - step["parameters"][param_name] = 75000.0 - elif "year" in param_name.lower(): - step["parameters"][param_name] = 2024 - elif "investment" in param_name.lower(): - step["parameters"][param_name] = 1000.0 - - # Execute the tool - result = self.execute_tool( - step["tool_name"], - step["parameters"] + # Execute function + spec = self.function_specs[step.function_name] + result = self._execute_function( + spec, step.parameters ) + context.results[step.step_id] = result + step.completed = True + step.result = result + + # Get agent's analysis + analysis_prompt = f""" + Step {step.step_id} completed: + Function: {step.function_name} + Result: {json.dumps(result)} + Remaining steps: {len(context.steps) - context.current_step - 1} - execution_log["steps"].append({ - "step_number": step["step_number"], - "tool": step["tool_name"], - "parameters": step["parameters"], - "success": True, - "result": result, - "description": step["description"] - }) - + Analyze the result and decide next action. + """ + + analysis_response = self.agent.run( + analysis_prompt + ) + analysis_data = json.loads( + analysis_response.replace( + "System:", "" + ).strip() + ) + + execution_log["steps"].append( + { + "step_id": step.step_id, + "function": step.function_name, + "parameters": step.parameters, + "result": result, + "analysis": analysis_data, + } + ) + + if ( + analysis_data["next_action"]["type"] + == "complete" + ): + if ( + context.current_step + < len(context.steps) - 1 + ): + continue + break + + context.current_step += 1 + except Exception as e: - execution_log["steps"].append({ - "step_number": step["step_number"], - "tool": step["tool_name"], - "parameters": step["parameters"], - "success": False, - "error": str(e), - "description": step["description"] - }) - print(f"Error executing step {step['step_number']}: {str(e)}") - # Continue with next step instead of raising - continue + print(f"Error in step {step.step_id}: {str(e)}") + execution_log["steps"].append( + { + "step_id": step.step_id, + "function": step.function_name, + "parameters": step.parameters, + "error": str(e), + } + ) + raise + + # Final analysis + final_prompt = f""" + Task completed. Results: + {json.dumps(context.results, indent=2)} - # Only mark as success if at least some steps succeeded - successful_steps = [s for s in execution_log["steps"] if s["success"]] - if successful_steps: - execution_log["final_result"] = { - "success": True, - "results": successful_steps, - "reasoning": plan_data.get("reasoning", "No reasoning provided") - } - else: - execution_log["final_result"] = { - "success": False, - "error": "No steps completed successfully", - "plan": plan_data - } - + Provide final analysis and recommendations. + """ + + final_analysis = self.agent.run(final_prompt) + execution_log["final_result"] = { + "success": True, + "results": context.results, + "analysis": json.loads( + final_analysis.replace("System:", "").strip() + ), + } + except Exception as e: execution_log["final_result"] = { "success": False, "error": str(e), - "plan": plan_data if 'plan_data' in locals() else None } - + execution_log["end_time"] = datetime.utcnow().isoformat() return execution_log -# Example usage -if __name__ == "__main__": - load_dotenv() - - # Example tool functions - def research_ira_requirements() -> Dict[str, Any]: - """Research and return ROTH IRA eligibility requirements.""" - return { - "age_requirement": "Must have earned income", - "income_limits": {"single": 144000, "married": 214000}, - } +def calculate_investment_return( + principal: float, rate: float, years: int +) -> float: + """Calculate investment return with compound interest. - def calculate_contribution_limit( - income: float, tax_year: int - ) -> Dict[str, float]: - """Calculate maximum ROTH IRA contribution based on income and tax year.""" - base_limit = 6000 if tax_year <= 2022 else 6500 - if income > 144000: - return {"limit": 0} - return {"limit": base_limit} - - def find_brokers(min_investment: float) -> List[Dict[str, Any]]: - """Find suitable brokers for ROTH IRA based on minimum investment.""" - return [ - {"name": "Broker A", "min_investment": min_investment}, - { - "name": "Broker B", - "min_investment": min_investment * 1.5, - }, - ] + :param principal: Initial investment amount in dollars + :param rate: Annual interest rate as decimal (e.g., 0.07 for 7%) + :param years: Number of years to invest + :return: Final investment value + """ + return principal * (1 + rate) ** years - # Initialize agent with tools - agent = ToolUsingAgent( - tools=[ - research_ira_requirements, - calculate_contribution_limit, - find_brokers, - ], - openai_api_key="", - ) - # Run a task - result = agent.run( - "How can I establish a ROTH IRA to buy stocks and get a tax break? " - "What are the criteria?" - ) +agent = ToolAgent( + functions=[calculate_investment_return], + openai_api_key=os.getenv("OPENAI_API_KEY"), +) - print(json.dumps(result, indent=2)) +result = agent.run( + "Calculate returns for $10000 invested at 7% for 10 years" +) diff --git a/pyproject.toml b/pyproject.toml index bb3d5043..51bb898f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "6.3.6" +version = "6.3.7" description = "Swarms - Pytorch" license = "MIT" authors = ["Kye Gomez "] diff --git a/swarms/__init__.py b/swarms/__init__.py index d73b8439..0c3b5ca5 100644 --- a/swarms/__init__.py +++ b/swarms/__init__.py @@ -5,8 +5,8 @@ from loguru import logger load_dotenv() -# More reliable string comparison -if os.getenv("SWARMS_VERBOSE_GLOBAL", "True").lower() == "false": +# Disable logging by default +if os.getenv("SWARMS_VERBOSE_GLOBAL", "False").lower() == "false": logger.disable("") # Import telemetry functions with error handling diff --git a/swarms/telemetry/auto_upgrade_swarms.py b/swarms/telemetry/auto_upgrade_swarms.py index 7203cd85..440f70ed 100644 --- a/swarms/telemetry/auto_upgrade_swarms.py +++ b/swarms/telemetry/auto_upgrade_swarms.py @@ -12,7 +12,7 @@ def auto_update(): try: # Check if auto-update is disabled auto_update_enabled = os.getenv( - "SWARMS_AUTOUPDATE_ON", "true" + "SWARMS_AUTOUPDATE_ON", "false" ).lower() if auto_update_enabled == "false": logger.info( diff --git a/swarms/tools/tool_utils.py b/swarms/tools/tool_utils.py index 972f98ec..b448d7a9 100644 --- a/swarms/tools/tool_utils.py +++ b/swarms/tools/tool_utils.py @@ -36,7 +36,7 @@ def scrape_tool_func_docs(fn: Callable) -> str: f" {inspect.getdoc(fn)}\nParameters:\n{parameters_str}" ) except Exception as error: - error_msg = ( + ( formatter.print_panel( f"Error scraping tool function docs {error} try" " optimizing your inputs with different" diff --git a/test.py b/test.py new file mode 100644 index 00000000..ce12ec1c --- /dev/null +++ b/test.py @@ -0,0 +1,292 @@ +import torch +import torch.nn as nn +import torch.distributed as dist +from dataclasses import dataclass +from typing import Optional, Tuple, Union +from loguru import logger +import math + + +@dataclass +class StarAttentionConfig: + """Configuration for StarAttention module. + + Attributes: + hidden_size: Dimension of the model's hidden states + num_attention_heads: Number of attention heads + num_hosts: Number of hosts in the distributed system + block_size: Size of each context block + anchor_size: Size of the anchor block + dropout_prob: Dropout probability (default: 0.1) + layer_norm_eps: Layer normalization epsilon (default: 1e-12) + """ + + hidden_size: int + num_attention_heads: int + num_hosts: int + block_size: int + anchor_size: int + dropout_prob: float = 0.1 + layer_norm_eps: float = 1e-12 + + +class StarAttention(nn.Module): + """ + Implementation of Star Attention mechanism for distributed inference. + + The module implements a two-phase attention mechanism: + 1. Local Context Encoding with Anchor Blocks + 2. Query Encoding and Output Generation with Global Attention + """ + + def __init__(self, config: StarAttentionConfig): + super().__init__() + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"Hidden size {config.hidden_size} not divisible by number of attention " + f"heads {config.num_attention_heads}" + ) + + self.config = config + self.head_dim = ( + config.hidden_size // config.num_attention_heads + ) + + # Initialize components + self.query = nn.Linear(config.hidden_size, config.hidden_size) + self.key = nn.Linear(config.hidden_size, config.hidden_size) + self.value = nn.Linear(config.hidden_size, config.hidden_size) + + self.dropout = nn.Dropout(config.dropout_prob) + self.layer_norm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + + # KV cache for storing computed key/value pairs + self.kv_cache = {} + + logger.info( + f"Initialized StarAttention with config: {config}" + ) + + def _split_heads( + self, tensor: torch.Tensor, num_heads: int + ) -> torch.Tensor: + """Split the last dimension into (num_heads, head_dim).""" + batch_size, seq_len, _ = tensor.size() + tensor = tensor.view( + batch_size, seq_len, num_heads, self.head_dim + ) + # Transpose to (batch_size, num_heads, seq_len, head_dim) + return tensor.transpose(1, 2) + + def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor: + """Merge the head dimension back into hidden_size.""" + batch_size, _, seq_len, _ = tensor.size() + tensor = tensor.transpose(1, 2) + return tensor.reshape( + batch_size, seq_len, self.config.hidden_size + ) + + def _compute_attention_scores( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute attention scores and weighted values.""" + # Scale dot-product attention + scores = torch.matmul( + query, key.transpose(-2, -1) + ) / math.sqrt(self.head_dim) + + if mask is not None: + scores = scores.masked_fill(mask == 0, float("-inf")) + + # Online softmax computation + attention_probs = torch.nn.functional.softmax(scores, dim=-1) + attention_probs = self.dropout(attention_probs) + + context = torch.matmul(attention_probs, value) + + return context, attention_probs + + def phase1_local_context_encoding( + self, + input_ids: torch.Tensor, + host_id: int, + device: Union[str, torch.device] = "cuda", + ) -> None: + """ + Phase 1: Local Context Encoding with Anchor Blocks + + Args: + input_ids: Input tensor of shape (batch_size, seq_len) + host_id: ID of the current host + device: Device to run computations on + """ + logger.debug(f"Starting Phase 1 on host {host_id}") + + # Calculate block assignments + block_start = host_id * self.config.block_size + block_end = block_start + self.config.block_size + + # Get local block + local_block = input_ids[:, block_start:block_end].to(device) + + # Get anchor block (first block) + anchor_block = input_ids[:, : self.config.anchor_size].to( + device + ) + + # Compute KV pairs for local block + local_hidden = self.layer_norm(local_block) + local_key = self._split_heads( + self.key(local_hidden), self.config.num_attention_heads + ) + local_value = self._split_heads( + self.value(local_hidden), self.config.num_attention_heads + ) + + # Store in KV cache + self.kv_cache[host_id] = { + "key": local_key, + "value": local_value, + "anchor_key": ( + None + if host_id == 0 + else self._split_heads( + self.key(self.layer_norm(anchor_block)), + self.config.num_attention_heads, + ) + ), + } + + logger.debug( + f"Phase 1 complete on host {host_id}. KV cache shapes - " + f"key: {local_key.shape}, value: {local_value.shape}" + ) + + def phase2_query_encoding( + self, + query_input: torch.Tensor, + host_id: int, + is_query_host: bool, + device: Union[str, torch.device] = "cuda", + ) -> Optional[torch.Tensor]: + """ + Phase 2: Query Encoding and Output Generation + + Args: + query_input: Query tensor of shape (batch_size, seq_len, hidden_size) + host_id: ID of the current host + is_query_host: Whether this host is the query host + device: Device to run computations on + + Returns: + Output tensor if this is the query host, None otherwise + """ + logger.debug(f"Starting Phase 2 on host {host_id}") + + # Transform query + query_hidden = self.layer_norm(query_input) + query = self._split_heads( + self.query(query_hidden), self.config.num_attention_heads + ) + + # Compute local attention scores + local_context, local_probs = self._compute_attention_scores( + query, + self.kv_cache[host_id]["key"], + self.kv_cache[host_id]["value"], + ) + + if not is_query_host: + # Non-query hosts send their local attention statistics + dist.send(local_probs, dst=self.config.num_hosts - 1) + return None + + # Query host aggregates attention from all hosts + all_attention_probs = [local_probs] + for src_rank in range(self.config.num_hosts - 1): + probs = torch.empty_like(local_probs) + dist.recv(probs, src=src_rank) + all_attention_probs.append(probs) + + # Compute global attention + torch.mean(torch.stack(all_attention_probs), dim=0) + + # Final output computation + output = self._merge_heads(local_context) + output = self.dropout(output) + + logger.debug( + f"Phase 2 complete on host {host_id}. Output shape: {output.shape}" + ) + + return output + + def forward( + self, + input_ids: torch.Tensor, + query_input: torch.Tensor, + host_id: int, + is_query_host: bool, + device: Union[str, torch.device] = "cuda", + ) -> Optional[torch.Tensor]: + """ + Forward pass of the StarAttention module. + + Args: + input_ids: Input tensor of shape (batch_size, seq_len) + query_input: Query tensor of shape (batch_size, seq_len, hidden_size) + host_id: ID of the current host + is_query_host: Whether this host is the query host + device: Device to run computations on + + Returns: + Output tensor if this is the query host, None otherwise + """ + # Phase 1: Local Context Encoding + self.phase1_local_context_encoding(input_ids, host_id, device) + + # Phase 2: Query Encoding and Output Generation + return self.phase2_query_encoding( + query_input, host_id, is_query_host, device + ) + + +# Example forward pass +config = StarAttentionConfig( + hidden_size=768, + num_attention_heads=12, + num_hosts=3, + block_size=512, + anchor_size=128, +) + +# Initialize model +model = StarAttention(config) + +# Example input tensors +batch_size = 4 +seq_len = 512 +input_ids = torch.randint( + 0, 1000, (batch_size, seq_len) +) # Random input IDs +query_input = torch.randn( + batch_size, seq_len, config.hidden_size +) # Random query input + +# Example forward pass for query host (host_id = 2) +output = model( + input_ids=input_ids, + query_input=query_input, + host_id=2, + is_query_host=True, + device="cpu", +) + +print(output) From 8b9e424dd5ed6c950fbecbac5132384caca7bff2 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 27 Nov 2024 16:56:07 -0800 Subject: [PATCH 5/6] [DOCS][CLEANUP] --- docs/assets/css/extra.css | 45 ++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/docs/assets/css/extra.css b/docs/assets/css/extra.css index d11f77f5..8a8a758f 100644 --- a/docs/assets/css/extra.css +++ b/docs/assets/css/extra.css @@ -1,26 +1,27 @@ /* * Further customization as needed */ */ - .md-typeset__table { min-width: 100%; - } - - .md-typeset table:not([class]) { - display: table; - } - - /* Dark mode */ - [data-md-color-scheme="slate"] { - --md-default-bg-color: black; - } - - .header__ellipsis { - color: black; - } - - /* - :root { - --md-primary-fg-color: #EE0F0F; - --md-primary-fg-color--light: #ECB7B7; - --md-primary-fg-color--dark: #90030C; - } */ \ No newline at end of file +} + +.md-typeset table:not([class]) { + display: table; +} + +/* Dark mode */ +[data-md-color-scheme="slate"] { + --md-default-bg-color: black; +} + +.header__ellipsis { + color: black; +} + +.md-copyright__highlight { + color: black; +} + + +.md-header.md-header--shadow { + color: black; +} \ No newline at end of file From 9762ee2b03d127c4b095545b66a2d387512ed40d Mon Sep 17 00:00:00 2001 From: evelynmitchell Date: Wed, 27 Nov 2024 20:38:47 -0700 Subject: [PATCH 6/6] remove yamllint --- .github/workflows/lint.yml | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index f2295d07..8787d772 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,14 +1,7 @@ --- name: Lint -on: [push, pull_request] # yamllint disable-line rule:truthy +on: [push, pull_request] jobs: - yaml-lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - - run: pip install yamllint - - run: yamllint . flake8-lint: runs-on: ubuntu-latest steps: