[AGENT][LiteLLM FIX] [API FIX]

pull/692/merge
Kye Gomez 3 weeks ago
parent a54785cb5f
commit 321000a299

@ -3,9 +3,8 @@ from loguru import logger
import time
from typing import Dict, Optional, Tuple
from uuid import UUID
import sys
BASE_URL = "http://localhost:8000/v1"
BASE_URL = "http://0.0.0.0:8000/v1"
def check_api_server() -> bool:
@ -199,6 +198,7 @@ def test_completion(session: TestSession, agent_id: UUID) -> bool:
if response.status_code == 200:
completion_data = response.json()
print(completion_data)
logger.success(
f"Got completion, used {completion_data['token_usage']['total_tokens']} tokens"
)
@ -317,4 +317,4 @@ def run_test_workflow():
if __name__ == "__main__":
success = run_test_workflow()
sys.exit(0 if success else 1)
print(success)

@ -2,7 +2,7 @@ import requests
import json
from time import sleep
BASE_URL = "http://api.swarms.ai:8000"
BASE_URL = "http://swarms-api-893767232.us-east-2.elb.amazonaws.com"
def make_request(method, endpoint, data=None):

@ -1,4 +1,3 @@
import asyncio
from typing import List
@ -26,8 +25,8 @@ async def create_specialized_agents() -> List[Agent]:
financial_agent = Agent(
agent_name="Financial-Analysis-Agent",
agent_description="Personal finance advisor agent",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT +
"Output the <DONE> token when you're done creating a portfolio of etfs, index, funds, and more for AI",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT
+ "Output the <DONE> token when you're done creating a portfolio of etfs, index, funds, and more for AI",
max_loops=1,
llm=model,
dynamic_temperature_enabled=True,
@ -81,6 +80,7 @@ async def create_specialized_agents() -> List[Agent]:
return [financial_agent, risk_agent, research_agent]
async def main():
# Create specialized agents
agents = await create_specialized_agents()
@ -89,7 +89,7 @@ async def main():
workflow = create_default_workflow(
agents=agents,
name="AI-Investment-Analysis-Workflow",
enable_group_chat=True
enable_group_chat=True,
)
# Configure speaker roles
@ -99,7 +99,7 @@ async def main():
agent=agents[0], # Financial agent as coordinator
priority=1,
concurrent=False,
required=True
required=True,
)
)
@ -108,7 +108,7 @@ async def main():
role=SpeakerRole.CRITIC,
agent=agents[1], # Risk agent as critic
priority=2,
concurrent=True
concurrent=True,
)
)
@ -117,7 +117,7 @@ async def main():
role=SpeakerRole.EXECUTOR,
agent=agents[2], # Research agent as executor
priority=2,
concurrent=True
concurrent=True,
)
)
@ -134,9 +134,7 @@ async def main():
try:
# Run workflow with retry
result = await run_workflow_with_retry(
workflow=workflow,
task=investment_task,
max_retries=3
workflow=workflow, task=investment_task, max_retries=3
)
print("\nWorkflow Results:")
@ -172,6 +170,7 @@ async def main():
finally:
await workflow.cleanup()
if __name__ == "__main__":
# Run the example
asyncio.run(main())

@ -93,88 +93,186 @@ print(results)
## Production-Grade Financial Example: Multiple Agents
### Example: Stock Analysis and Investment Strategy
```python
import asyncio
from swarms import Agent, AsyncWorkflow
from swarms.prompts.finance_agent_sys_prompt import FINANCIAL_AGENT_SYS_PROMPT
# Initialize multiple Financial Agents
portfolio_analysis_agent = Agent(
agent_name="Portfolio-Analysis-Agent",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
model_name="gpt-4o-mini",
autosave=True,
verbose=True,
)
from typing import List
stock_strategy_agent = Agent(
agent_name="Stock-Strategy-Agent",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
model_name="gpt-4o-mini",
autosave=True,
verbose=True,
from swarm_models import OpenAIChat
from swarms.structs.async_workflow import (
SpeakerConfig,
SpeakerRole,
create_default_workflow,
run_workflow_with_retry,
)
from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT,
)
from swarms.structs.agent import Agent
async def create_specialized_agents() -> List[Agent]:
"""Create a set of specialized agents for financial analysis"""
# Base model configuration
model = OpenAIChat(model_name="gpt-4o")
# Financial Analysis Agent
financial_agent = Agent(
agent_name="Financial-Analysis-Agent",
agent_description="Personal finance advisor agent",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT
+ "Output the <DONE> token when you're done creating a portfolio of etfs, index, funds, and more for AI",
max_loops=1,
llm=model,
dynamic_temperature_enabled=True,
user_name="Kye",
retry_attempts=3,
context_length=8192,
return_step_meta=False,
output_type="str",
auto_generate_prompt=False,
max_tokens=4000,
stopping_token="<DONE>",
saved_state_path="financial_agent.json",
interactive=False,
)
risk_management_agent = Agent(
agent_name="Risk-Management-Agent",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
model_name="gpt-4o-mini",
autosave=True,
verbose=True,
# Risk Assessment Agent
risk_agent = Agent(
agent_name="Risk-Assessment-Agent",
agent_description="Investment risk analysis specialist",
system_prompt="Analyze investment risks and provide risk scores. Output <DONE> when analysis is complete.",
max_loops=1,
llm=model,
dynamic_temperature_enabled=True,
user_name="Kye",
retry_attempts=3,
context_length=8192,
output_type="str",
max_tokens=4000,
stopping_token="<DONE>",
saved_state_path="risk_agent.json",
interactive=False,
)
# Create a workflow with multiple agents
workflow = AsyncWorkflow(
name="Financial-Workflow",
agents=[portfolio_analysis_agent, stock_strategy_agent, risk_management_agent],
verbose=True,
# Market Research Agent
research_agent = Agent(
agent_name="Market-Research-Agent",
agent_description="AI and tech market research specialist",
system_prompt="Research AI market trends and growth opportunities. Output <DONE> when research is complete.",
max_loops=1,
llm=model,
dynamic_temperature_enabled=True,
user_name="Kye",
retry_attempts=3,
context_length=8192,
output_type="str",
max_tokens=4000,
stopping_token="<DONE>",
saved_state_path="research_agent.json",
interactive=False,
)
# Run the workflow
return [financial_agent, risk_agent, research_agent]
async def main():
task = "Analyze the current stock market trends and provide an investment strategy with risk assessment."
results = await workflow.run(task)
for agent_result in results:
print(agent_result)
# Create specialized agents
agents = await create_specialized_agents()
# Create workflow with group chat enabled
workflow = create_default_workflow(
agents=agents,
name="AI-Investment-Analysis-Workflow",
enable_group_chat=True,
)
asyncio.run(main())
```
# Configure speaker roles
workflow.speaker_system.add_speaker(
SpeakerConfig(
role=SpeakerRole.COORDINATOR,
agent=agents[0], # Financial agent as coordinator
priority=1,
concurrent=False,
required=True,
)
)
**Output**:
```
INFO: Agent Portfolio-Analysis-Agent processing task: Analyze the current stock market trends and provide an investment strategy with risk assessment.
INFO: Agent Stock-Strategy-Agent processing task: Analyze the current stock market trends and provide an investment strategy with risk assessment.
INFO: Agent Risk-Management-Agent processing task: Analyze the current stock market trends and provide an investment strategy with risk assessment.
INFO: Agent Portfolio-Analysis-Agent completed task
INFO: Agent Stock-Strategy-Agent completed task
INFO: Agent Risk-Management-Agent completed task
Results:
- Detailed portfolio analysis...
- Stock investment strategies...
- Risk assessment insights...
```
workflow.speaker_system.add_speaker(
SpeakerConfig(
role=SpeakerRole.CRITIC,
agent=agents[1], # Risk agent as critic
priority=2,
concurrent=True,
)
)
---
workflow.speaker_system.add_speaker(
SpeakerConfig(
role=SpeakerRole.EXECUTOR,
agent=agents[2], # Research agent as executor
priority=2,
concurrent=True,
)
)
## Notes
1. **Autosave**: The autosave functionality is a placeholder. Users can implement custom logic to save `self.results`.
2. **Error Handling**: Exceptions raised by agents are logged and returned as part of the results.
3. **Dashboard**: The `dashboard` feature is currently not implemented but can be extended for visualization.
# Investment analysis task
investment_task = """
Create a comprehensive investment analysis for a $40k portfolio focused on AI growth opportunities:
1. Identify high-growth AI ETFs and index funds
2. Analyze risks and potential returns
3. Create a diversified portfolio allocation
4. Provide market trend analysis
Present the results in a structured markdown format.
"""
try:
# Run workflow with retry
result = await run_workflow_with_retry(
workflow=workflow, task=investment_task, max_retries=3
)
---
print("\nWorkflow Results:")
print("================")
# Process and display agent outputs
for output in result.agent_outputs:
print(f"\nAgent: {output.agent_name}")
print("-" * (len(output.agent_name) + 8))
print(output.output)
# Display group chat history if enabled
if workflow.enable_group_chat:
print("\nGroup Chat Discussion:")
print("=====================")
for msg in workflow.speaker_system.message_history:
print(f"\n{msg.role} ({msg.agent_name}):")
print(msg.content)
# Save detailed results
if result.metadata.get("shared_memory_keys"):
print("\nShared Insights:")
print("===============")
for key in result.metadata["shared_memory_keys"]:
value = workflow.shared_memory.get(key)
if value:
print(f"\n{key}:")
print(value)
except Exception as e:
print(f"Workflow failed: {str(e)}")
finally:
await workflow.cleanup()
if __name__ == "__main__":
# Run the example
asyncio.run(main())
## Dependencies
- `asyncio`: Python's asynchronous I/O framework.
- `loguru`: Logging utility for better log management.
- `swarms`: Base components (`BaseWorkflow`, `Agent`).
---
```
## Future Extensions
- **Dashboard**: Implement a real-time dashboard for monitoring agent performance.
- **Autosave**: Add persistent storage support for task results.
- **Task Management**: Extend task pooling and scheduling logic to support dynamic workloads.
---
## License
This class is part of the `swarms` framework and follows the framework's licensing terms.

@ -2,10 +2,6 @@ from swarms import Agent
from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT,
)
from swarm_models import OpenAIChat
model = OpenAIChat(model_name="gpt-4o")
# Initialize the agent
agent = Agent(
@ -14,7 +10,7 @@ agent = Agent(
system_prompt=FINANCIAL_AGENT_SYS_PROMPT
+ "Output the <DONE> token when you're done creating a portfolio of etfs, index, funds, and more for AI",
max_loops=1,
llm=model,
model_name="gpt-4o",
dynamic_temperature_enabled=True,
user_name="Kye",
retry_attempts=3,

@ -24,7 +24,7 @@ from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from pydantic import BaseModel, Field
from swarms import Agent
from swarms.structs.agent import Agent
# Load environment variables
load_dotenv()
@ -127,8 +127,8 @@ class AgentUpdate(BaseModel):
description: Optional[str] = None
system_prompt: Optional[str] = None
temperature: Optional[float] = None
max_loops: Optional[int] = None
temperature: Optional[float] = 0.5
max_loops: Optional[int] = 1
tags: Optional[List[str]] = None
status: Optional[AgentStatus] = None
@ -167,7 +167,7 @@ class CompletionRequest(BaseModel):
max_tokens: Optional[int] = Field(
None, description="Maximum tokens to generate"
)
temperature_override: Optional[float] = None
temperature_override: Optional[float] = 0.5
stream: bool = Field(
default=False, description="Enable streaming response"
)
@ -267,7 +267,7 @@ class AgentStore:
autosave=config.autosave,
dashboard=config.dashboard,
verbose=config.verbose,
dynamic_temperature_enabled=config.dynamic_temperature_enabled,
dynamic_temperature_enabled=True,
saved_state_path=f"states/{config.agent_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
user_name=config.user_name,
retry_attempts=config.retry_attempts,
@ -328,8 +328,6 @@ class AgentStore:
if update.system_prompt:
agent.system_prompt = update.system_prompt
if update.temperature is not None:
agent.llm.temperature = update.temperature
if update.max_loops is not None:
agent.max_loops = update.max_loops
if update.tags is not None:
@ -434,8 +432,8 @@ class AgentStore:
agent_name=new_name,
description=f"Clone of {original_agent.agent_name}",
system_prompt=original_agent.system_prompt,
model_name=original_agent.llm.model_name,
temperature=original_agent.llm.temperature,
model_name=original_agent.model_name,
temperature=0.5,
max_loops=original_agent.max_loops,
tags=original_metadata["tags"],
)
@ -476,18 +474,9 @@ class AgentStore:
metadata["status"] = AgentStatus.PROCESSING
metadata["last_used"] = start_time
# Apply temporary overrides if specified
original_temp = agent.llm.temperature
if temperature_override is not None:
agent.llm.temperature = temperature_override
# Process the completion
response = agent.run(prompt)
# Reset overrides
if temperature_override is not None:
agent.llm.temperature = original_temp
# Update metrics
processing_time = (
datetime.utcnow() - start_time
@ -518,8 +507,8 @@ class AgentStore:
response=response,
metadata={
"agent_name": agent.agent_name,
"model_name": agent.llm.model_name,
"temperature": agent.llm.temperature,
# "model_name": agent.llm.model_name,
# "temperature": 0.5,
},
timestamp=datetime.utcnow(),
processing_time=processing_time,
@ -790,7 +779,7 @@ class SwarmsAPI:
request.prompt,
request.agent_id,
request.max_tokens,
request.temperature_override,
0.5,
)
# Schedule background cleanup

@ -572,7 +572,7 @@ class Agent:
# Telemetry Processor to log agent data
threading.Thread(target=self.log_agent_data).start()
if self.llm is not None and self.model_name is not None:
if self.llm is None and self.model_name is not None:
self.llm = self.llm_handling()
def llm_handling(self):
@ -2406,20 +2406,37 @@ class Agent:
**kwargs: Arbitrary keyword arguments.
Returns:
The result of the method call on the `llm` object.
str: The result of the method call on the `llm` object.
Raises:
AttributeError: If no suitable method is found in the llm object.
TypeError: If task is not a string or llm object is None.
ValueError: If task is empty.
"""
if hasattr(self.llm, "__call__"):
return self.llm(task, *args, **kwargs)
elif hasattr(self.llm, "run"):
return self.llm.run(task, *args, **kwargs)
elif hasattr(self.llm, "generate"):
return self.llm.generate(task, *args, **kwargs)
elif hasattr(self.llm, "invoke"):
return self.llm.invoke(task, *args, **kwargs)
else:
if not isinstance(task, str):
raise TypeError("Task must be a string")
if not task.strip():
raise ValueError("Task cannot be empty")
if self.llm is None:
raise TypeError("LLM object cannot be None")
# Define common method names for LLM interfaces
method_names = ["run", "__call__", "generate", "invoke"]
for method_name in method_names:
if hasattr(self.llm, method_name):
try:
method = getattr(self.llm, method_name)
return method(task, *args, **kwargs)
except Exception as e:
raise RuntimeError(
f"Error calling {method_name}: {str(e)}"
)
raise AttributeError(
"No suitable method found in the llm object."
f"No suitable method found in the llm object. Expected one of: {method_names}"
)
def handle_sop_ops(self):

@ -12,11 +12,7 @@ from logging.handlers import RotatingFileHandler
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from swarm_models import OpenAIChat
from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT,
)
from swarms.structs.agent import Agent
from swarms.structs.base_workflow import BaseWorkflow
from swarms.utils.loguru_logger import initialize_logger
@ -24,6 +20,7 @@ from swarms.utils.loguru_logger import initialize_logger
# Base logger initialization
logger = initialize_logger("async_workflow")
# Pydantic models for structured data
class AgentOutput(BaseModel):
agent_id: str
@ -36,6 +33,7 @@ class AgentOutput(BaseModel):
status: str
error: Optional[str] = None
class WorkflowOutput(BaseModel):
workflow_id: str
workflow_name: str
@ -47,6 +45,7 @@ class WorkflowOutput(BaseModel):
agent_outputs: List[AgentOutput]
metadata: Dict[str, Any] = Field(default_factory=dict)
class SpeakerRole(str, Enum):
COORDINATOR = "coordinator"
CRITIC = "critic"
@ -54,6 +53,7 @@ class SpeakerRole(str, Enum):
VALIDATOR = "validator"
DEFAULT = "default"
class SpeakerMessage(BaseModel):
role: SpeakerRole
content: Any
@ -61,6 +61,7 @@ class SpeakerMessage(BaseModel):
agent_name: str
metadata: Dict[str, Any] = Field(default_factory=dict)
class GroupChatConfig(BaseModel):
max_turns: int = 10
timeout_per_turn: float = 30.0
@ -68,6 +69,7 @@ class GroupChatConfig(BaseModel):
allow_concurrent: bool = True
save_history: bool = True
@dataclass
class SharedMemoryItem:
key: str
@ -76,6 +78,7 @@ class SharedMemoryItem:
author: str
metadata: Dict[str, Any] = None
@dataclass
class SpeakerConfig:
role: SpeakerRole
@ -85,22 +88,30 @@ class SpeakerConfig:
timeout: float = 30.0
required: bool = False
class SharedMemory:
"""Thread-safe shared memory implementation with persistence"""
def __init__(self, persistence_path: Optional[str] = None):
self._memory = {}
self._lock = threading.Lock()
self._persistence_path = persistence_path
self._load_from_disk()
def set(self, key: str, value: Any, author: str, metadata: Dict[str, Any] = None) -> None:
def set(
self,
key: str,
value: Any,
author: str,
metadata: Dict[str, Any] = None,
) -> None:
with self._lock:
item = SharedMemoryItem(
key=key,
value=value,
timestamp=datetime.utcnow(),
author=author,
metadata=metadata or {}
metadata=metadata or {},
)
self._memory[key] = item
self._persist_to_disk()
@ -110,25 +121,33 @@ class SharedMemory:
item = self._memory.get(key)
return item.value if item else None
def get_with_metadata(self, key: str) -> Optional[SharedMemoryItem]:
def get_with_metadata(
self, key: str
) -> Optional[SharedMemoryItem]:
with self._lock:
return self._memory.get(key)
def _persist_to_disk(self) -> None:
if self._persistence_path:
with open(self._persistence_path, 'w') as f:
json.dump({k: asdict(v) for k, v in self._memory.items()}, f)
with open(self._persistence_path, "w") as f:
json.dump(
{k: asdict(v) for k, v in self._memory.items()}, f
)
def _load_from_disk(self) -> None:
if self._persistence_path and os.path.exists(self._persistence_path):
with open(self._persistence_path, 'r') as f:
if self._persistence_path and os.path.exists(
self._persistence_path
):
with open(self._persistence_path, "r") as f:
data = json.load(f)
self._memory = {
k: SharedMemoryItem(**v) for k, v in data.items()
}
class SpeakerSystem:
"""Manages speaker interactions and group chat functionality"""
def __init__(self, default_timeout: float = 30.0):
self.speakers: Dict[SpeakerRole, SpeakerConfig] = {}
self.message_history: List[SpeakerMessage] = []
@ -147,12 +166,11 @@ class SpeakerSystem:
self,
config: SpeakerConfig,
input_data: Any,
context: Dict[str, Any] = None
context: Dict[str, Any] = None,
) -> SpeakerMessage:
try:
result = await asyncio.wait_for(
config.agent.arun(input_data),
timeout=config.timeout
config.agent.arun(input_data), timeout=config.timeout
)
return SpeakerMessage(
@ -160,7 +178,7 @@ class SpeakerSystem:
content=result,
timestamp=datetime.utcnow(),
agent_name=config.agent.agent_name,
metadata={"context": context or {}}
metadata={"context": context or {}},
)
except asyncio.TimeoutError:
return SpeakerMessage(
@ -168,7 +186,7 @@ class SpeakerSystem:
content=None,
timestamp=datetime.utcnow(),
agent_name=config.agent.agent_name,
metadata={"error": "Timeout"}
metadata={"error": "Timeout"},
)
except Exception as e:
return SpeakerMessage(
@ -176,9 +194,10 @@ class SpeakerSystem:
content=None,
timestamp=datetime.utcnow(),
agent_name=config.agent.agent_name,
metadata={"error": str(e)}
metadata={"error": str(e)},
)
class AsyncWorkflow(BaseWorkflow):
"""Enhanced asynchronous workflow with advanced speaker system"""
@ -209,20 +228,26 @@ class AsyncWorkflow(BaseWorkflow):
self.shared_memory = SharedMemory(shared_memory_path)
self.speaker_system = SpeakerSystem()
self.enable_group_chat = enable_group_chat
self.group_chat_config = group_chat_config or GroupChatConfig()
self.group_chat_config = (
group_chat_config or GroupChatConfig()
)
self._setup_logging(log_path)
self.metadata = {}
def _setup_logging(self, log_path: str) -> None:
"""Configure rotating file logger"""
self.logger = logging.getLogger(f"workflow_{self.workflow_id}")
self.logger.setLevel(logging.DEBUG if self.verbose else logging.INFO)
self.logger = logging.getLogger(
f"workflow_{self.workflow_id}"
)
self.logger.setLevel(
logging.DEBUG if self.verbose else logging.INFO
)
handler = RotatingFileHandler(
log_path, maxBytes=10 * 1024 * 1024, backupCount=5
)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
@ -235,35 +260,35 @@ class AsyncWorkflow(BaseWorkflow):
agent=agent,
concurrent=True,
timeout=30.0,
required=False
required=False,
)
self.speaker_system.add_speaker(config)
async def run_concurrent_speakers(
self,
task: str,
context: Dict[str, Any] = None
self, task: str, context: Dict[str, Any] = None
) -> List[SpeakerMessage]:
"""Run all concurrent speakers in parallel"""
concurrent_tasks = [
self.speaker_system._execute_speaker(config, task, context)
self.speaker_system._execute_speaker(
config, task, context
)
for config in self.speaker_system.speakers.values()
if config.concurrent
]
results = await asyncio.gather(*concurrent_tasks, return_exceptions=True)
results = await asyncio.gather(
*concurrent_tasks, return_exceptions=True
)
return [r for r in results if isinstance(r, SpeakerMessage)]
async def run_sequential_speakers(
self,
task: str,
context: Dict[str, Any] = None
self, task: str, context: Dict[str, Any] = None
) -> List[SpeakerMessage]:
"""Run non-concurrent speakers in sequence"""
results = []
for config in sorted(
self.speaker_system.speakers.values(),
key=lambda x: x.priority
key=lambda x: x.priority,
):
if not config.concurrent:
result = await self.speaker_system._execute_speaker(
@ -273,13 +298,13 @@ class AsyncWorkflow(BaseWorkflow):
return results
async def run_group_chat(
self,
initial_message: str,
context: Dict[str, Any] = None
self, initial_message: str, context: Dict[str, Any] = None
) -> List[SpeakerMessage]:
"""Run a group chat discussion among speakers"""
if not self.enable_group_chat:
raise ValueError("Group chat is not enabled for this workflow")
raise ValueError(
"Group chat is not enabled for this workflow"
)
messages: List[SpeakerMessage] = []
current_turn = 0
@ -288,18 +313,26 @@ class AsyncWorkflow(BaseWorkflow):
turn_context = {
"turn": current_turn,
"history": messages,
**(context or {})
**(context or {}),
}
if self.group_chat_config.allow_concurrent:
turn_messages = await self.run_concurrent_speakers(
initial_message if current_turn == 0 else messages[-1].content,
turn_context
(
initial_message
if current_turn == 0
else messages[-1].content
),
turn_context,
)
else:
turn_messages = await self.run_sequential_speakers(
initial_message if current_turn == 0 else messages[-1].content,
turn_context
(
initial_message
if current_turn == 0
else messages[-1].content
),
turn_context,
)
messages.extend(turn_messages)
@ -315,7 +348,9 @@ class AsyncWorkflow(BaseWorkflow):
return messages
def _should_end_group_chat(self, messages: List[SpeakerMessage]) -> bool:
def _should_end_group_chat(
self, messages: List[SpeakerMessage]
) -> bool:
"""Determine if group chat should end based on messages"""
if not messages:
return True
@ -324,7 +359,8 @@ class AsyncWorkflow(BaseWorkflow):
if self.group_chat_config.require_all_speakers:
participating_roles = {msg.role for msg in messages}
required_roles = {
role for role, config in self.speaker_system.speakers.items()
role
for role, config in self.speaker_system.speakers.items()
if config.required
}
if not required_roles.issubset(participating_roles):
@ -344,9 +380,7 @@ class AsyncWorkflow(BaseWorkflow):
await self._save_results(start_time, end_time)
async def _execute_agent_task(
self,
agent: Agent,
task: str
self, agent: Agent, task: str
) -> AgentOutput:
"""Execute a single agent task with enhanced error handling and monitoring"""
start_time = datetime.utcnow()
@ -372,14 +406,14 @@ class AsyncWorkflow(BaseWorkflow):
output=result,
start_time=start_time,
end_time=end_time,
status="success"
status="success",
)
except Exception as e:
end_time = datetime.utcnow()
self.logger.error(
f"Error in agent {agent.agent_name} task {task_id}: {str(e)}",
exc_info=True
exc_info=True,
)
return AgentOutput(
@ -391,7 +425,7 @@ class AsyncWorkflow(BaseWorkflow):
start_time=start_time,
end_time=end_time,
status="error",
error=str(e)
error=str(e),
)
async def run(self, task: str) -> WorkflowOutput:
@ -408,15 +442,21 @@ class AsyncWorkflow(BaseWorkflow):
if self.enable_group_chat:
speaker_outputs = await self.run_group_chat(task)
else:
concurrent_outputs = await self.run_concurrent_speakers(task)
sequential_outputs = await self.run_sequential_speakers(task)
speaker_outputs = concurrent_outputs + sequential_outputs
concurrent_outputs = (
await self.run_concurrent_speakers(task)
)
sequential_outputs = (
await self.run_sequential_speakers(task)
)
speaker_outputs = (
concurrent_outputs + sequential_outputs
)
# Store speaker outputs in shared memory
self.shared_memory.set(
"speaker_outputs",
[msg.dict() for msg in speaker_outputs],
"workflow"
"workflow",
)
# Create tasks for all agents
@ -426,13 +466,19 @@ class AsyncWorkflow(BaseWorkflow):
]
# Execute all tasks concurrently
agent_outputs = await asyncio.gather(*tasks, return_exceptions=True)
agent_outputs = await asyncio.gather(
*tasks, return_exceptions=True
)
end_time = datetime.utcnow()
# Calculate success/failure counts
successful_tasks = sum(1 for output in agent_outputs
if isinstance(output, AgentOutput) and output.status == "success")
successful_tasks = sum(
1
for output in agent_outputs
if isinstance(output, AgentOutput)
and output.status == "success"
)
failed_tasks = len(agent_outputs) - successful_tasks
return WorkflowOutput(
@ -443,22 +489,36 @@ class AsyncWorkflow(BaseWorkflow):
total_agents=len(self.agents),
successful_tasks=successful_tasks,
failed_tasks=failed_tasks,
agent_outputs=[output for output in agent_outputs
if isinstance(output, AgentOutput)],
agent_outputs=[
output
for output in agent_outputs
if isinstance(output, AgentOutput)
],
metadata={
"max_workers": self.max_workers,
"shared_memory_keys": list(self.shared_memory._memory.keys()),
"shared_memory_keys": list(
self.shared_memory._memory.keys()
),
"group_chat_enabled": self.enable_group_chat,
"total_speaker_messages": len(speaker_outputs),
"speaker_outputs": [msg.dict() for msg in speaker_outputs]
}
"total_speaker_messages": len(
speaker_outputs
),
"speaker_outputs": [
msg.dict() for msg in speaker_outputs
],
},
)
except Exception as e:
self.logger.error(f"Critical workflow error: {str(e)}", exc_info=True)
self.logger.error(
f"Critical workflow error: {str(e)}",
exc_info=True,
)
raise
async def _save_results(self, start_time: datetime, end_time: datetime) -> None:
async def _save_results(
self, start_time: datetime, end_time: datetime
) -> None:
"""Save workflow results to disk"""
if not self.autosave:
return
@ -469,38 +529,59 @@ class AsyncWorkflow(BaseWorkflow):
filename = f"{output_dir}/workflow_{self.workflow_id}_{end_time.strftime('%Y%m%d_%H%M%S')}.json"
try:
with open(filename, 'w') as f:
json.dump({
with open(filename, "w") as f:
json.dump(
{
"workflow_id": self.workflow_id,
"start_time": start_time.isoformat(),
"end_time": end_time.isoformat(),
"results": [
asdict(result) if hasattr(result, '__dict__')
else result.dict() if hasattr(result, 'dict')
(
asdict(result)
if hasattr(result, "__dict__")
else (
result.dict()
if hasattr(result, "dict")
else str(result)
)
)
for result in self.results
],
"speaker_history": [
msg.dict() for msg in self.speaker_system.message_history
msg.dict()
for msg in self.speaker_system.message_history
],
"metadata": self.metadata
}, f, default=str, indent=2)
"metadata": self.metadata,
},
f,
default=str,
indent=2,
)
self.logger.info(f"Workflow results saved to {filename}")
except Exception as e:
self.logger.error(f"Error saving workflow results: {str(e)}")
self.logger.error(
f"Error saving workflow results: {str(e)}"
)
def _validate_config(self) -> None:
"""Validate workflow configuration"""
if self.max_workers < 1:
raise ValueError("max_workers must be at least 1")
if self.enable_group_chat and not self.speaker_system.speakers:
raise ValueError("Group chat enabled but no speakers configured")
if (
self.enable_group_chat
and not self.speaker_system.speakers
):
raise ValueError(
"Group chat enabled but no speakers configured"
)
for config in self.speaker_system.speakers.values():
if config.timeout <= 0:
raise ValueError(f"Invalid timeout for speaker {config.role}")
raise ValueError(
f"Invalid timeout for speaker {config.role}"
)
async def cleanup(self) -> None:
"""Cleanup workflow resources"""
@ -513,7 +594,14 @@ class AsyncWorkflow(BaseWorkflow):
# Persist final state
if self.autosave:
end_time = datetime.utcnow()
await self._save_results(self.results[0].start_time if self.results else end_time, end_time)
await self._save_results(
(
self.results[0].start_time
if self.results
else end_time
),
end_time,
)
# Clear shared memory if configured
self.shared_memory._memory.clear()
@ -522,11 +610,12 @@ class AsyncWorkflow(BaseWorkflow):
self.logger.error(f"Error during cleanup: {str(e)}")
raise
# Utility functions for the workflow
def create_default_workflow(
agents: List[Agent],
name: str = "DefaultWorkflow",
enable_group_chat: bool = False
enable_group_chat: bool = False,
) -> AsyncWorkflow:
"""Create a workflow with default configuration"""
workflow = AsyncWorkflow(
@ -540,18 +629,19 @@ def create_default_workflow(
group_chat_config=GroupChatConfig(
max_turns=5,
allow_concurrent=True,
require_all_speakers=False
)
require_all_speakers=False,
),
)
workflow.add_default_speakers()
return workflow
async def run_workflow_with_retry(
workflow: AsyncWorkflow,
task: str,
max_retries: int = 3,
retry_delay: float = 1.0
retry_delay: float = 1.0,
) -> WorkflowOutput:
"""Run workflow with retry logic"""
for attempt in range(max_retries):

Loading…
Cancel
Save