diff --git a/tests/agents/workers/multi_model_worker.py b/tests/agents/workers/multi_model_worker.py new file mode 100644 index 00000000..fe8b0d6e --- /dev/null +++ b/tests/agents/workers/multi_model_worker.py @@ -0,0 +1,25 @@ +import pytest +from unittest.mock import Mock +from swarms.workers.multi_modal_worker import MultiModalVisualAgent, MultiModalVisualAgentTool + +@pytest.fixture +def multimodal_agent(): + # Mock the MultiModalVisualAgent + mock_agent = Mock(spec=MultiModalVisualAgent) + mock_agent.run_text.return_value = "Expected output from agent" + return mock_agent + +@pytest.fixture +def multimodal_agent_tool(multimodal_agent): + # Use the mocked MultiModalVisualAgent in the MultiModalVisualAgentTool + return MultiModalVisualAgentTool(multimodal_agent) + +@pytest.mark.parametrize("text_input, expected_output", [ + ("Hello, world!", "Expected output from agent"), + ("Another task", "Expected output from agent"), +]) +def test_run(multimodal_agent_tool, text_input, expected_output): + assert multimodal_agent_tool._run(text_input) == expected_output + + # You can also test if the MultiModalVisualAgent's run_text method was called with the right argument + multimodal_agent_tool.agent.run_text.assert_called_with(text_input) diff --git a/tests/agents/workers/omni_worker.py b/tests/agents/workers/omni_worker.py new file mode 100644 index 00000000..169272cd --- /dev/null +++ b/tests/agents/workers/omni_worker.py @@ -0,0 +1,33 @@ +import pytest +from unittest.mock import Mock +from swarms.worker.omni_worker import OmniWorkerAgent # replace 'yourmodule' with the actual module name + +@pytest.fixture +def omni_worker(): + api_key = 'test-key' + api_endpoint = 'test-endpoint' + api_type = 'test-type' + return OmniWorkerAgent(api_key, api_endpoint, api_type) + +@pytest.mark.parametrize("data, expected_response", [ + ( + {"messages": ["Hello"], "api_key": "key1", "api_type": "type1", "api_endpoint": "endpoint1"}, + {"response": "Hello back from Huggingface!"} + ), + ( + {"messages": ["Goodbye"], "api_key": "key2", "api_type": "type2", "api_endpoint": "endpoint2"}, + {"response": "Goodbye from Huggingface!"} + ), +]) +def test_chat_valid_data(mocker, omni_worker, data, expected_response): + mocker.patch('yourmodule.chat_huggingface', return_value=expected_response) # replace 'yourmodule' with actual module name + assert omni_worker.chat(data) == expected_response + +@pytest.mark.parametrize("invalid_data", [ + {"messages": ["Hello"]}, # missing api_key, api_type and api_endpoint + {"messages": ["Hello"], "api_key": "key1"}, # missing api_type and api_endpoint + {"messages": ["Hello"], "api_key": "key1", "api_type": "type1"}, # missing api_endpoint +]) +def test_chat_invalid_data(omni_worker, invalid_data): + with pytest.raises(ValueError): + omni_worker.chat(invalid_data)