From 6496feff109d3d86727c760bd3d68fbd70df4d4d Mon Sep 17 00:00:00 2001 From: Kye Gomez Date: Thu, 5 Jun 2025 11:30:17 -0700 Subject: [PATCH] examples re-shuffle --- docs/swarms/structs/swarm_router.md | 200 ---- .../deep_research_swarm.py | 13 + .../deep_research_swarm_example.py | 23 + examples/news_aggregator_summarizer.py | 14 +- .../reasoning_agent_router.py | 22 + .../multii_tool_use/many_tool_use_demo.py | 4 +- long_agent_example.py | 8 - reasoning_agent_router.py | 20 - swarms/communication/supabase_wrap.py | 973 +++++++++++++----- swarms/structs/__init__.py | 2 - swarms/structs/swarm_router.py | 86 +- .../test_supabase_conversation.py | 859 ++++++++++------ 12 files changed, 1332 insertions(+), 892 deletions(-) create mode 100644 examples/multi_agent/deep_research_examples/deep_research_swarm.py create mode 100644 examples/multi_agent/deep_research_examples/deep_research_swarm_example.py delete mode 100644 long_agent_example.py delete mode 100644 reasoning_agent_router.py diff --git a/docs/swarms/structs/swarm_router.md b/docs/swarms/structs/swarm_router.md index db7892fe..cf0a87e9 100644 --- a/docs/swarms/structs/swarm_router.md +++ b/docs/swarms/structs/swarm_router.md @@ -107,25 +107,6 @@ Main class for routing tasks to different swarm types. | `concurrent_run` | `task: str, *args, **kwargs` | Execute a task using concurrent execution | | `concurrent_batch_run` | `tasks: List[str], *args, **kwargs` | Execute multiple tasks concurrently | -## Function: swarm_router - -A convenience function to create and run a SwarmRouter instance. - -| Parameter | Type | Default | Description | -| --- | --- | --- | --- | -| `name` | str | "swarm-router" | Name of the swarm router. | -| `description` | str | "Routes your task to the desired swarm" | Description of the router. | -| `max_loops` | int | 1 | Maximum number of execution loops. | -| `agents` | List[Union[Agent, Callable]] | [] | List of agents or callables. | -| `swarm_type` | SwarmType | "SequentialWorkflow" | Type of swarm to use. | -| `autosave` | bool | False | Whether to autosave results. | -| `flow` | str | None | Flow configuration. | -| `return_json` | bool | True | Whether to return results as JSON. | -| `auto_generate_prompts` | bool | False | Whether to auto-generate prompts. | -| `task` | str | None | Task to execute. | -| `rules` | str | None | Rules to inject into every agent. | -| `*args` | | | Additional positional arguments passed to SwarmRouter.run() | -| `**kwargs` | | | Additional keyword arguments passed to SwarmRouter.run() | ## Installation @@ -564,184 +545,3 @@ result = swarm_router( task="Analyze the quarterly report" ) ``` - -I'll help you create tables for each section while maintaining their titles. I'll convert the bullet points into organized tables. - -## Best Practices - -**1. Choose the Right Swarm Type** - -| Swarm Type | Use Case | -|------------|----------| -| `SequentialWorkflow` | Tasks that need to be executed in order | -| `ConcurrentWorkflow` | Independent tasks that can run in parallel | -| `MALT` | Complex language processing tasks | -| `DeepResearchSwarm` | Comprehensive research tasks | -| `CouncilAsAJudge` | Consensus-based decision making | -| `auto` | When unsure which swarm type is best | - -**2. Configure Multi-Agent Collaboration** - -| Feature | Purpose | -|---------|----------| -| `multi_agent_collab_prompt` | Enable better agent cooperation | -| `GroupChat` swarm type | Interactive agent discussions | -| Shared memory | Share context between agents | - -**3. Optimize Performance** - -| Feature | Purpose | -|---------|----------| -| `output_type` | Set appropriate output format | -| `autosave` | Enable for long-running tasks | -| `concurrent_batch_run` | Process multiple independent tasks | -| `max_loops` | Set reasonable limits to prevent infinite loops | - -**4. Error Handling and Logging** - -| Feature | Purpose | -|---------|----------| -| `get_logs()` | Regular log checking | -| Error handling | Proper application error management | -| `reliability_check` | Verify system reliability | - -**5. Agent Configuration** - -| Feature | Purpose | -|---------|----------| -| System prompts | Provide clear and specific instructions | -| Model settings | Configure appropriate temperature and parameters | -| CSV loading | Manage large agent sets efficiently | - -**6. Resource Management** - -| Feature | Purpose | -|---------|----------| -| `no_cluster_ops` | Work with limited resources | -| Background tasks | Handle long-running operations | -| Resource cleanup | Proper cleanup and release | - -**7. Security and Rules** - -| Feature | Purpose | -|---------|----------| -| `rules` | Enforce agent behavior constraints | -| Authentication | Implement proper API security | -| Input validation | Sanitize and validate inputs | - -**8. Monitoring and Debugging** - -| Feature | Purpose | -|---------|----------| -| Logging system | Effective system monitoring | -| Performance monitoring | Track swarm performance | -| Error reporting | Implement proper error tracking | - -**9. Scalability** - -| Feature | Purpose | -|---------|----------| -| Horizontal scaling | Design for system expansion | -| Batch sizes | Configure appropriate concurrent operations | -| Resource limitations | Consider system constraints | - -**10. Documentation and Maintenance** - -| Feature | Purpose | -|---------|----------| -| Custom configurations | Document system extensions | -| Performance metrics | Track swarm performance | -| Agent prompts | Regular updates of prompts and rules | - -## Common Use Cases - -**1. Document Analysis** - -| Component | Purpose | -|-----------|----------| -| `DeepResearchSwarm` | Comprehensive document analysis | -| `MALT` | Complex language processing | -| `SequentialWorkflow` | Structured document processing | - -**2. Decision Making** -| Component | Purpose | -|-----------|----------| -| `CouncilAsAJudge` | Consensus-based decisions | -| `MajorityVoting` | Democratic decision making | -| `GroupChat` | Interactive decision processes | - -**3. Research and Analysis** - -| Component | Purpose | -|-----------|----------| -| `DeepResearchSwarm` | Deep research tasks | -| `MixtureOfAgents` | Combine multiple expert agents | -| `HiearchicalSwarm` | Organized research workflows | - -**4. Data Processing** - -| Component | Purpose | -|-----------|----------| -| `ConcurrentWorkflow` | Parallel data processing | -| `SpreadSheetSwarm` | Structured data operations | -| `SequentialWorkflow` | Ordered data transformations | - -**5. Collaborative Tasks** - -| Feature | Purpose | -|---------|----------| -| `multi_agent_collab_prompt` | Better cooperation | -| `GroupChat` | Interactive collaboration | -| Shared memory | Context sharing | - -## Troubleshooting - -**1. Common Issues** - -| Area | Action | -|------|---------| -| Agent configurations | Check configurations and prompts | -| Swarm type | Verify compatibility with task | -| Resource usage | Monitor limitations | - -**2. Performance Issues** - -| Area | Action | -|------|---------| -| Operations | Optimize batch sizes and concurrent operations | -| Memory | Check for leaks and resource cleanup | -| System resources | Monitor scaling | - -**3. Integration Issues** - -| Area | Action | -|------|---------| -| API configurations | Verify settings and authentication | -| External services | Check compatibility | -| Network | Monitor connectivity and timeouts | - -## Future Development - -**1. Planned Features** - -| Feature | Description | -|---------|-------------| -| Swarm types | Additional specialized tasks | -| Collaboration | Enhanced mechanisms | -| Performance | Improved optimizations | - -**2. Contributing** - -| Area | Action | -|------|---------| -| Guidelines | Follow contribution rules | -| Bug reports | Submit issues and requests | -| Community | Participate in discussions | - -**3. Roadmap** - -| Feature | Description | -|---------|-------------| -| AI models | Integration with more models | -| Tools | Enhanced monitoring and debugging | -| Performance | Improved scalability | diff --git a/examples/multi_agent/deep_research_examples/deep_research_swarm.py b/examples/multi_agent/deep_research_examples/deep_research_swarm.py new file mode 100644 index 00000000..c52d9370 --- /dev/null +++ b/examples/multi_agent/deep_research_examples/deep_research_swarm.py @@ -0,0 +1,13 @@ +from swarms.structs.deep_research_swarm import DeepResearchSwarm + + +def main(): + swarm = DeepResearchSwarm( + name="Deep Research Swarm", + description="A swarm of agents that can perform deep research on a given topic", + ) + + swarm.run("What are the latest news in the AI an crypto space") + + +main() diff --git a/examples/multi_agent/deep_research_examples/deep_research_swarm_example.py b/examples/multi_agent/deep_research_examples/deep_research_swarm_example.py new file mode 100644 index 00000000..3cc26c9e --- /dev/null +++ b/examples/multi_agent/deep_research_examples/deep_research_swarm_example.py @@ -0,0 +1,23 @@ +from swarms.structs.deep_research_swarm import DeepResearchSwarm + + +def main(): + swarm = DeepResearchSwarm( + name="Deep Research Swarm", + description="A swarm of agents that can perform deep research on a given topic", + output_type="string", # Change to string output type for better readability + ) + + # Format the query as a proper question + query = "What are the latest developments and news in the AI and cryptocurrency space?" + + try: + result = swarm.run(query) + print("\nResearch Results:") + print(result) + except Exception as e: + print(f"Error occurred: {str(e)}") + + +if __name__ == "__main__": + main() diff --git a/examples/news_aggregator_summarizer.py b/examples/news_aggregator_summarizer.py index ce55e956..83c89de2 100644 --- a/examples/news_aggregator_summarizer.py +++ b/examples/news_aggregator_summarizer.py @@ -16,7 +16,8 @@ def fetch_hackernews_headlines(limit: int = 5): """Fetch top headlines from Hacker News using its public API.""" try: ids = httpx.get( - "https://hacker-news.firebaseio.com/v0/topstories.json", timeout=10 + "https://hacker-news.firebaseio.com/v0/topstories.json", + timeout=10, ).json() except Exception: return [] @@ -29,7 +30,12 @@ def fetch_hackernews_headlines(limit: int = 5): ).json() except Exception: continue - headlines.append({"title": item.get("title", "No title"), "url": item.get("url", "")}) + headlines.append( + { + "title": item.get("title", "No title"), + "url": item.get("url", ""), + } + ) return headlines @@ -96,7 +102,9 @@ if __name__ == "__main__": for article in headlines: content = fetch_article_content(article["url"]) summary = summarize_article(content) - summaries.append({"title": article["title"], "summary": summary}) + summaries.append( + {"title": article["title"], "summary": summary} + ) print("\nArticle Summaries:\n") for s in summaries: diff --git a/examples/single_agent/reasoning_agent_examples/reasoning_agent_router.py b/examples/single_agent/reasoning_agent_examples/reasoning_agent_router.py index 00e7a96e..96341179 100644 --- a/examples/single_agent/reasoning_agent_examples/reasoning_agent_router.py +++ b/examples/single_agent/reasoning_agent_examples/reasoning_agent_router.py @@ -22,3 +22,25 @@ reasoning_agent_router.run( # "What is the best possible financial strategy to maximize returns but minimize risk? Give a list of etfs to invest in and the percentage of the portfolio to allocate to each etf.", # ] # ) + + +# from swarms import ReasoningAgentRouter + + +# calculus_router = ReasoningAgentRouter( +# agent_name="calculus-expert", +# description="A calculus problem solving agent", +# model_name="gpt-4o-mini", +# system_prompt="You are a calculus expert. Solve differentiation and integration problems methodically.", +# swarm_type="self-consistency", +# num_samples=3, # Generate 3 samples to ensure consistency +# output_type="list", +# ) + + +# # Example calculus problem +# calculus_problem = "Find the derivative of f(x) = x³ln(x) - 5x²" + +# # Get the solution +# solution = calculus_router.run(calculus_problem) +# print(solution) diff --git a/examples/tools/multii_tool_use/many_tool_use_demo.py b/examples/tools/multii_tool_use/many_tool_use_demo.py index 4b3d1f4c..f15369f0 100644 --- a/examples/tools/multii_tool_use/many_tool_use_demo.py +++ b/examples/tools/multii_tool_use/many_tool_use_demo.py @@ -442,7 +442,5 @@ agent = Agent( ) # agent.run("Use defi stats to find the best defi project to invest in") -agent.run( - "Get the price of bitcoin on both functions get_htx_crypto_price and get_crypto_price and also get the market sentiment for bitcoin" -) +agent.run("Get the market sentiment for bitcoin") # Automatically executes any number and combination of tools you have uploaded to the tools parameter! diff --git a/long_agent_example.py b/long_agent_example.py deleted file mode 100644 index bccf9608..00000000 --- a/long_agent_example.py +++ /dev/null @@ -1,8 +0,0 @@ -from swarms.structs.long_agent import LongAgent - - -if __name__ == "__main__": - long_agent = LongAgent( - token_count_per_agent=3000, output_type="final" - ) - print(long_agent.run([""])) diff --git a/reasoning_agent_router.py b/reasoning_agent_router.py deleted file mode 100644 index 31e1b6c7..00000000 --- a/reasoning_agent_router.py +++ /dev/null @@ -1,20 +0,0 @@ -from swarms import ReasoningAgentRouter - - -calculus_router = ReasoningAgentRouter( - agent_name="calculus-expert", - description="A calculus problem solving agent", - model_name="gpt-4o-mini", - system_prompt="You are a calculus expert. Solve differentiation and integration problems methodically.", - swarm_type="self-consistency", - num_samples=3, # Generate 3 samples to ensure consistency - output_type="list", -) - - -# Example calculus problem -calculus_problem = "Find the derivative of f(x) = x³ln(x) - 5x²" - -# Get the solution -solution = calculus_router.run(calculus_problem) -print(solution) diff --git a/swarms/communication/supabase_wrap.py b/swarms/communication/supabase_wrap.py index b80c21bf..321f084c 100644 --- a/swarms/communication/supabase_wrap.py +++ b/swarms/communication/supabase_wrap.py @@ -3,8 +3,6 @@ import json import logging import threading import uuid -from contextlib import contextmanager -from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union import yaml @@ -12,6 +10,7 @@ import yaml try: from supabase import Client, create_client from postgrest import APIResponse, APIError as PostgrestAPIError + SUPABASE_AVAILABLE = True except ImportError: SUPABASE_AVAILABLE = False @@ -29,27 +28,35 @@ from swarms.communication.base_communication import ( # Try to import loguru logger, fallback to standard logging try: from loguru import logger + LOGURU_AVAILABLE = True except ImportError: LOGURU_AVAILABLE = False logger = None + # Custom Exceptions for Supabase Communication class SupabaseConnectionError(Exception): """Custom exception for Supabase connection errors.""" + pass + class SupabaseOperationError(Exception): """Custom exception for Supabase operation errors.""" + pass + class DateTimeEncoder(json.JSONEncoder): """Custom JSON encoder for handling datetime objects.""" + def default(self, obj): if isinstance(obj, datetime.datetime): return obj.isoformat() return super().default(obj) + class SupabaseConversation(BaseCommunication): """ A Supabase-backed implementation of the BaseCommunication class for managing @@ -80,21 +87,21 @@ class SupabaseConversation(BaseCommunication): system_prompt: Optional[str] = None, time_enabled: bool = False, autosave: bool = False, # Standardized parameter name - less relevant for DB-backed, but kept for interface - save_filepath: str = None, # Used for export/import + save_filepath: str = None, # Used for export/import tokenizer: Any = None, context_length: int = 8192, rules: str = None, custom_rules_prompt: str = None, user: str = "User:", - save_as_yaml: bool = True, # Default export format - save_as_json_bool: bool = False, # Alternative export format + save_as_yaml: bool = True, # Default export format + save_as_json_bool: bool = False, # Alternative export format token_count: bool = True, - cache_enabled: bool = True, # Currently for token counting + cache_enabled: bool = True, # Currently for token counting table_name: str = "conversations", - enable_timestamps: bool = True, # DB schema handles this with DEFAULT NOW() + enable_timestamps: bool = True, # DB schema handles this with DEFAULT NOW() enable_logging: bool = True, use_loguru: bool = True, - max_retries: int = 3, # For Supabase API calls (not implemented yet, supabase-py might handle) + max_retries: int = 3, # For Supabase API calls (not implemented yet, supabase-py might handle) *args, **kwargs, ): @@ -121,11 +128,13 @@ class SupabaseConversation(BaseCommunication): self.supabase_url = supabase_url self.supabase_key = supabase_key self.table_name = table_name - self.enable_timestamps = enable_timestamps # DB handles actual timestamping + self.enable_timestamps = ( + enable_timestamps # DB handles actual timestamping + ) self.enable_logging = enable_logging self.use_loguru = use_loguru and LOGURU_AVAILABLE self.max_retries = max_retries - + # Setup logging if self.enable_logging: if self.use_loguru and logger: @@ -146,28 +155,50 @@ class SupabaseConversation(BaseCommunication): self.logger.addHandler(logging.NullHandler()) self.current_conversation_id: Optional[str] = None - self._lock = threading.Lock() # For thread-safe operations if any (e.g. token calculation) + self._lock = ( + threading.Lock() + ) # For thread-safe operations if any (e.g. token calculation) try: - self.client: Client = create_client(supabase_url, supabase_key) + self.client: Client = create_client( + supabase_url, supabase_key + ) if self.enable_logging: - self.logger.info(f"Successfully initialized Supabase client for URL: {supabase_url}") + self.logger.info( + f"Successfully initialized Supabase client for URL: {supabase_url}" + ) except Exception as e: if self.enable_logging: - self.logger.error(f"Failed to initialize Supabase client: {e}") - raise SupabaseConnectionError(f"Failed to connect to Supabase: {e}") + self.logger.error( + f"Failed to initialize Supabase client: {e}" + ) + raise SupabaseConnectionError( + f"Failed to connect to Supabase: {e}" + ) self._init_db() # Verifies table existence - self.start_new_conversation() # Initializes a conversation ID + self.start_new_conversation() # Initializes a conversation ID # Add initial prompts if provided if self.system_prompt: - self.add(role="system", content=self.system_prompt, message_type=MessageType.SYSTEM) + self.add( + role="system", + content=self.system_prompt, + message_type=MessageType.SYSTEM, + ) if self.rules: # Assuming rules are spoken by the system or user based on context - self.add(role="system", content=self.rules, message_type=MessageType.SYSTEM) + self.add( + role="system", + content=self.rules, + message_type=MessageType.SYSTEM, + ) if self.custom_rules_prompt: - self.add(role=self.user, content=self.custom_rules_prompt, message_type=MessageType.USER) + self.add( + role=self.user, + content=self.custom_rules_prompt, + message_type=MessageType.USER, + ) def _init_db(self): """ @@ -190,47 +221,77 @@ class SupabaseConversation(BaseCommunication): created_at TIMESTAMPTZ DEFAULT NOW() ); """ - + # Try to create index as well create_index_sql = f""" CREATE INDEX IF NOT EXISTS idx_{self.table_name}_conversation_id ON {self.table_name} (conversation_id); """ - + # Attempt to create table using RPC function # Note: This requires a stored procedure to be created in Supabase # If RPC is not available, we'll fall back to checking if table exists try: # Try using a custom RPC function if available - self.client.rpc('exec_sql', {'sql': create_table_sql}).execute() + self.client.rpc( + "exec_sql", {"sql": create_table_sql} + ).execute() if self.enable_logging: - self.logger.info(f"Successfully created or verified table '{self.table_name}' using RPC.") + self.logger.info( + f"Successfully created or verified table '{self.table_name}' using RPC." + ) except Exception as rpc_error: if self.enable_logging: - self.logger.debug(f"RPC table creation failed (expected if no custom function): {rpc_error}") - + self.logger.debug( + f"RPC table creation failed (expected if no custom function): {rpc_error}" + ) + # Fallback: Try to verify table exists, if not provide helpful error try: - response = self.client.table(self.table_name).select("id").limit(1).execute() - if response.error and "does not exist" in str(response.error).lower(): + response = ( + self.client.table(self.table_name) + .select("id") + .limit(1) + .execute() + ) + if ( + response.error + and "does not exist" + in str(response.error).lower() + ): # Table doesn't exist, try alternative creation method self._create_table_fallback() elif response.error: - raise SupabaseOperationError(f"Error accessing table: {response.error.message}") + raise SupabaseOperationError( + f"Error accessing table: {response.error.message}" + ) else: if self.enable_logging: - self.logger.info(f"Successfully verified existing table '{self.table_name}'.") + self.logger.info( + f"Successfully verified existing table '{self.table_name}'." + ) except Exception as table_check_error: - if "does not exist" in str(table_check_error).lower() or "relation" in str(table_check_error).lower(): + if ( + "does not exist" + in str(table_check_error).lower() + or "relation" + in str(table_check_error).lower() + ): # Table definitely doesn't exist, provide creation instructions self._handle_missing_table() else: - raise SupabaseOperationError(f"Failed to access or create table: {table_check_error}") - + raise SupabaseOperationError( + f"Failed to access or create table: {table_check_error}" + ) + except Exception as e: if self.enable_logging: - self.logger.error(f"Database initialization failed: {e}") - raise SupabaseOperationError(f"Failed to initialize database: {e}") + self.logger.error( + f"Database initialization failed: {e}" + ) + raise SupabaseOperationError( + f"Failed to initialize database: {e}" + ) def _create_table_fallback(self): """ @@ -255,18 +316,26 @@ class SupabaseConversation(BaseCommunication): CREATE INDEX IF NOT EXISTS idx_{self.table_name}_conversation_id ON {self.table_name} (conversation_id); """ - + # Note: This might not work with all Supabase configurations # but we attempt it anyway - if hasattr(self.client, 'postgrest') and hasattr(self.client.postgrest, 'rpc'): - result = self.client.postgrest.rpc('exec_sql', {'query': admin_sql}).execute() + if hasattr(self.client, "postgrest") and hasattr( + self.client.postgrest, "rpc" + ): + result = self.client.postgrest.rpc( + "exec_sql", {"query": admin_sql} + ).execute() if self.enable_logging: - self.logger.info(f"Successfully created table '{self.table_name}' using admin API.") + self.logger.info( + f"Successfully created table '{self.table_name}' using admin API." + ) return except Exception as e: if self.enable_logging: - self.logger.debug(f"Admin API table creation failed: {e}") - + self.logger.debug( + f"Admin API table creation failed: {e}" + ) + # If all else fails, call the missing table handler self._handle_missing_table() @@ -301,7 +370,7 @@ ALTER TABLE {self.table_name} ENABLE ROW LEVEL SECURITY; CREATE POLICY "Users can manage their own conversations" ON {self.table_name} FOR ALL USING (true); -- Adjust this policy based on your security requirements """ - + error_msg = ( f"Table '{self.table_name}' does not exist in your Supabase database and cannot be created automatically. " f"Please create it manually by running the following SQL in your Supabase SQL Editor:\n\n{table_creation_sql}\n\n" @@ -316,17 +385,19 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} f"$$ LANGUAGE plpgsql SECURITY DEFINER;\n\n" f"After creating either the table or the RPC function, retry initializing the SupabaseConversation." ) - + if self.enable_logging: self.logger.error(error_msg) raise SupabaseOperationError(error_msg) - def _handle_api_response(self, response, operation_name: str = "Supabase operation"): + def _handle_api_response( + self, response, operation_name: str = "Supabase operation" + ): """Handles Supabase API response, checking for errors and returning data.""" # The new supabase-py client structure: response has .data and .count attributes # Errors are raised as exceptions rather than being in response.error try: - if hasattr(response, 'data'): + if hasattr(response, "data"): # Return the data, which could be None, a list, or a dict return response.data else: @@ -335,19 +406,25 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} except Exception as e: if self.enable_logging: self.logger.error(f"{operation_name} failed: {e}") - raise SupabaseOperationError(f"{operation_name} failed: {e}") + raise SupabaseOperationError( + f"{operation_name} failed: {e}" + ) - def _serialize_content(self, content: Union[str, dict, list]) -> str: + def _serialize_content( + self, content: Union[str, dict, list] + ) -> str: """Serializes content to JSON string if it's a dict or list.""" if isinstance(content, (dict, list)): return json.dumps(content, cls=DateTimeEncoder) return str(content) - def _deserialize_content(self, content_str: str) -> Union[str, dict, list]: + def _deserialize_content( + self, content_str: str + ) -> Union[str, dict, list]: """Deserializes content from JSON string if it looks like JSON. More robust approach.""" if not content_str: return content_str - + # Always try to parse as JSON first, fall back to string try: return json.loads(content_str) @@ -355,18 +432,26 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} # Not valid JSON, return as string return content_str - def _serialize_metadata(self, metadata: Optional[Dict]) -> Optional[str]: + def _serialize_metadata( + self, metadata: Optional[Dict] + ) -> Optional[str]: """Serializes metadata dict to JSON string using simplified encoder.""" if metadata is None: return None try: - return json.dumps(metadata, default=str, ensure_ascii=False) + return json.dumps( + metadata, default=str, ensure_ascii=False + ) except (TypeError, ValueError) as e: if self.enable_logging: - self.logger.warning(f"Failed to serialize metadata: {e}") + self.logger.warning( + f"Failed to serialize metadata: {e}" + ) return None - def _deserialize_metadata(self, metadata_str: Optional[str]) -> Optional[Dict]: + def _deserialize_metadata( + self, metadata_str: Optional[str] + ) -> Optional[Dict]: """Deserializes metadata from JSON string with better error handling.""" if metadata_str is None: return None @@ -374,19 +459,27 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} return json.loads(metadata_str) except (json.JSONDecodeError, TypeError) as e: if self.enable_logging: - self.logger.warning(f"Failed to deserialize metadata: {metadata_str[:50]}... Error: {e}") + self.logger.warning( + f"Failed to deserialize metadata: {metadata_str[:50]}... Error: {e}" + ) return None def _generate_conversation_id(self) -> str: """Generate a unique conversation ID using UUID and timestamp.""" - timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%d_%H%M%S_%f") + timestamp = datetime.datetime.now( + datetime.timezone.utc + ).strftime("%Y%m%d_%H%M%S_%f") unique_id = str(uuid.uuid4())[:8] return f"conv_{timestamp}_{unique_id}" def start_new_conversation(self) -> str: """Starts a new conversation and returns its ID.""" - self.current_conversation_id = self._generate_conversation_id() - self.logger.info(f"Started new conversation with ID: {self.current_conversation_id}") + self.current_conversation_id = ( + self._generate_conversation_id() + ) + self.logger.info( + f"Started new conversation with ID: {self.current_conversation_id}" + ) return self.current_conversation_id def add( @@ -402,46 +495,72 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} self.start_new_conversation() serialized_content = self._serialize_content(content) - current_timestamp_iso = datetime.datetime.now(datetime.timezone.utc).isoformat() + current_timestamp_iso = datetime.datetime.now( + datetime.timezone.utc + ).isoformat() message_data = { "conversation_id": self.current_conversation_id, "role": role, "content": serialized_content, - "timestamp": current_timestamp_iso, # Supabase will use its default if not provided / column allows NULL - "message_type": message_type.value if message_type else None, + "timestamp": current_timestamp_iso, # Supabase will use its default if not provided / column allows NULL + "message_type": ( + message_type.value if message_type else None + ), "metadata": self._serialize_metadata(metadata), # token_count handled below } # Calculate token_count if enabled and not provided - if self.calculate_token_count and token_count is None and self.tokenizer: + if ( + self.calculate_token_count + and token_count is None + and self.tokenizer + ): try: # For now, do this synchronously. For long content, consider async/threading. - message_data["token_count"] = self.tokenizer.count_tokens(str(content)) + message_data["token_count"] = ( + self.tokenizer.count_tokens(str(content)) + ) except Exception as e: if self.enable_logging: - self.logger.warning(f"Failed to count tokens for content: {e}") + self.logger.warning( + f"Failed to count tokens for content: {e}" + ) elif token_count is not None: message_data["token_count"] = token_count - + # Filter out None values to let Supabase handle defaults or NULLs appropriately - message_to_insert = {k: v for k, v in message_data.items() if v is not None} + message_to_insert = { + k: v for k, v in message_data.items() if v is not None + } try: - response = self.client.table(self.table_name).insert(message_to_insert).execute() + response = ( + self.client.table(self.table_name) + .insert(message_to_insert) + .execute() + ) data = self._handle_api_response(response, "add_message") if data and len(data) > 0 and "id" in data[0]: inserted_id = data[0]["id"] if self.enable_logging: - self.logger.debug(f"Added message with ID {inserted_id} to conversation {self.current_conversation_id}") + self.logger.debug( + f"Added message with ID {inserted_id} to conversation {self.current_conversation_id}" + ) return inserted_id if self.enable_logging: - self.logger.error(f"Failed to retrieve ID for inserted message in conversation {self.current_conversation_id}") - raise SupabaseOperationError("Failed to retrieve ID for inserted message.") + self.logger.error( + f"Failed to retrieve ID for inserted message in conversation {self.current_conversation_id}" + ) + raise SupabaseOperationError( + "Failed to retrieve ID for inserted message." + ) except Exception as e: if self.enable_logging: - self.logger.error(f"Error adding message to Supabase: {e}") + self.logger.error( + f"Error adding message to Supabase: {e}" + ) raise SupabaseOperationError(f"Error adding message: {e}") def batch_add(self, messages: List[Message]) -> List[int]: @@ -451,61 +570,108 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} messages_to_insert = [] for msg_obj in messages: - serialized_content = self._serialize_content(msg_obj.content) - current_timestamp_iso = (msg_obj.timestamp or datetime.datetime.now(datetime.timezone.utc).isoformat()) + serialized_content = self._serialize_content( + msg_obj.content + ) + current_timestamp_iso = ( + msg_obj.timestamp + or datetime.datetime.now( + datetime.timezone.utc + ).isoformat() + ) msg_data = { "conversation_id": self.current_conversation_id, "role": msg_obj.role, "content": serialized_content, "timestamp": current_timestamp_iso, - "message_type": msg_obj.message_type.value if msg_obj.message_type else None, - "metadata": self._serialize_metadata(msg_obj.metadata), + "message_type": ( + msg_obj.message_type.value + if msg_obj.message_type + else None + ), + "metadata": self._serialize_metadata( + msg_obj.metadata + ), } - + # Token count current_token_count = msg_obj.token_count - if self.calculate_token_count and current_token_count is None and self.tokenizer: + if ( + self.calculate_token_count + and current_token_count is None + and self.tokenizer + ): try: - current_token_count = self.tokenizer.count_tokens(str(msg_obj.content)) + current_token_count = self.tokenizer.count_tokens( + str(msg_obj.content) + ) except Exception as e: - self.logger.warning(f"Failed to count tokens for batch message: {e}") + self.logger.warning( + f"Failed to count tokens for batch message: {e}" + ) if current_token_count is not None: - msg_data["token_count"] = current_token_count - - messages_to_insert.append({k: v for k, v in msg_data.items() if v is not None}) + msg_data["token_count"] = current_token_count + + messages_to_insert.append( + {k: v for k, v in msg_data.items() if v is not None} + ) if not messages_to_insert: return [] try: - response = self.client.table(self.table_name).insert(messages_to_insert).execute() - data = self._handle_api_response(response, "batch_add_messages") - inserted_ids = [item['id'] for item in data if 'id' in item] + response = ( + self.client.table(self.table_name) + .insert(messages_to_insert) + .execute() + ) + data = self._handle_api_response( + response, "batch_add_messages" + ) + inserted_ids = [ + item["id"] for item in data if "id" in item + ] if len(inserted_ids) != len(messages_to_insert): - self.logger.warning("Mismatch in expected and inserted message counts during batch_add.") - self.logger.debug(f"Batch added {len(inserted_ids)} messages to conversation {self.current_conversation_id}") + self.logger.warning( + "Mismatch in expected and inserted message counts during batch_add." + ) + self.logger.debug( + f"Batch added {len(inserted_ids)} messages to conversation {self.current_conversation_id}" + ) return inserted_ids except Exception as e: - self.logger.error(f"Error batch adding messages to Supabase: {e}") - raise SupabaseOperationError(f"Error batch adding messages: {e}") + self.logger.error( + f"Error batch adding messages to Supabase: {e}" + ) + raise SupabaseOperationError( + f"Error batch adding messages: {e}" + ) def _format_row_to_dict(self, row: Dict) -> Dict: """Helper to format a raw row from Supabase to our standard message dict.""" formatted_message = { "id": row.get("id"), "role": row.get("role"), - "content": self._deserialize_content(row.get("content", "")), + "content": self._deserialize_content( + row.get("content", "") + ), "timestamp": row.get("timestamp"), "message_type": row.get("message_type"), - "metadata": self._deserialize_metadata(row.get("metadata")), + "metadata": self._deserialize_metadata( + row.get("metadata") + ), "token_count": row.get("token_count"), "conversation_id": row.get("conversation_id"), "created_at": row.get("created_at"), } # Clean None values from the root, but keep them within deserialized content/metadata - return {k: v for k, v in formatted_message.items() if v is not None or k in ["metadata", "token_count", "message_type"]} - + return { + k: v + for k, v in formatted_message.items() + if v is not None + or k in ["metadata", "token_count", "message_type"] + } def get_messages( self, @@ -516,33 +682,48 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} if self.current_conversation_id is None: return [] try: - query = self.client.table(self.table_name).select("*") \ - .eq("conversation_id", self.current_conversation_id) \ - .order("timestamp", desc=False) # Assuming 'timestamp' or 'id' for ordering + query = ( + self.client.table(self.table_name) + .select("*") + .eq("conversation_id", self.current_conversation_id) + .order("timestamp", desc=False) + ) # Assuming 'timestamp' or 'id' for ordering if limit is not None: query = query.limit(limit) if offset is not None: query = query.offset(offset) - + response = query.execute() data = self._handle_api_response(response, "get_messages") return [self._format_row_to_dict(row) for row in data] except Exception as e: - self.logger.error(f"Error getting messages from Supabase: {e}") - raise SupabaseOperationError(f"Error getting messages: {e}") + self.logger.error( + f"Error getting messages from Supabase: {e}" + ) + raise SupabaseOperationError( + f"Error getting messages: {e}" + ) def get_str(self) -> str: """Get the current conversation history as a formatted string.""" messages_dict = self.get_messages() conv_str = [] for msg in messages_dict: - ts_prefix = f"[{msg['timestamp']}] " if msg.get('timestamp') and self.time_enabled else "" + ts_prefix = ( + f"[{msg['timestamp']}] " + if msg.get("timestamp") and self.time_enabled + else "" + ) # Content might be dict/list if deserialized - content_display = msg['content'] + content_display = msg["content"] if isinstance(content_display, (dict, list)): - content_display = json.dumps(content_display, indent=2, cls=DateTimeEncoder) - conv_str.append(f"{ts_prefix}{msg['role']}: {content_display}") + content_display = json.dumps( + content_display, indent=2, cls=DateTimeEncoder + ) + conv_str.append( + f"{ts_prefix}{msg['role']}: {content_display}" + ) return "\n".join(conv_str) def display_conversation(self, detailed: bool = False): @@ -554,46 +735,69 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} """Delete a message from the conversation history by its primary key 'id'.""" if self.current_conversation_id is None: if self.enable_logging: - self.logger.warning("Cannot delete message: No current conversation.") + self.logger.warning( + "Cannot delete message: No current conversation." + ) return - + try: # Handle both string and int message IDs try: message_id = int(index) except ValueError: if self.enable_logging: - self.logger.error(f"Invalid message ID for delete: {index}. Must be an integer.") - raise ValueError(f"Invalid message ID for delete: {index}. Must be an integer.") - - response = self.client.table(self.table_name).delete() \ - .eq("id", message_id) \ - .eq("conversation_id", self.current_conversation_id) \ + self.logger.error( + f"Invalid message ID for delete: {index}. Must be an integer." + ) + raise ValueError( + f"Invalid message ID for delete: {index}. Must be an integer." + ) + + response = ( + self.client.table(self.table_name) + .delete() + .eq("id", message_id) + .eq("conversation_id", self.current_conversation_id) .execute() - self._handle_api_response(response, f"delete_message (id: {message_id})") + ) + self._handle_api_response( + response, f"delete_message (id: {message_id})" + ) if self.enable_logging: - self.logger.info(f"Deleted message with ID {message_id} from conversation {self.current_conversation_id}") + self.logger.info( + f"Deleted message with ID {message_id} from conversation {self.current_conversation_id}" + ) except Exception as e: if self.enable_logging: - self.logger.error(f"Error deleting message ID {index} from Supabase: {e}") - raise SupabaseOperationError(f"Error deleting message ID {index}: {e}") + self.logger.error( + f"Error deleting message ID {index} from Supabase: {e}" + ) + raise SupabaseOperationError( + f"Error deleting message ID {index}: {e}" + ) - def update(self, index: str, role: str, content: Union[str, dict]): + def update( + self, index: str, role: str, content: Union[str, dict] + ): """Update a message in the conversation history. Matches BaseCommunication signature exactly.""" # Use the flexible internal method - return self._update_flexible(index=index, role=role, content=content) + return self._update_flexible( + index=index, role=role, content=content + ) def _update_flexible( - self, - index: Union[str, int], - role: Optional[str] = None, + self, + index: Union[str, int], + role: Optional[str] = None, content: Optional[Union[str, dict]] = None, - metadata: Optional[Dict] = None + metadata: Optional[Dict] = None, ) -> bool: """Internal flexible update method. Returns True if successful, False otherwise.""" if self.current_conversation_id is None: if self.enable_logging: - self.logger.warning("Cannot update message: No current conversation.") + self.logger.warning( + "Cannot update message: No current conversation." + ) return False # Handle both string and int message IDs @@ -604,7 +808,9 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} message_id = index except ValueError: if self.enable_logging: - self.logger.error(f"Invalid message ID for update: {index}. Must be an integer.") + self.logger.error( + f"Invalid message ID for update: {index}. Must be an integer." + ) return False update_data = {} @@ -614,39 +820,60 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} update_data["content"] = self._serialize_content(content) if self.calculate_token_count and self.tokenizer: try: - update_data["token_count"] = self.tokenizer.count_tokens(str(content)) + update_data["token_count"] = ( + self.tokenizer.count_tokens(str(content)) + ) except Exception as e: if self.enable_logging: - self.logger.warning(f"Failed to count tokens for updated content: {e}") - if metadata is not None: # Allows setting metadata to null by passing {} then serializing - update_data["metadata"] = self._serialize_metadata(metadata) - + self.logger.warning( + f"Failed to count tokens for updated content: {e}" + ) + if ( + metadata is not None + ): # Allows setting metadata to null by passing {} then serializing + update_data["metadata"] = self._serialize_metadata( + metadata + ) + if not update_data: if self.enable_logging: - self.logger.info("No fields provided to update for message.") + self.logger.info( + "No fields provided to update for message." + ) return False try: - response = self.client.table(self.table_name).update(update_data) \ - .eq("id", message_id) \ - .eq("conversation_id", self.current_conversation_id) \ + response = ( + self.client.table(self.table_name) + .update(update_data) + .eq("id", message_id) + .eq("conversation_id", self.current_conversation_id) .execute() - - data = self._handle_api_response(response, f"update_message (id: {message_id})") - + ) + + data = self._handle_api_response( + response, f"update_message (id: {message_id})" + ) + # Check if any rows were actually updated if data and len(data) > 0: if self.enable_logging: - self.logger.info(f"Updated message with ID {message_id} in conversation {self.current_conversation_id}") + self.logger.info( + f"Updated message with ID {message_id} in conversation {self.current_conversation_id}" + ) return True else: if self.enable_logging: - self.logger.warning(f"No message found with ID {message_id} in conversation {self.current_conversation_id}") + self.logger.warning( + f"No message found with ID {message_id} in conversation {self.current_conversation_id}" + ) return False - + except Exception as e: if self.enable_logging: - self.logger.error(f"Error updating message ID {message_id} in Supabase: {e}") + self.logger.error( + f"Error updating message ID {message_id} in Supabase: {e}" + ) return False def query(self, index: str) -> Dict: @@ -659,22 +886,31 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} message_id = int(index) except ValueError: if self.enable_logging: - self.logger.warning(f"Invalid message ID for query: {index}. Must be an integer.") + self.logger.warning( + f"Invalid message ID for query: {index}. Must be an integer." + ) return {} - response = self.client.table(self.table_name).select("*") \ - .eq("id", message_id) \ - .eq("conversation_id", self.current_conversation_id) \ - .maybe_single() \ - .execute() # maybe_single returns one record or None - - data = self._handle_api_response(response, f"query_message (id: {message_id})") + response = ( + self.client.table(self.table_name) + .select("*") + .eq("id", message_id) + .eq("conversation_id", self.current_conversation_id) + .maybe_single() + .execute() + ) # maybe_single returns one record or None + + data = self._handle_api_response( + response, f"query_message (id: {message_id})" + ) if data: return self._format_row_to_dict(data) return {} except Exception as e: if self.enable_logging: - self.logger.error(f"Error querying message ID {index} from Supabase: {e}") + self.logger.error( + f"Error querying message ID {index} from Supabase: {e}" + ) return {} def query_optional(self, index: str) -> Optional[Dict]: @@ -688,46 +924,67 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} return [] try: # PostgREST ilike is case-insensitive - response = self.client.table(self.table_name).select("*") \ - .eq("conversation_id", self.current_conversation_id) \ - .ilike("content", f"%{keyword}%") \ - .order("timestamp", desc=False) \ + response = ( + self.client.table(self.table_name) + .select("*") + .eq("conversation_id", self.current_conversation_id) + .ilike("content", f"%{keyword}%") + .order("timestamp", desc=False) .execute() - data = self._handle_api_response(response, f"search_messages (keyword: {keyword})") + ) + data = self._handle_api_response( + response, f"search_messages (keyword: {keyword})" + ) return [self._format_row_to_dict(row) for row in data] except Exception as e: - self.logger.error(f"Error searching messages in Supabase: {e}") - raise SupabaseOperationError(f"Error searching messages: {e}") + self.logger.error( + f"Error searching messages in Supabase: {e}" + ) + raise SupabaseOperationError( + f"Error searching messages: {e}" + ) def _export_to_file(self, filename: str, format_type: str): """Helper to export conversation to JSON or YAML file.""" if self.current_conversation_id is None: self.logger.warning("No current conversation to export.") return - - data_to_export = self.to_dict() # Gets messages for current_conversation_id + + data_to_export = ( + self.to_dict() + ) # Gets messages for current_conversation_id try: with open(filename, "w") as f: if format_type == "json": - json.dump(data_to_export, f, indent=2, cls=DateTimeEncoder) + json.dump( + data_to_export, + f, + indent=2, + cls=DateTimeEncoder, + ) elif format_type == "yaml": yaml.dump(data_to_export, f, sort_keys=False) else: - raise ValueError(f"Unsupported export format: {format_type}") - self.logger.info(f"Conversation {self.current_conversation_id} exported to {filename} as {format_type}.") + raise ValueError( + f"Unsupported export format: {format_type}" + ) + self.logger.info( + f"Conversation {self.current_conversation_id} exported to {filename} as {format_type}." + ) except Exception as e: - self.logger.error(f"Failed to export conversation to {format_type}: {e}") + self.logger.error( + f"Failed to export conversation to {format_type}: {e}" + ) raise def export_conversation(self, filename: str): """Export the current conversation history to a file (JSON or YAML based on init flags).""" if self.save_as_json_on_export: self._export_to_file(filename, "json") - elif self.save_as_yaml_on_export: # Default if json is false + elif self.save_as_yaml_on_export: # Default if json is false + self._export_to_file(filename, "yaml") + else: # Fallback if somehow both are false self._export_to_file(filename, "yaml") - else: # Fallback if somehow both are false - self._export_to_file(filename, "yaml") - def _import_from_file(self, filename: str, format_type: str): """Helper to import conversation from JSON or YAML file.""" @@ -738,38 +995,56 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} elif format_type == "yaml": imported_data = yaml.safe_load(f) else: - raise ValueError(f"Unsupported import format: {format_type}") + raise ValueError( + f"Unsupported import format: {format_type}" + ) if not isinstance(imported_data, list): - raise ValueError("Imported data must be a list of messages.") + raise ValueError( + "Imported data must be a list of messages." + ) # Start a new conversation for the imported data self.start_new_conversation() - + messages_to_batch = [] for msg_data in imported_data: # Adapt to Message dataclass structure if possible role = msg_data.get("role") content = msg_data.get("content") if role is None or content is None: - self.logger.warning(f"Skipping message due to missing role/content: {msg_data}") + self.logger.warning( + f"Skipping message due to missing role/content: {msg_data}" + ) continue - messages_to_batch.append(Message( - role=role, - content=content, - timestamp=msg_data.get("timestamp"), # Will be handled by batch_add - message_type=MessageType(msg_data["message_type"]) if msg_data.get("message_type") else None, - metadata=msg_data.get("metadata"), - token_count=msg_data.get("token_count") - )) - + messages_to_batch.append( + Message( + role=role, + content=content, + timestamp=msg_data.get( + "timestamp" + ), # Will be handled by batch_add + message_type=( + MessageType(msg_data["message_type"]) + if msg_data.get("message_type") + else None + ), + metadata=msg_data.get("metadata"), + token_count=msg_data.get("token_count"), + ) + ) + if messages_to_batch: self.batch_add(messages_to_batch) - self.logger.info(f"Conversation imported from {filename} ({format_type}) into new ID {self.current_conversation_id}.") + self.logger.info( + f"Conversation imported from {filename} ({format_type}) into new ID {self.current_conversation_id}." + ) except Exception as e: - self.logger.error(f"Failed to import conversation from {format_type}: {e}") + self.logger.error( + f"Failed to import conversation from {format_type}: {e}" + ) raise def import_conversation(self, filename: str): @@ -783,12 +1058,18 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} # Try JSON first, then YAML as a fallback try: self._import_from_file(filename, "json") - except (json.JSONDecodeError, ValueError): # ValueError if not list - self.logger.info(f"Failed to import {filename} as JSON, trying YAML.") + except ( + json.JSONDecodeError, + ValueError, + ): # ValueError if not list + self.logger.info( + f"Failed to import {filename} as JSON, trying YAML." + ) self._import_from_file(filename, "yaml") - except Exception as e: # Catch errors from _import_from_file - raise SupabaseOperationError(f"Could not import {filename}: {e}") - + except Exception as e: # Catch errors from _import_from_file + raise SupabaseOperationError( + f"Could not import {filename}: {e}" + ) def count_messages_by_role(self) -> Dict[str, int]: """Count messages by role for the current conversation.""" @@ -797,7 +1078,9 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} try: # Supabase rpc might be better for direct count, but select + python count is also fine # For direct DB count: self.client.rpc('count_roles', {'conv_id': self.current_conversation_id}).execute() - messages = self.get_messages() # Fetches for current_conversation_id + messages = ( + self.get_messages() + ) # Fetches for current_conversation_id counts = {} for msg in messages: role = msg.get("role", "unknown") @@ -805,7 +1088,9 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} return counts except Exception as e: self.logger.error(f"Error counting messages by role: {e}") - raise SupabaseOperationError(f"Error counting messages by role: {e}") + raise SupabaseOperationError( + f"Error counting messages by role: {e}" + ) def return_history_as_string(self) -> str: """Return the conversation history as a string.""" @@ -817,25 +1102,41 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} self.logger.info("No current conversation to clear.") return try: - response = self.client.table(self.table_name).delete() \ - .eq("conversation_id", self.current_conversation_id) \ + response = ( + self.client.table(self.table_name) + .delete() + .eq("conversation_id", self.current_conversation_id) .execute() + ) # response.data will be a list of deleted items. # response.count might be available for delete operations in some supabase-py versions or configurations. # For now, we assume success if no error. - self._handle_api_response(response, f"clear_conversation (id: {self.current_conversation_id})") - self.logger.info(f"Cleared conversation with ID: {self.current_conversation_id}") + self._handle_api_response( + response, + f"clear_conversation (id: {self.current_conversation_id})", + ) + self.logger.info( + f"Cleared conversation with ID: {self.current_conversation_id}" + ) except Exception as e: - self.logger.error(f"Error clearing conversation {self.current_conversation_id} from Supabase: {e}") - raise SupabaseOperationError(f"Error clearing conversation: {e}") + self.logger.error( + f"Error clearing conversation {self.current_conversation_id} from Supabase: {e}" + ) + raise SupabaseOperationError( + f"Error clearing conversation: {e}" + ) def to_dict(self) -> List[Dict]: """Convert the current conversation history to a list of dictionaries.""" - return self.get_messages() # Already fetches for current_conversation_id + return ( + self.get_messages() + ) # Already fetches for current_conversation_id def to_json(self) -> str: """Convert the current conversation history to a JSON string.""" - return json.dumps(self.to_dict(), indent=2, cls=DateTimeEncoder) + return json.dumps( + self.to_dict(), indent=2, cls=DateTimeEncoder + ) def to_yaml(self) -> str: """Convert the current conversation history to a YAML string.""" @@ -862,27 +1163,42 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} if self.current_conversation_id is None: return None try: - response = self.client.table(self.table_name).select("*") \ - .eq("conversation_id", self.current_conversation_id) \ - .order("timestamp", desc=True) \ - .limit(1) \ - .maybe_single() \ + response = ( + self.client.table(self.table_name) + .select("*") + .eq("conversation_id", self.current_conversation_id) + .order("timestamp", desc=True) + .limit(1) + .maybe_single() .execute() - data = self._handle_api_response(response, "get_last_message") + ) + data = self._handle_api_response( + response, "get_last_message" + ) return self._format_row_to_dict(data) if data else None except Exception as e: - self.logger.error(f"Error getting last message from Supabase: {e}") - raise SupabaseOperationError(f"Error getting last message: {e}") + self.logger.error( + f"Error getting last message from Supabase: {e}" + ) + raise SupabaseOperationError( + f"Error getting last message: {e}" + ) def get_last_message_as_string(self) -> str: """Get the last message as a formatted string.""" last_msg = self.get_last_message() if not last_msg: return "" - ts_prefix = f"[{last_msg['timestamp']}] " if last_msg.get('timestamp') and self.time_enabled else "" - content_display = last_msg['content'] + ts_prefix = ( + f"[{last_msg['timestamp']}] " + if last_msg.get("timestamp") and self.time_enabled + else "" + ) + content_display = last_msg["content"] if isinstance(content_display, (dict, list)): - content_display = json.dumps(content_display, cls=DateTimeEncoder) + content_display = json.dumps( + content_display, cls=DateTimeEncoder + ) return f"{ts_prefix}{last_msg['role']}: {content_display}" def get_messages_by_role(self, role: str) -> List[Dict]: @@ -890,22 +1206,31 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} if self.current_conversation_id is None: return [] try: - response = self.client.table(self.table_name).select("*") \ - .eq("conversation_id", self.current_conversation_id) \ - .eq("role", role) \ - .order("timestamp", desc=False) \ + response = ( + self.client.table(self.table_name) + .select("*") + .eq("conversation_id", self.current_conversation_id) + .eq("role", role) + .order("timestamp", desc=False) .execute() - data = self._handle_api_response(response, f"get_messages_by_role (role: {role})") + ) + data = self._handle_api_response( + response, f"get_messages_by_role (role: {role})" + ) return [self._format_row_to_dict(row) for row in data] except Exception as e: - self.logger.error(f"Error getting messages by role '{role}' from Supabase: {e}") - raise SupabaseOperationError(f"Error getting messages by role '{role}': {e}") + self.logger.error( + f"Error getting messages by role '{role}' from Supabase: {e}" + ) + raise SupabaseOperationError( + f"Error getting messages by role '{role}': {e}" + ) def get_conversation_summary(self) -> Dict: """Get a summary of the current conversation.""" if self.current_conversation_id is None: return {"error": "No current conversation."} - + # This could be optimized with an RPC call in Supabase for better performance # Example RPC: CREATE OR REPLACE FUNCTION get_conversation_summary(conv_id TEXT) ... messages = self.get_messages() @@ -923,10 +1248,12 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} roles_counts = {} total_tokens_sum = 0 for msg in messages: - roles_counts[msg["role"]] = roles_counts.get(msg["role"], 0) + 1 + roles_counts[msg["role"]] = ( + roles_counts.get(msg["role"], 0) + 1 + ) if msg.get("token_count") is not None: total_tokens_sum += int(msg["token_count"]) - + return { "conversation_id": self.current_conversation_id, "total_messages": len(messages), @@ -948,14 +1275,17 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} def delete_current_conversation(self) -> bool: """Delete the current conversation. Returns True if successful.""" if self.current_conversation_id: - self.clear() # clear messages for current_conversation_id - self.logger.info(f"Deleted current conversation: {self.current_conversation_id}") - self.current_conversation_id = None # No active conversation after deletion + self.clear() # clear messages for current_conversation_id + self.logger.info( + f"Deleted current conversation: {self.current_conversation_id}" + ) + self.current_conversation_id = ( + None # No active conversation after deletion + ) return True self.logger.info("No current conversation to delete.") return False - def search_messages(self, query: str) -> List[Dict]: """Search for messages containing specific text (alias for search).""" return self.search(keyword=query) @@ -967,7 +1297,7 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} if self.current_conversation_id is None: return {"error": "No current conversation."} summary = self.get_conversation_summary() - + # Example of additional metadata one might compute client-side or via RPC # message_type_distribution, average_tokens_per_message, hourly_message_frequency return { @@ -976,36 +1306,42 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} # Placeholder for more detailed stats if implemented } - def get_conversation_timeline_dict(self) -> Dict[str, List[Dict]]: """Get the conversation organized by timestamps (dates as keys).""" if self.current_conversation_id is None: return {} - - messages = self.get_messages() # Assumes messages are ordered by timestamp + + messages = ( + self.get_messages() + ) # Assumes messages are ordered by timestamp timeline_dict = {} for msg in messages: try: # Ensure timestamp is a string and valid ISO format ts_str = msg.get("timestamp") if isinstance(ts_str, str): - date_key = datetime.datetime.fromisoformat(ts_str.replace("Z", "+00:00")).strftime('%Y-%m-%d') + date_key = datetime.datetime.fromisoformat( + ts_str.replace("Z", "+00:00") + ).strftime("%Y-%m-%d") if date_key not in timeline_dict: timeline_dict[date_key] = [] timeline_dict[date_key].append(msg) else: - self.logger.warning(f"Message ID {msg.get('id')} has invalid timestamp format: {ts_str}") + self.logger.warning( + f"Message ID {msg.get('id')} has invalid timestamp format: {ts_str}" + ) except ValueError as e: - self.logger.warning(f"Could not parse timestamp for message ID {msg.get('id')}: {ts_str}, Error: {e}") + self.logger.warning( + f"Could not parse timestamp for message ID {msg.get('id')}: {ts_str}, Error: {e}" + ) return timeline_dict - def get_conversation_by_role_dict(self) -> Dict[str, List[Dict]]: """Get the conversation organized by roles.""" if self.current_conversation_id is None: return {} - + messages = self.get_messages() role_dict = {} for msg in messages: @@ -1019,46 +1355,56 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} """Get the entire current conversation as a dictionary with messages and metadata.""" if self.current_conversation_id is None: return {"error": "No current conversation."} - + return { "conversation_id": self.current_conversation_id, "messages": self.get_messages(), - "metadata": self.get_conversation_summary(), # Using summary as metadata + "metadata": self.get_conversation_summary(), # Using summary as metadata } def truncate_memory_with_tokenizer(self): """Truncate the conversation history based on token count if a tokenizer is provided. Optimized for better performance.""" if not self.tokenizer or self.current_conversation_id is None: if self.enable_logging: - self.logger.info("Tokenizer not available or no current conversation, skipping truncation.") + self.logger.info( + "Tokenizer not available or no current conversation, skipping truncation." + ) return try: # Fetch messages with only necessary fields for efficiency - response = self.client.table(self.table_name).select("id, content, token_count") \ - .eq("conversation_id", self.current_conversation_id) \ - .order("timestamp", desc=False) \ + response = ( + self.client.table(self.table_name) + .select("id, content, token_count") + .eq("conversation_id", self.current_conversation_id) + .order("timestamp", desc=False) .execute() - - messages = self._handle_api_response(response, "fetch_messages_for_truncation") + ) + + messages = self._handle_api_response( + response, "fetch_messages_for_truncation" + ) if not messages: return # Calculate tokens and determine which messages to delete total_tokens = 0 message_tokens = [] - + for msg in messages: token_count = msg.get("token_count") if token_count is None and self.calculate_token_count: # Recalculate if missing - content = self._deserialize_content(msg.get("content", "")) - token_count = self.tokenizer.count_tokens(str(content)) - - message_tokens.append({ - "id": msg["id"], - "tokens": token_count or 0 - }) + content = self._deserialize_content( + msg.get("content", "") + ) + token_count = self.tokenizer.count_tokens( + str(content) + ) + + message_tokens.append( + {"id": msg["id"], "tokens": token_count or 0} + ) total_tokens += token_count or 0 tokens_to_remove = total_tokens - self.context_length @@ -1079,30 +1425,50 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} # Batch delete for better performance if len(ids_to_delete) == 1: # Single delete - response = self.client.table(self.table_name).delete() \ - .eq("id", ids_to_delete[0]) \ - .eq("conversation_id", self.current_conversation_id) \ + response = ( + self.client.table(self.table_name) + .delete() + .eq("id", ids_to_delete[0]) + .eq( + "conversation_id", + self.current_conversation_id, + ) .execute() + ) else: # Batch delete using 'in' operator - response = self.client.table(self.table_name).delete() \ - .in_("id", ids_to_delete) \ - .eq("conversation_id", self.current_conversation_id) \ + response = ( + self.client.table(self.table_name) + .delete() + .in_("id", ids_to_delete) + .eq( + "conversation_id", + self.current_conversation_id, + ) .execute() - - self._handle_api_response(response, "truncate_conversation_batch_delete") - + ) + + self._handle_api_response( + response, "truncate_conversation_batch_delete" + ) + if self.enable_logging: - self.logger.info(f"Truncated conversation {self.current_conversation_id}, removed {len(ids_to_delete)} oldest messages.") + self.logger.info( + f"Truncated conversation {self.current_conversation_id}, removed {len(ids_to_delete)} oldest messages." + ) except Exception as e: if self.enable_logging: - self.logger.error(f"Error during memory truncation for conversation {self.current_conversation_id}: {e}") + self.logger.error( + f"Error during memory truncation for conversation {self.current_conversation_id}: {e}" + ) # Don't re-raise, truncation is best-effort # Methods from duckdb_wrap.py that seem generally useful and can be adapted def get_visible_messages( - self, agent: Optional[Callable] = None, turn: Optional[int] = None + self, + agent: Optional[Callable] = None, + turn: Optional[int] = None, ) -> List[Dict]: """ Get visible messages, optionally filtered by agent visibility and turn. @@ -1113,42 +1479,65 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} return [] # Base query - query = self.client.table(self.table_name).select("*") \ - .eq("conversation_id", self.current_conversation_id) \ + query = ( + self.client.table(self.table_name) + .select("*") + .eq("conversation_id", self.current_conversation_id) .order("timestamp", desc=False) - + ) + # Execute and then filter in Python, as JSONB querying for array containment or # numeric comparison within JSON can be complex with supabase-py's fluent API. # For complex filtering, an RPC function in Supabase would be more efficient. - + try: response = query.execute() - all_messages = self._handle_api_response(response, "get_visible_messages_fetch_all") + all_messages = self._handle_api_response( + response, "get_visible_messages_fetch_all" + ) except Exception as e: - self.logger.error(f"Error fetching messages for visibility check: {e}") + self.logger.error( + f"Error fetching messages for visibility check: {e}" + ) return [] visible_messages = [] for row_data in all_messages: msg = self._format_row_to_dict(row_data) - metadata = msg.get("metadata") if isinstance(msg.get("metadata"), dict) else {} + metadata = ( + msg.get("metadata") + if isinstance(msg.get("metadata"), dict) + else {} + ) # Turn filtering if turn is not None: msg_turn = metadata.get("turn") - if not (isinstance(msg_turn, int) and msg_turn < turn): - continue # Skip if turn condition not met + if not ( + isinstance(msg_turn, int) and msg_turn < turn + ): + continue # Skip if turn condition not met # Agent visibility filtering if agent is not None: visible_to = metadata.get("visible_to") - agent_name_attr = getattr(agent, 'agent_name', None) # Safely get agent_name - if agent_name_attr is None: # If agent has no name, assume it can't see restricted msgs + agent_name_attr = getattr( + agent, "agent_name", None + ) # Safely get agent_name + if ( + agent_name_attr is None + ): # If agent has no name, assume it can't see restricted msgs if visible_to is not None and visible_to != "all": continue - elif isinstance(visible_to, list) and agent_name_attr not in visible_to: - continue # Skip if agent not in visible_to list - elif isinstance(visible_to, str) and visible_to != "all": + elif ( + isinstance(visible_to, list) + and agent_name_attr not in visible_to + ): + continue # Skip if agent not in visible_to list + elif ( + isinstance(visible_to, str) + and visible_to != "all" + ): # If visible_to is a string but not "all", and doesn't match agent_name if visible_to != agent_name_attr: continue @@ -1168,26 +1557,39 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} """Return the conversation messages as a list of dictionaries [{role: R, content: C}].""" messages_dict = self.get_messages() return [ - {"role": msg.get("role"), "content": msg.get("content")} # Content already deserialized by _format_row_to_dict + { + "role": msg.get("role"), + "content": msg.get("content"), + } # Content already deserialized by _format_row_to_dict for msg in messages_dict ] - def add_tool_output_to_agent(self, role: str, tool_output: dict): # role is usually "tool" + def add_tool_output_to_agent( + self, role: str, tool_output: dict + ): # role is usually "tool" """Add a tool output to the conversation history.""" # Assuming tool_output is a dict that should be stored as content - self.add(role=role, content=tool_output, message_type=MessageType.TOOL) + self.add( + role=role, + content=tool_output, + message_type=MessageType.TOOL, + ) def get_final_message(self) -> Optional[str]: """Return the final message from the conversation history as 'role: content' string.""" last_msg = self.get_last_message() if not last_msg: return None - content_display = last_msg['content'] + content_display = last_msg["content"] if isinstance(content_display, (dict, list)): - content_display = json.dumps(content_display, cls=DateTimeEncoder) + content_display = json.dumps( + content_display, cls=DateTimeEncoder + ) return f"{last_msg.get('role', 'unknown')}: {content_display}" - def get_final_message_content(self) -> Union[str, dict, list, None]: + def get_final_message_content( + self, + ) -> Union[str, dict, list, None]: """Return the content of the final message from the conversation history.""" last_msg = self.get_last_message() return last_msg.get("content") if last_msg else None @@ -1199,17 +1601,24 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} all_messages = self.get_messages() return all_messages[1:] if len(all_messages) > 1 else [] - def return_all_except_first_string(self) -> str: """Return all messages except the first one as a concatenated string.""" messages_to_format = self.return_all_except_first() conv_str = [] for msg in messages_to_format: - ts_prefix = f"[{msg['timestamp']}] " if msg.get('timestamp') and self.time_enabled else "" - content_display = msg['content'] + ts_prefix = ( + f"[{msg['timestamp']}] " + if msg.get("timestamp") and self.time_enabled + else "" + ) + content_display = msg["content"] if isinstance(content_display, (dict, list)): - content_display = json.dumps(content_display, indent=2, cls=DateTimeEncoder) - conv_str.append(f"{ts_prefix}{msg['role']}: {content_display}") + content_display = json.dumps( + content_display, indent=2, cls=DateTimeEncoder + ) + conv_str.append( + f"{ts_prefix}{msg['role']}: {content_display}" + ) return "\n".join(conv_str) def update_message( @@ -1220,4 +1629,6 @@ CREATE POLICY "Users can manage their own conversations" ON {self.table_name} ) -> bool: """Update an existing message. Matches BaseCommunication.update_message signature exactly.""" # Use the flexible internal method - return self._update_flexible(index=message_id, content=content, metadata=metadata) \ No newline at end of file + return self._update_flexible( + index=message_id, content=content, metadata=metadata + ) diff --git a/swarms/structs/__init__.py b/swarms/structs/__init__.py index 377fb2b4..53181f32 100644 --- a/swarms/structs/__init__.py +++ b/swarms/structs/__init__.py @@ -55,7 +55,6 @@ from swarms.structs.swarm_arange import SwarmRearrange from swarms.structs.swarm_router import ( SwarmRouter, SwarmType, - swarm_router, ) from swarms.structs.swarming_architectures import ( broadcast, @@ -135,7 +134,6 @@ __all__ = [ "run_agents_with_different_tasks", "run_agent_with_timeout", "run_agents_with_resource_monitoring", - "swarm_router", "run_agents_with_tasks_concurrently", "GroupChat", "expertise_based", diff --git a/swarms/structs/swarm_router.py b/swarms/structs/swarm_router.py index ff09209f..98c7da19 100644 --- a/swarms/structs/swarm_router.py +++ b/swarms/structs/swarm_router.py @@ -332,9 +332,7 @@ class SwarmRouter: ) logger.info("🚀 [SYSTEM] Swarm is ready for deployment") - def _create_swarm( - self, task: str = None, *args, **kwargs - ): + def _create_swarm(self, task: str = None, *args, **kwargs): """ Dynamically create and return the specified swarm type or automatically match the best swarm type for a given task. @@ -768,85 +766,3 @@ class SwarmRouter: results.append(None) return results - - -def swarm_router( - name: str = "swarm-router", - description: str = "Routes your task to the desired swarm", - max_loops: int = 1, - agents: List[Union[Agent, Callable]] = [], - swarm_type: SwarmType = "SequentialWorkflow", # "SpreadSheetSwarm" # "auto" - autosave: bool = False, - flow: str = None, - return_json: bool = True, - auto_generate_prompts: bool = False, - task: str = None, - rules: str = None, - *args, - **kwargs, -) -> SwarmRouter: - """ - Create and run a SwarmRouter instance with the given configuration. - - Args: - name (str, optional): Name of the swarm router. Defaults to "swarm-router". - description (str, optional): Description of the router. Defaults to "Routes your task to the desired swarm". - max_loops (int, optional): Maximum number of execution loops. Defaults to 1. - agents (List[Union[Agent, Callable]], optional): List of agents or callables. Defaults to []. - swarm_type (SwarmType, optional): Type of swarm to use. Defaults to "SequentialWorkflow". - autosave (bool, optional): Whether to autosave results. Defaults to False. - flow (str, optional): Flow configuration. Defaults to None. - return_json (bool, optional): Whether to return results as JSON. Defaults to True. - auto_generate_prompts (bool, optional): Whether to auto-generate prompts. Defaults to False. - task (str, optional): Task to execute. Defaults to None. - *args: Additional positional arguments passed to SwarmRouter.run() - **kwargs: Additional keyword arguments passed to SwarmRouter.run() - - Returns: - Any: Result from executing the swarm router - - Raises: - ValueError: If invalid arguments are provided - Exception: If an error occurs during router creation or task execution - """ - try: - logger.info( - f"Creating SwarmRouter with name: {name}, swarm_type: {swarm_type}" - ) - - if not agents: - logger.warning( - "No agents provided, router may have limited functionality" - ) - - if task is None: - logger.warning("No task provided") - - swarm_router = SwarmRouter( - name=name, - description=description, - max_loops=max_loops, - agents=agents, - swarm_type=swarm_type, - autosave=autosave, - flow=flow, - return_json=return_json, - auto_generate_prompts=auto_generate_prompts, - rules=rules, - ) - - logger.info(f"Executing task with SwarmRouter: {task}") - result = swarm_router.run(task, *args, **kwargs) - logger.info( - f"Task execution completed successfully: {result}" - ) - return result - - except ValueError as e: - logger.error( - f"Invalid arguments provided to swarm_router: {str(e)}" - ) - raise - except Exception as e: - logger.error(f"Error in swarm_router execution: {str(e)}") - raise diff --git a/tests/communication/test_supabase_conversation.py b/tests/communication/test_supabase_conversation.py index ac30c323..17f67745 100644 --- a/tests/communication/test_supabase_conversation.py +++ b/tests/communication/test_supabase_conversation.py @@ -2,7 +2,6 @@ import os import sys import json import datetime -import tempfile import threading from typing import Dict, List, Any, Tuple from pathlib import Path @@ -13,9 +12,11 @@ sys.path.insert(0, str(project_root)) try: from loguru import logger + LOGURU_AVAILABLE = True except ImportError: import logging + logger = logging.getLogger(__name__) LOGURU_AVAILABLE = False @@ -23,6 +24,7 @@ try: from rich.console import Console from rich.table import Table from rich.panel import Panel + RICH_AVAILABLE = True console = Console() except ImportError: @@ -36,7 +38,11 @@ try: SupabaseConnectionError, SupabaseOperationError, ) - from swarms.communication.base_communication import Message, MessageType + from swarms.communication.base_communication import ( + Message, + MessageType, + ) + SUPABASE_AVAILABLE = True except ImportError as e: SUPABASE_AVAILABLE = False @@ -46,13 +52,14 @@ except ImportError as e: # Try to load environment variables try: from dotenv import load_dotenv + load_dotenv() except ImportError: pass # dotenv is optional # Test configuration TEST_SUPABASE_URL = os.getenv("SUPABASE_URL") -TEST_SUPABASE_KEY = os.getenv("SUPABASE_KEY") +TEST_SUPABASE_KEY = os.getenv("SUPABASE_KEY") TEST_TABLE_NAME = "conversations_test" @@ -81,7 +88,9 @@ def print_test_result( ) console.print(f"\n{status} - {test_name}") console.print(f"Message: {message}") - console.print(f"Execution time: {execution_time:.3f} seconds\n") + console.print( + f"Execution time: {execution_time:.3f} seconds\n" + ) else: status = "PASSED" if success else "FAILED" print(f"\n{status} - {test_name}") @@ -89,7 +98,9 @@ def print_test_result( print(f"Execution time: {execution_time:.3f} seconds\n") -def print_messages(messages: List[Dict], title: str = "Messages") -> None: +def print_messages( + messages: List[Dict], title: str = "Messages" +) -> None: """Print messages in a formatted table.""" if RICH_AVAILABLE and console: table = Table(title=title) @@ -99,19 +110,29 @@ def print_messages(messages: List[Dict], title: str = "Messages") -> None: table.add_column("Type", style="magenta") table.add_column("Timestamp", style="blue") - for msg in messages[:10]: # Limit to first 10 messages for display + for msg in messages[ + :10 + ]: # Limit to first 10 messages for display content = str(msg.get("content", "")) if isinstance(content, (dict, list)): - content = json.dumps(content)[:50] + "..." if len(json.dumps(content)) > 50 else json.dumps(content) + content = ( + json.dumps(content)[:50] + "..." + if len(json.dumps(content)) > 50 + else json.dumps(content) + ) elif len(content) > 50: content = content[:50] + "..." - + table.add_row( str(msg.get("id", "")), msg.get("role", ""), content, str(msg.get("message_type", "")), - str(msg.get("timestamp", ""))[:19] if msg.get("timestamp") else "", + ( + str(msg.get("timestamp", ""))[:19] + if msg.get("timestamp") + else "" + ), ) console.print(table) @@ -126,7 +147,9 @@ def print_messages(messages: List[Dict], title: str = "Messages") -> None: print(f"{i+1}. {msg.get('role', '')}: {content}") -def run_test(test_func: callable, *args, **kwargs) -> Tuple[bool, str, float]: +def run_test( + test_func: callable, *args, **kwargs +) -> Tuple[bool, str, float]: """ Run a test function and return its results. @@ -154,12 +177,12 @@ def setup_test_conversation(): """Set up a test conversation instance.""" if not SUPABASE_AVAILABLE: raise ImportError("Supabase dependencies not available") - + if not TEST_SUPABASE_URL or not TEST_SUPABASE_KEY: raise ValueError( "SUPABASE_URL and SUPABASE_KEY environment variables must be set for testing" ) - + conversation = SupabaseConversation( supabase_url=TEST_SUPABASE_URL, supabase_key=TEST_SUPABASE_KEY, @@ -176,46 +199,72 @@ def cleanup_test_conversation(conversation): conversation.clear() except Exception as e: if LOGURU_AVAILABLE: - logger.warning(f"Failed to clean up test conversation: {e}") + logger.warning( + f"Failed to clean up test conversation: {e}" + ) else: - print(f"Warning: Failed to clean up test conversation: {e}") + print( + f"Warning: Failed to clean up test conversation: {e}" + ) def test_import_availability() -> bool: """Test that Supabase imports are properly handled.""" print_test_header("Import Availability Test") - + if not SUPABASE_AVAILABLE: - print("✓ Import availability test passed - detected missing dependencies correctly") + print( + "✓ Import availability test passed - detected missing dependencies correctly" + ) return True - + # Test that all required classes are available - assert SupabaseConversation is not None, "SupabaseConversation should be available" - assert SupabaseConnectionError is not None, "SupabaseConnectionError should be available" - assert SupabaseOperationError is not None, "SupabaseOperationError should be available" + assert ( + SupabaseConversation is not None + ), "SupabaseConversation should be available" + assert ( + SupabaseConnectionError is not None + ), "SupabaseConnectionError should be available" + assert ( + SupabaseOperationError is not None + ), "SupabaseOperationError should be available" assert Message is not None, "Message should be available" assert MessageType is not None, "MessageType should be available" - - print("✓ Import availability test passed - all imports successful") + + print( + "✓ Import availability test passed - all imports successful" + ) return True def test_initialization() -> bool: """Test SupabaseConversation initialization.""" print_test_header("Initialization Test") - + if not SUPABASE_AVAILABLE: - print("✓ Initialization test skipped - Supabase not available") + print( + "✓ Initialization test skipped - Supabase not available" + ) return True - + conversation = setup_test_conversation() try: - assert conversation.supabase_url == TEST_SUPABASE_URL, "Supabase URL mismatch" - assert conversation.table_name == TEST_TABLE_NAME, "Table name mismatch" - assert conversation.current_conversation_id is not None, "Conversation ID should not be None" - assert conversation.client is not None, "Supabase client should not be None" - assert isinstance(conversation.get_conversation_id(), str), "Conversation ID should be string" - + assert ( + conversation.supabase_url == TEST_SUPABASE_URL + ), "Supabase URL mismatch" + assert ( + conversation.table_name == TEST_TABLE_NAME + ), "Table name mismatch" + assert ( + conversation.current_conversation_id is not None + ), "Conversation ID should not be None" + assert ( + conversation.client is not None + ), "Supabase client should not be None" + assert isinstance( + conversation.get_conversation_id(), str + ), "Conversation ID should be string" + # Test that initialization doesn't call super().__init__() improperly # This should not raise any errors print("✓ Initialization test passed") @@ -227,11 +276,13 @@ def test_initialization() -> bool: def test_logging_configuration() -> bool: """Test logging configuration options.""" print_test_header("Logging Configuration Test") - + if not SUPABASE_AVAILABLE: - print("✓ Logging configuration test skipped - Supabase not available") + print( + "✓ Logging configuration test skipped - Supabase not available" + ) return True - + # Test with logging enabled conversation_with_logging = SupabaseConversation( supabase_url=TEST_SUPABASE_URL, @@ -240,11 +291,15 @@ def test_logging_configuration() -> bool: enable_logging=True, use_loguru=False, # Force standard logging ) - + try: - assert conversation_with_logging.enable_logging == True, "Logging should be enabled" - assert conversation_with_logging.logger is not None, "Logger should be configured" - + assert ( + conversation_with_logging.enable_logging == True + ), "Logging should be enabled" + assert ( + conversation_with_logging.logger is not None + ), "Logger should be configured" + # Test with logging disabled conversation_no_logging = SupabaseConversation( supabase_url=TEST_SUPABASE_URL, @@ -252,9 +307,11 @@ def test_logging_configuration() -> bool: table_name=TEST_TABLE_NAME + "_no_log", enable_logging=False, ) - - assert conversation_no_logging.enable_logging == False, "Logging should be disabled" - + + assert ( + conversation_no_logging.enable_logging == False + ), "Logging should be disabled" + print("✓ Logging configuration test passed") return True finally: @@ -268,22 +325,24 @@ def test_logging_configuration() -> bool: def test_add_message() -> bool: """Test adding a single message.""" print_test_header("Add Message Test") - + if not SUPABASE_AVAILABLE: print("✓ Add message test skipped - Supabase not available") return True - + conversation = setup_test_conversation() try: msg_id = conversation.add( role="user", content="Hello, Supabase!", message_type=MessageType.USER, - metadata={"test": True} + metadata={"test": True}, ) assert msg_id is not None, "Message ID should not be None" - assert isinstance(msg_id, int), "Message ID should be an integer" - + assert isinstance( + msg_id, int + ), "Message ID should be an integer" + # Verify message was stored messages = conversation.get_messages() assert len(messages) >= 1, "Should have at least 1 message" @@ -296,19 +355,21 @@ def test_add_message() -> bool: def test_add_complex_message() -> bool: """Test adding a message with complex content.""" print_test_header("Add Complex Message Test") - + if not SUPABASE_AVAILABLE: - print("✓ Add complex message test skipped - Supabase not available") + print( + "✓ Add complex message test skipped - Supabase not available" + ) return True - + conversation = setup_test_conversation() try: complex_content = { "text": "Hello from Supabase", "data": [1, 2, 3, {"nested": "value"}], - "metadata": {"source": "test", "priority": "high"} + "metadata": {"source": "test", "priority": "high"}, } - + msg_id = conversation.add( role="assistant", content=complex_content, @@ -316,19 +377,23 @@ def test_add_complex_message() -> bool: metadata={ "model": "test-model", "temperature": 0.7, - "tokens": 42 + "tokens": 42, }, - token_count=42 + token_count=42, ) - + assert msg_id is not None, "Message ID should not be None" - + # Verify complex content was stored and retrieved correctly message = conversation.query(str(msg_id)) assert message is not None, "Message should be retrievable" - assert message["content"] == complex_content, "Complex content should match" - assert message["token_count"] == 42, "Token count should match" - + assert ( + message["content"] == complex_content + ), "Complex content should match" + assert ( + message["token_count"] == 42 + ), "Token count should match" + print("✓ Add complex message test passed") return True finally: @@ -338,11 +403,11 @@ def test_add_complex_message() -> bool: def test_batch_add() -> bool: """Test batch adding messages.""" print_test_header("Batch Add Test") - + if not SUPABASE_AVAILABLE: print("✓ Batch add test skipped - Supabase not available") return True - + conversation = setup_test_conversation() try: messages = [ @@ -350,30 +415,44 @@ def test_batch_add() -> bool: role="user", content="First batch message", message_type=MessageType.USER, - metadata={"batch": 1} + metadata={"batch": 1}, ), Message( role="assistant", - content={"response": "First response", "confidence": 0.9}, + content={ + "response": "First response", + "confidence": 0.9, + }, message_type=MessageType.ASSISTANT, - metadata={"batch": 1} + metadata={"batch": 1}, ), Message( role="user", content="Second batch message", message_type=MessageType.USER, - metadata={"batch": 2} - ) + metadata={"batch": 2}, + ), ] - + msg_ids = conversation.batch_add(messages) assert len(msg_ids) == 3, "Should have 3 message IDs" - assert all(isinstance(id, int) for id in msg_ids), "All IDs should be integers" - + assert all( + isinstance(id, int) for id in msg_ids + ), "All IDs should be integers" + # Verify messages were stored all_messages = conversation.get_messages() - assert len([m for m in all_messages if m.get("metadata", {}).get("batch")]) == 3, "Should find 3 batch messages" - + assert ( + len( + [ + m + for m in all_messages + if m.get("metadata", {}).get("batch") + ] + ) + == 3 + ), "Should find 3 batch messages" + print("✓ Batch add test passed") return True finally: @@ -383,20 +462,24 @@ def test_batch_add() -> bool: def test_get_str() -> bool: """Test getting conversation as string.""" print_test_header("Get String Test") - + if not SUPABASE_AVAILABLE: print("✓ Get string test skipped - Supabase not available") return True - + conversation = setup_test_conversation() try: conversation.add("user", "Hello!") conversation.add("assistant", "Hi there!") - + conv_str = conversation.get_str() - assert "user: Hello!" in conv_str, "User message not found in string" - assert "assistant: Hi there!" in conv_str, "Assistant message not found in string" - + assert ( + "user: Hello!" in conv_str + ), "User message not found in string" + assert ( + "assistant: Hi there!" in conv_str + ), "Assistant message not found in string" + print("✓ Get string test passed") return True finally: @@ -406,11 +489,11 @@ def test_get_str() -> bool: def test_get_messages() -> bool: """Test getting messages with pagination.""" print_test_header("Get Messages Test") - + if not SUPABASE_AVAILABLE: print("✓ Get messages test skipped - Supabase not available") return True - + conversation = setup_test_conversation() try: # Add multiple messages @@ -419,15 +502,21 @@ def test_get_messages() -> bool: # Test getting all messages all_messages = conversation.get_messages() - assert len(all_messages) >= 5, "Should have at least 5 messages" + assert ( + len(all_messages) >= 5 + ), "Should have at least 5 messages" # Test pagination limited_messages = conversation.get_messages(limit=2) - assert len(limited_messages) == 2, "Should have 2 limited messages" + assert ( + len(limited_messages) == 2 + ), "Should have 2 limited messages" offset_messages = conversation.get_messages(offset=2, limit=2) - assert len(offset_messages) == 2, "Should have 2 offset messages" - + assert ( + len(offset_messages) == 2 + ), "Should have 2 offset messages" + print("✓ Get messages test passed") return True finally: @@ -437,11 +526,13 @@ def test_get_messages() -> bool: def test_search_messages() -> bool: """Test searching messages.""" print_test_header("Search Messages Test") - + if not SUPABASE_AVAILABLE: - print("✓ Search messages test skipped - Supabase not available") + print( + "✓ Search messages test skipped - Supabase not available" + ) return True - + conversation = setup_test_conversation() try: conversation.add("user", "Hello world from Supabase") @@ -451,14 +542,20 @@ def test_search_messages() -> bool: # Test search functionality world_results = conversation.search("world") - assert len(world_results) >= 2, "Should find at least 2 messages with 'world'" - + assert ( + len(world_results) >= 2 + ), "Should find at least 2 messages with 'world'" + hello_results = conversation.search("Hello") - assert len(hello_results) >= 2, "Should find at least 2 messages with 'Hello'" - + assert ( + len(hello_results) >= 2 + ), "Should find at least 2 messages with 'Hello'" + supabase_results = conversation.search("Supabase") - assert len(supabase_results) >= 1, "Should find at least 1 message with 'Supabase'" - + assert ( + len(supabase_results) >= 1 + ), "Should find at least 1 message with 'Supabase'" + print("✓ Search messages test passed") return True finally: @@ -468,33 +565,37 @@ def test_search_messages() -> bool: def test_update_and_delete() -> bool: """Test updating and deleting messages.""" print_test_header("Update and Delete Test") - + if not SUPABASE_AVAILABLE: - print("✓ Update and delete test skipped - Supabase not available") + print( + "✓ Update and delete test skipped - Supabase not available" + ) return True - + conversation = setup_test_conversation() try: # Add a message to update/delete msg_id = conversation.add("user", "Original message") - + # Test update method (BaseCommunication signature) conversation.update( - index=str(msg_id), - role="user", - content="Updated message" + index=str(msg_id), role="user", content="Updated message" ) - + updated_msg = conversation.query_optional(str(msg_id)) - assert updated_msg is not None, "Message should exist after update" - assert updated_msg["content"] == "Updated message", "Message should be updated" - + assert ( + updated_msg is not None + ), "Message should exist after update" + assert ( + updated_msg["content"] == "Updated message" + ), "Message should be updated" + # Test delete conversation.delete(str(msg_id)) - + deleted_msg = conversation.query_optional(str(msg_id)) assert deleted_msg is None, "Message should be deleted" - + print("✓ Update and delete test passed") return True finally: @@ -504,43 +605,55 @@ def test_update_and_delete() -> bool: def test_update_message_method() -> bool: """Test the new update_message method.""" print_test_header("Update Message Method Test") - + if not SUPABASE_AVAILABLE: - print("✓ Update message method test skipped - Supabase not available") + print( + "✓ Update message method test skipped - Supabase not available" + ) return True - + conversation = setup_test_conversation() try: # Add a message to update msg_id = conversation.add( - role="user", + role="user", content="Original content", - metadata={"version": 1} + metadata={"version": 1}, ) - + # Test update_message method success = conversation.update_message( message_id=msg_id, content="Updated content via update_message", - metadata={"version": 2, "updated": True} + metadata={"version": 2, "updated": True}, ) - - assert success == True, "update_message should return True on success" - + + assert ( + success == True + ), "update_message should return True on success" + # Verify the update updated_msg = conversation.query(str(msg_id)) assert updated_msg is not None, "Message should still exist" - assert updated_msg["content"] == "Updated content via update_message", "Content should be updated" - assert updated_msg["metadata"]["version"] == 2, "Metadata should be updated" - assert updated_msg["metadata"]["updated"] == True, "New metadata field should be added" - + assert ( + updated_msg["content"] + == "Updated content via update_message" + ), "Content should be updated" + assert ( + updated_msg["metadata"]["version"] == 2 + ), "Metadata should be updated" + assert ( + updated_msg["metadata"]["updated"] == True + ), "New metadata field should be added" + # Test update_message with non-existent ID failure = conversation.update_message( - message_id=999999, - content="This should fail" + message_id=999999, content="This should fail" ) - assert failure == False, "update_message should return False for non-existent message" - + assert ( + failure == False + ), "update_message should return False for non-existent message" + print("✓ Update message method test passed") return True finally: @@ -550,30 +663,46 @@ def test_update_message_method() -> bool: def test_conversation_statistics() -> bool: """Test getting conversation statistics.""" print_test_header("Conversation Statistics Test") - + if not SUPABASE_AVAILABLE: - print("✓ Conversation statistics test skipped - Supabase not available") + print( + "✓ Conversation statistics test skipped - Supabase not available" + ) return True - + conversation = setup_test_conversation() try: # Add messages with different roles and token counts conversation.add("user", "Hello", token_count=2) conversation.add("assistant", "Hi there!", token_count=3) conversation.add("system", "System message", token_count=5) - conversation.add("user", "Another user message", token_count=4) + conversation.add( + "user", "Another user message", token_count=4 + ) stats = conversation.get_conversation_summary() - assert stats["total_messages"] >= 4, "Should have at least 4 messages" - assert stats["unique_roles"] >= 3, "Should have at least 3 unique roles" - assert stats["total_tokens"] >= 14, "Should have at least 14 total tokens" - + assert ( + stats["total_messages"] >= 4 + ), "Should have at least 4 messages" + assert ( + stats["unique_roles"] >= 3 + ), "Should have at least 3 unique roles" + assert ( + stats["total_tokens"] >= 14 + ), "Should have at least 14 total tokens" + # Test role counting role_counts = conversation.count_messages_by_role() - assert role_counts.get("user", 0) >= 2, "Should have at least 2 user messages" - assert role_counts.get("assistant", 0) >= 1, "Should have at least 1 assistant message" - assert role_counts.get("system", 0) >= 1, "Should have at least 1 system message" - + assert ( + role_counts.get("user", 0) >= 2 + ), "Should have at least 2 user messages" + assert ( + role_counts.get("assistant", 0) >= 1 + ), "Should have at least 1 assistant message" + assert ( + role_counts.get("system", 0) >= 1 + ), "Should have at least 1 system message" + print("✓ Conversation statistics test passed") return True finally: @@ -583,38 +712,53 @@ def test_conversation_statistics() -> bool: def test_json_operations() -> bool: """Test JSON save and load operations.""" print_test_header("JSON Operations Test") - + if not SUPABASE_AVAILABLE: - print("✓ JSON operations test skipped - Supabase not available") + print( + "✓ JSON operations test skipped - Supabase not available" + ) return True - + conversation = setup_test_conversation() json_file = "test_conversation.json" - + try: # Add test messages conversation.add("user", "Test message for JSON") - conversation.add("assistant", {"response": "JSON test response", "data": [1, 2, 3]}) - + conversation.add( + "assistant", + {"response": "JSON test response", "data": [1, 2, 3]}, + ) + # Test JSON export conversation.save_as_json(json_file) - assert os.path.exists(json_file), "JSON file should be created" - + assert os.path.exists( + json_file + ), "JSON file should be created" + # Verify JSON content - with open(json_file, 'r') as f: + with open(json_file, "r") as f: json_data = json.load(f) - assert isinstance(json_data, list), "JSON data should be a list" - assert len(json_data) >= 2, "Should have at least 2 messages in JSON" - + assert isinstance( + json_data, list + ), "JSON data should be a list" + assert ( + len(json_data) >= 2 + ), "Should have at least 2 messages in JSON" + # Test JSON import (creates new conversation) original_conv_id = conversation.get_conversation_id() conversation.load_from_json(json_file) new_conv_id = conversation.get_conversation_id() - assert new_conv_id != original_conv_id, "Should create new conversation on import" - + assert ( + new_conv_id != original_conv_id + ), "Should create new conversation on import" + imported_messages = conversation.get_messages() - assert len(imported_messages) >= 2, "Should have imported messages" - + assert ( + len(imported_messages) >= 2 + ), "Should have imported messages" + print("✓ JSON operations test passed") return True finally: @@ -627,32 +771,40 @@ def test_json_operations() -> bool: def test_yaml_operations() -> bool: """Test YAML save and load operations.""" print_test_header("YAML Operations Test") - + if not SUPABASE_AVAILABLE: - print("✓ YAML operations test skipped - Supabase not available") + print( + "✓ YAML operations test skipped - Supabase not available" + ) return True - + conversation = setup_test_conversation() yaml_file = "test_conversation.yaml" - + try: # Add test messages conversation.add("user", "Test message for YAML") conversation.add("assistant", "YAML test response") - + # Test YAML export conversation.save_as_yaml(yaml_file) - assert os.path.exists(yaml_file), "YAML file should be created" - + assert os.path.exists( + yaml_file + ), "YAML file should be created" + # Test YAML import (creates new conversation) original_conv_id = conversation.get_conversation_id() conversation.load_from_yaml(yaml_file) new_conv_id = conversation.get_conversation_id() - assert new_conv_id != original_conv_id, "Should create new conversation on import" - + assert ( + new_conv_id != original_conv_id + ), "Should create new conversation on import" + imported_messages = conversation.get_messages() - assert len(imported_messages) >= 2, "Should have imported messages" - + assert ( + len(imported_messages) >= 2 + ), "Should have imported messages" + print("✓ YAML operations test passed") return True finally: @@ -665,11 +817,11 @@ def test_yaml_operations() -> bool: def test_message_types() -> bool: """Test different message types.""" print_test_header("Message Types Test") - + if not SUPABASE_AVAILABLE: print("✓ Message types test skipped - Supabase not available") return True - + conversation = setup_test_conversation() try: # Test all message types @@ -678,23 +830,33 @@ def test_message_types() -> bool: (MessageType.ASSISTANT, "assistant"), (MessageType.SYSTEM, "system"), (MessageType.FUNCTION, "function"), - (MessageType.TOOL, "tool") + (MessageType.TOOL, "tool"), ] - + for msg_type, role in types_to_test: msg_id = conversation.add( role=role, content=f"Test {msg_type.value} message", - message_type=msg_type + message_type=msg_type, ) - assert msg_id is not None, f"Should create {msg_type.value} message" - + assert ( + msg_id is not None + ), f"Should create {msg_type.value} message" + # Verify all message types were stored messages = conversation.get_messages() - stored_types = {msg.get("message_type") for msg in messages if msg.get("message_type")} - expected_types = {msg_type.value for msg_type, _ in types_to_test} - assert stored_types.issuperset(expected_types), "Should store all message types" - + stored_types = { + msg.get("message_type") + for msg in messages + if msg.get("message_type") + } + expected_types = { + msg_type.value for msg_type, _ in types_to_test + } + assert stored_types.issuperset( + expected_types + ), "Should store all message types" + print("✓ Message types test passed") return True finally: @@ -704,42 +866,60 @@ def test_message_types() -> bool: def test_conversation_management() -> bool: """Test conversation management operations.""" print_test_header("Conversation Management Test") - + if not SUPABASE_AVAILABLE: - print("✓ Conversation management test skipped - Supabase not available") + print( + "✓ Conversation management test skipped - Supabase not available" + ) return True - + conversation = setup_test_conversation() try: # Test getting conversation ID conv_id = conversation.get_conversation_id() assert conv_id, "Should have a conversation ID" - assert isinstance(conv_id, str), "Conversation ID should be a string" - + assert isinstance( + conv_id, str + ), "Conversation ID should be a string" + # Add some messages conversation.add("user", "First conversation message") - conversation.add("assistant", "Response in first conversation") - + conversation.add( + "assistant", "Response in first conversation" + ) + first_conv_messages = len(conversation.get_messages()) - assert first_conv_messages >= 2, "Should have messages in first conversation" - + assert ( + first_conv_messages >= 2 + ), "Should have messages in first conversation" + # Start new conversation new_conv_id = conversation.start_new_conversation() - assert new_conv_id != conv_id, "New conversation should have different ID" - assert conversation.get_conversation_id() == new_conv_id, "Should switch to new conversation" - assert isinstance(new_conv_id, str), "New conversation ID should be a string" - + assert ( + new_conv_id != conv_id + ), "New conversation should have different ID" + assert ( + conversation.get_conversation_id() == new_conv_id + ), "Should switch to new conversation" + assert isinstance( + new_conv_id, str + ), "New conversation ID should be a string" + # Verify new conversation is empty (except any system prompts) new_messages = conversation.get_messages() conversation.add("user", "New conversation message") updated_messages = conversation.get_messages() - assert len(updated_messages) > len(new_messages), "Should be able to add to new conversation" - + assert len(updated_messages) > len( + new_messages + ), "Should be able to add to new conversation" + # Test clear conversation conversation.clear() cleared_messages = conversation.get_messages() - assert len(cleared_messages) == 0, "Conversation should be cleared" - + assert ( + len(cleared_messages) == 0 + ), "Conversation should be cleared" + print("✓ Conversation management test passed") return True finally: @@ -749,11 +929,13 @@ def test_conversation_management() -> bool: def test_get_messages_by_role() -> bool: """Test getting messages filtered by role.""" print_test_header("Get Messages by Role Test") - + if not SUPABASE_AVAILABLE: - print("✓ Get messages by role test skipped - Supabase not available") + print( + "✓ Get messages by role test skipped - Supabase not available" + ) return True - + conversation = setup_test_conversation() try: # Add messages with different roles @@ -762,20 +944,34 @@ def test_get_messages_by_role() -> bool: conversation.add("user", "User message 2") conversation.add("system", "System message") conversation.add("assistant", "Assistant message 2") - + # Test filtering by role user_messages = conversation.get_messages_by_role("user") - assert len(user_messages) >= 2, "Should have at least 2 user messages" - assert all(msg["role"] == "user" for msg in user_messages), "All messages should be from user" - - assistant_messages = conversation.get_messages_by_role("assistant") - assert len(assistant_messages) >= 2, "Should have at least 2 assistant messages" - assert all(msg["role"] == "assistant" for msg in assistant_messages), "All messages should be from assistant" - + assert ( + len(user_messages) >= 2 + ), "Should have at least 2 user messages" + assert all( + msg["role"] == "user" for msg in user_messages + ), "All messages should be from user" + + assistant_messages = conversation.get_messages_by_role( + "assistant" + ) + assert ( + len(assistant_messages) >= 2 + ), "Should have at least 2 assistant messages" + assert all( + msg["role"] == "assistant" for msg in assistant_messages + ), "All messages should be from assistant" + system_messages = conversation.get_messages_by_role("system") - assert len(system_messages) >= 1, "Should have at least 1 system message" - assert all(msg["role"] == "system" for msg in system_messages), "All messages should be from system" - + assert ( + len(system_messages) >= 1 + ), "Should have at least 1 system message" + assert all( + msg["role"] == "system" for msg in system_messages + ), "All messages should be from system" + print("✓ Get messages by role test passed") return True finally: @@ -785,35 +981,47 @@ def test_get_messages_by_role() -> bool: def test_timeline_and_organization() -> bool: """Test conversation timeline and organization features.""" print_test_header("Timeline and Organization Test") - + if not SUPABASE_AVAILABLE: - print("✓ Timeline and organization test skipped - Supabase not available") + print( + "✓ Timeline and organization test skipped - Supabase not available" + ) return True - + conversation = setup_test_conversation() try: # Add messages conversation.add("user", "Timeline test message 1") conversation.add("assistant", "Timeline test response 1") conversation.add("user", "Timeline test message 2") - + # Test timeline organization timeline = conversation.get_conversation_timeline_dict() - assert isinstance(timeline, dict), "Timeline should be a dictionary" + assert isinstance( + timeline, dict + ), "Timeline should be a dictionary" assert len(timeline) > 0, "Timeline should have entries" - + # Test organization by role by_role = conversation.get_conversation_by_role_dict() - assert isinstance(by_role, dict), "Role organization should be a dictionary" + assert isinstance( + by_role, dict + ), "Role organization should be a dictionary" assert "user" in by_role, "Should have user messages" - assert "assistant" in by_role, "Should have assistant messages" - + assert ( + "assistant" in by_role + ), "Should have assistant messages" + # Test conversation as dict conv_dict = conversation.get_conversation_as_dict() - assert isinstance(conv_dict, dict), "Conversation dict should be a dictionary" - assert "conversation_id" in conv_dict, "Should have conversation ID" + assert isinstance( + conv_dict, dict + ), "Conversation dict should be a dictionary" + assert ( + "conversation_id" in conv_dict + ), "Should have conversation ID" assert "messages" in conv_dict, "Should have messages" - + print("✓ Timeline and organization test passed") return True finally: @@ -823,14 +1031,16 @@ def test_timeline_and_organization() -> bool: def test_concurrent_operations() -> bool: """Test concurrent operations for thread safety.""" print_test_header("Concurrent Operations Test") - + if not SUPABASE_AVAILABLE: - print("✓ Concurrent operations test skipped - Supabase not available") + print( + "✓ Concurrent operations test skipped - Supabase not available" + ) return True - + conversation = setup_test_conversation() results = [] - + def add_messages(thread_id): """Add messages in a separate thread.""" try: @@ -838,12 +1048,15 @@ def test_concurrent_operations() -> bool: msg_id = conversation.add( role="user", content=f"Thread {thread_id} message {i}", - metadata={"thread_id": thread_id, "message_num": i} + metadata={ + "thread_id": thread_id, + "message_num": i, + }, ) results.append(("success", thread_id, msg_id)) except Exception as e: results.append(("error", thread_id, str(e))) - + try: # Create and start multiple threads threads = [] @@ -851,20 +1064,30 @@ def test_concurrent_operations() -> bool: thread = threading.Thread(target=add_messages, args=(i,)) threads.append(thread) thread.start() - + # Wait for all threads to complete for thread in threads: thread.join() - + # Check results - successful_operations = [r for r in results if r[0] == "success"] - assert len(successful_operations) >= 6, "Should have successful concurrent operations" - + successful_operations = [ + r for r in results if r[0] == "success" + ] + assert ( + len(successful_operations) >= 6 + ), "Should have successful concurrent operations" + # Verify messages were actually stored all_messages = conversation.get_messages() - thread_messages = [m for m in all_messages if m.get("metadata", {}).get("thread_id") is not None] - assert len(thread_messages) >= 6, "Should have stored concurrent messages" - + thread_messages = [ + m + for m in all_messages + if m.get("metadata", {}).get("thread_id") is not None + ] + assert ( + len(thread_messages) >= 6 + ), "Should have stored concurrent messages" + print("✓ Concurrent operations test passed") return True finally: @@ -874,60 +1097,86 @@ def test_concurrent_operations() -> bool: def test_enhanced_error_handling() -> bool: """Test enhanced error handling for various edge cases.""" print_test_header("Enhanced Error Handling Test") - + if not SUPABASE_AVAILABLE: - print("✓ Enhanced error handling test skipped - Supabase not available") + print( + "✓ Enhanced error handling test skipped - Supabase not available" + ) return True - + # Test invalid credentials try: invalid_conversation = SupabaseConversation( supabase_url="https://invalid-url.supabase.co", supabase_key="invalid_key", - enable_logging=False + enable_logging=False, ) # This should raise an exception during initialization assert False, "Should raise exception for invalid credentials" except (SupabaseConnectionError, Exception): pass # Expected behavior - + # Test with valid conversation conversation = setup_test_conversation() try: # Test querying non-existent message with query (should return empty dict) non_existent = conversation.query("999999") - assert non_existent == {}, "Non-existent message should return empty dict" - + assert ( + non_existent == {} + ), "Non-existent message should return empty dict" + # Test querying non-existent message with query_optional (should return None) non_existent_opt = conversation.query_optional("999999") - assert non_existent_opt is None, "Non-existent message should return None with query_optional" - + assert ( + non_existent_opt is None + ), "Non-existent message should return None with query_optional" + # Test deleting non-existent message (should not raise exception) conversation.delete("999999") # Should handle gracefully - + # Test updating non-existent message (should return False) - update_result = conversation._update_flexible("999999", "user", "content") - assert update_result == False, "_update_flexible should return False for invalid ID" - + update_result = conversation._update_flexible( + "999999", "user", "content" + ) + assert ( + update_result == False + ), "_update_flexible should return False for invalid ID" + # Test update_message with invalid ID - result = conversation.update_message(999999, "invalid content") - assert result == False, "update_message should return False for invalid ID" - + result = conversation.update_message( + 999999, "invalid content" + ) + assert ( + result == False + ), "update_message should return False for invalid ID" + # Test search with empty query empty_results = conversation.search("") - assert isinstance(empty_results, list), "Empty search should return list" - + assert isinstance( + empty_results, list + ), "Empty search should return list" + # Test invalid message ID formats (should return empty dict now) invalid_query = conversation.query("not_a_number") - assert invalid_query == {}, "Invalid ID should return empty dict" - - invalid_query_opt = conversation.query_optional("not_a_number") - assert invalid_query_opt is None, "Invalid ID should return None with query_optional" - + assert ( + invalid_query == {} + ), "Invalid ID should return empty dict" + + invalid_query_opt = conversation.query_optional( + "not_a_number" + ) + assert ( + invalid_query_opt is None + ), "Invalid ID should return None with query_optional" + # Test update with invalid ID (should return False) - invalid_update = conversation._update_flexible("not_a_number", "user", "content") - assert invalid_update == False, "Invalid ID should return False for update" - + invalid_update = conversation._update_flexible( + "not_a_number", "user", "content" + ) + assert ( + invalid_update == False + ), "Invalid ID should return False for update" + print("✓ Enhanced error handling test passed") return True finally: @@ -937,60 +1186,80 @@ def test_enhanced_error_handling() -> bool: def test_fallback_functionality() -> bool: """Test fallback functionality when dependencies are missing.""" print_test_header("Fallback Functionality Test") - + # This test always passes as it tests the import fallback mechanism if not SUPABASE_AVAILABLE: - print("✓ Fallback functionality test passed - gracefully handled missing dependencies") + print( + "✓ Fallback functionality test passed - gracefully handled missing dependencies" + ) return True else: - print("✓ Fallback functionality test passed - dependencies available, no fallback needed") + print( + "✓ Fallback functionality test passed - dependencies available, no fallback needed" + ) return True -def generate_test_report(test_results: List[Dict[str, Any]]) -> Dict[str, Any]: +def generate_test_report( + test_results: List[Dict[str, Any]], +) -> Dict[str, Any]: """Generate a comprehensive test report.""" total_tests = len(test_results) - passed_tests = sum(1 for result in test_results if result["success"]) + passed_tests = sum( + 1 for result in test_results if result["success"] + ) failed_tests = total_tests - passed_tests - - total_time = sum(result["execution_time"] for result in test_results) + + total_time = sum( + result["execution_time"] for result in test_results + ) avg_time = total_time / total_tests if total_tests > 0 else 0 - + report = { "summary": { "total_tests": total_tests, "passed_tests": passed_tests, "failed_tests": failed_tests, - "success_rate": (passed_tests / total_tests * 100) if total_tests > 0 else 0, + "success_rate": ( + (passed_tests / total_tests * 100) + if total_tests > 0 + else 0 + ), "total_execution_time": total_time, "average_execution_time": avg_time, "timestamp": datetime.datetime.now().isoformat(), "supabase_available": SUPABASE_AVAILABLE, - "environment_configured": bool(TEST_SUPABASE_URL and TEST_SUPABASE_KEY), + "environment_configured": bool( + TEST_SUPABASE_URL and TEST_SUPABASE_KEY + ), }, "test_results": test_results, "failed_tests": [ result for result in test_results if not result["success"] ], } - + return report def run_all_tests() -> None: """Run all SupabaseConversation tests.""" print("🚀 Starting Enhanced SupabaseConversation Test Suite") - print(f"Supabase Available: {'✅' if SUPABASE_AVAILABLE else '❌'}") - + print( + f"Supabase Available: {'✅' if SUPABASE_AVAILABLE else '❌'}" + ) + if TEST_SUPABASE_URL and TEST_SUPABASE_KEY: print(f"Using Supabase URL: {TEST_SUPABASE_URL[:30]}...") print(f"Using table: {TEST_TABLE_NAME}") else: - print("❌ Environment variables SUPABASE_URL and SUPABASE_KEY not set") + print( + "❌ Environment variables SUPABASE_URL and SUPABASE_KEY not set" + ) print("Some tests will be skipped") - + print("=" * 60) - + # Define tests to run tests = [ ("Import Availability", test_import_availability), @@ -1015,14 +1284,14 @@ def run_all_tests() -> None: ("Concurrent Operations", test_concurrent_operations), ("Enhanced Error Handling", test_enhanced_error_handling), ] - + test_results = [] - + # Run each test for test_name, test_func in tests: print_test_header(test_name) success, message, execution_time = run_test(test_func) - + test_result = { "test_name": test_name, "success": success, @@ -1030,12 +1299,12 @@ def run_all_tests() -> None: "execution_time": execution_time, } test_results.append(test_result) - + print_test_result(test_name, success, message, execution_time) - + # Generate and display report report = generate_test_report(test_results) - + print("\n" + "=" * 60) print("📊 ENHANCED TEST SUMMARY") print("=" * 60) @@ -1043,33 +1312,43 @@ def run_all_tests() -> None: print(f"Passed: {report['summary']['passed_tests']}") print(f"Failed: {report['summary']['failed_tests']}") print(f"Success Rate: {report['summary']['success_rate']:.1f}%") - print(f"Total Time: {report['summary']['total_execution_time']:.3f} seconds") - print(f"Average Time: {report['summary']['average_execution_time']:.3f} seconds") - print(f"Supabase Available: {'✅' if report['summary']['supabase_available'] else '❌'}") - print(f"Environment Configured: {'✅' if report['summary']['environment_configured'] else '❌'}") - - if report['failed_tests']: + print( + f"Total Time: {report['summary']['total_execution_time']:.3f} seconds" + ) + print( + f"Average Time: {report['summary']['average_execution_time']:.3f} seconds" + ) + print( + f"Supabase Available: {'✅' if report['summary']['supabase_available'] else '❌'}" + ) + print( + f"Environment Configured: {'✅' if report['summary']['environment_configured'] else '❌'}" + ) + + if report["failed_tests"]: print("\n❌ FAILED TESTS:") - for failed_test in report['failed_tests']: - print(f" - {failed_test['test_name']}: {failed_test['message']}") + for failed_test in report["failed_tests"]: + print( + f" - {failed_test['test_name']}: {failed_test['message']}" + ) else: print("\n✅ All tests passed!") - + # Additional information if not SUPABASE_AVAILABLE: print("\n🔍 NOTE: Supabase dependencies not available.") print("Install with: pip install supabase") - + if not (TEST_SUPABASE_URL and TEST_SUPABASE_KEY): print("\n🔍 NOTE: Environment variables not set.") print("Set SUPABASE_URL and SUPABASE_KEY to run full tests.") - + # Save detailed report report_file = f"supabase_test_report_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.json" - with open(report_file, 'w') as f: + with open(report_file, "w") as f: json.dump(report, f, indent=2, default=str) print(f"\n📄 Detailed report saved to: {report_file}") if __name__ == "__main__": - run_all_tests() \ No newline at end of file + run_all_tests()