parent
d3a28edd77
commit
754c2b710d
@ -0,0 +1,640 @@
|
|||||||
|
import traceback
|
||||||
|
from typing import Optional, List, Union
|
||||||
|
import base64
|
||||||
|
import requests
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
import litellm
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from litellm import completion, acompletion, supports_vision
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMException(Exception):
|
||||||
|
"""
|
||||||
|
Exception for LiteLLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_base64(audio_source: str) -> str:
|
||||||
|
"""
|
||||||
|
Convert audio from a given source to a base64 encoded string.
|
||||||
|
|
||||||
|
This function handles both URLs and local file paths. If the audio source is a URL, it fetches the audio data
|
||||||
|
from the internet. If it is a local file path, it reads the audio data from the specified file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_source (str): The source of the audio, which can be a URL or a local file path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A base64 encoded string representation of the audio data.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
requests.HTTPError: If the HTTP request to fetch audio data fails.
|
||||||
|
FileNotFoundError: If the local audio file does not exist.
|
||||||
|
"""
|
||||||
|
# Handle URL
|
||||||
|
if audio_source.startswith(("http://", "https://")):
|
||||||
|
response = requests.get(audio_source)
|
||||||
|
response.raise_for_status()
|
||||||
|
audio_data = response.content
|
||||||
|
# Handle local file
|
||||||
|
else:
|
||||||
|
with open(audio_source, "rb") as file:
|
||||||
|
audio_data = file.read()
|
||||||
|
|
||||||
|
encoded_string = base64.b64encode(audio_data).decode("utf-8")
|
||||||
|
return encoded_string
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_base64(image_source: str) -> str:
|
||||||
|
"""
|
||||||
|
Convert image from a given source to a base64 encoded string.
|
||||||
|
Handles URLs, local file paths, and data URIs.
|
||||||
|
"""
|
||||||
|
# If already a data URI, return as is
|
||||||
|
if image_source.startswith("data:image"):
|
||||||
|
return image_source
|
||||||
|
|
||||||
|
# Handle URL
|
||||||
|
if image_source.startswith(("http://", "https://")):
|
||||||
|
response = requests.get(image_source)
|
||||||
|
response.raise_for_status()
|
||||||
|
image_data = response.content
|
||||||
|
# Handle local file
|
||||||
|
else:
|
||||||
|
with open(image_source, "rb") as file:
|
||||||
|
image_data = file.read()
|
||||||
|
|
||||||
|
# Get file extension for mime type
|
||||||
|
extension = Path(image_source).suffix.lower()
|
||||||
|
mime_type = (
|
||||||
|
f"image/{extension[1:]}" if extension else "image/jpeg"
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded_string = base64.b64encode(image_data).decode("utf-8")
|
||||||
|
return f"data:{mime_type};base64,{encoded_string}"
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM:
|
||||||
|
"""
|
||||||
|
This class represents a LiteLLM.
|
||||||
|
It is used to interact with the LLM model for various tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "gpt-4o",
|
||||||
|
system_prompt: str = None,
|
||||||
|
stream: bool = False,
|
||||||
|
temperature: float = 0.5,
|
||||||
|
max_tokens: int = 4000,
|
||||||
|
ssl_verify: bool = False,
|
||||||
|
max_completion_tokens: int = 4000,
|
||||||
|
tools_list_dictionary: List[dict] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
parallel_tool_calls: bool = False,
|
||||||
|
audio: str = None,
|
||||||
|
retries: int = 0,
|
||||||
|
verbose: bool = False,
|
||||||
|
caching: bool = False,
|
||||||
|
mcp_call: bool = False,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
functions: List[dict] = None,
|
||||||
|
return_all: bool = False,
|
||||||
|
base_url: str = None,
|
||||||
|
api_key: str = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the LiteLLM with the given parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str, optional): The name of the model to use. Defaults to "gpt-4o".
|
||||||
|
system_prompt (str, optional): The system prompt to use. Defaults to None.
|
||||||
|
stream (bool, optional): Whether to stream the output. Defaults to False.
|
||||||
|
temperature (float, optional): The temperature for the model. Defaults to 0.5.
|
||||||
|
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 4000.
|
||||||
|
"""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
self.stream = stream
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.ssl_verify = ssl_verify
|
||||||
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
self.tools_list_dictionary = tools_list_dictionary
|
||||||
|
self.tool_choice = tool_choice
|
||||||
|
self.parallel_tool_calls = parallel_tool_calls
|
||||||
|
self.caching = caching
|
||||||
|
self.mcp_call = mcp_call
|
||||||
|
self.top_p = top_p
|
||||||
|
self.functions = functions
|
||||||
|
self.audio = audio
|
||||||
|
self.return_all = return_all
|
||||||
|
self.base_url = base_url
|
||||||
|
self.api_key = api_key
|
||||||
|
self.modalities = []
|
||||||
|
self.messages = [] # Initialize messages list
|
||||||
|
|
||||||
|
# Configure litellm settings
|
||||||
|
litellm.set_verbose = (
|
||||||
|
verbose # Disable verbose mode for better performance
|
||||||
|
)
|
||||||
|
litellm.ssl_verify = ssl_verify
|
||||||
|
litellm.num_retries = (
|
||||||
|
retries # Add retries for better reliability
|
||||||
|
)
|
||||||
|
|
||||||
|
def output_for_tools(self, response: any):
|
||||||
|
if self.mcp_call is True:
|
||||||
|
out = response.choices[0].message.tool_calls[0].function
|
||||||
|
output = {
|
||||||
|
"function": {
|
||||||
|
"name": out.name,
|
||||||
|
"arguments": out.arguments,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
else:
|
||||||
|
out = response.choices[0].message.tool_calls
|
||||||
|
|
||||||
|
if isinstance(out, BaseModel):
|
||||||
|
out = out.model_dump()
|
||||||
|
return out
|
||||||
|
|
||||||
|
# Modification: Updated _prepare_messages method to accept img parameter as string or list of strings
|
||||||
|
def _prepare_messages(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
img: Union[str, List[str]] = None, #Modification: Parameter type is string or string list
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Prepare the messages for the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to prepare messages for.
|
||||||
|
img (Union[str, List[str]], optional): Single image input or list of image inputs. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of messages prepared for the task.
|
||||||
|
"""
|
||||||
|
# Edit: Convert single image string to list for unified processing
|
||||||
|
image_list = []
|
||||||
|
if img is not None:
|
||||||
|
if isinstance(img, str):
|
||||||
|
image_list = [img]
|
||||||
|
else:
|
||||||
|
image_list = img
|
||||||
|
|
||||||
|
# Edit: Check if there is an image to process
|
||||||
|
if image_list:
|
||||||
|
self.check_if_model_supports_vision(image_list=image_list)
|
||||||
|
|
||||||
|
# Initialize messages
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
# Add system prompt if present
|
||||||
|
if self.system_prompt is not None:
|
||||||
|
messages.append(
|
||||||
|
{"role": "system", "content": self.system_prompt}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle vision case
|
||||||
|
if image_list: # 修改:处理图像列表
|
||||||
|
messages = self.vision_processing(
|
||||||
|
task=task, images=image_list, messages=messages
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
messages.append({"role": "user", "content": task})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# Modification: Updated anthropic_vision_processing method to handle multiple images
|
||||||
|
def anthropic_vision_processing(
|
||||||
|
self, task: str, images: List[str], messages: list
|
||||||
|
) -> list:
|
||||||
|
"""
|
||||||
|
Process vision input specifically for Anthropic models.
|
||||||
|
Handles Anthropic's specific image format requirements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task prompt
|
||||||
|
images (List[str]): List of image paths or URLs
|
||||||
|
messages (list): Current message list
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Updated messages list with images
|
||||||
|
"""
|
||||||
|
|
||||||
|
content = [{"type": "text", "text": task}]
|
||||||
|
|
||||||
|
|
||||||
|
for image in images:
|
||||||
|
|
||||||
|
image_url = get_image_base64(image)
|
||||||
|
|
||||||
|
|
||||||
|
mime_type = "image/jpeg"
|
||||||
|
if "data:" in image_url and ";base64," in image_url:
|
||||||
|
mime_type = image_url.split(";base64,")[0].split("data:")[1]
|
||||||
|
|
||||||
|
|
||||||
|
supported_formats = [
|
||||||
|
"image/jpeg",
|
||||||
|
"image/png",
|
||||||
|
"image/gif",
|
||||||
|
"image/webp",
|
||||||
|
]
|
||||||
|
if mime_type not in supported_formats:
|
||||||
|
mime_type = "image/jpeg"
|
||||||
|
|
||||||
|
|
||||||
|
content.append({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url,
|
||||||
|
"format": mime_type,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": content,
|
||||||
|
})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# Modification: Updated openai_vision_processing method to handle multiple images
|
||||||
|
def openai_vision_processing(
|
||||||
|
self, task: str, images: List[str], messages: list
|
||||||
|
) -> list:
|
||||||
|
"""
|
||||||
|
Process vision input specifically for OpenAI models.
|
||||||
|
Handles OpenAI's specific image format requirements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task prompt
|
||||||
|
images (List[str]): List of image paths or URLs
|
||||||
|
messages (list): Current message list
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Updated messages list with images
|
||||||
|
"""
|
||||||
|
|
||||||
|
content = [{"type": "text", "text": task}]
|
||||||
|
|
||||||
|
|
||||||
|
for image in images:
|
||||||
|
|
||||||
|
image_url = get_image_base64(image)
|
||||||
|
|
||||||
|
|
||||||
|
vision_message = {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": image_url},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
extension = Path(image).suffix.lower()
|
||||||
|
mime_type = f"image/{extension[1:]}" if extension else "image/jpeg"
|
||||||
|
vision_message["image_url"]["format"] = mime_type
|
||||||
|
|
||||||
|
|
||||||
|
content.append(vision_message)
|
||||||
|
|
||||||
|
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": content,
|
||||||
|
})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# Modification: Updated vision_processing method to handle multiple images
|
||||||
|
def vision_processing(
|
||||||
|
self, task: str, images: List[str], messages: Optional[list] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Process the images for the given task.
|
||||||
|
Handles different image formats and model requirements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task prompt
|
||||||
|
images (List[str]): List of image paths or URLs
|
||||||
|
messages (Optional[list], optional): Current messages list. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Updated messages with image content
|
||||||
|
"""
|
||||||
|
if messages is None:
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
if (
|
||||||
|
"anthropic" in self.model_name.lower()
|
||||||
|
or "claude" in self.model_name.lower()
|
||||||
|
):
|
||||||
|
messages = self.anthropic_vision_processing(
|
||||||
|
task, images, messages
|
||||||
|
)
|
||||||
|
return messages
|
||||||
|
else:
|
||||||
|
messages = self.openai_vision_processing(
|
||||||
|
task, images, messages
|
||||||
|
)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def audio_processing(self, task: str, audio: str):
|
||||||
|
"""
|
||||||
|
Process the audio for the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to be processed.
|
||||||
|
audio (str): The path or identifier for the audio file.
|
||||||
|
"""
|
||||||
|
encoded_string = get_audio_base64(audio)
|
||||||
|
|
||||||
|
# Append audio message
|
||||||
|
self.messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": task},
|
||||||
|
{
|
||||||
|
"type": "input_audio",
|
||||||
|
"input_audio": {
|
||||||
|
"data": encoded_string,
|
||||||
|
"format": "wav",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Modification: Updated check_if_model_supports_vision method to support image lists
|
||||||
|
def check_if_model_supports_vision(self, img: str = None, image_list: List[str] = None):
|
||||||
|
"""
|
||||||
|
Check if the model supports vision.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (str, optional): Single image path (for backward compatibility). Defaults to None.
|
||||||
|
image_list (List[str], optional): List of image paths. Defaults to None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the model does not support vision.
|
||||||
|
"""
|
||||||
|
# If there are any images (single or multiple), check if the model supports vision
|
||||||
|
if img is not None or (image_list and len(image_list) > 0):
|
||||||
|
out = supports_vision(model=self.model_name)
|
||||||
|
|
||||||
|
if out is False:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {self.model_name} does not support vision"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Modification: Update the run method so that the img parameter can accept a string or a list of strings
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
audio: Optional[str] = None,
|
||||||
|
img: Union[str, List[str]] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run the LLM model for the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to run the model for.
|
||||||
|
audio (str, optional): Audio input if any. Defaults to None.
|
||||||
|
img (Union[str, List[str]], optional): Single image input or list of image inputs. Defaults to None.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The content of the response from the model.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If there is an error in processing the request.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
|
||||||
|
messages = self._prepare_messages(task=task, img=img)
|
||||||
|
|
||||||
|
# Base completion parameters
|
||||||
|
completion_params = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": self.stream,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"caching": self.caching,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add temperature for non-o4/o3 models
|
||||||
|
if self.model_name not in [
|
||||||
|
"openai/o4-mini",
|
||||||
|
"openai/o3-2025-04-16",
|
||||||
|
]:
|
||||||
|
completion_params["temperature"] = self.temperature
|
||||||
|
|
||||||
|
# Add tools if specified
|
||||||
|
if self.tools_list_dictionary is not None:
|
||||||
|
completion_params.update(
|
||||||
|
{
|
||||||
|
"tools": self.tools_list_dictionary,
|
||||||
|
"tool_choice": self.tool_choice,
|
||||||
|
"parallel_tool_calls": self.parallel_tool_calls,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.functions is not None:
|
||||||
|
completion_params.update(
|
||||||
|
{"functions": self.functions}
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.base_url is not None:
|
||||||
|
completion_params["base_url"] = self.base_url
|
||||||
|
|
||||||
|
# Add modalities if needed
|
||||||
|
if self.modalities and len(self.modalities) >= 2:
|
||||||
|
completion_params["modalities"] = self.modalities
|
||||||
|
|
||||||
|
# Make the completion call
|
||||||
|
response = completion(**completion_params)
|
||||||
|
|
||||||
|
# Handle tool-based response
|
||||||
|
if self.tools_list_dictionary is not None:
|
||||||
|
return self.output_for_tools(response)
|
||||||
|
elif self.return_all is True:
|
||||||
|
return response.model_dump()
|
||||||
|
else:
|
||||||
|
# Return standard response content
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
except LiteLLMException as error:
|
||||||
|
logger.error(
|
||||||
|
f"Error in LiteLLM run: {str(error)} Traceback: {traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
if "rate_limit" in str(error).lower():
|
||||||
|
logger.warning(
|
||||||
|
"Rate limit hit, retrying with exponential backoff..."
|
||||||
|
)
|
||||||
|
import time
|
||||||
|
|
||||||
|
time.sleep(2)
|
||||||
|
return self.run(task, audio, img, *args, **kwargs)
|
||||||
|
raise error
|
||||||
|
|
||||||
|
def __call__(self, task: str, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Call the LLM model for the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to run the model for.
|
||||||
|
*args: Additional positional arguments to pass to the model.
|
||||||
|
**kwargs: Additional keyword arguments to pass to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The content of the response from the model.
|
||||||
|
"""
|
||||||
|
return self.run(task, *args, **kwargs)
|
||||||
|
|
||||||
|
# Modification: Updated arun method to accept img parameter as string or list of strings
|
||||||
|
async def arun(self, task: str, img: Union[str, List[str]] = None, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Run the LLM model asynchronously for the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The task to run the model for.
|
||||||
|
img (Union[str, List[str]], optional): Single image input or list of image inputs. Defaults to None.
|
||||||
|
*args: Additional positional arguments.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The content of the response from the model.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
|
||||||
|
messages = self._prepare_messages(task=task, img=img)
|
||||||
|
|
||||||
|
# Prepare common completion parameters
|
||||||
|
completion_params = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": self.stream,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle tool-based completion
|
||||||
|
if self.tools_list_dictionary is not None:
|
||||||
|
completion_params.update(
|
||||||
|
{
|
||||||
|
"tools": self.tools_list_dictionary,
|
||||||
|
"tool_choice": self.tool_choice,
|
||||||
|
"parallel_tool_calls": self.parallel_tool_calls,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
response = await acompletion(**completion_params)
|
||||||
|
return (
|
||||||
|
response.choices[0]
|
||||||
|
.message.tool_calls[0]
|
||||||
|
.function.arguments
|
||||||
|
)
|
||||||
|
|
||||||
|
# Standard completion
|
||||||
|
response = await acompletion(**completion_params)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as error:
|
||||||
|
logger.error(f"Error in LiteLLM arun: {str(error)}")
|
||||||
|
# if "rate_limit" in str(error).lower():
|
||||||
|
# logger.warning(
|
||||||
|
# "Rate limit hit, retrying with exponential backoff..."
|
||||||
|
# )
|
||||||
|
# await asyncio.sleep(2) # Use async sleep
|
||||||
|
# return await self.arun(task, *args, **kwargs)
|
||||||
|
raise error
|
||||||
|
|
||||||
|
async def _process_batch(
|
||||||
|
self, tasks: List[str], batch_size: int = 10
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Process a batch of tasks asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks (List[str]): List of tasks to process.
|
||||||
|
batch_size (int): Size of each batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of responses.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for i in range(0, len(tasks), batch_size):
|
||||||
|
batch = tasks[i : i + batch_size]
|
||||||
|
batch_results = await asyncio.gather(
|
||||||
|
*[self.arun(task) for task in batch],
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle any exceptions in the batch
|
||||||
|
for result in batch_results:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.error(
|
||||||
|
f"Error in batch processing: {str(result)}"
|
||||||
|
)
|
||||||
|
results.append(str(result))
|
||||||
|
else:
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# Add a small delay between batches to avoid rate limits
|
||||||
|
if i + batch_size < len(tasks):
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def batched_run(self, tasks: List[str], batch_size: int = 10):
|
||||||
|
"""
|
||||||
|
Run multiple tasks in batches synchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks (List[str]): List of tasks to process.
|
||||||
|
batch_size (int): Size of each batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of responses.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Running {len(tasks)} tasks in batches of {batch_size}"
|
||||||
|
)
|
||||||
|
return asyncio.run(self._process_batch(tasks, batch_size))
|
||||||
|
|
||||||
|
async def batched_arun(
|
||||||
|
self, tasks: List[str], batch_size: int = 10
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run multiple tasks in batches asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks (List[str]): List of tasks to process.
|
||||||
|
batch_size (int): Size of each batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of responses.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Running {len(tasks)} tasks asynchronously in batches of {batch_size}"
|
||||||
|
)
|
||||||
|
return await self._process_batch(tasks, batch_size)
|
Loading…
Reference in new issue