From 754c2b710d2afb78583286f4217262e97b85b8db Mon Sep 17 00:00:00 2001 From: Wxysnx <625024108@qq.com> Date: Sat, 21 Jun 2025 16:30:31 +0800 Subject: [PATCH] Feed multiple images into the agent --- swarms/utils/litellm_wrapper_new.py | 640 ++++++++++++++++++++++++++++ 1 file changed, 640 insertions(+) create mode 100644 swarms/utils/litellm_wrapper_new.py diff --git a/swarms/utils/litellm_wrapper_new.py b/swarms/utils/litellm_wrapper_new.py new file mode 100644 index 00000000..d5ed3f60 --- /dev/null +++ b/swarms/utils/litellm_wrapper_new.py @@ -0,0 +1,640 @@ +import traceback +from typing import Optional, List, Union +import base64 +import requests +from pathlib import Path + +import asyncio +from typing import List + +from loguru import logger +import litellm +from pydantic import BaseModel + +from litellm import completion, acompletion, supports_vision + + +class LiteLLMException(Exception): + """ + Exception for LiteLLM. + """ + + +def get_audio_base64(audio_source: str) -> str: + """ + Convert audio from a given source to a base64 encoded string. + + This function handles both URLs and local file paths. If the audio source is a URL, it fetches the audio data + from the internet. If it is a local file path, it reads the audio data from the specified file. + + Args: + audio_source (str): The source of the audio, which can be a URL or a local file path. + + Returns: + str: A base64 encoded string representation of the audio data. + + Raises: + requests.HTTPError: If the HTTP request to fetch audio data fails. + FileNotFoundError: If the local audio file does not exist. + """ + # Handle URL + if audio_source.startswith(("http://", "https://")): + response = requests.get(audio_source) + response.raise_for_status() + audio_data = response.content + # Handle local file + else: + with open(audio_source, "rb") as file: + audio_data = file.read() + + encoded_string = base64.b64encode(audio_data).decode("utf-8") + return encoded_string + + +def get_image_base64(image_source: str) -> str: + """ + Convert image from a given source to a base64 encoded string. + Handles URLs, local file paths, and data URIs. + """ + # If already a data URI, return as is + if image_source.startswith("data:image"): + return image_source + + # Handle URL + if image_source.startswith(("http://", "https://")): + response = requests.get(image_source) + response.raise_for_status() + image_data = response.content + # Handle local file + else: + with open(image_source, "rb") as file: + image_data = file.read() + + # Get file extension for mime type + extension = Path(image_source).suffix.lower() + mime_type = ( + f"image/{extension[1:]}" if extension else "image/jpeg" + ) + + encoded_string = base64.b64encode(image_data).decode("utf-8") + return f"data:{mime_type};base64,{encoded_string}" + + +class LiteLLM: + """ + This class represents a LiteLLM. + It is used to interact with the LLM model for various tasks. + """ + + def __init__( + self, + model_name: str = "gpt-4o", + system_prompt: str = None, + stream: bool = False, + temperature: float = 0.5, + max_tokens: int = 4000, + ssl_verify: bool = False, + max_completion_tokens: int = 4000, + tools_list_dictionary: List[dict] = None, + tool_choice: str = "auto", + parallel_tool_calls: bool = False, + audio: str = None, + retries: int = 0, + verbose: bool = False, + caching: bool = False, + mcp_call: bool = False, + top_p: float = 1.0, + functions: List[dict] = None, + return_all: bool = False, + base_url: str = None, + api_key: str = None, + *args, + **kwargs, + ): + """ + Initialize the LiteLLM with the given parameters. + + Args: + model_name (str, optional): The name of the model to use. Defaults to "gpt-4o". + system_prompt (str, optional): The system prompt to use. Defaults to None. + stream (bool, optional): Whether to stream the output. Defaults to False. + temperature (float, optional): The temperature for the model. Defaults to 0.5. + max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 4000. + """ + self.model_name = model_name + self.system_prompt = system_prompt + self.stream = stream + self.temperature = temperature + self.max_tokens = max_tokens + self.ssl_verify = ssl_verify + self.max_completion_tokens = max_completion_tokens + self.tools_list_dictionary = tools_list_dictionary + self.tool_choice = tool_choice + self.parallel_tool_calls = parallel_tool_calls + self.caching = caching + self.mcp_call = mcp_call + self.top_p = top_p + self.functions = functions + self.audio = audio + self.return_all = return_all + self.base_url = base_url + self.api_key = api_key + self.modalities = [] + self.messages = [] # Initialize messages list + + # Configure litellm settings + litellm.set_verbose = ( + verbose # Disable verbose mode for better performance + ) + litellm.ssl_verify = ssl_verify + litellm.num_retries = ( + retries # Add retries for better reliability + ) + + def output_for_tools(self, response: any): + if self.mcp_call is True: + out = response.choices[0].message.tool_calls[0].function + output = { + "function": { + "name": out.name, + "arguments": out.arguments, + } + } + return output + else: + out = response.choices[0].message.tool_calls + + if isinstance(out, BaseModel): + out = out.model_dump() + return out + + # Modification: Updated _prepare_messages method to accept img parameter as string or list of strings + def _prepare_messages( + self, + task: str, + img: Union[str, List[str]] = None, #Modification: Parameter type is string or string list + ): + """ + Prepare the messages for the given task. + + Args: + task (str): The task to prepare messages for. + img (Union[str, List[str]], optional): Single image input or list of image inputs. Defaults to None. + + Returns: + list: A list of messages prepared for the task. + """ + # Edit: Convert single image string to list for unified processing + image_list = [] + if img is not None: + if isinstance(img, str): + image_list = [img] + else: + image_list = img + + # Edit: Check if there is an image to process + if image_list: + self.check_if_model_supports_vision(image_list=image_list) + + # Initialize messages + messages = [] + + # Add system prompt if present + if self.system_prompt is not None: + messages.append( + {"role": "system", "content": self.system_prompt} + ) + + # Handle vision case + if image_list: # 修改:处理图像列表 + messages = self.vision_processing( + task=task, images=image_list, messages=messages + ) + else: + messages.append({"role": "user", "content": task}) + + return messages + + # Modification: Updated anthropic_vision_processing method to handle multiple images + def anthropic_vision_processing( + self, task: str, images: List[str], messages: list + ) -> list: + """ + Process vision input specifically for Anthropic models. + Handles Anthropic's specific image format requirements. + + Args: + task (str): The task prompt + images (List[str]): List of image paths or URLs + messages (list): Current message list + + Returns: + list: Updated messages list with images + """ + + content = [{"type": "text", "text": task}] + + + for image in images: + + image_url = get_image_base64(image) + + + mime_type = "image/jpeg" + if "data:" in image_url and ";base64," in image_url: + mime_type = image_url.split(";base64,")[0].split("data:")[1] + + + supported_formats = [ + "image/jpeg", + "image/png", + "image/gif", + "image/webp", + ] + if mime_type not in supported_formats: + mime_type = "image/jpeg" + + + content.append({ + "type": "image_url", + "image_url": { + "url": image_url, + "format": mime_type, + }, + }) + + + messages.append({ + "role": "user", + "content": content, + }) + + return messages + + # Modification: Updated openai_vision_processing method to handle multiple images + def openai_vision_processing( + self, task: str, images: List[str], messages: list + ) -> list: + """ + Process vision input specifically for OpenAI models. + Handles OpenAI's specific image format requirements. + + Args: + task (str): The task prompt + images (List[str]): List of image paths or URLs + messages (list): Current message list + + Returns: + list: Updated messages list with images + """ + + content = [{"type": "text", "text": task}] + + + for image in images: + + image_url = get_image_base64(image) + + + vision_message = { + "type": "image_url", + "image_url": {"url": image_url}, + } + + + extension = Path(image).suffix.lower() + mime_type = f"image/{extension[1:]}" if extension else "image/jpeg" + vision_message["image_url"]["format"] = mime_type + + + content.append(vision_message) + + + messages.append({ + "role": "user", + "content": content, + }) + + return messages + + # Modification: Updated vision_processing method to handle multiple images + def vision_processing( + self, task: str, images: List[str], messages: Optional[list] = None + ): + """ + Process the images for the given task. + Handles different image formats and model requirements. + + Args: + task (str): The task prompt + images (List[str]): List of image paths or URLs + messages (Optional[list], optional): Current messages list. Defaults to None. + + Returns: + list: Updated messages with image content + """ + if messages is None: + messages = [] + + if ( + "anthropic" in self.model_name.lower() + or "claude" in self.model_name.lower() + ): + messages = self.anthropic_vision_processing( + task, images, messages + ) + return messages + else: + messages = self.openai_vision_processing( + task, images, messages + ) + return messages + + def audio_processing(self, task: str, audio: str): + """ + Process the audio for the given task. + + Args: + task (str): The task to be processed. + audio (str): The path or identifier for the audio file. + """ + encoded_string = get_audio_base64(audio) + + # Append audio message + self.messages.append( + { + "role": "user", + "content": [ + {"type": "text", "text": task}, + { + "type": "input_audio", + "input_audio": { + "data": encoded_string, + "format": "wav", + }, + }, + ], + } + ) + + # Modification: Updated check_if_model_supports_vision method to support image lists + def check_if_model_supports_vision(self, img: str = None, image_list: List[str] = None): + """ + Check if the model supports vision. + + Args: + img (str, optional): Single image path (for backward compatibility). Defaults to None. + image_list (List[str], optional): List of image paths. Defaults to None. + + Raises: + ValueError: If the model does not support vision. + """ + # If there are any images (single or multiple), check if the model supports vision + if img is not None or (image_list and len(image_list) > 0): + out = supports_vision(model=self.model_name) + + if out is False: + raise ValueError( + f"Model {self.model_name} does not support vision" + ) + + # Modification: Update the run method so that the img parameter can accept a string or a list of strings + def run( + self, + task: str, + audio: Optional[str] = None, + img: Union[str, List[str]] = None, + *args, + **kwargs, + ): + """ + Run the LLM model for the given task. + + Args: + task (str): The task to run the model for. + audio (str, optional): Audio input if any. Defaults to None. + img (Union[str, List[str]], optional): Single image input or list of image inputs. Defaults to None. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The content of the response from the model. + + Raises: + Exception: If there is an error in processing the request. + """ + try: + + messages = self._prepare_messages(task=task, img=img) + + # Base completion parameters + completion_params = { + "model": self.model_name, + "messages": messages, + "stream": self.stream, + "max_tokens": self.max_tokens, + "caching": self.caching, + "temperature": self.temperature, + "top_p": self.top_p, + **kwargs, + } + + # Add temperature for non-o4/o3 models + if self.model_name not in [ + "openai/o4-mini", + "openai/o3-2025-04-16", + ]: + completion_params["temperature"] = self.temperature + + # Add tools if specified + if self.tools_list_dictionary is not None: + completion_params.update( + { + "tools": self.tools_list_dictionary, + "tool_choice": self.tool_choice, + "parallel_tool_calls": self.parallel_tool_calls, + } + ) + + if self.functions is not None: + completion_params.update( + {"functions": self.functions} + ) + + if self.base_url is not None: + completion_params["base_url"] = self.base_url + + # Add modalities if needed + if self.modalities and len(self.modalities) >= 2: + completion_params["modalities"] = self.modalities + + # Make the completion call + response = completion(**completion_params) + + # Handle tool-based response + if self.tools_list_dictionary is not None: + return self.output_for_tools(response) + elif self.return_all is True: + return response.model_dump() + else: + # Return standard response content + return response.choices[0].message.content + + except LiteLLMException as error: + logger.error( + f"Error in LiteLLM run: {str(error)} Traceback: {traceback.format_exc()}" + ) + if "rate_limit" in str(error).lower(): + logger.warning( + "Rate limit hit, retrying with exponential backoff..." + ) + import time + + time.sleep(2) + return self.run(task, audio, img, *args, **kwargs) + raise error + + def __call__(self, task: str, *args, **kwargs): + """ + Call the LLM model for the given task. + + Args: + task (str): The task to run the model for. + *args: Additional positional arguments to pass to the model. + **kwargs: Additional keyword arguments to pass to the model. + + Returns: + str: The content of the response from the model. + """ + return self.run(task, *args, **kwargs) + + # Modification: Updated arun method to accept img parameter as string or list of strings + async def arun(self, task: str, img: Union[str, List[str]] = None, *args, **kwargs): + """ + Run the LLM model asynchronously for the given task. + + Args: + task (str): The task to run the model for. + img (Union[str, List[str]], optional): Single image input or list of image inputs. Defaults to None. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + str: The content of the response from the model. + """ + try: + + messages = self._prepare_messages(task=task, img=img) + + # Prepare common completion parameters + completion_params = { + "model": self.model_name, + "messages": messages, + "stream": self.stream, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + **kwargs, + } + + # Handle tool-based completion + if self.tools_list_dictionary is not None: + completion_params.update( + { + "tools": self.tools_list_dictionary, + "tool_choice": self.tool_choice, + "parallel_tool_calls": self.parallel_tool_calls, + } + ) + response = await acompletion(**completion_params) + return ( + response.choices[0] + .message.tool_calls[0] + .function.arguments + ) + + # Standard completion + response = await acompletion(**completion_params) + + print(response) + return response + + except Exception as error: + logger.error(f"Error in LiteLLM arun: {str(error)}") + # if "rate_limit" in str(error).lower(): + # logger.warning( + # "Rate limit hit, retrying with exponential backoff..." + # ) + # await asyncio.sleep(2) # Use async sleep + # return await self.arun(task, *args, **kwargs) + raise error + + async def _process_batch( + self, tasks: List[str], batch_size: int = 10 + ): + """ + Process a batch of tasks asynchronously. + + Args: + tasks (List[str]): List of tasks to process. + batch_size (int): Size of each batch. + + Returns: + List[str]: List of responses. + """ + results = [] + for i in range(0, len(tasks), batch_size): + batch = tasks[i : i + batch_size] + batch_results = await asyncio.gather( + *[self.arun(task) for task in batch], + return_exceptions=True, + ) + + # Handle any exceptions in the batch + for result in batch_results: + if isinstance(result, Exception): + logger.error( + f"Error in batch processing: {str(result)}" + ) + results.append(str(result)) + else: + results.append(result) + + # Add a small delay between batches to avoid rate limits + if i + batch_size < len(tasks): + await asyncio.sleep(0.5) + + return results + + def batched_run(self, tasks: List[str], batch_size: int = 10): + """ + Run multiple tasks in batches synchronously. + + Args: + tasks (List[str]): List of tasks to process. + batch_size (int): Size of each batch. + + Returns: + List[str]: List of responses. + """ + logger.info( + f"Running {len(tasks)} tasks in batches of {batch_size}" + ) + return asyncio.run(self._process_batch(tasks, batch_size)) + + async def batched_arun( + self, tasks: List[str], batch_size: int = 10 + ): + """ + Run multiple tasks in batches asynchronously. + + Args: + tasks (List[str]): List of tasks to process. + batch_size (int): Size of each batch. + + Returns: + List[str]: List of responses. + """ + logger.info( + f"Running {len(tasks)} tasks asynchronously in batches of {batch_size}" + ) + return await self._process_batch(tasks, batch_size) \ No newline at end of file