From 7cd6f25353f062557c1fecce317ebc31cb4bd8f5 Mon Sep 17 00:00:00 2001 From: Kye Date: Mon, 25 Dec 2023 18:05:04 -0500 Subject: [PATCH] [FEAT][print_class_parameters] --- pyproject.toml | 2 +- swarms/utils/class_args_wrapper.py | 36 +++++++++ tests/utils/test_class_args_wrapper.py | 81 +++++++++++++++++++ ...els_torch.py => test_load_models_torch.py} | 0 ....py => test_prep_torch_model_inference.py} | 0 5 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 swarms/utils/class_args_wrapper.py create mode 100644 tests/utils/test_class_args_wrapper.py rename tests/utils/{load_models_torch.py => test_load_models_torch.py} (100%) rename tests/utils/{prep_torch_model_inference.py => test_prep_torch_model_inference.py} (100%) diff --git a/pyproject.toml b/pyproject.toml index 76150dde..907d1914 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "2.4.0" +version = "2.4.1" description = "Swarms - Pytorch" license = "MIT" authors = ["Kye Gomez "] diff --git a/swarms/utils/class_args_wrapper.py b/swarms/utils/class_args_wrapper.py new file mode 100644 index 00000000..f24932cf --- /dev/null +++ b/swarms/utils/class_args_wrapper.py @@ -0,0 +1,36 @@ +import inspect + + +def print_class_parameters(cls, api_format: bool = False): + """ + Print the parameters of a class constructor. + + Parameters: + cls (type): The class to inspect. + + Example: + >>> print_class_parameters(Agent) + Parameter: x, Type: + Parameter: y, Type: + """ + try: + # Get the parameters of the class constructor + sig = inspect.signature(cls.__init__) + params = sig.parameters + + if api_format: + param_dict = {} + for name, param in params.items(): + if name == "self": + continue + param_dict[name] = str(param.annotation) + return param_dict + + # Print the parameters + for name, param in params.items(): + if name == "self": + continue + print(f"Parameter: {name}, Type: {param.annotation}") + + except Exception as e: + print(f"An error occurred while inspecting the class: {e}") diff --git a/tests/utils/test_class_args_wrapper.py b/tests/utils/test_class_args_wrapper.py new file mode 100644 index 00000000..d846f786 --- /dev/null +++ b/tests/utils/test_class_args_wrapper.py @@ -0,0 +1,81 @@ +import pytest +from io import StringIO +from contextlib import redirect_stdout +from swarms.utils.class_args_wrapper import print_class_parameters +from swarms.structs import Agent, Autoscaler +from fastapi import FastAPI +from fastapi.testclient import TestClient +from swarms.utils.class_args_wrapper import print_class_parameters +from swarms.structs import Agent, Autoscaler + +app = FastAPI() + + +def test_print_class_parameters_agent(): + f = StringIO() + with redirect_stdout(f): + print_class_parameters(Agent) + output = f.getvalue().strip() + # Replace with the expected output for Agent class + expected_output = ( + "Parameter: name, Type: \nParameter: age, Type:" + " " + ) + assert output == expected_output + + +def test_print_class_parameters_autoscaler(): + f = StringIO() + with redirect_stdout(f): + print_class_parameters(Autoscaler) + output = f.getvalue().strip() + # Replace with the expected output for Autoscaler class + expected_output = ( + "Parameter: min_agents, Type: \nParameter:" + " max_agents, Type: " + ) + assert output == expected_output + + +def test_print_class_parameters_error(): + with pytest.raises(TypeError): + print_class_parameters("Not a class") + + +@app.get("/parameters/{class_name}") +def get_parameters(class_name: str): + classes = {"Agent": Agent, "Autoscaler": Autoscaler} + if class_name in classes: + return print_class_parameters( + classes[class_name], api_format=True + ) + else: + return {"error": "Class not found"} + + +client = TestClient(app) + + +def test_get_parameters_agent(): + response = client.get("/parameters/Agent") + assert response.status_code == 200 + # Replace with the expected output for Agent class + expected_output = {"x": "", "y": ""} + assert response.json() == expected_output + + +def test_get_parameters_autoscaler(): + response = client.get("/parameters/Autoscaler") + assert response.status_code == 200 + # Replace with the expected output for Autoscaler class + expected_output = { + "min_agents": "", + "max_agents": "", + } + assert response.json() == expected_output + + +def test_get_parameters_not_found(): + response = client.get("/parameters/NonexistentClass") + assert response.status_code == 200 + assert response.json() == {"error": "Class not found"} diff --git a/tests/utils/load_models_torch.py b/tests/utils/test_load_models_torch.py similarity index 100% rename from tests/utils/load_models_torch.py rename to tests/utils/test_load_models_torch.py diff --git a/tests/utils/prep_torch_model_inference.py b/tests/utils/test_prep_torch_model_inference.py similarity index 100% rename from tests/utils/prep_torch_model_inference.py rename to tests/utils/test_prep_torch_model_inference.py