|
|
|
@ -6,7 +6,6 @@ import threading
|
|
|
|
|
import uuid
|
|
|
|
|
from typing import (
|
|
|
|
|
TYPE_CHECKING,
|
|
|
|
|
Callable,
|
|
|
|
|
Dict,
|
|
|
|
|
List,
|
|
|
|
|
Optional,
|
|
|
|
@ -190,18 +189,16 @@ class Conversation(BaseStructure):
|
|
|
|
|
save_enabled: bool = False, # New parameter to control if saving is enabled
|
|
|
|
|
save_filepath: str = None,
|
|
|
|
|
load_filepath: str = None, # New parameter to specify which file to load from
|
|
|
|
|
tokenizer: Callable = None,
|
|
|
|
|
context_length: int = 8192,
|
|
|
|
|
rules: str = None,
|
|
|
|
|
custom_rules_prompt: str = None,
|
|
|
|
|
user: str = "User:",
|
|
|
|
|
user: str = "User",
|
|
|
|
|
save_as_yaml: bool = False,
|
|
|
|
|
save_as_json_bool: bool = False,
|
|
|
|
|
token_count: bool = True,
|
|
|
|
|
token_count: bool = False,
|
|
|
|
|
message_id_on: bool = False,
|
|
|
|
|
provider: providers = "in-memory",
|
|
|
|
|
backend: Optional[str] = None,
|
|
|
|
|
# Backend-specific parameters
|
|
|
|
|
supabase_url: Optional[str] = None,
|
|
|
|
|
supabase_key: Optional[str] = None,
|
|
|
|
|
redis_host: str = "localhost",
|
|
|
|
@ -210,7 +207,6 @@ class Conversation(BaseStructure):
|
|
|
|
|
redis_password: Optional[str] = None,
|
|
|
|
|
db_path: Optional[str] = None,
|
|
|
|
|
table_name: str = "conversations",
|
|
|
|
|
# Additional backend parameters
|
|
|
|
|
use_embedded_redis: bool = True,
|
|
|
|
|
persist_redis: bool = True,
|
|
|
|
|
auto_persist: bool = True,
|
|
|
|
@ -230,20 +226,7 @@ class Conversation(BaseStructure):
|
|
|
|
|
self.save_enabled = save_enabled
|
|
|
|
|
self.conversations_dir = conversations_dir
|
|
|
|
|
self.message_id_on = message_id_on
|
|
|
|
|
|
|
|
|
|
# Handle save filepath
|
|
|
|
|
if save_enabled and save_filepath:
|
|
|
|
|
self.save_filepath = save_filepath
|
|
|
|
|
elif save_enabled and conversations_dir:
|
|
|
|
|
self.save_filepath = os.path.join(
|
|
|
|
|
conversations_dir, f"{self.id}.json"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
self.save_filepath = None
|
|
|
|
|
|
|
|
|
|
self.load_filepath = load_filepath
|
|
|
|
|
self.conversation_history = []
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
self.context_length = context_length
|
|
|
|
|
self.rules = rules
|
|
|
|
|
self.custom_rules_prompt = custom_rules_prompt
|
|
|
|
@ -253,9 +236,40 @@ class Conversation(BaseStructure):
|
|
|
|
|
self.token_count = token_count
|
|
|
|
|
self.provider = provider # Keep for backwards compatibility
|
|
|
|
|
self.conversations_dir = conversations_dir
|
|
|
|
|
self.backend = backend
|
|
|
|
|
self.supabase_url = supabase_url
|
|
|
|
|
self.supabase_key = supabase_key
|
|
|
|
|
self.redis_host = redis_host
|
|
|
|
|
self.redis_port = redis_port
|
|
|
|
|
self.redis_db = redis_db
|
|
|
|
|
self.redis_password = redis_password
|
|
|
|
|
self.db_path = db_path
|
|
|
|
|
self.table_name = table_name
|
|
|
|
|
self.use_embedded_redis = use_embedded_redis
|
|
|
|
|
self.persist_redis = persist_redis
|
|
|
|
|
self.auto_persist = auto_persist
|
|
|
|
|
self.redis_data_dir = redis_data_dir
|
|
|
|
|
|
|
|
|
|
self.conversation_history = []
|
|
|
|
|
|
|
|
|
|
# Handle save filepath
|
|
|
|
|
if save_enabled and save_filepath:
|
|
|
|
|
self.save_filepath = save_filepath
|
|
|
|
|
elif save_enabled and conversations_dir:
|
|
|
|
|
self.save_filepath = os.path.join(
|
|
|
|
|
conversations_dir, f"{self.id}.json"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
self.save_filepath = None
|
|
|
|
|
|
|
|
|
|
# Support both 'provider' and 'backend' parameters for backwards compatibility
|
|
|
|
|
# 'backend' takes precedence if both are provided
|
|
|
|
|
|
|
|
|
|
self.backend_setup(backend, provider)
|
|
|
|
|
|
|
|
|
|
def backend_setup(
|
|
|
|
|
self, backend: str = None, provider: str = None
|
|
|
|
|
):
|
|
|
|
|
self.backend = backend or provider
|
|
|
|
|
self.backend_instance = None
|
|
|
|
|
|
|
|
|
@ -285,19 +299,18 @@ class Conversation(BaseStructure):
|
|
|
|
|
]:
|
|
|
|
|
try:
|
|
|
|
|
self._initialize_backend(
|
|
|
|
|
supabase_url=supabase_url,
|
|
|
|
|
supabase_key=supabase_key,
|
|
|
|
|
redis_host=redis_host,
|
|
|
|
|
redis_port=redis_port,
|
|
|
|
|
redis_db=redis_db,
|
|
|
|
|
redis_password=redis_password,
|
|
|
|
|
db_path=db_path,
|
|
|
|
|
table_name=table_name,
|
|
|
|
|
use_embedded_redis=use_embedded_redis,
|
|
|
|
|
persist_redis=persist_redis,
|
|
|
|
|
auto_persist=auto_persist,
|
|
|
|
|
redis_data_dir=redis_data_dir,
|
|
|
|
|
**kwargs,
|
|
|
|
|
supabase_url=self.supabase_url,
|
|
|
|
|
supabase_key=self.supabase_key,
|
|
|
|
|
redis_host=self.redis_host,
|
|
|
|
|
redis_port=self.redis_port,
|
|
|
|
|
redis_db=self.redis_db,
|
|
|
|
|
redis_password=self.redis_password,
|
|
|
|
|
db_path=self.db_path,
|
|
|
|
|
table_name=self.table_name,
|
|
|
|
|
use_embedded_redis=self.use_embedded_redis,
|
|
|
|
|
persist_redis=self.persist_redis,
|
|
|
|
|
auto_persist=self.auto_persist,
|
|
|
|
|
redis_data_dir=self.redis_data_dir,
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(
|
|
|
|
@ -324,7 +337,6 @@ class Conversation(BaseStructure):
|
|
|
|
|
"time_enabled": self.time_enabled,
|
|
|
|
|
"autosave": self.autosave,
|
|
|
|
|
"save_filepath": self.save_filepath,
|
|
|
|
|
"tokenizer": self.tokenizer,
|
|
|
|
|
"context_length": self.context_length,
|
|
|
|
|
"rules": self.rules,
|
|
|
|
|
"custom_rules_prompt": self.custom_rules_prompt,
|
|
|
|
@ -449,8 +461,8 @@ class Conversation(BaseStructure):
|
|
|
|
|
if self.custom_rules_prompt is not None:
|
|
|
|
|
self.add(self.user or "User", self.custom_rules_prompt)
|
|
|
|
|
|
|
|
|
|
if self.tokenizer is not None:
|
|
|
|
|
self.truncate_memory_with_tokenizer()
|
|
|
|
|
# if self.tokenizer is not None:
|
|
|
|
|
# self.truncate_memory_with_tokenizer()
|
|
|
|
|
|
|
|
|
|
def _autosave(self):
|
|
|
|
|
"""Automatically save the conversation if autosave is enabled."""
|
|
|
|
@ -950,6 +962,10 @@ class Conversation(BaseStructure):
|
|
|
|
|
|
|
|
|
|
# Don't save if saving is disabled
|
|
|
|
|
if not self.save_enabled:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"An attempt to save the conversation failed: save_enabled is False."
|
|
|
|
|
"Please set save_enabled=True when creating a Conversation object to enable saving."
|
|
|
|
|
)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
save_path = filename or self.save_filepath
|
|
|
|
@ -1051,9 +1067,7 @@ class Conversation(BaseStructure):
|
|
|
|
|
for message in self.conversation_history:
|
|
|
|
|
role = message.get("role")
|
|
|
|
|
content = message.get("content")
|
|
|
|
|
tokens = self.tokenizer.count_tokens(
|
|
|
|
|
text=content
|
|
|
|
|
) # Count the number of tokens
|
|
|
|
|
tokens = count_tokens(content)
|
|
|
|
|
count = tokens # Assign the token count
|
|
|
|
|
total_tokens += count
|
|
|
|
|
|
|
|
|
|