264 lines
8.4 KiB
264 lines
8.4 KiB
from enum import Enum
|
|
from typing import Any, Callable, Dict, List
|
|
|
|
import networkx as nx
|
|
from pydantic.v1 import BaseModel, Field, validator
|
|
|
|
from swarms.structs.agent import Agent # noqa: F401
|
|
from swarms.utils.loguru_logger import logger
|
|
|
|
|
|
class NodeType(str, Enum):
|
|
AGENT: Agent = "agent"
|
|
TASK: str = "task"
|
|
|
|
|
|
class Node(BaseModel):
|
|
"""
|
|
Represents a node in a graph workflow.
|
|
|
|
Attributes:
|
|
id (str): The unique identifier of the node.
|
|
type (NodeType): The type of the node.
|
|
callable (Callable, optional): The callable associated with the node. Required for task nodes.
|
|
agent (Any, optional): The agent associated with the node.
|
|
|
|
Raises:
|
|
ValueError: If the node type is TASK and no callable is provided.
|
|
|
|
Examples:
|
|
>>> node = Node(id="task1", type=NodeType.TASK, callable=sample_task)
|
|
>>> node = Node(id="agent1", type=NodeType.AGENT, agent=agent1)
|
|
>>> node = Node(id="agent2", type=NodeType.AGENT, agent=agent2)
|
|
|
|
"""
|
|
|
|
id: str
|
|
type: NodeType
|
|
callable: Callable = None
|
|
agent: Any = None
|
|
|
|
@validator("callable", always=True)
|
|
def validate_callable(cls, value, values):
|
|
if values["type"] == NodeType.TASK and value is None:
|
|
raise ValueError("Task nodes must have a callable.")
|
|
return value
|
|
|
|
|
|
class Edge(BaseModel):
|
|
source: str
|
|
target: str
|
|
|
|
|
|
class GraphWorkflow(BaseModel):
|
|
"""
|
|
Represents a workflow graph.
|
|
|
|
Attributes:
|
|
nodes (Dict[str, Node]): A dictionary of nodes in the graph, where the key is the node ID and the value is the Node object.
|
|
edges (List[Edge]): A list of edges in the graph, where each edge is represented by an Edge object.
|
|
entry_points (List[str]): A list of node IDs that serve as entry points to the graph.
|
|
end_points (List[str]): A list of node IDs that serve as end points of the graph.
|
|
graph (nx.DiGraph): A directed graph object from the NetworkX library representing the workflow graph.
|
|
"""
|
|
|
|
nodes: Dict[str, Node] = Field(default_factory=dict)
|
|
edges: List[Edge] = Field(default_factory=list)
|
|
entry_points: List[str] = Field(default_factory=list)
|
|
end_points: List[str] = Field(default_factory=list)
|
|
graph: nx.DiGraph = Field(
|
|
default_factory=nx.DiGraph, exclude=True
|
|
)
|
|
max_loops: int = 1
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
def add_node(self, node: Node):
|
|
"""
|
|
Adds a node to the workflow graph.
|
|
|
|
Args:
|
|
node (Node): The node object to be added.
|
|
|
|
Raises:
|
|
ValueError: If a node with the same ID already exists in the graph.
|
|
"""
|
|
try:
|
|
if node.id in self.nodes:
|
|
raise ValueError(
|
|
f"Node with id {node.id} already exists."
|
|
)
|
|
self.nodes[node.id] = node
|
|
self.graph.add_node(
|
|
node.id,
|
|
type=node.type,
|
|
callable=node.callable,
|
|
agent=node.agent,
|
|
)
|
|
except Exception as e:
|
|
logger.info(f"Error in adding node to the workflow: {e}")
|
|
raise e
|
|
|
|
def add_edge(self, edge: Edge):
|
|
"""
|
|
Adds an edge to the workflow graph.
|
|
|
|
Args:
|
|
edge (Edge): The edge object to be added.
|
|
|
|
Raises:
|
|
ValueError: If either the source or target node of the edge does not exist in the graph.
|
|
"""
|
|
if (
|
|
edge.source not in self.nodes
|
|
or edge.target not in self.nodes
|
|
):
|
|
raise ValueError(
|
|
"Both source and target nodes must exist before adding an edge."
|
|
)
|
|
self.edges.append(edge)
|
|
self.graph.add_edge(edge.source, edge.target)
|
|
|
|
def set_entry_points(self, entry_points: List[str]):
|
|
"""
|
|
Sets the entry points of the workflow graph.
|
|
|
|
Args:
|
|
entry_points (List[str]): A list of node IDs to be set as entry points.
|
|
|
|
Raises:
|
|
ValueError: If any of the specified node IDs do not exist in the graph.
|
|
"""
|
|
for node_id in entry_points:
|
|
if node_id not in self.nodes:
|
|
raise ValueError(
|
|
f"Node with id {node_id} does not exist."
|
|
)
|
|
self.entry_points = entry_points
|
|
|
|
def set_end_points(self, end_points: List[str]):
|
|
"""
|
|
Sets the end points of the workflow graph.
|
|
|
|
Args:
|
|
end_points (List[str]): A list of node IDs to be set as end points.
|
|
|
|
Raises:
|
|
ValueError: If any of the specified node IDs do not exist in the graph.
|
|
"""
|
|
for node_id in end_points:
|
|
if node_id not in self.nodes:
|
|
raise ValueError(
|
|
f"Node with id {node_id} does not exist."
|
|
)
|
|
self.end_points = end_points
|
|
|
|
def visualize(self) -> str:
|
|
"""
|
|
Generates a string representation of the workflow graph in the Mermaid syntax.
|
|
|
|
Returns:
|
|
str: The Mermaid string representation of the workflow graph.
|
|
"""
|
|
mermaid_str = "graph TD\n"
|
|
for node_id, node in self.nodes.items():
|
|
mermaid_str += f" {node_id}[{node_id}]\n"
|
|
for edge in self.edges:
|
|
mermaid_str += f" {edge.source} --> {edge.target}\n"
|
|
return mermaid_str
|
|
|
|
def run(
|
|
self, task: str = None, *args, **kwargs
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Function to run the workflow graph.
|
|
|
|
Args:
|
|
task (str): The task to be executed by the workflow.
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
|
|
Returns:
|
|
Dict[str, Any]: A dictionary containing the results of the execution.
|
|
|
|
Raises:
|
|
ValueError: If no entry points or end points are defined in the graph.
|
|
|
|
"""
|
|
try:
|
|
loop = 0
|
|
while loop < self.max_loops:
|
|
# Ensure all nodes and edges are valid
|
|
if not self.entry_points:
|
|
raise ValueError(
|
|
"At least one entry point must be defined."
|
|
)
|
|
if not self.end_points:
|
|
raise ValueError(
|
|
"At least one end point must be defined."
|
|
)
|
|
|
|
# Perform a topological sort of the graph to ensure proper execution order
|
|
sorted_nodes = list(nx.topological_sort(self.graph))
|
|
|
|
# Initialize execution state
|
|
execution_results = {}
|
|
|
|
for node_id in sorted_nodes:
|
|
node = self.nodes[node_id]
|
|
if node.type == NodeType.TASK:
|
|
print(f"Executing task: {node_id}")
|
|
result = node.callable()
|
|
elif node.type == NodeType.AGENT:
|
|
print(f"Executing agent: {node_id}")
|
|
result = node.agent.run(task, *args, **kwargs)
|
|
execution_results[node_id] = result
|
|
|
|
loop += 1
|
|
|
|
return execution_results
|
|
except Exception as e:
|
|
logger.info(f"Error in running the workflow: {e}")
|
|
raise e
|
|
|
|
|
|
# # Example usage
|
|
# if __name__ == "__main__":
|
|
# from swarms import Agent
|
|
|
|
# import os
|
|
# from dotenv import load_dotenv
|
|
|
|
# load_dotenv()
|
|
|
|
# api_key = os.environ.get("OPENAI_API_KEY")
|
|
|
|
# llm = OpenAIChat(
|
|
# temperature=0.5, openai_api_key=api_key, max_tokens=4000
|
|
# )
|
|
# agent1 = Agent(llm=llm, max_loops=1, autosave=True, dashboard=True)
|
|
# agent2 = Agent(llm=llm, max_loops=1, autosave=True, dashboard=True)
|
|
|
|
# def sample_task():
|
|
# print("Running sample task")
|
|
# return "Task completed"
|
|
|
|
# wf_graph = GraphWorkflow()
|
|
# wf_graph.add_node(Node(id="agent1", type=NodeType.AGENT, agent=agent1))
|
|
# wf_graph.add_node(Node(id="agent2", type=NodeType.AGENT, agent=agent2))
|
|
# wf_graph.add_node(
|
|
# Node(id="task1", type=NodeType.TASK, callable=sample_task)
|
|
# )
|
|
# wf_graph.add_edge(Edge(source="agent1", target="task1"))
|
|
# wf_graph.add_edge(Edge(source="agent2", target="task1"))
|
|
|
|
# wf_graph.set_entry_points(["agent1", "agent2"])
|
|
# wf_graph.set_end_points(["task1"])
|
|
|
|
# print(wf_graph.visualize())
|
|
|
|
# # Run the workflow
|
|
# results = wf_graph.run()
|
|
# print("Execution results:", results)
|