|
|
|
@ -184,6 +184,7 @@ class Conversation:
|
|
|
|
|
system_prompt: Optional[str] = None,
|
|
|
|
|
time_enabled: bool = False,
|
|
|
|
|
autosave: bool = False, # Changed default to False
|
|
|
|
|
save_enabled: bool = False, # New parameter to control if saving is enabled
|
|
|
|
|
save_filepath: str = None,
|
|
|
|
|
load_filepath: str = None, # New parameter to specify which file to load from
|
|
|
|
|
context_length: int = 8192,
|
|
|
|
@ -222,6 +223,7 @@ class Conversation:
|
|
|
|
|
self.system_prompt = system_prompt
|
|
|
|
|
self.time_enabled = time_enabled
|
|
|
|
|
self.autosave = autosave
|
|
|
|
|
self.save_enabled = save_enabled
|
|
|
|
|
self.conversations_dir = conversations_dir
|
|
|
|
|
self.tokenizer_model_name = tokenizer_model_name
|
|
|
|
|
self.message_id_on = message_id_on
|
|
|
|
@ -1019,6 +1021,13 @@ class Conversation:
|
|
|
|
|
)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Don't save if saving is disabled (你的PR代码)
|
|
|
|
|
if not self.save_enabled:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"An attempt to save the conversation failed: save_enabled is False."
|
|
|
|
|
"Please set save_enabled=True when creating a Conversation object to enable saving."
|
|
|
|
|
)
|
|
|
|
|
return
|
|
|
|
|
# Get the full data including metadata and conversation history
|
|
|
|
|
data = self.get_init_params()
|
|
|
|
|
|
|
|
|
@ -1267,39 +1276,121 @@ class Conversation:
|
|
|
|
|
|
|
|
|
|
def truncate_memory_with_tokenizer(self):
|
|
|
|
|
"""
|
|
|
|
|
Truncates the conversation history based on the total number of tokens using a tokenizer.
|
|
|
|
|
|
|
|
|
|
Truncate conversation history based on the total token count using tokenizer.
|
|
|
|
|
|
|
|
|
|
This version is more generic, not dependent on a specific LLM model, and can work with any model that provides a counter.
|
|
|
|
|
Uses count_tokens function to calculate and truncate by message, ensuring the result is still valid content.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
"""
|
|
|
|
|
from swarms.utils.litellm_tokenizer import count_tokens
|
|
|
|
|
|
|
|
|
|
total_tokens = 0
|
|
|
|
|
truncated_history = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for message in self.conversation_history:
|
|
|
|
|
role = message.get("role")
|
|
|
|
|
content = message.get("content")
|
|
|
|
|
tokens = count_tokens(content, self.tokenizer_model_name)
|
|
|
|
|
count = tokens # Assign the token count
|
|
|
|
|
total_tokens += count
|
|
|
|
|
|
|
|
|
|
if total_tokens <= self.context_length:
|
|
|
|
|
|
|
|
|
|
# Convert content to string if it's not already a string
|
|
|
|
|
if not isinstance(content, str):
|
|
|
|
|
content = str(content)
|
|
|
|
|
|
|
|
|
|
# Calculate token count for this message
|
|
|
|
|
token_count = count_tokens(content, self.tokenizer_model_name)
|
|
|
|
|
|
|
|
|
|
# Check if adding this message would exceed the limit
|
|
|
|
|
if total_tokens + token_count <= self.context_length:
|
|
|
|
|
# If not exceeding limit, add the full message
|
|
|
|
|
truncated_history.append(message)
|
|
|
|
|
total_tokens += token_count
|
|
|
|
|
else:
|
|
|
|
|
remaining_tokens = self.context_length - (
|
|
|
|
|
total_tokens - count
|
|
|
|
|
# Calculate remaining tokens we can include
|
|
|
|
|
remaining_tokens = self.context_length - total_tokens
|
|
|
|
|
|
|
|
|
|
# If no token space left, break the loop
|
|
|
|
|
if remaining_tokens <= 0:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# If we have space left, we need to truncate this message
|
|
|
|
|
# Use binary search to find content length that fits remaining token space
|
|
|
|
|
truncated_content = self._binary_search_truncate(
|
|
|
|
|
content,
|
|
|
|
|
remaining_tokens,
|
|
|
|
|
self.tokenizer_model_name
|
|
|
|
|
)
|
|
|
|
|
truncated_content = content[
|
|
|
|
|
:remaining_tokens
|
|
|
|
|
] # Truncate the content based on the remaining tokens
|
|
|
|
|
|
|
|
|
|
# Create the truncated message
|
|
|
|
|
truncated_message = {
|
|
|
|
|
"role": role,
|
|
|
|
|
"content": truncated_content,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Add any other fields from the original message
|
|
|
|
|
for key, value in message.items():
|
|
|
|
|
if key not in ["role", "content"]:
|
|
|
|
|
truncated_message[key] = value
|
|
|
|
|
|
|
|
|
|
truncated_history.append(truncated_message)
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Update conversation history
|
|
|
|
|
self.conversation_history = truncated_history
|
|
|
|
|
|
|
|
|
|
def _binary_search_truncate(self, text, target_tokens, model_name):
|
|
|
|
|
"""
|
|
|
|
|
Use binary search to find the maximum text substring that fits the target token count.
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
text (str): Original text to truncate
|
|
|
|
|
target_tokens (int): Target token count
|
|
|
|
|
model_name (str): Model name for token counting
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
str: Truncated text with token count not exceeding target_tokens
|
|
|
|
|
"""
|
|
|
|
|
from swarms.utils.litellm_tokenizer import count_tokens
|
|
|
|
|
|
|
|
|
|
# If text is empty or target tokens is 0, return empty string
|
|
|
|
|
if not text or target_tokens <= 0:
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
# If original text token count is already less than or equal to target, return as is
|
|
|
|
|
original_tokens = count_tokens(text, model_name)
|
|
|
|
|
if original_tokens <= target_tokens:
|
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
# Binary search
|
|
|
|
|
left, right = 0, len(text)
|
|
|
|
|
best_length = 0
|
|
|
|
|
best_text = ""
|
|
|
|
|
|
|
|
|
|
while left <= right:
|
|
|
|
|
mid = (left + right) // 2
|
|
|
|
|
truncated = text[:mid]
|
|
|
|
|
tokens = count_tokens(truncated, model_name)
|
|
|
|
|
|
|
|
|
|
if tokens <= target_tokens:
|
|
|
|
|
# If current truncated text token count is less than or equal to target, try longer text
|
|
|
|
|
best_length = mid
|
|
|
|
|
best_text = truncated
|
|
|
|
|
left = mid + 1
|
|
|
|
|
else:
|
|
|
|
|
# Otherwise try shorter text
|
|
|
|
|
right = mid - 1
|
|
|
|
|
|
|
|
|
|
# Try to truncate at sentence boundaries if possible
|
|
|
|
|
sentence_delimiters = ['.', '!', '?', '\n']
|
|
|
|
|
for delimiter in sentence_delimiters:
|
|
|
|
|
last_pos = best_text.rfind(delimiter)
|
|
|
|
|
if last_pos > len(best_text) * 0.75: # Only truncate at sentence boundary if we don't lose too much content
|
|
|
|
|
truncated_at_sentence = best_text[:last_pos+1]
|
|
|
|
|
if count_tokens(truncated_at_sentence, model_name) <= target_tokens:
|
|
|
|
|
return truncated_at_sentence
|
|
|
|
|
|
|
|
|
|
return best_text
|
|
|
|
|
|
|
|
|
|
def clear(self):
|
|
|
|
|
"""Clear the conversation history."""
|
|
|
|
|
if self.backend_instance:
|
|
|
|
|