[TESTS][swarms.structs.autoscaler]

pull/334/head
Kye 1 year ago
parent aa4f3d1563
commit d66621a9e1

@ -12,9 +12,10 @@ from swarms.utils.decorators import (
log_decorator, log_decorator,
timing_decorator, timing_decorator,
) )
from swarms.structs.base import BaseStructure
class AutoScaler: class AutoScaler(BaseStructure):
""" """
AutoScaler class AutoScaler class
@ -260,11 +261,17 @@ class AutoScaler:
def balance_load(self): def balance_load(self):
"""Distributes tasks among agents based on their current load.""" """Distributes tasks among agents based on their current load."""
while not self.task_queue.empty(): try:
for agent in self.agents_pool: while not self.task_queue.empty():
if agent.can_accept_task(): for agent in self.agents_pool:
task = self.task_queue.get() if agent.can_accept_task():
agent.run(task) task = self.task_queue.get()
agent.run(task)
except Exception as error:
print(
f"Error balancing load: {error} try again with a new"
" task"
)
def set_scaling_strategy( def set_scaling_strategy(
self, strategy: Callable[[int, int], int] self, strategy: Callable[[int, int], int]
@ -274,17 +281,23 @@ class AutoScaler:
def execute_scaling_strategy(self): def execute_scaling_strategy(self):
"""Execute the custom scaling strategy if defined.""" """Execute the custom scaling strategy if defined."""
if hasattr(self, "custom_scale_strategy"): try:
scale_amount = self.custom_scale_strategy( if hasattr(self, "custom_scale_strategy"):
self.task_queue.qsize(), len(self.agents_pool) scale_amount = self.custom_scale_strategy(
self.task_queue.qsize(), len(self.agents_pool)
)
if scale_amount > 0:
for _ in range(scale_amount):
self.agents_pool.append(self.agent())
elif scale_amount < 0:
for _ in range(abs(scale_amount)):
if len(self.agents_pool) > 10:
del self.agents_pool[-1]
except Exception as error:
print(
f"Error executing scaling strategy: {error} try again"
" with a new task"
) )
if scale_amount > 0:
for _ in range(scale_amount):
self.agents_pool.append(self.agent())
elif scale_amount < 0:
for _ in range(abs(scale_amount)):
if len(self.agents_pool) > 10:
del self.agents_pool[-1]
def report_agent_metrics(self) -> Dict[str, List[float]]: def report_agent_metrics(self) -> Dict[str, List[float]]:
"""Collects and reports metrics from each agent.""" """Collects and reports metrics from each agent."""

@ -1,7 +1,9 @@
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
from unittest.mock import patch from unittest.mock import MagicMock, patch
import pytest
from swarms.models import OpenAIChat from swarms.models import OpenAIChat
from swarms.structs import Agent from swarms.structs import Agent
@ -138,3 +140,79 @@ def test_autoscaler_print_dashboard(mock_print):
autoscaler = AutoScaler(initial_agents=5, agent=agent) autoscaler = AutoScaler(initial_agents=5, agent=agent)
autoscaler.print_dashboard() autoscaler.print_dashboard()
mock_print.assert_called() mock_print.assert_called()
@patch("swarms.structs.autoscaler.logging")
def test_check_agent_health_all_healthy(mock_logging):
autoscaler = AutoScaler(initial_agents=5, agent=agent)
for agent in autoscaler.agents_pool:
agent.is_healthy = MagicMock(return_value=True)
autoscaler.check_agent_health()
mock_logging.warning.assert_not_called()
@patch("swarms.structs.autoscaler.logging")
def test_check_agent_health_some_unhealthy(mock_logging):
autoscaler = AutoScaler(initial_agents=5, agent=agent)
for i, agent in enumerate(autoscaler.agents_pool):
agent.is_healthy = MagicMock(return_value=(i % 2 == 0))
autoscaler.check_agent_health()
assert mock_logging.warning.call_count == 2
@patch("swarms.structs.autoscaler.logging")
def test_check_agent_health_all_unhealthy(mock_logging):
autoscaler = AutoScaler(initial_agents=5, agent=agent)
for agent in autoscaler.agents_pool:
agent.is_healthy = MagicMock(return_value=False)
autoscaler.check_agent_health()
assert mock_logging.warning.call_count == 5
@patch("swarms.structs.autoscaler.Agent")
def test_add_agent(mock_agent):
autoscaler = AutoScaler(initial_agents=5, agent=agent)
initial_count = len(autoscaler.agents_pool)
autoscaler.add_agent()
assert len(autoscaler.agents_pool) == initial_count + 1
mock_agent.assert_called_once()
@patch("swarms.structs.autoscaler.Agent")
def test_remove_agent(mock_agent):
autoscaler = AutoScaler(initial_agents=5, agent=agent)
initial_count = len(autoscaler.agents_pool)
autoscaler.remove_agent()
assert len(autoscaler.agents_pool) == initial_count - 1
@patch("swarms.structs.autoscaler.AutoScaler.add_agent")
@patch("swarms.structs.autoscaler.AutoScaler.remove_agent")
def test_scale(mock_remove_agent, mock_add_agent):
autoscaler = AutoScaler(initial_agents=5, agent=agent)
autoscaler.scale(10)
assert mock_add_agent.call_count == 5
assert mock_remove_agent.call_count == 0
mock_add_agent.reset_mock()
mock_remove_agent.reset_mock()
autoscaler.scale(3)
assert mock_add_agent.call_count == 0
assert mock_remove_agent.call_count == 2
def test_add_task_success():
autoscaler = AutoScaler(initial_agents=5)
initial_queue_size = autoscaler.task_queue.qsize()
autoscaler.add_task("test_task")
assert autoscaler.task_queue.qsize() == initial_queue_size + 1
@patch("swarms.structs.autoscaler.queue.Queue.put")
def test_add_task_exception(mock_put):
mock_put.side_effect = Exception("test error")
autoscaler = AutoScaler(initial_agents=5)
with pytest.raises(Exception) as e:
autoscaler.add_task("test_task")
assert str(e.value) == "test error"

Loading…
Cancel
Save