You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/utils/litellm_wrapper.py

1131 lines
40 KiB

import asyncio
import base64
import socket
import traceback
import uuid
from pathlib import Path
from typing import List, Optional
import litellm
import requests
from litellm import completion, supports_vision
from loguru import logger
from pydantic import BaseModel
class LiteLLMException(Exception):
"""
Exception for LiteLLM.
"""
class NetworkConnectionError(Exception):
"""
Exception raised when network connectivity issues are detected.
"""
def get_audio_base64(audio_source: str) -> str:
"""
Convert audio data from a URL or local file path to a base64-encoded string.
This function supports both remote (HTTP/HTTPS) and local audio sources. If the source is a URL,
it fetches the audio data via HTTP. If the source is a local file path, it reads the file directly.
Args:
audio_source (str): The path or URL to the audio file.
Returns:
str: The base64-encoded string of the audio data.
Raises:
requests.HTTPError: If fetching audio from a URL fails.
FileNotFoundError: If the local audio file does not exist.
"""
if audio_source.startswith(("http://", "https://")):
response = requests.get(audio_source)
response.raise_for_status()
audio_data = response.content
else:
with open(audio_source, "rb") as file:
audio_data = file.read()
encoded_string = base64.b64encode(audio_data).decode("utf-8")
return encoded_string
def get_image_base64(image_source: str) -> str:
"""
Convert image data from a URL, local file path, or data URI to a base64-encoded string in data URI format.
If the input is already a data URI, it is returned unchanged. Otherwise, the image is loaded from the
specified source, encoded as base64, and returned as a data URI with the appropriate MIME type.
Args:
image_source (str): The path, URL, or data URI of the image.
Returns:
str: The image as a base64-encoded data URI string.
Raises:
requests.HTTPError: If fetching the image from a URL fails.
FileNotFoundError: If the local image file does not exist.
"""
if image_source.startswith("data:image"):
return image_source
if image_source.startswith(("http://", "https://")):
response = requests.get(image_source)
response.raise_for_status()
image_data = response.content
else:
with open(image_source, "rb") as file:
image_data = file.read()
extension = Path(image_source).suffix.lower()
mime_type_mapping = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".gif": "image/gif",
".webp": "image/webp",
".bmp": "image/bmp",
".tiff": "image/tiff",
".svg": "image/svg+xml",
}
mime_type = mime_type_mapping.get(extension, "image/jpeg")
encoded_string = base64.b64encode(image_data).decode("utf-8")
return f"data:{mime_type};base64,{encoded_string}"
def save_base64_as_image(
base64_data: str,
output_dir: str = "images",
) -> str:
"""
Decode base64-encoded image data and save it as an image file in the specified directory.
This function supports both raw base64 strings and data URIs (data:image/...;base64,...).
The image format is determined from the MIME type if present, otherwise defaults to JPEG.
The image is saved with a randomly generated filename.
Args:
base64_data (str): The base64-encoded image data, either as a raw string or a data URI.
output_dir (str, optional): Directory to save the image file. Defaults to "images".
If None, saves to the current working directory.
Returns:
str: The full path to the saved image file.
Raises:
ValueError: If the base64 data is not a valid data URI or is otherwise invalid.
IOError: If the image cannot be written to disk.
"""
import os
if output_dir is None:
output_dir = os.getcwd()
os.makedirs(output_dir, exist_ok=True)
if base64_data.startswith("data:image"):
try:
header, encoded_data = base64_data.split(",", 1)
mime_type = header.split(":")[1].split(";")[0]
except (ValueError, IndexError):
raise ValueError("Invalid data URI format")
else:
encoded_data = base64_data
mime_type = "image/jpeg"
mime_to_extension = {
"image/jpeg": ".jpg",
"image/jpg": ".jpg",
"image/png": ".png",
"image/gif": ".gif",
"image/webp": ".webp",
"image/bmp": ".bmp",
"image/tiff": ".tiff",
"image/svg+xml": ".svg",
}
extension = mime_to_extension.get(mime_type, ".jpg")
filename = f"{uuid.uuid4()}{extension}"
file_path = os.path.join(output_dir, filename)
try:
logger.debug(
f"Attempting to decode base64 data of length: {len(encoded_data)}"
)
logger.debug(
f"Base64 data (first 100 chars): {encoded_data[:100]}..."
)
image_data = base64.b64decode(encoded_data)
with open(file_path, "wb") as f:
f.write(image_data)
logger.info(f"Image saved successfully to: {file_path}")
return file_path
except Exception as e:
logger.error(
f"Base64 decoding failed. Data length: {len(encoded_data)}"
)
logger.error(
f"First 100 chars of data: {encoded_data[:100]}..."
)
raise IOError(f"Failed to save image: {str(e)}")
def gemini_output_img_handler(response: any):
"""
Handle Gemini model output that may contain a base64-encoded image string.
If the response content is a base64-encoded image (i.e., a string starting with a known image data URI prefix),
this function saves the image to disk and returns the file path. Otherwise, it returns the content as is.
Args:
response (any): The response object from the Gemini model. It is expected to have
a structure such that `response.choices[0].message.content` contains the output.
Returns:
str: The file path to the saved image if the content is a base64 image, or the original content otherwise.
"""
response_content = response.choices[0].message.content
base64_prefixes = [
"data:image/jpeg;base64,",
"data:image/jpg;base64,",
"data:image/png;base64,",
"data:image/gif;base64,",
"data:image/webp;base64,",
"data:image/bmp;base64,",
"data:image/tiff;base64,",
"data:image/svg+xml;base64,",
]
if isinstance(response_content, str) and any(
response_content.strip().startswith(prefix)
for prefix in base64_prefixes
):
return save_base64_as_image(base64_data=response_content)
else:
return response_content
class LiteLLM:
"""
This class represents a LiteLLM.
It is used to interact with the LLM model for various tasks.
"""
def __init__(
self,
model_name: str = "gpt-4.1",
system_prompt: str = None,
stream: bool = False,
temperature: float = 0.5,
max_tokens: int = 4000,
ssl_verify: bool = False,
max_completion_tokens: int = 4000,
tools_list_dictionary: List[dict] = None,
tool_choice: str = "auto",
parallel_tool_calls: bool = False,
audio: str = None,
retries: int = 3,
verbose: bool = False,
caching: bool = False,
mcp_call: bool = False,
top_p: float = 1.0,
functions: List[dict] = None,
return_all: bool = False,
base_url: str = None,
api_key: str = None,
api_version: str = None,
reasoning_effort: str = None,
drop_params: bool = True,
thinking_tokens: int = None,
reasoning_enabled: bool = False,
response_format: any = None,
*args,
**kwargs,
):
"""
Initialize the LiteLLM with the given parameters.
Args:
model_name (str, optional): The name of the model to use. Defaults to "gpt-4.1".
system_prompt (str, optional): The system prompt to use. Defaults to None.
stream (bool, optional): Whether to stream the output. Defaults to False.
temperature (float, optional): The temperature for the model. Defaults to 0.5.
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 4000.
ssl_verify (bool, optional): Whether to verify SSL certificates. Defaults to False.
max_completion_tokens (int, optional): Maximum completion tokens. Defaults to 4000.
tools_list_dictionary (List[dict], optional): List of tool definitions. Defaults to None.
tool_choice (str, optional): Tool choice strategy. Defaults to "auto".
parallel_tool_calls (bool, optional): Whether to enable parallel tool calls. Defaults to False.
audio (str, optional): Audio input path. Defaults to None.
retries (int, optional): Number of retries. Defaults to 0.
verbose (bool, optional): Whether to enable verbose logging. Defaults to False.
caching (bool, optional): Whether to enable caching. Defaults to False.
mcp_call (bool, optional): Whether this is an MCP call. Defaults to False.
top_p (float, optional): Top-p sampling parameter. Defaults to 1.0.
functions (List[dict], optional): Function definitions. Defaults to None.
return_all (bool, optional): Whether to return all response data. Defaults to False.
base_url (str, optional): Base URL for the API. Defaults to None.
api_key (str, optional): API key. Defaults to None.
api_version (str, optional): API version. Defaults to None.
*args: Additional positional arguments that will be stored and used in run method.
If a single dictionary is passed, it will be merged into completion parameters.
**kwargs: Additional keyword arguments that will be stored and used in run method.
These will be merged into completion parameters with lower priority than
runtime kwargs passed to the run method.
"""
self.model_name = model_name
self.system_prompt = system_prompt
self.stream = stream
self.temperature = temperature
self.max_tokens = max_tokens
self.ssl_verify = ssl_verify
self.max_completion_tokens = max_completion_tokens
self.tools_list_dictionary = tools_list_dictionary
self.tool_choice = tool_choice
self.parallel_tool_calls = parallel_tool_calls
self.caching = caching
self.mcp_call = mcp_call
self.top_p = top_p
self.functions = functions
self.audio = audio
self.return_all = return_all
self.base_url = base_url
self.api_key = api_key
self.api_version = api_version
self.reasoning_effort = reasoning_effort
self.thinking_tokens = thinking_tokens
self.reasoning_enabled = reasoning_enabled
self.verbose = verbose
self.response_format = response_format
self.modalities = []
self.messages = [] # Initialize messages list
# Configure litellm settings
litellm.set_verbose = (
verbose # Disable verbose mode for better performance
)
litellm.ssl_verify = ssl_verify
litellm.num_retries = (
retries # Add retries for better reliability
)
litellm.drop_params = drop_params
# Add system prompt if present
if self.system_prompt is not None:
self.messages.append(
{"role": "system", "content": self.system_prompt}
)
# Store additional args and kwargs for use in run method
self.init_args = args
self.init_kwargs = kwargs
# if self.reasoning_enabled is True:
# self.reasoning_check()
def reasoning_check(self):
"""
Check if reasoning is enabled and supported by the model, and adjust temperature accordingly.
If reasoning is enabled and the model supports reasoning, set temperature to 1 for optimal reasoning.
Also logs information or warnings based on the model's reasoning support and configuration.
"""
"""
Check if reasoning is enabled and supported by the model, and adjust temperature, thinking_tokens, and top_p accordingly.
This single-line version combines all previous checks and actions for reasoning-enabled models, including Anthropic-specific logic.
"""
if self.reasoning_enabled:
supports_reasoning = litellm.supports_reasoning(
model=self.model_name
)
uses_anthropic = self.check_if_model_name_uses_anthropic(
model_name=self.model_name
)
if supports_reasoning:
logger.info(
f"Model {self.model_name} supports reasoning and reasoning enabled is set to {self.reasoning_enabled}. Temperature will be set to 1 for better reasoning as some models may not work with low temperature."
)
self.temperature = 1
else:
logger.warning(
f"Model {self.model_name} does not support reasoning and reasoning enabled is set to {self.reasoning_enabled}. Temperature will not be set to 1."
)
logger.warning(
f"Model {self.model_name} may or may not support reasoning and reasoning enabled is set to {self.reasoning_enabled}"
)
if uses_anthropic:
if self.thinking_tokens is None:
logger.info(
f"Model {self.model_name} is an Anthropic model and reasoning enabled is set to {self.reasoning_enabled}. Thinking tokens is mandatory for Anthropic models."
)
self.thinking_tokens = self.max_tokens / 4
logger.info(
"top_p must be greater than 0.95 for Anthropic models with reasoning enabled"
)
self.top_p = 0.95
def _process_additional_args(
self, completion_params: dict, runtime_args: tuple
):
"""
Process additional arguments from both initialization and runtime.
Args:
completion_params (dict): The completion parameters dictionary to update
runtime_args (tuple): Runtime positional arguments
"""
# Process initialization args
if self.init_args:
if len(self.init_args) == 1 and isinstance(
self.init_args[0], dict
):
# If init_args contains a single dictionary, merge it
completion_params.update(self.init_args[0])
else:
# Store other types of init_args for debugging
completion_params["init_args"] = self.init_args
# Process runtime args
if runtime_args:
if len(runtime_args) == 1 and isinstance(
runtime_args[0], dict
):
# If runtime_args contains a single dictionary, merge it (highest priority)
completion_params.update(runtime_args[0])
else:
# Store other types of runtime_args for debugging
completion_params["runtime_args"] = runtime_args
def output_for_tools(self, response: any):
"""
Process and extract tool call information from the LLM response.
This function handles the output for tool-based responses, supporting both
MCP (Multi-Call Protocol) and standard tool call formats. It extracts the
relevant function name and arguments from the response, handling both
BaseModel and dictionary outputs.
Args:
response (any): The response object returned by the LLM API call.
Returns:
dict or list: A dictionary containing the function name and arguments
if MCP call is used, or the tool calls output (as a dict or list)
for standard tool call responses.
"""
if self.mcp_call is True:
tool_calls = response.choices[0].message.tool_calls
# Check if there are multiple tool calls
if len(tool_calls) > 1:
# Return all tool calls if there are multiple
return [
{
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments,
}
}
for tool_call in tool_calls
]
else:
# Single tool call
out = tool_calls[0].function
output = {
"function": {
"name": out.name,
"arguments": out.arguments,
}
}
return output
else:
out = response.choices[0].message.tool_calls
if isinstance(out, BaseModel):
out = out.model_dump()
return out
def output_for_reasoning(self, response: any):
"""
Handle output for reasoning models, formatting reasoning content and thinking blocks.
Args:
response: The response object from the LLM API call
Returns:
str: Formatted string containing reasoning content, thinking blocks, and main content
"""
output_parts = []
# Check if reasoning content is available
if (
hasattr(response.choices[0].message, "reasoning_content")
and response.choices[0].message.reasoning_content
):
output_parts.append(
f"Reasoning Content:\n{response.choices[0].message.reasoning_content}\n"
)
# Check if thinking blocks are available (Anthropic models)
if (
hasattr(response.choices[0].message, "thinking_blocks")
and response.choices[0].message.thinking_blocks
):
output_parts.append("Thinking Blocks:")
for i, block in enumerate(
response.choices[0].message.thinking_blocks, 1
):
block_type = block.get("type", "")
thinking = block.get("thinking", "")
output_parts.append(
f"Block {i} (Type: {block_type}):"
)
output_parts.append(f" Thinking: {thinking}")
output_parts.append("")
# Include tools if available
if (
hasattr(response.choices[0].message, "tool_calls")
and response.choices[0].message.tool_calls
):
output_parts.append(
f"Tools:\n{self.output_for_tools(response)}\n"
)
# Always include the main content
content = response.choices[0].message.content
if content:
output_parts.append(f"Content:\n{content}")
# Join all parts into a single string
return "\n".join(output_parts)
def _prepare_messages(
self,
task: Optional[str] = None,
img: Optional[str] = None,
):
"""
Prepare the messages for the given task.
Args:
task (str): The task to prepare messages for.
img (str, optional): Image input if any. Defaults to None.
Returns:
list: A list of messages prepared for the task.
"""
# Start with a fresh copy of messages to avoid duplication
messages = self.messages.copy()
# Check if model supports vision if image is provided
if img is not None:
self.check_if_model_supports_vision(img=img)
# Handle vision case - this already includes both task and image
messages = self.vision_processing(
task=task, image=img, messages=messages
)
elif task is not None:
# Only add task message if no image (since vision_processing handles both)
messages.append({"role": "user", "content": task})
return messages
def anthropic_vision_processing(
self, task: str, image: str, messages: list
) -> list:
"""
Process vision input specifically for Anthropic models.
Handles Anthropic's specific image format requirements.
"""
# Check if we can use direct URL
if self._should_use_direct_url(image):
# Use direct URL without base64 conversion
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": task},
{
"type": "image_url",
"image_url": {
"url": image,
},
},
],
}
)
else:
# Fall back to base64 conversion for local files
image_url = get_image_base64(image)
# Extract mime type from the data URI or use default
mime_type = "image/jpeg" # default
if "data:" in image_url and ";base64," in image_url:
mime_type = image_url.split(";base64,")[0].split(
"data:"
)[1]
# Ensure mime type is one of the supported formats
supported_formats = [
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
]
if mime_type not in supported_formats:
mime_type = (
"image/jpeg" # fallback to jpeg if unsupported
)
# Construct Anthropic vision message with base64
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": task},
{
"type": "image_url",
"image_url": {
"url": image_url,
"format": mime_type,
},
},
],
}
)
return messages
def openai_vision_processing(
self, task: str, image: str, messages: list
) -> list:
"""
Process vision input specifically for OpenAI models.
Handles OpenAI's specific image format requirements.
"""
# Check if we can use direct URL
if self._should_use_direct_url(image):
# Use direct URL without base64 conversion
vision_message = {
"type": "image_url",
"image_url": {"url": image},
}
else:
# Fall back to base64 conversion for local files
image_url = get_image_base64(image)
# Prepare vision message with base64
vision_message = {
"type": "image_url",
"image_url": {"url": image_url},
}
# Add format for specific models
extension = Path(image).suffix.lower()
# Map common image extensions to proper MIME types
mime_type_mapping = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".gif": "image/gif",
".webp": "image/webp",
".bmp": "image/bmp",
".tiff": "image/tiff",
".svg": "image/svg+xml",
}
mime_type = mime_type_mapping.get(extension, "image/jpeg")
vision_message["image_url"]["format"] = mime_type
# Append vision message
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": task},
vision_message,
],
}
)
return messages
def _should_use_direct_url(self, image: str) -> bool:
"""
Determine if we should use direct URL passing instead of base64 conversion.
Args:
image (str): The image source (URL or file path)
Returns:
bool: True if we should use direct URL, False if we need base64 conversion
"""
# Only use direct URL for HTTP/HTTPS URLs
if not image.startswith(("http://", "https://")):
return False
# Check for local/custom models that might not support direct URLs
model_lower = self.model_name.lower()
local_indicators = [
"localhost",
"127.0.0.1",
"local",
"custom",
"ollama",
"llama-cpp",
]
is_local = any(
indicator in model_lower for indicator in local_indicators
) or (
self.base_url is not None
and any(
indicator in self.base_url.lower()
for indicator in local_indicators
)
)
if is_local:
return False
# Use LiteLLM's supports_vision to check if model supports vision and direct URLs
try:
return supports_vision(model=self.model_name)
except Exception:
return False
def vision_processing(
self, task: str, image: str, messages: Optional[list] = None
):
"""
Process the image for the given task.
Handles different image formats and model requirements.
This method now intelligently chooses between:
1. Direct URL passing (when model supports it and image is a URL)
2. Base64 conversion (for local files or unsupported models)
This approach reduces server load and improves performance by avoiding
unnecessary image downloads and base64 conversions when possible.
"""
# Ensure messages is a list
if messages is None:
messages = []
logger.info(f"Processing image for model: {self.model_name}")
# Log whether we're using direct URL or base64 conversion
if self._should_use_direct_url(image):
logger.info(
f"Using direct URL passing for image: {image[:100]}..."
)
else:
if image.startswith(("http://", "https://")):
logger.info(
"Converting URL image to base64 (model doesn't support direct URLs)"
)
else:
logger.info("Converting local file to base64")
if (
"anthropic" in self.model_name.lower()
or "claude" in self.model_name.lower()
):
messages = self.anthropic_vision_processing(
task, image, messages
)
return messages
else:
messages = self.openai_vision_processing(
task, image, messages
)
return messages
def audio_processing(self, task: str, audio: str):
"""
Process the audio for the given task.
Args:
task (str): The task to be processed.
audio (str): The path or identifier for the audio file.
"""
encoded_string = get_audio_base64(audio)
# Append audio message
self.messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": task},
{
"type": "input_audio",
"input_audio": {
"data": encoded_string,
"format": "wav",
},
},
],
}
)
def check_if_model_supports_vision(self, img: str = None):
"""
Check if the model supports vision capabilities.
This method uses LiteLLM's built-in supports_vision function to verify
that the model can handle image inputs before processing.
Args:
img (str, optional): Image path/URL to validate against model capabilities
Raises:
ValueError: If the model doesn't support vision and an image is provided
"""
if img is not None:
out = supports_vision(model=self.model_name)
if out is False:
raise ValueError(
f"Model {self.model_name} does not support vision"
)
@staticmethod
def check_if_model_name_uses_anthropic(model_name: str):
"""
Check if the model name uses Anthropic.
"""
if "anthropic" in model_name.lower():
return True
else:
return False
@staticmethod
def check_if_model_name_uses_openai(model_name: str):
"""
Check if the model name uses OpenAI.
"""
if "openai" in model_name.lower():
return True
else:
return False
@staticmethod
def check_internet_connection(
host: str = "8.8.8.8", port: int = 53, timeout: int = 3
) -> bool:
"""
Check if there is an active internet connection.
This method attempts to establish a socket connection to a DNS server
(default is Google's DNS at 8.8.8.8) to verify internet connectivity.
Args:
host (str, optional): The host to connect to for checking connectivity.
Defaults to "8.8.8.8" (Google DNS).
port (int, optional): The port to use for the connection. Defaults to 53 (DNS).
timeout (int, optional): Connection timeout in seconds. Defaults to 3.
Returns:
bool: True if internet connection is available, False otherwise.
"""
try:
socket.setdefaulttimeout(timeout)
socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect(
(host, port)
)
return True
except (socket.error, socket.timeout):
return False
@staticmethod
def is_local_model(
model_name: str, base_url: Optional[str] = None
) -> bool:
"""
Determine if the model is a local model (e.g., Ollama, LlamaCPP).
Args:
model_name (str): The name of the model to check.
base_url (str, optional): The base URL if specified. Defaults to None.
Returns:
bool: True if the model is a local model, False otherwise.
"""
local_indicators = [
"ollama",
"llama-cpp",
"local",
"localhost",
"127.0.0.1",
"custom",
]
model_lower = model_name.lower()
is_local_model = any(
indicator in model_lower for indicator in local_indicators
)
is_local_url = base_url is not None and any(
indicator in base_url.lower()
for indicator in local_indicators
)
return is_local_model or is_local_url
def run(
self,
task: str,
audio: Optional[str] = None,
img: Optional[str] = None,
*args,
**kwargs,
):
"""
Run the LLM model for the given task.
Args:
task (str): The task to run the model for.
audio (str, optional): Audio input if any. Defaults to None.
img (str, optional): Image input if any. Defaults to None.
*args: Additional positional arguments. If a single dictionary is passed,
it will be merged into completion parameters with highest priority.
**kwargs: Additional keyword arguments that will be merged into completion
parameters with highest priority (overrides init kwargs).
Returns:
str: The content of the response from the model.
Raises:
Exception: If there is an error in processing the request.
Note:
Parameter priority order (highest to lowest):
1. Runtime kwargs (passed to run method)
2. Runtime args (if dictionary, passed to run method)
3. Init kwargs (passed to __init__)
4. Init args (if dictionary, passed to __init__)
5. Default parameters
"""
try:
# Prepare messages properly - this handles both task and image together
messages = self._prepare_messages(task=task, img=img)
# Base completion parameters
completion_params = {
"model": self.model_name,
"messages": messages,
"stream": self.stream,
"max_tokens": self.max_tokens,
"caching": self.caching,
"temperature": self.temperature,
"top_p": self.top_p,
}
# Merge initialization kwargs first (lower priority)
if self.init_kwargs:
completion_params.update(self.init_kwargs)
# Merge runtime kwargs (higher priority - overrides init kwargs)
if kwargs:
completion_params.update(kwargs)
if self.api_version is not None:
completion_params["api_version"] = self.api_version
# Add temperature for non-o4/o3 models
if self.model_name not in [
"openai/o4-mini",
"openai/o3-2025-04-16",
]:
completion_params["temperature"] = self.temperature
# Add tools if specified
if self.tools_list_dictionary is not None:
completion_params.update(
{
"tools": self.tools_list_dictionary,
"tool_choice": self.tool_choice,
"parallel_tool_calls": self.parallel_tool_calls,
}
)
if self.functions is not None:
completion_params.update(
{"functions": self.functions}
)
if self.base_url is not None:
completion_params["base_url"] = self.base_url
if self.response_format is not None:
completion_params["response_format"] = (
self.response_format
)
# Add modalities if needed
if self.modalities and len(self.modalities) >= 2:
completion_params["modalities"] = self.modalities
if (
self.reasoning_effort is not None
and litellm.supports_reasoning(model=self.model_name)
is True
):
completion_params["reasoning_effort"] = (
self.reasoning_effort
)
if (
self.reasoning_enabled is True
and self.thinking_tokens is not None
):
thinking = {
"type": "enabled",
"budget_tokens": self.thinking_tokens,
}
completion_params["thinking"] = thinking
# Process additional args if any
self._process_additional_args(completion_params, args)
# Make the completion call
response = completion(**completion_params)
# print(response)
# Validate response
if not response:
logger.error(
"Received empty response from completion call"
)
return None
# Handle streaming response
if self.stream:
return response # Return the streaming generator directly
# Handle reasoning model output
elif (
self.reasoning_enabled
and self.reasoning_effort is not None
):
return self.output_for_reasoning(response)
# Handle tool-based response
elif self.tools_list_dictionary is not None:
result = self.output_for_tools(response)
return result
elif self.return_all is True:
return response.model_dump()
elif "gemini" in self.model_name.lower():
return gemini_output_img_handler(response)
else:
return response.choices[0].message.content
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.RequestException,
ConnectionError,
TimeoutError,
) as network_error:
# Check if this is a local model
if self.is_local_model(self.model_name, self.base_url):
error_msg = (
f"Network error connecting to local model '{self.model_name}': {str(network_error)}\n\n"
"Troubleshooting steps:\n"
"1. Ensure your local model server (e.g., Ollama, LlamaCPP) is running\n"
"2. Verify the base_url is correct and accessible\n"
"3. Check that the model is properly loaded and available\n"
)
logger.error(error_msg)
raise NetworkConnectionError(
error_msg
) from network_error
# Check internet connectivity
has_internet = self.check_internet_connection()
if not has_internet:
error_msg = (
f"No internet connection detected while trying to use model '{self.model_name}'.\n\n"
"Possible solutions:\n"
"1. Check your internet connection and try again\n"
"2. Reconnect to your network\n"
"3. Use a local model instead (e.g., Ollama):\n"
" - Install Ollama from https://ollama.ai\n"
" - Run: ollama pull llama2\n"
" - Use model_name='ollama/llama2' in your LiteLLM configuration\n"
"\nExample:\n"
" model = LiteLLM(model_name='ollama/llama2')\n"
)
logger.error(error_msg)
raise NetworkConnectionError(
error_msg
) from network_error
else:
# Internet is available but request failed
error_msg = (
f"Network error occurred while connecting to '{self.model_name}': {str(network_error)}\n\n"
"Possible causes:\n"
"1. The API endpoint may be temporarily unavailable\n"
"2. Connection timeout or slow network\n"
"3. Firewall or proxy blocking the connection\n"
"\nConsider using a local model as a fallback:\n"
" model = LiteLLM(model_name='ollama/llama2')\n"
)
logger.error(error_msg)
raise NetworkConnectionError(
error_msg
) from network_error
except LiteLLMException as error:
logger.error(
f"Error in LiteLLM run: {str(error)} Traceback: {traceback.format_exc()}"
)
raise
except Exception as error:
logger.error(
f"Unexpected error in LiteLLM run: {str(error)} Traceback: {traceback.format_exc()}"
)
raise
def __call__(self, task: str, *args, **kwargs):
"""
Call the LLM model for the given task.
Args:
task (str): The task to run the model for.
*args: Additional positional arguments to pass to the model.
**kwargs: Additional keyword arguments to pass to the model.
Returns:
str: The content of the response from the model.
"""
return self.run(task, *args, **kwargs)
def batched_run(self, tasks: List[str], batch_size: int = 10):
"""
Run multiple tasks in batches synchronously.
Args:
tasks (List[str]): List of tasks to process.
batch_size (int): Size of each batch.
Returns:
List[str]: List of responses.
"""
logger.info(
f"Running {len(tasks)} tasks in batches of {batch_size}"
)
return asyncio.run(self._process_batch(tasks, batch_size))