更新 litellm_wrapper.py 文件

pull/916/head
王祥宇 4 weeks ago
parent a2c42fca54
commit c28dea745f

@ -1,3 +1,4 @@
import traceback import traceback
from typing import Optional, List, Union from typing import Optional, List, Union
import base64 import base64
@ -51,11 +52,24 @@ def get_audio_base64(audio_source: str) -> str:
return encoded_string return encoded_string
def get_image_base64(image_source: str) -> str: # 修改:更新函数签名和实现以支持列表输入
def get_image_base64(image_source: Union[str, List[str]]) -> Union[str, List[str]]:
""" """
Convert image from a given source to a base64 encoded string. Convert image from a given source to a base64 encoded string.
Handles URLs, local file paths, and data URIs. Handles URLs, local file paths, and data URIs.
Now supports both single image path and list of image paths.
Args:
image_source: String path to image or list of image paths
Returns:
Single base64 string or list of base64 strings
""" """
# 处理图像列表
if isinstance(image_source, list):
return [get_image_base64(single_image) for single_image in image_source]
# 处理单个图像(原始逻辑)
# If already a data URI, return as is # If already a data URI, return as is
if image_source.startswith("data:image"): if image_source.startswith("data:image"):
return image_source return image_source
@ -234,17 +248,16 @@ class LiteLLM:
content = [{"type": "text", "text": task}] content = [{"type": "text", "text": task}]
# 修改使用新版get_image_base64函数处理图像列表
image_urls = get_image_base64(images)
if not isinstance(image_urls, list):
image_urls = [image_urls]
for image in images: for i, image_url in enumerate(image_urls):
image_url = get_image_base64(image)
mime_type = "image/jpeg" mime_type = "image/jpeg"
if "data:" in image_url and ";base64," in image_url: if "data:" in image_url and ";base64," in image_url:
mime_type = image_url.split(";base64,")[0].split("data:")[1] mime_type = image_url.split(";base64,")[0].split("data:")[1]
supported_formats = [ supported_formats = [
"image/jpeg", "image/jpeg",
"image/png", "image/png",
@ -254,7 +267,6 @@ class LiteLLM:
if mime_type not in supported_formats: if mime_type not in supported_formats:
mime_type = "image/jpeg" mime_type = "image/jpeg"
content.append({ content.append({
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
@ -263,7 +275,6 @@ class LiteLLM:
}, },
}) })
messages.append({ messages.append({
"role": "user", "role": "user",
"content": content, "content": content,
@ -290,26 +301,25 @@ class LiteLLM:
content = [{"type": "text", "text": task}] content = [{"type": "text", "text": task}]
# 修改使用新版get_image_base64函数处理图像列表
image_urls = get_image_base64(images)
if not isinstance(image_urls, list):
image_urls = [image_urls]
for image in images: for i, image_url in enumerate(image_urls):
image_url = get_image_base64(image)
vision_message = { vision_message = {
"type": "image_url", "type": "image_url",
"image_url": {"url": image_url}, "image_url": {"url": image_url},
} }
# 获取对应图像的原始路径
extension = Path(image).suffix.lower() original_image = images[i] if i < len(images) else images[0]
extension = Path(original_image).suffix.lower()
mime_type = f"image/{extension[1:]}" if extension else "image/jpeg" mime_type = f"image/{extension[1:]}" if extension else "image/jpeg"
vision_message["image_url"]["format"] = mime_type vision_message["image_url"]["format"] = mime_type
content.append(vision_message) content.append(vision_message)
messages.append({ messages.append({
"role": "user", "role": "user",
"content": content, "content": content,
@ -638,3 +648,4 @@ class LiteLLM:
f"Running {len(tasks)} tasks asynchronously in batches of {batch_size}" f"Running {len(tasks)} tasks asynchronously in batches of {batch_size}"
) )
return await self._process_batch(tasks, batch_size) return await self._process_batch(tasks, batch_size)
Loading…
Cancel
Save