swarms/swarms/structs/graph_workflow.py

264 lines
8.4 KiB

from enum import Enum
from typing import Any, Callable, Dict, List
import networkx as nx
from pydantic.v1 import BaseModel, Field, validator
from swarms.structs.agent import Agent # noqa: F401
from swarms.utils.loguru_logger import logger
class NodeType(str, Enum):
AGENT: Agent = "agent"
TASK: str = "task"
class Node(BaseModel):
"""
Represents a node in a graph workflow.
Attributes:
id (str): The unique identifier of the node.
type (NodeType): The type of the node.
callable (Callable, optional): The callable associated with the node. Required for task nodes.
agent (Any, optional): The agent associated with the node.
Raises:
ValueError: If the node type is TASK and no callable is provided.
Examples:
>>> node = Node(id="task1", type=NodeType.TASK, callable=sample_task)
>>> node = Node(id="agent1", type=NodeType.AGENT, agent=agent1)
>>> node = Node(id="agent2", type=NodeType.AGENT, agent=agent2)
"""
id: str
type: NodeType
callable: Callable = None
agent: Any = None
@validator("callable", always=True)
def validate_callable(cls, value, values):
if values["type"] == NodeType.TASK and value is None:
raise ValueError("Task nodes must have a callable.")
return value
class Edge(BaseModel):
source: str
target: str
class GraphWorkflow(BaseModel):
"""
Represents a workflow graph.
Attributes:
nodes (Dict[str, Node]): A dictionary of nodes in the graph, where the key is the node ID and the value is the Node object.
edges (List[Edge]): A list of edges in the graph, where each edge is represented by an Edge object.
entry_points (List[str]): A list of node IDs that serve as entry points to the graph.
end_points (List[str]): A list of node IDs that serve as end points of the graph.
graph (nx.DiGraph): A directed graph object from the NetworkX library representing the workflow graph.
"""
nodes: Dict[str, Node] = Field(default_factory=dict)
edges: List[Edge] = Field(default_factory=list)
entry_points: List[str] = Field(default_factory=list)
end_points: List[str] = Field(default_factory=list)
graph: nx.DiGraph = Field(
default_factory=nx.DiGraph, exclude=True
)
max_loops: int = 1
class Config:
arbitrary_types_allowed = True
def add_node(self, node: Node):
"""
Adds a node to the workflow graph.
Args:
node (Node): The node object to be added.
Raises:
ValueError: If a node with the same ID already exists in the graph.
"""
try:
if node.id in self.nodes:
raise ValueError(
f"Node with id {node.id} already exists."
)
self.nodes[node.id] = node
self.graph.add_node(
node.id,
type=node.type,
callable=node.callable,
agent=node.agent,
)
except Exception as e:
logger.info(f"Error in adding node to the workflow: {e}")
raise e
def add_edge(self, edge: Edge):
"""
Adds an edge to the workflow graph.
Args:
edge (Edge): The edge object to be added.
Raises:
ValueError: If either the source or target node of the edge does not exist in the graph.
"""
if (
edge.source not in self.nodes
or edge.target not in self.nodes
):
raise ValueError(
"Both source and target nodes must exist before adding an edge."
)
self.edges.append(edge)
self.graph.add_edge(edge.source, edge.target)
def set_entry_points(self, entry_points: List[str]):
"""
Sets the entry points of the workflow graph.
Args:
entry_points (List[str]): A list of node IDs to be set as entry points.
Raises:
ValueError: If any of the specified node IDs do not exist in the graph.
"""
for node_id in entry_points:
if node_id not in self.nodes:
raise ValueError(
f"Node with id {node_id} does not exist."
)
self.entry_points = entry_points
def set_end_points(self, end_points: List[str]):
"""
Sets the end points of the workflow graph.
Args:
end_points (List[str]): A list of node IDs to be set as end points.
Raises:
ValueError: If any of the specified node IDs do not exist in the graph.
"""
for node_id in end_points:
if node_id not in self.nodes:
raise ValueError(
f"Node with id {node_id} does not exist."
)
self.end_points = end_points
def visualize(self) -> str:
"""
Generates a string representation of the workflow graph in the Mermaid syntax.
Returns:
str: The Mermaid string representation of the workflow graph.
"""
mermaid_str = "graph TD\n"
for node_id, node in self.nodes.items():
mermaid_str += f" {node_id}[{node_id}]\n"
for edge in self.edges:
mermaid_str += f" {edge.source} --> {edge.target}\n"
return mermaid_str
def run(
self, task: str = None, *args, **kwargs
) -> Dict[str, Any]:
"""
Function to run the workflow graph.
Args:
task (str): The task to be executed by the workflow.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
Dict[str, Any]: A dictionary containing the results of the execution.
Raises:
ValueError: If no entry points or end points are defined in the graph.
"""
try:
loop = 0
while loop < self.max_loops:
# Ensure all nodes and edges are valid
if not self.entry_points:
raise ValueError(
"At least one entry point must be defined."
)
if not self.end_points:
raise ValueError(
"At least one end point must be defined."
)
# Perform a topological sort of the graph to ensure proper execution order
sorted_nodes = list(nx.topological_sort(self.graph))
# Initialize execution state
execution_results = {}
for node_id in sorted_nodes:
node = self.nodes[node_id]
if node.type == NodeType.TASK:
print(f"Executing task: {node_id}")
result = node.callable()
elif node.type == NodeType.AGENT:
print(f"Executing agent: {node_id}")
result = node.agent.run(task, *args, **kwargs)
execution_results[node_id] = result
loop += 1
return execution_results
except Exception as e:
logger.info(f"Error in running the workflow: {e}")
raise e
# # Example usage
# if __name__ == "__main__":
# from swarms import Agent
# import os
# from dotenv import load_dotenv
# load_dotenv()
# api_key = os.environ.get("OPENAI_API_KEY")
# llm = OpenAIChat(
# temperature=0.5, openai_api_key=api_key, max_tokens=4000
# )
# agent1 = Agent(llm=llm, max_loops=1, autosave=True, dashboard=True)
# agent2 = Agent(llm=llm, max_loops=1, autosave=True, dashboard=True)
# def sample_task():
# print("Running sample task")
# return "Task completed"
# wf_graph = GraphWorkflow()
# wf_graph.add_node(Node(id="agent1", type=NodeType.AGENT, agent=agent1))
# wf_graph.add_node(Node(id="agent2", type=NodeType.AGENT, agent=agent2))
# wf_graph.add_node(
# Node(id="task1", type=NodeType.TASK, callable=sample_task)
# )
# wf_graph.add_edge(Edge(source="agent1", target="task1"))
# wf_graph.add_edge(Edge(source="agent2", target="task1"))
# wf_graph.set_entry_points(["agent1", "agent2"])
# wf_graph.set_end_points(["task1"])
# print(wf_graph.visualize())
# # Run the workflow
# results = wf_graph.run()
# print("Execution results:", results)