You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/swarms/utils/litellm_wrapper.py

479 lines
15 KiB

import base64
import requests
import asyncio
from typing import List
from loguru import logger
import litellm
try:
from litellm import completion, acompletion
except ImportError:
import subprocess
import sys
import litellm
print("Installing litellm")
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-U", "litellm"]
)
print("litellm installed")
from litellm import completion
litellm.set_verbose = True
litellm.ssl_verify = False
class LiteLLMException(Exception):
"""
Exception for LiteLLM.
"""
def get_audio_base64(audio_source: str) -> str:
"""
Convert audio from a given source to a base64 encoded string.
This function handles both URLs and local file paths. If the audio source is a URL, it fetches the audio data
from the internet. If it is a local file path, it reads the audio data from the specified file.
Args:
audio_source (str): The source of the audio, which can be a URL or a local file path.
Returns:
str: A base64 encoded string representation of the audio data.
Raises:
requests.HTTPError: If the HTTP request to fetch audio data fails.
FileNotFoundError: If the local audio file does not exist.
"""
# Handle URL
if audio_source.startswith(("http://", "https://")):
response = requests.get(audio_source)
response.raise_for_status()
audio_data = response.content
# Handle local file
else:
with open(audio_source, "rb") as file:
audio_data = file.read()
encoded_string = base64.b64encode(audio_data).decode("utf-8")
return encoded_string
class LiteLLM:
"""
This class represents a LiteLLM.
It is used to interact with the LLM model for various tasks.
"""
def __init__(
self,
model_name: str = "gpt-4o",
system_prompt: str = None,
stream: bool = False,
temperature: float = 0.5,
max_tokens: int = 4000,
ssl_verify: bool = False,
max_completion_tokens: int = 4000,
tools_list_dictionary: List[dict] = None,
tool_choice: str = "auto",
parallel_tool_calls: bool = False,
audio: str = None,
retries: int = 3,
verbose: bool = False,
caching: bool = False,
*args,
**kwargs,
):
"""
Initialize the LiteLLM with the given parameters.
Args:
model_name (str, optional): The name of the model to use. Defaults to "gpt-4o".
system_prompt (str, optional): The system prompt to use. Defaults to None.
stream (bool, optional): Whether to stream the output. Defaults to False.
temperature (float, optional): The temperature for the model. Defaults to 0.5.
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 4000.
"""
self.model_name = model_name
self.system_prompt = system_prompt
self.stream = stream
self.temperature = temperature
self.max_tokens = max_tokens
self.ssl_verify = ssl_verify
self.max_completion_tokens = max_completion_tokens
self.tools_list_dictionary = tools_list_dictionary
self.tool_choice = tool_choice
self.parallel_tool_calls = parallel_tool_calls
self.caching = caching
self.modalities = []
self._cached_messages = {} # Cache for prepared messages
self.messages = [] # Initialize messages list
# Configure litellm settings
litellm.set_verbose = (
verbose # Disable verbose mode for better performance
)
litellm.ssl_verify = ssl_verify
litellm.num_retries = (
retries # Add retries for better reliability
)
def output_for_tools(self, response: any):
return response["choices"][0]["message"]["tool_calls"][0]
def _prepare_messages(self, task: str) -> list:
"""
Prepare the messages for the given task.
Args:
task (str): The task to prepare messages for.
Returns:
list: A list of messages prepared for the task.
"""
# Check cache first
cache_key = f"{self.system_prompt}:{task}"
if cache_key in self._cached_messages:
return self._cached_messages[cache_key].copy()
messages = []
if self.system_prompt:
messages.append(
{"role": "system", "content": self.system_prompt}
)
messages.append({"role": "user", "content": task})
# Cache the prepared messages
self._cached_messages[cache_key] = messages.copy()
return messages
def audio_processing(self, task: str, audio: str):
"""
Process the audio for the given task.
Args:
task (str): The task to be processed.
audio (str): The path or identifier for the audio file.
"""
self.modalities.append("audio")
encoded_string = get_audio_base64(audio)
# Append messages
self.messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": task},
{
"type": "input_audio",
"input_audio": {
"data": encoded_string,
"format": "wav",
},
},
],
}
)
def vision_processing(self, task: str, image: str):
"""
Process the image for the given task.
"""
self.modalities.append("vision")
# Append messages
self.messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": task},
{
"type": "image_url",
"image_url": {
"url": image,
# "detail": "high"
# "format": "image",
},
},
],
}
)
def handle_modalities(
self, task: str, audio: str = None, img: str = None
):
"""
Handle the modalities for the given task.
"""
self.messages = [] # Reset messages
self.modalities.append("text")
if audio is not None:
self.audio_processing(task=task, audio=audio)
self.modalities.append("audio")
if img is not None:
self.vision_processing(task=task, image=img)
self.modalities.append("vision")
def run(
self,
task: str,
audio: str = None,
img: str = None,
*args,
**kwargs,
):
"""
Run the LLM model for the given task.
Args:
task (str): The task to run the model for.
audio (str, optional): Audio input if any. Defaults to None.
img (str, optional): Image input if any. Defaults to None.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The content of the response from the model.
Raises:
Exception: If there is an error in processing the request.
"""
try:
messages = self._prepare_messages(task)
if audio is not None or img is not None:
self.handle_modalities(
task=task, audio=audio, img=img
)
messages = (
self.messages
) # Use modality-processed messages
if (
self.model_name == "openai/o4-mini"
or self.model_name == "openai/o3-2025-04-16"
):
# Prepare common completion parameters
completion_params = {
"model": self.model_name,
"messages": messages,
"stream": self.stream,
# "temperature": self.temperature,
"max_completion_tokens": self.max_tokens,
"caching": self.caching,
**kwargs,
}
else:
# Prepare common completion parameters
completion_params = {
"model": self.model_name,
"messages": messages,
"stream": self.stream,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"caching": self.caching,
**kwargs,
}
# Handle tool-based completion
if self.tools_list_dictionary is not None:
completion_params.update(
{
"tools": self.tools_list_dictionary,
"tool_choice": self.tool_choice,
"parallel_tool_calls": self.parallel_tool_calls,
}
)
response = completion(**completion_params)
return (
response.choices[0]
.message.tool_calls[0]
.function.arguments
)
# Handle modality-based completion
if (
self.modalities and len(self.modalities) > 1
): # More than just text
completion_params.update(
{"modalities": self.modalities}
)
response = completion(**completion_params)
return response.choices[0].message.content
# Standard completion
if self.stream:
output = completion(**completion_params)
else:
response = completion(**completion_params)
if self.tools_list_dictionary is not None:
return self.output_for_tools(response)
else:
return response.choices[0].message.content
return output
except LiteLLMException as error:
logger.error(f"Error in LiteLLM run: {str(error)}")
if "rate_limit" in str(error).lower():
logger.warning(
"Rate limit hit, retrying with exponential backoff..."
)
import time
time.sleep(2) # Add a small delay before retry
return self.run(task, audio, img, *args, **kwargs)
raise error
def __call__(self, task: str, *args, **kwargs):
"""
Call the LLM model for the given task.
Args:
task (str): The task to run the model for.
*args: Additional positional arguments to pass to the model.
**kwargs: Additional keyword arguments to pass to the model.
Returns:
str: The content of the response from the model.
"""
return self.run(task, *args, **kwargs)
async def arun(self, task: str, *args, **kwargs):
"""
Run the LLM model asynchronously for the given task.
Args:
task (str): The task to run the model for.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The content of the response from the model.
"""
try:
messages = self._prepare_messages(task)
# Prepare common completion parameters
completion_params = {
"model": self.model_name,
"messages": messages,
"stream": self.stream,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
**kwargs,
}
# Handle tool-based completion
if self.tools_list_dictionary is not None:
completion_params.update(
{
"tools": self.tools_list_dictionary,
"tool_choice": self.tool_choice,
"parallel_tool_calls": self.parallel_tool_calls,
}
)
response = await acompletion(**completion_params)
return (
response.choices[0]
.message.tool_calls[0]
.function.arguments
)
# Standard completion
response = await acompletion(**completion_params)
print(response)
return response
except Exception as error:
logger.error(f"Error in LiteLLM arun: {str(error)}")
# if "rate_limit" in str(error).lower():
# logger.warning(
# "Rate limit hit, retrying with exponential backoff..."
# )
# await asyncio.sleep(2) # Use async sleep
# return await self.arun(task, *args, **kwargs)
raise error
async def _process_batch(
self, tasks: List[str], batch_size: int = 10
):
"""
Process a batch of tasks asynchronously.
Args:
tasks (List[str]): List of tasks to process.
batch_size (int): Size of each batch.
Returns:
List[str]: List of responses.
"""
results = []
for i in range(0, len(tasks), batch_size):
batch = tasks[i : i + batch_size]
batch_results = await asyncio.gather(
*[self.arun(task) for task in batch],
return_exceptions=True,
)
# Handle any exceptions in the batch
for result in batch_results:
if isinstance(result, Exception):
logger.error(
f"Error in batch processing: {str(result)}"
)
results.append(str(result))
else:
results.append(result)
# Add a small delay between batches to avoid rate limits
if i + batch_size < len(tasks):
await asyncio.sleep(0.5)
return results
def batched_run(self, tasks: List[str], batch_size: int = 10):
"""
Run multiple tasks in batches synchronously.
Args:
tasks (List[str]): List of tasks to process.
batch_size (int): Size of each batch.
Returns:
List[str]: List of responses.
"""
logger.info(
f"Running {len(tasks)} tasks in batches of {batch_size}"
)
return asyncio.run(self._process_batch(tasks, batch_size))
async def batched_arun(
self, tasks: List[str], batch_size: int = 10
):
"""
Run multiple tasks in batches asynchronously.
Args:
tasks (List[str]): List of tasks to process.
batch_size (int): Size of each batch.
Returns:
List[str]: List of responses.
"""
logger.info(
f"Running {len(tasks)} tasks asynchronously in batches of {batch_size}"
)
return await self._process_batch(tasks, batch_size)