You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
116 lines
3.6 KiB
116 lines
3.6 KiB
1 year ago
|
import pytest
|
||
|
from swarms.models.llama_function_caller import LlamaFunctionCaller
|
||
|
|
||
|
|
||
|
# Define fixtures if needed
|
||
|
@pytest.fixture
|
||
|
def llama_caller():
|
||
|
# Initialize the LlamaFunctionCaller with a sample model
|
||
|
return LlamaFunctionCaller()
|
||
|
|
||
|
|
||
|
# Basic test for model loading
|
||
|
def test_llama_model_loading(llama_caller):
|
||
|
assert llama_caller.model is not None
|
||
|
assert llama_caller.tokenizer is not None
|
||
|
|
||
|
|
||
|
# Test adding and calling custom functions
|
||
|
def test_llama_custom_function(llama_caller):
|
||
|
def sample_function(arg1, arg2):
|
||
|
return f"Sample function called with args: {arg1}, {arg2}"
|
||
|
|
||
|
llama_caller.add_func(
|
||
|
name="sample_function",
|
||
|
function=sample_function,
|
||
|
description="Sample custom function",
|
||
|
arguments=[
|
||
|
{"name": "arg1", "type": "string", "description": "Argument 1"},
|
||
|
{"name": "arg2", "type": "string", "description": "Argument 2"},
|
||
|
],
|
||
|
)
|
||
|
|
||
|
result = llama_caller.call_function(
|
||
|
"sample_function", arg1="arg1_value", arg2="arg2_value"
|
||
|
)
|
||
|
assert result == "Sample function called with args: arg1_value, arg2_value"
|
||
|
|
||
|
|
||
|
# Test streaming user prompts
|
||
|
def test_llama_streaming(llama_caller):
|
||
|
user_prompt = "Tell me about the tallest mountain in the world."
|
||
|
response = llama_caller(user_prompt)
|
||
|
assert isinstance(response, str)
|
||
|
assert len(response) > 0
|
||
|
|
||
|
|
||
|
# Test custom function not found
|
||
|
def test_llama_custom_function_not_found(llama_caller):
|
||
|
with pytest.raises(ValueError):
|
||
|
llama_caller.call_function("non_existent_function")
|
||
|
|
||
|
|
||
|
# Test invalid arguments for custom function
|
||
|
def test_llama_custom_function_invalid_arguments(llama_caller):
|
||
|
def sample_function(arg1, arg2):
|
||
|
return f"Sample function called with args: {arg1}, {arg2}"
|
||
|
|
||
|
llama_caller.add_func(
|
||
|
name="sample_function",
|
||
|
function=sample_function,
|
||
|
description="Sample custom function",
|
||
|
arguments=[
|
||
|
{"name": "arg1", "type": "string", "description": "Argument 1"},
|
||
|
{"name": "arg2", "type": "string", "description": "Argument 2"},
|
||
|
],
|
||
|
)
|
||
|
|
||
|
with pytest.raises(TypeError):
|
||
|
llama_caller.call_function("sample_function", arg1="arg1_value")
|
||
|
|
||
|
|
||
|
# Test streaming with custom runtime
|
||
|
def test_llama_custom_runtime():
|
||
|
llama_caller = LlamaFunctionCaller(
|
||
|
model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda"
|
||
|
)
|
||
|
user_prompt = "Tell me about the tallest mountain in the world."
|
||
|
response = llama_caller(user_prompt)
|
||
|
assert isinstance(response, str)
|
||
|
assert len(response) > 0
|
||
|
|
||
|
|
||
|
# Test caching functionality
|
||
|
def test_llama_cache():
|
||
|
llama_caller = LlamaFunctionCaller(
|
||
|
model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda"
|
||
|
)
|
||
|
|
||
|
# Perform a request to populate the cache
|
||
|
user_prompt = "Tell me about the tallest mountain in the world."
|
||
|
response = llama_caller(user_prompt)
|
||
|
|
||
|
# Check if the response is retrieved from the cache
|
||
|
llama_caller.model.from_cache = True
|
||
|
response_from_cache = llama_caller(user_prompt)
|
||
|
assert response == response_from_cache
|
||
|
|
||
|
|
||
|
# Test response length within max_tokens limit
|
||
|
def test_llama_response_length():
|
||
|
llama_caller = LlamaFunctionCaller(
|
||
|
model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda"
|
||
|
)
|
||
|
|
||
|
# Generate a long prompt
|
||
|
long_prompt = "A " + "test " * 100 # Approximately 500 tokens
|
||
|
|
||
|
# Ensure the response does not exceed max_tokens
|
||
|
response = llama_caller(long_prompt)
|
||
|
assert len(response.split()) <= 500
|
||
|
|
||
|
|
||
|
# Add more test cases as needed to cover different aspects of your code
|
||
|
|
||
|
# ...
|