parent
612beb4df3
commit
36a092f6e6
@ -0,0 +1,315 @@
|
||||
import random
|
||||
from threading import Lock
|
||||
from time import sleep
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from swarms import Agent
|
||||
from swarms.structs.base_swarm import BaseSwarm
|
||||
from swarms.utils.loguru_logger import logger
|
||||
|
||||
|
||||
class AgentLoadBalancer(BaseSwarm):
|
||||
"""
|
||||
A load balancer class that distributes tasks among a group of agents.
|
||||
|
||||
Args:
|
||||
agents (List[Agent]): The list of agents available for task execution.
|
||||
max_retries (int, optional): The maximum number of retries for a task if it fails. Defaults to 3.
|
||||
max_loops (int, optional): The maximum number of loops to run a task. Defaults to 5.
|
||||
cooldown_time (float, optional): The cooldown time between retries. Defaults to 0.
|
||||
|
||||
Attributes:
|
||||
agents (List[Agent]): The list of agents available for task execution.
|
||||
agent_status (Dict[str, bool]): The status of each agent, indicating whether it is available or not.
|
||||
max_retries (int): The maximum number of retries for a task if it fails.
|
||||
max_loops (int): The maximum number of loops to run a task.
|
||||
agent_performance (Dict[str, Dict[str, int]]): The performance statistics of each agent.
|
||||
lock (Lock): A lock to ensure thread safety.
|
||||
cooldown_time (float): The cooldown time between retries.
|
||||
|
||||
Methods:
|
||||
get_available_agent: Get an available agent for task execution.
|
||||
set_agent_status: Set the status of an agent.
|
||||
update_performance: Update the performance statistics of an agent.
|
||||
log_performance: Log the performance statistics of all agents.
|
||||
run_task: Run a single task using an available agent.
|
||||
run_multiple_tasks: Run multiple tasks using available agents.
|
||||
run_task_with_loops: Run a task multiple times using an available agent.
|
||||
run_task_with_callback: Run a task with a callback function.
|
||||
run_task_with_timeout: Run a task with a timeout.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agents: List[Agent],
|
||||
max_retries: int = 3,
|
||||
max_loops: int = 5,
|
||||
cooldown_time: float = 0,
|
||||
):
|
||||
self.agents = agents
|
||||
self.agent_status = {agent.agent_name: True for agent in agents}
|
||||
self.max_retries = max_retries
|
||||
self.max_loops = max_loops
|
||||
self.agent_performance = {
|
||||
agent.agent_name: {"success_count": 0, "failure_count": 0}
|
||||
for agent in agents
|
||||
}
|
||||
self.lock = Lock()
|
||||
self.cooldown_time = cooldown_time
|
||||
|
||||
def get_available_agent(self) -> Optional[Agent]:
|
||||
"""
|
||||
Get an available agent for task execution.
|
||||
|
||||
Returns:
|
||||
Optional[Agent]: An available agent, or None if no agents are available.
|
||||
|
||||
"""
|
||||
with self.lock:
|
||||
available_agents = [
|
||||
agent
|
||||
for agent in self.agents
|
||||
if self.agent_status[agent.agent_name]
|
||||
]
|
||||
if not available_agents:
|
||||
return None
|
||||
return random.choice(available_agents)
|
||||
|
||||
def set_agent_status(self, agent: Agent, status: bool) -> None:
|
||||
"""
|
||||
Set the status of an agent.
|
||||
|
||||
Args:
|
||||
agent (Agent): The agent whose status needs to be set.
|
||||
status (bool): The status to set for the agent.
|
||||
|
||||
"""
|
||||
with self.lock:
|
||||
self.agent_status[agent.agent_name] = status
|
||||
|
||||
def update_performance(self, agent: Agent, success: bool) -> None:
|
||||
"""
|
||||
Update the performance statistics of an agent.
|
||||
|
||||
Args:
|
||||
agent (Agent): The agent whose performance statistics need to be updated.
|
||||
success (bool): Whether the task executed by the agent was successful or not.
|
||||
|
||||
"""
|
||||
with self.lock:
|
||||
if success:
|
||||
self.agent_performance[agent.agent_name][
|
||||
"success_count"
|
||||
] += 1
|
||||
else:
|
||||
self.agent_performance[agent.agent_name][
|
||||
"failure_count"
|
||||
] += 1
|
||||
|
||||
def log_performance(self) -> None:
|
||||
"""
|
||||
Log the performance statistics of all agents.
|
||||
|
||||
"""
|
||||
logger.info("Agent Performance:")
|
||||
for agent_name, stats in self.agent_performance.items():
|
||||
logger.info(f"{agent_name}: {stats}")
|
||||
|
||||
def run_task(self, task: str, *args, **kwargs) -> str:
|
||||
"""
|
||||
Run a single task using an available agent.
|
||||
|
||||
Args:
|
||||
task (str): The task to be executed.
|
||||
|
||||
Returns:
|
||||
str: The output of the task execution.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no available agents are found to handle the request.
|
||||
|
||||
"""
|
||||
try:
|
||||
retries = 0
|
||||
while retries < self.max_retries:
|
||||
agent = self.get_available_agent()
|
||||
if not agent:
|
||||
raise RuntimeError(
|
||||
"No available agents to handle the request."
|
||||
)
|
||||
|
||||
try:
|
||||
self.set_agent_status(agent, False)
|
||||
output = agent.run(task, *args, **kwargs)
|
||||
self.update_performance(agent, True)
|
||||
return output
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error with agent {agent.agent_name}: {e}"
|
||||
)
|
||||
self.update_performance(agent, False)
|
||||
retries += 1
|
||||
sleep(self.cooldown_time)
|
||||
if retries >= self.max_retries:
|
||||
raise e
|
||||
finally:
|
||||
self.set_agent_status(agent, True)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Task failed: {e} try again by optimizing the code."
|
||||
)
|
||||
raise RuntimeError(f"Task failed: {e}")
|
||||
|
||||
def run_multiple_tasks(self, tasks: List[str]) -> List[str]:
|
||||
"""
|
||||
Run multiple tasks using available agents.
|
||||
|
||||
Args:
|
||||
tasks (List[str]): The list of tasks to be executed.
|
||||
|
||||
Returns:
|
||||
List[str]: The list of outputs corresponding to each task execution.
|
||||
|
||||
"""
|
||||
results = []
|
||||
for task in tasks:
|
||||
result = self.run_task(task)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def run_task_with_loops(self, task: str) -> List[str]:
|
||||
"""
|
||||
Run a task multiple times using an available agent.
|
||||
|
||||
Args:
|
||||
task (str): The task to be executed.
|
||||
|
||||
Returns:
|
||||
List[str]: The list of outputs corresponding to each task execution.
|
||||
|
||||
"""
|
||||
results = []
|
||||
for _ in range(self.max_loops):
|
||||
result = self.run_task(task)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def run_task_with_callback(
|
||||
self, task: str, callback: Callable[[str], None]
|
||||
) -> None:
|
||||
"""
|
||||
Run a task with a callback function.
|
||||
|
||||
Args:
|
||||
task (str): The task to be executed.
|
||||
callback (Callable[[str], None]): The callback function to be called with the task result.
|
||||
|
||||
"""
|
||||
try:
|
||||
result = self.run_task(task)
|
||||
callback(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Task failed: {e}")
|
||||
callback(str(e))
|
||||
|
||||
def run_task_with_timeout(self, task: str, timeout: float) -> str:
|
||||
"""
|
||||
Run a task with a timeout.
|
||||
|
||||
Args:
|
||||
task (str): The task to be executed.
|
||||
timeout (float): The maximum time (in seconds) to wait for the task to complete.
|
||||
|
||||
Returns:
|
||||
str: The output of the task execution.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the task execution exceeds the specified timeout.
|
||||
Exception: If the task execution raises an exception.
|
||||
|
||||
"""
|
||||
import threading
|
||||
|
||||
result = [None]
|
||||
exception = [None]
|
||||
|
||||
def target():
|
||||
try:
|
||||
result[0] = self.run_task(task)
|
||||
except Exception as e:
|
||||
exception[0] = e
|
||||
|
||||
thread = threading.Thread(target=target)
|
||||
thread.start()
|
||||
thread.join(timeout)
|
||||
|
||||
if thread.is_alive():
|
||||
raise TimeoutError(f"Task timed out after {timeout} seconds.")
|
||||
|
||||
if exception[0]:
|
||||
raise exception[0]
|
||||
|
||||
return result[0]
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# from swarms import llama3Hosted()
|
||||
# # User initializes the agents
|
||||
# agents = [
|
||||
# Agent(
|
||||
# agent_name="Transcript Generator 1",
|
||||
# agent_description="Generate a transcript for a youtube video on what swarms are!",
|
||||
# llm=llama3Hosted(),
|
||||
# max_loops="auto",
|
||||
# autosave=True,
|
||||
# dashboard=False,
|
||||
# streaming_on=True,
|
||||
# verbose=True,
|
||||
# stopping_token="<DONE>",
|
||||
# interactive=True,
|
||||
# state_save_file_type="json",
|
||||
# saved_state_path="transcript_generator_1.json",
|
||||
# ),
|
||||
# Agent(
|
||||
# agent_name="Transcript Generator 2",
|
||||
# agent_description="Generate a transcript for a youtube video on what swarms are!",
|
||||
# llm=llama3Hosted(),
|
||||
# max_loops="auto",
|
||||
# autosave=True,
|
||||
# dashboard=False,
|
||||
# streaming_on=True,
|
||||
# verbose=True,
|
||||
# stopping_token="<DONE>",
|
||||
# interactive=True,
|
||||
# state_save_file_type="json",
|
||||
# saved_state_path="transcript_generator_2.json",
|
||||
# )
|
||||
# # Add more agents as needed
|
||||
# ]
|
||||
|
||||
# load_balancer = LoadBalancer(agents)
|
||||
|
||||
# try:
|
||||
# result = load_balancer.run_task("Generate a transcript for a youtube video on what swarms are!")
|
||||
# print(result)
|
||||
|
||||
# # Running multiple tasks
|
||||
# tasks = [
|
||||
# "Generate a transcript for a youtube video on what swarms are!",
|
||||
# "Generate a transcript for a youtube video on AI advancements!"
|
||||
# ]
|
||||
# results = load_balancer.run_multiple_tasks(tasks)
|
||||
# for res in results:
|
||||
# print(res)
|
||||
|
||||
# # Running task with loops
|
||||
# loop_results = load_balancer.run_task_with_loops("Generate a transcript for a youtube video on what swarms are!")
|
||||
# for res in loop_results:
|
||||
# print(res)
|
||||
|
||||
# except RuntimeError as e:
|
||||
# print(f"Error: {e}")
|
||||
|
||||
# # Log performance
|
||||
# load_balancer.log_performance()
|
@ -1,59 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from swarms.agents.multion_agent import MultiOnAgent
|
||||
|
||||
|
||||
@patch("swarms.agents.multion_agent.multion")
|
||||
def test_multion_agent_run(mock_multion):
|
||||
mock_response = MagicMock()
|
||||
mock_response.result = "result"
|
||||
mock_response.status = "status"
|
||||
mock_response.lastUrl = "lastUrl"
|
||||
mock_multion.browse.return_value = mock_response
|
||||
|
||||
agent = MultiOnAgent(
|
||||
multion_api_key="test_key",
|
||||
max_steps=5,
|
||||
starting_url="https://www.example.com",
|
||||
)
|
||||
result, status, last_url = agent.run("task")
|
||||
|
||||
assert result == "result"
|
||||
assert status == "status"
|
||||
assert last_url == "lastUrl"
|
||||
mock_multion.browse.assert_called_once_with(
|
||||
{
|
||||
"cmd": "task",
|
||||
"url": "https://www.example.com",
|
||||
"maxSteps": 5,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Additional tests for different tasks
|
||||
@pytest.mark.parametrize(
|
||||
"task", ["task1", "task2", "task3", "task4", "task5"]
|
||||
)
|
||||
@patch("swarms.agents.multion_agent.multion")
|
||||
def test_multion_agent_run_different_tasks(mock_multion, task):
|
||||
mock_response = MagicMock()
|
||||
mock_response.result = "result"
|
||||
mock_response.status = "status"
|
||||
mock_response.lastUrl = "lastUrl"
|
||||
mock_multion.browse.return_value = mock_response
|
||||
|
||||
agent = MultiOnAgent(
|
||||
multion_api_key="test_key",
|
||||
max_steps=5,
|
||||
starting_url="https://www.example.com",
|
||||
)
|
||||
result, status, last_url = agent.run(task)
|
||||
|
||||
assert result == "result"
|
||||
assert status == "status"
|
||||
assert last_url == "lastUrl"
|
||||
mock_multion.browse.assert_called_once_with(
|
||||
{"cmd": task, "url": "https://www.example.com", "maxSteps": 5}
|
||||
)
|
@ -1,276 +0,0 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from swarms.models import OpenAIChat
|
||||
from swarms.structs import Agent
|
||||
from swarms.structs.autoscaler import AutoScaler
|
||||
|
||||
load_dotenv()
|
||||
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
llm = OpenAIChat(
|
||||
temperature=0.5,
|
||||
openai_api_key=api_key,
|
||||
)
|
||||
agent = Agent(llm=llm, max_loops=1)
|
||||
|
||||
|
||||
def test_autoscaler_init():
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
assert autoscaler.initial_agents == 5
|
||||
assert autoscaler.scale_up_factor == 1
|
||||
assert autoscaler.idle_threshold == 0.2
|
||||
assert autoscaler.busy_threshold == 0.7
|
||||
assert autoscaler.autoscale is True
|
||||
assert autoscaler.min_agents == 1
|
||||
assert autoscaler.max_agents == 5
|
||||
assert autoscaler.custom_scale_strategy is None
|
||||
assert len(autoscaler.agents_pool) == 5
|
||||
assert autoscaler.task_queue.empty() is True
|
||||
|
||||
|
||||
def test_autoscaler_add_task():
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
autoscaler.add_task("task1")
|
||||
assert autoscaler.task_queue.empty() is False
|
||||
|
||||
|
||||
def test_autoscaler_run():
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
out = autoscaler.run(
|
||||
agent.id,
|
||||
"Generate a 10,000 word blog on health and wellness.",
|
||||
)
|
||||
assert out == "Generate a 10,000 word blog on health and wellness."
|
||||
|
||||
|
||||
def test_autoscaler_add_agent():
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
autoscaler.add_agent(agent)
|
||||
assert len(autoscaler.agents_pool) == 6
|
||||
|
||||
|
||||
def test_autoscaler_remove_agent():
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
autoscaler.remove_agent(agent)
|
||||
assert len(autoscaler.agents_pool) == 4
|
||||
|
||||
|
||||
def test_autoscaler_get_agent():
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
agent = autoscaler.get_agent()
|
||||
assert isinstance(agent, Agent)
|
||||
|
||||
|
||||
def test_autoscaler_get_agent_by_id():
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
agent = autoscaler.get_agent_by_id(agent.id)
|
||||
assert isinstance(agent, Agent)
|
||||
|
||||
|
||||
def test_autoscaler_get_agent_by_id_not_found():
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
agent = autoscaler.get_agent_by_id("fake_id")
|
||||
assert agent is None
|
||||
|
||||
|
||||
@patch("swarms.swarms.Agent.is_healthy")
|
||||
def test_autoscaler_check_agent_health(mock_is_healthy):
|
||||
mock_is_healthy.side_effect = [False, True, True, True, True]
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
autoscaler.check_agent_health()
|
||||
assert mock_is_healthy.call_count == 5
|
||||
|
||||
|
||||
def test_autoscaler_balance_load():
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
autoscaler.add_task("task1")
|
||||
autoscaler.add_task("task2")
|
||||
autoscaler.balance_load()
|
||||
assert autoscaler.task_queue.empty()
|
||||
|
||||
|
||||
def test_autoscaler_set_scaling_strategy():
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
|
||||
def strategy(x, y):
|
||||
return x - y
|
||||
|
||||
autoscaler.set_scaling_strategy(strategy)
|
||||
assert autoscaler.custom_scale_strategy == strategy
|
||||
|
||||
|
||||
def test_autoscaler_execute_scaling_strategy():
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
|
||||
def strategy(x, y):
|
||||
return x - y
|
||||
|
||||
autoscaler.set_scaling_strategy(strategy)
|
||||
autoscaler.add_task("task1")
|
||||
autoscaler.execute_scaling_strategy()
|
||||
assert len(autoscaler.agents_pool) == 4
|
||||
|
||||
|
||||
def test_autoscaler_report_agent_metrics():
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
metrics = autoscaler.report_agent_metrics()
|
||||
assert set(metrics.keys()) == {
|
||||
"completion_time",
|
||||
"success_rate",
|
||||
"error_rate",
|
||||
}
|
||||
|
||||
|
||||
@patch("swarms.swarms.AutoScaler.report_agent_metrics")
|
||||
def test_autoscaler_report(mock_report_agent_metrics):
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
autoscaler.report()
|
||||
mock_report_agent_metrics.assert_called_once()
|
||||
|
||||
|
||||
@patch("builtins.print")
|
||||
def test_autoscaler_print_dashboard(mock_print):
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
autoscaler.print_dashboard()
|
||||
mock_print.assert_called()
|
||||
|
||||
|
||||
@patch("swarms.structs.autoscaler.logging")
|
||||
def test_check_agent_health_all_healthy(mock_logging):
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
for agent in autoscaler.agents_pool:
|
||||
agent.is_healthy = MagicMock(return_value=True)
|
||||
autoscaler.check_agent_health()
|
||||
mock_logging.warning.assert_not_called()
|
||||
|
||||
|
||||
@patch("swarms.structs.autoscaler.logging")
|
||||
def test_check_agent_health_some_unhealthy(mock_logging):
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
for i, agent in enumerate(autoscaler.agents_pool):
|
||||
agent.is_healthy = MagicMock(return_value=(i % 2 == 0))
|
||||
autoscaler.check_agent_health()
|
||||
assert mock_logging.warning.call_count == 2
|
||||
|
||||
|
||||
@patch("swarms.structs.autoscaler.logging")
|
||||
def test_check_agent_health_all_unhealthy(mock_logging):
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
for agent in autoscaler.agents_pool:
|
||||
agent.is_healthy = MagicMock(return_value=False)
|
||||
autoscaler.check_agent_health()
|
||||
assert mock_logging.warning.call_count == 5
|
||||
|
||||
|
||||
@patch("swarms.structs.autoscaler.Agent")
|
||||
def test_add_agent(mock_agent):
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
initial_count = len(autoscaler.agents_pool)
|
||||
autoscaler.add_agent()
|
||||
assert len(autoscaler.agents_pool) == initial_count + 1
|
||||
mock_agent.assert_called_once()
|
||||
|
||||
|
||||
@patch("swarms.structs.autoscaler.Agent")
|
||||
def test_remove_agent(mock_agent):
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
initial_count = len(autoscaler.agents_pool)
|
||||
autoscaler.remove_agent()
|
||||
assert len(autoscaler.agents_pool) == initial_count - 1
|
||||
|
||||
|
||||
@patch("swarms.structs.autoscaler.AutoScaler.add_agent")
|
||||
@patch("swarms.structs.autoscaler.AutoScaler.remove_agent")
|
||||
def test_scale(mock_remove_agent, mock_add_agent):
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
autoscaler.scale(10)
|
||||
assert mock_add_agent.call_count == 5
|
||||
assert mock_remove_agent.call_count == 0
|
||||
|
||||
mock_add_agent.reset_mock()
|
||||
mock_remove_agent.reset_mock()
|
||||
|
||||
autoscaler.scale(3)
|
||||
assert mock_add_agent.call_count == 0
|
||||
assert mock_remove_agent.call_count == 2
|
||||
|
||||
|
||||
def test_add_task_success():
|
||||
autoscaler = AutoScaler(initial_agents=5)
|
||||
initial_queue_size = autoscaler.task_queue.qsize()
|
||||
autoscaler.add_task("test_task")
|
||||
assert autoscaler.task_queue.qsize() == initial_queue_size + 1
|
||||
|
||||
|
||||
@patch("swarms.structs.autoscaler.queue.Queue.put")
|
||||
def test_add_task_exception(mock_put):
|
||||
mock_put.side_effect = Exception("test error")
|
||||
autoscaler = AutoScaler(initial_agents=5)
|
||||
with pytest.raises(Exception) as e:
|
||||
autoscaler.add_task("test_task")
|
||||
assert str(e.value) == "test error"
|
||||
|
||||
|
||||
def test_autoscaler_initialization():
|
||||
autoscaler = AutoScaler(
|
||||
initial_agents=5,
|
||||
scale_up_factor=2,
|
||||
idle_threshold=0.1,
|
||||
busy_threshold=0.8,
|
||||
agent=agent,
|
||||
)
|
||||
assert isinstance(autoscaler, AutoScaler)
|
||||
assert autoscaler.scale_up_factor == 2
|
||||
assert autoscaler.idle_threshold == 0.1
|
||||
assert autoscaler.busy_threshold == 0.8
|
||||
assert len(autoscaler.agents_pool) == 5
|
||||
|
||||
|
||||
def test_autoscaler_add_task():
|
||||
autoscaler = AutoScaler(agent=agent)
|
||||
autoscaler.add_task("task1")
|
||||
assert autoscaler.task_queue.qsize() == 1
|
||||
|
||||
|
||||
def test_autoscaler_scale_up():
|
||||
autoscaler = AutoScaler(
|
||||
initial_agents=5, scale_up_factor=2, agent=agent
|
||||
)
|
||||
autoscaler.scale_up()
|
||||
assert len(autoscaler.agents_pool) == 10
|
||||
|
||||
|
||||
def test_autoscaler_scale_down():
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
autoscaler.scale_down()
|
||||
assert len(autoscaler.agents_pool) == 4
|
||||
|
||||
|
||||
@patch("swarms.swarms.AutoScaler.scale_up")
|
||||
@patch("swarms.swarms.AutoScaler.scale_down")
|
||||
def test_autoscaler_monitor_and_scale(mock_scale_down, mock_scale_up):
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
autoscaler.add_task("task1")
|
||||
autoscaler.monitor_and_scale()
|
||||
mock_scale_up.assert_called_once()
|
||||
mock_scale_down.assert_called_once()
|
||||
|
||||
|
||||
@patch("swarms.swarms.AutoScaler.monitor_and_scale")
|
||||
@patch("swarms.swarms.agent.run")
|
||||
def test_autoscaler_start(mock_run, mock_monitor_and_scale):
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
autoscaler.add_task("task1")
|
||||
autoscaler.start()
|
||||
mock_run.assert_called_once()
|
||||
mock_monitor_and_scale.assert_called_once()
|
||||
|
||||
|
||||
def test_autoscaler_del_agent():
|
||||
autoscaler = AutoScaler(initial_agents=5, agent=agent)
|
||||
autoscaler.del_agent()
|
||||
assert len(autoscaler.agents_pool) == 4
|
@ -1,72 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from swarms.structs.graph_workflow import GraphWorkflow
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph_workflow():
|
||||
return GraphWorkflow()
|
||||
|
||||
|
||||
def test_init(graph_workflow):
|
||||
assert graph_workflow.graph == {}
|
||||
assert graph_workflow.entry_point is None
|
||||
|
||||
|
||||
def test_add(graph_workflow):
|
||||
graph_workflow.add("node1", "value1")
|
||||
assert "node1" in graph_workflow.graph
|
||||
assert graph_workflow.graph["node1"]["value"] == "value1"
|
||||
assert graph_workflow.graph["node1"]["edges"] == {}
|
||||
|
||||
|
||||
def test_set_entry_point(graph_workflow):
|
||||
graph_workflow.add("node1", "value1")
|
||||
graph_workflow.set_entry_point("node1")
|
||||
assert graph_workflow.entry_point == "node1"
|
||||
|
||||
|
||||
def test_set_entry_point_nonexistent_node(graph_workflow):
|
||||
with pytest.raises(ValueError, match="Node does not exist in graph"):
|
||||
graph_workflow.set_entry_point("nonexistent")
|
||||
|
||||
|
||||
def test_add_edge(graph_workflow):
|
||||
graph_workflow.add("node1", "value1")
|
||||
graph_workflow.add("node2", "value2")
|
||||
graph_workflow.add_edge("node1", "node2")
|
||||
assert "node2" in graph_workflow.graph["node1"]["edges"]
|
||||
|
||||
|
||||
def test_add_edge_nonexistent_node(graph_workflow):
|
||||
graph_workflow.add("node1", "value1")
|
||||
with pytest.raises(ValueError, match="Node does not exist in graph"):
|
||||
graph_workflow.add_edge("node1", "nonexistent")
|
||||
|
||||
|
||||
def test_add_conditional_edges(graph_workflow):
|
||||
graph_workflow.add("node1", "value1")
|
||||
graph_workflow.add("node2", "value2")
|
||||
graph_workflow.add_conditional_edges(
|
||||
"node1", "condition1", {"condition_value1": "node2"}
|
||||
)
|
||||
assert "node2" in graph_workflow.graph["node1"]["edges"]
|
||||
|
||||
|
||||
def test_add_conditional_edges_nonexistent_node(graph_workflow):
|
||||
graph_workflow.add("node1", "value1")
|
||||
with pytest.raises(ValueError, match="Node does not exist in graph"):
|
||||
graph_workflow.add_conditional_edges(
|
||||
"node1", "condition1", {"condition_value1": "nonexistent"}
|
||||
)
|
||||
|
||||
|
||||
def test_run(graph_workflow):
|
||||
graph_workflow.add("node1", "value1")
|
||||
graph_workflow.set_entry_point("node1")
|
||||
assert graph_workflow.run() == graph_workflow.graph
|
||||
|
||||
|
||||
def test_run_no_entry_point(graph_workflow):
|
||||
with pytest.raises(ValueError, match="Entry point not set"):
|
||||
graph_workflow.run()
|
Loading…
Reference in new issue