pull/916/merge
王祥宇 3 weeks ago committed by GitHub
commit b00dd9860a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,5 +1,6 @@
import traceback import traceback
from typing import Optional from typing import Optional, List, Union
import base64 import base64
import requests import requests
from pathlib import Path from pathlib import Path
@ -51,11 +52,24 @@ def get_audio_base64(audio_source: str) -> str:
return encoded_string return encoded_string
def get_image_base64(image_source: str) -> str: # 修改:更新函数签名和实现以支持列表输入
def get_image_base64(image_source: Union[str, List[str]]) -> Union[str, List[str]]:
""" """
Convert image from a given source to a base64 encoded string. Convert image from a given source to a base64 encoded string.
Handles URLs, local file paths, and data URIs. Handles URLs, local file paths, and data URIs.
Now supports both single image path and list of image paths.
Args:
image_source: String path to image or list of image paths
Returns:
Single base64 string or list of base64 strings
""" """
# 处理图像列表
if isinstance(image_source, list):
return [get_image_base64(single_image) for single_image in image_source]
# 处理单个图像(原始逻辑)
# If already a data URI, return as is # If already a data URI, return as is
if image_source.startswith("data:image"): if image_source.startswith("data:image"):
return image_source return image_source
@ -170,21 +184,33 @@ class LiteLLM:
out = out.model_dump() out = out.model_dump()
return out return out
# Modification: Updated _prepare_messages method to accept img parameter as string or list of strings
def _prepare_messages( def _prepare_messages(
self, self,
task: str, task: str,
img: str = None, img: Union[str, List[str]] = None, #Modification: Parameter type is string or string list
): ):
""" """
Prepare the messages for the given task. Prepare the messages for the given task.
Args: Args:
task (str): The task to prepare messages for. task (str): The task to prepare messages for.
img (Union[str, List[str]], optional): Single image input or list of image inputs. Defaults to None.
Returns: Returns:
list: A list of messages prepared for the task. list: A list of messages prepared for the task.
""" """
self.check_if_model_supports_vision(img=img) # Edit: Convert single image string to list for unified processing
image_list = []
if img is not None:
if isinstance(img, str):
image_list = [img]
else:
image_list = img
# Edit: Check if there is an image to process
if image_list:
self.check_if_model_supports_vision(image_list=image_list)
# Initialize messages # Initialize messages
messages = [] messages = []
@ -196,148 +222,143 @@ class LiteLLM:
) )
# Handle vision case # Handle vision case
if img is not None: if image_list: # 修改:处理图像列表
messages = self.vision_processing( messages = self.vision_processing(
task=task, image=img, messages=messages task=task, images=image_list, messages=messages
) )
else: else:
messages.append({"role": "user", "content": task}) messages.append({"role": "user", "content": task})
return messages return messages
# Modification: Updated anthropic_vision_processing method to handle multiple images
def anthropic_vision_processing( def anthropic_vision_processing(
self, task: str, image: str, messages: list self, task: str, images: List[str], messages: list
) -> list: ) -> list:
""" """
Process vision input specifically for Anthropic models. Process vision input specifically for Anthropic models.
Handles Anthropic's specific image format requirements. Handles Anthropic's specific image format requirements.
Args:
task (str): The task prompt
images (List[str]): List of image paths or URLs
messages (list): Current message list
Returns:
list: Updated messages list with images
""" """
# Get base64 encoded image
image_url = get_image_base64(image) content = [{"type": "text", "text": task}]
# Extract mime type from the data URI or use default # 修改使用新版get_image_base64函数处理图像列表
mime_type = "image/jpeg" # default image_urls = get_image_base64(images)
if "data:" in image_url and ";base64," in image_url: if not isinstance(image_urls, list):
mime_type = image_url.split(";base64,")[0].split("data:")[ image_urls = [image_urls]
1
for i, image_url in enumerate(image_urls):
mime_type = "image/jpeg"
if "data:" in image_url and ";base64," in image_url:
mime_type = image_url.split(";base64,")[0].split("data:")[1]
supported_formats = [
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
] ]
if mime_type not in supported_formats:
# Ensure mime type is one of the supported formats mime_type = "image/jpeg"
supported_formats = [
"image/jpeg", content.append({
"image/png", "type": "image_url",
"image/gif", "image_url": {
"image/webp", "url": image_url,
] "format": mime_type,
if mime_type not in supported_formats: },
mime_type = ( })
"image/jpeg" # fallback to jpeg if unsupported
) messages.append({
"role": "user",
# Construct Anthropic vision message "content": content,
messages.append( })
{
"role": "user",
"content": [
{"type": "text", "text": task},
{
"type": "image_url",
"image_url": {
"url": image_url,
"format": mime_type,
},
},
],
}
)
return messages return messages
# Modification: Updated openai_vision_processing method to handle multiple images
def openai_vision_processing( def openai_vision_processing(
self, task: str, image: str, messages: list self, task: str, images: List[str], messages: list
) -> list: ) -> list:
""" """
Process vision input specifically for OpenAI models. Process vision input specifically for OpenAI models.
Handles OpenAI's specific image format requirements. Handles OpenAI's specific image format requirements.
Args:
task (str): The task prompt
images (List[str]): List of image paths or URLs
messages (list): Current message list
Returns:
list: Updated messages list with images
""" """
# Get base64 encoded image with proper format
image_url = get_image_base64(image) content = [{"type": "text", "text": task}]
# Prepare vision message # 修改使用新版get_image_base64函数处理图像列表
vision_message = { image_urls = get_image_base64(images)
"type": "image_url", if not isinstance(image_urls, list):
"image_url": {"url": image_url}, image_urls = [image_urls]
}
for i, image_url in enumerate(image_urls):
# Add format for specific models vision_message = {
extension = Path(image).suffix.lower() "type": "image_url",
mime_type = ( "image_url": {"url": image_url},
f"image/{extension[1:]}" if extension else "image/jpeg"
)
vision_message["image_url"]["format"] = mime_type
# Append vision message
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": task},
vision_message,
],
} }
)
# 获取对应图像的原始路径
original_image = images[i] if i < len(images) else images[0]
extension = Path(original_image).suffix.lower()
mime_type = f"image/{extension[1:]}" if extension else "image/jpeg"
vision_message["image_url"]["format"] = mime_type
content.append(vision_message)
messages.append({
"role": "user",
"content": content,
})
return messages return messages
# Modification: Updated vision_processing method to handle multiple images
def vision_processing( def vision_processing(
self, task: str, image: str, messages: Optional[list] = None self, task: str, images: List[str], messages: Optional[list] = None
): ):
""" """
Process the image for the given task. Process the images for the given task.
Handles different image formats and model requirements. Handles different image formats and model requirements.
Args:
task (str): The task prompt
images (List[str]): List of image paths or URLs
messages (Optional[list], optional): Current messages list. Defaults to None.
Returns:
list: Updated messages with image content
""" """
# # # Handle Anthropic models separately if messages is None:
# # if "anthropic" in self.model_name.lower() or "claude" in self.model_name.lower(): messages = []
# # messages = self.anthropic_vision_processing(task, image, messages)
# # return messages
# # Get base64 encoded image with proper format
# image_url = get_image_base64(image)
# # Prepare vision message
# vision_message = {
# "type": "image_url",
# "image_url": {"url": image_url},
# }
# # Add format for specific models
# extension = Path(image).suffix.lower()
# mime_type = f"image/{extension[1:]}" if extension else "image/jpeg"
# vision_message["image_url"]["format"] = mime_type
# # Append vision message
# messages.append(
# {
# "role": "user",
# "content": [
# {"type": "text", "text": task},
# vision_message,
# ],
# }
# )
# return messages
if ( if (
"anthropic" in self.model_name.lower() "anthropic" in self.model_name.lower()
or "claude" in self.model_name.lower() or "claude" in self.model_name.lower()
): ):
messages = self.anthropic_vision_processing( messages = self.anthropic_vision_processing(
task, image, messages task, images, messages
) )
return messages return messages
else: else:
messages = self.openai_vision_processing( messages = self.openai_vision_processing(
task, image, messages task, images, messages
) )
return messages return messages
@ -368,23 +389,33 @@ class LiteLLM:
} }
) )
def check_if_model_supports_vision(self, img: str = None): # Modification: Updated check_if_model_supports_vision method to support image lists
def check_if_model_supports_vision(self, img: str = None, image_list: List[str] = None):
""" """
Check if the model supports vision. Check if the model supports vision.
Args:
img (str, optional): Single image path (for backward compatibility). Defaults to None.
image_list (List[str], optional): List of image paths. Defaults to None.
Raises:
ValueError: If the model does not support vision.
""" """
if img is not None: # If there are any images (single or multiple), check if the model supports vision
if img is not None or (image_list and len(image_list) > 0):
out = supports_vision(model=self.model_name) out = supports_vision(model=self.model_name)
if out is False: if out is False:
raise ValueError( raise ValueError(
f"Model {self.model_name} does not support vision" f"Model {self.model_name} does not support vision"
) )
# Modification: Update the run method so that the img parameter can accept a string or a list of strings
def run( def run(
self, self,
task: str, task: str,
audio: Optional[str] = None, audio: Optional[str] = None,
img: Optional[str] = None, img: Union[str, List[str]] = None,
*args, *args,
**kwargs, **kwargs,
): ):
@ -394,7 +425,7 @@ class LiteLLM:
Args: Args:
task (str): The task to run the model for. task (str): The task to run the model for.
audio (str, optional): Audio input if any. Defaults to None. audio (str, optional): Audio input if any. Defaults to None.
img (str, optional): Image input if any. Defaults to None. img (Union[str, List[str]], optional): Single image input or list of image inputs. Defaults to None.
*args: Additional positional arguments. *args: Additional positional arguments.
**kwargs: Additional keyword arguments. **kwargs: Additional keyword arguments.
@ -405,6 +436,7 @@ class LiteLLM:
Exception: If there is an error in processing the request. Exception: If there is an error in processing the request.
""" """
try: try:
messages = self._prepare_messages(task=task, img=img) messages = self._prepare_messages(task=task, img=img)
# Base completion parameters # Base completion parameters
@ -492,12 +524,14 @@ class LiteLLM:
""" """
return self.run(task, *args, **kwargs) return self.run(task, *args, **kwargs)
async def arun(self, task: str, *args, **kwargs): # Modification: Updated arun method to accept img parameter as string or list of strings
async def arun(self, task: str, img: Union[str, List[str]] = None, *args, **kwargs):
""" """
Run the LLM model asynchronously for the given task. Run the LLM model asynchronously for the given task.
Args: Args:
task (str): The task to run the model for. task (str): The task to run the model for.
img (Union[str, List[str]], optional): Single image input or list of image inputs. Defaults to None.
*args: Additional positional arguments. *args: Additional positional arguments.
**kwargs: Additional keyword arguments. **kwargs: Additional keyword arguments.
@ -505,7 +539,8 @@ class LiteLLM:
str: The content of the response from the model. str: The content of the response from the model.
""" """
try: try:
messages = self._prepare_messages(task)
messages = self._prepare_messages(task=task, img=img)
# Prepare common completion parameters # Prepare common completion parameters
completion_params = { completion_params = {
@ -618,4 +653,5 @@ class LiteLLM:
logger.info( logger.info(
f"Running {len(tasks)} tasks asynchronously in batches of {batch_size}" f"Running {len(tasks)} tasks asynchronously in batches of {batch_size}"
) )
return await self._process_batch(tasks, batch_size) return await self._process_batch(tasks, batch_size)
Loading…
Cancel
Save