parent
6c609e316e
commit
4f33a7f81a
@ -0,0 +1,25 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from swarms.workers.multi_modal_worker import MultiModalVisualAgent, MultiModalVisualAgentTool
|
||||
|
||||
@pytest.fixture
|
||||
def multimodal_agent():
|
||||
# Mock the MultiModalVisualAgent
|
||||
mock_agent = Mock(spec=MultiModalVisualAgent)
|
||||
mock_agent.run_text.return_value = "Expected output from agent"
|
||||
return mock_agent
|
||||
|
||||
@pytest.fixture
|
||||
def multimodal_agent_tool(multimodal_agent):
|
||||
# Use the mocked MultiModalVisualAgent in the MultiModalVisualAgentTool
|
||||
return MultiModalVisualAgentTool(multimodal_agent)
|
||||
|
||||
@pytest.mark.parametrize("text_input, expected_output", [
|
||||
("Hello, world!", "Expected output from agent"),
|
||||
("Another task", "Expected output from agent"),
|
||||
])
|
||||
def test_run(multimodal_agent_tool, text_input, expected_output):
|
||||
assert multimodal_agent_tool._run(text_input) == expected_output
|
||||
|
||||
# You can also test if the MultiModalVisualAgent's run_text method was called with the right argument
|
||||
multimodal_agent_tool.agent.run_text.assert_called_with(text_input)
|
@ -0,0 +1,33 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from swarms.worker.omni_worker import OmniWorkerAgent # replace 'yourmodule' with the actual module name
|
||||
|
||||
@pytest.fixture
|
||||
def omni_worker():
|
||||
api_key = 'test-key'
|
||||
api_endpoint = 'test-endpoint'
|
||||
api_type = 'test-type'
|
||||
return OmniWorkerAgent(api_key, api_endpoint, api_type)
|
||||
|
||||
@pytest.mark.parametrize("data, expected_response", [
|
||||
(
|
||||
{"messages": ["Hello"], "api_key": "key1", "api_type": "type1", "api_endpoint": "endpoint1"},
|
||||
{"response": "Hello back from Huggingface!"}
|
||||
),
|
||||
(
|
||||
{"messages": ["Goodbye"], "api_key": "key2", "api_type": "type2", "api_endpoint": "endpoint2"},
|
||||
{"response": "Goodbye from Huggingface!"}
|
||||
),
|
||||
])
|
||||
def test_chat_valid_data(mocker, omni_worker, data, expected_response):
|
||||
mocker.patch('yourmodule.chat_huggingface', return_value=expected_response) # replace 'yourmodule' with actual module name
|
||||
assert omni_worker.chat(data) == expected_response
|
||||
|
||||
@pytest.mark.parametrize("invalid_data", [
|
||||
{"messages": ["Hello"]}, # missing api_key, api_type and api_endpoint
|
||||
{"messages": ["Hello"], "api_key": "key1"}, # missing api_type and api_endpoint
|
||||
{"messages": ["Hello"], "api_key": "key1", "api_type": "type1"}, # missing api_endpoint
|
||||
])
|
||||
def test_chat_invalid_data(omni_worker, invalid_data):
|
||||
with pytest.raises(ValueError):
|
||||
omni_worker.chat(invalid_data)
|
Loading…
Reference in new issue