From b4f59aad401ecbdd8c8f13dee8264310b327e652 Mon Sep 17 00:00:00 2001 From: Kye Gomez Date: Sun, 6 Apr 2025 08:56:07 +0800 Subject: [PATCH] llama 4 docs --- docs/mkdocs.yml | 36 +++-- docs/swarms/examples/llama4.md | 161 +++++++++++++++++++ example.py | 2 +- examples/swarms_of_vllm.py | 6 +- litellm_example.py | 5 + llama_4.py | 55 +++++++ mcp_example/agent_mcp_test.py | 1 - pyproject.toml | 2 +- swarms/structs/agent.py | 142 +++++++++++------ swarms/utils/litellm_wrapper.py | 264 +++++++++++++++++++++----------- 10 files changed, 515 insertions(+), 159 deletions(-) create mode 100644 docs/swarms/examples/llama4.md create mode 100644 litellm_example.py create mode 100644 llama_4.py diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 52c47599..807792c3 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -267,6 +267,24 @@ nav: - ChromaDB: "swarms_memory/chromadb.md" - Pinecone: "swarms_memory/pinecone.md" - Faiss: "swarms_memory/faiss.md" + + - About Us: + - Swarms Vision: "swarms/concept/vision.md" + - Swarm Ecosystem: "swarms/concept/swarm_ecosystem.md" + - Swarms Products: "swarms/products.md" + - Swarms Framework Architecture: "swarms/concept/framework_architecture.md" + + - Developers and Contributors: + - Bounty Program: "corporate/bounty_program.md" + - Contributing: + - Contributing: "swarms/contributing.md" + - Tests: "swarms/framework/test.md" + - Code Cleanliness: "swarms/framework/code_cleanliness.md" + - Philosophy: "swarms/concept/philosophy.md" + - Changelog: + - Swarms 5.6.8: "swarms/changelog/5_6_8.md" + - Swarms 5.8.1: "swarms/changelog/5_8_1.md" + - Swarms 5.9.2: "swarms/changelog/changelog_new.md" - Examples: - Overview: "swarms/examples/unique_swarms.md" @@ -284,6 +302,7 @@ nav: - OpenRouter: "swarms/examples/openrouter.md" - XAI: "swarms/examples/xai.md" - VLLM: "swarms/examples/vllm_integration.md" + - Llama4: "swarms/examples/llama4.md" - Swarms Tools: - Agent with Yahoo Finance: "swarms/examples/yahoo_finance.md" - Twitter Agents: "swarms_tools/twitter.md" @@ -342,23 +361,6 @@ nav: - Swarm Platform API Keys: "swarms_platform/apikeys.md" - Account Management: "swarms_platform/account_management.md" - - About Us: - - Swarms Vision: "swarms/concept/vision.md" - - Swarm Ecosystem: "swarms/concept/swarm_ecosystem.md" - - Swarms Products: "swarms/products.md" - - Swarms Framework Architecture: "swarms/concept/framework_architecture.md" - - - Contributors: - - Bounty Program: "corporate/bounty_program.md" - - Contributing: - - Contributing: "swarms/contributing.md" - - Tests: "swarms/framework/test.md" - - Code Cleanliness: "swarms/framework/code_cleanliness.md" - - Philosophy: "swarms/concept/philosophy.md" - - Changelog: - - Swarms 5.6.8: "swarms/changelog/5_6_8.md" - - Swarms 5.8.1: "swarms/changelog/5_8_1.md" - - Swarms 5.9.2: "swarms/changelog/changelog_new.md" # - Prompts API: # - Add Prompts: "swarms_platform/prompts/add_prompt.md" diff --git a/docs/swarms/examples/llama4.md b/docs/swarms/examples/llama4.md new file mode 100644 index 00000000..4367bc1c --- /dev/null +++ b/docs/swarms/examples/llama4.md @@ -0,0 +1,161 @@ +# Llama4 Model Integration + +!!! info "Prerequisites" + - Python 3.8 or higher + - `swarms` library installed + - Access to Llama4 model + - Valid environment variables configured + +## Quick Start + +Here's a simple example of integrating Llama4 model for crypto risk analysis: + +```python +from dotenv import load_dotenv +from swarms import Agent +from swarms.utils.vllm_wrapper import VLLM + +load_dotenv() +model = VLLM(model_name="meta-llama/Llama-4-Maverick-17B-128E") +``` + +!!! tip "Environment Setup" + Make sure to set up your environment variables properly before running the code. + Create a `.env` file in your project root if needed. + +## Detailed Implementation + +### 1. Define Custom System Prompt + +```python +CRYPTO_RISK_ANALYSIS_PROMPT = """ +You are a cryptocurrency risk analysis expert. Your role is to: + +1. Analyze market risks: + - Volatility assessment + - Market sentiment analysis + - Trading volume patterns + - Price trend evaluation + +2. Evaluate technical risks: + - Network security + - Protocol vulnerabilities + - Smart contract risks + - Technical scalability + +3. Consider regulatory risks: + - Current regulations + - Potential regulatory changes + - Compliance requirements + - Geographic restrictions + +4. Assess fundamental risks: + - Team background + - Project development status + - Competition analysis + - Use case viability + +Provide detailed, balanced analysis with both risks and potential mitigations. +Base your analysis on established crypto market principles and current market conditions. +""" +``` + +### 2. Initialize Agent + +```python +agent = Agent( + agent_name="Crypto-Risk-Analysis-Agent", + agent_description="Agent for analyzing risks in cryptocurrency investments", + system_prompt=CRYPTO_RISK_ANALYSIS_PROMPT, + max_loops=1, + llm=model, +) +``` + +## Full Code + +```python +from dotenv import load_dotenv + +from swarms import Agent +from swarms.utils.vllm_wrapper import VLLM + +load_dotenv() + +# Define custom system prompt for crypto risk analysis +CRYPTO_RISK_ANALYSIS_PROMPT = """ +You are a cryptocurrency risk analysis expert. Your role is to: + +1. Analyze market risks: + - Volatility assessment + - Market sentiment analysis + - Trading volume patterns + - Price trend evaluation + +2. Evaluate technical risks: + - Network security + - Protocol vulnerabilities + - Smart contract risks + - Technical scalability + +3. Consider regulatory risks: + - Current regulations + - Potential regulatory changes + - Compliance requirements + - Geographic restrictions + +4. Assess fundamental risks: + - Team background + - Project development status + - Competition analysis + - Use case viability + +Provide detailed, balanced analysis with both risks and potential mitigations. +Base your analysis on established crypto market principles and current market conditions. +""" + +model = VLLM(model_name="meta-llama/Llama-4-Maverick-17B-128E") + +# Initialize the agent with custom prompt +agent = Agent( + agent_name="Crypto-Risk-Analysis-Agent", + agent_description="Agent for analyzing risks in cryptocurrency investments", + system_prompt=CRYPTO_RISK_ANALYSIS_PROMPT, + max_loops=1, + llm=model, +) + +print( + agent.run( + "Conduct a risk analysis of the top cryptocurrencies. Think for 2 loops internally" + ) +) +``` + +!!! warning "Resource Usage" + The Llama4 model requires significant computational resources. Ensure your system meets the minimum requirements. + +## FAQ + +??? question "What is the purpose of max_loops parameter?" + The `max_loops` parameter determines how many times the agent will iterate through its thinking process. In this example, it's set to 1 for a single pass analysis. + +??? question "Can I use a different model?" + Yes, you can replace the VLLM wrapper with other compatible models. Just ensure you update the model initialization accordingly. + +??? question "How do I customize the system prompt?" + You can modify the `CRYPTO_RISK_ANALYSIS_PROMPT` string to match your specific use case while maintaining the structured format. + +!!! note "Best Practices" + - Always handle API errors gracefully + - Monitor model performance and resource usage + - Keep your prompts clear and specific + - Test thoroughly before production deployment + +!!! example "Sample Usage" + ```python + response = agent.run( + "Conduct a risk analysis of the top cryptocurrencies. Think for 2 loops internally" + ) + print(response) + ``` \ No newline at end of file diff --git a/example.py b/example.py index a6149463..bb40e770 100644 --- a/example.py +++ b/example.py @@ -13,7 +13,7 @@ agent = Agent( agent_description="Personal finance advisor agent", system_prompt=FINANCIAL_AGENT_SYS_PROMPT, max_loops=2, - model_name="gpt-4o", + model_name="groq/llama-3.3-70b-versatile", dynamic_temperature_enabled=True, user_name="swarms_corp", retry_attempts=3, diff --git a/examples/swarms_of_vllm.py b/examples/swarms_of_vllm.py index a463d443..89191ab0 100644 --- a/examples/swarms_of_vllm.py +++ b/examples/swarms_of_vllm.py @@ -200,7 +200,7 @@ stock_analysis_agents = [ fundamental_analyst, sentiment_analyst, quant_analyst, - portfolio_strategist + portfolio_strategist, ] swarm = ConcurrentWorkflow( @@ -209,4 +209,6 @@ swarm = ConcurrentWorkflow( agents=stock_analysis_agents, ) -swarm.run("Analyze the best etfs for gold and other similiar commodities in volatile markets") \ No newline at end of file +swarm.run( + "Analyze the best etfs for gold and other similiar commodities in volatile markets" +) diff --git a/litellm_example.py b/litellm_example.py new file mode 100644 index 00000000..63b297ef --- /dev/null +++ b/litellm_example.py @@ -0,0 +1,5 @@ +from swarms.utils.litellm_wrapper import LiteLLM + +model = LiteLLM(model_name="gpt-4o-mini", verbose=True) + +print(model.run("What is your purpose in life?")) diff --git a/llama_4.py b/llama_4.py new file mode 100644 index 00000000..df4a08b7 --- /dev/null +++ b/llama_4.py @@ -0,0 +1,55 @@ +from dotenv import load_dotenv + +from swarms import Agent +from swarms.utils.vllm_wrapper import VLLM + +load_dotenv() + +# Define custom system prompt for crypto risk analysis +CRYPTO_RISK_ANALYSIS_PROMPT = """ +You are a cryptocurrency risk analysis expert. Your role is to: + +1. Analyze market risks: + - Volatility assessment + - Market sentiment analysis + - Trading volume patterns + - Price trend evaluation + +2. Evaluate technical risks: + - Network security + - Protocol vulnerabilities + - Smart contract risks + - Technical scalability + +3. Consider regulatory risks: + - Current regulations + - Potential regulatory changes + - Compliance requirements + - Geographic restrictions + +4. Assess fundamental risks: + - Team background + - Project development status + - Competition analysis + - Use case viability + +Provide detailed, balanced analysis with both risks and potential mitigations. +Base your analysis on established crypto market principles and current market conditions. +""" + +model = VLLM(model_name="meta-llama/Llama-4-Maverick-17B-128E") + +# Initialize the agent with custom prompt +agent = Agent( + agent_name="Crypto-Risk-Analysis-Agent", + agent_description="Agent for analyzing risks in cryptocurrency investments", + system_prompt=CRYPTO_RISK_ANALYSIS_PROMPT, + max_loops=1, + llm=model, +) + +print( + agent.run( + "Conduct a risk analysis of the top cryptocurrencies. Think for 2 loops internally" + ) +) \ No newline at end of file diff --git a/mcp_example/agent_mcp_test.py b/mcp_example/agent_mcp_test.py index 0e3b6a8e..86c19a25 100644 --- a/mcp_example/agent_mcp_test.py +++ b/mcp_example/agent_mcp_test.py @@ -1,4 +1,3 @@ - from swarms import Agent from swarms.prompts.finance_agent_sys_prompt import ( FINANCIAL_AGENT_SYS_PROMPT, diff --git a/pyproject.toml b/pyproject.toml index 08f59ac3..be5f833d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "swarms" -version = "7.6.4" +version = "7.6.5" description = "Swarms - TGSC" license = "MIT" authors = ["Kye Gomez "] diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index 7fd8099c..187860ed 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -46,6 +46,11 @@ from swarms.structs.safe_loading import ( ) from swarms.telemetry.main import log_agent_data from swarms.tools.base_tool import BaseTool +from swarms.tools.mcp_integration import ( + MCPServerSseParams, + batch_mcp_flow, + mcp_flow_get_tool_schema, +) from swarms.tools.tool_parse_exec import parse_and_execute_json from swarms.utils.any_to_str import any_to_str from swarms.utils.data_to_text import data_to_text @@ -55,15 +60,10 @@ from swarms.utils.history_output_formatter import ( history_output_formatter, ) from swarms.utils.litellm_tokenizer import count_tokens +from swarms.utils.litellm_wrapper import LiteLLM from swarms.utils.pdf_to_text import pdf_to_text from swarms.utils.str_to_dict import str_to_dict -from swarms.tools.mcp_integration import ( - batch_mcp_flow, - mcp_flow_get_tool_schema, - MCPServerSseParams, -) - # Utils # Custom stopping condition @@ -104,6 +104,51 @@ agent_output_type = Literal[ ToolUsageType = Union[BaseModel, Dict[str, Any]] +# Agent Exceptions + + +class AgentError(Exception): + """Base class for all agent-related exceptions.""" + + pass + + +class AgentInitializationError(AgentError): + """Exception raised when the agent fails to initialize properly. Please check the configuration and parameters.""" + + pass + + +class AgentRunError(AgentError): + """Exception raised when the agent encounters an error during execution. Ensure that the task and environment are set up correctly.""" + + pass + + +class AgentLLMError(AgentError): + """Exception raised when there is an issue with the language model (LLM). Verify the model's availability and compatibility.""" + + pass + + +class AgentToolError(AgentError): + """Exception raised when the agent fails to utilize a tool. Check the tool's configuration and availability.""" + + pass + + +class AgentMemoryError(AgentError): + """Exception raised when the agent encounters a memory-related issue. Ensure that memory resources are properly allocated and accessible.""" + + pass + + +class AgentLLMInitializationError(AgentError): + """Exception raised when the LLM fails to initialize properly. Please check the configuration and parameters.""" + + pass + + # [FEAT][AGENT] class Agent: """ @@ -479,6 +524,12 @@ class Agent: self.no_print = no_print self.tools_list_dictionary = tools_list_dictionary self.mcp_servers = mcp_servers + self._cached_llm = ( + None # Add this line to cache the LLM instance + ) + self._default_model = ( + "gpt-4o-mini" # Move default model name here + ) if ( self.agent_name is not None @@ -599,50 +650,49 @@ class Agent: self.tools_list_dictionary = self.mcp_tool_handling() def llm_handling(self): - from swarms.utils.litellm_wrapper import LiteLLM + # Use cached instance if available + if self._cached_llm is not None: + return self._cached_llm if self.model_name is None: - # raise ValueError("Model name cannot be None") logger.warning( - "Model name is not provided, using gpt-4o-mini. You can configure any model from litellm if desired." + f"Model name is not provided, using {self._default_model}. You can configure any model from litellm if desired." ) - self.model_name = "gpt-4o-mini" + self.model_name = self._default_model try: + # Simplify initialization logic + common_args = { + "model_name": self.model_name, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "system_prompt": self.system_prompt, + } + if self.llm_args is not None: - llm = LiteLLM( - model_name=self.model_name, **self.llm_args + self._cached_llm = LiteLLM( + **{**common_args, **self.llm_args} ) elif self.tools_list_dictionary is not None: - - length_of_tools_list_dictionary = len( - self.tools_list_dictionary - ) - - if length_of_tools_list_dictionary > 0: - - parallel_tool_calls = True - - llm = LiteLLM( - model_name=self.model_name, - temperature=self.temperature, - max_tokens=self.max_tokens, - system_prompt=self.system_prompt, + self._cached_llm = LiteLLM( + **common_args, tools_list_dictionary=self.tools_list_dictionary, tool_choice="auto", - parallel_tool_calls=parallel_tool_calls, + parallel_tool_calls=len( + self.tools_list_dictionary + ) + > 1, ) else: - llm = LiteLLM( - model_name=self.model_name, - temperature=self.temperature, - max_tokens=self.max_tokens, - system_prompt=self.system_prompt, - stream=self.streaming_on, + self._cached_llm = LiteLLM( + **common_args, stream=self.streaming_on ) - return llm - except Exception as e: - logger.error(f"Error in llm_handling: {e}") + + return self._cached_llm + except AgentLLMInitializationError as e: + logger.error( + f"Error in llm_handling: {e} Your current configuration is not supported. Please check the configuration and parameters." + ) return None def mcp_execution_flow(self, response: any): @@ -2336,6 +2386,8 @@ class Agent: Args: task (str): The task to be performed by the `llm` object. + img (str, optional): Path or URL to an image file. + audio (str, optional): Path or URL to an audio file. *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. @@ -2347,22 +2399,22 @@ class Agent: TypeError: If task is not a string or llm object is None. ValueError: If task is empty. """ - if not isinstance(task, str): - raise TypeError("Task must be a string") + # if not isinstance(task, str): + # task = any_to_str(task) - if task is None: - raise ValueError("Task cannot be None") + # if img is not None: + # kwargs['img'] = img - # if self.llm is None: - # raise TypeError("LLM object cannot be None") + # if audio is not None: + # kwargs['audio'] = audio try: - out = self.llm.run(task, *args, **kwargs) + out = self.llm.run(task=task, *args, **kwargs) return out - except AttributeError as e: + except AgentLLMError as e: logger.error( - f"Error calling LLM: {e} You need a class with a run(task: str) method" + f"Error calling LLM: {e}. Task: {task}, Args: {args}, Kwargs: {kwargs}" ) raise e diff --git a/swarms/utils/litellm_wrapper.py b/swarms/utils/litellm_wrapper.py index bd934af6..2e899b80 100644 --- a/swarms/utils/litellm_wrapper.py +++ b/swarms/utils/litellm_wrapper.py @@ -5,7 +5,7 @@ import asyncio from typing import List from loguru import logger - +import litellm try: from litellm import completion, acompletion @@ -77,6 +77,8 @@ class LiteLLM: tool_choice: str = "auto", parallel_tool_calls: bool = False, audio: str = None, + retries: int = 3, + verbose: bool = False, *args, **kwargs, ): @@ -100,7 +102,18 @@ class LiteLLM: self.tools_list_dictionary = tools_list_dictionary self.tool_choice = tool_choice self.parallel_tool_calls = parallel_tool_calls - self.modalities = ["text"] + self.modalities = [] + self._cached_messages = {} # Cache for prepared messages + self.messages = [] # Initialize messages list + + # Configure litellm settings + litellm.set_verbose = ( + verbose # Disable verbose mode for better performance + ) + litellm.ssl_verify = ssl_verify + litellm.num_retries = ( + retries # Add retries for better reliability + ) def _prepare_messages(self, task: str) -> list: """ @@ -112,15 +125,20 @@ class LiteLLM: Returns: list: A list of messages prepared for the task. """ - messages = [] + # Check cache first + cache_key = f"{self.system_prompt}:{task}" + if cache_key in self._cached_messages: + return self._cached_messages[cache_key].copy() - if self.system_prompt: # Check if system_prompt is not None + messages = [] + if self.system_prompt: messages.append( {"role": "system", "content": self.system_prompt} ) - messages.append({"role": "user", "content": task}) + # Cache the prepared messages + self._cached_messages[cache_key] = messages.copy() return messages def audio_processing(self, task: str, audio: str): @@ -182,15 +200,16 @@ class LiteLLM: """ Handle the modalities for the given task. """ + self.messages = [] # Reset messages + self.modalities.append("text") + if audio is not None: self.audio_processing(task=task, audio=audio) + self.modalities.append("audio") if img is not None: self.vision_processing(task=task, image=img) - - if audio is not None and img is not None: - self.audio_processing(task=task, audio=audio) - self.vision_processing(task=task, image=img) + self.modalities.append("vision") def run( self, @@ -205,58 +224,78 @@ class LiteLLM: Args: task (str): The task to run the model for. - *args: Additional positional arguments to pass to the model. - **kwargs: Additional keyword arguments to pass to the model. + audio (str, optional): Audio input if any. Defaults to None. + img (str, optional): Image input if any. Defaults to None. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. Returns: str: The content of the response from the model. + + Raises: + Exception: If there is an error in processing the request. """ try: - messages = self._prepare_messages(task) - self.handle_modalities(task=task, audio=audio, img=img) + if audio is not None or img is not None: + self.handle_modalities( + task=task, audio=audio, img=img + ) + messages = ( + self.messages + ) # Use modality-processed messages + + # Prepare common completion parameters + completion_params = { + "model": self.model_name, + "messages": messages, + "stream": self.stream, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + **kwargs, + } + # Handle tool-based completion if self.tools_list_dictionary is not None: - response = completion( - model=self.model_name, - messages=messages, - stream=self.stream, - temperature=self.temperature, - max_tokens=self.max_tokens, - tools=self.tools_list_dictionary, - modalities=self.modalities, - tool_choice=self.tool_choice, - parallel_tool_calls=self.parallel_tool_calls, - *args, - **kwargs, + completion_params.update( + { + "tools": self.tools_list_dictionary, + "tool_choice": self.tool_choice, + "parallel_tool_calls": self.parallel_tool_calls, + } ) - + response = completion(**completion_params) return ( response.choices[0] .message.tool_calls[0] .function.arguments ) - else: - response = completion( - model=self.model_name, - messages=messages, - stream=self.stream, - temperature=self.temperature, - max_tokens=self.max_tokens, - modalities=self.modalities, - *args, - **kwargs, + # Handle modality-based completion + if ( + self.modalities and len(self.modalities) > 1 + ): # More than just text + completion_params.update( + {"modalities": self.modalities} ) + response = completion(**completion_params) + return response.choices[0].message.content - content = response.choices[ - 0 - ].message.content # Accessing the content + # Standard completion + response = completion(**completion_params) + return response.choices[0].message.content - return content except Exception as error: - logger.error(f"Error in LiteLLM: {error}") + logger.error(f"Error in LiteLLM run: {str(error)}") + if "rate_limit" in str(error).lower(): + logger.warning( + "Rate limit hit, retrying with exponential backoff..." + ) + import time + + time.sleep(2) # Add a small delay before retry + return self.run(task, audio, img, *args, **kwargs) raise error def __call__(self, task: str, *args, **kwargs): @@ -275,12 +314,12 @@ class LiteLLM: async def arun(self, task: str, *args, **kwargs): """ - Run the LLM model for the given task. + Run the LLM model asynchronously for the given task. Args: task (str): The task to run the model for. - *args: Additional positional arguments to pass to the model. - **kwargs: Additional keyword arguments to pass to the model. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. Returns: str: The content of the response from the model. @@ -288,72 +327,113 @@ class LiteLLM: try: messages = self._prepare_messages(task) + # Prepare common completion parameters + completion_params = { + "model": self.model_name, + "messages": messages, + "stream": self.stream, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + **kwargs, + } + + # Handle tool-based completion if self.tools_list_dictionary is not None: - response = await acompletion( - model=self.model_name, - messages=messages, - stream=self.stream, - temperature=self.temperature, - max_tokens=self.max_tokens, - tools=self.tools_list_dictionary, - tool_choice=self.tool_choice, - parallel_tool_calls=self.parallel_tool_calls, - *args, - **kwargs, + completion_params.update( + { + "tools": self.tools_list_dictionary, + "tool_choice": self.tool_choice, + "parallel_tool_calls": self.parallel_tool_calls, + } ) - - content = ( + response = await acompletion(**completion_params) + return ( response.choices[0] .message.tool_calls[0] .function.arguments ) - # return response - - else: - response = await acompletion( - model=self.model_name, - messages=messages, - stream=self.stream, - temperature=self.temperature, - max_tokens=self.max_tokens, - *args, - **kwargs, - ) + # Standard completion + response = await acompletion(**completion_params) + return response.choices[0].message.content - content = response.choices[ - 0 - ].message.content # Accessing the content - - return content except Exception as error: - logger.error(f"Error in LiteLLM: {error}") + logger.error(f"Error in LiteLLM arun: {str(error)}") + if "rate_limit" in str(error).lower(): + logger.warning( + "Rate limit hit, retrying with exponential backoff..." + ) + await asyncio.sleep(2) # Use async sleep + return await self.arun(task, *args, **kwargs) raise error + async def _process_batch( + self, tasks: List[str], batch_size: int = 10 + ): + """ + Process a batch of tasks asynchronously. + + Args: + tasks (List[str]): List of tasks to process. + batch_size (int): Size of each batch. + + Returns: + List[str]: List of responses. + """ + results = [] + for i in range(0, len(tasks), batch_size): + batch = tasks[i : i + batch_size] + batch_results = await asyncio.gather( + *[self.arun(task) for task in batch], + return_exceptions=True, + ) + + # Handle any exceptions in the batch + for result in batch_results: + if isinstance(result, Exception): + logger.error( + f"Error in batch processing: {str(result)}" + ) + results.append(str(result)) + else: + results.append(result) + + # Add a small delay between batches to avoid rate limits + if i + batch_size < len(tasks): + await asyncio.sleep(0.5) + + return results + def batched_run(self, tasks: List[str], batch_size: int = 10): """ - Run the LLM model for the given tasks in batches. + Run multiple tasks in batches synchronously. + + Args: + tasks (List[str]): List of tasks to process. + batch_size (int): Size of each batch. + + Returns: + List[str]: List of responses. """ logger.info( - f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}" + f"Running {len(tasks)} tasks in batches of {batch_size}" ) - results = [] - for task in tasks: - logger.info(f"Running task: {task}") - results.append(self.run(task)) - logger.info("Completed all tasks.") - return results + return asyncio.run(self._process_batch(tasks, batch_size)) - def batched_arun(self, tasks: List[str], batch_size: int = 10): + async def batched_arun( + self, tasks: List[str], batch_size: int = 10 + ): """ - Run the LLM model for the given tasks in batches. + Run multiple tasks in batches asynchronously. + + Args: + tasks (List[str]): List of tasks to process. + batch_size (int): Size of each batch. + + Returns: + List[str]: List of responses. """ logger.info( - f"Running asynchronous tasks in batches of size {batch_size}. Total tasks: {len(tasks)}" + f"Running {len(tasks)} tasks asynchronously in batches of {batch_size}" ) - results = [] - for task in tasks: - logger.info(f"Running asynchronous task: {task}") - results.append(asyncio.run(self.arun(task))) - logger.info("Completed all asynchronous tasks.") - return results + return await self._process_batch(tasks, batch_size)