|
|
|
@ -1,13 +1,13 @@
|
|
|
|
|
import pytest
|
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
|
from swarms.models.fire_function import FireFunctionCaller
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_fire_function_caller_run(mocker):
|
|
|
|
|
# Create mock model and tokenizer
|
|
|
|
|
model = MagicMock()
|
|
|
|
|
tokenizer = MagicMock()
|
|
|
|
|
mocker.patch.object(FireFunctionCaller, 'model', model)
|
|
|
|
|
mocker.patch.object(FireFunctionCaller, 'tokenizer', tokenizer)
|
|
|
|
|
mocker.patch.object(FireFunctionCaller, "model", model)
|
|
|
|
|
mocker.patch.object(FireFunctionCaller, "tokenizer", tokenizer)
|
|
|
|
|
|
|
|
|
|
# Create mock task and arguments
|
|
|
|
|
task = "Add 2 and 3"
|
|
|
|
@ -38,4 +38,6 @@ def test_fire_function_caller_run(mocker):
|
|
|
|
|
tokenizer.batch_decode.assert_called_once_with(generated_ids)
|
|
|
|
|
|
|
|
|
|
# Assert the decoded output is printed
|
|
|
|
|
assert decoded_output in mocker.patch.object(print, 'call_args_list')
|
|
|
|
|
assert decoded_output in mocker.patch.object(
|
|
|
|
|
print, "call_args_list"
|
|
|
|
|
)
|
|
|
|
|