|
|
@ -12,11 +12,7 @@ from logging.handlers import RotatingFileHandler
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from swarm_models import OpenAIChat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from swarms.prompts.finance_agent_sys_prompt import (
|
|
|
|
|
|
|
|
FINANCIAL_AGENT_SYS_PROMPT,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
from swarms.structs.agent import Agent
|
|
|
|
from swarms.structs.agent import Agent
|
|
|
|
from swarms.structs.base_workflow import BaseWorkflow
|
|
|
|
from swarms.structs.base_workflow import BaseWorkflow
|
|
|
|
from swarms.utils.loguru_logger import initialize_logger
|
|
|
|
from swarms.utils.loguru_logger import initialize_logger
|
|
|
@ -24,6 +20,7 @@ from swarms.utils.loguru_logger import initialize_logger
|
|
|
|
# Base logger initialization
|
|
|
|
# Base logger initialization
|
|
|
|
logger = initialize_logger("async_workflow")
|
|
|
|
logger = initialize_logger("async_workflow")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Pydantic models for structured data
|
|
|
|
# Pydantic models for structured data
|
|
|
|
class AgentOutput(BaseModel):
|
|
|
|
class AgentOutput(BaseModel):
|
|
|
|
agent_id: str
|
|
|
|
agent_id: str
|
|
|
@ -36,6 +33,7 @@ class AgentOutput(BaseModel):
|
|
|
|
status: str
|
|
|
|
status: str
|
|
|
|
error: Optional[str] = None
|
|
|
|
error: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WorkflowOutput(BaseModel):
|
|
|
|
class WorkflowOutput(BaseModel):
|
|
|
|
workflow_id: str
|
|
|
|
workflow_id: str
|
|
|
|
workflow_name: str
|
|
|
|
workflow_name: str
|
|
|
@ -47,6 +45,7 @@ class WorkflowOutput(BaseModel):
|
|
|
|
agent_outputs: List[AgentOutput]
|
|
|
|
agent_outputs: List[AgentOutput]
|
|
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpeakerRole(str, Enum):
|
|
|
|
class SpeakerRole(str, Enum):
|
|
|
|
COORDINATOR = "coordinator"
|
|
|
|
COORDINATOR = "coordinator"
|
|
|
|
CRITIC = "critic"
|
|
|
|
CRITIC = "critic"
|
|
|
@ -54,6 +53,7 @@ class SpeakerRole(str, Enum):
|
|
|
|
VALIDATOR = "validator"
|
|
|
|
VALIDATOR = "validator"
|
|
|
|
DEFAULT = "default"
|
|
|
|
DEFAULT = "default"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpeakerMessage(BaseModel):
|
|
|
|
class SpeakerMessage(BaseModel):
|
|
|
|
role: SpeakerRole
|
|
|
|
role: SpeakerRole
|
|
|
|
content: Any
|
|
|
|
content: Any
|
|
|
@ -61,6 +61,7 @@ class SpeakerMessage(BaseModel):
|
|
|
|
agent_name: str
|
|
|
|
agent_name: str
|
|
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GroupChatConfig(BaseModel):
|
|
|
|
class GroupChatConfig(BaseModel):
|
|
|
|
max_turns: int = 10
|
|
|
|
max_turns: int = 10
|
|
|
|
timeout_per_turn: float = 30.0
|
|
|
|
timeout_per_turn: float = 30.0
|
|
|
@ -68,6 +69,7 @@ class GroupChatConfig(BaseModel):
|
|
|
|
allow_concurrent: bool = True
|
|
|
|
allow_concurrent: bool = True
|
|
|
|
save_history: bool = True
|
|
|
|
save_history: bool = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
@dataclass
|
|
|
|
class SharedMemoryItem:
|
|
|
|
class SharedMemoryItem:
|
|
|
|
key: str
|
|
|
|
key: str
|
|
|
@ -76,6 +78,7 @@ class SharedMemoryItem:
|
|
|
|
author: str
|
|
|
|
author: str
|
|
|
|
metadata: Dict[str, Any] = None
|
|
|
|
metadata: Dict[str, Any] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
@dataclass
|
|
|
|
class SpeakerConfig:
|
|
|
|
class SpeakerConfig:
|
|
|
|
role: SpeakerRole
|
|
|
|
role: SpeakerRole
|
|
|
@ -85,22 +88,30 @@ class SpeakerConfig:
|
|
|
|
timeout: float = 30.0
|
|
|
|
timeout: float = 30.0
|
|
|
|
required: bool = False
|
|
|
|
required: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SharedMemory:
|
|
|
|
class SharedMemory:
|
|
|
|
"""Thread-safe shared memory implementation with persistence"""
|
|
|
|
"""Thread-safe shared memory implementation with persistence"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, persistence_path: Optional[str] = None):
|
|
|
|
def __init__(self, persistence_path: Optional[str] = None):
|
|
|
|
self._memory = {}
|
|
|
|
self._memory = {}
|
|
|
|
self._lock = threading.Lock()
|
|
|
|
self._lock = threading.Lock()
|
|
|
|
self._persistence_path = persistence_path
|
|
|
|
self._persistence_path = persistence_path
|
|
|
|
self._load_from_disk()
|
|
|
|
self._load_from_disk()
|
|
|
|
|
|
|
|
|
|
|
|
def set(self, key: str, value: Any, author: str, metadata: Dict[str, Any] = None) -> None:
|
|
|
|
def set(
|
|
|
|
|
|
|
|
self,
|
|
|
|
|
|
|
|
key: str,
|
|
|
|
|
|
|
|
value: Any,
|
|
|
|
|
|
|
|
author: str,
|
|
|
|
|
|
|
|
metadata: Dict[str, Any] = None,
|
|
|
|
|
|
|
|
) -> None:
|
|
|
|
with self._lock:
|
|
|
|
with self._lock:
|
|
|
|
item = SharedMemoryItem(
|
|
|
|
item = SharedMemoryItem(
|
|
|
|
key=key,
|
|
|
|
key=key,
|
|
|
|
value=value,
|
|
|
|
value=value,
|
|
|
|
timestamp=datetime.utcnow(),
|
|
|
|
timestamp=datetime.utcnow(),
|
|
|
|
author=author,
|
|
|
|
author=author,
|
|
|
|
metadata=metadata or {}
|
|
|
|
metadata=metadata or {},
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self._memory[key] = item
|
|
|
|
self._memory[key] = item
|
|
|
|
self._persist_to_disk()
|
|
|
|
self._persist_to_disk()
|
|
|
@ -110,25 +121,33 @@ class SharedMemory:
|
|
|
|
item = self._memory.get(key)
|
|
|
|
item = self._memory.get(key)
|
|
|
|
return item.value if item else None
|
|
|
|
return item.value if item else None
|
|
|
|
|
|
|
|
|
|
|
|
def get_with_metadata(self, key: str) -> Optional[SharedMemoryItem]:
|
|
|
|
def get_with_metadata(
|
|
|
|
|
|
|
|
self, key: str
|
|
|
|
|
|
|
|
) -> Optional[SharedMemoryItem]:
|
|
|
|
with self._lock:
|
|
|
|
with self._lock:
|
|
|
|
return self._memory.get(key)
|
|
|
|
return self._memory.get(key)
|
|
|
|
|
|
|
|
|
|
|
|
def _persist_to_disk(self) -> None:
|
|
|
|
def _persist_to_disk(self) -> None:
|
|
|
|
if self._persistence_path:
|
|
|
|
if self._persistence_path:
|
|
|
|
with open(self._persistence_path, 'w') as f:
|
|
|
|
with open(self._persistence_path, "w") as f:
|
|
|
|
json.dump({k: asdict(v) for k, v in self._memory.items()}, f)
|
|
|
|
json.dump(
|
|
|
|
|
|
|
|
{k: asdict(v) for k, v in self._memory.items()}, f
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _load_from_disk(self) -> None:
|
|
|
|
def _load_from_disk(self) -> None:
|
|
|
|
if self._persistence_path and os.path.exists(self._persistence_path):
|
|
|
|
if self._persistence_path and os.path.exists(
|
|
|
|
with open(self._persistence_path, 'r') as f:
|
|
|
|
self._persistence_path
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
with open(self._persistence_path, "r") as f:
|
|
|
|
data = json.load(f)
|
|
|
|
data = json.load(f)
|
|
|
|
self._memory = {
|
|
|
|
self._memory = {
|
|
|
|
k: SharedMemoryItem(**v) for k, v in data.items()
|
|
|
|
k: SharedMemoryItem(**v) for k, v in data.items()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpeakerSystem:
|
|
|
|
class SpeakerSystem:
|
|
|
|
"""Manages speaker interactions and group chat functionality"""
|
|
|
|
"""Manages speaker interactions and group chat functionality"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, default_timeout: float = 30.0):
|
|
|
|
def __init__(self, default_timeout: float = 30.0):
|
|
|
|
self.speakers: Dict[SpeakerRole, SpeakerConfig] = {}
|
|
|
|
self.speakers: Dict[SpeakerRole, SpeakerConfig] = {}
|
|
|
|
self.message_history: List[SpeakerMessage] = []
|
|
|
|
self.message_history: List[SpeakerMessage] = []
|
|
|
@ -147,12 +166,11 @@ class SpeakerSystem:
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
config: SpeakerConfig,
|
|
|
|
config: SpeakerConfig,
|
|
|
|
input_data: Any,
|
|
|
|
input_data: Any,
|
|
|
|
context: Dict[str, Any] = None
|
|
|
|
context: Dict[str, Any] = None,
|
|
|
|
) -> SpeakerMessage:
|
|
|
|
) -> SpeakerMessage:
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
result = await asyncio.wait_for(
|
|
|
|
result = await asyncio.wait_for(
|
|
|
|
config.agent.arun(input_data),
|
|
|
|
config.agent.arun(input_data), timeout=config.timeout
|
|
|
|
timeout=config.timeout
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return SpeakerMessage(
|
|
|
|
return SpeakerMessage(
|
|
|
@ -160,7 +178,7 @@ class SpeakerSystem:
|
|
|
|
content=result,
|
|
|
|
content=result,
|
|
|
|
timestamp=datetime.utcnow(),
|
|
|
|
timestamp=datetime.utcnow(),
|
|
|
|
agent_name=config.agent.agent_name,
|
|
|
|
agent_name=config.agent.agent_name,
|
|
|
|
metadata={"context": context or {}}
|
|
|
|
metadata={"context": context or {}},
|
|
|
|
)
|
|
|
|
)
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
return SpeakerMessage(
|
|
|
|
return SpeakerMessage(
|
|
|
@ -168,7 +186,7 @@ class SpeakerSystem:
|
|
|
|
content=None,
|
|
|
|
content=None,
|
|
|
|
timestamp=datetime.utcnow(),
|
|
|
|
timestamp=datetime.utcnow(),
|
|
|
|
agent_name=config.agent.agent_name,
|
|
|
|
agent_name=config.agent.agent_name,
|
|
|
|
metadata={"error": "Timeout"}
|
|
|
|
metadata={"error": "Timeout"},
|
|
|
|
)
|
|
|
|
)
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
return SpeakerMessage(
|
|
|
|
return SpeakerMessage(
|
|
|
@ -176,9 +194,10 @@ class SpeakerSystem:
|
|
|
|
content=None,
|
|
|
|
content=None,
|
|
|
|
timestamp=datetime.utcnow(),
|
|
|
|
timestamp=datetime.utcnow(),
|
|
|
|
agent_name=config.agent.agent_name,
|
|
|
|
agent_name=config.agent.agent_name,
|
|
|
|
metadata={"error": str(e)}
|
|
|
|
metadata={"error": str(e)},
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AsyncWorkflow(BaseWorkflow):
|
|
|
|
class AsyncWorkflow(BaseWorkflow):
|
|
|
|
"""Enhanced asynchronous workflow with advanced speaker system"""
|
|
|
|
"""Enhanced asynchronous workflow with advanced speaker system"""
|
|
|
|
|
|
|
|
|
|
|
@ -209,20 +228,26 @@ class AsyncWorkflow(BaseWorkflow):
|
|
|
|
self.shared_memory = SharedMemory(shared_memory_path)
|
|
|
|
self.shared_memory = SharedMemory(shared_memory_path)
|
|
|
|
self.speaker_system = SpeakerSystem()
|
|
|
|
self.speaker_system = SpeakerSystem()
|
|
|
|
self.enable_group_chat = enable_group_chat
|
|
|
|
self.enable_group_chat = enable_group_chat
|
|
|
|
self.group_chat_config = group_chat_config or GroupChatConfig()
|
|
|
|
self.group_chat_config = (
|
|
|
|
|
|
|
|
group_chat_config or GroupChatConfig()
|
|
|
|
|
|
|
|
)
|
|
|
|
self._setup_logging(log_path)
|
|
|
|
self._setup_logging(log_path)
|
|
|
|
self.metadata = {}
|
|
|
|
self.metadata = {}
|
|
|
|
|
|
|
|
|
|
|
|
def _setup_logging(self, log_path: str) -> None:
|
|
|
|
def _setup_logging(self, log_path: str) -> None:
|
|
|
|
"""Configure rotating file logger"""
|
|
|
|
"""Configure rotating file logger"""
|
|
|
|
self.logger = logging.getLogger(f"workflow_{self.workflow_id}")
|
|
|
|
self.logger = logging.getLogger(
|
|
|
|
self.logger.setLevel(logging.DEBUG if self.verbose else logging.INFO)
|
|
|
|
f"workflow_{self.workflow_id}"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
self.logger.setLevel(
|
|
|
|
|
|
|
|
logging.DEBUG if self.verbose else logging.INFO
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
handler = RotatingFileHandler(
|
|
|
|
handler = RotatingFileHandler(
|
|
|
|
log_path, maxBytes=10 * 1024 * 1024, backupCount=5
|
|
|
|
log_path, maxBytes=10 * 1024 * 1024, backupCount=5
|
|
|
|
)
|
|
|
|
)
|
|
|
|
formatter = logging.Formatter(
|
|
|
|
formatter = logging.Formatter(
|
|
|
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
|
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
handler.setFormatter(formatter)
|
|
|
|
handler.setFormatter(formatter)
|
|
|
|
self.logger.addHandler(handler)
|
|
|
|
self.logger.addHandler(handler)
|
|
|
@ -235,35 +260,35 @@ class AsyncWorkflow(BaseWorkflow):
|
|
|
|
agent=agent,
|
|
|
|
agent=agent,
|
|
|
|
concurrent=True,
|
|
|
|
concurrent=True,
|
|
|
|
timeout=30.0,
|
|
|
|
timeout=30.0,
|
|
|
|
required=False
|
|
|
|
required=False,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self.speaker_system.add_speaker(config)
|
|
|
|
self.speaker_system.add_speaker(config)
|
|
|
|
|
|
|
|
|
|
|
|
async def run_concurrent_speakers(
|
|
|
|
async def run_concurrent_speakers(
|
|
|
|
self,
|
|
|
|
self, task: str, context: Dict[str, Any] = None
|
|
|
|
task: str,
|
|
|
|
|
|
|
|
context: Dict[str, Any] = None
|
|
|
|
|
|
|
|
) -> List[SpeakerMessage]:
|
|
|
|
) -> List[SpeakerMessage]:
|
|
|
|
"""Run all concurrent speakers in parallel"""
|
|
|
|
"""Run all concurrent speakers in parallel"""
|
|
|
|
concurrent_tasks = [
|
|
|
|
concurrent_tasks = [
|
|
|
|
self.speaker_system._execute_speaker(config, task, context)
|
|
|
|
self.speaker_system._execute_speaker(
|
|
|
|
|
|
|
|
config, task, context
|
|
|
|
|
|
|
|
)
|
|
|
|
for config in self.speaker_system.speakers.values()
|
|
|
|
for config in self.speaker_system.speakers.values()
|
|
|
|
if config.concurrent
|
|
|
|
if config.concurrent
|
|
|
|
]
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
results = await asyncio.gather(*concurrent_tasks, return_exceptions=True)
|
|
|
|
results = await asyncio.gather(
|
|
|
|
|
|
|
|
*concurrent_tasks, return_exceptions=True
|
|
|
|
|
|
|
|
)
|
|
|
|
return [r for r in results if isinstance(r, SpeakerMessage)]
|
|
|
|
return [r for r in results if isinstance(r, SpeakerMessage)]
|
|
|
|
|
|
|
|
|
|
|
|
async def run_sequential_speakers(
|
|
|
|
async def run_sequential_speakers(
|
|
|
|
self,
|
|
|
|
self, task: str, context: Dict[str, Any] = None
|
|
|
|
task: str,
|
|
|
|
|
|
|
|
context: Dict[str, Any] = None
|
|
|
|
|
|
|
|
) -> List[SpeakerMessage]:
|
|
|
|
) -> List[SpeakerMessage]:
|
|
|
|
"""Run non-concurrent speakers in sequence"""
|
|
|
|
"""Run non-concurrent speakers in sequence"""
|
|
|
|
results = []
|
|
|
|
results = []
|
|
|
|
for config in sorted(
|
|
|
|
for config in sorted(
|
|
|
|
self.speaker_system.speakers.values(),
|
|
|
|
self.speaker_system.speakers.values(),
|
|
|
|
key=lambda x: x.priority
|
|
|
|
key=lambda x: x.priority,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
if not config.concurrent:
|
|
|
|
if not config.concurrent:
|
|
|
|
result = await self.speaker_system._execute_speaker(
|
|
|
|
result = await self.speaker_system._execute_speaker(
|
|
|
@ -273,13 +298,13 @@ class AsyncWorkflow(BaseWorkflow):
|
|
|
|
return results
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
async def run_group_chat(
|
|
|
|
async def run_group_chat(
|
|
|
|
self,
|
|
|
|
self, initial_message: str, context: Dict[str, Any] = None
|
|
|
|
initial_message: str,
|
|
|
|
|
|
|
|
context: Dict[str, Any] = None
|
|
|
|
|
|
|
|
) -> List[SpeakerMessage]:
|
|
|
|
) -> List[SpeakerMessage]:
|
|
|
|
"""Run a group chat discussion among speakers"""
|
|
|
|
"""Run a group chat discussion among speakers"""
|
|
|
|
if not self.enable_group_chat:
|
|
|
|
if not self.enable_group_chat:
|
|
|
|
raise ValueError("Group chat is not enabled for this workflow")
|
|
|
|
raise ValueError(
|
|
|
|
|
|
|
|
"Group chat is not enabled for this workflow"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
messages: List[SpeakerMessage] = []
|
|
|
|
messages: List[SpeakerMessage] = []
|
|
|
|
current_turn = 0
|
|
|
|
current_turn = 0
|
|
|
@ -288,18 +313,26 @@ class AsyncWorkflow(BaseWorkflow):
|
|
|
|
turn_context = {
|
|
|
|
turn_context = {
|
|
|
|
"turn": current_turn,
|
|
|
|
"turn": current_turn,
|
|
|
|
"history": messages,
|
|
|
|
"history": messages,
|
|
|
|
**(context or {})
|
|
|
|
**(context or {}),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if self.group_chat_config.allow_concurrent:
|
|
|
|
if self.group_chat_config.allow_concurrent:
|
|
|
|
turn_messages = await self.run_concurrent_speakers(
|
|
|
|
turn_messages = await self.run_concurrent_speakers(
|
|
|
|
initial_message if current_turn == 0 else messages[-1].content,
|
|
|
|
(
|
|
|
|
turn_context
|
|
|
|
initial_message
|
|
|
|
|
|
|
|
if current_turn == 0
|
|
|
|
|
|
|
|
else messages[-1].content
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
turn_context,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
turn_messages = await self.run_sequential_speakers(
|
|
|
|
turn_messages = await self.run_sequential_speakers(
|
|
|
|
initial_message if current_turn == 0 else messages[-1].content,
|
|
|
|
(
|
|
|
|
turn_context
|
|
|
|
initial_message
|
|
|
|
|
|
|
|
if current_turn == 0
|
|
|
|
|
|
|
|
else messages[-1].content
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
turn_context,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
messages.extend(turn_messages)
|
|
|
|
messages.extend(turn_messages)
|
|
|
@ -315,7 +348,9 @@ class AsyncWorkflow(BaseWorkflow):
|
|
|
|
|
|
|
|
|
|
|
|
return messages
|
|
|
|
return messages
|
|
|
|
|
|
|
|
|
|
|
|
def _should_end_group_chat(self, messages: List[SpeakerMessage]) -> bool:
|
|
|
|
def _should_end_group_chat(
|
|
|
|
|
|
|
|
self, messages: List[SpeakerMessage]
|
|
|
|
|
|
|
|
) -> bool:
|
|
|
|
"""Determine if group chat should end based on messages"""
|
|
|
|
"""Determine if group chat should end based on messages"""
|
|
|
|
if not messages:
|
|
|
|
if not messages:
|
|
|
|
return True
|
|
|
|
return True
|
|
|
@ -324,7 +359,8 @@ class AsyncWorkflow(BaseWorkflow):
|
|
|
|
if self.group_chat_config.require_all_speakers:
|
|
|
|
if self.group_chat_config.require_all_speakers:
|
|
|
|
participating_roles = {msg.role for msg in messages}
|
|
|
|
participating_roles = {msg.role for msg in messages}
|
|
|
|
required_roles = {
|
|
|
|
required_roles = {
|
|
|
|
role for role, config in self.speaker_system.speakers.items()
|
|
|
|
role
|
|
|
|
|
|
|
|
for role, config in self.speaker_system.speakers.items()
|
|
|
|
if config.required
|
|
|
|
if config.required
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if not required_roles.issubset(participating_roles):
|
|
|
|
if not required_roles.issubset(participating_roles):
|
|
|
@ -344,9 +380,7 @@ class AsyncWorkflow(BaseWorkflow):
|
|
|
|
await self._save_results(start_time, end_time)
|
|
|
|
await self._save_results(start_time, end_time)
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_agent_task(
|
|
|
|
async def _execute_agent_task(
|
|
|
|
self,
|
|
|
|
self, agent: Agent, task: str
|
|
|
|
agent: Agent,
|
|
|
|
|
|
|
|
task: str
|
|
|
|
|
|
|
|
) -> AgentOutput:
|
|
|
|
) -> AgentOutput:
|
|
|
|
"""Execute a single agent task with enhanced error handling and monitoring"""
|
|
|
|
"""Execute a single agent task with enhanced error handling and monitoring"""
|
|
|
|
start_time = datetime.utcnow()
|
|
|
|
start_time = datetime.utcnow()
|
|
|
@ -372,14 +406,14 @@ class AsyncWorkflow(BaseWorkflow):
|
|
|
|
output=result,
|
|
|
|
output=result,
|
|
|
|
start_time=start_time,
|
|
|
|
start_time=start_time,
|
|
|
|
end_time=end_time,
|
|
|
|
end_time=end_time,
|
|
|
|
status="success"
|
|
|
|
status="success",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
end_time = datetime.utcnow()
|
|
|
|
end_time = datetime.utcnow()
|
|
|
|
self.logger.error(
|
|
|
|
self.logger.error(
|
|
|
|
f"Error in agent {agent.agent_name} task {task_id}: {str(e)}",
|
|
|
|
f"Error in agent {agent.agent_name} task {task_id}: {str(e)}",
|
|
|
|
exc_info=True
|
|
|
|
exc_info=True,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return AgentOutput(
|
|
|
|
return AgentOutput(
|
|
|
@ -391,7 +425,7 @@ class AsyncWorkflow(BaseWorkflow):
|
|
|
|
start_time=start_time,
|
|
|
|
start_time=start_time,
|
|
|
|
end_time=end_time,
|
|
|
|
end_time=end_time,
|
|
|
|
status="error",
|
|
|
|
status="error",
|
|
|
|
error=str(e)
|
|
|
|
error=str(e),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async def run(self, task: str) -> WorkflowOutput:
|
|
|
|
async def run(self, task: str) -> WorkflowOutput:
|
|
|
@ -408,15 +442,21 @@ class AsyncWorkflow(BaseWorkflow):
|
|
|
|
if self.enable_group_chat:
|
|
|
|
if self.enable_group_chat:
|
|
|
|
speaker_outputs = await self.run_group_chat(task)
|
|
|
|
speaker_outputs = await self.run_group_chat(task)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
concurrent_outputs = await self.run_concurrent_speakers(task)
|
|
|
|
concurrent_outputs = (
|
|
|
|
sequential_outputs = await self.run_sequential_speakers(task)
|
|
|
|
await self.run_concurrent_speakers(task)
|
|
|
|
speaker_outputs = concurrent_outputs + sequential_outputs
|
|
|
|
)
|
|
|
|
|
|
|
|
sequential_outputs = (
|
|
|
|
|
|
|
|
await self.run_sequential_speakers(task)
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
speaker_outputs = (
|
|
|
|
|
|
|
|
concurrent_outputs + sequential_outputs
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Store speaker outputs in shared memory
|
|
|
|
# Store speaker outputs in shared memory
|
|
|
|
self.shared_memory.set(
|
|
|
|
self.shared_memory.set(
|
|
|
|
"speaker_outputs",
|
|
|
|
"speaker_outputs",
|
|
|
|
[msg.dict() for msg in speaker_outputs],
|
|
|
|
[msg.dict() for msg in speaker_outputs],
|
|
|
|
"workflow"
|
|
|
|
"workflow",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Create tasks for all agents
|
|
|
|
# Create tasks for all agents
|
|
|
@ -426,13 +466,19 @@ class AsyncWorkflow(BaseWorkflow):
|
|
|
|
]
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
# Execute all tasks concurrently
|
|
|
|
# Execute all tasks concurrently
|
|
|
|
agent_outputs = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
agent_outputs = await asyncio.gather(
|
|
|
|
|
|
|
|
*tasks, return_exceptions=True
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
end_time = datetime.utcnow()
|
|
|
|
end_time = datetime.utcnow()
|
|
|
|
|
|
|
|
|
|
|
|
# Calculate success/failure counts
|
|
|
|
# Calculate success/failure counts
|
|
|
|
successful_tasks = sum(1 for output in agent_outputs
|
|
|
|
successful_tasks = sum(
|
|
|
|
if isinstance(output, AgentOutput) and output.status == "success")
|
|
|
|
1
|
|
|
|
|
|
|
|
for output in agent_outputs
|
|
|
|
|
|
|
|
if isinstance(output, AgentOutput)
|
|
|
|
|
|
|
|
and output.status == "success"
|
|
|
|
|
|
|
|
)
|
|
|
|
failed_tasks = len(agent_outputs) - successful_tasks
|
|
|
|
failed_tasks = len(agent_outputs) - successful_tasks
|
|
|
|
|
|
|
|
|
|
|
|
return WorkflowOutput(
|
|
|
|
return WorkflowOutput(
|
|
|
@ -443,22 +489,36 @@ class AsyncWorkflow(BaseWorkflow):
|
|
|
|
total_agents=len(self.agents),
|
|
|
|
total_agents=len(self.agents),
|
|
|
|
successful_tasks=successful_tasks,
|
|
|
|
successful_tasks=successful_tasks,
|
|
|
|
failed_tasks=failed_tasks,
|
|
|
|
failed_tasks=failed_tasks,
|
|
|
|
agent_outputs=[output for output in agent_outputs
|
|
|
|
agent_outputs=[
|
|
|
|
if isinstance(output, AgentOutput)],
|
|
|
|
output
|
|
|
|
|
|
|
|
for output in agent_outputs
|
|
|
|
|
|
|
|
if isinstance(output, AgentOutput)
|
|
|
|
|
|
|
|
],
|
|
|
|
metadata={
|
|
|
|
metadata={
|
|
|
|
"max_workers": self.max_workers,
|
|
|
|
"max_workers": self.max_workers,
|
|
|
|
"shared_memory_keys": list(self.shared_memory._memory.keys()),
|
|
|
|
"shared_memory_keys": list(
|
|
|
|
|
|
|
|
self.shared_memory._memory.keys()
|
|
|
|
|
|
|
|
),
|
|
|
|
"group_chat_enabled": self.enable_group_chat,
|
|
|
|
"group_chat_enabled": self.enable_group_chat,
|
|
|
|
"total_speaker_messages": len(speaker_outputs),
|
|
|
|
"total_speaker_messages": len(
|
|
|
|
"speaker_outputs": [msg.dict() for msg in speaker_outputs]
|
|
|
|
speaker_outputs
|
|
|
|
}
|
|
|
|
),
|
|
|
|
|
|
|
|
"speaker_outputs": [
|
|
|
|
|
|
|
|
msg.dict() for msg in speaker_outputs
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
},
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
self.logger.error(f"Critical workflow error: {str(e)}", exc_info=True)
|
|
|
|
self.logger.error(
|
|
|
|
|
|
|
|
f"Critical workflow error: {str(e)}",
|
|
|
|
|
|
|
|
exc_info=True,
|
|
|
|
|
|
|
|
)
|
|
|
|
raise
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
async def _save_results(self, start_time: datetime, end_time: datetime) -> None:
|
|
|
|
async def _save_results(
|
|
|
|
|
|
|
|
self, start_time: datetime, end_time: datetime
|
|
|
|
|
|
|
|
) -> None:
|
|
|
|
"""Save workflow results to disk"""
|
|
|
|
"""Save workflow results to disk"""
|
|
|
|
if not self.autosave:
|
|
|
|
if not self.autosave:
|
|
|
|
return
|
|
|
|
return
|
|
|
@ -469,38 +529,59 @@ class AsyncWorkflow(BaseWorkflow):
|
|
|
|
filename = f"{output_dir}/workflow_{self.workflow_id}_{end_time.strftime('%Y%m%d_%H%M%S')}.json"
|
|
|
|
filename = f"{output_dir}/workflow_{self.workflow_id}_{end_time.strftime('%Y%m%d_%H%M%S')}.json"
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
with open(filename, 'w') as f:
|
|
|
|
with open(filename, "w") as f:
|
|
|
|
json.dump({
|
|
|
|
json.dump(
|
|
|
|
|
|
|
|
{
|
|
|
|
"workflow_id": self.workflow_id,
|
|
|
|
"workflow_id": self.workflow_id,
|
|
|
|
"start_time": start_time.isoformat(),
|
|
|
|
"start_time": start_time.isoformat(),
|
|
|
|
"end_time": end_time.isoformat(),
|
|
|
|
"end_time": end_time.isoformat(),
|
|
|
|
"results": [
|
|
|
|
"results": [
|
|
|
|
asdict(result) if hasattr(result, '__dict__')
|
|
|
|
(
|
|
|
|
else result.dict() if hasattr(result, 'dict')
|
|
|
|
asdict(result)
|
|
|
|
|
|
|
|
if hasattr(result, "__dict__")
|
|
|
|
|
|
|
|
else (
|
|
|
|
|
|
|
|
result.dict()
|
|
|
|
|
|
|
|
if hasattr(result, "dict")
|
|
|
|
else str(result)
|
|
|
|
else str(result)
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
)
|
|
|
|
for result in self.results
|
|
|
|
for result in self.results
|
|
|
|
],
|
|
|
|
],
|
|
|
|
"speaker_history": [
|
|
|
|
"speaker_history": [
|
|
|
|
msg.dict() for msg in self.speaker_system.message_history
|
|
|
|
msg.dict()
|
|
|
|
|
|
|
|
for msg in self.speaker_system.message_history
|
|
|
|
],
|
|
|
|
],
|
|
|
|
"metadata": self.metadata
|
|
|
|
"metadata": self.metadata,
|
|
|
|
}, f, default=str, indent=2)
|
|
|
|
},
|
|
|
|
|
|
|
|
f,
|
|
|
|
|
|
|
|
default=str,
|
|
|
|
|
|
|
|
indent=2,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self.logger.info(f"Workflow results saved to {filename}")
|
|
|
|
self.logger.info(f"Workflow results saved to {filename}")
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
self.logger.error(f"Error saving workflow results: {str(e)}")
|
|
|
|
self.logger.error(
|
|
|
|
|
|
|
|
f"Error saving workflow results: {str(e)}"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _validate_config(self) -> None:
|
|
|
|
def _validate_config(self) -> None:
|
|
|
|
"""Validate workflow configuration"""
|
|
|
|
"""Validate workflow configuration"""
|
|
|
|
if self.max_workers < 1:
|
|
|
|
if self.max_workers < 1:
|
|
|
|
raise ValueError("max_workers must be at least 1")
|
|
|
|
raise ValueError("max_workers must be at least 1")
|
|
|
|
|
|
|
|
|
|
|
|
if self.enable_group_chat and not self.speaker_system.speakers:
|
|
|
|
if (
|
|
|
|
raise ValueError("Group chat enabled but no speakers configured")
|
|
|
|
self.enable_group_chat
|
|
|
|
|
|
|
|
and not self.speaker_system.speakers
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
|
|
"Group chat enabled but no speakers configured"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
for config in self.speaker_system.speakers.values():
|
|
|
|
for config in self.speaker_system.speakers.values():
|
|
|
|
if config.timeout <= 0:
|
|
|
|
if config.timeout <= 0:
|
|
|
|
raise ValueError(f"Invalid timeout for speaker {config.role}")
|
|
|
|
raise ValueError(
|
|
|
|
|
|
|
|
f"Invalid timeout for speaker {config.role}"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async def cleanup(self) -> None:
|
|
|
|
async def cleanup(self) -> None:
|
|
|
|
"""Cleanup workflow resources"""
|
|
|
|
"""Cleanup workflow resources"""
|
|
|
@ -513,7 +594,14 @@ class AsyncWorkflow(BaseWorkflow):
|
|
|
|
# Persist final state
|
|
|
|
# Persist final state
|
|
|
|
if self.autosave:
|
|
|
|
if self.autosave:
|
|
|
|
end_time = datetime.utcnow()
|
|
|
|
end_time = datetime.utcnow()
|
|
|
|
await self._save_results(self.results[0].start_time if self.results else end_time, end_time)
|
|
|
|
await self._save_results(
|
|
|
|
|
|
|
|
(
|
|
|
|
|
|
|
|
self.results[0].start_time
|
|
|
|
|
|
|
|
if self.results
|
|
|
|
|
|
|
|
else end_time
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
end_time,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Clear shared memory if configured
|
|
|
|
# Clear shared memory if configured
|
|
|
|
self.shared_memory._memory.clear()
|
|
|
|
self.shared_memory._memory.clear()
|
|
|
@ -522,11 +610,12 @@ class AsyncWorkflow(BaseWorkflow):
|
|
|
|
self.logger.error(f"Error during cleanup: {str(e)}")
|
|
|
|
self.logger.error(f"Error during cleanup: {str(e)}")
|
|
|
|
raise
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Utility functions for the workflow
|
|
|
|
# Utility functions for the workflow
|
|
|
|
def create_default_workflow(
|
|
|
|
def create_default_workflow(
|
|
|
|
agents: List[Agent],
|
|
|
|
agents: List[Agent],
|
|
|
|
name: str = "DefaultWorkflow",
|
|
|
|
name: str = "DefaultWorkflow",
|
|
|
|
enable_group_chat: bool = False
|
|
|
|
enable_group_chat: bool = False,
|
|
|
|
) -> AsyncWorkflow:
|
|
|
|
) -> AsyncWorkflow:
|
|
|
|
"""Create a workflow with default configuration"""
|
|
|
|
"""Create a workflow with default configuration"""
|
|
|
|
workflow = AsyncWorkflow(
|
|
|
|
workflow = AsyncWorkflow(
|
|
|
@ -540,18 +629,19 @@ def create_default_workflow(
|
|
|
|
group_chat_config=GroupChatConfig(
|
|
|
|
group_chat_config=GroupChatConfig(
|
|
|
|
max_turns=5,
|
|
|
|
max_turns=5,
|
|
|
|
allow_concurrent=True,
|
|
|
|
allow_concurrent=True,
|
|
|
|
require_all_speakers=False
|
|
|
|
require_all_speakers=False,
|
|
|
|
)
|
|
|
|
),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
workflow.add_default_speakers()
|
|
|
|
workflow.add_default_speakers()
|
|
|
|
return workflow
|
|
|
|
return workflow
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def run_workflow_with_retry(
|
|
|
|
async def run_workflow_with_retry(
|
|
|
|
workflow: AsyncWorkflow,
|
|
|
|
workflow: AsyncWorkflow,
|
|
|
|
task: str,
|
|
|
|
task: str,
|
|
|
|
max_retries: int = 3,
|
|
|
|
max_retries: int = 3,
|
|
|
|
retry_delay: float = 1.0
|
|
|
|
retry_delay: float = 1.0,
|
|
|
|
) -> WorkflowOutput:
|
|
|
|
) -> WorkflowOutput:
|
|
|
|
"""Run workflow with retry logic"""
|
|
|
|
"""Run workflow with retry logic"""
|
|
|
|
for attempt in range(max_retries):
|
|
|
|
for attempt in range(max_retries):
|
|
|
|