Update a_star_swarm.py

pull/594/head
kirill670 7 months ago committed by GitHub
parent edc293cb6f
commit f79bc9de99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,228 +1,132 @@
import heapq
import networkx as nx import networkx as nx
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Any, Optional, Callable
from swarms import Agent from swarms import Agent
from typing import List, Optional, Callable from swarms.utils.loguru_logger import logger
from swarms.structs.base_swarm import BaseSwarm
class AStarSwarm:
class AStarSwarm(BaseSwarm):
def __init__( def __init__(
self, self,
root_agent: Agent, agents: List[Agent],
child_agents: Optional[List[Agent]] = None, communication_costs: Optional[Dict[Tuple[str, str], float]] = None,
heuristic: Optional[Callable[[Agent], float]] = None, heuristic: Optional[Callable[[Agent, str], float]] = None,
*args,
**kwargs,
): ):
""" self.agents = agents
Initializes the A* Swarm with a root agent and optionally a list of child agents. self.communication_costs = communication_costs or {} # Default to no cost
self.heuristic = heuristic or self.default_heuristic
self.graph = self._build_communication_graph()
def _build_communication_graph(self) -> nx.Graph:
graph = nx.Graph()
for agent in self.agents:
graph.add_node(agent.agent_name)
# Add edges with communication costs (if provided)
for (agent1_name, agent2_name), cost in self.communication_costs.items():
if agent1_name in graph.nodes and agent2_name in graph.nodes:
graph.add_edge(agent1_name, agent2_name, weight=cost)
Args: return graph
root_agent (Agent): The root agent in the swarm.
child_agents (Optional[List[Agent]]): List of child agents.
def a_star_search(self, start_agent: Agent, task: str, goal_agent: Optional[Agent]=None) -> Optional[List[Agent]]:
""" """
self.root_agent = root_agent Performs A* search to find a path to the goal agent or all agents.
self.child_agents = child_agents
self.heuristic = heuristic
self.child_agents = (
child_agents if child_agents is not None else []
)
self.parent_map = {
agent: root_agent for agent in self.child_agents
}
def a_star_communicate(
self,
agent: Agent,
task: str,
) -> str:
""" """
Distributes the task among agents using A* search-like communication.
Args: open_set = [(0, start_agent.agent_name)]
agent (Agent): The agent to start the communication from. came_from = {}
task (str): The task to distribute and process. g_score = {agent.agent_name: float('inf') for agent in self.agents}
heuristic (Callable[[Agent], float], optional): Function to prioritize which agent to communicate with first. g_score[start_agent.agent_name] = 0
f_score = {agent.agent_name: float('inf') for agent in self.agents}
f_score[start_agent.agent_name] = self.heuristic(start_agent, task)
Returns: while open_set:
str: The result of the task after processing. _, current_agent_name = heapq.heappop(open_set)
"""
# Perform the task at the current agent
result = agent.run(task)
# Base case: if no child agents, return the result
if agent not in self.parent_map.values():
return result
# Gather child agents if goal_agent and current_agent_name == goal_agent.agent_name: # Stop if specific goal agent is reached
children = [ return self._reconstruct_path(came_from, current_agent_name)
child elif not goal_agent and len(came_from) == len(self.agents) -1: # Stop if all agents (except the starting one) are reached
for child, parent in self.parent_map.items() return self._reconstruct_path(came_from, current_agent_name)
if parent == agent
]
# Sort children based on the heuristic (if provided)
if self.heuristic:
children.sort(key=self.heuristic, reverse=True)
# Communicate with child agents for neighbor_name in self.graph.neighbors(current_agent_name):
for child in children: weight = self.graph[current_agent_name][neighbor_name].get('weight', 1) # Default weight is 1
sub_result = self.a_star_communicate( tentative_g_score = g_score[current_agent_name] + weight
child, task, self.heuristic
)
result += f"\n{sub_result}"
return result if tentative_g_score < g_score[neighbor_name]:
came_from[neighbor_name] = current_agent_name
g_score[neighbor_name] = tentative_g_score
def visualize(self): neighbor_agent = self.get_agent_by_name(neighbor_name)
""" if neighbor_agent:
Visualizes the communication flow between agents in the swarm using networkx and matplotlib. f_score[neighbor_name] = tentative_g_score + self.heuristic(neighbor_agent, task)
""" if (f_score[neighbor_name], neighbor_name) not in open_set:
graph = nx.DiGraph() heapq.heappush(open_set, (f_score[neighbor_name], neighbor_name))
# Add edges between the root agent and child agents return None # No path found
for child in self.child_agents:
graph.add_edge(
self.root_agent.agent_name, child.agent_name
)
self._add_edges(graph, child)
# Draw the graph def _reconstruct_path(self, came_from: Dict[str, str], current: str) -> List[Agent]:
pos = nx.spring_layout(graph) path = [self.get_agent_by_name(current)]
plt.figure(figsize=(10, 8)) while current in came_from:
nx.draw( current = came_from[current]
graph, path.insert(0, self.get_agent_by_name(current)) # Insert at beginning
pos, return path
with_labels=True,
node_color="lightblue",
font_size=10,
node_size=3000,
font_weight="bold",
edge_color="gray",
)
plt.title("Communication Flow Between Agents")
plt.show()
def _add_edges(self, graph: nx.DiGraph, agent: Agent): def get_agent_by_name(self, name:str) -> Optional[Agent]:
""" for agent in self.agents:
Recursively adds edges to the graph for the given agent. if agent.agent_name == name:
return agent
return None
Args: def default_heuristic(self, agent: Agent, task: str) -> float:
graph (nx.DiGraph): The graph to add edges to. return 0 # Default heuristic (equivalent to Dijkstra's algorithm)
agent (Agent): The current agent.
"""
children = [
child
for child, parent in self.parent_map.items()
if parent == agent
]
for child in children:
graph.add_edge(agent.agent_name, child.agent_name)
self._add_edges(graph, child)
def run(
self,
task: str,
) -> str:
"""
Start the task from the root agent using A* communication.
Args:
task (str): The task to execute.
heuristic (Callable[[Agent], float], optional): Heuristic for A* communication.
Returns:
str: The result of the task after processing. def run(self, task: str, start_agent_name: str, goal_agent_name:Optional[str] = None) -> List[Dict[str, Any]]:
""" start_agent = self.get_agent_by_name(start_agent_name)
return self.a_star_communicate( goal_agent = self.get_agent_by_name(goal_agent_name) if goal_agent_name else None
self.root_agent, task, self.heuristic
) if not start_agent:
logger.error(f"Start agent '{start_agent_name}' not found.")
return []
# # Heuristic example (can be customized)
# def example_heuristic(agent: Agent) -> float: if goal_agent_name and not goal_agent:
# """ logger.error(f"Goal agent '{goal_agent_name}' not found.")
# Example heuristic that prioritizes agents based on some custom logic. return []
# Args: agent_path = self.a_star_search(start_agent, task, goal_agent)
# agent (Agent): The agent to evaluate.
results = []
# Returns: if agent_path:
# float: The priority score for the agent. current_input = task
# """ for agent in agent_path:
# # Example heuristic: prioritize based on the length of the agent's name (as a proxy for complexity) logger.info(f"Agent {agent.agent_name} processing task: {current_input}")
# return len(agent.agent_name) try:
result = agent.run(current_input)
results.append({"agent": agent.agent_name, "task": current_input, "result": result})
# # Set up the model as provided current_input = str(result) # Pass output to the next agent
# api_key = os.getenv("OPENAI_API_KEY") except Exception as e:
# model = OpenAIChat( logger.error(f"Agent {agent.agent_name} encountered an error: {e}")
# api_key=api_key, model_name="gpt-4o-mini", temperature=0.1 results.append({"agent": agent.agent_name, "task": current_input, "result": f"Error: {e}"})
# ) break # Stop processing if an agent fails
else:
# # Initialize root agent logger.warning("No path found between agents.")
# root_agent = Agent(
# agent_name="Financial-Analysis-Agent",
# system_prompt=FINANCIAL_AGENT_SYS_PROMPT, return results
# llm=model,
# max_loops=2,
# autosave=True, def visualize(self):
# dashboard=False, pos = nx.spring_layout(self.graph)
# verbose=True, plt.figure(figsize=(10, 8))
# streaming_on=True, nx.draw(self.graph, pos, with_labels=True, node_color="lightblue", node_size=3000, font_weight="bold")
# dynamic_temperature_enabled=True, edge_labels = nx.get_edge_attributes(self.graph, 'weight')
# saved_state_path="finance_agent.json", nx.draw_networkx_edge_labels(self.graph, pos, edge_labels=edge_labels) # Display edge weights
# user_name="swarms_corp", plt.title("Agent Communication Graph")
# retry_attempts=3, plt.show()
# context_length=200000,
# )
# # List of child agents
# child_agents = [
# Agent(
# agent_name="Child-Agent-1",
# system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
# llm=model,
# max_loops=2,
# autosave=True,
# dashboard=False,
# verbose=True,
# streaming_on=True,
# dynamic_temperature_enabled=True,
# saved_state_path="finance_agent_child_1.json",
# user_name="swarms_corp",
# retry_attempts=3,
# context_length=200000,
# ),
# Agent(
# agent_name="Child-Agent-2",
# system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
# llm=model,
# max_loops=2,
# autosave=True,
# dashboard=False,
# verbose=True,
# streaming_on=True,
# dynamic_temperature_enabled=True,
# saved_state_path="finance_agent_child_2.json",
# user_name="swarms_corp",
# retry_attempts=3,
# context_length=200000,
# ),
# ]
# # Create the A* swarm
# swarm = AStarSwarm(
# root_agent=root_agent,
# child_agents=child_agents,
# heauristic=example_heuristic,
# )
# # Run the task with the heuristic
# result = swarm.run(
# "What are the components of a startups stock incentive equity plan",
# )
# print(result)
# # Visualize the communication flow
# swarm.visualize()

Loading…
Cancel
Save