[AGENT][LiteLLM FIX] [API FIX]

pull/696/head
Kye Gomez 3 weeks ago
parent a54785cb5f
commit 321000a299

@ -3,9 +3,8 @@ from loguru import logger
import time import time
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
from uuid import UUID from uuid import UUID
import sys
BASE_URL = "http://localhost:8000/v1" BASE_URL = "http://0.0.0.0:8000/v1"
def check_api_server() -> bool: def check_api_server() -> bool:
@ -199,6 +198,7 @@ def test_completion(session: TestSession, agent_id: UUID) -> bool:
if response.status_code == 200: if response.status_code == 200:
completion_data = response.json() completion_data = response.json()
print(completion_data)
logger.success( logger.success(
f"Got completion, used {completion_data['token_usage']['total_tokens']} tokens" f"Got completion, used {completion_data['token_usage']['total_tokens']} tokens"
) )
@ -317,4 +317,4 @@ def run_test_workflow():
if __name__ == "__main__": if __name__ == "__main__":
success = run_test_workflow() success = run_test_workflow()
sys.exit(0 if success else 1) print(success)

@ -2,7 +2,7 @@ import requests
import json import json
from time import sleep from time import sleep
BASE_URL = "http://api.swarms.ai:8000" BASE_URL = "http://swarms-api-893767232.us-east-2.elb.amazonaws.com"
def make_request(method, endpoint, data=None): def make_request(method, endpoint, data=None):

@ -1,4 +1,3 @@
import asyncio import asyncio
from typing import List from typing import List
@ -26,8 +25,8 @@ async def create_specialized_agents() -> List[Agent]:
financial_agent = Agent( financial_agent = Agent(
agent_name="Financial-Analysis-Agent", agent_name="Financial-Analysis-Agent",
agent_description="Personal finance advisor agent", agent_description="Personal finance advisor agent",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT + system_prompt=FINANCIAL_AGENT_SYS_PROMPT
"Output the <DONE> token when you're done creating a portfolio of etfs, index, funds, and more for AI", + "Output the <DONE> token when you're done creating a portfolio of etfs, index, funds, and more for AI",
max_loops=1, max_loops=1,
llm=model, llm=model,
dynamic_temperature_enabled=True, dynamic_temperature_enabled=True,
@ -81,6 +80,7 @@ async def create_specialized_agents() -> List[Agent]:
return [financial_agent, risk_agent, research_agent] return [financial_agent, risk_agent, research_agent]
async def main(): async def main():
# Create specialized agents # Create specialized agents
agents = await create_specialized_agents() agents = await create_specialized_agents()
@ -89,7 +89,7 @@ async def main():
workflow = create_default_workflow( workflow = create_default_workflow(
agents=agents, agents=agents,
name="AI-Investment-Analysis-Workflow", name="AI-Investment-Analysis-Workflow",
enable_group_chat=True enable_group_chat=True,
) )
# Configure speaker roles # Configure speaker roles
@ -99,7 +99,7 @@ async def main():
agent=agents[0], # Financial agent as coordinator agent=agents[0], # Financial agent as coordinator
priority=1, priority=1,
concurrent=False, concurrent=False,
required=True required=True,
) )
) )
@ -108,7 +108,7 @@ async def main():
role=SpeakerRole.CRITIC, role=SpeakerRole.CRITIC,
agent=agents[1], # Risk agent as critic agent=agents[1], # Risk agent as critic
priority=2, priority=2,
concurrent=True concurrent=True,
) )
) )
@ -117,7 +117,7 @@ async def main():
role=SpeakerRole.EXECUTOR, role=SpeakerRole.EXECUTOR,
agent=agents[2], # Research agent as executor agent=agents[2], # Research agent as executor
priority=2, priority=2,
concurrent=True concurrent=True,
) )
) )
@ -134,9 +134,7 @@ async def main():
try: try:
# Run workflow with retry # Run workflow with retry
result = await run_workflow_with_retry( result = await run_workflow_with_retry(
workflow=workflow, workflow=workflow, task=investment_task, max_retries=3
task=investment_task,
max_retries=3
) )
print("\nWorkflow Results:") print("\nWorkflow Results:")
@ -172,6 +170,7 @@ async def main():
finally: finally:
await workflow.cleanup() await workflow.cleanup()
if __name__ == "__main__": if __name__ == "__main__":
# Run the example # Run the example
asyncio.run(main()) asyncio.run(main())

@ -93,88 +93,186 @@ print(results)
## Production-Grade Financial Example: Multiple Agents ## Production-Grade Financial Example: Multiple Agents
### Example: Stock Analysis and Investment Strategy ### Example: Stock Analysis and Investment Strategy
```python ```python
import asyncio import asyncio
from swarms import Agent, AsyncWorkflow from typing import List
from swarms.prompts.finance_agent_sys_prompt import FINANCIAL_AGENT_SYS_PROMPT
# Initialize multiple Financial Agents
portfolio_analysis_agent = Agent(
agent_name="Portfolio-Analysis-Agent",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
model_name="gpt-4o-mini",
autosave=True,
verbose=True,
)
stock_strategy_agent = Agent( from swarm_models import OpenAIChat
agent_name="Stock-Strategy-Agent",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT, from swarms.structs.async_workflow import (
model_name="gpt-4o-mini", SpeakerConfig,
autosave=True, SpeakerRole,
verbose=True, create_default_workflow,
run_workflow_with_retry,
)
from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT,
)
from swarms.structs.agent import Agent
async def create_specialized_agents() -> List[Agent]:
"""Create a set of specialized agents for financial analysis"""
# Base model configuration
model = OpenAIChat(model_name="gpt-4o")
# Financial Analysis Agent
financial_agent = Agent(
agent_name="Financial-Analysis-Agent",
agent_description="Personal finance advisor agent",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT
+ "Output the <DONE> token when you're done creating a portfolio of etfs, index, funds, and more for AI",
max_loops=1,
llm=model,
dynamic_temperature_enabled=True,
user_name="Kye",
retry_attempts=3,
context_length=8192,
return_step_meta=False,
output_type="str",
auto_generate_prompt=False,
max_tokens=4000,
stopping_token="<DONE>",
saved_state_path="financial_agent.json",
interactive=False,
) )
risk_management_agent = Agent( # Risk Assessment Agent
agent_name="Risk-Management-Agent", risk_agent = Agent(
system_prompt=FINANCIAL_AGENT_SYS_PROMPT, agent_name="Risk-Assessment-Agent",
model_name="gpt-4o-mini", agent_description="Investment risk analysis specialist",
autosave=True, system_prompt="Analyze investment risks and provide risk scores. Output <DONE> when analysis is complete.",
verbose=True, max_loops=1,
llm=model,
dynamic_temperature_enabled=True,
user_name="Kye",
retry_attempts=3,
context_length=8192,
output_type="str",
max_tokens=4000,
stopping_token="<DONE>",
saved_state_path="risk_agent.json",
interactive=False,
) )
# Create a workflow with multiple agents # Market Research Agent
workflow = AsyncWorkflow( research_agent = Agent(
name="Financial-Workflow", agent_name="Market-Research-Agent",
agents=[portfolio_analysis_agent, stock_strategy_agent, risk_management_agent], agent_description="AI and tech market research specialist",
verbose=True, system_prompt="Research AI market trends and growth opportunities. Output <DONE> when research is complete.",
max_loops=1,
llm=model,
dynamic_temperature_enabled=True,
user_name="Kye",
retry_attempts=3,
context_length=8192,
output_type="str",
max_tokens=4000,
stopping_token="<DONE>",
saved_state_path="research_agent.json",
interactive=False,
) )
# Run the workflow return [financial_agent, risk_agent, research_agent]
async def main(): async def main():
task = "Analyze the current stock market trends and provide an investment strategy with risk assessment." # Create specialized agents
results = await workflow.run(task) agents = await create_specialized_agents()
for agent_result in results:
print(agent_result) # Create workflow with group chat enabled
workflow = create_default_workflow(
agents=agents,
name="AI-Investment-Analysis-Workflow",
enable_group_chat=True,
)
asyncio.run(main()) # Configure speaker roles
``` workflow.speaker_system.add_speaker(
SpeakerConfig(
role=SpeakerRole.COORDINATOR,
agent=agents[0], # Financial agent as coordinator
priority=1,
concurrent=False,
required=True,
)
)
**Output**: workflow.speaker_system.add_speaker(
``` SpeakerConfig(
INFO: Agent Portfolio-Analysis-Agent processing task: Analyze the current stock market trends and provide an investment strategy with risk assessment. role=SpeakerRole.CRITIC,
INFO: Agent Stock-Strategy-Agent processing task: Analyze the current stock market trends and provide an investment strategy with risk assessment. agent=agents[1], # Risk agent as critic
INFO: Agent Risk-Management-Agent processing task: Analyze the current stock market trends and provide an investment strategy with risk assessment. priority=2,
INFO: Agent Portfolio-Analysis-Agent completed task concurrent=True,
INFO: Agent Stock-Strategy-Agent completed task )
INFO: Agent Risk-Management-Agent completed task )
Results:
- Detailed portfolio analysis...
- Stock investment strategies...
- Risk assessment insights...
```
--- workflow.speaker_system.add_speaker(
SpeakerConfig(
role=SpeakerRole.EXECUTOR,
agent=agents[2], # Research agent as executor
priority=2,
concurrent=True,
)
)
## Notes # Investment analysis task
1. **Autosave**: The autosave functionality is a placeholder. Users can implement custom logic to save `self.results`. investment_task = """
2. **Error Handling**: Exceptions raised by agents are logged and returned as part of the results. Create a comprehensive investment analysis for a $40k portfolio focused on AI growth opportunities:
3. **Dashboard**: The `dashboard` feature is currently not implemented but can be extended for visualization. 1. Identify high-growth AI ETFs and index funds
2. Analyze risks and potential returns
3. Create a diversified portfolio allocation
4. Provide market trend analysis
Present the results in a structured markdown format.
"""
try:
# Run workflow with retry
result = await run_workflow_with_retry(
workflow=workflow, task=investment_task, max_retries=3
)
--- print("\nWorkflow Results:")
print("================")
# Process and display agent outputs
for output in result.agent_outputs:
print(f"\nAgent: {output.agent_name}")
print("-" * (len(output.agent_name) + 8))
print(output.output)
# Display group chat history if enabled
if workflow.enable_group_chat:
print("\nGroup Chat Discussion:")
print("=====================")
for msg in workflow.speaker_system.message_history:
print(f"\n{msg.role} ({msg.agent_name}):")
print(msg.content)
# Save detailed results
if result.metadata.get("shared_memory_keys"):
print("\nShared Insights:")
print("===============")
for key in result.metadata["shared_memory_keys"]:
value = workflow.shared_memory.get(key)
if value:
print(f"\n{key}:")
print(value)
except Exception as e:
print(f"Workflow failed: {str(e)}")
finally:
await workflow.cleanup()
if __name__ == "__main__":
# Run the example
asyncio.run(main())
## Dependencies
- `asyncio`: Python's asynchronous I/O framework.
- `loguru`: Logging utility for better log management.
- `swarms`: Base components (`BaseWorkflow`, `Agent`).
--- ```
## Future Extensions
- **Dashboard**: Implement a real-time dashboard for monitoring agent performance.
- **Autosave**: Add persistent storage support for task results.
- **Task Management**: Extend task pooling and scheduling logic to support dynamic workloads.
--- ---
## License
This class is part of the `swarms` framework and follows the framework's licensing terms.

@ -2,10 +2,6 @@ from swarms import Agent
from swarms.prompts.finance_agent_sys_prompt import ( from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT, FINANCIAL_AGENT_SYS_PROMPT,
) )
from swarm_models import OpenAIChat
model = OpenAIChat(model_name="gpt-4o")
# Initialize the agent # Initialize the agent
agent = Agent( agent = Agent(
@ -14,7 +10,7 @@ agent = Agent(
system_prompt=FINANCIAL_AGENT_SYS_PROMPT system_prompt=FINANCIAL_AGENT_SYS_PROMPT
+ "Output the <DONE> token when you're done creating a portfolio of etfs, index, funds, and more for AI", + "Output the <DONE> token when you're done creating a portfolio of etfs, index, funds, and more for AI",
max_loops=1, max_loops=1,
llm=model, model_name="gpt-4o",
dynamic_temperature_enabled=True, dynamic_temperature_enabled=True,
user_name="Kye", user_name="Kye",
retry_attempts=3, retry_attempts=3,

@ -24,7 +24,7 @@ from fastapi.middleware.cors import CORSMiddleware
from loguru import logger from loguru import logger
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from swarms import Agent from swarms.structs.agent import Agent
# Load environment variables # Load environment variables
load_dotenv() load_dotenv()
@ -127,8 +127,8 @@ class AgentUpdate(BaseModel):
description: Optional[str] = None description: Optional[str] = None
system_prompt: Optional[str] = None system_prompt: Optional[str] = None
temperature: Optional[float] = None temperature: Optional[float] = 0.5
max_loops: Optional[int] = None max_loops: Optional[int] = 1
tags: Optional[List[str]] = None tags: Optional[List[str]] = None
status: Optional[AgentStatus] = None status: Optional[AgentStatus] = None
@ -167,7 +167,7 @@ class CompletionRequest(BaseModel):
max_tokens: Optional[int] = Field( max_tokens: Optional[int] = Field(
None, description="Maximum tokens to generate" None, description="Maximum tokens to generate"
) )
temperature_override: Optional[float] = None temperature_override: Optional[float] = 0.5
stream: bool = Field( stream: bool = Field(
default=False, description="Enable streaming response" default=False, description="Enable streaming response"
) )
@ -267,7 +267,7 @@ class AgentStore:
autosave=config.autosave, autosave=config.autosave,
dashboard=config.dashboard, dashboard=config.dashboard,
verbose=config.verbose, verbose=config.verbose,
dynamic_temperature_enabled=config.dynamic_temperature_enabled, dynamic_temperature_enabled=True,
saved_state_path=f"states/{config.agent_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", saved_state_path=f"states/{config.agent_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
user_name=config.user_name, user_name=config.user_name,
retry_attempts=config.retry_attempts, retry_attempts=config.retry_attempts,
@ -328,8 +328,6 @@ class AgentStore:
if update.system_prompt: if update.system_prompt:
agent.system_prompt = update.system_prompt agent.system_prompt = update.system_prompt
if update.temperature is not None:
agent.llm.temperature = update.temperature
if update.max_loops is not None: if update.max_loops is not None:
agent.max_loops = update.max_loops agent.max_loops = update.max_loops
if update.tags is not None: if update.tags is not None:
@ -434,8 +432,8 @@ class AgentStore:
agent_name=new_name, agent_name=new_name,
description=f"Clone of {original_agent.agent_name}", description=f"Clone of {original_agent.agent_name}",
system_prompt=original_agent.system_prompt, system_prompt=original_agent.system_prompt,
model_name=original_agent.llm.model_name, model_name=original_agent.model_name,
temperature=original_agent.llm.temperature, temperature=0.5,
max_loops=original_agent.max_loops, max_loops=original_agent.max_loops,
tags=original_metadata["tags"], tags=original_metadata["tags"],
) )
@ -476,18 +474,9 @@ class AgentStore:
metadata["status"] = AgentStatus.PROCESSING metadata["status"] = AgentStatus.PROCESSING
metadata["last_used"] = start_time metadata["last_used"] = start_time
# Apply temporary overrides if specified
original_temp = agent.llm.temperature
if temperature_override is not None:
agent.llm.temperature = temperature_override
# Process the completion # Process the completion
response = agent.run(prompt) response = agent.run(prompt)
# Reset overrides
if temperature_override is not None:
agent.llm.temperature = original_temp
# Update metrics # Update metrics
processing_time = ( processing_time = (
datetime.utcnow() - start_time datetime.utcnow() - start_time
@ -518,8 +507,8 @@ class AgentStore:
response=response, response=response,
metadata={ metadata={
"agent_name": agent.agent_name, "agent_name": agent.agent_name,
"model_name": agent.llm.model_name, # "model_name": agent.llm.model_name,
"temperature": agent.llm.temperature, # "temperature": 0.5,
}, },
timestamp=datetime.utcnow(), timestamp=datetime.utcnow(),
processing_time=processing_time, processing_time=processing_time,
@ -790,7 +779,7 @@ class SwarmsAPI:
request.prompt, request.prompt,
request.agent_id, request.agent_id,
request.max_tokens, request.max_tokens,
request.temperature_override, 0.5,
) )
# Schedule background cleanup # Schedule background cleanup

@ -572,7 +572,7 @@ class Agent:
# Telemetry Processor to log agent data # Telemetry Processor to log agent data
threading.Thread(target=self.log_agent_data).start() threading.Thread(target=self.log_agent_data).start()
if self.llm is not None and self.model_name is not None: if self.llm is None and self.model_name is not None:
self.llm = self.llm_handling() self.llm = self.llm_handling()
def llm_handling(self): def llm_handling(self):
@ -2406,20 +2406,37 @@ class Agent:
**kwargs: Arbitrary keyword arguments. **kwargs: Arbitrary keyword arguments.
Returns: Returns:
The result of the method call on the `llm` object. str: The result of the method call on the `llm` object.
Raises:
AttributeError: If no suitable method is found in the llm object.
TypeError: If task is not a string or llm object is None.
ValueError: If task is empty.
""" """
if hasattr(self.llm, "__call__"): if not isinstance(task, str):
return self.llm(task, *args, **kwargs) raise TypeError("Task must be a string")
elif hasattr(self.llm, "run"):
return self.llm.run(task, *args, **kwargs) if not task.strip():
elif hasattr(self.llm, "generate"): raise ValueError("Task cannot be empty")
return self.llm.generate(task, *args, **kwargs)
elif hasattr(self.llm, "invoke"): if self.llm is None:
return self.llm.invoke(task, *args, **kwargs) raise TypeError("LLM object cannot be None")
else:
# Define common method names for LLM interfaces
method_names = ["run", "__call__", "generate", "invoke"]
for method_name in method_names:
if hasattr(self.llm, method_name):
try:
method = getattr(self.llm, method_name)
return method(task, *args, **kwargs)
except Exception as e:
raise RuntimeError(
f"Error calling {method_name}: {str(e)}"
)
raise AttributeError( raise AttributeError(
"No suitable method found in the llm object." f"No suitable method found in the llm object. Expected one of: {method_names}"
) )
def handle_sop_ops(self): def handle_sop_ops(self):

@ -12,11 +12,7 @@ from logging.handlers import RotatingFileHandler
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from swarm_models import OpenAIChat
from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT,
)
from swarms.structs.agent import Agent from swarms.structs.agent import Agent
from swarms.structs.base_workflow import BaseWorkflow from swarms.structs.base_workflow import BaseWorkflow
from swarms.utils.loguru_logger import initialize_logger from swarms.utils.loguru_logger import initialize_logger
@ -24,6 +20,7 @@ from swarms.utils.loguru_logger import initialize_logger
# Base logger initialization # Base logger initialization
logger = initialize_logger("async_workflow") logger = initialize_logger("async_workflow")
# Pydantic models for structured data # Pydantic models for structured data
class AgentOutput(BaseModel): class AgentOutput(BaseModel):
agent_id: str agent_id: str
@ -36,6 +33,7 @@ class AgentOutput(BaseModel):
status: str status: str
error: Optional[str] = None error: Optional[str] = None
class WorkflowOutput(BaseModel): class WorkflowOutput(BaseModel):
workflow_id: str workflow_id: str
workflow_name: str workflow_name: str
@ -47,6 +45,7 @@ class WorkflowOutput(BaseModel):
agent_outputs: List[AgentOutput] agent_outputs: List[AgentOutput]
metadata: Dict[str, Any] = Field(default_factory=dict) metadata: Dict[str, Any] = Field(default_factory=dict)
class SpeakerRole(str, Enum): class SpeakerRole(str, Enum):
COORDINATOR = "coordinator" COORDINATOR = "coordinator"
CRITIC = "critic" CRITIC = "critic"
@ -54,6 +53,7 @@ class SpeakerRole(str, Enum):
VALIDATOR = "validator" VALIDATOR = "validator"
DEFAULT = "default" DEFAULT = "default"
class SpeakerMessage(BaseModel): class SpeakerMessage(BaseModel):
role: SpeakerRole role: SpeakerRole
content: Any content: Any
@ -61,6 +61,7 @@ class SpeakerMessage(BaseModel):
agent_name: str agent_name: str
metadata: Dict[str, Any] = Field(default_factory=dict) metadata: Dict[str, Any] = Field(default_factory=dict)
class GroupChatConfig(BaseModel): class GroupChatConfig(BaseModel):
max_turns: int = 10 max_turns: int = 10
timeout_per_turn: float = 30.0 timeout_per_turn: float = 30.0
@ -68,6 +69,7 @@ class GroupChatConfig(BaseModel):
allow_concurrent: bool = True allow_concurrent: bool = True
save_history: bool = True save_history: bool = True
@dataclass @dataclass
class SharedMemoryItem: class SharedMemoryItem:
key: str key: str
@ -76,6 +78,7 @@ class SharedMemoryItem:
author: str author: str
metadata: Dict[str, Any] = None metadata: Dict[str, Any] = None
@dataclass @dataclass
class SpeakerConfig: class SpeakerConfig:
role: SpeakerRole role: SpeakerRole
@ -85,22 +88,30 @@ class SpeakerConfig:
timeout: float = 30.0 timeout: float = 30.0
required: bool = False required: bool = False
class SharedMemory: class SharedMemory:
"""Thread-safe shared memory implementation with persistence""" """Thread-safe shared memory implementation with persistence"""
def __init__(self, persistence_path: Optional[str] = None): def __init__(self, persistence_path: Optional[str] = None):
self._memory = {} self._memory = {}
self._lock = threading.Lock() self._lock = threading.Lock()
self._persistence_path = persistence_path self._persistence_path = persistence_path
self._load_from_disk() self._load_from_disk()
def set(self, key: str, value: Any, author: str, metadata: Dict[str, Any] = None) -> None: def set(
self,
key: str,
value: Any,
author: str,
metadata: Dict[str, Any] = None,
) -> None:
with self._lock: with self._lock:
item = SharedMemoryItem( item = SharedMemoryItem(
key=key, key=key,
value=value, value=value,
timestamp=datetime.utcnow(), timestamp=datetime.utcnow(),
author=author, author=author,
metadata=metadata or {} metadata=metadata or {},
) )
self._memory[key] = item self._memory[key] = item
self._persist_to_disk() self._persist_to_disk()
@ -110,25 +121,33 @@ class SharedMemory:
item = self._memory.get(key) item = self._memory.get(key)
return item.value if item else None return item.value if item else None
def get_with_metadata(self, key: str) -> Optional[SharedMemoryItem]: def get_with_metadata(
self, key: str
) -> Optional[SharedMemoryItem]:
with self._lock: with self._lock:
return self._memory.get(key) return self._memory.get(key)
def _persist_to_disk(self) -> None: def _persist_to_disk(self) -> None:
if self._persistence_path: if self._persistence_path:
with open(self._persistence_path, 'w') as f: with open(self._persistence_path, "w") as f:
json.dump({k: asdict(v) for k, v in self._memory.items()}, f) json.dump(
{k: asdict(v) for k, v in self._memory.items()}, f
)
def _load_from_disk(self) -> None: def _load_from_disk(self) -> None:
if self._persistence_path and os.path.exists(self._persistence_path): if self._persistence_path and os.path.exists(
with open(self._persistence_path, 'r') as f: self._persistence_path
):
with open(self._persistence_path, "r") as f:
data = json.load(f) data = json.load(f)
self._memory = { self._memory = {
k: SharedMemoryItem(**v) for k, v in data.items() k: SharedMemoryItem(**v) for k, v in data.items()
} }
class SpeakerSystem: class SpeakerSystem:
"""Manages speaker interactions and group chat functionality""" """Manages speaker interactions and group chat functionality"""
def __init__(self, default_timeout: float = 30.0): def __init__(self, default_timeout: float = 30.0):
self.speakers: Dict[SpeakerRole, SpeakerConfig] = {} self.speakers: Dict[SpeakerRole, SpeakerConfig] = {}
self.message_history: List[SpeakerMessage] = [] self.message_history: List[SpeakerMessage] = []
@ -147,12 +166,11 @@ class SpeakerSystem:
self, self,
config: SpeakerConfig, config: SpeakerConfig,
input_data: Any, input_data: Any,
context: Dict[str, Any] = None context: Dict[str, Any] = None,
) -> SpeakerMessage: ) -> SpeakerMessage:
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(
config.agent.arun(input_data), config.agent.arun(input_data), timeout=config.timeout
timeout=config.timeout
) )
return SpeakerMessage( return SpeakerMessage(
@ -160,7 +178,7 @@ class SpeakerSystem:
content=result, content=result,
timestamp=datetime.utcnow(), timestamp=datetime.utcnow(),
agent_name=config.agent.agent_name, agent_name=config.agent.agent_name,
metadata={"context": context or {}} metadata={"context": context or {}},
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
return SpeakerMessage( return SpeakerMessage(
@ -168,7 +186,7 @@ class SpeakerSystem:
content=None, content=None,
timestamp=datetime.utcnow(), timestamp=datetime.utcnow(),
agent_name=config.agent.agent_name, agent_name=config.agent.agent_name,
metadata={"error": "Timeout"} metadata={"error": "Timeout"},
) )
except Exception as e: except Exception as e:
return SpeakerMessage( return SpeakerMessage(
@ -176,9 +194,10 @@ class SpeakerSystem:
content=None, content=None,
timestamp=datetime.utcnow(), timestamp=datetime.utcnow(),
agent_name=config.agent.agent_name, agent_name=config.agent.agent_name,
metadata={"error": str(e)} metadata={"error": str(e)},
) )
class AsyncWorkflow(BaseWorkflow): class AsyncWorkflow(BaseWorkflow):
"""Enhanced asynchronous workflow with advanced speaker system""" """Enhanced asynchronous workflow with advanced speaker system"""
@ -209,20 +228,26 @@ class AsyncWorkflow(BaseWorkflow):
self.shared_memory = SharedMemory(shared_memory_path) self.shared_memory = SharedMemory(shared_memory_path)
self.speaker_system = SpeakerSystem() self.speaker_system = SpeakerSystem()
self.enable_group_chat = enable_group_chat self.enable_group_chat = enable_group_chat
self.group_chat_config = group_chat_config or GroupChatConfig() self.group_chat_config = (
group_chat_config or GroupChatConfig()
)
self._setup_logging(log_path) self._setup_logging(log_path)
self.metadata = {} self.metadata = {}
def _setup_logging(self, log_path: str) -> None: def _setup_logging(self, log_path: str) -> None:
"""Configure rotating file logger""" """Configure rotating file logger"""
self.logger = logging.getLogger(f"workflow_{self.workflow_id}") self.logger = logging.getLogger(
self.logger.setLevel(logging.DEBUG if self.verbose else logging.INFO) f"workflow_{self.workflow_id}"
)
self.logger.setLevel(
logging.DEBUG if self.verbose else logging.INFO
)
handler = RotatingFileHandler( handler = RotatingFileHandler(
log_path, maxBytes=10 * 1024 * 1024, backupCount=5 log_path, maxBytes=10 * 1024 * 1024, backupCount=5
) )
formatter = logging.Formatter( formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s' "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
) )
handler.setFormatter(formatter) handler.setFormatter(formatter)
self.logger.addHandler(handler) self.logger.addHandler(handler)
@ -235,35 +260,35 @@ class AsyncWorkflow(BaseWorkflow):
agent=agent, agent=agent,
concurrent=True, concurrent=True,
timeout=30.0, timeout=30.0,
required=False required=False,
) )
self.speaker_system.add_speaker(config) self.speaker_system.add_speaker(config)
async def run_concurrent_speakers( async def run_concurrent_speakers(
self, self, task: str, context: Dict[str, Any] = None
task: str,
context: Dict[str, Any] = None
) -> List[SpeakerMessage]: ) -> List[SpeakerMessage]:
"""Run all concurrent speakers in parallel""" """Run all concurrent speakers in parallel"""
concurrent_tasks = [ concurrent_tasks = [
self.speaker_system._execute_speaker(config, task, context) self.speaker_system._execute_speaker(
config, task, context
)
for config in self.speaker_system.speakers.values() for config in self.speaker_system.speakers.values()
if config.concurrent if config.concurrent
] ]
results = await asyncio.gather(*concurrent_tasks, return_exceptions=True) results = await asyncio.gather(
*concurrent_tasks, return_exceptions=True
)
return [r for r in results if isinstance(r, SpeakerMessage)] return [r for r in results if isinstance(r, SpeakerMessage)]
async def run_sequential_speakers( async def run_sequential_speakers(
self, self, task: str, context: Dict[str, Any] = None
task: str,
context: Dict[str, Any] = None
) -> List[SpeakerMessage]: ) -> List[SpeakerMessage]:
"""Run non-concurrent speakers in sequence""" """Run non-concurrent speakers in sequence"""
results = [] results = []
for config in sorted( for config in sorted(
self.speaker_system.speakers.values(), self.speaker_system.speakers.values(),
key=lambda x: x.priority key=lambda x: x.priority,
): ):
if not config.concurrent: if not config.concurrent:
result = await self.speaker_system._execute_speaker( result = await self.speaker_system._execute_speaker(
@ -273,13 +298,13 @@ class AsyncWorkflow(BaseWorkflow):
return results return results
async def run_group_chat( async def run_group_chat(
self, self, initial_message: str, context: Dict[str, Any] = None
initial_message: str,
context: Dict[str, Any] = None
) -> List[SpeakerMessage]: ) -> List[SpeakerMessage]:
"""Run a group chat discussion among speakers""" """Run a group chat discussion among speakers"""
if not self.enable_group_chat: if not self.enable_group_chat:
raise ValueError("Group chat is not enabled for this workflow") raise ValueError(
"Group chat is not enabled for this workflow"
)
messages: List[SpeakerMessage] = [] messages: List[SpeakerMessage] = []
current_turn = 0 current_turn = 0
@ -288,18 +313,26 @@ class AsyncWorkflow(BaseWorkflow):
turn_context = { turn_context = {
"turn": current_turn, "turn": current_turn,
"history": messages, "history": messages,
**(context or {}) **(context or {}),
} }
if self.group_chat_config.allow_concurrent: if self.group_chat_config.allow_concurrent:
turn_messages = await self.run_concurrent_speakers( turn_messages = await self.run_concurrent_speakers(
initial_message if current_turn == 0 else messages[-1].content, (
turn_context initial_message
if current_turn == 0
else messages[-1].content
),
turn_context,
) )
else: else:
turn_messages = await self.run_sequential_speakers( turn_messages = await self.run_sequential_speakers(
initial_message if current_turn == 0 else messages[-1].content, (
turn_context initial_message
if current_turn == 0
else messages[-1].content
),
turn_context,
) )
messages.extend(turn_messages) messages.extend(turn_messages)
@ -315,7 +348,9 @@ class AsyncWorkflow(BaseWorkflow):
return messages return messages
def _should_end_group_chat(self, messages: List[SpeakerMessage]) -> bool: def _should_end_group_chat(
self, messages: List[SpeakerMessage]
) -> bool:
"""Determine if group chat should end based on messages""" """Determine if group chat should end based on messages"""
if not messages: if not messages:
return True return True
@ -324,7 +359,8 @@ class AsyncWorkflow(BaseWorkflow):
if self.group_chat_config.require_all_speakers: if self.group_chat_config.require_all_speakers:
participating_roles = {msg.role for msg in messages} participating_roles = {msg.role for msg in messages}
required_roles = { required_roles = {
role for role, config in self.speaker_system.speakers.items() role
for role, config in self.speaker_system.speakers.items()
if config.required if config.required
} }
if not required_roles.issubset(participating_roles): if not required_roles.issubset(participating_roles):
@ -344,9 +380,7 @@ class AsyncWorkflow(BaseWorkflow):
await self._save_results(start_time, end_time) await self._save_results(start_time, end_time)
async def _execute_agent_task( async def _execute_agent_task(
self, self, agent: Agent, task: str
agent: Agent,
task: str
) -> AgentOutput: ) -> AgentOutput:
"""Execute a single agent task with enhanced error handling and monitoring""" """Execute a single agent task with enhanced error handling and monitoring"""
start_time = datetime.utcnow() start_time = datetime.utcnow()
@ -372,14 +406,14 @@ class AsyncWorkflow(BaseWorkflow):
output=result, output=result,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
status="success" status="success",
) )
except Exception as e: except Exception as e:
end_time = datetime.utcnow() end_time = datetime.utcnow()
self.logger.error( self.logger.error(
f"Error in agent {agent.agent_name} task {task_id}: {str(e)}", f"Error in agent {agent.agent_name} task {task_id}: {str(e)}",
exc_info=True exc_info=True,
) )
return AgentOutput( return AgentOutput(
@ -391,7 +425,7 @@ class AsyncWorkflow(BaseWorkflow):
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
status="error", status="error",
error=str(e) error=str(e),
) )
async def run(self, task: str) -> WorkflowOutput: async def run(self, task: str) -> WorkflowOutput:
@ -408,15 +442,21 @@ class AsyncWorkflow(BaseWorkflow):
if self.enable_group_chat: if self.enable_group_chat:
speaker_outputs = await self.run_group_chat(task) speaker_outputs = await self.run_group_chat(task)
else: else:
concurrent_outputs = await self.run_concurrent_speakers(task) concurrent_outputs = (
sequential_outputs = await self.run_sequential_speakers(task) await self.run_concurrent_speakers(task)
speaker_outputs = concurrent_outputs + sequential_outputs )
sequential_outputs = (
await self.run_sequential_speakers(task)
)
speaker_outputs = (
concurrent_outputs + sequential_outputs
)
# Store speaker outputs in shared memory # Store speaker outputs in shared memory
self.shared_memory.set( self.shared_memory.set(
"speaker_outputs", "speaker_outputs",
[msg.dict() for msg in speaker_outputs], [msg.dict() for msg in speaker_outputs],
"workflow" "workflow",
) )
# Create tasks for all agents # Create tasks for all agents
@ -426,13 +466,19 @@ class AsyncWorkflow(BaseWorkflow):
] ]
# Execute all tasks concurrently # Execute all tasks concurrently
agent_outputs = await asyncio.gather(*tasks, return_exceptions=True) agent_outputs = await asyncio.gather(
*tasks, return_exceptions=True
)
end_time = datetime.utcnow() end_time = datetime.utcnow()
# Calculate success/failure counts # Calculate success/failure counts
successful_tasks = sum(1 for output in agent_outputs successful_tasks = sum(
if isinstance(output, AgentOutput) and output.status == "success") 1
for output in agent_outputs
if isinstance(output, AgentOutput)
and output.status == "success"
)
failed_tasks = len(agent_outputs) - successful_tasks failed_tasks = len(agent_outputs) - successful_tasks
return WorkflowOutput( return WorkflowOutput(
@ -443,22 +489,36 @@ class AsyncWorkflow(BaseWorkflow):
total_agents=len(self.agents), total_agents=len(self.agents),
successful_tasks=successful_tasks, successful_tasks=successful_tasks,
failed_tasks=failed_tasks, failed_tasks=failed_tasks,
agent_outputs=[output for output in agent_outputs agent_outputs=[
if isinstance(output, AgentOutput)], output
for output in agent_outputs
if isinstance(output, AgentOutput)
],
metadata={ metadata={
"max_workers": self.max_workers, "max_workers": self.max_workers,
"shared_memory_keys": list(self.shared_memory._memory.keys()), "shared_memory_keys": list(
self.shared_memory._memory.keys()
),
"group_chat_enabled": self.enable_group_chat, "group_chat_enabled": self.enable_group_chat,
"total_speaker_messages": len(speaker_outputs), "total_speaker_messages": len(
"speaker_outputs": [msg.dict() for msg in speaker_outputs] speaker_outputs
} ),
"speaker_outputs": [
msg.dict() for msg in speaker_outputs
],
},
) )
except Exception as e: except Exception as e:
self.logger.error(f"Critical workflow error: {str(e)}", exc_info=True) self.logger.error(
f"Critical workflow error: {str(e)}",
exc_info=True,
)
raise raise
async def _save_results(self, start_time: datetime, end_time: datetime) -> None: async def _save_results(
self, start_time: datetime, end_time: datetime
) -> None:
"""Save workflow results to disk""" """Save workflow results to disk"""
if not self.autosave: if not self.autosave:
return return
@ -469,38 +529,59 @@ class AsyncWorkflow(BaseWorkflow):
filename = f"{output_dir}/workflow_{self.workflow_id}_{end_time.strftime('%Y%m%d_%H%M%S')}.json" filename = f"{output_dir}/workflow_{self.workflow_id}_{end_time.strftime('%Y%m%d_%H%M%S')}.json"
try: try:
with open(filename, 'w') as f: with open(filename, "w") as f:
json.dump({ json.dump(
{
"workflow_id": self.workflow_id, "workflow_id": self.workflow_id,
"start_time": start_time.isoformat(), "start_time": start_time.isoformat(),
"end_time": end_time.isoformat(), "end_time": end_time.isoformat(),
"results": [ "results": [
asdict(result) if hasattr(result, '__dict__') (
else result.dict() if hasattr(result, 'dict') asdict(result)
if hasattr(result, "__dict__")
else (
result.dict()
if hasattr(result, "dict")
else str(result) else str(result)
)
)
for result in self.results for result in self.results
], ],
"speaker_history": [ "speaker_history": [
msg.dict() for msg in self.speaker_system.message_history msg.dict()
for msg in self.speaker_system.message_history
], ],
"metadata": self.metadata "metadata": self.metadata,
}, f, default=str, indent=2) },
f,
default=str,
indent=2,
)
self.logger.info(f"Workflow results saved to {filename}") self.logger.info(f"Workflow results saved to {filename}")
except Exception as e: except Exception as e:
self.logger.error(f"Error saving workflow results: {str(e)}") self.logger.error(
f"Error saving workflow results: {str(e)}"
)
def _validate_config(self) -> None: def _validate_config(self) -> None:
"""Validate workflow configuration""" """Validate workflow configuration"""
if self.max_workers < 1: if self.max_workers < 1:
raise ValueError("max_workers must be at least 1") raise ValueError("max_workers must be at least 1")
if self.enable_group_chat and not self.speaker_system.speakers: if (
raise ValueError("Group chat enabled but no speakers configured") self.enable_group_chat
and not self.speaker_system.speakers
):
raise ValueError(
"Group chat enabled but no speakers configured"
)
for config in self.speaker_system.speakers.values(): for config in self.speaker_system.speakers.values():
if config.timeout <= 0: if config.timeout <= 0:
raise ValueError(f"Invalid timeout for speaker {config.role}") raise ValueError(
f"Invalid timeout for speaker {config.role}"
)
async def cleanup(self) -> None: async def cleanup(self) -> None:
"""Cleanup workflow resources""" """Cleanup workflow resources"""
@ -513,7 +594,14 @@ class AsyncWorkflow(BaseWorkflow):
# Persist final state # Persist final state
if self.autosave: if self.autosave:
end_time = datetime.utcnow() end_time = datetime.utcnow()
await self._save_results(self.results[0].start_time if self.results else end_time, end_time) await self._save_results(
(
self.results[0].start_time
if self.results
else end_time
),
end_time,
)
# Clear shared memory if configured # Clear shared memory if configured
self.shared_memory._memory.clear() self.shared_memory._memory.clear()
@ -522,11 +610,12 @@ class AsyncWorkflow(BaseWorkflow):
self.logger.error(f"Error during cleanup: {str(e)}") self.logger.error(f"Error during cleanup: {str(e)}")
raise raise
# Utility functions for the workflow # Utility functions for the workflow
def create_default_workflow( def create_default_workflow(
agents: List[Agent], agents: List[Agent],
name: str = "DefaultWorkflow", name: str = "DefaultWorkflow",
enable_group_chat: bool = False enable_group_chat: bool = False,
) -> AsyncWorkflow: ) -> AsyncWorkflow:
"""Create a workflow with default configuration""" """Create a workflow with default configuration"""
workflow = AsyncWorkflow( workflow = AsyncWorkflow(
@ -540,18 +629,19 @@ def create_default_workflow(
group_chat_config=GroupChatConfig( group_chat_config=GroupChatConfig(
max_turns=5, max_turns=5,
allow_concurrent=True, allow_concurrent=True,
require_all_speakers=False require_all_speakers=False,
) ),
) )
workflow.add_default_speakers() workflow.add_default_speakers()
return workflow return workflow
async def run_workflow_with_retry( async def run_workflow_with_retry(
workflow: AsyncWorkflow, workflow: AsyncWorkflow,
task: str, task: str,
max_retries: int = 3, max_retries: int = 3,
retry_delay: float = 1.0 retry_delay: float = 1.0,
) -> WorkflowOutput: ) -> WorkflowOutput:
"""Run workflow with retry logic""" """Run workflow with retry logic"""
for attempt in range(max_retries): for attempt in range(max_retries):

Loading…
Cancel
Save