You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							106 lines
						
					
					
						
							3.5 KiB
						
					
					
				
			
		
		
	
	
							106 lines
						
					
					
						
							3.5 KiB
						
					
					
				| import concurrent.futures
 | |
| from dataclasses import dataclass, field
 | |
| from typing import Dict, List, Optional
 | |
| 
 | |
| from swarms.structs.base import BaseStruct
 | |
| from swarms.structs.task import Task
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class ConcurrentWorkflow(BaseStruct):
 | |
|     """
 | |
|     ConcurrentWorkflow class for running a set of tasks concurrently using N number of autonomous agents.
 | |
| 
 | |
|     Args:
 | |
|         max_workers (int): The maximum number of workers to use for concurrent execution.
 | |
|         autosave (bool): Whether to autosave the workflow state.
 | |
|         saved_state_filepath (Optional[str]): The file path to save the workflow state.
 | |
| 
 | |
|     Attributes:
 | |
|         tasks (List[Task]): The list of tasks to execute.
 | |
|         max_workers (int): The maximum number of workers to use for concurrent execution.
 | |
|         autosave (bool): Whether to autosave the workflow state.
 | |
|         saved_state_filepath (Optional[str]): The file path to save the workflow state.
 | |
| 
 | |
|     Examples:
 | |
|     >>> from swarms.models import OpenAIChat
 | |
|     >>> from swarms.structs import ConcurrentWorkflow
 | |
|     >>> llm = OpenAIChat(openai_api_key="")
 | |
|     >>> workflow = ConcurrentWorkflow(max_workers=5)
 | |
|     >>> workflow.add("What's the weather in miami", llm)
 | |
|     >>> workflow.add("Create a report on these metrics", llm)
 | |
|     >>> workflow.run()
 | |
|     >>> workflow.tasks
 | |
|     """
 | |
| 
 | |
|     tasks: List[Dict] = field(default_factory=list)
 | |
|     max_workers: int = 5
 | |
|     autosave: bool = False
 | |
|     saved_state_filepath: Optional[str] = (
 | |
|         "runs/concurrent_workflow.json"
 | |
|     )
 | |
|     print_results: bool = False
 | |
|     return_results: bool = False
 | |
|     use_processes: bool = False
 | |
| 
 | |
|     def add(self, task: Task, tasks: List[Task] = None):
 | |
|         """Adds a task to the workflow.
 | |
| 
 | |
|         Args:
 | |
|             task (Task): _description_
 | |
|             tasks (List[Task]): _description_
 | |
|         """
 | |
|         try:
 | |
|             if tasks:
 | |
|                 for task in tasks:
 | |
|                     self.tasks.append(task)
 | |
|             else:
 | |
|                 self.tasks.append(task)
 | |
|         except Exception as error:
 | |
|             print(f"[ERROR][ConcurrentWorkflow] {error}")
 | |
|             raise error
 | |
| 
 | |
|     def run(self):
 | |
|         """
 | |
|         Executes the tasks in parallel using a ThreadPoolExecutor.
 | |
| 
 | |
|         Args:
 | |
|             print_results (bool): Whether to print the results of each task. Default is False.
 | |
|             return_results (bool): Whether to return the results of each task. Default is False.
 | |
| 
 | |
|         Returns:
 | |
|             List[Any]: A list of the results of each task, if return_results is True. Otherwise, returns None.
 | |
|         """
 | |
|         with concurrent.futures.ThreadPoolExecutor(
 | |
|             max_workers=self.max_workers
 | |
|         ) as executor:
 | |
|             futures = {
 | |
|                 executor.submit(task.execute): task
 | |
|                 for task in self.tasks
 | |
|             }
 | |
|             results = []
 | |
| 
 | |
|             for future in concurrent.futures.as_completed(futures):
 | |
|                 task = futures[future]
 | |
|                 try:
 | |
|                     result = future.result()
 | |
|                     if self.print_results:
 | |
|                         print(f"Task {task}: {result}")
 | |
|                     if self.return_results:
 | |
|                         results.append(result)
 | |
|                 except Exception as e:
 | |
|                     print(f"Task {task} generated an exception: {e}")
 | |
| 
 | |
|         return results if self.return_results else None
 | |
| 
 | |
|     def _execute_task(self, task: Task):
 | |
|         """Executes a task.
 | |
| 
 | |
|         Args:
 | |
|             task (Task): _description_
 | |
| 
 | |
|         Returns:
 | |
|             _type_: _description_
 | |
|         """
 | |
|         return task.run()
 |