From f79bc9de993aff92c4b7eee71d112b2af3813c82 Mon Sep 17 00:00:00 2001 From: kirill670 <51964569+kirill670@users.noreply.github.com> Date: Fri, 4 Oct 2024 10:24:49 +0300 Subject: [PATCH] Update a_star_swarm.py --- .../swarms/experimental/a_star_swarm.py | 312 ++++++------------ 1 file changed, 108 insertions(+), 204 deletions(-) diff --git a/examples/structs/swarms/experimental/a_star_swarm.py b/examples/structs/swarms/experimental/a_star_swarm.py index d0f7bbc5..36231fdf 100644 --- a/examples/structs/swarms/experimental/a_star_swarm.py +++ b/examples/structs/swarms/experimental/a_star_swarm.py @@ -1,228 +1,132 @@ +import heapq import networkx as nx import matplotlib.pyplot as plt +from typing import List, Dict, Tuple, Any, Optional, Callable from swarms import Agent -from typing import List, Optional, Callable -from swarms.structs.base_swarm import BaseSwarm +from swarms.utils.loguru_logger import logger - -class AStarSwarm(BaseSwarm): +class AStarSwarm: def __init__( self, - root_agent: Agent, - child_agents: Optional[List[Agent]] = None, - heuristic: Optional[Callable[[Agent], float]] = None, - *args, - **kwargs, + agents: List[Agent], + communication_costs: Optional[Dict[Tuple[str, str], float]] = None, + heuristic: Optional[Callable[[Agent, str], float]] = None, ): - """ - Initializes the A* Swarm with a root agent and optionally a list of child agents. + self.agents = agents + self.communication_costs = communication_costs or {} # Default to no cost + self.heuristic = heuristic or self.default_heuristic + self.graph = self._build_communication_graph() + + def _build_communication_graph(self) -> nx.Graph: + graph = nx.Graph() + for agent in self.agents: + graph.add_node(agent.agent_name) + + # Add edges with communication costs (if provided) + for (agent1_name, agent2_name), cost in self.communication_costs.items(): + if agent1_name in graph.nodes and agent2_name in graph.nodes: + graph.add_edge(agent1_name, agent2_name, weight=cost) - Args: - root_agent (Agent): The root agent in the swarm. - child_agents (Optional[List[Agent]]): List of child agents. + return graph + + + + def a_star_search(self, start_agent: Agent, task: str, goal_agent: Optional[Agent]=None) -> Optional[List[Agent]]: """ - self.root_agent = root_agent - self.child_agents = child_agents - self.heuristic = heuristic - self.child_agents = ( - child_agents if child_agents is not None else [] - ) - self.parent_map = { - agent: root_agent for agent in self.child_agents - } - - def a_star_communicate( - self, - agent: Agent, - task: str, - ) -> str: + Performs A* search to find a path to the goal agent or all agents. + """ - Distributes the task among agents using A* search-like communication. - Args: - agent (Agent): The agent to start the communication from. - task (str): The task to distribute and process. - heuristic (Callable[[Agent], float], optional): Function to prioritize which agent to communicate with first. + open_set = [(0, start_agent.agent_name)] + came_from = {} + g_score = {agent.agent_name: float('inf') for agent in self.agents} + g_score[start_agent.agent_name] = 0 + f_score = {agent.agent_name: float('inf') for agent in self.agents} + f_score[start_agent.agent_name] = self.heuristic(start_agent, task) - Returns: - str: The result of the task after processing. - """ - # Perform the task at the current agent - result = agent.run(task) + while open_set: + _, current_agent_name = heapq.heappop(open_set) - # Base case: if no child agents, return the result - if agent not in self.parent_map.values(): - return result - # Gather child agents - children = [ - child - for child, parent in self.parent_map.items() - if parent == agent - ] + if goal_agent and current_agent_name == goal_agent.agent_name: # Stop if specific goal agent is reached + return self._reconstruct_path(came_from, current_agent_name) + elif not goal_agent and len(came_from) == len(self.agents) -1: # Stop if all agents (except the starting one) are reached + return self._reconstruct_path(came_from, current_agent_name) - # Sort children based on the heuristic (if provided) - if self.heuristic: - children.sort(key=self.heuristic, reverse=True) - # Communicate with child agents - for child in children: - sub_result = self.a_star_communicate( - child, task, self.heuristic - ) - result += f"\n{sub_result}" + for neighbor_name in self.graph.neighbors(current_agent_name): + weight = self.graph[current_agent_name][neighbor_name].get('weight', 1) # Default weight is 1 + tentative_g_score = g_score[current_agent_name] + weight - return result + if tentative_g_score < g_score[neighbor_name]: + came_from[neighbor_name] = current_agent_name + g_score[neighbor_name] = tentative_g_score - def visualize(self): - """ - Visualizes the communication flow between agents in the swarm using networkx and matplotlib. - """ - graph = nx.DiGraph() + neighbor_agent = self.get_agent_by_name(neighbor_name) + if neighbor_agent: + f_score[neighbor_name] = tentative_g_score + self.heuristic(neighbor_agent, task) + if (f_score[neighbor_name], neighbor_name) not in open_set: + heapq.heappush(open_set, (f_score[neighbor_name], neighbor_name)) - # Add edges between the root agent and child agents - for child in self.child_agents: - graph.add_edge( - self.root_agent.agent_name, child.agent_name - ) - self._add_edges(graph, child) + return None # No path found - # Draw the graph - pos = nx.spring_layout(graph) - plt.figure(figsize=(10, 8)) - nx.draw( - graph, - pos, - with_labels=True, - node_color="lightblue", - font_size=10, - node_size=3000, - font_weight="bold", - edge_color="gray", - ) - plt.title("Communication Flow Between Agents") - plt.show() + def _reconstruct_path(self, came_from: Dict[str, str], current: str) -> List[Agent]: + path = [self.get_agent_by_name(current)] + while current in came_from: + current = came_from[current] + path.insert(0, self.get_agent_by_name(current)) # Insert at beginning + return path - def _add_edges(self, graph: nx.DiGraph, agent: Agent): - """ - Recursively adds edges to the graph for the given agent. + def get_agent_by_name(self, name:str) -> Optional[Agent]: + for agent in self.agents: + if agent.agent_name == name: + return agent + return None - Args: - graph (nx.DiGraph): The graph to add edges to. - agent (Agent): The current agent. - """ - children = [ - child - for child, parent in self.parent_map.items() - if parent == agent - ] - for child in children: - graph.add_edge(agent.agent_name, child.agent_name) - self._add_edges(graph, child) - - def run( - self, - task: str, - ) -> str: - """ - Start the task from the root agent using A* communication. + def default_heuristic(self, agent: Agent, task: str) -> float: + return 0 # Default heuristic (equivalent to Dijkstra's algorithm) - Args: - task (str): The task to execute. - heuristic (Callable[[Agent], float], optional): Heuristic for A* communication. - Returns: - str: The result of the task after processing. - """ - return self.a_star_communicate( - self.root_agent, task, self.heuristic - ) - - -# # Heuristic example (can be customized) -# def example_heuristic(agent: Agent) -> float: -# """ -# Example heuristic that prioritizes agents based on some custom logic. - -# Args: -# agent (Agent): The agent to evaluate. - -# Returns: -# float: The priority score for the agent. -# """ -# # Example heuristic: prioritize based on the length of the agent's name (as a proxy for complexity) -# return len(agent.agent_name) - - -# # Set up the model as provided -# api_key = os.getenv("OPENAI_API_KEY") -# model = OpenAIChat( -# api_key=api_key, model_name="gpt-4o-mini", temperature=0.1 -# ) - -# # Initialize root agent -# root_agent = Agent( -# agent_name="Financial-Analysis-Agent", -# system_prompt=FINANCIAL_AGENT_SYS_PROMPT, -# llm=model, -# max_loops=2, -# autosave=True, -# dashboard=False, -# verbose=True, -# streaming_on=True, -# dynamic_temperature_enabled=True, -# saved_state_path="finance_agent.json", -# user_name="swarms_corp", -# retry_attempts=3, -# context_length=200000, -# ) - -# # List of child agents -# child_agents = [ -# Agent( -# agent_name="Child-Agent-1", -# system_prompt=FINANCIAL_AGENT_SYS_PROMPT, -# llm=model, -# max_loops=2, -# autosave=True, -# dashboard=False, -# verbose=True, -# streaming_on=True, -# dynamic_temperature_enabled=True, -# saved_state_path="finance_agent_child_1.json", -# user_name="swarms_corp", -# retry_attempts=3, -# context_length=200000, -# ), -# Agent( -# agent_name="Child-Agent-2", -# system_prompt=FINANCIAL_AGENT_SYS_PROMPT, -# llm=model, -# max_loops=2, -# autosave=True, -# dashboard=False, -# verbose=True, -# streaming_on=True, -# dynamic_temperature_enabled=True, -# saved_state_path="finance_agent_child_2.json", -# user_name="swarms_corp", -# retry_attempts=3, -# context_length=200000, -# ), -# ] - -# # Create the A* swarm -# swarm = AStarSwarm( -# root_agent=root_agent, -# child_agents=child_agents, -# heauristic=example_heuristic, -# ) - -# # Run the task with the heuristic -# result = swarm.run( -# "What are the components of a startups stock incentive equity plan", -# ) -# print(result) - -# # Visualize the communication flow -# swarm.visualize() + + def run(self, task: str, start_agent_name: str, goal_agent_name:Optional[str] = None) -> List[Dict[str, Any]]: + start_agent = self.get_agent_by_name(start_agent_name) + goal_agent = self.get_agent_by_name(goal_agent_name) if goal_agent_name else None + + if not start_agent: + logger.error(f"Start agent '{start_agent_name}' not found.") + return [] + + if goal_agent_name and not goal_agent: + logger.error(f"Goal agent '{goal_agent_name}' not found.") + return [] + + agent_path = self.a_star_search(start_agent, task, goal_agent) + + results = [] + if agent_path: + current_input = task + for agent in agent_path: + logger.info(f"Agent {agent.agent_name} processing task: {current_input}") + try: + result = agent.run(current_input) + results.append({"agent": agent.agent_name, "task": current_input, "result": result}) + current_input = str(result) # Pass output to the next agent + except Exception as e: + logger.error(f"Agent {agent.agent_name} encountered an error: {e}") + results.append({"agent": agent.agent_name, "task": current_input, "result": f"Error: {e}"}) + break # Stop processing if an agent fails + else: + logger.warning("No path found between agents.") + + + return results + + + def visualize(self): + pos = nx.spring_layout(self.graph) + plt.figure(figsize=(10, 8)) + nx.draw(self.graph, pos, with_labels=True, node_color="lightblue", node_size=3000, font_weight="bold") + edge_labels = nx.get_edge_attributes(self.graph, 'weight') + nx.draw_networkx_edge_labels(self.graph, pos, edge_labels=edge_labels) # Display edge weights + plt.title("Agent Communication Graph") + plt.show()