From 6c21dfce9efd4edc5195f458d22e6e667c1c6fef Mon Sep 17 00:00:00 2001 From: ascender1729 Date: Tue, 13 May 2025 21:40:35 +0530 Subject: [PATCH] Fix LLM serialization issue in Agent class --- swarms/structs/agent.py | 98 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index a2597c4e..104ecec5 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -6,6 +6,7 @@ import random import threading import time import uuid +import importlib from concurrent.futures import ThreadPoolExecutor from datetime import datetime from typing import ( @@ -2889,9 +2890,10 @@ class Agent: state_dict['logger_handler'] = self._serialize_logger_handler() state_dict['short_memory'] = self._serialize_short_memory() state_dict['tokenizer'] = self._serialize_tokenizer() + state_dict['llm'] = self._serialize_llm() # Handle other non-serializable objects - non_serializable = ['llm', 'long_term_memory', 'agent_output', 'executor'] + non_serializable = ['long_term_memory', 'agent_output', 'executor'] for key in non_serializable: if key in state_dict: if state_dict[key] is not None: @@ -2939,9 +2941,12 @@ class Agent: logger_config = state_dict.pop('logger_handler', None) short_memory_config = state_dict.pop('short_memory', None) tokenizer_config = state_dict.pop('tokenizer', None) + llm_config = state_dict.pop('llm', None) + self._deserialize_logger_handler(logger_config) self._deserialize_short_memory(short_memory_config) self._deserialize_tokenizer(tokenizer_config) + self._deserialize_llm(llm_config) # Update remaining agent attributes for key, value in state_dict.items(): @@ -3007,3 +3012,94 @@ class Agent: except Exception as e: logger.warning(f"Failed to recreate tokenizer: {e}") self.tokenizer = None + + def _serialize_llm(self): + """Serialize LLM configuration and state.""" + if not hasattr(self, 'llm') or self.llm is None: + return None + + try: + # Get the LLM's class name and module + llm_class = type(self.llm).__name__ + llm_module = type(self.llm).__module__ + + # Get all attributes that are serializable + attrs = {} + for attr_name in dir(self.llm): + if not attr_name.startswith('_'): # Skip private attributes + try: + attr_value = getattr(self.llm, attr_name) + if not callable(attr_value): # Skip methods + # Try to serialize the attribute + json.dumps(attr_value) + attrs[attr_name] = attr_value + except (TypeError, ValueError): + continue + + return { + 'class': llm_class, + 'module': llm_module, + 'attributes': attrs, + 'model_name': self.model_name if hasattr(self, 'model_name') else None, + 'llm_args': self.llm_args if hasattr(self, 'llm_args') else None + } + except Exception as e: + logger.warning(f"Failed to serialize LLM: {e}") + return None + + def _deserialize_llm(self, config: dict) -> None: + """Recreate LLM from configuration.""" + if not config: + self.llm = None + return + + try: + # Import the LLM class + try: + module_parts = config['module'].split('.') + module = __import__(config['module'], fromlist=[module_parts[-1]]) + llm_class = getattr(module, config['class']) + except (ImportError, AttributeError) as e: + logger.warning(f"Failed to import LLM class: {e}") + # As a fallback, try to use LiteLLM directly if that's what we need + if config['class'] == 'LiteLLM': + from swarms.utils.litellm_wrapper import LiteLLM + llm_class = LiteLLM + else: + raise + + # Get configuration parameters + attrs = config.get('attributes', {}) + + # Handle the model_name parameter + model_name = attrs.get('model_name') or config.get('model_name') or self.model_name + + # Initialize the LLM + kwargs = {} + for attr_name, attr_value in attrs.items(): + if attr_name not in ('model_name',): # Skip these as they're handled separately + kwargs[attr_name] = attr_value + + # Special handling for LiteLLM + if config['class'] == 'LiteLLM': + logger.info(f"Recreating LiteLLM with model: {model_name}") + self.llm = llm_class(model_name=model_name, **kwargs) + else: + # For other LLM types + self.llm = llm_class(**kwargs) + + logger.info(f"Successfully restored LLM of type {config['class']}") + + except Exception as e: + logger.warning(f"Failed to recreate LLM: {e}") + # Fallback: create a new LLM if model_name is available + if hasattr(self, 'model_name') and self.model_name: + try: + from swarms.utils.litellm_wrapper import LiteLLM + self.llm = LiteLLM(model_name=self.model_name) + logger.info(f"Created fallback LiteLLM with model: {self.model_name}") + except Exception as e2: + logger.error(f"Failed to create fallback LLM: {e2}") + self.llm = None + else: + self.llm = None