You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/workers/omni_worker.py

59 lines
1.6 KiB

1 year ago
import pytest
1 year ago
from swarms.worker.omni_worker import OmniWorkerAgent
1 year ago
@pytest.fixture
def omni_worker():
1 year ago
api_key = "test-key"
api_endpoint = "test-endpoint"
api_type = "test-type"
1 year ago
return OmniWorkerAgent(api_key, api_endpoint, api_type)
1 year ago
@pytest.mark.parametrize(
"data, expected_response",
[
(
{
"messages": ["Hello"],
"api_key": "key1",
"api_type": "type1",
"api_endpoint": "endpoint1",
},
{"response": "Hello back from Huggingface!"},
),
(
{
"messages": ["Goodbye"],
"api_key": "key2",
"api_type": "type2",
"api_endpoint": "endpoint2",
},
{"response": "Goodbye from Huggingface!"},
),
],
)
1 year ago
def test_chat_valid_data(mocker, omni_worker, data, expected_response):
1 year ago
mocker.patch(
"yourmodule.chat_huggingface", return_value=expected_response
) # replace 'yourmodule' with actual module name
1 year ago
assert omni_worker.chat(data) == expected_response
1 year ago
@pytest.mark.parametrize(
"invalid_data",
[
{"messages": ["Hello"]}, # missing api_key, api_type and api_endpoint
{"messages": ["Hello"], "api_key": "key1"}, # missing api_type and api_endpoint
{
"messages": ["Hello"],
"api_key": "key1",
"api_type": "type1",
}, # missing api_endpoint
],
)
1 year ago
def test_chat_invalid_data(omni_worker, invalid_data):
with pytest.raises(ValueError):
omni_worker.chat(invalid_data)