You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/utils/test_litellm_wrapper.py

343 lines
12 KiB

import asyncio
import sys
from loguru import logger
from swarms.utils.litellm_wrapper import LiteLLM
# Configure loguru logger
logger.remove() # Remove default handler
logger.add(
"test_litellm.log",
rotation="1 MB",
format="{time} | {level} | {message}",
level="DEBUG",
)
logger.add(sys.stdout, level="INFO")
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Retrieve detailed current weather information for a specified location, including temperature, humidity, wind speed, and atmospheric conditions.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA, or a specific geographic coordinate in the format 'latitude,longitude'.",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit", "kelvin"],
"description": "The unit of temperature measurement to be used in the response.",
},
"include_forecast": {
"type": "boolean",
"description": "Indicates whether to include a short-term weather forecast along with the current conditions.",
},
"time": {
"type": "string",
"format": "date-time",
"description": "Optional parameter to specify the time for which the weather data is requested, in ISO 8601 format.",
},
},
"required": [
"location",
"unit",
"include_forecast",
"time",
],
},
},
}
]
# Initialize LiteLLM with streaming enabled
llm = LiteLLM(model_name="gpt-4o-mini", tools_list_dictionary=tools)
async def main():
# When streaming is enabled, arun returns a stream of chunks
# We need to handle these chunks directly, not try to access .choices
stream = await llm.arun("What is the weather in San Francisco?")
logger.info(f"Received stream from LLM. {stream}")
if stream is not None:
logger.info(f"Stream is not None. {stream}")
else:
logger.info("Stream is None.")
def run_test_suite():
"""Run all test cases and generate a comprehensive report."""
logger.info("Starting LiteLLM Test Suite")
total_tests = 0
passed_tests = 0
failed_tests = []
def log_test_result(test_name: str, passed: bool, error=None):
nonlocal total_tests, passed_tests
total_tests += 1
if passed:
passed_tests += 1
logger.success(f"{test_name} - PASSED")
else:
failed_tests.append((test_name, error))
logger.error(f"{test_name} - FAILED: {error}")
# Test 1: Basic Initialization
try:
logger.info("Testing basic initialization")
llm = LiteLLM()
assert llm.model_name == "gpt-4o"
assert llm.temperature == 0.5
assert llm.max_tokens == 4000
log_test_result("Basic Initialization", True)
except Exception as e:
log_test_result("Basic Initialization", False, str(e))
# Test 2: Custom Parameters
try:
logger.info("Testing custom parameters")
llm = LiteLLM(
model_name="gpt-3.5-turbo",
temperature=0.7,
max_tokens=2000,
system_prompt="You are a helpful assistant",
)
assert llm.model_name == "gpt-3.5-turbo"
assert llm.temperature == 0.7
assert llm.max_tokens == 2000
assert llm.system_prompt == "You are a helpful assistant"
log_test_result("Custom Parameters", True)
except Exception as e:
log_test_result("Custom Parameters", False, str(e))
# Test 3: Message Preparation
try:
logger.info("Testing message preparation")
llm = LiteLLM(system_prompt="Test system prompt")
messages = llm._prepare_messages("Test task")
assert len(messages) == 2
assert messages[0]["role"] == "system"
assert messages[0]["content"] == "Test system prompt"
assert messages[1]["role"] == "user"
assert messages[1]["content"] == "Test task"
log_test_result("Message Preparation", True)
except Exception as e:
log_test_result("Message Preparation", False, str(e))
# Test 4: Basic Completion
try:
logger.info("Testing basic completion")
llm = LiteLLM()
response = llm.run("What is 2+2?")
assert isinstance(response, str)
assert len(response) > 0
log_test_result("Basic Completion", True)
except Exception as e:
log_test_result("Basic Completion", False, str(e))
try:
# tool usage
asyncio.run(main())
except Exception as e:
log_test_result("Tool Usage", False, str(e))
# Test 5: Tool Calling
try:
logger.info("Testing tool calling")
tools = [
{
"type": "function",
"function": {
"name": "test_function",
"description": "A test function",
"parameters": {
"type": "object",
"properties": {"test": {"type": "string"}},
},
},
}
]
llm = LiteLLM(
tools_list_dictionary=tools,
tool_choice="auto",
model_name="gpt-4o-mini",
)
assert llm.tools_list_dictionary == tools
assert llm.tool_choice == "auto"
log_test_result("Tool Calling Setup", True)
except Exception as e:
log_test_result("Tool Calling Setup", False, str(e))
# Test 6: Async Completion
async def test_async():
try:
logger.info("Testing async completion")
llm = LiteLLM()
response = await llm.arun("What is 3+3?")
assert isinstance(response, str)
assert len(response) > 0
log_test_result("Async Completion", True)
except Exception as e:
log_test_result("Async Completion", False, str(e))
asyncio.run(test_async())
# Test 7: Batched Run
try:
logger.info("Testing batched run")
llm = LiteLLM()
tasks = ["Task 1", "Task 2", "Task 3"]
responses = llm.batched_run(tasks, batch_size=2)
assert isinstance(responses, list)
assert len(responses) == 3
log_test_result("Batched Run", True)
except Exception as e:
log_test_result("Batched Run", False, str(e))
# Test 8: Vision Support Check
try:
logger.info("Testing vision support check")
llm = LiteLLM(model_name="gpt-4o")
# This should not raise an error for vision-capable models
llm.check_if_model_supports_vision(img="test.jpg")
log_test_result("Vision Support Check", True)
except Exception as e:
log_test_result("Vision Support Check", False, str(e))
# Test 9: Direct URL Processing
try:
logger.info("Testing direct URL processing")
llm = LiteLLM(model_name="gpt-4o")
test_url = "https://github.com/kyegomez/swarms/blob/master/swarms_logo_new.png?raw=true"
should_use_direct = llm._should_use_direct_url(test_url)
assert isinstance(should_use_direct, bool)
log_test_result("Direct URL Processing", True)
except Exception as e:
log_test_result("Direct URL Processing", False, str(e))
# Test 10: Message Preparation with Image
try:
logger.info("Testing message preparation with image")
llm = LiteLLM(model_name="gpt-4o")
# Mock image URL to test message structure
test_img = "https://github.com/kyegomez/swarms/blob/master/swarms_logo_new.png?raw=true"
messages = llm._prepare_messages("Describe this image", img=test_img)
assert isinstance(messages, list)
assert len(messages) >= 1
# Check if image content is properly structured
user_message = next((msg for msg in messages if msg["role"] == "user"), None)
assert user_message is not None
log_test_result("Message Preparation with Image", True)
except Exception as e:
log_test_result("Message Preparation with Image", False, str(e))
# Test 11: Vision Processing Methods
try:
logger.info("Testing vision processing methods")
llm = LiteLLM(model_name="gpt-4o")
messages = []
# Test OpenAI vision processing
processed_messages = llm.openai_vision_processing(
"Describe this image",
"https://github.com/kyegomez/swarms/blob/master/swarms_logo_new.png?raw=true",
messages.copy()
)
assert isinstance(processed_messages, list)
assert len(processed_messages) > 0
# Test Anthropic vision processing
llm_anthropic = LiteLLM(model_name="claude-3-5-sonnet-20241022")
processed_messages_anthropic = llm_anthropic.anthropic_vision_processing(
"Describe this image",
"https://github.com/kyegomez/swarms/blob/master/swarms_logo_new.png?raw=true",
messages.copy()
)
assert isinstance(processed_messages_anthropic, list)
assert len(processed_messages_anthropic) > 0
log_test_result("Vision Processing Methods", True)
except Exception as e:
log_test_result("Vision Processing Methods", False, str(e))
# Test 12: Local vs URL Detection
try:
logger.info("Testing local vs URL detection")
llm = LiteLLM(model_name="gpt-4o")
# Test URL detection
url_test = "https://github.com/kyegomez/swarms/blob/master/swarms_logo_new.png?raw=true"
is_url_direct = llm._should_use_direct_url(url_test)
# Test local file detection
local_test = "/path/to/local/image.jpg"
is_local_direct = llm._should_use_direct_url(local_test)
# URLs should potentially use direct, local files should not
assert isinstance(is_url_direct, bool)
assert isinstance(is_local_direct, bool)
assert is_local_direct == False # Local files should never use direct URL
log_test_result("Local vs URL Detection", True)
except Exception as e:
log_test_result("Local vs URL Detection", False, str(e))
# Test 13: Vision Message Structure
try:
logger.info("Testing vision message structure")
llm = LiteLLM(model_name="gpt-4o")
messages = []
# Test message structure for image input
result = llm.vision_processing(
task="What do you see?",
image="https://github.com/kyegomez/swarms/blob/master/swarms_logo_new.png?raw=true",
messages=messages
)
assert isinstance(result, list)
assert len(result) > 0
# Verify the message contains both text and image components
user_msg = result[-1] # Last message should be user message
assert user_msg["role"] == "user"
assert "content" in user_msg
log_test_result("Vision Message Structure", True)
except Exception as e:
log_test_result("Vision Message Structure", False, str(e))
# Generate test report
success_rate = (passed_tests / total_tests) * 100
logger.info("\n=== Test Suite Report ===")
logger.info(f"Total Tests: {total_tests}")
logger.info(f"Passed Tests: {passed_tests}")
logger.info(f"Failed Tests: {len(failed_tests)}")
logger.info(f"Success Rate: {success_rate:.2f}%")
if failed_tests:
logger.error("\nFailed Tests Details:")
for test_name, error in failed_tests:
logger.error(f"{test_name}: {error}")
return {
"total_tests": total_tests,
"passed_tests": passed_tests,
"failed_tests": failed_tests,
"success_rate": success_rate,
}
if __name__ == "__main__":
test_results = run_test_suite()
logger.info(
"Test suite completed. Check test_litellm.log for detailed logs."
)