fix bugs truncate_memory_with_tokenizer

pull/995/head
王祥宇 1 month ago
parent d6ef64eb4a
commit 932a3e7a47

@ -184,6 +184,7 @@ class Conversation:
system_prompt: Optional[str] = None,
time_enabled: bool = False,
autosave: bool = False, # Changed default to False
save_enabled: bool = False, # New parameter to control if saving is enabled
save_filepath: str = None,
load_filepath: str = None, # New parameter to specify which file to load from
context_length: int = 8192,
@ -222,6 +223,7 @@ class Conversation:
self.system_prompt = system_prompt
self.time_enabled = time_enabled
self.autosave = autosave
self.save_enabled = save_enabled
self.conversations_dir = conversations_dir
self.tokenizer_model_name = tokenizer_model_name
self.message_id_on = message_id_on
@ -1019,6 +1021,13 @@ class Conversation:
)
return
# Don't save if saving is disabled (你的PR代码)
if not self.save_enabled:
logger.warning(
"An attempt to save the conversation failed: save_enabled is False."
"Please set save_enabled=True when creating a Conversation object to enable saving."
)
return
# Get the full data including metadata and conversation history
data = self.get_init_params()
@ -1267,39 +1276,121 @@ class Conversation:
def truncate_memory_with_tokenizer(self):
"""
Truncates the conversation history based on the total number of tokens using a tokenizer.
Truncate conversation history based on the total token count using tokenizer.
This version is more generic, not dependent on a specific LLM model, and can work with any model that provides a counter.
Uses count_tokens function to calculate and truncate by message, ensuring the result is still valid content.
Returns:
None
"""
from swarms.utils.litellm_tokenizer import count_tokens
total_tokens = 0
truncated_history = []
for message in self.conversation_history:
role = message.get("role")
content = message.get("content")
tokens = count_tokens(content, self.tokenizer_model_name)
count = tokens # Assign the token count
total_tokens += count
if total_tokens <= self.context_length:
# Convert content to string if it's not already a string
if not isinstance(content, str):
content = str(content)
# Calculate token count for this message
token_count = count_tokens(content, self.tokenizer_model_name)
# Check if adding this message would exceed the limit
if total_tokens + token_count <= self.context_length:
# If not exceeding limit, add the full message
truncated_history.append(message)
total_tokens += token_count
else:
remaining_tokens = self.context_length - (
total_tokens - count
# Calculate remaining tokens we can include
remaining_tokens = self.context_length - total_tokens
# If no token space left, break the loop
if remaining_tokens <= 0:
break
# If we have space left, we need to truncate this message
# Use binary search to find content length that fits remaining token space
truncated_content = self._binary_search_truncate(
content,
remaining_tokens,
self.tokenizer_model_name
)
truncated_content = content[
:remaining_tokens
] # Truncate the content based on the remaining tokens
# Create the truncated message
truncated_message = {
"role": role,
"content": truncated_content,
}
# Add any other fields from the original message
for key, value in message.items():
if key not in ["role", "content"]:
truncated_message[key] = value
truncated_history.append(truncated_message)
break
# Update conversation history
self.conversation_history = truncated_history
def _binary_search_truncate(self, text, target_tokens, model_name):
"""
Use binary search to find the maximum text substring that fits the target token count.
Parameters:
text (str): Original text to truncate
target_tokens (int): Target token count
model_name (str): Model name for token counting
Returns:
str: Truncated text with token count not exceeding target_tokens
"""
from swarms.utils.litellm_tokenizer import count_tokens
# If text is empty or target tokens is 0, return empty string
if not text or target_tokens <= 0:
return ""
# If original text token count is already less than or equal to target, return as is
original_tokens = count_tokens(text, model_name)
if original_tokens <= target_tokens:
return text
# Binary search
left, right = 0, len(text)
best_length = 0
best_text = ""
while left <= right:
mid = (left + right) // 2
truncated = text[:mid]
tokens = count_tokens(truncated, model_name)
if tokens <= target_tokens:
# If current truncated text token count is less than or equal to target, try longer text
best_length = mid
best_text = truncated
left = mid + 1
else:
# Otherwise try shorter text
right = mid - 1
# Try to truncate at sentence boundaries if possible
sentence_delimiters = ['.', '!', '?', '\n']
for delimiter in sentence_delimiters:
last_pos = best_text.rfind(delimiter)
if last_pos > len(best_text) * 0.75: # Only truncate at sentence boundary if we don't lose too much content
truncated_at_sentence = best_text[:last_pos+1]
if count_tokens(truncated_at_sentence, model_name) <= target_tokens:
return truncated_at_sentence
return best_text
def clear(self):
"""Clear the conversation history."""
if self.backend_instance:

Loading…
Cancel
Save