parent
8b9e424dd5
commit
7a66cbd705
Can't render this file because it has a wrong number of fields in line 4.
|
@ -0,0 +1,666 @@
|
|||||||
|
"""
|
||||||
|
GraphSwarm: A production-grade framework for orchestrating swarms of agents
|
||||||
|
Author: Claude
|
||||||
|
License: MIT
|
||||||
|
Version: 2.0.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
import networkx as nx
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from swarms import Agent
|
||||||
|
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logger.add(
|
||||||
|
"graphswarm.log",
|
||||||
|
rotation="500 MB",
|
||||||
|
retention="10 days",
|
||||||
|
level="INFO",
|
||||||
|
format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentOutput(BaseModel):
|
||||||
|
"""Structured output from an agent."""
|
||||||
|
|
||||||
|
agent_name: str
|
||||||
|
timestamp: float = Field(default_factory=time.time)
|
||||||
|
output: Any
|
||||||
|
execution_time: float
|
||||||
|
error: Optional[str] = None
|
||||||
|
metadata: Dict = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class SwarmOutput(BaseModel):
|
||||||
|
"""Structured output from the entire swarm."""
|
||||||
|
|
||||||
|
timestamp: float = Field(default_factory=time.time)
|
||||||
|
outputs: Dict[str, AgentOutput]
|
||||||
|
execution_time: float
|
||||||
|
success: bool
|
||||||
|
error: Optional[str] = None
|
||||||
|
metadata: Dict = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class SwarmMemory:
|
||||||
|
"""Vector-based memory system for GraphSwarm using ChromaDB."""
|
||||||
|
|
||||||
|
def __init__(self, collection_name: str = "swarm_memories"):
|
||||||
|
"""Initialize SwarmMemory with ChromaDB."""
|
||||||
|
self.client = chromadb.Client()
|
||||||
|
|
||||||
|
# Get or create collection
|
||||||
|
self.collection = self.client.get_or_create_collection(
|
||||||
|
name=collection_name,
|
||||||
|
metadata={"description": "GraphSwarm execution memories"},
|
||||||
|
)
|
||||||
|
|
||||||
|
def store_execution(self, task: str, result: SwarmOutput):
|
||||||
|
"""Store execution results in vector memory."""
|
||||||
|
try:
|
||||||
|
# Create metadata
|
||||||
|
metadata = {
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"success": result.success,
|
||||||
|
"execution_time": result.execution_time,
|
||||||
|
"agent_sequence": json.dumps(
|
||||||
|
[name for name in result.outputs.keys()]
|
||||||
|
),
|
||||||
|
"error": result.error if result.error else "",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create document from outputs
|
||||||
|
document = {
|
||||||
|
"task": task,
|
||||||
|
"outputs": json.dumps(
|
||||||
|
{
|
||||||
|
name: {
|
||||||
|
"output": str(output.output),
|
||||||
|
"execution_time": output.execution_time,
|
||||||
|
"error": output.error,
|
||||||
|
}
|
||||||
|
for name, output in result.outputs.items()
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store in ChromaDB
|
||||||
|
self.collection.add(
|
||||||
|
documents=[json.dumps(document)],
|
||||||
|
metadatas=[metadata],
|
||||||
|
ids=[f"exec_{datetime.now().timestamp()}"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print("added to database")
|
||||||
|
|
||||||
|
logger.info(f"Stored execution in memory: {task}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to store execution in memory: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_similar_executions(self, task: str, limit: int = 5):
|
||||||
|
"""Retrieve similar past executions."""
|
||||||
|
try:
|
||||||
|
# Query ChromaDB for similar executions
|
||||||
|
results = self.collection.query(
|
||||||
|
query_texts=[task],
|
||||||
|
n_results=limit,
|
||||||
|
include=["documents", "metadatas"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
if not results["documents"]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Process results
|
||||||
|
executions = []
|
||||||
|
for doc, metadata in zip(
|
||||||
|
results["documents"][0], results["metadatas"][0]
|
||||||
|
):
|
||||||
|
doc_dict = json.loads(doc)
|
||||||
|
executions.append(
|
||||||
|
{
|
||||||
|
"task": doc_dict["task"],
|
||||||
|
"outputs": json.loads(doc_dict["outputs"]),
|
||||||
|
"success": metadata["success"],
|
||||||
|
"execution_time": metadata["execution_time"],
|
||||||
|
"agent_sequence": json.loads(
|
||||||
|
metadata["agent_sequence"]
|
||||||
|
),
|
||||||
|
"timestamp": metadata["timestamp"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return executions
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to retrieve similar executions: {str(e)}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_optimal_sequence(self, task: str) -> Optional[List[str]]:
|
||||||
|
"""Get the most successful agent sequence for similar tasks."""
|
||||||
|
similar_executions = self.get_similar_executions(task)
|
||||||
|
print(f"similar_executions {similar_executions}")
|
||||||
|
|
||||||
|
if not similar_executions:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Sort by success and execution time
|
||||||
|
successful_execs = [
|
||||||
|
ex for ex in similar_executions if ex["success"]
|
||||||
|
]
|
||||||
|
|
||||||
|
if not successful_execs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Return sequence from most successful execution
|
||||||
|
return successful_execs[0]["agent_sequence"]
|
||||||
|
|
||||||
|
def clear_memory(self):
|
||||||
|
"""Clear all memories."""
|
||||||
|
self.client.delete_collection(self.collection.name)
|
||||||
|
self.collection = self.client.get_or_create_collection(
|
||||||
|
name=self.collection.name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphSwarm:
|
||||||
|
"""
|
||||||
|
Enhanced framework for creating and managing swarms of collaborative agents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
agents: Union[
|
||||||
|
List[Agent], List[Tuple[Agent, List[str]]], None
|
||||||
|
] = None,
|
||||||
|
max_workers: Optional[int] = None,
|
||||||
|
swarm_name: str = "Collaborative Agent Swarm",
|
||||||
|
memory_collection: str = "swarm_memory",
|
||||||
|
):
|
||||||
|
"""Initialize GraphSwarm."""
|
||||||
|
self.graph = nx.DiGraph()
|
||||||
|
self.agents: Dict[str, Agent] = {}
|
||||||
|
self.dependencies: Dict[str, List[str]] = {}
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
|
self.swarm_name = swarm_name
|
||||||
|
self.memory_collection = memory_collection
|
||||||
|
self.memory = SwarmMemory(collection_name=memory_collection)
|
||||||
|
|
||||||
|
|
||||||
|
if agents:
|
||||||
|
self.initialize_agents(agents)
|
||||||
|
|
||||||
|
logger.info(f"Initialized GraphSwarm: {swarm_name}")
|
||||||
|
|
||||||
|
def initialize_agents(
|
||||||
|
self,
|
||||||
|
agents: Union[List[Agent], List[Tuple[Agent, List[str]]]],
|
||||||
|
):
|
||||||
|
"""Initialize agents and their dependencies."""
|
||||||
|
try:
|
||||||
|
# Handle list of Agents or (Agent, dependencies) tuples
|
||||||
|
for item in agents:
|
||||||
|
if isinstance(item, tuple):
|
||||||
|
agent, dependencies = item
|
||||||
|
else:
|
||||||
|
agent, dependencies = item, []
|
||||||
|
|
||||||
|
if not isinstance(agent, Agent):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected Agent object, got {type(agent)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.agents[agent.agent_name] = agent
|
||||||
|
self.dependencies[agent.agent_name] = dependencies
|
||||||
|
self.graph.add_node(agent.agent_name, agent=agent)
|
||||||
|
|
||||||
|
# Add dependencies
|
||||||
|
for dep in dependencies:
|
||||||
|
if dep not in self.agents:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dependency {dep} not found for agent {agent.agent_name}"
|
||||||
|
)
|
||||||
|
self.graph.add_edge(dep, agent.agent_name)
|
||||||
|
|
||||||
|
self._validate_graph()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize agents: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _validate_graph(self):
|
||||||
|
"""Validate the agent dependency graph."""
|
||||||
|
if not self.graph.nodes():
|
||||||
|
raise ValueError("No agents added to swarm")
|
||||||
|
|
||||||
|
if not nx.is_directed_acyclic_graph(self.graph):
|
||||||
|
cycles = list(nx.simple_cycles(self.graph))
|
||||||
|
raise ValueError(
|
||||||
|
f"Agent dependency graph contains cycles: {cycles}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_agent_role_description(self, agent_name: str) -> str:
|
||||||
|
"""Generate a description of the agent's role in the swarm."""
|
||||||
|
predecessors = list(self.graph.predecessors(agent_name))
|
||||||
|
successors = list(self.graph.successors(agent_name))
|
||||||
|
position = (
|
||||||
|
"initial"
|
||||||
|
if not predecessors
|
||||||
|
else ("final" if not successors else "intermediate")
|
||||||
|
)
|
||||||
|
|
||||||
|
role = f"""You are {agent_name}, a specialized agent in the {self.swarm_name}.
|
||||||
|
Position: {position} agent in the workflow
|
||||||
|
|
||||||
|
Your relationships:"""
|
||||||
|
|
||||||
|
if predecessors:
|
||||||
|
role += (
|
||||||
|
f"\nYou receive input from: {', '.join(predecessors)}"
|
||||||
|
)
|
||||||
|
if successors:
|
||||||
|
role += f"\nYour output will be used by: {', '.join(successors)}"
|
||||||
|
|
||||||
|
return role
|
||||||
|
|
||||||
|
def _generate_workflow_context(self) -> str:
|
||||||
|
"""Generate a description of the entire workflow."""
|
||||||
|
execution_order = list(nx.topological_sort(self.graph))
|
||||||
|
|
||||||
|
workflow = f"""Workflow Overview of {self.swarm_name}:
|
||||||
|
|
||||||
|
Processing Order:
|
||||||
|
{' -> '.join(execution_order)}
|
||||||
|
|
||||||
|
Agent Roles:
|
||||||
|
"""
|
||||||
|
|
||||||
|
for agent_name in execution_order:
|
||||||
|
predecessors = list(self.graph.predecessors(agent_name))
|
||||||
|
successors = list(self.graph.successors(agent_name))
|
||||||
|
|
||||||
|
workflow += f"\n\n{agent_name}:"
|
||||||
|
if predecessors:
|
||||||
|
workflow += (
|
||||||
|
f"\n- Receives from: {', '.join(predecessors)}"
|
||||||
|
)
|
||||||
|
if successors:
|
||||||
|
workflow += f"\n- Sends to: {', '.join(successors)}"
|
||||||
|
if not predecessors and not successors:
|
||||||
|
workflow += "\n- Independent agent"
|
||||||
|
|
||||||
|
return workflow
|
||||||
|
|
||||||
|
def _build_agent_prompt(
|
||||||
|
self, agent_name: str, task: str, context: Dict = None
|
||||||
|
) -> str:
|
||||||
|
"""Build a comprehensive prompt for the agent including role and context."""
|
||||||
|
prompt_parts = [
|
||||||
|
self._get_agent_role_description(agent_name),
|
||||||
|
"\nWorkflow Context:",
|
||||||
|
self._generate_workflow_context(),
|
||||||
|
"\nYour Task:",
|
||||||
|
task,
|
||||||
|
]
|
||||||
|
|
||||||
|
if context:
|
||||||
|
prompt_parts.extend(
|
||||||
|
["\nContext from Previous Agents:", str(context)]
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_parts.extend(
|
||||||
|
[
|
||||||
|
"\nInstructions:",
|
||||||
|
"1. Process the task according to your role",
|
||||||
|
"2. Consider the input from previous agents when available",
|
||||||
|
"3. Provide clear, structured output",
|
||||||
|
"4. Remember that your output will be used by subsequent agents",
|
||||||
|
"\nResponse Guidelines:",
|
||||||
|
"- Provide clear, well-organized output",
|
||||||
|
"- Include relevant details and insights",
|
||||||
|
"- Highlight key findings",
|
||||||
|
"- Flag any uncertainties or issues",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(prompt_parts)
|
||||||
|
|
||||||
|
async def _execute_agent(
|
||||||
|
self, agent_name: str, task: str, context: Dict = None
|
||||||
|
) -> AgentOutput:
|
||||||
|
"""Execute a single agent."""
|
||||||
|
start_time = time.time()
|
||||||
|
agent = self.agents[agent_name]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Build comprehensive prompt
|
||||||
|
full_prompt = self._build_agent_prompt(
|
||||||
|
agent_name, task, context
|
||||||
|
)
|
||||||
|
logger.debug(f"Prompt for {agent_name}:\n{full_prompt}")
|
||||||
|
|
||||||
|
# Execute agent
|
||||||
|
output = await asyncio.to_thread(agent.run, full_prompt)
|
||||||
|
|
||||||
|
return AgentOutput(
|
||||||
|
agent_name=agent_name,
|
||||||
|
output=output,
|
||||||
|
execution_time=time.time() - start_time,
|
||||||
|
metadata={
|
||||||
|
"task": task,
|
||||||
|
"context": context,
|
||||||
|
"position_in_workflow": list(
|
||||||
|
nx.topological_sort(self.graph)
|
||||||
|
).index(agent_name),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error executing agent {agent_name}: {str(e)}"
|
||||||
|
)
|
||||||
|
return AgentOutput(
|
||||||
|
agent_name=agent_name,
|
||||||
|
output=None,
|
||||||
|
execution_time=time.time() - start_time,
|
||||||
|
error=str(e),
|
||||||
|
metadata={"task": task},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(self, task: str) -> SwarmOutput:
|
||||||
|
"""
|
||||||
|
Execute the entire swarm of agents with memory integration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Initial task to execute
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SwarmOutput: Structured output from all agents
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
outputs = {}
|
||||||
|
success = True
|
||||||
|
error = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get similar past executions
|
||||||
|
similar_executions = self.memory.get_similar_executions(
|
||||||
|
task, limit=3
|
||||||
|
)
|
||||||
|
optimal_sequence = self.memory.get_optimal_sequence(task)
|
||||||
|
|
||||||
|
# Get base execution order
|
||||||
|
base_execution_order = list(
|
||||||
|
nx.topological_sort(self.graph)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine final execution order
|
||||||
|
if optimal_sequence and all(
|
||||||
|
agent in base_execution_order
|
||||||
|
for agent in optimal_sequence
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"Using optimal sequence from memory: {optimal_sequence}"
|
||||||
|
)
|
||||||
|
execution_order = optimal_sequence
|
||||||
|
else:
|
||||||
|
execution_order = base_execution_order
|
||||||
|
|
||||||
|
# Get historical context if available
|
||||||
|
historical_context = {}
|
||||||
|
if similar_executions:
|
||||||
|
best_execution = similar_executions[0]
|
||||||
|
if best_execution["success"]:
|
||||||
|
historical_context = {
|
||||||
|
"similar_task": best_execution["task"],
|
||||||
|
"previous_outputs": best_execution["outputs"],
|
||||||
|
"execution_time": best_execution[
|
||||||
|
"execution_time"
|
||||||
|
],
|
||||||
|
"success_patterns": self._extract_success_patterns(
|
||||||
|
similar_executions
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute agents in order
|
||||||
|
for agent_name in execution_order:
|
||||||
|
try:
|
||||||
|
# Get context from dependencies and history
|
||||||
|
agent_context = {
|
||||||
|
"dependencies": {
|
||||||
|
dep: outputs[dep].output
|
||||||
|
for dep in self.graph.predecessors(
|
||||||
|
agent_name
|
||||||
|
)
|
||||||
|
if dep in outputs
|
||||||
|
},
|
||||||
|
"historical": historical_context,
|
||||||
|
"position": execution_order.index(agent_name),
|
||||||
|
"total_agents": len(execution_order),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute agent with enhanced context
|
||||||
|
output = await self._execute_agent(
|
||||||
|
agent_name, task, agent_context
|
||||||
|
)
|
||||||
|
outputs[agent_name] = output
|
||||||
|
|
||||||
|
# Update historical context with current execution
|
||||||
|
if output.output:
|
||||||
|
historical_context.update(
|
||||||
|
{
|
||||||
|
f"current_{agent_name}_output": output.output
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for errors
|
||||||
|
if output.error:
|
||||||
|
success = False
|
||||||
|
error = f"Agent {agent_name} failed: {output.error}"
|
||||||
|
|
||||||
|
# Try to recover using memory
|
||||||
|
if similar_executions:
|
||||||
|
recovery_output = self._attempt_recovery(
|
||||||
|
agent_name, task, similar_executions
|
||||||
|
)
|
||||||
|
if recovery_output:
|
||||||
|
outputs[agent_name] = recovery_output
|
||||||
|
success = True
|
||||||
|
error = None
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as agent_error:
|
||||||
|
logger.error(
|
||||||
|
f"Error executing agent {agent_name}: {str(agent_error)}"
|
||||||
|
)
|
||||||
|
success = False
|
||||||
|
error = f"Agent {agent_name} failed: {str(agent_error)}"
|
||||||
|
break
|
||||||
|
|
||||||
|
# Create result
|
||||||
|
result = SwarmOutput(
|
||||||
|
outputs=outputs,
|
||||||
|
execution_time=time.time() - start_time,
|
||||||
|
success=success,
|
||||||
|
error=error,
|
||||||
|
metadata={
|
||||||
|
"task": task,
|
||||||
|
"used_optimal_sequence": optimal_sequence
|
||||||
|
is not None,
|
||||||
|
"similar_executions_found": len(
|
||||||
|
similar_executions
|
||||||
|
),
|
||||||
|
"execution_order": execution_order,
|
||||||
|
"historical_context_used": bool(
|
||||||
|
historical_context
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store execution in memory
|
||||||
|
await self._store_execution_async(task, result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Swarm execution failed: {str(e)}")
|
||||||
|
return SwarmOutput(
|
||||||
|
outputs=outputs,
|
||||||
|
execution_time=time.time() - start_time,
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
metadata={"task": task},
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self, task: str) -> SwarmOutput:
|
||||||
|
"""Synchronous interface to execute the swarm."""
|
||||||
|
return asyncio.run(self.execute(task))
|
||||||
|
|
||||||
|
def _extract_success_patterns(
|
||||||
|
self, similar_executions: List[Dict]
|
||||||
|
) -> Dict:
|
||||||
|
"""Extract success patterns from similar executions."""
|
||||||
|
patterns = {}
|
||||||
|
successful_execs = [
|
||||||
|
ex for ex in similar_executions if ex["success"]
|
||||||
|
]
|
||||||
|
|
||||||
|
if successful_execs:
|
||||||
|
patterns = {
|
||||||
|
"common_sequences": self._find_common_sequences(
|
||||||
|
successful_execs
|
||||||
|
),
|
||||||
|
"avg_execution_time": sum(
|
||||||
|
ex["execution_time"] for ex in successful_execs
|
||||||
|
)
|
||||||
|
/ len(successful_execs),
|
||||||
|
"successful_strategies": self._extract_strategies(
|
||||||
|
successful_execs
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
return patterns
|
||||||
|
|
||||||
|
def _attempt_recovery(
|
||||||
|
self,
|
||||||
|
failed_agent: str,
|
||||||
|
task: str,
|
||||||
|
similar_executions: List[Dict],
|
||||||
|
) -> Optional[AgentOutput]:
|
||||||
|
"""Attempt to recover from failure using memory."""
|
||||||
|
for execution in similar_executions:
|
||||||
|
if (
|
||||||
|
execution["success"]
|
||||||
|
and failed_agent in execution["outputs"]
|
||||||
|
):
|
||||||
|
historical_output = execution["outputs"][failed_agent]
|
||||||
|
|
||||||
|
return AgentOutput(
|
||||||
|
agent_name=failed_agent,
|
||||||
|
output=historical_output["output"],
|
||||||
|
execution_time=historical_output[
|
||||||
|
"execution_time"
|
||||||
|
],
|
||||||
|
metadata={
|
||||||
|
"recovered_from_memory": True,
|
||||||
|
"original_task": execution["task"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _store_execution_async(
|
||||||
|
self, task: str, result: SwarmOutput
|
||||||
|
):
|
||||||
|
"""Asynchronously store execution in memory."""
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.memory.store_execution, task, result
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to store execution in memory: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_agent(self, agent: Agent, dependencies: List[str] = None):
|
||||||
|
"""Add a new agent to the swarm."""
|
||||||
|
dependencies = dependencies or []
|
||||||
|
self.agents[agent.agent_name] = agent
|
||||||
|
self.dependencies[agent.agent_name] = dependencies
|
||||||
|
self.graph.add_node(agent.agent_name, agent=agent)
|
||||||
|
|
||||||
|
for dep in dependencies:
|
||||||
|
if dep not in self.agents:
|
||||||
|
raise ValueError(f"Dependency {dep} not found")
|
||||||
|
self.graph.add_edge(dep, agent.agent_name)
|
||||||
|
|
||||||
|
self._validate_graph()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
# Create agents
|
||||||
|
data_collector = Agent(
|
||||||
|
agent_name="Market-Data-Collector",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
streaming_on=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
trend_analyzer = Agent(
|
||||||
|
agent_name="Market-Trend-Analyzer",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
streaming_on=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
report_generator = Agent(
|
||||||
|
agent_name="Investment-Report-Generator",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
streaming_on=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create swarm
|
||||||
|
swarm = GraphSwarm(
|
||||||
|
agents=[
|
||||||
|
(data_collector, []),
|
||||||
|
(trend_analyzer, ["Market-Data-Collector"]),
|
||||||
|
(report_generator, ["Market-Trend-Analyzer"]),
|
||||||
|
],
|
||||||
|
swarm_name="Market Analysis Intelligence Network",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the swarm
|
||||||
|
result = swarm.run(
|
||||||
|
"Analyze current market trends for tech stocks and provide investment recommendations"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(f"Execution success: {result.success}")
|
||||||
|
print(f"Total time: {result.execution_time:.2f} seconds")
|
||||||
|
|
||||||
|
for agent_name, output in result.outputs.items():
|
||||||
|
print(f"\nAgent: {agent_name}")
|
||||||
|
print(f"Output: {output.output}")
|
||||||
|
if output.error:
|
||||||
|
print(f"Error: {output.error}")
|
||||||
|
except Exception as error:
|
||||||
|
logger.error(error)
|
||||||
|
raise error
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,7 @@
|
|||||||
|
from swarms import Agent
|
||||||
|
|
||||||
|
Agent(
|
||||||
|
agent_name="Stock-Analysis-Agent",
|
||||||
|
model_name="gpt-4o-mini",
|
||||||
|
max_loops=1,
|
||||||
|
).run("What are 5 hft algorithms")
|
@ -0,0 +1,276 @@
|
|||||||
|
import asyncio
|
||||||
|
import pulsar
|
||||||
|
|
||||||
|
from pulsar import ConsumerType
|
||||||
|
from loguru import logger
|
||||||
|
from swarms import Agent
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class ScalableAsyncAgentSwarm:
|
||||||
|
"""
|
||||||
|
A scalable, asynchronous swarm of agents leveraging Apache Pulsar for inter-agent communication.
|
||||||
|
Provides load balancing, health monitoring, dead letter queues, and centralized logging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pulsar_url: str,
|
||||||
|
topic: str,
|
||||||
|
dlq_topic: str,
|
||||||
|
agents_config: List[Dict[str, Any]],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializes the async swarm with agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pulsar_url (str): The URL of the Apache Pulsar broker.
|
||||||
|
topic (str): The main topic for task distribution.
|
||||||
|
dlq_topic (str): The Dead Letter Queue topic for failed messages.
|
||||||
|
agents_config (List[Dict[str, Any]]): List of agent configurations with `name`, `description`, and `model_name`.
|
||||||
|
"""
|
||||||
|
self.pulsar_url = pulsar_url
|
||||||
|
self.topic = topic
|
||||||
|
self.dlq_topic = dlq_topic
|
||||||
|
self.agents_config = agents_config
|
||||||
|
self.client = pulsar.Client(pulsar_url)
|
||||||
|
self.consumer = self.client.subscribe(
|
||||||
|
topic,
|
||||||
|
subscription_name="swarm-task-sub",
|
||||||
|
consumer_type=ConsumerType.Shared,
|
||||||
|
)
|
||||||
|
self.dlq_producer = self.client.create_producer(dlq_topic)
|
||||||
|
self.response_logger = []
|
||||||
|
self.agents = [
|
||||||
|
self.create_agent(config) for config in agents_config
|
||||||
|
]
|
||||||
|
self.agent_index = 0
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Swarm initialized with agents: {}",
|
||||||
|
[agent["name"] for agent in agents_config],
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_agent(
|
||||||
|
self, agent_config: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Creates a new agent configuration with asynchronous capabilities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_config (Dict[str, Any]): Configuration dictionary with agent details.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: A dictionary containing agent metadata and functionality.
|
||||||
|
"""
|
||||||
|
agent_name = agent_config["name"]
|
||||||
|
description = agent_config["description"]
|
||||||
|
model_name = agent_config.get("model_name", "gpt-4o-mini")
|
||||||
|
|
||||||
|
class AsyncAgent:
|
||||||
|
"""
|
||||||
|
An asynchronous agent that processes tasks and communicates via Apache Pulsar.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, name: str, description: str, model_name: str
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.agent = Agent(
|
||||||
|
agent_name=name,
|
||||||
|
model_name=model_name,
|
||||||
|
max_loops="auto",
|
||||||
|
interactive=True,
|
||||||
|
streaming_on=True,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Initialized agent '{name}' - {description}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def process_task(
|
||||||
|
self, message: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Processes a single task using the agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The task message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: JSON-formatted response.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"Agent {self.name} processing task: {message}"
|
||||||
|
)
|
||||||
|
response = await asyncio.to_thread(
|
||||||
|
self.agent.run, message
|
||||||
|
)
|
||||||
|
logger.info(f"Agent {self.name} completed task.")
|
||||||
|
return {
|
||||||
|
"agent_name": self.name,
|
||||||
|
"response": response,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Agent {self.name} encountered an error: {e}"
|
||||||
|
)
|
||||||
|
return {"agent_name": self.name, "error": str(e)}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": agent_name,
|
||||||
|
"instance": AsyncAgent(
|
||||||
|
agent_name, description, model_name
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def distribute_task(self, message: str):
|
||||||
|
"""
|
||||||
|
Distributes a task to the next available agent using round-robin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The task message.
|
||||||
|
"""
|
||||||
|
agent = self.agents[self.agent_index]
|
||||||
|
self.agent_index = (self.agent_index + 1) % len(self.agents)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await agent["instance"].process_task(message)
|
||||||
|
self.log_response(response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error processing task by agent {agent['name']}: {e}"
|
||||||
|
)
|
||||||
|
self.send_to_dlq(message)
|
||||||
|
|
||||||
|
async def monitor_health(self):
|
||||||
|
"""
|
||||||
|
Periodically monitors the health of agents.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
logger.info("Performing health check for all agents.")
|
||||||
|
for agent in self.agents:
|
||||||
|
logger.info(f"Agent {agent['name']} is online.")
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
def send_to_dlq(self, message: str):
|
||||||
|
"""
|
||||||
|
Sends a failed message to the Dead Letter Queue (DLQ).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): The message to send to the DLQ.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.dlq_producer.send(message.encode("utf-8"))
|
||||||
|
logger.info("Message sent to Dead Letter Queue.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send message to DLQ: {e}")
|
||||||
|
|
||||||
|
def log_response(self, response: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Logs the response to a centralized list for later analysis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response (Dict[str, Any]): The agent's response.
|
||||||
|
"""
|
||||||
|
self.response_logger.append(response)
|
||||||
|
logger.info(f"Response logged: {response}")
|
||||||
|
|
||||||
|
async def listen_and_distribute(self):
|
||||||
|
"""
|
||||||
|
Listens to the main Pulsar topic and distributes tasks to agents.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
msg = self.consumer.receive()
|
||||||
|
try:
|
||||||
|
message = msg.data().decode("utf-8")
|
||||||
|
logger.info(f"Received task: {message}")
|
||||||
|
await self.distribute_task(message)
|
||||||
|
self.consumer.acknowledge(msg)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing message: {e}")
|
||||||
|
self.send_to_dlq(msg.data().decode("utf-8"))
|
||||||
|
self.consumer.negative_acknowledge(msg)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""
|
||||||
|
Runs the swarm asynchronously with health monitoring and task distribution.
|
||||||
|
"""
|
||||||
|
logger.info("Starting the async swarm...")
|
||||||
|
task_listener = asyncio.create_task(
|
||||||
|
self.listen_and_distribute()
|
||||||
|
)
|
||||||
|
health_monitor = asyncio.create_task(self.monitor_health())
|
||||||
|
await asyncio.gather(task_listener, health_monitor)
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
"""
|
||||||
|
Safely shuts down the swarm and logs all responses.
|
||||||
|
"""
|
||||||
|
logger.info("Shutting down the swarm...")
|
||||||
|
self.client.close()
|
||||||
|
with open("responses.json", "w") as f:
|
||||||
|
json.dump(self.response_logger, f, indent=4)
|
||||||
|
logger.info("Responses saved to 'responses.json'.")
|
||||||
|
|
||||||
|
|
||||||
|
# from scalable_agent_swarm import ScalableAsyncAgentSwarm # Assuming your swarm class is saved here
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example Configuration
|
||||||
|
PULSAR_URL = "pulsar://localhost:6650"
|
||||||
|
TOPIC = "stock-analysis"
|
||||||
|
DLQ_TOPIC = "stock-analysis-dlq"
|
||||||
|
|
||||||
|
# Agents configuration
|
||||||
|
AGENTS_CONFIG = [
|
||||||
|
{
|
||||||
|
"name": "Stock-Analysis-Agent-1",
|
||||||
|
"description": "Analyzes stock trends.",
|
||||||
|
"model_name": "gpt-4o-mini",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Stock-News-Agent",
|
||||||
|
"description": "Summarizes stock news.",
|
||||||
|
"model_name": "gpt-4o-mini",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Tech-Trends-Agent",
|
||||||
|
"description": "Tracks tech sector trends.",
|
||||||
|
"model_name": "gpt-4o-mini",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Tasks to send
|
||||||
|
TASKS = [
|
||||||
|
"Analyze the trend for tech stocks in Q4 2024",
|
||||||
|
"Summarize the latest news on the S&P 500",
|
||||||
|
"Identify the top-performing sectors in the stock market",
|
||||||
|
"Provide a forecast for AI-related stocks for 2025",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Initialize and run the swarm
|
||||||
|
swarm = ScalableAsyncAgentSwarm(
|
||||||
|
PULSAR_URL, TOPIC, DLQ_TOPIC, AGENTS_CONFIG
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# Run the swarm in the background
|
||||||
|
swarm_task = asyncio.create_task(swarm.run())
|
||||||
|
|
||||||
|
# Send tasks to the topic
|
||||||
|
client = pulsar.Client(PULSAR_URL)
|
||||||
|
producer = client.create_producer(TOPIC)
|
||||||
|
|
||||||
|
for task in TASKS:
|
||||||
|
producer.send(task.encode("utf-8"))
|
||||||
|
print(f"Sent task: {task}")
|
||||||
|
|
||||||
|
producer.close()
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
# Keep the swarm running
|
||||||
|
asyncio.run(swarm_task)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
swarm.shutdown()
|
@ -0,0 +1,253 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||||
|
|
||||||
|
from swarms import Agent
|
||||||
|
from swarms.agents.create_agents_from_yaml import (
|
||||||
|
create_agents_from_yaml,
|
||||||
|
)
|
||||||
|
from swarms.utils.formatter import formatter
|
||||||
|
from swarms.utils.litellm import LiteLLM
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_yaml_for_parsing(raw_yaml: str) -> str:
|
||||||
|
"""
|
||||||
|
Prepares raw YAML content by fixing spacing and formatting issues.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_yaml (str): The raw YAML content extracted from Markdown.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The cleaned YAML content ready for parsing.
|
||||||
|
"""
|
||||||
|
# Fix sequence items that are improperly placed on the same line as their key
|
||||||
|
fixed_yaml = re.sub(
|
||||||
|
r"(\b\w+\b):\s*-\s*", r"\1:\n - ", raw_yaml
|
||||||
|
) # Fix "key: - value" to "key:\n - value"
|
||||||
|
|
||||||
|
# Ensure proper spacing after colons
|
||||||
|
fixed_yaml = re.sub(
|
||||||
|
r"(\S):(\S)", r"\1: \2", fixed_yaml
|
||||||
|
) # Ensure space after colons
|
||||||
|
|
||||||
|
# Remove trailing spaces before newlines
|
||||||
|
fixed_yaml = re.sub(r"\s+\n", "\n", fixed_yaml)
|
||||||
|
|
||||||
|
# Replace non-breaking spaces (if any) with regular spaces
|
||||||
|
fixed_yaml = fixed_yaml.replace("\xa0", " ")
|
||||||
|
|
||||||
|
return fixed_yaml.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_yaml_from_swarm_markdown(markdown_text: str) -> dict:
|
||||||
|
"""
|
||||||
|
Extracts and prepares YAML content from a Markdown-style 'Auto-Swarm-Builder' block and parses it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
markdown_text (str): The Markdown text containing the YAML inside 'Auto-Swarm-Builder' block.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A parsed Python dictionary of the YAML content.
|
||||||
|
"""
|
||||||
|
# Match the 'Auto-Swarm-Builder' block with YAML inside triple backticks
|
||||||
|
pattern = r"```yaml\s*\n(.*?)```"
|
||||||
|
match = re.search(pattern, markdown_text, re.DOTALL)
|
||||||
|
|
||||||
|
if not match:
|
||||||
|
raise ValueError(
|
||||||
|
"No YAML content found in the 'Auto-Swarm-Builder' block."
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_yaml = match.group(1).strip()
|
||||||
|
|
||||||
|
# Preprocess and normalize the YAML content
|
||||||
|
normalized_yaml = prepare_yaml_for_parsing(raw_yaml)
|
||||||
|
|
||||||
|
return normalized_yaml
|
||||||
|
|
||||||
|
|
||||||
|
AUTO_GEN_PROMPT = """
|
||||||
|
You are a specialized agent responsible for creating YAML configuration files for multi-agent swarms. Your role is to generate well-structured YAML that defines both individual agents and swarm architectures based on user requirements.
|
||||||
|
Output only the yaml nothing else. You will be penalized for making mistakes
|
||||||
|
|
||||||
|
GUIDELINES:
|
||||||
|
1. Each YAML file must contain an `agents` section with at least one agent configuration
|
||||||
|
2. Each agent configuration requires the following mandatory fields:
|
||||||
|
- agent_name (string)
|
||||||
|
- system_prompt (string)
|
||||||
|
|
||||||
|
3. Optional agent fields include:
|
||||||
|
- max_loops (integer)
|
||||||
|
- autosave (boolean)
|
||||||
|
- dashboard (boolean)
|
||||||
|
- verbose (boolean)
|
||||||
|
- dynamic_temperature_enabled (boolean)
|
||||||
|
- saved_state_path (string)
|
||||||
|
- user_name (string)
|
||||||
|
- retry_attempts (integer)
|
||||||
|
- context_length (integer)
|
||||||
|
- return_step_meta (boolean)
|
||||||
|
- output_type (string)
|
||||||
|
- task (string)
|
||||||
|
|
||||||
|
4. When a swarm is needed, include a `swarm_architecture` section with:
|
||||||
|
Mandatory fields:
|
||||||
|
- name (string)
|
||||||
|
- swarm_type (string: "ConcurrentWorkflow" or "SequentialWorkflow") [AgentRearrange, MixtureOfAgents, SpreadSheetSwarm, SequentialWorkflow, ConcurrentWorkflow]
|
||||||
|
|
||||||
|
Optional fields:
|
||||||
|
- description (string)
|
||||||
|
- max_loops (integer)
|
||||||
|
- task (string)
|
||||||
|
|
||||||
|
TEMPLATE STRUCTURE:
|
||||||
|
```yaml
|
||||||
|
agents:
|
||||||
|
- agent_name: "Agent-1-Name"
|
||||||
|
system_prompt: "Detailed system prompt here"
|
||||||
|
max_loops: 1
|
||||||
|
# [additional optional fields]
|
||||||
|
|
||||||
|
- agent_name: "Agent-2-Name"
|
||||||
|
system_prompt: "Detailed system prompt here"
|
||||||
|
# [additional optional fields]
|
||||||
|
|
||||||
|
swarm_architecture:
|
||||||
|
name: "Swarm-Name"
|
||||||
|
description: "Swarm purpose and goals"
|
||||||
|
swarm_type: "ConcurrentWorkflow"
|
||||||
|
max_loops: 5
|
||||||
|
task: "Main swarm task description"
|
||||||
|
```
|
||||||
|
|
||||||
|
VALIDATION RULES:
|
||||||
|
1. All agent names must be unique
|
||||||
|
2. System prompts must be clear and specific to the agent's role
|
||||||
|
3. Integer values must be positive
|
||||||
|
4. Boolean values must be true or false (lowercase)
|
||||||
|
5. File paths should use forward slashes
|
||||||
|
6. Tasks should be specific and aligned with the agent/swarm purpose
|
||||||
|
|
||||||
|
When generating a YAML configuration:
|
||||||
|
1. Ask for specific requirements about the agents and swarm needed
|
||||||
|
2. Determine if a swarm architecture is necessary based on the task complexity
|
||||||
|
3. Generate appropriate system prompts for each agent based on their roles
|
||||||
|
4. Include relevant optional fields based on the use case
|
||||||
|
5. Validate the configuration against all rules before returning
|
||||||
|
|
||||||
|
Example valid YAML configurations are provided below. Use these as references for structure and formatting:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
|
||||||
|
|
||||||
|
agents:
|
||||||
|
- agent_name: "Data-Analysis-Agent"
|
||||||
|
system_prompt: "You are a specialized data analysis agent focused on processing and interpreting financial data. Provide clear, actionable insights based on the data provided."
|
||||||
|
max_loops: 3
|
||||||
|
autosave: true
|
||||||
|
verbose: true
|
||||||
|
context_length: 100000
|
||||||
|
output_type: "json"
|
||||||
|
task: "Analyze quarterly financial reports and identify trends"
|
||||||
|
|
||||||
|
# Multi-Agent Swarm Example
|
||||||
|
agents:
|
||||||
|
- agent_name: "Research-Agent"
|
||||||
|
system_prompt: "You are a research agent specialized in gathering and summarizing scientific publications. Focus on peer-reviewed sources and provide comprehensive summaries."
|
||||||
|
max_loops: 2
|
||||||
|
context_length: 150000
|
||||||
|
output_type: "str"
|
||||||
|
|
||||||
|
- agent_name: "Analysis-Agent"
|
||||||
|
system_prompt: "You are an analysis agent that processes research summaries and identifies key patterns and insights. Provide detailed analytical reports."
|
||||||
|
max_loops: 3
|
||||||
|
context_length: 200000
|
||||||
|
output_type: "json"
|
||||||
|
|
||||||
|
swarm_architecture:
|
||||||
|
name: "Research-Analysis-Swarm"
|
||||||
|
description: "A swarm for comprehensive research analysis and insight generation"
|
||||||
|
swarm_type: "SequentialWorkflow"
|
||||||
|
max_loops: 5
|
||||||
|
task: "Research and analyze recent developments in quantum computing"
|
||||||
|
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def generate_swarm_config(
|
||||||
|
task: str,
|
||||||
|
file_name: str = "swarm_config_output.yaml",
|
||||||
|
model_name: str = "gpt-4o",
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generates a swarm configuration based on the provided task and model name.
|
||||||
|
|
||||||
|
This function attempts to generate a swarm configuration by running an agent with the specified task and model name.
|
||||||
|
It then parses the output into YAML format and creates agents based on the parsed YAML content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to be performed by the swarm.
|
||||||
|
file_name (str, optional): The file name for the output YAML configuration. Defaults to "swarm_config_output.yaml".
|
||||||
|
model_name (str, optional): The name of the model to use for the agent. Defaults to "gpt-4o".
|
||||||
|
*args: Additional positional arguments to be passed to the agent's run method.
|
||||||
|
**kwargs: Additional keyword arguments to be passed to the agent's run method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: The output of the swarm configuration generation process. This can be a SwarmRouter instance or an error message.
|
||||||
|
"""
|
||||||
|
formatter.print_panel(
|
||||||
|
"Auto Generating Swarm...", "Auto Swarm Builder"
|
||||||
|
)
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(min=4, max=10),
|
||||||
|
)
|
||||||
|
def attempt_generate_swarm_config():
|
||||||
|
try:
|
||||||
|
model = LiteLLM(model_name=model_name)
|
||||||
|
|
||||||
|
# Initialize the agent
|
||||||
|
agent = Agent(
|
||||||
|
agent_name="Auto-Swarm-Builder",
|
||||||
|
system_prompt=AUTO_GEN_PROMPT,
|
||||||
|
llm=model,
|
||||||
|
max_loops=1,
|
||||||
|
dynamic_temperature_enabled=True,
|
||||||
|
saved_state_path="swarm_builder.json",
|
||||||
|
user_name="swarms_corp",
|
||||||
|
output_type="str",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate output from the agent
|
||||||
|
raw_output = agent.run(task, *args, **kwargs)
|
||||||
|
yaml_content = parse_yaml_from_swarm_markdown(raw_output)
|
||||||
|
print(yaml_content)
|
||||||
|
|
||||||
|
# Create agents from the YAML file
|
||||||
|
output = create_agents_from_yaml(
|
||||||
|
yaml_string=yaml_content,
|
||||||
|
return_type="run_swarm",
|
||||||
|
)
|
||||||
|
|
||||||
|
formatter.print_panel(
|
||||||
|
"Swarm configuration generated successfully.",
|
||||||
|
"Success",
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
formatter.print_panel(
|
||||||
|
f"Error generating swarm configuration: {str(e)}",
|
||||||
|
"Error",
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return attempt_generate_swarm_config()
|
@ -0,0 +1,105 @@
|
|||||||
|
try:
|
||||||
|
from litellm import completion
|
||||||
|
except ImportError:
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
subprocess.check_call(["pip", "install", "litellm"])
|
||||||
|
import litellm
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM:
|
||||||
|
"""
|
||||||
|
This class represents a LiteLLM.
|
||||||
|
It is used to interact with the LLM model for various tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "gpt-4o",
|
||||||
|
system_prompt: str = None,
|
||||||
|
stream: bool = False,
|
||||||
|
temperature: float = 0.5,
|
||||||
|
max_tokens: int = 4000,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the LiteLLM with the given parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str, optional): The name of the model to use. Defaults to "gpt-4o".
|
||||||
|
system_prompt (str, optional): The system prompt to use. Defaults to None.
|
||||||
|
stream (bool, optional): Whether to stream the output. Defaults to False.
|
||||||
|
temperature (float, optional): The temperature for the model. Defaults to 0.5.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 4000.
|
||||||
|
"""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
self.stream = stream
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
def _prepare_messages(self, task: str) -> list:
|
||||||
|
"""
|
||||||
|
Prepare the messages for the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to prepare messages for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of messages prepared for the task.
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
if self.system_prompt: # Check if system_prompt is not None
|
||||||
|
messages.append(
|
||||||
|
{"role": "system", "content": self.system_prompt}
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": task})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def run(self, task: str, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Run the LLM model for the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to run the model for.
|
||||||
|
*args: Additional positional arguments to pass to the model.
|
||||||
|
**kwargs: Additional keyword arguments to pass to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The content of the response from the model.
|
||||||
|
"""
|
||||||
|
messages = self._prepare_messages(task)
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=messages,
|
||||||
|
stream=self.stream,
|
||||||
|
temperature=self.temperature,
|
||||||
|
# max_completion_tokens=self.max_tokens,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
content = response.choices[
|
||||||
|
0
|
||||||
|
].message.content # Accessing the content
|
||||||
|
return content
|
||||||
|
|
||||||
|
def __call__(self, task: str, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Call the LLM model for the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to run the model for.
|
||||||
|
*args: Additional positional arguments to pass to the model.
|
||||||
|
**kwargs: Additional keyword arguments to pass to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The content of the response from the model.
|
||||||
|
"""
|
||||||
|
return self.run(task, *args, **kwargs)
|
@ -1,50 +1,64 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
def extract_code_from_markdown(markdown_content: str) -> str:
|
def extract_code_blocks_with_language(markdown_text: str):
|
||||||
"""
|
"""
|
||||||
Extracts code blocks from a Markdown string and returns them as a single string.
|
Extracts all code blocks from Markdown text along with their languages.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- markdown_content (str): The Markdown content as a string.
|
markdown_text (str): The input Markdown text.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- str: A single string containing all the code blocks separated by newlines.
|
list[dict]: A list of dictionaries, each containing:
|
||||||
|
- 'language': The detected language (or 'plaintext' if none specified).
|
||||||
|
- 'content': The content of the code block.
|
||||||
"""
|
"""
|
||||||
# Regular expression for fenced code blocks with optional language specifier
|
# Regex pattern to match code blocks and optional language specifiers
|
||||||
pattern = r"```(?:\w+\n)?(.*?)```"
|
pattern = r"```(\w+)?\n(.*?)```"
|
||||||
|
|
||||||
# Check if markdown_content is a string
|
# Find all matches (language and content)
|
||||||
if not isinstance(markdown_content, str):
|
matches = re.findall(pattern, markdown_text, re.DOTALL)
|
||||||
raise TypeError("markdown_content must be a string")
|
|
||||||
|
|
||||||
# Find all matches of the pattern
|
# Parse results
|
||||||
matches = re.finditer(pattern, markdown_content, re.DOTALL)
|
|
||||||
|
|
||||||
# Extract the content inside the backticks
|
|
||||||
code_blocks = []
|
code_blocks = []
|
||||||
for match in matches:
|
for language, content in matches:
|
||||||
code_block = match.group(1).strip()
|
language = (
|
||||||
# Remove any leading or trailing whitespace from the code block
|
language.strip() if language else "plaintext"
|
||||||
code_block = code_block.strip()
|
) # Default to 'plaintext'
|
||||||
# Remove any empty lines from the code block
|
code_blocks.append(
|
||||||
code_block = "\n".join(
|
{"language": language, "content": content.strip()}
|
||||||
[line for line in code_block.split("\n") if line.strip()]
|
|
||||||
)
|
)
|
||||||
code_blocks.append(code_block)
|
|
||||||
|
|
||||||
# Concatenate all code blocks separated by newlines
|
return code_blocks
|
||||||
if code_blocks:
|
|
||||||
return "\n\n".join(code_blocks)
|
|
||||||
else:
|
def extract_code_from_markdown(
|
||||||
return ""
|
markdown_text: str, language: str = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Extracts content of code blocks for a specific language or all blocks if no language specified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
markdown_text (str): The input Markdown text.
|
||||||
|
language (str, optional): The language to filter by (e.g., 'yaml', 'python').
|
||||||
|
|
||||||
# example = """
|
Returns:
|
||||||
# hello im an agent
|
str: The concatenated content of matched code blocks or an empty string if none found.
|
||||||
# ```bash
|
"""
|
||||||
# pip install swarms
|
# Get all code blocks with detected languages
|
||||||
# ```
|
code_blocks = extract_code_blocks_with_language(markdown_text)
|
||||||
# """
|
|
||||||
|
# Filter by language if specified
|
||||||
|
if language:
|
||||||
|
code_blocks = [
|
||||||
|
block["content"]
|
||||||
|
for block in code_blocks
|
||||||
|
if block["language"] == language
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
code_blocks = [
|
||||||
|
block["content"] for block in code_blocks
|
||||||
|
] # Include all blocks
|
||||||
|
|
||||||
# print(extract_code_from_markdown(example)) # Output: { "type": "function", "function": { "name": "fetch_financial_news", "parameters": { "query": "Nvidia news", "num_articles": 5 } } }
|
# Return concatenated content
|
||||||
|
return "\n\n".join(code_blocks) if code_blocks else ""
|
||||||
|
@ -1,292 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.distributed as dist
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
from loguru import logger
|
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StarAttentionConfig:
|
|
||||||
"""Configuration for StarAttention module.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
hidden_size: Dimension of the model's hidden states
|
|
||||||
num_attention_heads: Number of attention heads
|
|
||||||
num_hosts: Number of hosts in the distributed system
|
|
||||||
block_size: Size of each context block
|
|
||||||
anchor_size: Size of the anchor block
|
|
||||||
dropout_prob: Dropout probability (default: 0.1)
|
|
||||||
layer_norm_eps: Layer normalization epsilon (default: 1e-12)
|
|
||||||
"""
|
|
||||||
|
|
||||||
hidden_size: int
|
|
||||||
num_attention_heads: int
|
|
||||||
num_hosts: int
|
|
||||||
block_size: int
|
|
||||||
anchor_size: int
|
|
||||||
dropout_prob: float = 0.1
|
|
||||||
layer_norm_eps: float = 1e-12
|
|
||||||
|
|
||||||
|
|
||||||
class StarAttention(nn.Module):
|
|
||||||
"""
|
|
||||||
Implementation of Star Attention mechanism for distributed inference.
|
|
||||||
|
|
||||||
The module implements a two-phase attention mechanism:
|
|
||||||
1. Local Context Encoding with Anchor Blocks
|
|
||||||
2. Query Encoding and Output Generation with Global Attention
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: StarAttentionConfig):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if config.hidden_size % config.num_attention_heads != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Hidden size {config.hidden_size} not divisible by number of attention "
|
|
||||||
f"heads {config.num_attention_heads}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.config = config
|
|
||||||
self.head_dim = (
|
|
||||||
config.hidden_size // config.num_attention_heads
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize components
|
|
||||||
self.query = nn.Linear(config.hidden_size, config.hidden_size)
|
|
||||||
self.key = nn.Linear(config.hidden_size, config.hidden_size)
|
|
||||||
self.value = nn.Linear(config.hidden_size, config.hidden_size)
|
|
||||||
|
|
||||||
self.dropout = nn.Dropout(config.dropout_prob)
|
|
||||||
self.layer_norm = nn.LayerNorm(
|
|
||||||
config.hidden_size, eps=config.layer_norm_eps
|
|
||||||
)
|
|
||||||
|
|
||||||
# KV cache for storing computed key/value pairs
|
|
||||||
self.kv_cache = {}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Initialized StarAttention with config: {config}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _split_heads(
|
|
||||||
self, tensor: torch.Tensor, num_heads: int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Split the last dimension into (num_heads, head_dim)."""
|
|
||||||
batch_size, seq_len, _ = tensor.size()
|
|
||||||
tensor = tensor.view(
|
|
||||||
batch_size, seq_len, num_heads, self.head_dim
|
|
||||||
)
|
|
||||||
# Transpose to (batch_size, num_heads, seq_len, head_dim)
|
|
||||||
return tensor.transpose(1, 2)
|
|
||||||
|
|
||||||
def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Merge the head dimension back into hidden_size."""
|
|
||||||
batch_size, _, seq_len, _ = tensor.size()
|
|
||||||
tensor = tensor.transpose(1, 2)
|
|
||||||
return tensor.reshape(
|
|
||||||
batch_size, seq_len, self.config.hidden_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def _compute_attention_scores(
|
|
||||||
self,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""Compute attention scores and weighted values."""
|
|
||||||
# Scale dot-product attention
|
|
||||||
scores = torch.matmul(
|
|
||||||
query, key.transpose(-2, -1)
|
|
||||||
) / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
if mask is not None:
|
|
||||||
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
||||||
|
|
||||||
# Online softmax computation
|
|
||||||
attention_probs = torch.nn.functional.softmax(scores, dim=-1)
|
|
||||||
attention_probs = self.dropout(attention_probs)
|
|
||||||
|
|
||||||
context = torch.matmul(attention_probs, value)
|
|
||||||
|
|
||||||
return context, attention_probs
|
|
||||||
|
|
||||||
def phase1_local_context_encoding(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
host_id: int,
|
|
||||||
device: Union[str, torch.device] = "cuda",
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Phase 1: Local Context Encoding with Anchor Blocks
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: Input tensor of shape (batch_size, seq_len)
|
|
||||||
host_id: ID of the current host
|
|
||||||
device: Device to run computations on
|
|
||||||
"""
|
|
||||||
logger.debug(f"Starting Phase 1 on host {host_id}")
|
|
||||||
|
|
||||||
# Calculate block assignments
|
|
||||||
block_start = host_id * self.config.block_size
|
|
||||||
block_end = block_start + self.config.block_size
|
|
||||||
|
|
||||||
# Get local block
|
|
||||||
local_block = input_ids[:, block_start:block_end].to(device)
|
|
||||||
|
|
||||||
# Get anchor block (first block)
|
|
||||||
anchor_block = input_ids[:, : self.config.anchor_size].to(
|
|
||||||
device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute KV pairs for local block
|
|
||||||
local_hidden = self.layer_norm(local_block)
|
|
||||||
local_key = self._split_heads(
|
|
||||||
self.key(local_hidden), self.config.num_attention_heads
|
|
||||||
)
|
|
||||||
local_value = self._split_heads(
|
|
||||||
self.value(local_hidden), self.config.num_attention_heads
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store in KV cache
|
|
||||||
self.kv_cache[host_id] = {
|
|
||||||
"key": local_key,
|
|
||||||
"value": local_value,
|
|
||||||
"anchor_key": (
|
|
||||||
None
|
|
||||||
if host_id == 0
|
|
||||||
else self._split_heads(
|
|
||||||
self.key(self.layer_norm(anchor_block)),
|
|
||||||
self.config.num_attention_heads,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Phase 1 complete on host {host_id}. KV cache shapes - "
|
|
||||||
f"key: {local_key.shape}, value: {local_value.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def phase2_query_encoding(
|
|
||||||
self,
|
|
||||||
query_input: torch.Tensor,
|
|
||||||
host_id: int,
|
|
||||||
is_query_host: bool,
|
|
||||||
device: Union[str, torch.device] = "cuda",
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Phase 2: Query Encoding and Output Generation
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
|
|
||||||
host_id: ID of the current host
|
|
||||||
is_query_host: Whether this host is the query host
|
|
||||||
device: Device to run computations on
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Output tensor if this is the query host, None otherwise
|
|
||||||
"""
|
|
||||||
logger.debug(f"Starting Phase 2 on host {host_id}")
|
|
||||||
|
|
||||||
# Transform query
|
|
||||||
query_hidden = self.layer_norm(query_input)
|
|
||||||
query = self._split_heads(
|
|
||||||
self.query(query_hidden), self.config.num_attention_heads
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute local attention scores
|
|
||||||
local_context, local_probs = self._compute_attention_scores(
|
|
||||||
query,
|
|
||||||
self.kv_cache[host_id]["key"],
|
|
||||||
self.kv_cache[host_id]["value"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_query_host:
|
|
||||||
# Non-query hosts send their local attention statistics
|
|
||||||
dist.send(local_probs, dst=self.config.num_hosts - 1)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Query host aggregates attention from all hosts
|
|
||||||
all_attention_probs = [local_probs]
|
|
||||||
for src_rank in range(self.config.num_hosts - 1):
|
|
||||||
probs = torch.empty_like(local_probs)
|
|
||||||
dist.recv(probs, src=src_rank)
|
|
||||||
all_attention_probs.append(probs)
|
|
||||||
|
|
||||||
# Compute global attention
|
|
||||||
torch.mean(torch.stack(all_attention_probs), dim=0)
|
|
||||||
|
|
||||||
# Final output computation
|
|
||||||
output = self._merge_heads(local_context)
|
|
||||||
output = self.dropout(output)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Phase 2 complete on host {host_id}. Output shape: {output.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
query_input: torch.Tensor,
|
|
||||||
host_id: int,
|
|
||||||
is_query_host: bool,
|
|
||||||
device: Union[str, torch.device] = "cuda",
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Forward pass of the StarAttention module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: Input tensor of shape (batch_size, seq_len)
|
|
||||||
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
|
|
||||||
host_id: ID of the current host
|
|
||||||
is_query_host: Whether this host is the query host
|
|
||||||
device: Device to run computations on
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Output tensor if this is the query host, None otherwise
|
|
||||||
"""
|
|
||||||
# Phase 1: Local Context Encoding
|
|
||||||
self.phase1_local_context_encoding(input_ids, host_id, device)
|
|
||||||
|
|
||||||
# Phase 2: Query Encoding and Output Generation
|
|
||||||
return self.phase2_query_encoding(
|
|
||||||
query_input, host_id, is_query_host, device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Example forward pass
|
|
||||||
config = StarAttentionConfig(
|
|
||||||
hidden_size=768,
|
|
||||||
num_attention_heads=12,
|
|
||||||
num_hosts=3,
|
|
||||||
block_size=512,
|
|
||||||
anchor_size=128,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize model
|
|
||||||
model = StarAttention(config)
|
|
||||||
|
|
||||||
# Example input tensors
|
|
||||||
batch_size = 4
|
|
||||||
seq_len = 512
|
|
||||||
input_ids = torch.randint(
|
|
||||||
0, 1000, (batch_size, seq_len)
|
|
||||||
) # Random input IDs
|
|
||||||
query_input = torch.randn(
|
|
||||||
batch_size, seq_len, config.hidden_size
|
|
||||||
) # Random query input
|
|
||||||
|
|
||||||
# Example forward pass for query host (host_id = 2)
|
|
||||||
output = model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
query_input=query_input,
|
|
||||||
host_id=2,
|
|
||||||
is_query_host=True,
|
|
||||||
device="cpu",
|
|
||||||
)
|
|
||||||
|
|
||||||
print(output)
|
|
Loading…
Reference in new issue