[FEAT-FALLBACK][Added agent model fallback, if primary model fails]

pull/1084/head
CI-DEV 3 weeks ago committed by GitHub
parent a64f57a66c
commit ef9aa85d22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -202,6 +202,8 @@ class Agent:
list_of_pdf (str): The list of pdf list_of_pdf (str): The list of pdf
tokenizer (Any): The tokenizer tokenizer (Any): The tokenizer
long_term_memory (BaseVectorDatabase): The long term memory long_term_memory (BaseVectorDatabase): The long term memory
fallback_model_name (str): The fallback model name to use if primary model fails
fallback_models (List[str]): List of fallback model names to try in order if primary model fails
preset_stopping_token (bool): Enable preset stopping token preset_stopping_token (bool): Enable preset stopping token
traceback (Any): The traceback traceback (Any): The traceback
traceback_handlers (Any): The traceback handlers traceback_handlers (Any): The traceback handlers
@ -302,6 +304,15 @@ class Agent:
>>> response = agent.run("Tell me a long story.") # Will stream in real-time >>> response = agent.run("Tell me a long story.") # Will stream in real-time
>>> print(response) # Final complete response >>> print(response) # Final complete response
>>> # Fallback model example
>>> agent = Agent(
... model_name="gpt-4o",
... fallback_models=["gpt-4o-mini", "gpt-3.5-turbo"],
... max_loops=1
... )
>>> response = agent.run("Generate a report on the financials.")
>>> # If gpt-4o fails, it will automatically try gpt-4o-mini, then gpt-3.5-turbo
""" """
def __init__( def __init__(
@ -340,6 +351,7 @@ class Agent:
tokenizer: Optional[Any] = None, tokenizer: Optional[Any] = None,
long_term_memory: Optional[Union[Callable, Any]] = None, long_term_memory: Optional[Union[Callable, Any]] = None,
fallback_model_name: Optional[str] = None, fallback_model_name: Optional[str] = None,
fallback_models: Optional[List[str]] = None,
preset_stopping_token: Optional[bool] = False, preset_stopping_token: Optional[bool] = False,
traceback: Optional[Any] = None, traceback: Optional[Any] = None,
traceback_handlers: Optional[Any] = None, traceback_handlers: Optional[Any] = None,
@ -605,6 +617,9 @@ class Agent:
self.thinking_tokens = thinking_tokens self.thinking_tokens = thinking_tokens
self.reasoning_enabled = reasoning_enabled self.reasoning_enabled = reasoning_enabled
self.fallback_model_name = fallback_model_name self.fallback_model_name = fallback_model_name
self.fallback_models = fallback_models or []
self.current_model_index = 0
self.model_attempts = {}
# self.init_handling() # self.init_handling()
self.setup_config() self.setup_config()
@ -728,6 +743,9 @@ class Agent:
if self.model_name is None: if self.model_name is None:
self.model_name = "gpt-4o-mini" self.model_name = "gpt-4o-mini"
# Use current model (which may be a fallback)
current_model = self.get_current_model()
# Determine if parallel tool calls should be enabled # Determine if parallel tool calls should be enabled
if exists(self.tools) and len(self.tools) >= 2: if exists(self.tools) and len(self.tools) >= 2:
@ -2551,11 +2569,46 @@ class Agent:
return out return out
except AgentLLMError as e: except (AgentLLMError, Exception) as e:
logger.error( logger.error(
f"Error calling LLM: {e}. Task: {task}, Args: {args}, Kwargs: {kwargs} Traceback: {traceback.format_exc()}" f"Error calling LLM with model '{self.get_current_model()}': {e}. "
f"Task: {task}, Args: {args}, Kwargs: {kwargs} Traceback: {traceback.format_exc()}"
) )
raise e
# Try fallback models if available
if self.is_fallback_available() and self.switch_to_next_model():
logger.info(
f"Retrying with fallback model '{self.get_current_model()}' for agent '{self.agent_name}'"
)
try:
# Recursive call with the new model
return self.call_llm(
task=task,
img=img,
current_loop=current_loop,
streaming_callback=streaming_callback,
*args,
**kwargs
)
except Exception as fallback_error:
logger.error(
f"Fallback model '{self.get_current_model()}' also failed: {fallback_error}"
)
# Continue to next fallback or raise if no more models
if self.is_fallback_available() and self.switch_to_next_model():
return self.call_llm(
task=task,
img=img,
current_loop=current_loop,
streaming_callback=streaming_callback,
*args,
**kwargs
)
else:
raise e
else:
# No fallback available or all fallbacks exhausted
raise e
def handle_sop_ops(self): def handle_sop_ops(self):
# If the user inputs a list of strings for the sop then join them and set the sop # If the user inputs a list of strings for the sop then join them and set the sop
@ -2976,6 +3029,85 @@ class Agent:
api_key=self.llm_api_key, api_key=self.llm_api_key,
) )
def get_available_models(self) -> List[str]:
"""
Get the list of available models including primary and fallback models.
Returns:
List[str]: List of model names in order of preference
"""
models = [self.model_name] if self.model_name else []
# Add single fallback model if specified
if self.fallback_model_name and self.fallback_model_name not in models:
models.append(self.fallback_model_name)
# Add fallback models list if specified
if self.fallback_models:
for model in self.fallback_models:
if model not in models:
models.append(model)
return models
def get_current_model(self) -> str:
"""
Get the current model being used.
Returns:
str: Current model name
"""
available_models = self.get_available_models()
if self.current_model_index < len(available_models):
return available_models[self.current_model_index]
return self.model_name or "gpt-4o-mini"
def switch_to_next_model(self) -> bool:
"""
Switch to the next available model in the fallback list.
Returns:
bool: True if successfully switched to next model, False if no more models available
"""
available_models = self.get_available_models()
if self.current_model_index + 1 < len(available_models):
self.current_model_index += 1
new_model = available_models[self.current_model_index]
logger.warning(
f"Agent '{self.agent_name}' switching to fallback model: {new_model} "
f"(attempt {self.current_model_index + 1}/{len(available_models)})"
)
# Update the model name and reinitialize LLM
self.model_name = new_model
self.llm = self.llm_handling()
return True
else:
logger.error(
f"Agent '{self.agent_name}' has exhausted all available models. "
f"Tried {len(available_models)} models: {available_models}"
)
return False
def reset_model_index(self) -> None:
"""Reset the model index to use the primary model."""
self.current_model_index = 0
if self.model_name:
self.llm = self.llm_handling()
def is_fallback_available(self) -> bool:
"""
Check if fallback models are available.
Returns:
bool: True if fallback models are configured
"""
available_models = self.get_available_models()
return len(available_models) > 1
def execute_tools(self, response: any, loop_count: int): def execute_tools(self, response: any, loop_count: int):
# Handle None response gracefully # Handle None response gracefully
if response is None: if response is None:

Loading…
Cancel
Save