You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/structs/airflow_swarm.py

431 lines
14 KiB

import subprocess
import sys
import uuid
import threading
from concurrent.futures import (
FIRST_COMPLETED,
ThreadPoolExecutor,
wait,
)
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union
from graphviz import Digraph
from loguru import logger
# Airflow imports
try:
from airflow import DAG
from airflow.operators.python import PythonOperator
except ImportError:
logger.error(
"Airflow is not installed. Please install it using 'pip install apache-airflow'."
)
subprocess.run(
[sys.executable, "-m", "pip", "install", "apache-airflow"]
)
from airflow import DAG
from airflow.operators.python import PythonOperator
# Import the real Agent from swarms.
from swarms.structs.conversation import Conversation
class NodeType(Enum):
AGENT = "agent"
CALLABLE = "callable"
TOOL = "tool"
def dag_id():
return uuid.uuid4().hex
@dataclass
class Node:
"""Represents a node in the DAG"""
id: str
type: NodeType
component: Any # Agent, Callable, or Tool
query: Optional[str] = None
args: Optional[List[Any]] = None
kwargs: Optional[Dict[str, Any]] = None
concurrent: bool = False
# ======= Airflow DAG Swarm Class =========
class AirflowDAGSwarm:
"""
A simplified and more intuitive DAG-based swarm for orchestrating agents, callables, and tools.
Provides an easy-to-use API for building agent pipelines with support for concurrent execution.
"""
def __init__(
self,
dag_id: str = dag_id(),
description: str = "A DAG Swarm for Airflow",
name: str = "Airflow DAG Swarm",
schedule_interval: Union[timedelta, str] = timedelta(days=1),
start_date: datetime = datetime(2025, 2, 14),
default_args: Optional[Dict[str, Any]] = None,
initial_message: Optional[str] = None,
max_workers: int = 5,
):
"""Initialize the AirflowDAGSwarm with improved configuration."""
self.dag_id = dag_id
self.name = name
self.description = description
self.max_workers = max_workers
self.default_args = default_args or {
"owner": "airflow",
"depends_on_past": False,
"email_on_failure": False,
"email_on_retry": False,
"retries": 1,
"retry_delay": timedelta(minutes=5),
}
self.dag = DAG(
dag_id=dag_id,
default_args=self.default_args,
schedule_interval=schedule_interval,
start_date=start_date,
catchup=False,
)
self.nodes: Dict[str, Node] = {}
self.edges: Dict[str, Set[str]] = (
{}
) # node_id -> set of child node_ids
# Initialize conversation
self.conversation = Conversation()
if initial_message:
self.conversation.add("user", initial_message)
self.lock = threading.Lock()
def add_user_message(self, message: str) -> None:
"""Add a user message to the conversation."""
with self.lock:
self.conversation.add("user", message)
logger.info(f"Added user message: {message}")
def get_conversation_history(self) -> str:
"""Get the conversation history as JSON."""
return self.conversation.to_json()
def add_node(
self,
node_id: str,
component: Any,
node_type: NodeType,
query: Optional[str] = None,
args: Optional[List[Any]] = None,
kwargs: Optional[Dict[str, Any]] = None,
concurrent: bool = False,
) -> str:
"""
Add a node to the DAG with improved type checking and validation.
Args:
node_id: Unique identifier for the node
component: Agent, callable, or tool to execute
node_type: Type of the node (AGENT, CALLABLE, or TOOL)
query: Query string for agents
args: Positional arguments for callables/tools
kwargs: Keyword arguments for callables/tools
concurrent: Whether to execute this node concurrently
Returns:
node_id: The ID of the created node
"""
if node_id in self.nodes:
raise ValueError(f"Node with ID {node_id} already exists")
if node_type == NodeType.AGENT and not hasattr(
component, "run"
):
raise ValueError("Agent must have a 'run' method")
elif node_type in (
NodeType.CALLABLE,
NodeType.TOOL,
) and not callable(component):
raise ValueError(f"{node_type.value} must be callable")
node = Node(
id=node_id,
type=node_type,
component=component,
query=query,
args=args or [],
kwargs=kwargs or {},
concurrent=concurrent,
)
self.nodes[node_id] = node
self.edges[node_id] = set()
logger.info(f"Added {node_type.value} node: {node_id}")
return node_id
def add_edge(self, from_node: str, to_node: str) -> None:
"""
Add a directed edge between two nodes in the DAG.
Args:
from_node: ID of the source node
to_node: ID of the target node
"""
if from_node not in self.nodes or to_node not in self.nodes:
raise ValueError("Both nodes must exist in the DAG")
self.edges[from_node].add(to_node)
logger.info(f"Added edge: {from_node} -> {to_node}")
def _execute_node(self, node: Node) -> str:
"""Execute a single node and return its output."""
try:
if node.type == NodeType.AGENT:
query = (
node.query
or self.conversation.get_last_message_as_string()
or "Default query"
)
logger.info(
f"Executing agent node {node.id} with query: {query}"
)
return node.component.run(query)
elif node.type in (NodeType.CALLABLE, NodeType.TOOL):
logger.info(
f"Executing {node.type.value} node {node.id}"
)
return node.component(
*node.args,
conversation=self.conversation,
**node.kwargs,
)
except Exception as e:
logger.exception(f"Error executing node {node.id}: {e}")
return f"Error in node {node.id}: {str(e)}"
def _get_root_nodes(self) -> List[str]:
"""Get nodes with no incoming edges."""
all_nodes = set(self.nodes.keys())
nodes_with_incoming = {
node for edges in self.edges.values() for node in edges
}
return list(all_nodes - nodes_with_incoming)
def run(self, **context: Any) -> str:
"""
Execute the DAG with improved concurrency handling and error recovery.
Returns:
The final conversation state as a JSON string
"""
logger.info("Starting swarm execution")
# Track completed nodes and their results
completed: Dict[str, str] = {}
def can_execute_node(node_id: str) -> bool:
"""Check if all dependencies of a node are completed."""
return all(
dep in completed
for dep_set in self.edges.values()
for dep in dep_set
if node_id in dep_set
)
with ThreadPoolExecutor(
max_workers=self.max_workers
) as executor:
# Initialize futures dict for concurrent root nodes
futures_dict = {
executor.submit(
self._execute_node, self.nodes[node_id]
): node_id
for node_id in self._get_root_nodes()
if self.nodes[node_id].concurrent
}
# Execute nodes that shouldn't run concurrently
for node_id in self._get_root_nodes():
if not self.nodes[node_id].concurrent:
output = self._execute_node(self.nodes[node_id])
with self.lock:
completed[node_id] = output
self.conversation.add("assistant", output)
# Process remaining nodes
while futures_dict:
done, _ = wait(
futures_dict.keys(), return_when=FIRST_COMPLETED
)
for future in done:
node_id = futures_dict.pop(future)
try:
output = future.result()
with self.lock:
completed[node_id] = output
self.conversation.add("assistant", output)
except Exception as e:
logger.exception(
f"Error in node {node_id}: {e}"
)
completed[node_id] = f"Error: {str(e)}"
# Add new nodes that are ready to execute
new_nodes = [
node_id
for node_id in self.nodes
if node_id not in completed
and can_execute_node(node_id)
]
for node_id in new_nodes:
if self.nodes[node_id].concurrent:
future = executor.submit(
self._execute_node,
self.nodes[node_id],
)
futures_dict[future] = node_id
else:
output = self._execute_node(
self.nodes[node_id]
)
with self.lock:
completed[node_id] = output
self.conversation.add(
"assistant", output
)
return self.conversation.to_json()
def visualize(
self, filename: str = "dag_visualization", view: bool = True
) -> Digraph:
"""
Generate a visualization of the DAG structure.
Args:
filename: Output filename for the visualization
view: Whether to open the visualization
Returns:
Graphviz Digraph object
"""
dot = Digraph(comment=f"DAG Visualization: {self.name}")
# Add nodes
for node_id, node in self.nodes.items():
label = f"{node_id}\n({node.type.value})"
shape = "box" if node.concurrent else "ellipse"
dot.node(node_id, label, shape=shape)
# Add edges
for from_node, to_nodes in self.edges.items():
for to_node in to_nodes:
dot.edge(from_node, to_node)
dot.render(filename, view=view, format="pdf")
return dot
def create_dag(self) -> DAG:
"""
Create an Airflow DAG with a single PythonOperator that executes the entire swarm.
In a production environment, you might break the components into multiple tasks.
:return: The configured Airflow DAG.
"""
logger.info("Creating Airflow DAG for swarm execution.")
PythonOperator(
task_id="run",
python_callable=self.run,
op_kwargs={
"concurrent": False
}, # Change to True for concurrent execution.
dag=self.dag,
)
return self.dag
# # ======= Example Usage =========
# if __name__ == "__main__":
# # Configure logger to output to console.
# logger.remove()
# logger.add(lambda msg: print(msg, end=""), level="DEBUG")
# # Create the DAG swarm with an initial message
# swarm = AirflowDAGSwarm(
# dag_id="swarm_conversation_dag",
# initial_message="Hello, how can I help you with financial planning?",
# )
# # Create a real financial agent using the swarms package.
# financial_agent = Agent(
# agent_name="Financial-Analysis-Agent",
# system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
# model_name="gpt-4o-mini",
# max_loops=1,
# )
# # Add the real agent with a specific query.
# swarm.add_node(
# "financial_advisor",
# financial_agent,
# NodeType.AGENT,
# query="How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria",
# concurrent=True,
# )
# # Add a callable component.
# def extra_processing(x: int, conversation: Conversation) -> str:
# return f"Extra processing output with data {x} and conversation length {len(conversation.messages)}"
# swarm.add_node(
# "extra_processing",
# extra_processing,
# NodeType.CALLABLE,
# args=[42],
# concurrent=True,
# )
# # Add a tool component (for example, a tool to create a conversation graph).
# def create_conversation_graph(conversation: Conversation) -> str:
# # In this tool, we generate the graph and return a confirmation message.
# swarm.visualize(
# filename="swarm_conversation_tool_graph", view=False
# )
# return "Graph created."
# swarm.add_node(
# "conversation_graph",
# create_conversation_graph,
# NodeType.TOOL,
# concurrent=False,
# )
# # Add edges to create the pipeline
# swarm.add_edge("financial_advisor", "extra_processing")
# swarm.add_edge("extra_processing", "conversation_graph")
# # Execute the swarm
# final_state = swarm.run()
# logger.info(f"Final conversation: {final_state}")
# # Visualize the DAG
# print(
# swarm.visualize(
# filename="swarm_conversation_final", view=False
# )
# )
# # Create the Airflow DAG.
# dag = swarm.create_dag()