更新 litellm_wrapper.py 文件

pull/916/head
王祥宇 4 weeks ago
parent a2c42fca54
commit c28dea745f

@ -1,3 +1,4 @@
import traceback
from typing import Optional, List, Union
import base64
@ -51,11 +52,24 @@ def get_audio_base64(audio_source: str) -> str:
return encoded_string
def get_image_base64(image_source: str) -> str:
# 修改:更新函数签名和实现以支持列表输入
def get_image_base64(image_source: Union[str, List[str]]) -> Union[str, List[str]]:
"""
Convert image from a given source to a base64 encoded string.
Handles URLs, local file paths, and data URIs.
Now supports both single image path and list of image paths.
Args:
image_source: String path to image or list of image paths
Returns:
Single base64 string or list of base64 strings
"""
# 处理图像列表
if isinstance(image_source, list):
return [get_image_base64(single_image) for single_image in image_source]
# 处理单个图像(原始逻辑)
# If already a data URI, return as is
if image_source.startswith("data:image"):
return image_source
@ -234,17 +248,16 @@ class LiteLLM:
content = [{"type": "text", "text": task}]
for image in images:
image_url = get_image_base64(image)
# 修改使用新版get_image_base64函数处理图像列表
image_urls = get_image_base64(images)
if not isinstance(image_urls, list):
image_urls = [image_urls]
for i, image_url in enumerate(image_urls):
mime_type = "image/jpeg"
if "data:" in image_url and ";base64," in image_url:
mime_type = image_url.split(";base64,")[0].split("data:")[1]
supported_formats = [
"image/jpeg",
"image/png",
@ -254,7 +267,6 @@ class LiteLLM:
if mime_type not in supported_formats:
mime_type = "image/jpeg"
content.append({
"type": "image_url",
"image_url": {
@ -263,7 +275,6 @@ class LiteLLM:
},
})
messages.append({
"role": "user",
"content": content,
@ -290,26 +301,25 @@ class LiteLLM:
content = [{"type": "text", "text": task}]
for image in images:
image_url = get_image_base64(image)
# 修改使用新版get_image_base64函数处理图像列表
image_urls = get_image_base64(images)
if not isinstance(image_urls, list):
image_urls = [image_urls]
for i, image_url in enumerate(image_urls):
vision_message = {
"type": "image_url",
"image_url": {"url": image_url},
}
extension = Path(image).suffix.lower()
# 获取对应图像的原始路径
original_image = images[i] if i < len(images) else images[0]
extension = Path(original_image).suffix.lower()
mime_type = f"image/{extension[1:]}" if extension else "image/jpeg"
vision_message["image_url"]["format"] = mime_type
content.append(vision_message)
messages.append({
"role": "user",
"content": content,
@ -637,4 +647,5 @@ class LiteLLM:
logger.info(
f"Running {len(tasks)} tasks asynchronously in batches of {batch_size}"
)
return await self._process_batch(tasks, batch_size)
return await self._process_batch(tasks, batch_size)
Loading…
Cancel
Save