|
|
|
@ -112,6 +112,7 @@ class ConcurrentWorkflow(BaseSwarm):
|
|
|
|
|
return_str_on: bool = False,
|
|
|
|
|
agent_responses: list = [],
|
|
|
|
|
auto_generate_prompts: bool = False,
|
|
|
|
|
max_workers: int = None,
|
|
|
|
|
*args,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
@ -132,9 +133,12 @@ class ConcurrentWorkflow(BaseSwarm):
|
|
|
|
|
self.return_str_on = return_str_on
|
|
|
|
|
self.agent_responses = agent_responses
|
|
|
|
|
self.auto_generate_prompts = auto_generate_prompts
|
|
|
|
|
self.max_workers = max_workers or os.cpu_count()
|
|
|
|
|
self.tasks = [] # Initialize tasks list
|
|
|
|
|
|
|
|
|
|
self.reliability_check()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reliability_check(self):
|
|
|
|
|
try:
|
|
|
|
|
logger.info("Starting reliability checks")
|
|
|
|
@ -389,6 +393,9 @@ class ConcurrentWorkflow(BaseSwarm):
|
|
|
|
|
ValueError: If an invalid device is specified.
|
|
|
|
|
Exception: If any other error occurs during execution.
|
|
|
|
|
"""
|
|
|
|
|
if task is not None:
|
|
|
|
|
self.tasks.append(task)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
logger.info(f"Attempting to run on device: {device}")
|
|
|
|
|
if device == "cpu":
|
|
|
|
@ -406,7 +413,6 @@ class ConcurrentWorkflow(BaseSwarm):
|
|
|
|
|
count, self._run, task, img, *args, **kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# If device gpu
|
|
|
|
|
elif device == "gpu":
|
|
|
|
|
logger.info("Device set to GPU")
|
|
|
|
|
return execute_on_gpu(
|
|
|
|
|