From c28dea745f70c876777b7f0c506b662e550a4ac9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=A5=A5=E5=AE=87?= <625024108@qq.com> Date: Sun, 29 Jun 2025 11:29:38 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20litellm=5Fwrapper.py=20?= =?UTF-8?q?=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- swarms/utils/litellm_wrapper.py | 49 ++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/swarms/utils/litellm_wrapper.py b/swarms/utils/litellm_wrapper.py index d5ed3f60..96390825 100644 --- a/swarms/utils/litellm_wrapper.py +++ b/swarms/utils/litellm_wrapper.py @@ -1,3 +1,4 @@ + import traceback from typing import Optional, List, Union import base64 @@ -51,11 +52,24 @@ def get_audio_base64(audio_source: str) -> str: return encoded_string -def get_image_base64(image_source: str) -> str: +# 修改:更新函数签名和实现以支持列表输入 +def get_image_base64(image_source: Union[str, List[str]]) -> Union[str, List[str]]: """ Convert image from a given source to a base64 encoded string. Handles URLs, local file paths, and data URIs. + Now supports both single image path and list of image paths. + + Args: + image_source: String path to image or list of image paths + + Returns: + Single base64 string or list of base64 strings """ + # 处理图像列表 + if isinstance(image_source, list): + return [get_image_base64(single_image) for single_image in image_source] + + # 处理单个图像(原始逻辑) # If already a data URI, return as is if image_source.startswith("data:image"): return image_source @@ -234,17 +248,16 @@ class LiteLLM: content = [{"type": "text", "text": task}] - - for image in images: - - image_url = get_image_base64(image) + # 修改:使用新版get_image_base64函数处理图像列表 + image_urls = get_image_base64(images) + if not isinstance(image_urls, list): + image_urls = [image_urls] - + for i, image_url in enumerate(image_urls): mime_type = "image/jpeg" if "data:" in image_url and ";base64," in image_url: mime_type = image_url.split(";base64,")[0].split("data:")[1] - supported_formats = [ "image/jpeg", "image/png", @@ -254,7 +267,6 @@ class LiteLLM: if mime_type not in supported_formats: mime_type = "image/jpeg" - content.append({ "type": "image_url", "image_url": { @@ -263,7 +275,6 @@ class LiteLLM: }, }) - messages.append({ "role": "user", "content": content, @@ -290,26 +301,25 @@ class LiteLLM: content = [{"type": "text", "text": task}] - - for image in images: - - image_url = get_image_base64(image) - + # 修改:使用新版get_image_base64函数处理图像列表 + image_urls = get_image_base64(images) + if not isinstance(image_urls, list): + image_urls = [image_urls] + for i, image_url in enumerate(image_urls): vision_message = { "type": "image_url", "image_url": {"url": image_url}, } - - extension = Path(image).suffix.lower() + # 获取对应图像的原始路径 + original_image = images[i] if i < len(images) else images[0] + extension = Path(original_image).suffix.lower() mime_type = f"image/{extension[1:]}" if extension else "image/jpeg" vision_message["image_url"]["format"] = mime_type - content.append(vision_message) - messages.append({ "role": "user", "content": content, @@ -637,4 +647,5 @@ class LiteLLM: logger.info( f"Running {len(tasks)} tasks asynchronously in batches of {batch_size}" ) - return await self._process_batch(tasks, batch_size) \ No newline at end of file + return await self._process_batch(tasks, batch_size) + \ No newline at end of file