Update a_star_swarm.py

pull/594/head
kirill670 7 months ago committed by GitHub
parent edc293cb6f
commit f79bc9de99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,228 +1,132 @@
import heapq
import networkx as nx
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple, Any, Optional, Callable
from swarms import Agent
from typing import List, Optional, Callable
from swarms.structs.base_swarm import BaseSwarm
from swarms.utils.loguru_logger import logger
class AStarSwarm(BaseSwarm):
class AStarSwarm:
def __init__(
self,
root_agent: Agent,
child_agents: Optional[List[Agent]] = None,
heuristic: Optional[Callable[[Agent], float]] = None,
*args,
**kwargs,
agents: List[Agent],
communication_costs: Optional[Dict[Tuple[str, str], float]] = None,
heuristic: Optional[Callable[[Agent, str], float]] = None,
):
"""
Initializes the A* Swarm with a root agent and optionally a list of child agents.
self.agents = agents
self.communication_costs = communication_costs or {} # Default to no cost
self.heuristic = heuristic or self.default_heuristic
self.graph = self._build_communication_graph()
def _build_communication_graph(self) -> nx.Graph:
graph = nx.Graph()
for agent in self.agents:
graph.add_node(agent.agent_name)
# Add edges with communication costs (if provided)
for (agent1_name, agent2_name), cost in self.communication_costs.items():
if agent1_name in graph.nodes and agent2_name in graph.nodes:
graph.add_edge(agent1_name, agent2_name, weight=cost)
Args:
root_agent (Agent): The root agent in the swarm.
child_agents (Optional[List[Agent]]): List of child agents.
return graph
def a_star_search(self, start_agent: Agent, task: str, goal_agent: Optional[Agent]=None) -> Optional[List[Agent]]:
"""
self.root_agent = root_agent
self.child_agents = child_agents
self.heuristic = heuristic
self.child_agents = (
child_agents if child_agents is not None else []
)
self.parent_map = {
agent: root_agent for agent in self.child_agents
}
def a_star_communicate(
self,
agent: Agent,
task: str,
) -> str:
Performs A* search to find a path to the goal agent or all agents.
"""
Distributes the task among agents using A* search-like communication.
Args:
agent (Agent): The agent to start the communication from.
task (str): The task to distribute and process.
heuristic (Callable[[Agent], float], optional): Function to prioritize which agent to communicate with first.
open_set = [(0, start_agent.agent_name)]
came_from = {}
g_score = {agent.agent_name: float('inf') for agent in self.agents}
g_score[start_agent.agent_name] = 0
f_score = {agent.agent_name: float('inf') for agent in self.agents}
f_score[start_agent.agent_name] = self.heuristic(start_agent, task)
Returns:
str: The result of the task after processing.
"""
# Perform the task at the current agent
result = agent.run(task)
while open_set:
_, current_agent_name = heapq.heappop(open_set)
# Base case: if no child agents, return the result
if agent not in self.parent_map.values():
return result
# Gather child agents
children = [
child
for child, parent in self.parent_map.items()
if parent == agent
]
if goal_agent and current_agent_name == goal_agent.agent_name: # Stop if specific goal agent is reached
return self._reconstruct_path(came_from, current_agent_name)
elif not goal_agent and len(came_from) == len(self.agents) -1: # Stop if all agents (except the starting one) are reached
return self._reconstruct_path(came_from, current_agent_name)
# Sort children based on the heuristic (if provided)
if self.heuristic:
children.sort(key=self.heuristic, reverse=True)
# Communicate with child agents
for child in children:
sub_result = self.a_star_communicate(
child, task, self.heuristic
)
result += f"\n{sub_result}"
for neighbor_name in self.graph.neighbors(current_agent_name):
weight = self.graph[current_agent_name][neighbor_name].get('weight', 1) # Default weight is 1
tentative_g_score = g_score[current_agent_name] + weight
return result
if tentative_g_score < g_score[neighbor_name]:
came_from[neighbor_name] = current_agent_name
g_score[neighbor_name] = tentative_g_score
def visualize(self):
"""
Visualizes the communication flow between agents in the swarm using networkx and matplotlib.
"""
graph = nx.DiGraph()
neighbor_agent = self.get_agent_by_name(neighbor_name)
if neighbor_agent:
f_score[neighbor_name] = tentative_g_score + self.heuristic(neighbor_agent, task)
if (f_score[neighbor_name], neighbor_name) not in open_set:
heapq.heappush(open_set, (f_score[neighbor_name], neighbor_name))
# Add edges between the root agent and child agents
for child in self.child_agents:
graph.add_edge(
self.root_agent.agent_name, child.agent_name
)
self._add_edges(graph, child)
return None # No path found
# Draw the graph
pos = nx.spring_layout(graph)
plt.figure(figsize=(10, 8))
nx.draw(
graph,
pos,
with_labels=True,
node_color="lightblue",
font_size=10,
node_size=3000,
font_weight="bold",
edge_color="gray",
)
plt.title("Communication Flow Between Agents")
plt.show()
def _reconstruct_path(self, came_from: Dict[str, str], current: str) -> List[Agent]:
path = [self.get_agent_by_name(current)]
while current in came_from:
current = came_from[current]
path.insert(0, self.get_agent_by_name(current)) # Insert at beginning
return path
def _add_edges(self, graph: nx.DiGraph, agent: Agent):
"""
Recursively adds edges to the graph for the given agent.
def get_agent_by_name(self, name:str) -> Optional[Agent]:
for agent in self.agents:
if agent.agent_name == name:
return agent
return None
Args:
graph (nx.DiGraph): The graph to add edges to.
agent (Agent): The current agent.
"""
children = [
child
for child, parent in self.parent_map.items()
if parent == agent
]
for child in children:
graph.add_edge(agent.agent_name, child.agent_name)
self._add_edges(graph, child)
def run(
self,
task: str,
) -> str:
"""
Start the task from the root agent using A* communication.
def default_heuristic(self, agent: Agent, task: str) -> float:
return 0 # Default heuristic (equivalent to Dijkstra's algorithm)
Args:
task (str): The task to execute.
heuristic (Callable[[Agent], float], optional): Heuristic for A* communication.
Returns:
str: The result of the task after processing.
"""
return self.a_star_communicate(
self.root_agent, task, self.heuristic
)
# # Heuristic example (can be customized)
# def example_heuristic(agent: Agent) -> float:
# """
# Example heuristic that prioritizes agents based on some custom logic.
# Args:
# agent (Agent): The agent to evaluate.
# Returns:
# float: The priority score for the agent.
# """
# # Example heuristic: prioritize based on the length of the agent's name (as a proxy for complexity)
# return len(agent.agent_name)
# # Set up the model as provided
# api_key = os.getenv("OPENAI_API_KEY")
# model = OpenAIChat(
# api_key=api_key, model_name="gpt-4o-mini", temperature=0.1
# )
# # Initialize root agent
# root_agent = Agent(
# agent_name="Financial-Analysis-Agent",
# system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
# llm=model,
# max_loops=2,
# autosave=True,
# dashboard=False,
# verbose=True,
# streaming_on=True,
# dynamic_temperature_enabled=True,
# saved_state_path="finance_agent.json",
# user_name="swarms_corp",
# retry_attempts=3,
# context_length=200000,
# )
# # List of child agents
# child_agents = [
# Agent(
# agent_name="Child-Agent-1",
# system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
# llm=model,
# max_loops=2,
# autosave=True,
# dashboard=False,
# verbose=True,
# streaming_on=True,
# dynamic_temperature_enabled=True,
# saved_state_path="finance_agent_child_1.json",
# user_name="swarms_corp",
# retry_attempts=3,
# context_length=200000,
# ),
# Agent(
# agent_name="Child-Agent-2",
# system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
# llm=model,
# max_loops=2,
# autosave=True,
# dashboard=False,
# verbose=True,
# streaming_on=True,
# dynamic_temperature_enabled=True,
# saved_state_path="finance_agent_child_2.json",
# user_name="swarms_corp",
# retry_attempts=3,
# context_length=200000,
# ),
# ]
# # Create the A* swarm
# swarm = AStarSwarm(
# root_agent=root_agent,
# child_agents=child_agents,
# heauristic=example_heuristic,
# )
# # Run the task with the heuristic
# result = swarm.run(
# "What are the components of a startups stock incentive equity plan",
# )
# print(result)
# # Visualize the communication flow
# swarm.visualize()
def run(self, task: str, start_agent_name: str, goal_agent_name:Optional[str] = None) -> List[Dict[str, Any]]:
start_agent = self.get_agent_by_name(start_agent_name)
goal_agent = self.get_agent_by_name(goal_agent_name) if goal_agent_name else None
if not start_agent:
logger.error(f"Start agent '{start_agent_name}' not found.")
return []
if goal_agent_name and not goal_agent:
logger.error(f"Goal agent '{goal_agent_name}' not found.")
return []
agent_path = self.a_star_search(start_agent, task, goal_agent)
results = []
if agent_path:
current_input = task
for agent in agent_path:
logger.info(f"Agent {agent.agent_name} processing task: {current_input}")
try:
result = agent.run(current_input)
results.append({"agent": agent.agent_name, "task": current_input, "result": result})
current_input = str(result) # Pass output to the next agent
except Exception as e:
logger.error(f"Agent {agent.agent_name} encountered an error: {e}")
results.append({"agent": agent.agent_name, "task": current_input, "result": f"Error: {e}"})
break # Stop processing if an agent fails
else:
logger.warning("No path found between agents.")
return results
def visualize(self):
pos = nx.spring_layout(self.graph)
plt.figure(figsize=(10, 8))
nx.draw(self.graph, pos, with_labels=True, node_color="lightblue", node_size=3000, font_weight="bold")
edge_labels = nx.get_edge_attributes(self.graph, 'weight')
nx.draw_networkx_edge_labels(self.graph, pos, edge_labels=edge_labels) # Display edge weights
plt.title("Agent Communication Graph")
plt.show()

Loading…
Cancel
Save