You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/models/test_llama_function_caller.py

143 lines
3.9 KiB

import pytest
from swarms.models.llama_function_caller import LlamaFunctionCaller
# Define fixtures if needed
@pytest.fixture
def llama_caller():
# Initialize the LlamaFunctionCaller with a sample model
return LlamaFunctionCaller()
# Basic test for model loading
def test_llama_model_loading(llama_caller):
assert llama_caller.model is not None
assert llama_caller.tokenizer is not None
# Test adding and calling custom functions
def test_llama_custom_function(llama_caller):
def sample_function(arg1, arg2):
return f"Sample function called with args: {arg1}, {arg2}"
llama_caller.add_func(
name="sample_function",
function=sample_function,
description="Sample custom function",
arguments=[
1 year ago
{
"name": "arg1",
"type": "string",
"description": "Argument 1",
},
{
"name": "arg2",
"type": "string",
"description": "Argument 2",
},
],
)
result = llama_caller.call_function(
"sample_function", arg1="arg1_value", arg2="arg2_value"
)
1 year ago
assert (
result
== "Sample function called with args: arg1_value, arg2_value"
)
# Test streaming user prompts
def test_llama_streaming(llama_caller):
user_prompt = "Tell me about the tallest mountain in the world."
response = llama_caller(user_prompt)
assert isinstance(response, str)
assert len(response) > 0
# Test custom function not found
def test_llama_custom_function_not_found(llama_caller):
with pytest.raises(ValueError):
llama_caller.call_function("non_existent_function")
# Test invalid arguments for custom function
def test_llama_custom_function_invalid_arguments(llama_caller):
def sample_function(arg1, arg2):
return f"Sample function called with args: {arg1}, {arg2}"
llama_caller.add_func(
name="sample_function",
function=sample_function,
description="Sample custom function",
arguments=[
1 year ago
{
"name": "arg1",
"type": "string",
"description": "Argument 1",
},
{
"name": "arg2",
"type": "string",
"description": "Argument 2",
},
],
)
with pytest.raises(TypeError):
1 year ago
llama_caller.call_function(
"sample_function", arg1="arg1_value"
)
# Test streaming with custom runtime
def test_llama_custom_runtime():
llama_caller = LlamaFunctionCaller(
model_id="Your-Model-ID",
cache_dir="Your-Cache-Directory",
runtime="cuda",
)
user_prompt = "Tell me about the tallest mountain in the world."
response = llama_caller(user_prompt)
assert isinstance(response, str)
assert len(response) > 0
# Test caching functionality
def test_llama_cache():
llama_caller = LlamaFunctionCaller(
model_id="Your-Model-ID",
cache_dir="Your-Cache-Directory",
runtime="cuda",
)
# Perform a request to populate the cache
user_prompt = "Tell me about the tallest mountain in the world."
response = llama_caller(user_prompt)
# Check if the response is retrieved from the cache
llama_caller.model.from_cache = True
response_from_cache = llama_caller(user_prompt)
assert response == response_from_cache
# Test response length within max_tokens limit
def test_llama_response_length():
llama_caller = LlamaFunctionCaller(
model_id="Your-Model-ID",
cache_dir="Your-Cache-Directory",
runtime="cuda",
)
# Generate a long prompt
long_prompt = "A " + "test " * 100 # Approximately 500 tokens
# Ensure the response does not exceed max_tokens
response = llama_caller(long_prompt)
assert len(response.split()) <= 500
# Add more test cases as needed to cover different aspects of your code
# ...