fix save_as_json

pull/989/head
王祥宇 2 months ago
parent 440173817d
commit ac91191c4b

@ -6,7 +6,6 @@ import threading
import uuid import uuid
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Callable,
Dict, Dict,
List, List,
Optional, Optional,
@ -190,18 +189,16 @@ class Conversation(BaseStructure):
save_enabled: bool = False, # New parameter to control if saving is enabled save_enabled: bool = False, # New parameter to control if saving is enabled
save_filepath: str = None, save_filepath: str = None,
load_filepath: str = None, # New parameter to specify which file to load from load_filepath: str = None, # New parameter to specify which file to load from
tokenizer: Callable = None,
context_length: int = 8192, context_length: int = 8192,
rules: str = None, rules: str = None,
custom_rules_prompt: str = None, custom_rules_prompt: str = None,
user: str = "User:", user: str = "User",
save_as_yaml: bool = False, save_as_yaml: bool = False,
save_as_json_bool: bool = False, save_as_json_bool: bool = False,
token_count: bool = True, token_count: bool = False,
message_id_on: bool = False, message_id_on: bool = False,
provider: providers = "in-memory", provider: providers = "in-memory",
backend: Optional[str] = None, backend: Optional[str] = None,
# Backend-specific parameters
supabase_url: Optional[str] = None, supabase_url: Optional[str] = None,
supabase_key: Optional[str] = None, supabase_key: Optional[str] = None,
redis_host: str = "localhost", redis_host: str = "localhost",
@ -210,7 +207,6 @@ class Conversation(BaseStructure):
redis_password: Optional[str] = None, redis_password: Optional[str] = None,
db_path: Optional[str] = None, db_path: Optional[str] = None,
table_name: str = "conversations", table_name: str = "conversations",
# Additional backend parameters
use_embedded_redis: bool = True, use_embedded_redis: bool = True,
persist_redis: bool = True, persist_redis: bool = True,
auto_persist: bool = True, auto_persist: bool = True,
@ -230,20 +226,7 @@ class Conversation(BaseStructure):
self.save_enabled = save_enabled self.save_enabled = save_enabled
self.conversations_dir = conversations_dir self.conversations_dir = conversations_dir
self.message_id_on = message_id_on self.message_id_on = message_id_on
# Handle save filepath
if save_enabled and save_filepath:
self.save_filepath = save_filepath
elif save_enabled and conversations_dir:
self.save_filepath = os.path.join(
conversations_dir, f"{self.id}.json"
)
else:
self.save_filepath = None
self.load_filepath = load_filepath self.load_filepath = load_filepath
self.conversation_history = []
self.tokenizer = tokenizer
self.context_length = context_length self.context_length = context_length
self.rules = rules self.rules = rules
self.custom_rules_prompt = custom_rules_prompt self.custom_rules_prompt = custom_rules_prompt
@ -253,9 +236,40 @@ class Conversation(BaseStructure):
self.token_count = token_count self.token_count = token_count
self.provider = provider # Keep for backwards compatibility self.provider = provider # Keep for backwards compatibility
self.conversations_dir = conversations_dir self.conversations_dir = conversations_dir
self.backend = backend
self.supabase_url = supabase_url
self.supabase_key = supabase_key
self.redis_host = redis_host
self.redis_port = redis_port
self.redis_db = redis_db
self.redis_password = redis_password
self.db_path = db_path
self.table_name = table_name
self.use_embedded_redis = use_embedded_redis
self.persist_redis = persist_redis
self.auto_persist = auto_persist
self.redis_data_dir = redis_data_dir
self.conversation_history = []
# Handle save filepath
if save_enabled and save_filepath:
self.save_filepath = save_filepath
elif save_enabled and conversations_dir:
self.save_filepath = os.path.join(
conversations_dir, f"{self.id}.json"
)
else:
self.save_filepath = None
# Support both 'provider' and 'backend' parameters for backwards compatibility # Support both 'provider' and 'backend' parameters for backwards compatibility
# 'backend' takes precedence if both are provided # 'backend' takes precedence if both are provided
self.backend_setup(backend, provider)
def backend_setup(
self, backend: str = None, provider: str = None
):
self.backend = backend or provider self.backend = backend or provider
self.backend_instance = None self.backend_instance = None
@ -285,19 +299,18 @@ class Conversation(BaseStructure):
]: ]:
try: try:
self._initialize_backend( self._initialize_backend(
supabase_url=supabase_url, supabase_url=self.supabase_url,
supabase_key=supabase_key, supabase_key=self.supabase_key,
redis_host=redis_host, redis_host=self.redis_host,
redis_port=redis_port, redis_port=self.redis_port,
redis_db=redis_db, redis_db=self.redis_db,
redis_password=redis_password, redis_password=self.redis_password,
db_path=db_path, db_path=self.db_path,
table_name=table_name, table_name=self.table_name,
use_embedded_redis=use_embedded_redis, use_embedded_redis=self.use_embedded_redis,
persist_redis=persist_redis, persist_redis=self.persist_redis,
auto_persist=auto_persist, auto_persist=self.auto_persist,
redis_data_dir=redis_data_dir, redis_data_dir=self.redis_data_dir,
**kwargs,
) )
except Exception as e: except Exception as e:
logger.warning( logger.warning(
@ -324,7 +337,6 @@ class Conversation(BaseStructure):
"time_enabled": self.time_enabled, "time_enabled": self.time_enabled,
"autosave": self.autosave, "autosave": self.autosave,
"save_filepath": self.save_filepath, "save_filepath": self.save_filepath,
"tokenizer": self.tokenizer,
"context_length": self.context_length, "context_length": self.context_length,
"rules": self.rules, "rules": self.rules,
"custom_rules_prompt": self.custom_rules_prompt, "custom_rules_prompt": self.custom_rules_prompt,
@ -449,8 +461,8 @@ class Conversation(BaseStructure):
if self.custom_rules_prompt is not None: if self.custom_rules_prompt is not None:
self.add(self.user or "User", self.custom_rules_prompt) self.add(self.user or "User", self.custom_rules_prompt)
if self.tokenizer is not None: # if self.tokenizer is not None:
self.truncate_memory_with_tokenizer() # self.truncate_memory_with_tokenizer()
def _autosave(self): def _autosave(self):
"""Automatically save the conversation if autosave is enabled.""" """Automatically save the conversation if autosave is enabled."""
@ -950,6 +962,10 @@ class Conversation(BaseStructure):
# Don't save if saving is disabled # Don't save if saving is disabled
if not self.save_enabled: if not self.save_enabled:
logger.warning(
"An attempt to save the conversation failed: save_enabled is False."
"Please set save_enabled=True when creating a Conversation object to enable saving."
)
return return
save_path = filename or self.save_filepath save_path = filename or self.save_filepath
@ -1051,9 +1067,7 @@ class Conversation(BaseStructure):
for message in self.conversation_history: for message in self.conversation_history:
role = message.get("role") role = message.get("role")
content = message.get("content") content = message.get("content")
tokens = self.tokenizer.count_tokens( tokens = count_tokens(content)
text=content
) # Count the number of tokens
count = tokens # Assign the token count count = tokens # Assign the token count
total_tokens += count total_tokens += count

Loading…
Cancel
Save