[TESTS][swarms.structs.autoscaler]

pull/334/head
Kye 1 year ago
parent aa4f3d1563
commit d66621a9e1

@ -12,9 +12,10 @@ from swarms.utils.decorators import (
log_decorator,
timing_decorator,
)
from swarms.structs.base import BaseStructure
class AutoScaler:
class AutoScaler(BaseStructure):
"""
AutoScaler class
@ -260,11 +261,17 @@ class AutoScaler:
def balance_load(self):
"""Distributes tasks among agents based on their current load."""
try:
while not self.task_queue.empty():
for agent in self.agents_pool:
if agent.can_accept_task():
task = self.task_queue.get()
agent.run(task)
except Exception as error:
print(
f"Error balancing load: {error} try again with a new"
" task"
)
def set_scaling_strategy(
self, strategy: Callable[[int, int], int]
@ -274,6 +281,7 @@ class AutoScaler:
def execute_scaling_strategy(self):
"""Execute the custom scaling strategy if defined."""
try:
if hasattr(self, "custom_scale_strategy"):
scale_amount = self.custom_scale_strategy(
self.task_queue.qsize(), len(self.agents_pool)
@ -285,6 +293,11 @@ class AutoScaler:
for _ in range(abs(scale_amount)):
if len(self.agents_pool) > 10:
del self.agents_pool[-1]
except Exception as error:
print(
f"Error executing scaling strategy: {error} try again"
" with a new task"
)
def report_agent_metrics(self) -> Dict[str, List[float]]:
"""Collects and reports metrics from each agent."""

@ -1,7 +1,9 @@
import os
from dotenv import load_dotenv
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from swarms.models import OpenAIChat
from swarms.structs import Agent
@ -138,3 +140,79 @@ def test_autoscaler_print_dashboard(mock_print):
autoscaler = AutoScaler(initial_agents=5, agent=agent)
autoscaler.print_dashboard()
mock_print.assert_called()
@patch("swarms.structs.autoscaler.logging")
def test_check_agent_health_all_healthy(mock_logging):
autoscaler = AutoScaler(initial_agents=5, agent=agent)
for agent in autoscaler.agents_pool:
agent.is_healthy = MagicMock(return_value=True)
autoscaler.check_agent_health()
mock_logging.warning.assert_not_called()
@patch("swarms.structs.autoscaler.logging")
def test_check_agent_health_some_unhealthy(mock_logging):
autoscaler = AutoScaler(initial_agents=5, agent=agent)
for i, agent in enumerate(autoscaler.agents_pool):
agent.is_healthy = MagicMock(return_value=(i % 2 == 0))
autoscaler.check_agent_health()
assert mock_logging.warning.call_count == 2
@patch("swarms.structs.autoscaler.logging")
def test_check_agent_health_all_unhealthy(mock_logging):
autoscaler = AutoScaler(initial_agents=5, agent=agent)
for agent in autoscaler.agents_pool:
agent.is_healthy = MagicMock(return_value=False)
autoscaler.check_agent_health()
assert mock_logging.warning.call_count == 5
@patch("swarms.structs.autoscaler.Agent")
def test_add_agent(mock_agent):
autoscaler = AutoScaler(initial_agents=5, agent=agent)
initial_count = len(autoscaler.agents_pool)
autoscaler.add_agent()
assert len(autoscaler.agents_pool) == initial_count + 1
mock_agent.assert_called_once()
@patch("swarms.structs.autoscaler.Agent")
def test_remove_agent(mock_agent):
autoscaler = AutoScaler(initial_agents=5, agent=agent)
initial_count = len(autoscaler.agents_pool)
autoscaler.remove_agent()
assert len(autoscaler.agents_pool) == initial_count - 1
@patch("swarms.structs.autoscaler.AutoScaler.add_agent")
@patch("swarms.structs.autoscaler.AutoScaler.remove_agent")
def test_scale(mock_remove_agent, mock_add_agent):
autoscaler = AutoScaler(initial_agents=5, agent=agent)
autoscaler.scale(10)
assert mock_add_agent.call_count == 5
assert mock_remove_agent.call_count == 0
mock_add_agent.reset_mock()
mock_remove_agent.reset_mock()
autoscaler.scale(3)
assert mock_add_agent.call_count == 0
assert mock_remove_agent.call_count == 2
def test_add_task_success():
autoscaler = AutoScaler(initial_agents=5)
initial_queue_size = autoscaler.task_queue.qsize()
autoscaler.add_task("test_task")
assert autoscaler.task_queue.qsize() == initial_queue_size + 1
@patch("swarms.structs.autoscaler.queue.Queue.put")
def test_add_task_exception(mock_put):
mock_put.side_effect = Exception("test error")
autoscaler = AutoScaler(initial_agents=5)
with pytest.raises(Exception) as e:
autoscaler.add_task("test_task")
assert str(e.value) == "test error"

Loading…
Cancel
Save