diff --git a/examples/utils/misc/conversation_test.py b/examples/utils/misc/conversation_test.py index ec8a0534..a7a6750e 100644 --- a/examples/utils/misc/conversation_test.py +++ b/examples/utils/misc/conversation_test.py @@ -19,4 +19,4 @@ print( conversation.export_and_count_categories( tokenizer_model_name="claude-3-5-sonnet-20240620" ) -) +) \ No newline at end of file diff --git a/examples/utils/misc/conversation_test_truncate.py b/examples/utils/misc/conversation_test_truncate.py new file mode 100644 index 00000000..f660ff45 --- /dev/null +++ b/examples/utils/misc/conversation_test_truncate.py @@ -0,0 +1,91 @@ +from swarms.structs.conversation import Conversation +from dotenv import load_dotenv +from swarms.utils.litellm_tokenizer import count_tokens + +# Load environment variables from .env file +load_dotenv() + +def demonstrate_truncation(): + # Using a smaller context length to clearly see the truncation effect + context_length = 25 + print(f"Creating a conversation instance with context length {context_length}") + + # Using Claude model as the tokenizer model + conversation = Conversation( + context_length=context_length, + tokenizer_model_name="claude-3-7-sonnet-20250219" + ) + + # Adding first message - short message + short_message = "Hello, I am a user." + print(f"\nAdding short message: '{short_message}'") + conversation.add("user", short_message) + + # Display token count + + tokens = count_tokens(short_message, conversation.tokenizer_model_name) + print(f"Short message token count: {tokens}") + + # Adding second message - long message, should be truncated + long_message = "I have a question about artificial intelligence. I want to understand how large language models handle long texts, especially under token constraints. This issue is important because it relates to the model's practicality and effectiveness. I hope to get a detailed answer that helps me understand this complex technical problem." + print(f"\nAdding long message:\n'{long_message}'") + conversation.add("assistant", long_message) + + # Display long message token count + tokens = count_tokens(long_message, conversation.tokenizer_model_name) + print(f"Long message token count: {tokens}") + + # Display current conversation total token count + total_tokens = sum(count_tokens(msg["content"], conversation.tokenizer_model_name) + for msg in conversation.conversation_history) + print(f"Total token count before truncation: {total_tokens}") + + # Print the complete conversation history before truncation + print("\nConversation history before truncation:") + for i, msg in enumerate(conversation.conversation_history): + print(f"[{i}] {msg['role']}: {msg['content']}") + print(f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}") + + # Execute truncation + print("\nExecuting truncation...") + conversation.truncate_memory_with_tokenizer() + + # Print conversation history after truncation + print("\nConversation history after truncation:") + for i, msg in enumerate(conversation.conversation_history): + print(f"[{i}] {msg['role']}: {msg['content']}") + print(f" Token count: {count_tokens(msg['content'], conversation.tokenizer_model_name)}") + + # Display total token count after truncation + total_tokens = sum(count_tokens(msg["content"], conversation.tokenizer_model_name) + for msg in conversation.conversation_history) + print(f"\nTotal token count after truncation: {total_tokens}") + print(f"Context length limit: {context_length}") + + # Verify if successfully truncated below the limit + if total_tokens <= context_length: + print("✅ Success: Total token count is now less than or equal to context length limit") + else: + print("❌ Failure: Total token count still exceeds context length limit") + + # Test sentence boundary truncation + print("\n\nTesting sentence boundary truncation:") + sentence_test = Conversation(context_length=15, tokenizer_model_name="claude-3-opus-20240229") + test_text = "This is the first sentence. This is the second very long sentence that contains a lot of content. This is the third sentence." + print(f"Original text: '{test_text}'") + print(f"Original token count: {count_tokens(test_text, sentence_test.tokenizer_model_name)}") + + # Using binary search for truncation + truncated = sentence_test._binary_search_truncate(test_text, 10, sentence_test.tokenizer_model_name) + print(f"Truncated text: '{truncated}'") + print(f"Truncated token count: {count_tokens(truncated, sentence_test.tokenizer_model_name)}") + + # Check if truncated at period + if truncated.endswith("."): + print("✅ Success: Text was truncated at sentence boundary") + else: + print("Note: Text was not truncated at sentence boundary") + + +if __name__ == "__main__": + demonstrate_truncation() \ No newline at end of file diff --git a/swarms/structs/conversation.py b/swarms/structs/conversation.py index 7c8d3109..6e11ebdd 100644 --- a/swarms/structs/conversation.py +++ b/swarms/structs/conversation.py @@ -1267,39 +1267,121 @@ class Conversation: def truncate_memory_with_tokenizer(self): """ - Truncates the conversation history based on the total number of tokens using a tokenizer. - + Truncate conversation history based on the total token count using tokenizer. + + This version is more generic, not dependent on a specific LLM model, and can work with any model that provides a counter. + Uses count_tokens function to calculate and truncate by message, ensuring the result is still valid content. + Returns: None """ + + total_tokens = 0 truncated_history = [] - + for message in self.conversation_history: role = message.get("role") content = message.get("content") - tokens = count_tokens(content, self.tokenizer_model_name) - count = tokens # Assign the token count - total_tokens += count - - if total_tokens <= self.context_length: + + # Convert content to string if it's not already a string + if not isinstance(content, str): + content = str(content) + + # Calculate token count for this message + token_count = count_tokens(content, self.tokenizer_model_name) + + # Check if adding this message would exceed the limit + if total_tokens + token_count <= self.context_length: + # If not exceeding limit, add the full message truncated_history.append(message) + total_tokens += token_count else: - remaining_tokens = self.context_length - ( - total_tokens - count + # Calculate remaining tokens we can include + remaining_tokens = self.context_length - total_tokens + + # If no token space left, break the loop + if remaining_tokens <= 0: + break + + # If we have space left, we need to truncate this message + # Use binary search to find content length that fits remaining token space + truncated_content = self._binary_search_truncate( + content, + remaining_tokens, + self.tokenizer_model_name ) - truncated_content = content[ - :remaining_tokens - ] # Truncate the content based on the remaining tokens + + # Create the truncated message truncated_message = { "role": role, "content": truncated_content, } + + # Add any other fields from the original message + for key, value in message.items(): + if key not in ["role", "content"]: + truncated_message[key] = value + truncated_history.append(truncated_message) break - + + # Update conversation history self.conversation_history = truncated_history + def _binary_search_truncate(self, text, target_tokens, model_name): + """ + Use binary search to find the maximum text substring that fits the target token count. + + Parameters: + text (str): Original text to truncate + target_tokens (int): Target token count + model_name (str): Model name for token counting + + Returns: + str: Truncated text with token count not exceeding target_tokens + """ + + + # If text is empty or target tokens is 0, return empty string + if not text or target_tokens <= 0: + return "" + + # If original text token count is already less than or equal to target, return as is + original_tokens = count_tokens(text, model_name) + if original_tokens <= target_tokens: + return text + + # Binary search + left, right = 0, len(text) + best_length = 0 + best_text = "" + + while left <= right: + mid = (left + right) // 2 + truncated = text[:mid] + tokens = count_tokens(truncated, model_name) + + if tokens <= target_tokens: + # If current truncated text token count is less than or equal to target, try longer text + best_length = mid + best_text = truncated + left = mid + 1 + else: + # Otherwise try shorter text + right = mid - 1 + + # Try to truncate at sentence boundaries if possible + sentence_delimiters = ['.', '!', '?', '\n'] + for delimiter in sentence_delimiters: + last_pos = best_text.rfind(delimiter) + if last_pos > len(best_text) * 0.75: # Only truncate at sentence boundary if we don't lose too much content + truncated_at_sentence = best_text[:last_pos+1] + if count_tokens(truncated_at_sentence, model_name) <= target_tokens: + return truncated_at_sentence + + return best_text + def clear(self): """Clear the conversation history.""" if self.backend_instance: @@ -1705,4 +1787,4 @@ class Conversation: # # # conversation.add("assistant", "I am doing well, thanks.") # # # # print(conversation.to_json()) # # print(type(conversation.to_dict())) -# # print(conversation.to_yaml()) +# # print(conversation.to_yaml()) \ No newline at end of file