diff --git a/swarms/structs/autoscaler.py b/swarms/structs/autoscaler.py index 6f07d0d3..f26247d5 100644 --- a/swarms/structs/autoscaler.py +++ b/swarms/structs/autoscaler.py @@ -12,9 +12,10 @@ from swarms.utils.decorators import ( log_decorator, timing_decorator, ) +from swarms.structs.base import BaseStructure -class AutoScaler: +class AutoScaler(BaseStructure): """ AutoScaler class @@ -260,11 +261,17 @@ class AutoScaler: def balance_load(self): """Distributes tasks among agents based on their current load.""" - while not self.task_queue.empty(): - for agent in self.agents_pool: - if agent.can_accept_task(): - task = self.task_queue.get() - agent.run(task) + try: + while not self.task_queue.empty(): + for agent in self.agents_pool: + if agent.can_accept_task(): + task = self.task_queue.get() + agent.run(task) + except Exception as error: + print( + f"Error balancing load: {error} try again with a new" + " task" + ) def set_scaling_strategy( self, strategy: Callable[[int, int], int] @@ -274,17 +281,23 @@ class AutoScaler: def execute_scaling_strategy(self): """Execute the custom scaling strategy if defined.""" - if hasattr(self, "custom_scale_strategy"): - scale_amount = self.custom_scale_strategy( - self.task_queue.qsize(), len(self.agents_pool) + try: + if hasattr(self, "custom_scale_strategy"): + scale_amount = self.custom_scale_strategy( + self.task_queue.qsize(), len(self.agents_pool) + ) + if scale_amount > 0: + for _ in range(scale_amount): + self.agents_pool.append(self.agent()) + elif scale_amount < 0: + for _ in range(abs(scale_amount)): + if len(self.agents_pool) > 10: + del self.agents_pool[-1] + except Exception as error: + print( + f"Error executing scaling strategy: {error} try again" + " with a new task" ) - if scale_amount > 0: - for _ in range(scale_amount): - self.agents_pool.append(self.agent()) - elif scale_amount < 0: - for _ in range(abs(scale_amount)): - if len(self.agents_pool) > 10: - del self.agents_pool[-1] def report_agent_metrics(self) -> Dict[str, List[float]]: """Collects and reports metrics from each agent.""" diff --git a/tests/structs/test_autoscaler.py b/tests/structs/test_autoscaler.py index f3b9fefa..62abeede 100644 --- a/tests/structs/test_autoscaler.py +++ b/tests/structs/test_autoscaler.py @@ -1,7 +1,9 @@ import os from dotenv import load_dotenv -from unittest.mock import patch +from unittest.mock import MagicMock, patch + +import pytest from swarms.models import OpenAIChat from swarms.structs import Agent @@ -138,3 +140,79 @@ def test_autoscaler_print_dashboard(mock_print): autoscaler = AutoScaler(initial_agents=5, agent=agent) autoscaler.print_dashboard() mock_print.assert_called() + + +@patch("swarms.structs.autoscaler.logging") +def test_check_agent_health_all_healthy(mock_logging): + autoscaler = AutoScaler(initial_agents=5, agent=agent) + for agent in autoscaler.agents_pool: + agent.is_healthy = MagicMock(return_value=True) + autoscaler.check_agent_health() + mock_logging.warning.assert_not_called() + + +@patch("swarms.structs.autoscaler.logging") +def test_check_agent_health_some_unhealthy(mock_logging): + autoscaler = AutoScaler(initial_agents=5, agent=agent) + for i, agent in enumerate(autoscaler.agents_pool): + agent.is_healthy = MagicMock(return_value=(i % 2 == 0)) + autoscaler.check_agent_health() + assert mock_logging.warning.call_count == 2 + + +@patch("swarms.structs.autoscaler.logging") +def test_check_agent_health_all_unhealthy(mock_logging): + autoscaler = AutoScaler(initial_agents=5, agent=agent) + for agent in autoscaler.agents_pool: + agent.is_healthy = MagicMock(return_value=False) + autoscaler.check_agent_health() + assert mock_logging.warning.call_count == 5 + + +@patch("swarms.structs.autoscaler.Agent") +def test_add_agent(mock_agent): + autoscaler = AutoScaler(initial_agents=5, agent=agent) + initial_count = len(autoscaler.agents_pool) + autoscaler.add_agent() + assert len(autoscaler.agents_pool) == initial_count + 1 + mock_agent.assert_called_once() + + +@patch("swarms.structs.autoscaler.Agent") +def test_remove_agent(mock_agent): + autoscaler = AutoScaler(initial_agents=5, agent=agent) + initial_count = len(autoscaler.agents_pool) + autoscaler.remove_agent() + assert len(autoscaler.agents_pool) == initial_count - 1 + + +@patch("swarms.structs.autoscaler.AutoScaler.add_agent") +@patch("swarms.structs.autoscaler.AutoScaler.remove_agent") +def test_scale(mock_remove_agent, mock_add_agent): + autoscaler = AutoScaler(initial_agents=5, agent=agent) + autoscaler.scale(10) + assert mock_add_agent.call_count == 5 + assert mock_remove_agent.call_count == 0 + + mock_add_agent.reset_mock() + mock_remove_agent.reset_mock() + + autoscaler.scale(3) + assert mock_add_agent.call_count == 0 + assert mock_remove_agent.call_count == 2 + + +def test_add_task_success(): + autoscaler = AutoScaler(initial_agents=5) + initial_queue_size = autoscaler.task_queue.qsize() + autoscaler.add_task("test_task") + assert autoscaler.task_queue.qsize() == initial_queue_size + 1 + + +@patch("swarms.structs.autoscaler.queue.Queue.put") +def test_add_task_exception(mock_put): + mock_put.side_effect = Exception("test error") + autoscaler = AutoScaler(initial_agents=5) + with pytest.raises(Exception) as e: + autoscaler.add_task("test_task") + assert str(e.value) == "test error"