diff --git a/swarms/structs/conversation.py b/swarms/structs/conversation.py index 82493f38..2d88189c 100644 --- a/swarms/structs/conversation.py +++ b/swarms/structs/conversation.py @@ -6,7 +6,6 @@ import threading import uuid from typing import ( TYPE_CHECKING, - Callable, Dict, List, Optional, @@ -190,18 +189,16 @@ class Conversation(BaseStructure): save_enabled: bool = False, # New parameter to control if saving is enabled save_filepath: str = None, load_filepath: str = None, # New parameter to specify which file to load from - tokenizer: Callable = None, context_length: int = 8192, rules: str = None, custom_rules_prompt: str = None, - user: str = "User:", + user: str = "User", save_as_yaml: bool = False, save_as_json_bool: bool = False, - token_count: bool = True, + token_count: bool = False, message_id_on: bool = False, provider: providers = "in-memory", backend: Optional[str] = None, - # Backend-specific parameters supabase_url: Optional[str] = None, supabase_key: Optional[str] = None, redis_host: str = "localhost", @@ -210,7 +207,6 @@ class Conversation(BaseStructure): redis_password: Optional[str] = None, db_path: Optional[str] = None, table_name: str = "conversations", - # Additional backend parameters use_embedded_redis: bool = True, persist_redis: bool = True, auto_persist: bool = True, @@ -230,20 +226,7 @@ class Conversation(BaseStructure): self.save_enabled = save_enabled self.conversations_dir = conversations_dir self.message_id_on = message_id_on - - # Handle save filepath - if save_enabled and save_filepath: - self.save_filepath = save_filepath - elif save_enabled and conversations_dir: - self.save_filepath = os.path.join( - conversations_dir, f"{self.id}.json" - ) - else: - self.save_filepath = None - self.load_filepath = load_filepath - self.conversation_history = [] - self.tokenizer = tokenizer self.context_length = context_length self.rules = rules self.custom_rules_prompt = custom_rules_prompt @@ -253,9 +236,40 @@ class Conversation(BaseStructure): self.token_count = token_count self.provider = provider # Keep for backwards compatibility self.conversations_dir = conversations_dir + self.backend = backend + self.supabase_url = supabase_url + self.supabase_key = supabase_key + self.redis_host = redis_host + self.redis_port = redis_port + self.redis_db = redis_db + self.redis_password = redis_password + self.db_path = db_path + self.table_name = table_name + self.use_embedded_redis = use_embedded_redis + self.persist_redis = persist_redis + self.auto_persist = auto_persist + self.redis_data_dir = redis_data_dir + + self.conversation_history = [] + + # Handle save filepath + if save_enabled and save_filepath: + self.save_filepath = save_filepath + elif save_enabled and conversations_dir: + self.save_filepath = os.path.join( + conversations_dir, f"{self.id}.json" + ) + else: + self.save_filepath = None # Support both 'provider' and 'backend' parameters for backwards compatibility # 'backend' takes precedence if both are provided + + self.backend_setup(backend, provider) + + def backend_setup( + self, backend: str = None, provider: str = None + ): self.backend = backend or provider self.backend_instance = None @@ -285,19 +299,18 @@ class Conversation(BaseStructure): ]: try: self._initialize_backend( - supabase_url=supabase_url, - supabase_key=supabase_key, - redis_host=redis_host, - redis_port=redis_port, - redis_db=redis_db, - redis_password=redis_password, - db_path=db_path, - table_name=table_name, - use_embedded_redis=use_embedded_redis, - persist_redis=persist_redis, - auto_persist=auto_persist, - redis_data_dir=redis_data_dir, - **kwargs, + supabase_url=self.supabase_url, + supabase_key=self.supabase_key, + redis_host=self.redis_host, + redis_port=self.redis_port, + redis_db=self.redis_db, + redis_password=self.redis_password, + db_path=self.db_path, + table_name=self.table_name, + use_embedded_redis=self.use_embedded_redis, + persist_redis=self.persist_redis, + auto_persist=self.auto_persist, + redis_data_dir=self.redis_data_dir, ) except Exception as e: logger.warning( @@ -324,7 +337,6 @@ class Conversation(BaseStructure): "time_enabled": self.time_enabled, "autosave": self.autosave, "save_filepath": self.save_filepath, - "tokenizer": self.tokenizer, "context_length": self.context_length, "rules": self.rules, "custom_rules_prompt": self.custom_rules_prompt, @@ -449,8 +461,8 @@ class Conversation(BaseStructure): if self.custom_rules_prompt is not None: self.add(self.user or "User", self.custom_rules_prompt) - if self.tokenizer is not None: - self.truncate_memory_with_tokenizer() + # if self.tokenizer is not None: + # self.truncate_memory_with_tokenizer() def _autosave(self): """Automatically save the conversation if autosave is enabled.""" @@ -950,6 +962,10 @@ class Conversation(BaseStructure): # Don't save if saving is disabled if not self.save_enabled: + logger.warning( + "An attempt to save the conversation failed: save_enabled is False." + "Please set save_enabled=True when creating a Conversation object to enable saving." + ) return save_path = filename or self.save_filepath @@ -1051,9 +1067,7 @@ class Conversation(BaseStructure): for message in self.conversation_history: role = message.get("role") content = message.get("content") - tokens = self.tokenizer.count_tokens( - text=content - ) # Count the number of tokens + tokens = count_tokens(content) count = tokens # Assign the token count total_tokens += count