add agents in batches to graph workflow

master
Kye Gomez 1 day ago
parent 871bc77713
commit a0075c3690

@ -1,51 +1,44 @@
#!/usr/bin/env python3
"""
Basic Graph Workflow Example
A minimal example showing how to use GraphWorkflow with backend selection.
"""
from swarms.structs.graph_workflow import GraphWorkflow from swarms.structs.graph_workflow import GraphWorkflow
from swarms.structs.agent import Agent from swarms.structs.agent import Agent
agent_one = Agent(agent_name="research_agent", model="gpt-4o-mini") agent_one = Agent(
agent_name="research_agent",
model_name="gpt-4o-mini",
name="Research Agent",
agent_description="Agent responsible for gathering and summarizing research information."
)
agent_two = Agent( agent_two = Agent(
agent_name="research_agent_two", model="gpt-4o-mini" agent_name="research_agent_two",
model_name="gpt-4o-mini",
name="Analysis Agent",
agent_description="Agent that analyzes the research data provided and processes insights."
) )
agent_three = Agent( agent_three = Agent(
agent_name="research_agent_three", model="gpt-4o-mini" agent_name="research_agent_three",
model_name="gpt-4o-mini",
agent_description="Agent tasked with structuring analysis into a final report or output."
) )
# Create workflow with backend selection
def main(): workflow = GraphWorkflow(
"""
Run a basic graph workflow example without print statements.
"""
# Create agents
# Create workflow with backend selection
workflow = GraphWorkflow(
name="Basic Example", name="Basic Example",
verbose=True, verbose=True,
) )
# Add agents to workflow workflow.add_nodes([agent_one, agent_two, agent_three])
workflow.add_node(agent_one)
workflow.add_node(agent_two)
workflow.add_node(agent_three)
# Create simple chain using the actual agent names # Create simple chain using the actual agent names
workflow.add_edge("research_agent", "research_agent_two") workflow.add_edge("research_agent", "research_agent_two")
workflow.add_edge("research_agent_two", "research_agent_three") workflow.add_edge("research_agent_two", "research_agent_three")
# Compile the workflow workflow.visualize()
workflow.compile()
# Run the workflow # Compile the workflow
task = "Complete a simple task" workflow.compile()
results = workflow.run(task)
return results
# Run the workflow
task = "Complete a simple task"
results = workflow.run(task)
if __name__ == "__main__": print(results)
main()

@ -0,0 +1,48 @@
from swarms.structs.graph_workflow import GraphWorkflow
from swarms.structs.agent import Agent
agent_one = Agent(
agent_name="research_agent",
model_name="claude-haiku-4-5",
top_p=None,
temperature=None,
agent_description="Agent responsible for gathering and summarizing research information."
)
agent_two = Agent(
agent_name="research_agent_two",
model_name="claude-haiku-4-5",
top_p=None,
temperature=None,
agent_description="Agent that analyzes the research data provided and processes insights."
)
agent_three = Agent(
agent_name="research_agent_three",
model_name="claude-haiku-4-5",
top_p=None,
temperature=None,
agent_description="Agent tasked with structuring analysis into a final report or output."
)
# Create workflow with backend selection
workflow = GraphWorkflow(
name="Basic Example",
verbose=True,
backend="rustworkx",
)
agents = [agent_one, agent_two, agent_three]
workflow.add_nodes(agents, batch_size=3)
workflow.add_edge("research_agent", "research_agent_two")
workflow.add_edge("research_agent_two", "research_agent_three")
workflow.visualize()
# Compile the workflow
workflow.compile()
# Run the workflow
task = "Analyze the best mining companies in the US"
results = workflow.run(task)
print(results)

@ -1,55 +0,0 @@
import re
from swarms.structs.maker import MAKER
# Define task-specific functions for a counting task
def format_counting_prompt(
task, state, step_idx, previous_result
):
"""Format prompt for counting task."""
if previous_result is None:
return f"{task}\nThis is step 1. What is the first number? Reply with just the number."
return f"{task}\nThe previous number was {previous_result}. What is the next number? Reply with just the number."
def parse_counting_response(response):
"""Parse the counting response to extract the number."""
numbers = re.findall(r"\d+", response)
if numbers:
return int(numbers[0])
return response.strip()
def validate_counting_response(response, max_tokens):
"""Validate counting response."""
if len(response) > max_tokens * 4:
return False
return bool(re.search(r"\d+", response))
# Create MAKER instance
maker = MAKER(
name="CountingExample",
description="MAKER example: counting numbers",
model_name="gpt-4o-mini",
system_prompt="You are a helpful assistant. When asked to count, respond with just the number, nothing else.",
format_prompt=format_counting_prompt,
parse_response=parse_counting_response,
validate_response=validate_counting_response,
k=2,
max_tokens=100,
temperature=0.1,
verbose=True,
)
# Run the solver with the task as the main input
results = maker.run(
task="Count from 1 to 10, one number at a time",
max_steps=5,
)
print(results)
# Show statistics
stats = maker.get_statistics()

@ -1,10 +1,21 @@
import asyncio import asyncio
import concurrent.futures import concurrent.futures
import json import json
import os
import time import time
import traceback
import uuid import uuid
from enum import Enum from enum import Enum
from typing import Any, Dict, Iterator, List, Optional, Set from typing import (
Any,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
Union,
)
import networkx as nx import networkx as nx
@ -596,12 +607,12 @@ class Node:
) )
@classmethod @classmethod
def from_agent(cls, agent, **kwargs): def from_agent(cls, agent: Agent, **kwargs: Any) -> "Node":
""" """
Create a Node from an Agent object. Create a Node from an Agent object.
Args: Args:
agent: The agent to create a node from. agent (Agent): The agent to create a node from.
**kwargs: Additional keyword arguments. **kwargs: Additional keyword arguments.
Returns: Returns:
@ -644,13 +655,18 @@ class Edge:
self.metadata = metadata or {} self.metadata = metadata or {}
@classmethod @classmethod
def from_nodes(cls, source_node, target_node, **kwargs): def from_nodes(
cls,
source_node: Union["Node", Agent, str],
target_node: Union["Node", Agent, str],
**kwargs: Any,
) -> "Edge":
""" """
Create an Edge from node objects or ids. Create an Edge from node objects or ids.
Args: Args:
source_node: Source node object or ID. source_node (Union[Node, Agent, str]): Source node object or ID.
target_node: Target node object or ID. target_node (Union[Node, Agent, str]): Target node object or ID.
**kwargs: Additional keyword arguments. **kwargs: Additional keyword arguments.
Returns: Returns:
@ -844,7 +860,7 @@ class GraphWorkflow:
"GraphWorkflow initialization completed successfully" "GraphWorkflow initialization completed successfully"
) )
def _invalidate_compilation(self): def _invalidate_compilation(self) -> None:
""" """
Invalidate compiled optimizations when graph structure changes. Invalidate compiled optimizations when graph structure changes.
Forces recompilation on next run to ensure cache coherency. Forces recompilation on next run to ensure cache coherency.
@ -864,7 +880,7 @@ class GraphWorkflow:
if self.verbose: if self.verbose:
logger.debug("Cleared predecessors cache") logger.debug("Cleared predecessors cache")
def compile(self): def compile(self) -> None:
""" """
Pre-compute expensive operations for faster execution. Pre-compute expensive operations for faster execution.
Call this after building the graph structure. Call this after building the graph structure.
@ -932,7 +948,7 @@ class GraphWorkflow:
) )
raise e raise e
def add_node(self, agent: Agent, **kwargs): def add_node(self, agent: Agent, **kwargs: Any) -> None:
""" """
Adds an agent node to the workflow graph. Adds an agent node to the workflow graph.
@ -971,13 +987,46 @@ class GraphWorkflow:
) )
raise e raise e
def add_edge(self, edge_or_source, target=None, **kwargs): def add_nodes(self, agents: List[Agent], batch_size: int = 10, **kwargs: Any) -> None:
"""
Add multiple agents to the workflow graph concurrently in batches.
Args:
agents (List[Agent]): List of agents to add.
batch_size (int): Number of agents to add concurrently in a batch. Defaults to 8.
**kwargs: Additional keyword arguments for each node addition.
"""
try:
with concurrent.futures.ThreadPoolExecutor(max_workers = self._max_workers) as executor:
# Process agents in batches
for i in range(0, len(agents), batch_size):
batch = agents[i:i + batch_size]
futures = [
executor.submit(self.add_node, agent, **kwargs)
for agent in batch
]
# Ensure all nodes in batch are added before next batch
for future in concurrent.futures.as_completed(futures):
future.result()
except Exception as e:
logger.exception(
f"Error in GraphWorkflow.add_nodes for agents {agents}: {e} Traceback: {traceback.format_exc()}"
)
raise e
def add_edge(
self,
edge_or_source: Union[Edge, Node, Agent, str],
target: Optional[Union[Node, Agent, str]] = None,
**kwargs: Any,
) -> None:
""" """
Add an edge by Edge object or by passing node objects/ids. Add an edge by Edge object or by passing node objects/ids.
Args: Args:
edge_or_source: Either an Edge object or the source node/id. edge_or_source (Union[Edge, Node, Agent, str]): Either an Edge object or the source node/id.
target: Target node/id (required if edge_or_source is not an Edge). target (Optional[Union[Node, Agent, str]]): Target node/id (required if edge_or_source is not an Edge).
**kwargs: Additional keyword arguments for the edge. **kwargs: Additional keyword arguments for the edge.
""" """
try: try:
@ -1022,15 +1071,20 @@ class GraphWorkflow:
logger.exception(f"Error in GraphWorkflow.add_edge: {e}") logger.exception(f"Error in GraphWorkflow.add_edge: {e}")
raise e raise e
def add_edges_from_source(self, source, targets, **kwargs): def add_edges_from_source(
self,
source: Union[Node, Agent, str],
targets: List[Union[Node, Agent, str]],
**kwargs: Any,
) -> List[Edge]:
""" """
Add multiple edges from a single source to multiple targets for parallel processing. Add multiple edges from a single source to multiple targets for parallel processing.
This creates a "fan-out" pattern where the source agent's output is distributed This creates a "fan-out" pattern where the source agent's output is distributed
to all target agents simultaneously. to all target agents simultaneously.
Args: Args:
source: Source node/id that will send output to multiple targets. source (Union[Node, Agent, str]): Source node/id that will send output to multiple targets.
targets: List of target node/ids that will receive the source output in parallel. targets (List[Union[Node, Agent, str]]): List of target node/ids that will receive the source output in parallel.
**kwargs: Additional keyword arguments for all edges. **kwargs: Additional keyword arguments for all edges.
Returns: Returns:
@ -1091,14 +1145,19 @@ class GraphWorkflow:
) )
raise e raise e
def add_edges_to_target(self, sources, target, **kwargs): def add_edges_to_target(
self,
sources: List[Union[Node, Agent, str]],
target: Union[Node, Agent, str],
**kwargs: Any,
) -> List[Edge]:
""" """
Add multiple edges from multiple sources to a single target for convergence processing. Add multiple edges from multiple sources to a single target for convergence processing.
This creates a "fan-in" pattern where multiple agents' outputs converge to a single target. This creates a "fan-in" pattern where multiple agents' outputs converge to a single target.
Args: Args:
sources: List of source node/ids that will send output to the target. sources (List[Union[Node, Agent, str]]): List of source node/ids that will send output to the target.
target: Target node/id that will receive all source outputs. target (Union[Node, Agent, str]): Target node/id that will receive all source outputs.
**kwargs: Additional keyword arguments for all edges. **kwargs: Additional keyword arguments for all edges.
Returns: Returns:
@ -1159,14 +1218,19 @@ class GraphWorkflow:
) )
raise e raise e
def add_parallel_chain(self, sources, targets, **kwargs): def add_parallel_chain(
self,
sources: List[Union[Node, Agent, str]],
targets: List[Union[Node, Agent, str]],
**kwargs: Any,
) -> List[Edge]:
""" """
Create a parallel processing chain where multiple sources connect to multiple targets. Create a parallel processing chain where multiple sources connect to multiple targets.
This creates a full mesh connection pattern for maximum parallel processing. This creates a full mesh connection pattern for maximum parallel processing.
Args: Args:
sources: List of source node/ids. sources (List[Union[Node, Agent, str]]): List of source node/ids.
targets: List of target node/ids. targets (List[Union[Node, Agent, str]]): List of target node/ids.
**kwargs: Additional keyword arguments for all edges. **kwargs: Additional keyword arguments for all edges.
Returns: Returns:
@ -1230,7 +1294,7 @@ class GraphWorkflow:
) )
raise e raise e
def set_entry_points(self, entry_points: List[str]): def set_entry_points(self, entry_points: List[str]) -> None:
""" """
Set the entry points for the workflow. Set the entry points for the workflow.
@ -1261,7 +1325,7 @@ class GraphWorkflow:
) )
raise e raise e
def set_end_points(self, end_points: List[str]): def set_end_points(self, end_points: List[str]) -> None:
""" """
Set the end points for the workflow. Set the end points for the workflow.
@ -1295,22 +1359,22 @@ class GraphWorkflow:
@classmethod @classmethod
def from_spec( def from_spec(
cls, cls,
agents, agents: List[Union[Agent, Node]],
edges, edges: List[Union[Edge, Tuple[Any, Any]]],
entry_points=None, entry_points: Optional[List[str]] = None,
end_points=None, end_points: Optional[List[str]] = None,
task=None, task: Optional[str] = None,
**kwargs, **kwargs: Any,
): ) -> "GraphWorkflow":
""" """
Construct a workflow from a list of agents and connections. Construct a workflow from a list of agents and connections.
Args: Args:
agents: List of agents or Node objects. agents (List[Union[Agent, Node]]): List of agents or Node objects.
edges: List of edges or edge tuples. edges (List[Union[Edge, Tuple[Any, Any]]]): List of edges or edge tuples.
entry_points: List of entry point node IDs. entry_points (Optional[List[str]]): List of entry point node IDs.
end_points: List of end point node IDs. end_points (Optional[List[str]]): List of end point node IDs.
task: Task to be executed by the workflow. task (Optional[str]): Task to be executed by the workflow.
**kwargs: Additional keyword arguments. **kwargs: Additional keyword arguments.
Returns: Returns:
@ -1425,7 +1489,7 @@ class GraphWorkflow:
logger.exception(f"Error in GraphWorkflow.from_spec: {e}") logger.exception(f"Error in GraphWorkflow.from_spec: {e}")
raise e raise e
def auto_set_entry_points(self): def auto_set_entry_points(self) -> None:
""" """
Automatically set entry points to nodes with no incoming edges. Automatically set entry points to nodes with no incoming edges.
""" """
@ -1455,7 +1519,7 @@ class GraphWorkflow:
) )
raise e raise e
def auto_set_end_points(self): def auto_set_end_points(self) -> None:
""" """
Automatically set end points to nodes with no outgoing edges. Automatically set end points to nodes with no outgoing edges.
""" """
@ -1483,7 +1547,7 @@ class GraphWorkflow:
) )
raise e raise e
def _get_predecessors(self, node_id: str) -> tuple: def _get_predecessors(self, node_id: str) -> Tuple[str, ...]:
""" """
Cached predecessor lookup for faster repeated access. Cached predecessor lookup for faster repeated access.
@ -1491,7 +1555,7 @@ class GraphWorkflow:
node_id (str): The node ID to get predecessors for. node_id (str): The node ID to get predecessors for.
Returns: Returns:
tuple: Tuple of predecessor node IDs. Tuple[str, ...]: Tuple of predecessor node IDs.
""" """
# Use instance-level caching instead of @lru_cache to avoid hashing issues # Use instance-level caching instead of @lru_cache to avoid hashing issues
if not hasattr(self, "_predecessors_cache"): if not hasattr(self, "_predecessors_cache"):
@ -1508,7 +1572,7 @@ class GraphWorkflow:
self, self,
node_id: str, node_id: str,
task: str, task: str,
prev_outputs: Dict[str, str], prev_outputs: Dict[str, Any],
layer_idx: int, layer_idx: int,
) -> str: ) -> str:
""" """
@ -1517,7 +1581,7 @@ class GraphWorkflow:
Args: Args:
node_id (str): The node ID to build a prompt for. node_id (str): The node ID to build a prompt for.
task (str): The main task. task (str): The main task.
prev_outputs (Dict[str, str]): Previous outputs from predecessor nodes. prev_outputs (Dict[str, Any]): Previous outputs from predecessor nodes.
layer_idx (int): The current layer index. layer_idx (int): The current layer index.
Returns: Returns:
@ -1574,13 +1638,16 @@ class GraphWorkflow:
raise e raise e
async def arun( async def arun(
self, task: str = None, *args, **kwargs self,
task: Optional[str] = None,
*args: Any,
**kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Async version of run for better performance with I/O bound operations. Async version of run for better performance with I/O bound operations.
Args: Args:
task (str, optional): Task to execute. Uses self.task if not provided. task (Optional[str]): Task to execute. Uses self.task if not provided.
*args: Additional positional arguments. *args: Additional positional arguments.
**kwargs: Additional keyword arguments. **kwargs: Additional keyword arguments.
@ -1608,16 +1675,17 @@ class GraphWorkflow:
def run( def run(
self, self,
task: str = None, task: Optional[str] = None,
img: Optional[str] = None, img: Optional[str] = None,
*args, *args: Any,
**kwargs, **kwargs: Any,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Run the workflow graph with optimized parallel agent execution. Run the workflow graph with optimized parallel agent execution.
Args: Args:
task (str, optional): Task to execute. Uses self.task if not provided. task (Optional[str]): Task to execute. Uses self.task if not provided.
img (Optional[str]): Optional image path for multimodal tasks.
*args: Additional positional arguments. *args: Additional positional arguments.
**kwargs: Additional keyword arguments. **kwargs: Additional keyword arguments.
@ -1846,16 +1914,15 @@ class GraphWorkflow:
view: bool = True, view: bool = True,
engine: str = "dot", engine: str = "dot",
show_summary: bool = False, show_summary: bool = False,
): ) -> str:
""" """
Visualize the workflow graph using Graphviz with enhanced parallel pattern detection. Visualize the workflow graph using Graphviz with enhanced parallel pattern detection.
Args: Args:
output_path (str, optional): Path to save the visualization file. If None, uses workflow name.
format (str): Output format ('png', 'svg', 'pdf', 'dot'). Defaults to 'png'. format (str): Output format ('png', 'svg', 'pdf', 'dot'). Defaults to 'png'.
view (bool): Whether to open the visualization after creation. Defaults to True. view (bool): Whether to open the visualization after creation. Defaults to True.
engine (str): Graphviz layout engine ('dot', 'neato', 'fdp', 'sfdp', 'twopi', 'circo'). Defaults to 'dot'. engine (str): Graphviz layout engine ('dot', 'neato', 'fdp', 'sfdp', 'twopi', 'circo'). Defaults to 'dot'.
show_summary (bool): Whether to print parallel processing summary. Defaults to True. show_summary (bool): Whether to print parallel processing summary. Defaults to False.
Returns: Returns:
str: Path to the generated visualization file. str: Path to the generated visualization file.
@ -2138,7 +2205,7 @@ class GraphWorkflow:
logger.exception(f"Error in GraphWorkflow.visualize: {e}") logger.exception(f"Error in GraphWorkflow.visualize: {e}")
raise e raise e
def visualize_simple(self): def visualize_simple(self) -> str:
""" """
Simple text-based visualization for environments without Graphviz. Simple text-based visualization for environments without Graphviz.
@ -2226,12 +2293,13 @@ class GraphWorkflow:
) )
raise e raise e
def to_json( def to_json(
self, self,
fast=True, fast: bool = True,
include_conversation=False, include_conversation: bool = False,
include_runtime_state=False, include_runtime_state: bool = False,
): ) -> str:
""" """
Serialize the workflow to JSON with comprehensive metadata and configuration. Serialize the workflow to JSON with comprehensive metadata and configuration.
@ -2250,7 +2318,7 @@ class GraphWorkflow:
try: try:
def node_to_dict(node): def node_to_dict(node: Node) -> Dict[str, Any]:
node_data = { node_data = {
"id": node.id, "id": node.id,
"type": str(node.type), "type": str(node.type),
@ -2285,7 +2353,7 @@ class GraphWorkflow:
return node_data return node_data
def edge_to_dict(edge): def edge_to_dict(edge: Edge) -> Dict[str, Any]:
return { return {
"source": edge.source, "source": edge.source,
"target": edge.target, "target": edge.target,
@ -2402,7 +2470,11 @@ class GraphWorkflow:
raise e raise e
@classmethod @classmethod
def from_json(cls, json_str, restore_runtime_state=False): def from_json(
cls,
json_str: str,
restore_runtime_state: bool = False,
) -> "GraphWorkflow":
""" """
Deserialize a workflow from JSON with comprehensive parameter support and backward compatibility. Deserialize a workflow from JSON with comprehensive parameter support and backward compatibility.
@ -2660,7 +2732,6 @@ class GraphWorkflow:
FileExistsError: If file exists and overwrite is False FileExistsError: If file exists and overwrite is False
Exception: If save operation fails Exception: If save operation fails
""" """
import os
# Handle file path validation # Handle file path validation
if not filepath.endswith(".json"): if not filepath.endswith(".json"):
@ -2723,7 +2794,6 @@ class GraphWorkflow:
FileNotFoundError: If file doesn't exist FileNotFoundError: If file doesn't exist
Exception: If load operation fails Exception: If load operation fails
""" """
import os
if not os.path.exists(filepath): if not os.path.exists(filepath):
raise FileNotFoundError( raise FileNotFoundError(
@ -2755,7 +2825,7 @@ class GraphWorkflow:
) )
raise e raise e
def validate(self, auto_fix=False) -> Dict[str, Any]: def validate(self, auto_fix: bool = False) -> Dict[str, Any]:
""" """
Validate the workflow structure, checking for potential issues such as isolated nodes, Validate the workflow structure, checking for potential issues such as isolated nodes,
cyclic dependencies, etc. cyclic dependencies, etc.

Loading…
Cancel
Save