diff --git a/swarms/structs/agent.py b/swarms/structs/agent.py index f2bd913f..a2597c4e 100644 --- a/swarms/structs/agent.py +++ b/swarms/structs/agent.py @@ -2888,9 +2888,10 @@ class Agent: # Handle special serialization for non-serializable objects state_dict['logger_handler'] = self._serialize_logger_handler() state_dict['short_memory'] = self._serialize_short_memory() + state_dict['tokenizer'] = self._serialize_tokenizer() - # Remove other non-serializable objects - non_serializable = ['llm', 'tokenizer', 'long_term_memory', 'agent_output', 'executor'] + # Handle other non-serializable objects + non_serializable = ['llm', 'long_term_memory', 'agent_output', 'executor'] for key in non_serializable: if key in state_dict: if state_dict[key] is not None: @@ -2937,9 +2938,10 @@ class Agent: # Handle special deserialization first logger_config = state_dict.pop('logger_handler', None) short_memory_config = state_dict.pop('short_memory', None) - + tokenizer_config = state_dict.pop('tokenizer', None) self._deserialize_logger_handler(logger_config) self._deserialize_short_memory(short_memory_config) + self._deserialize_tokenizer(tokenizer_config) # Update remaining agent attributes for key, value in state_dict.items(): @@ -2953,3 +2955,55 @@ class Agent: except Exception as e: logger.error(f"Error loading agent state: {e}") raise + + def _serialize_tokenizer(self): + """Serialize any tokenizer that has a count method and model_name attribute.""" + if not hasattr(self, 'tokenizer') or self.tokenizer is None: + return None + + try: + # Get the tokenizer's class name and module + tokenizer_class = type(self.tokenizer).__name__ + tokenizer_module = type(self.tokenizer).__module__ + + # Get all attributes that are serializable + attrs = {} + for attr_name in dir(self.tokenizer): + if not attr_name.startswith('_'): # Skip private attributes + try: + attr_value = getattr(self.tokenizer, attr_name) + if not callable(attr_value): # Skip methods + # Try to serialize the attribute + json.dumps(attr_value) + attrs[attr_name] = attr_value + except (TypeError, ValueError): + continue + + return { + 'class': tokenizer_class, + 'module': tokenizer_module, + 'attributes': attrs + } + except Exception as e: + logger.warning(f"Failed to serialize tokenizer: {e}") + return None + + def _deserialize_tokenizer(self, config): + """Deserialize any tokenizer that was previously serialized.""" + if not config: + self.tokenizer = None + return + + try: + # Import the tokenizer class from its module + module = __import__(config['module'], fromlist=[config['class']]) + tokenizer_class = getattr(module, config['class']) + + # Create a new instance with the saved attributes + attrs = config.get('attributes', {}) + self.tokenizer = tokenizer_class(**attrs) + + logger.info(f"Successfully restored tokenizer of type {config['class']}") + except Exception as e: + logger.warning(f"Failed to recreate tokenizer: {e}") + self.tokenizer = None