parent
38438c2266
commit
42ef8134fb
@ -0,0 +1 @@
|
||||
# Base implementation for the diffusers library
|
@ -0,0 +1,28 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class TextModality(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ImageModality(BaseModel):
|
||||
url: str
|
||||
alt_text: Optional[str]
|
||||
|
||||
|
||||
class AudioModality(BaseModel):
|
||||
url: str
|
||||
transcript: Optional[str]
|
||||
|
||||
|
||||
class VideoModality(BaseModel):
|
||||
url: str
|
||||
transcript: Optional[str]
|
||||
|
||||
|
||||
class MultimodalData(BaseModel):
|
||||
text: Optional[List[TextModality]]
|
||||
images: Optional[List[ImageModality]]
|
||||
audio: Optional[List[AudioModality]]
|
||||
video: Optional[List[VideoModality]]
|
@ -0,0 +1,96 @@
|
||||
import concurrent.futures
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from swarms.structs.base import BaseStruct
|
||||
from swarms.structs.task import Task
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConcurrentWorkflow(BaseStruct):
|
||||
"""
|
||||
ConcurrentWorkflow class for running a set of tasks concurrently using N number of autonomous agents.
|
||||
|
||||
Args:
|
||||
max_workers (int): The maximum number of workers to use for concurrent execution.
|
||||
autosave (bool): Whether to autosave the workflow state.
|
||||
saved_state_filepath (Optional[str]): The file path to save the workflow state.
|
||||
|
||||
Attributes:
|
||||
tasks (List[Task]): The list of tasks to execute.
|
||||
max_workers (int): The maximum number of workers to use for concurrent execution.
|
||||
autosave (bool): Whether to autosave the workflow state.
|
||||
saved_state_filepath (Optional[str]): The file path to save the workflow state.
|
||||
|
||||
Examples:
|
||||
>>> from swarms.models import OpenAIChat
|
||||
>>> from swarms.structs import ConcurrentWorkflow
|
||||
>>> llm = OpenAIChat(openai_api_key="")
|
||||
>>> workflow = ConcurrentWorkflow(max_workers=5)
|
||||
>>> workflow.add("What's the weather in miami", llm)
|
||||
>>> workflow.add("Create a report on these metrics", llm)
|
||||
>>> workflow.run()
|
||||
>>> workflow.tasks
|
||||
"""
|
||||
|
||||
tasks: List[Dict] = field(default_factory=list)
|
||||
max_workers: int = 5
|
||||
autosave: bool = False
|
||||
saved_state_filepath: Optional[str] = (
|
||||
"runs/concurrent_workflow.json"
|
||||
)
|
||||
print_results: bool = False
|
||||
return_results: bool = False
|
||||
use_processes: bool = False
|
||||
|
||||
def add(self, task: Task):
|
||||
"""Adds a task to the workflow.
|
||||
|
||||
Args:
|
||||
task (Task): _description_
|
||||
"""
|
||||
self.tasks.append(task)
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Executes the tasks in parallel using a ThreadPoolExecutor.
|
||||
|
||||
Args:
|
||||
print_results (bool): Whether to print the results of each task. Default is False.
|
||||
return_results (bool): Whether to return the results of each task. Default is False.
|
||||
|
||||
Returns:
|
||||
List[Any]: A list of the results of each task, if return_results is True. Otherwise, returns None.
|
||||
"""
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=self.max_workers
|
||||
) as executor:
|
||||
futures = {
|
||||
executor.submit(task.execute): task
|
||||
for task in self.tasks
|
||||
}
|
||||
results = []
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
task = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
if self.print_results:
|
||||
print(f"Task {task}: {result}")
|
||||
if self.return_results:
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
print(f"Task {task} generated an exception: {e}")
|
||||
|
||||
return results if self.return_results else None
|
||||
|
||||
def _execute_task(self, task: Task):
|
||||
"""Executes a task.
|
||||
|
||||
Args:
|
||||
task (Task): _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
return task.run()
|
@ -0,0 +1,56 @@
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
from concurrent.futures import Future
|
||||
from swarms.structs import ConcurrentWorkflow, Task, Agent
|
||||
|
||||
|
||||
def test_add():
|
||||
workflow = ConcurrentWorkflow(max_workers=2)
|
||||
task = Mock(spec=Task)
|
||||
workflow.add(task)
|
||||
assert task in workflow.tasks
|
||||
|
||||
|
||||
def test_run():
|
||||
workflow = ConcurrentWorkflow(max_workers=2)
|
||||
task1 = create_autospec(Task)
|
||||
task2 = create_autospec(Task)
|
||||
workflow.add(task1)
|
||||
workflow.add(task2)
|
||||
|
||||
with patch(
|
||||
"concurrent.futures.ThreadPoolExecutor"
|
||||
) as mock_executor:
|
||||
future1 = Future()
|
||||
future1.set_result(None)
|
||||
future2 = Future()
|
||||
future2.set_result(None)
|
||||
|
||||
mock_executor.return_value.__enter__.return_value.submit.side_effect = [
|
||||
future1,
|
||||
future2,
|
||||
]
|
||||
mock_executor.return_value.__enter__.return_value.as_completed.return_value = [
|
||||
future1,
|
||||
future2,
|
||||
]
|
||||
|
||||
workflow.run()
|
||||
|
||||
task1.execute.assert_called_once()
|
||||
task2.execute.assert_called_once()
|
||||
|
||||
|
||||
def test_execute_task():
|
||||
workflow = ConcurrentWorkflow(max_workers=2)
|
||||
task = create_autospec(Task)
|
||||
workflow._execute_task(task)
|
||||
task.execute.assert_called_once()
|
||||
|
||||
|
||||
def test_agent_execution():
|
||||
workflow = ConcurrentWorkflow(max_workers=2)
|
||||
agent = create_autospec(Agent)
|
||||
task = Task(agent)
|
||||
workflow.add(task)
|
||||
workflow._execute_task(task)
|
||||
agent.execute.assert_called_once()
|
Loading…
Reference in new issue