|
|
|
@ -6,6 +6,7 @@ import random
|
|
|
|
|
import threading
|
|
|
|
|
import time
|
|
|
|
|
import uuid
|
|
|
|
|
import importlib
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from typing import (
|
|
|
|
@ -2889,9 +2890,10 @@ class Agent:
|
|
|
|
|
state_dict['logger_handler'] = self._serialize_logger_handler()
|
|
|
|
|
state_dict['short_memory'] = self._serialize_short_memory()
|
|
|
|
|
state_dict['tokenizer'] = self._serialize_tokenizer()
|
|
|
|
|
state_dict['llm'] = self._serialize_llm()
|
|
|
|
|
|
|
|
|
|
# Handle other non-serializable objects
|
|
|
|
|
non_serializable = ['llm', 'long_term_memory', 'agent_output', 'executor']
|
|
|
|
|
non_serializable = ['long_term_memory', 'agent_output', 'executor']
|
|
|
|
|
for key in non_serializable:
|
|
|
|
|
if key in state_dict:
|
|
|
|
|
if state_dict[key] is not None:
|
|
|
|
@ -2939,9 +2941,12 @@ class Agent:
|
|
|
|
|
logger_config = state_dict.pop('logger_handler', None)
|
|
|
|
|
short_memory_config = state_dict.pop('short_memory', None)
|
|
|
|
|
tokenizer_config = state_dict.pop('tokenizer', None)
|
|
|
|
|
llm_config = state_dict.pop('llm', None)
|
|
|
|
|
|
|
|
|
|
self._deserialize_logger_handler(logger_config)
|
|
|
|
|
self._deserialize_short_memory(short_memory_config)
|
|
|
|
|
self._deserialize_tokenizer(tokenizer_config)
|
|
|
|
|
self._deserialize_llm(llm_config)
|
|
|
|
|
|
|
|
|
|
# Update remaining agent attributes
|
|
|
|
|
for key, value in state_dict.items():
|
|
|
|
@ -3007,3 +3012,94 @@ class Agent:
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Failed to recreate tokenizer: {e}")
|
|
|
|
|
self.tokenizer = None
|
|
|
|
|
|
|
|
|
|
def _serialize_llm(self):
|
|
|
|
|
"""Serialize LLM configuration and state."""
|
|
|
|
|
if not hasattr(self, 'llm') or self.llm is None:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Get the LLM's class name and module
|
|
|
|
|
llm_class = type(self.llm).__name__
|
|
|
|
|
llm_module = type(self.llm).__module__
|
|
|
|
|
|
|
|
|
|
# Get all attributes that are serializable
|
|
|
|
|
attrs = {}
|
|
|
|
|
for attr_name in dir(self.llm):
|
|
|
|
|
if not attr_name.startswith('_'): # Skip private attributes
|
|
|
|
|
try:
|
|
|
|
|
attr_value = getattr(self.llm, attr_name)
|
|
|
|
|
if not callable(attr_value): # Skip methods
|
|
|
|
|
# Try to serialize the attribute
|
|
|
|
|
json.dumps(attr_value)
|
|
|
|
|
attrs[attr_name] = attr_value
|
|
|
|
|
except (TypeError, ValueError):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
'class': llm_class,
|
|
|
|
|
'module': llm_module,
|
|
|
|
|
'attributes': attrs,
|
|
|
|
|
'model_name': self.model_name if hasattr(self, 'model_name') else None,
|
|
|
|
|
'llm_args': self.llm_args if hasattr(self, 'llm_args') else None
|
|
|
|
|
}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Failed to serialize LLM: {e}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _deserialize_llm(self, config: dict) -> None:
|
|
|
|
|
"""Recreate LLM from configuration."""
|
|
|
|
|
if not config:
|
|
|
|
|
self.llm = None
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Import the LLM class
|
|
|
|
|
try:
|
|
|
|
|
module_parts = config['module'].split('.')
|
|
|
|
|
module = __import__(config['module'], fromlist=[module_parts[-1]])
|
|
|
|
|
llm_class = getattr(module, config['class'])
|
|
|
|
|
except (ImportError, AttributeError) as e:
|
|
|
|
|
logger.warning(f"Failed to import LLM class: {e}")
|
|
|
|
|
# As a fallback, try to use LiteLLM directly if that's what we need
|
|
|
|
|
if config['class'] == 'LiteLLM':
|
|
|
|
|
from swarms.utils.litellm_wrapper import LiteLLM
|
|
|
|
|
llm_class = LiteLLM
|
|
|
|
|
else:
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
# Get configuration parameters
|
|
|
|
|
attrs = config.get('attributes', {})
|
|
|
|
|
|
|
|
|
|
# Handle the model_name parameter
|
|
|
|
|
model_name = attrs.get('model_name') or config.get('model_name') or self.model_name
|
|
|
|
|
|
|
|
|
|
# Initialize the LLM
|
|
|
|
|
kwargs = {}
|
|
|
|
|
for attr_name, attr_value in attrs.items():
|
|
|
|
|
if attr_name not in ('model_name',): # Skip these as they're handled separately
|
|
|
|
|
kwargs[attr_name] = attr_value
|
|
|
|
|
|
|
|
|
|
# Special handling for LiteLLM
|
|
|
|
|
if config['class'] == 'LiteLLM':
|
|
|
|
|
logger.info(f"Recreating LiteLLM with model: {model_name}")
|
|
|
|
|
self.llm = llm_class(model_name=model_name, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
# For other LLM types
|
|
|
|
|
self.llm = llm_class(**kwargs)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Successfully restored LLM of type {config['class']}")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Failed to recreate LLM: {e}")
|
|
|
|
|
# Fallback: create a new LLM if model_name is available
|
|
|
|
|
if hasattr(self, 'model_name') and self.model_name:
|
|
|
|
|
try:
|
|
|
|
|
from swarms.utils.litellm_wrapper import LiteLLM
|
|
|
|
|
self.llm = LiteLLM(model_name=self.model_name)
|
|
|
|
|
logger.info(f"Created fallback LiteLLM with model: {self.model_name}")
|
|
|
|
|
except Exception as e2:
|
|
|
|
|
logger.error(f"Failed to create fallback LLM: {e2}")
|
|
|
|
|
self.llm = None
|
|
|
|
|
else:
|
|
|
|
|
self.llm = None
|
|
|
|
|