llama 4 docs

pull/811/head
Kye Gomez 2 weeks ago
parent 643bc3ccef
commit b4f59aad40

@ -268,6 +268,24 @@ nav:
- Pinecone: "swarms_memory/pinecone.md" - Pinecone: "swarms_memory/pinecone.md"
- Faiss: "swarms_memory/faiss.md" - Faiss: "swarms_memory/faiss.md"
- About Us:
- Swarms Vision: "swarms/concept/vision.md"
- Swarm Ecosystem: "swarms/concept/swarm_ecosystem.md"
- Swarms Products: "swarms/products.md"
- Swarms Framework Architecture: "swarms/concept/framework_architecture.md"
- Developers and Contributors:
- Bounty Program: "corporate/bounty_program.md"
- Contributing:
- Contributing: "swarms/contributing.md"
- Tests: "swarms/framework/test.md"
- Code Cleanliness: "swarms/framework/code_cleanliness.md"
- Philosophy: "swarms/concept/philosophy.md"
- Changelog:
- Swarms 5.6.8: "swarms/changelog/5_6_8.md"
- Swarms 5.8.1: "swarms/changelog/5_8_1.md"
- Swarms 5.9.2: "swarms/changelog/changelog_new.md"
- Examples: - Examples:
- Overview: "swarms/examples/unique_swarms.md" - Overview: "swarms/examples/unique_swarms.md"
- Swarms API Examples: - Swarms API Examples:
@ -284,6 +302,7 @@ nav:
- OpenRouter: "swarms/examples/openrouter.md" - OpenRouter: "swarms/examples/openrouter.md"
- XAI: "swarms/examples/xai.md" - XAI: "swarms/examples/xai.md"
- VLLM: "swarms/examples/vllm_integration.md" - VLLM: "swarms/examples/vllm_integration.md"
- Llama4: "swarms/examples/llama4.md"
- Swarms Tools: - Swarms Tools:
- Agent with Yahoo Finance: "swarms/examples/yahoo_finance.md" - Agent with Yahoo Finance: "swarms/examples/yahoo_finance.md"
- Twitter Agents: "swarms_tools/twitter.md" - Twitter Agents: "swarms_tools/twitter.md"
@ -342,23 +361,6 @@ nav:
- Swarm Platform API Keys: "swarms_platform/apikeys.md" - Swarm Platform API Keys: "swarms_platform/apikeys.md"
- Account Management: "swarms_platform/account_management.md" - Account Management: "swarms_platform/account_management.md"
- About Us:
- Swarms Vision: "swarms/concept/vision.md"
- Swarm Ecosystem: "swarms/concept/swarm_ecosystem.md"
- Swarms Products: "swarms/products.md"
- Swarms Framework Architecture: "swarms/concept/framework_architecture.md"
- Contributors:
- Bounty Program: "corporate/bounty_program.md"
- Contributing:
- Contributing: "swarms/contributing.md"
- Tests: "swarms/framework/test.md"
- Code Cleanliness: "swarms/framework/code_cleanliness.md"
- Philosophy: "swarms/concept/philosophy.md"
- Changelog:
- Swarms 5.6.8: "swarms/changelog/5_6_8.md"
- Swarms 5.8.1: "swarms/changelog/5_8_1.md"
- Swarms 5.9.2: "swarms/changelog/changelog_new.md"
# - Prompts API: # - Prompts API:
# - Add Prompts: "swarms_platform/prompts/add_prompt.md" # - Add Prompts: "swarms_platform/prompts/add_prompt.md"

@ -0,0 +1,161 @@
# Llama4 Model Integration
!!! info "Prerequisites"
- Python 3.8 or higher
- `swarms` library installed
- Access to Llama4 model
- Valid environment variables configured
## Quick Start
Here's a simple example of integrating Llama4 model for crypto risk analysis:
```python
from dotenv import load_dotenv
from swarms import Agent
from swarms.utils.vllm_wrapper import VLLM
load_dotenv()
model = VLLM(model_name="meta-llama/Llama-4-Maverick-17B-128E")
```
!!! tip "Environment Setup"
Make sure to set up your environment variables properly before running the code.
Create a `.env` file in your project root if needed.
## Detailed Implementation
### 1. Define Custom System Prompt
```python
CRYPTO_RISK_ANALYSIS_PROMPT = """
You are a cryptocurrency risk analysis expert. Your role is to:
1. Analyze market risks:
- Volatility assessment
- Market sentiment analysis
- Trading volume patterns
- Price trend evaluation
2. Evaluate technical risks:
- Network security
- Protocol vulnerabilities
- Smart contract risks
- Technical scalability
3. Consider regulatory risks:
- Current regulations
- Potential regulatory changes
- Compliance requirements
- Geographic restrictions
4. Assess fundamental risks:
- Team background
- Project development status
- Competition analysis
- Use case viability
Provide detailed, balanced analysis with both risks and potential mitigations.
Base your analysis on established crypto market principles and current market conditions.
"""
```
### 2. Initialize Agent
```python
agent = Agent(
agent_name="Crypto-Risk-Analysis-Agent",
agent_description="Agent for analyzing risks in cryptocurrency investments",
system_prompt=CRYPTO_RISK_ANALYSIS_PROMPT,
max_loops=1,
llm=model,
)
```
## Full Code
```python
from dotenv import load_dotenv
from swarms import Agent
from swarms.utils.vllm_wrapper import VLLM
load_dotenv()
# Define custom system prompt for crypto risk analysis
CRYPTO_RISK_ANALYSIS_PROMPT = """
You are a cryptocurrency risk analysis expert. Your role is to:
1. Analyze market risks:
- Volatility assessment
- Market sentiment analysis
- Trading volume patterns
- Price trend evaluation
2. Evaluate technical risks:
- Network security
- Protocol vulnerabilities
- Smart contract risks
- Technical scalability
3. Consider regulatory risks:
- Current regulations
- Potential regulatory changes
- Compliance requirements
- Geographic restrictions
4. Assess fundamental risks:
- Team background
- Project development status
- Competition analysis
- Use case viability
Provide detailed, balanced analysis with both risks and potential mitigations.
Base your analysis on established crypto market principles and current market conditions.
"""
model = VLLM(model_name="meta-llama/Llama-4-Maverick-17B-128E")
# Initialize the agent with custom prompt
agent = Agent(
agent_name="Crypto-Risk-Analysis-Agent",
agent_description="Agent for analyzing risks in cryptocurrency investments",
system_prompt=CRYPTO_RISK_ANALYSIS_PROMPT,
max_loops=1,
llm=model,
)
print(
agent.run(
"Conduct a risk analysis of the top cryptocurrencies. Think for 2 loops internally"
)
)
```
!!! warning "Resource Usage"
The Llama4 model requires significant computational resources. Ensure your system meets the minimum requirements.
## FAQ
??? question "What is the purpose of max_loops parameter?"
The `max_loops` parameter determines how many times the agent will iterate through its thinking process. In this example, it's set to 1 for a single pass analysis.
??? question "Can I use a different model?"
Yes, you can replace the VLLM wrapper with other compatible models. Just ensure you update the model initialization accordingly.
??? question "How do I customize the system prompt?"
You can modify the `CRYPTO_RISK_ANALYSIS_PROMPT` string to match your specific use case while maintaining the structured format.
!!! note "Best Practices"
- Always handle API errors gracefully
- Monitor model performance and resource usage
- Keep your prompts clear and specific
- Test thoroughly before production deployment
!!! example "Sample Usage"
```python
response = agent.run(
"Conduct a risk analysis of the top cryptocurrencies. Think for 2 loops internally"
)
print(response)
```

@ -13,7 +13,7 @@ agent = Agent(
agent_description="Personal finance advisor agent", agent_description="Personal finance advisor agent",
system_prompt=FINANCIAL_AGENT_SYS_PROMPT, system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
max_loops=2, max_loops=2,
model_name="gpt-4o", model_name="groq/llama-3.3-70b-versatile",
dynamic_temperature_enabled=True, dynamic_temperature_enabled=True,
user_name="swarms_corp", user_name="swarms_corp",
retry_attempts=3, retry_attempts=3,

@ -200,7 +200,7 @@ stock_analysis_agents = [
fundamental_analyst, fundamental_analyst,
sentiment_analyst, sentiment_analyst,
quant_analyst, quant_analyst,
portfolio_strategist portfolio_strategist,
] ]
swarm = ConcurrentWorkflow( swarm = ConcurrentWorkflow(
@ -209,4 +209,6 @@ swarm = ConcurrentWorkflow(
agents=stock_analysis_agents, agents=stock_analysis_agents,
) )
swarm.run("Analyze the best etfs for gold and other similiar commodities in volatile markets") swarm.run(
"Analyze the best etfs for gold and other similiar commodities in volatile markets"
)

@ -0,0 +1,5 @@
from swarms.utils.litellm_wrapper import LiteLLM
model = LiteLLM(model_name="gpt-4o-mini", verbose=True)
print(model.run("What is your purpose in life?"))

@ -0,0 +1,55 @@
from dotenv import load_dotenv
from swarms import Agent
from swarms.utils.vllm_wrapper import VLLM
load_dotenv()
# Define custom system prompt for crypto risk analysis
CRYPTO_RISK_ANALYSIS_PROMPT = """
You are a cryptocurrency risk analysis expert. Your role is to:
1. Analyze market risks:
- Volatility assessment
- Market sentiment analysis
- Trading volume patterns
- Price trend evaluation
2. Evaluate technical risks:
- Network security
- Protocol vulnerabilities
- Smart contract risks
- Technical scalability
3. Consider regulatory risks:
- Current regulations
- Potential regulatory changes
- Compliance requirements
- Geographic restrictions
4. Assess fundamental risks:
- Team background
- Project development status
- Competition analysis
- Use case viability
Provide detailed, balanced analysis with both risks and potential mitigations.
Base your analysis on established crypto market principles and current market conditions.
"""
model = VLLM(model_name="meta-llama/Llama-4-Maverick-17B-128E")
# Initialize the agent with custom prompt
agent = Agent(
agent_name="Crypto-Risk-Analysis-Agent",
agent_description="Agent for analyzing risks in cryptocurrency investments",
system_prompt=CRYPTO_RISK_ANALYSIS_PROMPT,
max_loops=1,
llm=model,
)
print(
agent.run(
"Conduct a risk analysis of the top cryptocurrencies. Think for 2 loops internally"
)
)

@ -1,4 +1,3 @@
from swarms import Agent from swarms import Agent
from swarms.prompts.finance_agent_sys_prompt import ( from swarms.prompts.finance_agent_sys_prompt import (
FINANCIAL_AGENT_SYS_PROMPT, FINANCIAL_AGENT_SYS_PROMPT,

@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "swarms" name = "swarms"
version = "7.6.4" version = "7.6.5"
description = "Swarms - TGSC" description = "Swarms - TGSC"
license = "MIT" license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"] authors = ["Kye Gomez <kye@apac.ai>"]

@ -46,6 +46,11 @@ from swarms.structs.safe_loading import (
) )
from swarms.telemetry.main import log_agent_data from swarms.telemetry.main import log_agent_data
from swarms.tools.base_tool import BaseTool from swarms.tools.base_tool import BaseTool
from swarms.tools.mcp_integration import (
MCPServerSseParams,
batch_mcp_flow,
mcp_flow_get_tool_schema,
)
from swarms.tools.tool_parse_exec import parse_and_execute_json from swarms.tools.tool_parse_exec import parse_and_execute_json
from swarms.utils.any_to_str import any_to_str from swarms.utils.any_to_str import any_to_str
from swarms.utils.data_to_text import data_to_text from swarms.utils.data_to_text import data_to_text
@ -55,15 +60,10 @@ from swarms.utils.history_output_formatter import (
history_output_formatter, history_output_formatter,
) )
from swarms.utils.litellm_tokenizer import count_tokens from swarms.utils.litellm_tokenizer import count_tokens
from swarms.utils.litellm_wrapper import LiteLLM
from swarms.utils.pdf_to_text import pdf_to_text from swarms.utils.pdf_to_text import pdf_to_text
from swarms.utils.str_to_dict import str_to_dict from swarms.utils.str_to_dict import str_to_dict
from swarms.tools.mcp_integration import (
batch_mcp_flow,
mcp_flow_get_tool_schema,
MCPServerSseParams,
)
# Utils # Utils
# Custom stopping condition # Custom stopping condition
@ -104,6 +104,51 @@ agent_output_type = Literal[
ToolUsageType = Union[BaseModel, Dict[str, Any]] ToolUsageType = Union[BaseModel, Dict[str, Any]]
# Agent Exceptions
class AgentError(Exception):
"""Base class for all agent-related exceptions."""
pass
class AgentInitializationError(AgentError):
"""Exception raised when the agent fails to initialize properly. Please check the configuration and parameters."""
pass
class AgentRunError(AgentError):
"""Exception raised when the agent encounters an error during execution. Ensure that the task and environment are set up correctly."""
pass
class AgentLLMError(AgentError):
"""Exception raised when there is an issue with the language model (LLM). Verify the model's availability and compatibility."""
pass
class AgentToolError(AgentError):
"""Exception raised when the agent fails to utilize a tool. Check the tool's configuration and availability."""
pass
class AgentMemoryError(AgentError):
"""Exception raised when the agent encounters a memory-related issue. Ensure that memory resources are properly allocated and accessible."""
pass
class AgentLLMInitializationError(AgentError):
"""Exception raised when the LLM fails to initialize properly. Please check the configuration and parameters."""
pass
# [FEAT][AGENT] # [FEAT][AGENT]
class Agent: class Agent:
""" """
@ -479,6 +524,12 @@ class Agent:
self.no_print = no_print self.no_print = no_print
self.tools_list_dictionary = tools_list_dictionary self.tools_list_dictionary = tools_list_dictionary
self.mcp_servers = mcp_servers self.mcp_servers = mcp_servers
self._cached_llm = (
None # Add this line to cache the LLM instance
)
self._default_model = (
"gpt-4o-mini" # Move default model name here
)
if ( if (
self.agent_name is not None self.agent_name is not None
@ -599,50 +650,49 @@ class Agent:
self.tools_list_dictionary = self.mcp_tool_handling() self.tools_list_dictionary = self.mcp_tool_handling()
def llm_handling(self): def llm_handling(self):
from swarms.utils.litellm_wrapper import LiteLLM # Use cached instance if available
if self._cached_llm is not None:
return self._cached_llm
if self.model_name is None: if self.model_name is None:
# raise ValueError("Model name cannot be None")
logger.warning( logger.warning(
"Model name is not provided, using gpt-4o-mini. You can configure any model from litellm if desired." f"Model name is not provided, using {self._default_model}. You can configure any model from litellm if desired."
) )
self.model_name = "gpt-4o-mini" self.model_name = self._default_model
try: try:
# Simplify initialization logic
common_args = {
"model_name": self.model_name,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"system_prompt": self.system_prompt,
}
if self.llm_args is not None: if self.llm_args is not None:
llm = LiteLLM( self._cached_llm = LiteLLM(
model_name=self.model_name, **self.llm_args **{**common_args, **self.llm_args}
) )
elif self.tools_list_dictionary is not None: elif self.tools_list_dictionary is not None:
self._cached_llm = LiteLLM(
length_of_tools_list_dictionary = len( **common_args,
self.tools_list_dictionary
)
if length_of_tools_list_dictionary > 0:
parallel_tool_calls = True
llm = LiteLLM(
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_tokens,
system_prompt=self.system_prompt,
tools_list_dictionary=self.tools_list_dictionary, tools_list_dictionary=self.tools_list_dictionary,
tool_choice="auto", tool_choice="auto",
parallel_tool_calls=parallel_tool_calls, parallel_tool_calls=len(
self.tools_list_dictionary
)
> 1,
) )
else: else:
llm = LiteLLM( self._cached_llm = LiteLLM(
model_name=self.model_name, **common_args, stream=self.streaming_on
temperature=self.temperature, )
max_tokens=self.max_tokens,
system_prompt=self.system_prompt, return self._cached_llm
stream=self.streaming_on, except AgentLLMInitializationError as e:
logger.error(
f"Error in llm_handling: {e} Your current configuration is not supported. Please check the configuration and parameters."
) )
return llm
except Exception as e:
logger.error(f"Error in llm_handling: {e}")
return None return None
def mcp_execution_flow(self, response: any): def mcp_execution_flow(self, response: any):
@ -2336,6 +2386,8 @@ class Agent:
Args: Args:
task (str): The task to be performed by the `llm` object. task (str): The task to be performed by the `llm` object.
img (str, optional): Path or URL to an image file.
audio (str, optional): Path or URL to an audio file.
*args: Variable length argument list. *args: Variable length argument list.
**kwargs: Arbitrary keyword arguments. **kwargs: Arbitrary keyword arguments.
@ -2347,22 +2399,22 @@ class Agent:
TypeError: If task is not a string or llm object is None. TypeError: If task is not a string or llm object is None.
ValueError: If task is empty. ValueError: If task is empty.
""" """
if not isinstance(task, str): # if not isinstance(task, str):
raise TypeError("Task must be a string") # task = any_to_str(task)
if task is None: # if img is not None:
raise ValueError("Task cannot be None") # kwargs['img'] = img
# if self.llm is None: # if audio is not None:
# raise TypeError("LLM object cannot be None") # kwargs['audio'] = audio
try: try:
out = self.llm.run(task, *args, **kwargs) out = self.llm.run(task=task, *args, **kwargs)
return out return out
except AttributeError as e: except AgentLLMError as e:
logger.error( logger.error(
f"Error calling LLM: {e} You need a class with a run(task: str) method" f"Error calling LLM: {e}. Task: {task}, Args: {args}, Kwargs: {kwargs}"
) )
raise e raise e

@ -5,7 +5,7 @@ import asyncio
from typing import List from typing import List
from loguru import logger from loguru import logger
import litellm
try: try:
from litellm import completion, acompletion from litellm import completion, acompletion
@ -77,6 +77,8 @@ class LiteLLM:
tool_choice: str = "auto", tool_choice: str = "auto",
parallel_tool_calls: bool = False, parallel_tool_calls: bool = False,
audio: str = None, audio: str = None,
retries: int = 3,
verbose: bool = False,
*args, *args,
**kwargs, **kwargs,
): ):
@ -100,7 +102,18 @@ class LiteLLM:
self.tools_list_dictionary = tools_list_dictionary self.tools_list_dictionary = tools_list_dictionary
self.tool_choice = tool_choice self.tool_choice = tool_choice
self.parallel_tool_calls = parallel_tool_calls self.parallel_tool_calls = parallel_tool_calls
self.modalities = ["text"] self.modalities = []
self._cached_messages = {} # Cache for prepared messages
self.messages = [] # Initialize messages list
# Configure litellm settings
litellm.set_verbose = (
verbose # Disable verbose mode for better performance
)
litellm.ssl_verify = ssl_verify
litellm.num_retries = (
retries # Add retries for better reliability
)
def _prepare_messages(self, task: str) -> list: def _prepare_messages(self, task: str) -> list:
""" """
@ -112,15 +125,20 @@ class LiteLLM:
Returns: Returns:
list: A list of messages prepared for the task. list: A list of messages prepared for the task.
""" """
messages = [] # Check cache first
cache_key = f"{self.system_prompt}:{task}"
if cache_key in self._cached_messages:
return self._cached_messages[cache_key].copy()
if self.system_prompt: # Check if system_prompt is not None messages = []
if self.system_prompt:
messages.append( messages.append(
{"role": "system", "content": self.system_prompt} {"role": "system", "content": self.system_prompt}
) )
messages.append({"role": "user", "content": task}) messages.append({"role": "user", "content": task})
# Cache the prepared messages
self._cached_messages[cache_key] = messages.copy()
return messages return messages
def audio_processing(self, task: str, audio: str): def audio_processing(self, task: str, audio: str):
@ -182,15 +200,16 @@ class LiteLLM:
""" """
Handle the modalities for the given task. Handle the modalities for the given task.
""" """
self.messages = [] # Reset messages
self.modalities.append("text")
if audio is not None: if audio is not None:
self.audio_processing(task=task, audio=audio) self.audio_processing(task=task, audio=audio)
self.modalities.append("audio")
if img is not None: if img is not None:
self.vision_processing(task=task, image=img) self.vision_processing(task=task, image=img)
self.modalities.append("vision")
if audio is not None and img is not None:
self.audio_processing(task=task, audio=audio)
self.vision_processing(task=task, image=img)
def run( def run(
self, self,
@ -205,58 +224,78 @@ class LiteLLM:
Args: Args:
task (str): The task to run the model for. task (str): The task to run the model for.
*args: Additional positional arguments to pass to the model. audio (str, optional): Audio input if any. Defaults to None.
**kwargs: Additional keyword arguments to pass to the model. img (str, optional): Image input if any. Defaults to None.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns: Returns:
str: The content of the response from the model. str: The content of the response from the model.
Raises:
Exception: If there is an error in processing the request.
""" """
try: try:
messages = self._prepare_messages(task) messages = self._prepare_messages(task)
self.handle_modalities(task=task, audio=audio, img=img) if audio is not None or img is not None:
self.handle_modalities(
task=task, audio=audio, img=img
)
messages = (
self.messages
) # Use modality-processed messages
# Prepare common completion parameters
completion_params = {
"model": self.model_name,
"messages": messages,
"stream": self.stream,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
**kwargs,
}
# Handle tool-based completion
if self.tools_list_dictionary is not None: if self.tools_list_dictionary is not None:
response = completion( completion_params.update(
model=self.model_name, {
messages=messages, "tools": self.tools_list_dictionary,
stream=self.stream, "tool_choice": self.tool_choice,
temperature=self.temperature, "parallel_tool_calls": self.parallel_tool_calls,
max_tokens=self.max_tokens, }
tools=self.tools_list_dictionary,
modalities=self.modalities,
tool_choice=self.tool_choice,
parallel_tool_calls=self.parallel_tool_calls,
*args,
**kwargs,
) )
response = completion(**completion_params)
return ( return (
response.choices[0] response.choices[0]
.message.tool_calls[0] .message.tool_calls[0]
.function.arguments .function.arguments
) )
else: # Handle modality-based completion
response = completion( if (
model=self.model_name, self.modalities and len(self.modalities) > 1
messages=messages, ): # More than just text
stream=self.stream, completion_params.update(
temperature=self.temperature, {"modalities": self.modalities}
max_tokens=self.max_tokens,
modalities=self.modalities,
*args,
**kwargs,
) )
response = completion(**completion_params)
return response.choices[0].message.content
content = response.choices[ # Standard completion
0 response = completion(**completion_params)
].message.content # Accessing the content return response.choices[0].message.content
return content
except Exception as error: except Exception as error:
logger.error(f"Error in LiteLLM: {error}") logger.error(f"Error in LiteLLM run: {str(error)}")
if "rate_limit" in str(error).lower():
logger.warning(
"Rate limit hit, retrying with exponential backoff..."
)
import time
time.sleep(2) # Add a small delay before retry
return self.run(task, audio, img, *args, **kwargs)
raise error raise error
def __call__(self, task: str, *args, **kwargs): def __call__(self, task: str, *args, **kwargs):
@ -275,12 +314,12 @@ class LiteLLM:
async def arun(self, task: str, *args, **kwargs): async def arun(self, task: str, *args, **kwargs):
""" """
Run the LLM model for the given task. Run the LLM model asynchronously for the given task.
Args: Args:
task (str): The task to run the model for. task (str): The task to run the model for.
*args: Additional positional arguments to pass to the model. *args: Additional positional arguments.
**kwargs: Additional keyword arguments to pass to the model. **kwargs: Additional keyword arguments.
Returns: Returns:
str: The content of the response from the model. str: The content of the response from the model.
@ -288,72 +327,113 @@ class LiteLLM:
try: try:
messages = self._prepare_messages(task) messages = self._prepare_messages(task)
if self.tools_list_dictionary is not None: # Prepare common completion parameters
response = await acompletion( completion_params = {
model=self.model_name, "model": self.model_name,
messages=messages, "messages": messages,
stream=self.stream, "stream": self.stream,
temperature=self.temperature, "temperature": self.temperature,
max_tokens=self.max_tokens, "max_tokens": self.max_tokens,
tools=self.tools_list_dictionary,
tool_choice=self.tool_choice,
parallel_tool_calls=self.parallel_tool_calls,
*args,
**kwargs, **kwargs,
) }
content = ( # Handle tool-based completion
if self.tools_list_dictionary is not None:
completion_params.update(
{
"tools": self.tools_list_dictionary,
"tool_choice": self.tool_choice,
"parallel_tool_calls": self.parallel_tool_calls,
}
)
response = await acompletion(**completion_params)
return (
response.choices[0] response.choices[0]
.message.tool_calls[0] .message.tool_calls[0]
.function.arguments .function.arguments
) )
# return response # Standard completion
response = await acompletion(**completion_params)
return response.choices[0].message.content
else: except Exception as error:
response = await acompletion( logger.error(f"Error in LiteLLM arun: {str(error)}")
model=self.model_name, if "rate_limit" in str(error).lower():
messages=messages, logger.warning(
stream=self.stream, "Rate limit hit, retrying with exponential backoff..."
temperature=self.temperature,
max_tokens=self.max_tokens,
*args,
**kwargs,
) )
await asyncio.sleep(2) # Use async sleep
return await self.arun(task, *args, **kwargs)
raise error
content = response.choices[ async def _process_batch(
0 self, tasks: List[str], batch_size: int = 10
].message.content # Accessing the content ):
"""
Process a batch of tasks asynchronously.
return content Args:
except Exception as error: tasks (List[str]): List of tasks to process.
logger.error(f"Error in LiteLLM: {error}") batch_size (int): Size of each batch.
raise error
Returns:
List[str]: List of responses.
"""
results = []
for i in range(0, len(tasks), batch_size):
batch = tasks[i : i + batch_size]
batch_results = await asyncio.gather(
*[self.arun(task) for task in batch],
return_exceptions=True,
)
# Handle any exceptions in the batch
for result in batch_results:
if isinstance(result, Exception):
logger.error(
f"Error in batch processing: {str(result)}"
)
results.append(str(result))
else:
results.append(result)
# Add a small delay between batches to avoid rate limits
if i + batch_size < len(tasks):
await asyncio.sleep(0.5)
return results
def batched_run(self, tasks: List[str], batch_size: int = 10): def batched_run(self, tasks: List[str], batch_size: int = 10):
""" """
Run the LLM model for the given tasks in batches. Run multiple tasks in batches synchronously.
Args:
tasks (List[str]): List of tasks to process.
batch_size (int): Size of each batch.
Returns:
List[str]: List of responses.
""" """
logger.info( logger.info(
f"Running tasks in batches of size {batch_size}. Total tasks: {len(tasks)}" f"Running {len(tasks)} tasks in batches of {batch_size}"
) )
results = [] return asyncio.run(self._process_batch(tasks, batch_size))
for task in tasks:
logger.info(f"Running task: {task}")
results.append(self.run(task))
logger.info("Completed all tasks.")
return results
def batched_arun(self, tasks: List[str], batch_size: int = 10): async def batched_arun(
self, tasks: List[str], batch_size: int = 10
):
""" """
Run the LLM model for the given tasks in batches. Run multiple tasks in batches asynchronously.
Args:
tasks (List[str]): List of tasks to process.
batch_size (int): Size of each batch.
Returns:
List[str]: List of responses.
""" """
logger.info( logger.info(
f"Running asynchronous tasks in batches of size {batch_size}. Total tasks: {len(tasks)}" f"Running {len(tasks)} tasks asynchronously in batches of {batch_size}"
) )
results = [] return await self._process_batch(tasks, batch_size)
for task in tasks:
logger.info(f"Running asynchronous task: {task}")
results.append(asyncio.run(self.arun(task)))
logger.info("Completed all asynchronous tasks.")
return results

Loading…
Cancel
Save