Feed multiple images into the agent

pull/913/head
Wxysnx 1 month ago
parent 7e952234c1
commit 2acbfc7e4b

@ -1,5 +1,5 @@
import traceback import traceback
from typing import Optional from typing import Optional, List, Union
import base64 import base64
import requests import requests
from pathlib import Path from pathlib import Path
@ -168,21 +168,33 @@ class LiteLLM:
out = out.model_dump() out = out.model_dump()
return out return out
# Modification: Updated _prepare_messages method to accept img parameter as string or list of strings
def _prepare_messages( def _prepare_messages(
self, self,
task: str, task: str,
img: str = None, img: Union[str, List[str]] = None, #Modification: Parameter type is string or string list
): ):
""" """
Prepare the messages for the given task. Prepare the messages for the given task.
Args: Args:
task (str): The task to prepare messages for. task (str): The task to prepare messages for.
img (Union[str, List[str]], optional): Single image input or list of image inputs. Defaults to None.
Returns: Returns:
list: A list of messages prepared for the task. list: A list of messages prepared for the task.
""" """
self.check_if_model_supports_vision(img=img) # Edit: Convert single image string to list for unified processing
image_list = []
if img is not None:
if isinstance(img, str):
image_list = [img]
else:
image_list = img
# Edit: Check if there is an image to process
if image_list:
self.check_if_model_supports_vision(image_list=image_list)
# Initialize messages # Initialize messages
messages = [] messages = []
@ -194,148 +206,147 @@ class LiteLLM:
) )
# Handle vision case # Handle vision case
if img is not None: if image_list: # 修改:处理图像列表
messages = self.vision_processing( messages = self.vision_processing(
task=task, image=img, messages=messages task=task, images=image_list, messages=messages
) )
else: else:
messages.append({"role": "user", "content": task}) messages.append({"role": "user", "content": task})
return messages return messages
# Modification: Updated anthropic_vision_processing method to handle multiple images
def anthropic_vision_processing( def anthropic_vision_processing(
self, task: str, image: str, messages: list self, task: str, images: List[str], messages: list
) -> list: ) -> list:
""" """
Process vision input specifically for Anthropic models. Process vision input specifically for Anthropic models.
Handles Anthropic's specific image format requirements. Handles Anthropic's specific image format requirements.
Args:
task (str): The task prompt
images (List[str]): List of image paths or URLs
messages (list): Current message list
Returns:
list: Updated messages list with images
""" """
# Get base64 encoded image
image_url = get_image_base64(image) content = [{"type": "text", "text": task}]
# Extract mime type from the data URI or use default
mime_type = "image/jpeg" # default for image in images:
if "data:" in image_url and ";base64," in image_url:
mime_type = image_url.split(";base64,")[0].split("data:")[ image_url = get_image_base64(image)
1
mime_type = "image/jpeg"
if "data:" in image_url and ";base64," in image_url:
mime_type = image_url.split(";base64,")[0].split("data:")[1]
supported_formats = [
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
] ]
if mime_type not in supported_formats:
# Ensure mime type is one of the supported formats mime_type = "image/jpeg"
supported_formats = [
"image/jpeg",
"image/png", content.append({
"image/gif", "type": "image_url",
"image/webp", "image_url": {
] "url": image_url,
if mime_type not in supported_formats: "format": mime_type,
mime_type = ( },
"image/jpeg" # fallback to jpeg if unsupported })
)
# Construct Anthropic vision message messages.append({
messages.append( "role": "user",
{ "content": content,
"role": "user", })
"content": [
{"type": "text", "text": task},
{
"type": "image_url",
"image_url": {
"url": image_url,
"format": mime_type,
},
},
],
}
)
return messages return messages
# Modification: Updated openai_vision_processing method to handle multiple images
def openai_vision_processing( def openai_vision_processing(
self, task: str, image: str, messages: list self, task: str, images: List[str], messages: list
) -> list: ) -> list:
""" """
Process vision input specifically for OpenAI models. Process vision input specifically for OpenAI models.
Handles OpenAI's specific image format requirements. Handles OpenAI's specific image format requirements.
Args:
task (str): The task prompt
images (List[str]): List of image paths or URLs
messages (list): Current message list
Returns:
list: Updated messages list with images
""" """
# Get base64 encoded image with proper format
image_url = get_image_base64(image) content = [{"type": "text", "text": task}]
# Prepare vision message
vision_message = { for image in images:
"type": "image_url",
"image_url": {"url": image_url}, image_url = get_image_base64(image)
}
# Add format for specific models vision_message = {
extension = Path(image).suffix.lower() "type": "image_url",
mime_type = ( "image_url": {"url": image_url},
f"image/{extension[1:]}" if extension else "image/jpeg"
)
vision_message["image_url"]["format"] = mime_type
# Append vision message
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": task},
vision_message,
],
} }
)
extension = Path(image).suffix.lower()
mime_type = f"image/{extension[1:]}" if extension else "image/jpeg"
vision_message["image_url"]["format"] = mime_type
content.append(vision_message)
messages.append({
"role": "user",
"content": content,
})
return messages return messages
# Modification: Updated vision_processing method to handle multiple images
def vision_processing( def vision_processing(
self, task: str, image: str, messages: Optional[list] = None self, task: str, images: List[str], messages: Optional[list] = None
): ):
""" """
Process the image for the given task. Process the images for the given task.
Handles different image formats and model requirements. Handles different image formats and model requirements.
Args:
task (str): The task prompt
images (List[str]): List of image paths or URLs
messages (Optional[list], optional): Current messages list. Defaults to None.
Returns:
list: Updated messages with image content
""" """
# # # Handle Anthropic models separately if messages is None:
# # if "anthropic" in self.model_name.lower() or "claude" in self.model_name.lower(): messages = []
# # messages = self.anthropic_vision_processing(task, image, messages)
# # return messages
# # Get base64 encoded image with proper format
# image_url = get_image_base64(image)
# # Prepare vision message
# vision_message = {
# "type": "image_url",
# "image_url": {"url": image_url},
# }
# # Add format for specific models
# extension = Path(image).suffix.lower()
# mime_type = f"image/{extension[1:]}" if extension else "image/jpeg"
# vision_message["image_url"]["format"] = mime_type
# # Append vision message
# messages.append(
# {
# "role": "user",
# "content": [
# {"type": "text", "text": task},
# vision_message,
# ],
# }
# )
# return messages
if ( if (
"anthropic" in self.model_name.lower() "anthropic" in self.model_name.lower()
or "claude" in self.model_name.lower() or "claude" in self.model_name.lower()
): ):
messages = self.anthropic_vision_processing( messages = self.anthropic_vision_processing(
task, image, messages task, images, messages
) )
return messages return messages
else: else:
messages = self.openai_vision_processing( messages = self.openai_vision_processing(
task, image, messages task, images, messages
) )
return messages return messages
@ -366,23 +377,33 @@ class LiteLLM:
} }
) )
def check_if_model_supports_vision(self, img: str = None): # Modification: Updated check_if_model_supports_vision method to support image lists
def check_if_model_supports_vision(self, img: str = None, image_list: List[str] = None):
""" """
Check if the model supports vision. Check if the model supports vision.
Args:
img (str, optional): Single image path (for backward compatibility). Defaults to None.
image_list (List[str], optional): List of image paths. Defaults to None.
Raises:
ValueError: If the model does not support vision.
""" """
if img is not None: # If there are any images (single or multiple), check if the model supports vision
if img is not None or (image_list and len(image_list) > 0):
out = supports_vision(model=self.model_name) out = supports_vision(model=self.model_name)
if out is False: if out is False:
raise ValueError( raise ValueError(
f"Model {self.model_name} does not support vision" f"Model {self.model_name} does not support vision"
) )
# Modification: Update the run method so that the img parameter can accept a string or a list of strings
def run( def run(
self, self,
task: str, task: str,
audio: Optional[str] = None, audio: Optional[str] = None,
img: Optional[str] = None, img: Union[str, List[str]] = None,
*args, *args,
**kwargs, **kwargs,
): ):
@ -392,7 +413,7 @@ class LiteLLM:
Args: Args:
task (str): The task to run the model for. task (str): The task to run the model for.
audio (str, optional): Audio input if any. Defaults to None. audio (str, optional): Audio input if any. Defaults to None.
img (str, optional): Image input if any. Defaults to None. img (Union[str, List[str]], optional): Single image input or list of image inputs. Defaults to None.
*args: Additional positional arguments. *args: Additional positional arguments.
**kwargs: Additional keyword arguments. **kwargs: Additional keyword arguments.
@ -403,6 +424,7 @@ class LiteLLM:
Exception: If there is an error in processing the request. Exception: If there is an error in processing the request.
""" """
try: try:
messages = self._prepare_messages(task=task, img=img) messages = self._prepare_messages(task=task, img=img)
# Base completion parameters # Base completion parameters
@ -486,12 +508,14 @@ class LiteLLM:
""" """
return self.run(task, *args, **kwargs) return self.run(task, *args, **kwargs)
async def arun(self, task: str, *args, **kwargs): # Modification: Updated arun method to accept img parameter as string or list of strings
async def arun(self, task: str, img: Union[str, List[str]] = None, *args, **kwargs):
""" """
Run the LLM model asynchronously for the given task. Run the LLM model asynchronously for the given task.
Args: Args:
task (str): The task to run the model for. task (str): The task to run the model for.
img (Union[str, List[str]], optional): Single image input or list of image inputs. Defaults to None.
*args: Additional positional arguments. *args: Additional positional arguments.
**kwargs: Additional keyword arguments. **kwargs: Additional keyword arguments.
@ -499,7 +523,8 @@ class LiteLLM:
str: The content of the response from the model. str: The content of the response from the model.
""" """
try: try:
messages = self._prepare_messages(task)
messages = self._prepare_messages(task=task, img=img)
# Prepare common completion parameters # Prepare common completion parameters
completion_params = { completion_params = {
@ -612,4 +637,4 @@ class LiteLLM:
logger.info( logger.info(
f"Running {len(tasks)} tasks asynchronously in batches of {batch_size}" f"Running {len(tasks)} tasks asynchronously in batches of {batch_size}"
) )
return await self._process_batch(tasks, batch_size) return await self._process_batch(tasks, batch_size)
Loading…
Cancel
Save