|
|
|
@ -3,7 +3,6 @@ from io import StringIO
|
|
|
|
|
from contextlib import redirect_stdout
|
|
|
|
|
from swarms.utils.class_args_wrapper import print_class_parameters
|
|
|
|
|
from swarms.structs.agent import Agent
|
|
|
|
|
from swarms.structs.autoscaler import Autoscaler
|
|
|
|
|
from fastapi import FastAPI
|
|
|
|
|
from fastapi.testclient import TestClient
|
|
|
|
|
|
|
|
|
@ -23,19 +22,6 @@ def test_print_class_parameters_agent():
|
|
|
|
|
assert output == expected_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_print_class_parameters_autoscaler():
|
|
|
|
|
f = StringIO()
|
|
|
|
|
with redirect_stdout(f):
|
|
|
|
|
print_class_parameters(Autoscaler)
|
|
|
|
|
output = f.getvalue().strip()
|
|
|
|
|
# Replace with the expected output for Autoscaler class
|
|
|
|
|
expected_output = (
|
|
|
|
|
"Parameter: min_agents, Type: <class 'int'>\nParameter:"
|
|
|
|
|
" max_agents, Type: <class 'int'>"
|
|
|
|
|
)
|
|
|
|
|
assert output == expected_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_print_class_parameters_error():
|
|
|
|
|
with pytest.raises(TypeError):
|
|
|
|
|
print_class_parameters("Not a class")
|
|
|
|
@ -43,7 +29,7 @@ def test_print_class_parameters_error():
|
|
|
|
|
|
|
|
|
|
@app.get("/parameters/{class_name}")
|
|
|
|
|
def get_parameters(class_name: str):
|
|
|
|
|
classes = {"Agent": Agent, "Autoscaler": Autoscaler}
|
|
|
|
|
classes = {"Agent": Agent}
|
|
|
|
|
if class_name in classes:
|
|
|
|
|
return print_class_parameters(
|
|
|
|
|
classes[class_name], api_format=True
|
|
|
|
@ -63,17 +49,6 @@ def test_get_parameters_agent():
|
|
|
|
|
assert response.json() == expected_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_parameters_autoscaler():
|
|
|
|
|
response = client.get("/parameters/Autoscaler")
|
|
|
|
|
assert response.status_code == 200
|
|
|
|
|
# Replace with the expected output for Autoscaler class
|
|
|
|
|
expected_output = {
|
|
|
|
|
"min_agents": "<class 'int'>",
|
|
|
|
|
"max_agents": "<class 'int'>",
|
|
|
|
|
}
|
|
|
|
|
assert response.json() == expected_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_parameters_not_found():
|
|
|
|
|
response = client.get("/parameters/NonexistentClass")
|
|
|
|
|
assert response.status_code == 200
|
|
|
|
|