You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
45 lines
1.3 KiB
45 lines
1.3 KiB
from unittest.mock import MagicMock
|
|
|
|
from swarms.models.fire_function import FireFunctionCaller
|
|
|
|
|
|
def test_fire_function_caller_run(mocker):
|
|
# Create mock model and tokenizer
|
|
model = MagicMock()
|
|
tokenizer = MagicMock()
|
|
mocker.patch.object(FireFunctionCaller, "model", model)
|
|
mocker.patch.object(FireFunctionCaller, "tokenizer", tokenizer)
|
|
|
|
# Create mock task and arguments
|
|
task = "Add 2 and 3"
|
|
args = (2, 3)
|
|
kwargs = {}
|
|
|
|
# Create mock generated_ids and decoded output
|
|
generated_ids = [1, 2, 3]
|
|
decoded_output = "5"
|
|
model.generate.return_value = generated_ids
|
|
tokenizer.batch_decode.return_value = [decoded_output]
|
|
|
|
# Create FireFunctionCaller instance
|
|
fire_function_caller = FireFunctionCaller()
|
|
|
|
# Run the function
|
|
fire_function_caller.run(task, *args, **kwargs)
|
|
|
|
# Assert model.generate was called with the correct inputs
|
|
model.generate.assert_called_once_with(
|
|
tokenizer.apply_chat_template.return_value,
|
|
max_new_tokens=fire_function_caller.max_tokens,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
|
|
# Assert tokenizer.batch_decode was called with the correct inputs
|
|
tokenizer.batch_decode.assert_called_once_with(generated_ids)
|
|
|
|
# Assert the decoded output is printed
|
|
assert decoded_output in mocker.patch.object(
|
|
print, "call_args_list"
|
|
)
|