Fix LLM serialization issue in Agent class

pull/844/head
ascender1729 2 months ago
parent 84bb4b17c9
commit 6c21dfce9e

@ -6,6 +6,7 @@ import random
import threading import threading
import time import time
import uuid import uuid
import importlib
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime from datetime import datetime
from typing import ( from typing import (
@ -2889,9 +2890,10 @@ class Agent:
state_dict['logger_handler'] = self._serialize_logger_handler() state_dict['logger_handler'] = self._serialize_logger_handler()
state_dict['short_memory'] = self._serialize_short_memory() state_dict['short_memory'] = self._serialize_short_memory()
state_dict['tokenizer'] = self._serialize_tokenizer() state_dict['tokenizer'] = self._serialize_tokenizer()
state_dict['llm'] = self._serialize_llm()
# Handle other non-serializable objects # Handle other non-serializable objects
non_serializable = ['llm', 'long_term_memory', 'agent_output', 'executor'] non_serializable = ['long_term_memory', 'agent_output', 'executor']
for key in non_serializable: for key in non_serializable:
if key in state_dict: if key in state_dict:
if state_dict[key] is not None: if state_dict[key] is not None:
@ -2939,9 +2941,12 @@ class Agent:
logger_config = state_dict.pop('logger_handler', None) logger_config = state_dict.pop('logger_handler', None)
short_memory_config = state_dict.pop('short_memory', None) short_memory_config = state_dict.pop('short_memory', None)
tokenizer_config = state_dict.pop('tokenizer', None) tokenizer_config = state_dict.pop('tokenizer', None)
llm_config = state_dict.pop('llm', None)
self._deserialize_logger_handler(logger_config) self._deserialize_logger_handler(logger_config)
self._deserialize_short_memory(short_memory_config) self._deserialize_short_memory(short_memory_config)
self._deserialize_tokenizer(tokenizer_config) self._deserialize_tokenizer(tokenizer_config)
self._deserialize_llm(llm_config)
# Update remaining agent attributes # Update remaining agent attributes
for key, value in state_dict.items(): for key, value in state_dict.items():
@ -3007,3 +3012,94 @@ class Agent:
except Exception as e: except Exception as e:
logger.warning(f"Failed to recreate tokenizer: {e}") logger.warning(f"Failed to recreate tokenizer: {e}")
self.tokenizer = None self.tokenizer = None
def _serialize_llm(self):
"""Serialize LLM configuration and state."""
if not hasattr(self, 'llm') or self.llm is None:
return None
try:
# Get the LLM's class name and module
llm_class = type(self.llm).__name__
llm_module = type(self.llm).__module__
# Get all attributes that are serializable
attrs = {}
for attr_name in dir(self.llm):
if not attr_name.startswith('_'): # Skip private attributes
try:
attr_value = getattr(self.llm, attr_name)
if not callable(attr_value): # Skip methods
# Try to serialize the attribute
json.dumps(attr_value)
attrs[attr_name] = attr_value
except (TypeError, ValueError):
continue
return {
'class': llm_class,
'module': llm_module,
'attributes': attrs,
'model_name': self.model_name if hasattr(self, 'model_name') else None,
'llm_args': self.llm_args if hasattr(self, 'llm_args') else None
}
except Exception as e:
logger.warning(f"Failed to serialize LLM: {e}")
return None
def _deserialize_llm(self, config: dict) -> None:
"""Recreate LLM from configuration."""
if not config:
self.llm = None
return
try:
# Import the LLM class
try:
module_parts = config['module'].split('.')
module = __import__(config['module'], fromlist=[module_parts[-1]])
llm_class = getattr(module, config['class'])
except (ImportError, AttributeError) as e:
logger.warning(f"Failed to import LLM class: {e}")
# As a fallback, try to use LiteLLM directly if that's what we need
if config['class'] == 'LiteLLM':
from swarms.utils.litellm_wrapper import LiteLLM
llm_class = LiteLLM
else:
raise
# Get configuration parameters
attrs = config.get('attributes', {})
# Handle the model_name parameter
model_name = attrs.get('model_name') or config.get('model_name') or self.model_name
# Initialize the LLM
kwargs = {}
for attr_name, attr_value in attrs.items():
if attr_name not in ('model_name',): # Skip these as they're handled separately
kwargs[attr_name] = attr_value
# Special handling for LiteLLM
if config['class'] == 'LiteLLM':
logger.info(f"Recreating LiteLLM with model: {model_name}")
self.llm = llm_class(model_name=model_name, **kwargs)
else:
# For other LLM types
self.llm = llm_class(**kwargs)
logger.info(f"Successfully restored LLM of type {config['class']}")
except Exception as e:
logger.warning(f"Failed to recreate LLM: {e}")
# Fallback: create a new LLM if model_name is available
if hasattr(self, 'model_name') and self.model_name:
try:
from swarms.utils.litellm_wrapper import LiteLLM
self.llm = LiteLLM(model_name=self.model_name)
logger.info(f"Created fallback LiteLLM with model: {self.model_name}")
except Exception as e2:
logger.error(f"Failed to create fallback LLM: {e2}")
self.llm = None
else:
self.llm = None

Loading…
Cancel
Save