import pytest from swarms.models.llama_function_caller import LlamaFunctionCaller # Define fixtures if needed @pytest.fixture def llama_caller(): # Initialize the LlamaFunctionCaller with a sample model return LlamaFunctionCaller() # Basic test for model loading def test_llama_model_loading(llama_caller): assert llama_caller.model is not None assert llama_caller.tokenizer is not None # Test adding and calling custom functions def test_llama_custom_function(llama_caller): def sample_function(arg1, arg2): return f"Sample function called with args: {arg1}, {arg2}" llama_caller.add_func( name="sample_function", function=sample_function, description="Sample custom function", arguments=[ { "name": "arg1", "type": "string", "description": "Argument 1", }, { "name": "arg2", "type": "string", "description": "Argument 2", }, ], ) result = llama_caller.call_function( "sample_function", arg1="arg1_value", arg2="arg2_value" ) assert ( result == "Sample function called with args: arg1_value, arg2_value" ) # Test streaming user prompts def test_llama_streaming(llama_caller): user_prompt = "Tell me about the tallest mountain in the world." response = llama_caller(user_prompt) assert isinstance(response, str) assert len(response) > 0 # Test custom function not found def test_llama_custom_function_not_found(llama_caller): with pytest.raises(ValueError): llama_caller.call_function("non_existent_function") # Test invalid arguments for custom function def test_llama_custom_function_invalid_arguments(llama_caller): def sample_function(arg1, arg2): return f"Sample function called with args: {arg1}, {arg2}" llama_caller.add_func( name="sample_function", function=sample_function, description="Sample custom function", arguments=[ { "name": "arg1", "type": "string", "description": "Argument 1", }, { "name": "arg2", "type": "string", "description": "Argument 2", }, ], ) with pytest.raises(TypeError): llama_caller.call_function("sample_function", arg1="arg1_value") # Test streaming with custom runtime def test_llama_custom_runtime(): llama_caller = LlamaFunctionCaller( model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda", ) user_prompt = "Tell me about the tallest mountain in the world." response = llama_caller(user_prompt) assert isinstance(response, str) assert len(response) > 0 # Test caching functionality def test_llama_cache(): llama_caller = LlamaFunctionCaller( model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda", ) # Perform a request to populate the cache user_prompt = "Tell me about the tallest mountain in the world." response = llama_caller(user_prompt) # Check if the response is retrieved from the cache llama_caller.model.from_cache = True response_from_cache = llama_caller(user_prompt) assert response == response_from_cache # Test response length within max_tokens limit def test_llama_response_length(): llama_caller = LlamaFunctionCaller( model_id="Your-Model-ID", cache_dir="Your-Cache-Directory", runtime="cuda", ) # Generate a long prompt long_prompt = "A " + "test " * 100 # Approximately 500 tokens # Ensure the response does not exceed max_tokens response = llama_caller(long_prompt) assert len(response.split()) <= 500 # Add more test cases as needed to cover different aspects of your code # ...