import pytest

from swarms.worker.omni_worker import OmniWorkerAgent


@pytest.fixture
def omni_worker():
    api_key = "test-key"
    api_endpoint = "test-endpoint"
    api_type = "test-type"
    return OmniWorkerAgent(api_key, api_endpoint, api_type)


@pytest.mark.parametrize(
    "data, expected_response",
    [
        (
            {
                "messages": ["Hello"],
                "api_key": "key1",
                "api_type": "type1",
                "api_endpoint": "endpoint1",
            },
            {"response": "Hello back from Huggingface!"},
        ),
        (
            {
                "messages": ["Goodbye"],
                "api_key": "key2",
                "api_type": "type2",
                "api_endpoint": "endpoint2",
            },
            {"response": "Goodbye from Huggingface!"},
        ),
    ],
)
def test_chat_valid_data(mocker, omni_worker, data, expected_response):
    mocker.patch(
        "yourmodule.chat_huggingface", return_value=expected_response
    )  # replace 'yourmodule' with actual module name
    assert omni_worker.chat(data) == expected_response


@pytest.mark.parametrize(
    "invalid_data",
    [
        {"messages": ["Hello"]},  # missing api_key, api_type and api_endpoint
        {"messages": ["Hello"], "api_key": "key1"},  # missing api_type and api_endpoint
        {
            "messages": ["Hello"],
            "api_key": "key1",
            "api_type": "type1",
        },  # missing api_endpoint
    ],
)
def test_chat_invalid_data(omni_worker, invalid_data):
    with pytest.raises(ValueError):
        omni_worker.chat(invalid_data)