from unittest.mock import MagicMock

from swarms.models.fire_function import FireFunctionCaller


def test_fire_function_caller_run(mocker):
    # Create mock model and tokenizer
    model = MagicMock()
    tokenizer = MagicMock()
    mocker.patch.object(FireFunctionCaller, "model", model)
    mocker.patch.object(FireFunctionCaller, "tokenizer", tokenizer)

    # Create mock task and arguments
    task = "Add 2 and 3"
    args = (2, 3)
    kwargs = {}

    # Create mock generated_ids and decoded output
    generated_ids = [1, 2, 3]
    decoded_output = "5"
    model.generate.return_value = generated_ids
    tokenizer.batch_decode.return_value = [decoded_output]

    # Create FireFunctionCaller instance
    fire_function_caller = FireFunctionCaller()

    # Run the function
    fire_function_caller.run(task, *args, **kwargs)

    # Assert model.generate was called with the correct inputs
    model.generate.assert_called_once_with(
        tokenizer.apply_chat_template.return_value,
        max_new_tokens=fire_function_caller.max_tokens,
        *args,
        **kwargs,
    )

    # Assert tokenizer.batch_decode was called with the correct inputs
    tokenizer.batch_decode.assert_called_once_with(generated_ids)

    # Assert the decoded output is printed
    assert decoded_output in mocker.patch.object(
        print, "call_args_list"
    )