From 932a3e7a472d79b1cb704489e06a73dd534e31cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=A5=A5=E5=AE=87?= <625024108@qq.com> Date: Wed, 30 Jul 2025 15:13:14 +0800 Subject: [PATCH] fix bugs truncate_memory_with_tokenizer --- swarms/structs/conversation.py | 119 +++++++++++++++++++++++++++++---- 1 file changed, 105 insertions(+), 14 deletions(-) diff --git a/swarms/structs/conversation.py b/swarms/structs/conversation.py index 7c8d3109..3e47cae1 100644 --- a/swarms/structs/conversation.py +++ b/swarms/structs/conversation.py @@ -184,6 +184,7 @@ class Conversation: system_prompt: Optional[str] = None, time_enabled: bool = False, autosave: bool = False, # Changed default to False + save_enabled: bool = False, # New parameter to control if saving is enabled save_filepath: str = None, load_filepath: str = None, # New parameter to specify which file to load from context_length: int = 8192, @@ -222,6 +223,7 @@ class Conversation: self.system_prompt = system_prompt self.time_enabled = time_enabled self.autosave = autosave + self.save_enabled = save_enabled self.conversations_dir = conversations_dir self.tokenizer_model_name = tokenizer_model_name self.message_id_on = message_id_on @@ -1019,6 +1021,13 @@ class Conversation: ) return + # Don't save if saving is disabled (你的PR代码) + if not self.save_enabled: + logger.warning( + "An attempt to save the conversation failed: save_enabled is False." + "Please set save_enabled=True when creating a Conversation object to enable saving." + ) + return # Get the full data including metadata and conversation history data = self.get_init_params() @@ -1267,39 +1276,121 @@ class Conversation: def truncate_memory_with_tokenizer(self): """ - Truncates the conversation history based on the total number of tokens using a tokenizer. - + Truncate conversation history based on the total token count using tokenizer. + + This version is more generic, not dependent on a specific LLM model, and can work with any model that provides a counter. + Uses count_tokens function to calculate and truncate by message, ensuring the result is still valid content. + Returns: None """ + from swarms.utils.litellm_tokenizer import count_tokens + total_tokens = 0 truncated_history = [] - + for message in self.conversation_history: role = message.get("role") content = message.get("content") - tokens = count_tokens(content, self.tokenizer_model_name) - count = tokens # Assign the token count - total_tokens += count - - if total_tokens <= self.context_length: + + # Convert content to string if it's not already a string + if not isinstance(content, str): + content = str(content) + + # Calculate token count for this message + token_count = count_tokens(content, self.tokenizer_model_name) + + # Check if adding this message would exceed the limit + if total_tokens + token_count <= self.context_length: + # If not exceeding limit, add the full message truncated_history.append(message) + total_tokens += token_count else: - remaining_tokens = self.context_length - ( - total_tokens - count + # Calculate remaining tokens we can include + remaining_tokens = self.context_length - total_tokens + + # If no token space left, break the loop + if remaining_tokens <= 0: + break + + # If we have space left, we need to truncate this message + # Use binary search to find content length that fits remaining token space + truncated_content = self._binary_search_truncate( + content, + remaining_tokens, + self.tokenizer_model_name ) - truncated_content = content[ - :remaining_tokens - ] # Truncate the content based on the remaining tokens + + # Create the truncated message truncated_message = { "role": role, "content": truncated_content, } + + # Add any other fields from the original message + for key, value in message.items(): + if key not in ["role", "content"]: + truncated_message[key] = value + truncated_history.append(truncated_message) break - + + # Update conversation history self.conversation_history = truncated_history + def _binary_search_truncate(self, text, target_tokens, model_name): + """ + Use binary search to find the maximum text substring that fits the target token count. + + Parameters: + text (str): Original text to truncate + target_tokens (int): Target token count + model_name (str): Model name for token counting + + Returns: + str: Truncated text with token count not exceeding target_tokens + """ + from swarms.utils.litellm_tokenizer import count_tokens + + # If text is empty or target tokens is 0, return empty string + if not text or target_tokens <= 0: + return "" + + # If original text token count is already less than or equal to target, return as is + original_tokens = count_tokens(text, model_name) + if original_tokens <= target_tokens: + return text + + # Binary search + left, right = 0, len(text) + best_length = 0 + best_text = "" + + while left <= right: + mid = (left + right) // 2 + truncated = text[:mid] + tokens = count_tokens(truncated, model_name) + + if tokens <= target_tokens: + # If current truncated text token count is less than or equal to target, try longer text + best_length = mid + best_text = truncated + left = mid + 1 + else: + # Otherwise try shorter text + right = mid - 1 + + # Try to truncate at sentence boundaries if possible + sentence_delimiters = ['.', '!', '?', '\n'] + for delimiter in sentence_delimiters: + last_pos = best_text.rfind(delimiter) + if last_pos > len(best_text) * 0.75: # Only truncate at sentence boundary if we don't lose too much content + truncated_at_sentence = best_text[:last_pos+1] + if count_tokens(truncated_at_sentence, model_name) <= target_tokens: + return truncated_at_sentence + + return best_text + def clear(self): """Clear the conversation history.""" if self.backend_instance: