Fix LLM serialization issue in Agent class

pull/844/head
ascender1729 2 months ago
parent 84bb4b17c9
commit 6c21dfce9e

@ -6,6 +6,7 @@ import random
import threading
import time
import uuid
import importlib
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import (
@ -2889,9 +2890,10 @@ class Agent:
state_dict['logger_handler'] = self._serialize_logger_handler()
state_dict['short_memory'] = self._serialize_short_memory()
state_dict['tokenizer'] = self._serialize_tokenizer()
state_dict['llm'] = self._serialize_llm()
# Handle other non-serializable objects
non_serializable = ['llm', 'long_term_memory', 'agent_output', 'executor']
non_serializable = ['long_term_memory', 'agent_output', 'executor']
for key in non_serializable:
if key in state_dict:
if state_dict[key] is not None:
@ -2939,9 +2941,12 @@ class Agent:
logger_config = state_dict.pop('logger_handler', None)
short_memory_config = state_dict.pop('short_memory', None)
tokenizer_config = state_dict.pop('tokenizer', None)
llm_config = state_dict.pop('llm', None)
self._deserialize_logger_handler(logger_config)
self._deserialize_short_memory(short_memory_config)
self._deserialize_tokenizer(tokenizer_config)
self._deserialize_llm(llm_config)
# Update remaining agent attributes
for key, value in state_dict.items():
@ -3007,3 +3012,94 @@ class Agent:
except Exception as e:
logger.warning(f"Failed to recreate tokenizer: {e}")
self.tokenizer = None
def _serialize_llm(self):
"""Serialize LLM configuration and state."""
if not hasattr(self, 'llm') or self.llm is None:
return None
try:
# Get the LLM's class name and module
llm_class = type(self.llm).__name__
llm_module = type(self.llm).__module__
# Get all attributes that are serializable
attrs = {}
for attr_name in dir(self.llm):
if not attr_name.startswith('_'): # Skip private attributes
try:
attr_value = getattr(self.llm, attr_name)
if not callable(attr_value): # Skip methods
# Try to serialize the attribute
json.dumps(attr_value)
attrs[attr_name] = attr_value
except (TypeError, ValueError):
continue
return {
'class': llm_class,
'module': llm_module,
'attributes': attrs,
'model_name': self.model_name if hasattr(self, 'model_name') else None,
'llm_args': self.llm_args if hasattr(self, 'llm_args') else None
}
except Exception as e:
logger.warning(f"Failed to serialize LLM: {e}")
return None
def _deserialize_llm(self, config: dict) -> None:
"""Recreate LLM from configuration."""
if not config:
self.llm = None
return
try:
# Import the LLM class
try:
module_parts = config['module'].split('.')
module = __import__(config['module'], fromlist=[module_parts[-1]])
llm_class = getattr(module, config['class'])
except (ImportError, AttributeError) as e:
logger.warning(f"Failed to import LLM class: {e}")
# As a fallback, try to use LiteLLM directly if that's what we need
if config['class'] == 'LiteLLM':
from swarms.utils.litellm_wrapper import LiteLLM
llm_class = LiteLLM
else:
raise
# Get configuration parameters
attrs = config.get('attributes', {})
# Handle the model_name parameter
model_name = attrs.get('model_name') or config.get('model_name') or self.model_name
# Initialize the LLM
kwargs = {}
for attr_name, attr_value in attrs.items():
if attr_name not in ('model_name',): # Skip these as they're handled separately
kwargs[attr_name] = attr_value
# Special handling for LiteLLM
if config['class'] == 'LiteLLM':
logger.info(f"Recreating LiteLLM with model: {model_name}")
self.llm = llm_class(model_name=model_name, **kwargs)
else:
# For other LLM types
self.llm = llm_class(**kwargs)
logger.info(f"Successfully restored LLM of type {config['class']}")
except Exception as e:
logger.warning(f"Failed to recreate LLM: {e}")
# Fallback: create a new LLM if model_name is available
if hasattr(self, 'model_name') and self.model_name:
try:
from swarms.utils.litellm_wrapper import LiteLLM
self.llm = LiteLLM(model_name=self.model_name)
logger.info(f"Created fallback LiteLLM with model: {self.model_name}")
except Exception as e2:
logger.error(f"Failed to create fallback LLM: {e2}")
self.llm = None
else:
self.llm = None

Loading…
Cancel
Save