parent
edc293cb6f
commit
f79bc9de99
@ -1,228 +1,132 @@
|
||||
import heapq
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import List, Dict, Tuple, Any, Optional, Callable
|
||||
from swarms import Agent
|
||||
from typing import List, Optional, Callable
|
||||
from swarms.structs.base_swarm import BaseSwarm
|
||||
from swarms.utils.loguru_logger import logger
|
||||
|
||||
|
||||
class AStarSwarm(BaseSwarm):
|
||||
class AStarSwarm:
|
||||
def __init__(
|
||||
self,
|
||||
root_agent: Agent,
|
||||
child_agents: Optional[List[Agent]] = None,
|
||||
heuristic: Optional[Callable[[Agent], float]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
agents: List[Agent],
|
||||
communication_costs: Optional[Dict[Tuple[str, str], float]] = None,
|
||||
heuristic: Optional[Callable[[Agent, str], float]] = None,
|
||||
):
|
||||
"""
|
||||
Initializes the A* Swarm with a root agent and optionally a list of child agents.
|
||||
self.agents = agents
|
||||
self.communication_costs = communication_costs or {} # Default to no cost
|
||||
self.heuristic = heuristic or self.default_heuristic
|
||||
self.graph = self._build_communication_graph()
|
||||
|
||||
def _build_communication_graph(self) -> nx.Graph:
|
||||
graph = nx.Graph()
|
||||
for agent in self.agents:
|
||||
graph.add_node(agent.agent_name)
|
||||
|
||||
# Add edges with communication costs (if provided)
|
||||
for (agent1_name, agent2_name), cost in self.communication_costs.items():
|
||||
if agent1_name in graph.nodes and agent2_name in graph.nodes:
|
||||
graph.add_edge(agent1_name, agent2_name, weight=cost)
|
||||
|
||||
Args:
|
||||
root_agent (Agent): The root agent in the swarm.
|
||||
child_agents (Optional[List[Agent]]): List of child agents.
|
||||
return graph
|
||||
|
||||
|
||||
|
||||
def a_star_search(self, start_agent: Agent, task: str, goal_agent: Optional[Agent]=None) -> Optional[List[Agent]]:
|
||||
"""
|
||||
self.root_agent = root_agent
|
||||
self.child_agents = child_agents
|
||||
self.heuristic = heuristic
|
||||
self.child_agents = (
|
||||
child_agents if child_agents is not None else []
|
||||
)
|
||||
self.parent_map = {
|
||||
agent: root_agent for agent in self.child_agents
|
||||
}
|
||||
|
||||
def a_star_communicate(
|
||||
self,
|
||||
agent: Agent,
|
||||
task: str,
|
||||
) -> str:
|
||||
Performs A* search to find a path to the goal agent or all agents.
|
||||
|
||||
"""
|
||||
Distributes the task among agents using A* search-like communication.
|
||||
|
||||
Args:
|
||||
agent (Agent): The agent to start the communication from.
|
||||
task (str): The task to distribute and process.
|
||||
heuristic (Callable[[Agent], float], optional): Function to prioritize which agent to communicate with first.
|
||||
open_set = [(0, start_agent.agent_name)]
|
||||
came_from = {}
|
||||
g_score = {agent.agent_name: float('inf') for agent in self.agents}
|
||||
g_score[start_agent.agent_name] = 0
|
||||
f_score = {agent.agent_name: float('inf') for agent in self.agents}
|
||||
f_score[start_agent.agent_name] = self.heuristic(start_agent, task)
|
||||
|
||||
Returns:
|
||||
str: The result of the task after processing.
|
||||
"""
|
||||
# Perform the task at the current agent
|
||||
result = agent.run(task)
|
||||
while open_set:
|
||||
_, current_agent_name = heapq.heappop(open_set)
|
||||
|
||||
# Base case: if no child agents, return the result
|
||||
if agent not in self.parent_map.values():
|
||||
return result
|
||||
|
||||
# Gather child agents
|
||||
children = [
|
||||
child
|
||||
for child, parent in self.parent_map.items()
|
||||
if parent == agent
|
||||
]
|
||||
if goal_agent and current_agent_name == goal_agent.agent_name: # Stop if specific goal agent is reached
|
||||
return self._reconstruct_path(came_from, current_agent_name)
|
||||
elif not goal_agent and len(came_from) == len(self.agents) -1: # Stop if all agents (except the starting one) are reached
|
||||
return self._reconstruct_path(came_from, current_agent_name)
|
||||
|
||||
# Sort children based on the heuristic (if provided)
|
||||
if self.heuristic:
|
||||
children.sort(key=self.heuristic, reverse=True)
|
||||
|
||||
# Communicate with child agents
|
||||
for child in children:
|
||||
sub_result = self.a_star_communicate(
|
||||
child, task, self.heuristic
|
||||
)
|
||||
result += f"\n{sub_result}"
|
||||
for neighbor_name in self.graph.neighbors(current_agent_name):
|
||||
weight = self.graph[current_agent_name][neighbor_name].get('weight', 1) # Default weight is 1
|
||||
tentative_g_score = g_score[current_agent_name] + weight
|
||||
|
||||
return result
|
||||
if tentative_g_score < g_score[neighbor_name]:
|
||||
came_from[neighbor_name] = current_agent_name
|
||||
g_score[neighbor_name] = tentative_g_score
|
||||
|
||||
def visualize(self):
|
||||
"""
|
||||
Visualizes the communication flow between agents in the swarm using networkx and matplotlib.
|
||||
"""
|
||||
graph = nx.DiGraph()
|
||||
neighbor_agent = self.get_agent_by_name(neighbor_name)
|
||||
if neighbor_agent:
|
||||
f_score[neighbor_name] = tentative_g_score + self.heuristic(neighbor_agent, task)
|
||||
if (f_score[neighbor_name], neighbor_name) not in open_set:
|
||||
heapq.heappush(open_set, (f_score[neighbor_name], neighbor_name))
|
||||
|
||||
# Add edges between the root agent and child agents
|
||||
for child in self.child_agents:
|
||||
graph.add_edge(
|
||||
self.root_agent.agent_name, child.agent_name
|
||||
)
|
||||
self._add_edges(graph, child)
|
||||
return None # No path found
|
||||
|
||||
# Draw the graph
|
||||
pos = nx.spring_layout(graph)
|
||||
plt.figure(figsize=(10, 8))
|
||||
nx.draw(
|
||||
graph,
|
||||
pos,
|
||||
with_labels=True,
|
||||
node_color="lightblue",
|
||||
font_size=10,
|
||||
node_size=3000,
|
||||
font_weight="bold",
|
||||
edge_color="gray",
|
||||
)
|
||||
plt.title("Communication Flow Between Agents")
|
||||
plt.show()
|
||||
def _reconstruct_path(self, came_from: Dict[str, str], current: str) -> List[Agent]:
|
||||
path = [self.get_agent_by_name(current)]
|
||||
while current in came_from:
|
||||
current = came_from[current]
|
||||
path.insert(0, self.get_agent_by_name(current)) # Insert at beginning
|
||||
return path
|
||||
|
||||
def _add_edges(self, graph: nx.DiGraph, agent: Agent):
|
||||
"""
|
||||
Recursively adds edges to the graph for the given agent.
|
||||
def get_agent_by_name(self, name:str) -> Optional[Agent]:
|
||||
for agent in self.agents:
|
||||
if agent.agent_name == name:
|
||||
return agent
|
||||
return None
|
||||
|
||||
Args:
|
||||
graph (nx.DiGraph): The graph to add edges to.
|
||||
agent (Agent): The current agent.
|
||||
"""
|
||||
children = [
|
||||
child
|
||||
for child, parent in self.parent_map.items()
|
||||
if parent == agent
|
||||
]
|
||||
for child in children:
|
||||
graph.add_edge(agent.agent_name, child.agent_name)
|
||||
self._add_edges(graph, child)
|
||||
|
||||
def run(
|
||||
self,
|
||||
task: str,
|
||||
) -> str:
|
||||
"""
|
||||
Start the task from the root agent using A* communication.
|
||||
def default_heuristic(self, agent: Agent, task: str) -> float:
|
||||
return 0 # Default heuristic (equivalent to Dijkstra's algorithm)
|
||||
|
||||
Args:
|
||||
task (str): The task to execute.
|
||||
heuristic (Callable[[Agent], float], optional): Heuristic for A* communication.
|
||||
|
||||
Returns:
|
||||
str: The result of the task after processing.
|
||||
"""
|
||||
return self.a_star_communicate(
|
||||
self.root_agent, task, self.heuristic
|
||||
)
|
||||
|
||||
|
||||
# # Heuristic example (can be customized)
|
||||
# def example_heuristic(agent: Agent) -> float:
|
||||
# """
|
||||
# Example heuristic that prioritizes agents based on some custom logic.
|
||||
|
||||
# Args:
|
||||
# agent (Agent): The agent to evaluate.
|
||||
|
||||
# Returns:
|
||||
# float: The priority score for the agent.
|
||||
# """
|
||||
# # Example heuristic: prioritize based on the length of the agent's name (as a proxy for complexity)
|
||||
# return len(agent.agent_name)
|
||||
|
||||
|
||||
# # Set up the model as provided
|
||||
# api_key = os.getenv("OPENAI_API_KEY")
|
||||
# model = OpenAIChat(
|
||||
# api_key=api_key, model_name="gpt-4o-mini", temperature=0.1
|
||||
# )
|
||||
|
||||
# # Initialize root agent
|
||||
# root_agent = Agent(
|
||||
# agent_name="Financial-Analysis-Agent",
|
||||
# system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
|
||||
# llm=model,
|
||||
# max_loops=2,
|
||||
# autosave=True,
|
||||
# dashboard=False,
|
||||
# verbose=True,
|
||||
# streaming_on=True,
|
||||
# dynamic_temperature_enabled=True,
|
||||
# saved_state_path="finance_agent.json",
|
||||
# user_name="swarms_corp",
|
||||
# retry_attempts=3,
|
||||
# context_length=200000,
|
||||
# )
|
||||
|
||||
# # List of child agents
|
||||
# child_agents = [
|
||||
# Agent(
|
||||
# agent_name="Child-Agent-1",
|
||||
# system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
|
||||
# llm=model,
|
||||
# max_loops=2,
|
||||
# autosave=True,
|
||||
# dashboard=False,
|
||||
# verbose=True,
|
||||
# streaming_on=True,
|
||||
# dynamic_temperature_enabled=True,
|
||||
# saved_state_path="finance_agent_child_1.json",
|
||||
# user_name="swarms_corp",
|
||||
# retry_attempts=3,
|
||||
# context_length=200000,
|
||||
# ),
|
||||
# Agent(
|
||||
# agent_name="Child-Agent-2",
|
||||
# system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
|
||||
# llm=model,
|
||||
# max_loops=2,
|
||||
# autosave=True,
|
||||
# dashboard=False,
|
||||
# verbose=True,
|
||||
# streaming_on=True,
|
||||
# dynamic_temperature_enabled=True,
|
||||
# saved_state_path="finance_agent_child_2.json",
|
||||
# user_name="swarms_corp",
|
||||
# retry_attempts=3,
|
||||
# context_length=200000,
|
||||
# ),
|
||||
# ]
|
||||
|
||||
# # Create the A* swarm
|
||||
# swarm = AStarSwarm(
|
||||
# root_agent=root_agent,
|
||||
# child_agents=child_agents,
|
||||
# heauristic=example_heuristic,
|
||||
# )
|
||||
|
||||
# # Run the task with the heuristic
|
||||
# result = swarm.run(
|
||||
# "What are the components of a startups stock incentive equity plan",
|
||||
# )
|
||||
# print(result)
|
||||
|
||||
# # Visualize the communication flow
|
||||
# swarm.visualize()
|
||||
|
||||
def run(self, task: str, start_agent_name: str, goal_agent_name:Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
start_agent = self.get_agent_by_name(start_agent_name)
|
||||
goal_agent = self.get_agent_by_name(goal_agent_name) if goal_agent_name else None
|
||||
|
||||
if not start_agent:
|
||||
logger.error(f"Start agent '{start_agent_name}' not found.")
|
||||
return []
|
||||
|
||||
if goal_agent_name and not goal_agent:
|
||||
logger.error(f"Goal agent '{goal_agent_name}' not found.")
|
||||
return []
|
||||
|
||||
agent_path = self.a_star_search(start_agent, task, goal_agent)
|
||||
|
||||
results = []
|
||||
if agent_path:
|
||||
current_input = task
|
||||
for agent in agent_path:
|
||||
logger.info(f"Agent {agent.agent_name} processing task: {current_input}")
|
||||
try:
|
||||
result = agent.run(current_input)
|
||||
results.append({"agent": agent.agent_name, "task": current_input, "result": result})
|
||||
current_input = str(result) # Pass output to the next agent
|
||||
except Exception as e:
|
||||
logger.error(f"Agent {agent.agent_name} encountered an error: {e}")
|
||||
results.append({"agent": agent.agent_name, "task": current_input, "result": f"Error: {e}"})
|
||||
break # Stop processing if an agent fails
|
||||
else:
|
||||
logger.warning("No path found between agents.")
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def visualize(self):
|
||||
pos = nx.spring_layout(self.graph)
|
||||
plt.figure(figsize=(10, 8))
|
||||
nx.draw(self.graph, pos, with_labels=True, node_color="lightblue", node_size=3000, font_weight="bold")
|
||||
edge_labels = nx.get_edge_attributes(self.graph, 'weight')
|
||||
nx.draw_networkx_edge_labels(self.graph, pos, edge_labels=edge_labels) # Display edge weights
|
||||
plt.title("Agent Communication Graph")
|
||||
plt.show()
|
||||
|
Loading…
Reference in new issue