communication structures

pull/812/merge
Kye Gomez 2 days ago
parent 26033cefd1
commit a5154aa26f

@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
name = "swarms" name = "swarms"
version = "7.7.3" version = "7.7.5"
description = "Swarms - TGSC" description = "Swarms - TGSC"
license = "MIT" license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"] authors = ["Kye Gomez <kye@apac.ai>"]

File diff suppressed because it is too large Load Diff

@ -0,0 +1,813 @@
import sqlite3
import json
import datetime
from typing import List, Optional, Union, Dict
from pathlib import Path
import threading
from contextlib import contextmanager
import logging
from dataclasses import dataclass
from enum import Enum
import uuid
import yaml
try:
from loguru import logger
LOGURU_AVAILABLE = True
except ImportError:
LOGURU_AVAILABLE = False
class MessageType(Enum):
"""Enum for different types of messages in the conversation."""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
FUNCTION = "function"
TOOL = "tool"
@dataclass
class Message:
"""Data class representing a message in the conversation."""
role: str
content: Union[str, dict, list]
timestamp: Optional[str] = None
message_type: Optional[MessageType] = None
metadata: Optional[Dict] = None
token_count: Optional[int] = None
class Config:
arbitrary_types_allowed = True
class SQLiteConversation:
"""
A production-grade SQLite wrapper class for managing conversation history.
This class provides persistent storage for conversations with various features
like message tracking, timestamps, and metadata support.
Attributes:
db_path (str): Path to the SQLite database file
table_name (str): Name of the table to store conversations
enable_timestamps (bool): Whether to track message timestamps
enable_logging (bool): Whether to enable logging
use_loguru (bool): Whether to use loguru for logging
max_retries (int): Maximum number of retries for database operations
connection_timeout (float): Timeout for database connections
current_conversation_id (str): Current active conversation ID
"""
def __init__(
self,
db_path: str = "conversations.db",
table_name: str = "conversations",
enable_timestamps: bool = True,
enable_logging: bool = True,
use_loguru: bool = True,
max_retries: int = 3,
connection_timeout: float = 5.0,
**kwargs,
):
"""
Initialize the SQLite conversation manager.
Args:
db_path (str): Path to the SQLite database file
table_name (str): Name of the table to store conversations
enable_timestamps (bool): Whether to track message timestamps
enable_logging (bool): Whether to enable logging
use_loguru (bool): Whether to use loguru for logging
max_retries (int): Maximum number of retries for database operations
connection_timeout (float): Timeout for database connections
"""
self.db_path = Path(db_path)
self.table_name = table_name
self.enable_timestamps = enable_timestamps
self.enable_logging = enable_logging
self.use_loguru = use_loguru and LOGURU_AVAILABLE
self.max_retries = max_retries
self.connection_timeout = connection_timeout
self._lock = threading.Lock()
self.current_conversation_id = (
self._generate_conversation_id()
)
# Setup logging
if self.enable_logging:
if self.use_loguru:
self.logger = logger
else:
self.logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)
# Initialize database
self._init_db()
def _generate_conversation_id(self) -> str:
"""Generate a unique conversation ID using UUID and timestamp."""
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
unique_id = str(uuid.uuid4())[:8]
return f"conv_{timestamp}_{unique_id}"
def start_new_conversation(self) -> str:
"""
Start a new conversation and return its ID.
Returns:
str: The new conversation ID
"""
self.current_conversation_id = (
self._generate_conversation_id()
)
return self.current_conversation_id
def _init_db(self):
"""Initialize the database and create necessary tables."""
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
id INTEGER PRIMARY KEY AUTOINCREMENT,
role TEXT NOT NULL,
content TEXT NOT NULL,
timestamp TEXT,
message_type TEXT,
metadata TEXT,
token_count INTEGER,
conversation_id TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
)
"""
)
conn.commit()
@contextmanager
def _get_connection(self):
"""Context manager for database connections with retry logic."""
conn = None
for attempt in range(self.max_retries):
try:
conn = sqlite3.connect(
str(self.db_path), timeout=self.connection_timeout
)
conn.row_factory = sqlite3.Row
yield conn
break
except sqlite3.Error as e:
if attempt == self.max_retries - 1:
raise
if self.enable_logging:
self.logger.warning(
f"Database connection attempt {attempt + 1} failed: {e}"
)
finally:
if conn:
conn.close()
def add(
self,
role: str,
content: Union[str, dict, list],
message_type: Optional[MessageType] = None,
metadata: Optional[Dict] = None,
token_count: Optional[int] = None,
) -> int:
"""
Add a message to the current conversation.
Args:
role (str): The role of the speaker
content (Union[str, dict, list]): The content of the message
message_type (Optional[MessageType]): Type of the message
metadata (Optional[Dict]): Additional metadata for the message
token_count (Optional[int]): Number of tokens in the message
Returns:
int: The ID of the inserted message
"""
timestamp = (
datetime.datetime.now().isoformat()
if self.enable_timestamps
else None
)
if isinstance(content, (dict, list)):
content = json.dumps(content)
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
f"""
INSERT INTO {self.table_name}
(role, content, timestamp, message_type, metadata, token_count, conversation_id)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
role,
content,
timestamp,
message_type.value if message_type else None,
json.dumps(metadata) if metadata else None,
token_count,
self.current_conversation_id,
),
)
conn.commit()
return cursor.lastrowid
def batch_add(self, messages: List[Message]) -> List[int]:
"""
Add multiple messages to the current conversation.
Args:
messages (List[Message]): List of messages to add
Returns:
List[int]: List of inserted message IDs
"""
with self._get_connection() as conn:
cursor = conn.cursor()
message_ids = []
for message in messages:
content = message.content
if isinstance(content, (dict, list)):
content = json.dumps(content)
cursor.execute(
f"""
INSERT INTO {self.table_name}
(role, content, timestamp, message_type, metadata, token_count, conversation_id)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
message.role,
content,
(
message.timestamp.isoformat()
if message.timestamp
else None
),
(
message.message_type.value
if message.message_type
else None
),
(
json.dumps(message.metadata)
if message.metadata
else None
),
message.token_count,
self.current_conversation_id,
),
)
message_ids.append(cursor.lastrowid)
conn.commit()
return message_ids
def get_str(self) -> str:
"""
Get the current conversation history as a formatted string.
Returns:
str: Formatted conversation history
"""
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
f"""
SELECT * FROM {self.table_name}
WHERE conversation_id = ?
ORDER BY id ASC
""",
(self.current_conversation_id,),
)
messages = []
for row in cursor.fetchall():
content = row["content"]
try:
content = json.loads(content)
except json.JSONDecodeError:
pass
timestamp = (
f"[{row['timestamp']}] "
if row["timestamp"]
else ""
)
messages.append(
f"{timestamp}{row['role']}: {content}"
)
return "\n".join(messages)
def get_messages(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> List[Dict]:
"""
Get messages from the current conversation with optional pagination.
Args:
limit (Optional[int]): Maximum number of messages to return
offset (Optional[int]): Number of messages to skip
Returns:
List[Dict]: List of message dictionaries
"""
with self._get_connection() as conn:
cursor = conn.cursor()
query = f"""
SELECT * FROM {self.table_name}
WHERE conversation_id = ?
ORDER BY id ASC
"""
params = [self.current_conversation_id]
if limit is not None:
query += " LIMIT ?"
params.append(limit)
if offset is not None:
query += " OFFSET ?"
params.append(offset)
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
def delete_current_conversation(self) -> bool:
"""
Delete the current conversation.
Returns:
bool: True if deletion was successful
"""
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
f"DELETE FROM {self.table_name} WHERE conversation_id = ?",
(self.current_conversation_id,),
)
conn.commit()
return cursor.rowcount > 0
def update_message(
self,
message_id: int,
content: Union[str, dict, list],
metadata: Optional[Dict] = None,
) -> bool:
"""
Update an existing message in the current conversation.
Args:
message_id (int): ID of the message to update
content (Union[str, dict, list]): New content for the message
metadata (Optional[Dict]): New metadata for the message
Returns:
bool: True if update was successful
"""
if isinstance(content, (dict, list)):
content = json.dumps(content)
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
f"""
UPDATE {self.table_name}
SET content = ?, metadata = ?
WHERE id = ? AND conversation_id = ?
""",
(
content,
json.dumps(metadata) if metadata else None,
message_id,
self.current_conversation_id,
),
)
conn.commit()
return cursor.rowcount > 0
def search_messages(self, query: str) -> List[Dict]:
"""
Search for messages containing specific text in the current conversation.
Args:
query (str): Text to search for
Returns:
List[Dict]: List of matching messages
"""
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
f"""
SELECT * FROM {self.table_name}
WHERE conversation_id = ? AND content LIKE ?
""",
(self.current_conversation_id, f"%{query}%"),
)
return [dict(row) for row in cursor.fetchall()]
def get_statistics(self) -> Dict:
"""
Get statistics about the current conversation.
Returns:
Dict: Statistics about the conversation
"""
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
f"""
SELECT
COUNT(*) as total_messages,
COUNT(DISTINCT role) as unique_roles,
SUM(token_count) as total_tokens,
MIN(timestamp) as first_message,
MAX(timestamp) as last_message
FROM {self.table_name}
WHERE conversation_id = ?
""",
(self.current_conversation_id,),
)
return dict(cursor.fetchone())
def clear_all(self) -> bool:
"""
Clear all messages from the database.
Returns:
bool: True if clearing was successful
"""
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(f"DELETE FROM {self.table_name}")
conn.commit()
return True
def get_conversation_id(self) -> str:
"""
Get the current conversation ID.
Returns:
str: The current conversation ID
"""
return self.current_conversation_id
def to_dict(self) -> List[Dict]:
"""
Convert the current conversation to a list of dictionaries.
Returns:
List[Dict]: List of message dictionaries
"""
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
f"""
SELECT role, content, timestamp, message_type, metadata, token_count
FROM {self.table_name}
WHERE conversation_id = ?
ORDER BY id ASC
""",
(self.current_conversation_id,),
)
messages = []
for row in cursor.fetchall():
content = row["content"]
try:
content = json.loads(content)
except json.JSONDecodeError:
pass
message = {"role": row["role"], "content": content}
if row["timestamp"]:
message["timestamp"] = row["timestamp"]
if row["message_type"]:
message["message_type"] = row["message_type"]
if row["metadata"]:
message["metadata"] = json.loads(row["metadata"])
if row["token_count"]:
message["token_count"] = row["token_count"]
messages.append(message)
return messages
def to_json(self) -> str:
"""
Convert the current conversation to a JSON string.
Returns:
str: JSON string representation of the conversation
"""
return json.dumps(self.to_dict(), indent=2)
def to_yaml(self) -> str:
"""
Convert the current conversation to a YAML string.
Returns:
str: YAML string representation of the conversation
"""
return yaml.dump(self.to_dict())
def save_as_json(self, filename: str) -> bool:
"""
Save the current conversation to a JSON file.
Args:
filename (str): Path to save the JSON file
Returns:
bool: True if save was successful
"""
try:
with open(filename, "w") as f:
json.dump(self.to_dict(), f, indent=2)
return True
except Exception as e:
if self.enable_logging:
self.logger.error(
f"Failed to save conversation to JSON: {e}"
)
return False
def save_as_yaml(self, filename: str) -> bool:
"""
Save the current conversation to a YAML file.
Args:
filename (str): Path to save the YAML file
Returns:
bool: True if save was successful
"""
try:
with open(filename, "w") as f:
yaml.dump(self.to_dict(), f)
return True
except Exception as e:
if self.enable_logging:
self.logger.error(
f"Failed to save conversation to YAML: {e}"
)
return False
def load_from_json(self, filename: str) -> bool:
"""
Load a conversation from a JSON file.
Args:
filename (str): Path to the JSON file
Returns:
bool: True if load was successful
"""
try:
with open(filename, "r") as f:
messages = json.load(f)
# Start a new conversation
self.start_new_conversation()
# Add all messages
for message in messages:
self.add(
role=message["role"],
content=message["content"],
message_type=(
MessageType(message["message_type"])
if "message_type" in message
else None
),
metadata=message.get("metadata"),
token_count=message.get("token_count"),
)
return True
except Exception as e:
if self.enable_logging:
self.logger.error(
f"Failed to load conversation from JSON: {e}"
)
return False
def load_from_yaml(self, filename: str) -> bool:
"""
Load a conversation from a YAML file.
Args:
filename (str): Path to the YAML file
Returns:
bool: True if load was successful
"""
try:
with open(filename, "r") as f:
messages = yaml.safe_load(f)
# Start a new conversation
self.start_new_conversation()
# Add all messages
for message in messages:
self.add(
role=message["role"],
content=message["content"],
message_type=(
MessageType(message["message_type"])
if "message_type" in message
else None
),
metadata=message.get("metadata"),
token_count=message.get("token_count"),
)
return True
except Exception as e:
if self.enable_logging:
self.logger.error(
f"Failed to load conversation from YAML: {e}"
)
return False
def get_last_message(self) -> Optional[Dict]:
"""
Get the last message from the current conversation.
Returns:
Optional[Dict]: The last message or None if conversation is empty
"""
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
f"""
SELECT role, content, timestamp, message_type, metadata, token_count
FROM {self.table_name}
WHERE conversation_id = ?
ORDER BY id DESC
LIMIT 1
""",
(self.current_conversation_id,),
)
row = cursor.fetchone()
if not row:
return None
content = row["content"]
try:
content = json.loads(content)
except json.JSONDecodeError:
pass
message = {"role": row["role"], "content": content}
if row["timestamp"]:
message["timestamp"] = row["timestamp"]
if row["message_type"]:
message["message_type"] = row["message_type"]
if row["metadata"]:
message["metadata"] = json.loads(row["metadata"])
if row["token_count"]:
message["token_count"] = row["token_count"]
return message
def get_last_message_as_string(self) -> str:
"""
Get the last message as a formatted string.
Returns:
str: Formatted string of the last message
"""
last_message = self.get_last_message()
if not last_message:
return ""
timestamp = (
f"[{last_message['timestamp']}] "
if "timestamp" in last_message
else ""
)
return f"{timestamp}{last_message['role']}: {last_message['content']}"
def count_messages_by_role(self) -> Dict[str, int]:
"""
Count messages by role in the current conversation.
Returns:
Dict[str, int]: Dictionary with role counts
"""
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
f"""
SELECT role, COUNT(*) as count
FROM {self.table_name}
WHERE conversation_id = ?
GROUP BY role
""",
(self.current_conversation_id,),
)
return {
row["role"]: row["count"] for row in cursor.fetchall()
}
def get_messages_by_role(self, role: str) -> List[Dict]:
"""
Get all messages from a specific role in the current conversation.
Args:
role (str): Role to filter messages by
Returns:
List[Dict]: List of messages from the specified role
"""
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
f"""
SELECT role, content, timestamp, message_type, metadata, token_count
FROM {self.table_name}
WHERE conversation_id = ? AND role = ?
ORDER BY id ASC
""",
(self.current_conversation_id, role),
)
messages = []
for row in cursor.fetchall():
content = row["content"]
try:
content = json.loads(content)
except json.JSONDecodeError:
pass
message = {"role": row["role"], "content": content}
if row["timestamp"]:
message["timestamp"] = row["timestamp"]
if row["message_type"]:
message["message_type"] = row["message_type"]
if row["metadata"]:
message["metadata"] = json.loads(row["metadata"])
if row["token_count"]:
message["token_count"] = row["token_count"]
messages.append(message)
return messages
def get_conversation_summary(self) -> Dict:
"""
Get a summary of the current conversation.
Returns:
Dict: Summary of the conversation including message counts, roles, and time range
"""
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
f"""
SELECT
COUNT(*) as total_messages,
COUNT(DISTINCT role) as unique_roles,
MIN(timestamp) as first_message_time,
MAX(timestamp) as last_message_time,
SUM(token_count) as total_tokens
FROM {self.table_name}
WHERE conversation_id = ?
""",
(self.current_conversation_id,),
)
row = cursor.fetchone()
return {
"conversation_id": self.current_conversation_id,
"total_messages": row["total_messages"],
"unique_roles": row["unique_roles"],
"first_message_time": row["first_message_time"],
"last_message_time": row["last_message_time"],
"total_tokens": row["total_tokens"],
"roles": self.count_messages_by_role(),
}

@ -1,5 +1,5 @@
import os import os
from typing import List from typing import List, Literal
from swarms.structs.agent import Agent from swarms.structs.agent import Agent
from swarms.structs.conversation import Conversation from swarms.structs.conversation import Conversation
from swarms.structs.multi_agent_exec import get_swarms_info from swarms.structs.multi_agent_exec import get_swarms_info
@ -10,6 +10,23 @@ from swarms.utils.history_output_formatter import (
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Union, Callable from typing import Union, Callable
HistoryOutputType = Literal[
"list",
"dict",
"dictionary",
"string",
"str",
"final",
"last",
"json",
"all",
"yaml",
# "dict-final",
"dict-all-except-first",
"str-all-except-first",
]
tools = [ tools = [
{ {
"type": "function", "type": "function",
@ -105,7 +122,7 @@ class HybridHierarchicalClusterSwarm:
description: str = "A swarm that uses a hybrid hierarchical-peer model to solve complex tasks.", description: str = "A swarm that uses a hybrid hierarchical-peer model to solve complex tasks.",
swarms: List[Union[SwarmRouter, Callable]] = [], swarms: List[Union[SwarmRouter, Callable]] = [],
max_loops: int = 1, max_loops: int = 1,
output_type: str = "list", output_type: HistoryOutputType = "list",
router_agent_model_name: str = "gpt-4o-mini", router_agent_model_name: str = "gpt-4o-mini",
*args, *args,
**kwargs, **kwargs,

@ -6,7 +6,7 @@ from swarms.utils.any_to_str import any_to_str
from swarms.utils.loguru_logger import initialize_logger from swarms.utils.loguru_logger import initialize_logger
from swarms.structs.conversation import Conversation from swarms.structs.conversation import Conversation
from swarms.utils.history_output_formatter import ( from swarms.utils.history_output_formatter import (
output_type, HistoryOutputType,
) )
logger = initialize_logger(log_folder="swarm_arange") logger = initialize_logger(log_folder="swarm_arange")
@ -58,7 +58,7 @@ class SwarmRearrange:
Callable[[str], str] Callable[[str], str]
] = None, ] = None,
return_json: bool = False, return_json: bool = False,
output_type: output_type = "dict-all-except-first", output_type: HistoryOutputType = "dict-all-except-first",
*args, *args,
**kwargs, **kwargs,
): ):

@ -0,0 +1,317 @@
import os
import json
from datetime import datetime
from pathlib import Path
import tempfile
import threading
from swarms.communication.duckdb_wrap import (
DuckDBConversation,
Message,
MessageType,
)
def setup_test():
"""Set up test environment."""
temp_dir = tempfile.TemporaryDirectory()
db_path = Path(temp_dir.name) / "test_conversations.duckdb"
conversation = DuckDBConversation(
db_path=str(db_path),
enable_timestamps=True,
enable_logging=True,
)
return temp_dir, db_path, conversation
def cleanup_test(temp_dir, db_path):
"""Clean up test environment."""
if os.path.exists(db_path):
os.remove(db_path)
temp_dir.cleanup()
def test_initialization():
"""Test conversation initialization."""
temp_dir, db_path, _ = setup_test()
try:
conv = DuckDBConversation(db_path=str(db_path))
assert conv.db_path == db_path, "Database path mismatch"
assert conv.table_name == "conversations", "Table name mismatch"
assert conv.enable_timestamps is True, "Timestamps should be enabled"
assert conv.current_conversation_id is not None, "Conversation ID should not be None"
print("✓ Initialization test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_add_message():
"""Test adding a single message."""
temp_dir, db_path, conversation = setup_test()
try:
msg_id = conversation.add(
role="user",
content="Hello, world!",
message_type=MessageType.USER,
)
assert msg_id is not None, "Message ID should not be None"
assert isinstance(msg_id, int), "Message ID should be an integer"
print("✓ Add message test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_add_complex_message():
"""Test adding a message with complex content."""
temp_dir, db_path, conversation = setup_test()
try:
complex_content = {
"text": "Hello",
"data": [1, 2, 3],
"nested": {"key": "value"}
}
msg_id = conversation.add(
role="assistant",
content=complex_content,
message_type=MessageType.ASSISTANT,
metadata={"source": "test"},
token_count=10
)
assert msg_id is not None, "Message ID should not be None"
print("✓ Add complex message test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_batch_add():
"""Test batch adding messages."""
temp_dir, db_path, conversation = setup_test()
try:
messages = [
Message(
role="user",
content="First message",
message_type=MessageType.USER
),
Message(
role="assistant",
content="Second message",
message_type=MessageType.ASSISTANT
)
]
msg_ids = conversation.batch_add(messages)
assert len(msg_ids) == 2, "Should have 2 message IDs"
assert all(isinstance(id, int) for id in msg_ids), "All IDs should be integers"
print("✓ Batch add test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_get_str():
"""Test getting conversation as string."""
temp_dir, db_path, conversation = setup_test()
try:
conversation.add("user", "Hello")
conversation.add("assistant", "Hi there!")
conv_str = conversation.get_str()
assert "user: Hello" in conv_str, "User message not found"
assert "assistant: Hi there!" in conv_str, "Assistant message not found"
print("✓ Get string test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_get_messages():
"""Test getting messages with pagination."""
temp_dir, db_path, conversation = setup_test()
try:
for i in range(5):
conversation.add("user", f"Message {i}")
all_messages = conversation.get_messages()
assert len(all_messages) == 5, "Should have 5 messages"
limited_messages = conversation.get_messages(limit=2)
assert len(limited_messages) == 2, "Should have 2 limited messages"
offset_messages = conversation.get_messages(offset=2)
assert len(offset_messages) == 3, "Should have 3 offset messages"
print("✓ Get messages test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_search_messages():
"""Test searching messages."""
temp_dir, db_path, conversation = setup_test()
try:
conversation.add("user", "Hello world")
conversation.add("assistant", "Hello there")
conversation.add("user", "Goodbye world")
results = conversation.search_messages("world")
assert len(results) == 2, "Should find 2 messages with 'world'"
assert all("world" in msg["content"] for msg in results), "All results should contain 'world'"
print("✓ Search messages test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_get_statistics():
"""Test getting conversation statistics."""
temp_dir, db_path, conversation = setup_test()
try:
conversation.add("user", "Hello", token_count=2)
conversation.add("assistant", "Hi", token_count=1)
stats = conversation.get_statistics()
assert stats["total_messages"] == 2, "Should have 2 total messages"
assert stats["unique_roles"] == 2, "Should have 2 unique roles"
assert stats["total_tokens"] == 3, "Should have 3 total tokens"
print("✓ Get statistics test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_json_operations():
"""Test JSON save and load operations."""
temp_dir, db_path, conversation = setup_test()
try:
conversation.add("user", "Hello")
conversation.add("assistant", "Hi")
json_path = Path(temp_dir.name) / "test_conversation.json"
conversation.save_as_json(str(json_path))
assert json_path.exists(), "JSON file should exist"
new_conversation = DuckDBConversation(
db_path=str(Path(temp_dir.name) / "new.duckdb")
)
assert new_conversation.load_from_json(str(json_path)), "Should load from JSON"
assert len(new_conversation.get_messages()) == 2, "Should have 2 messages after load"
print("✓ JSON operations test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_yaml_operations():
"""Test YAML save and load operations."""
temp_dir, db_path, conversation = setup_test()
try:
conversation.add("user", "Hello")
conversation.add("assistant", "Hi")
yaml_path = Path(temp_dir.name) / "test_conversation.yaml"
conversation.save_as_yaml(str(yaml_path))
assert yaml_path.exists(), "YAML file should exist"
new_conversation = DuckDBConversation(
db_path=str(Path(temp_dir.name) / "new.duckdb")
)
assert new_conversation.load_from_yaml(str(yaml_path)), "Should load from YAML"
assert len(new_conversation.get_messages()) == 2, "Should have 2 messages after load"
print("✓ YAML operations test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_message_types():
"""Test different message types."""
temp_dir, db_path, conversation = setup_test()
try:
conversation.add("system", "System message", message_type=MessageType.SYSTEM)
conversation.add("user", "User message", message_type=MessageType.USER)
conversation.add("assistant", "Assistant message", message_type=MessageType.ASSISTANT)
conversation.add("function", "Function message", message_type=MessageType.FUNCTION)
conversation.add("tool", "Tool message", message_type=MessageType.TOOL)
messages = conversation.get_messages()
assert len(messages) == 5, "Should have 5 messages"
assert all("message_type" in msg for msg in messages), "All messages should have type"
print("✓ Message types test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_delete_operations():
"""Test deletion operations."""
temp_dir, db_path, conversation = setup_test()
try:
conversation.add("user", "Hello")
conversation.add("assistant", "Hi")
assert conversation.delete_current_conversation(), "Should delete conversation"
assert len(conversation.get_messages()) == 0, "Should have no messages after delete"
conversation.add("user", "New message")
assert conversation.clear_all(), "Should clear all messages"
assert len(conversation.get_messages()) == 0, "Should have no messages after clear"
print("✓ Delete operations test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_concurrent_operations():
"""Test concurrent operations."""
temp_dir, db_path, conversation = setup_test()
try:
def add_messages():
for i in range(10):
conversation.add("user", f"Message {i}")
threads = [threading.Thread(target=add_messages) for _ in range(5)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
messages = conversation.get_messages()
assert len(messages) == 50, "Should have 50 messages (10 * 5 threads)"
print("✓ Concurrent operations test passed")
finally:
cleanup_test(temp_dir, db_path)
def test_error_handling():
"""Test error handling."""
temp_dir, db_path, conversation = setup_test()
try:
# Test invalid message type
try:
conversation.add("user", "Message", message_type="invalid")
assert False, "Should raise exception for invalid message type"
except Exception:
pass
# Test invalid JSON content
try:
conversation.add("user", {"invalid": object()})
assert False, "Should raise exception for invalid JSON content"
except Exception:
pass
# Test invalid file operations
try:
conversation.load_from_json("/nonexistent/path.json")
assert False, "Should raise exception for invalid file path"
except Exception:
pass
print("✓ Error handling test passed")
finally:
cleanup_test(temp_dir, db_path)
def run_all_tests():
"""Run all tests."""
print("Running DuckDB Conversation tests...")
tests = [
test_initialization,
test_add_message,
test_add_complex_message,
test_batch_add,
test_get_str,
test_get_messages,
test_search_messages,
test_get_statistics,
test_json_operations,
test_yaml_operations,
test_message_types,
test_delete_operations,
test_concurrent_operations,
test_error_handling
]
for test in tests:
try:
test()
except Exception as e:
print(f"{test.__name__} failed: {str(e)}")
raise
print("\nAll tests completed successfully!")
if __name__ == '__main__':
run_all_tests()

@ -0,0 +1,386 @@
import json
import datetime
import os
from typing import Dict, List, Any, Tuple
from loguru import logger
from swarms.communication.sqlite_wrap import (
SQLiteConversation,
Message,
MessageType,
)
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
console = Console()
def print_test_header(test_name: str) -> None:
"""Print a formatted test header."""
console.print(
Panel(
f"[bold blue]Running Test: {test_name}[/bold blue]",
expand=False,
)
)
def print_test_result(
test_name: str, success: bool, message: str, execution_time: float
) -> None:
"""Print a formatted test result."""
status = (
"[bold green]PASSED[/bold green]"
if success
else "[bold red]FAILED[/bold red]"
)
console.print(f"\n{status} - {test_name}")
console.print(f"Message: {message}")
console.print(f"Execution time: {execution_time:.3f} seconds\n")
def print_messages(
messages: List[Dict], title: str = "Messages"
) -> None:
"""Print messages in a formatted table."""
table = Table(title=title)
table.add_column("Role", style="cyan")
table.add_column("Content", style="green")
table.add_column("Type", style="yellow")
table.add_column("Timestamp", style="magenta")
for msg in messages:
content = str(msg.get("content", ""))
if isinstance(content, (dict, list)):
content = json.dumps(content)
table.add_row(
msg.get("role", ""),
content,
str(msg.get("message_type", "")),
str(msg.get("timestamp", "")),
)
console.print(table)
def run_test(
test_func: callable, *args, **kwargs
) -> Tuple[bool, str, float]:
"""
Run a test function and return its results.
Args:
test_func: The test function to run
*args: Arguments for the test function
**kwargs: Keyword arguments for the test function
Returns:
Tuple[bool, str, float]: (success, message, execution_time)
"""
start_time = datetime.datetime.now()
try:
result = test_func(*args, **kwargs)
end_time = datetime.datetime.now()
execution_time = (end_time - start_time).total_seconds()
return True, str(result), execution_time
except Exception as e:
end_time = datetime.datetime.now()
execution_time = (end_time - start_time).total_seconds()
return False, str(e), execution_time
def test_basic_conversation() -> bool:
"""Test basic conversation operations."""
print_test_header("Basic Conversation Test")
db_path = "test_conversations.db"
conversation = SQLiteConversation(db_path=db_path)
# Test adding messages
console.print("\n[bold]Adding messages...[/bold]")
msg_id1 = conversation.add("user", "Hello")
msg_id2 = conversation.add("assistant", "Hi there!")
# Test getting messages
console.print("\n[bold]Retrieved messages:[/bold]")
messages = conversation.get_messages()
print_messages(messages)
assert len(messages) == 2
assert messages[0]["role"] == "user"
assert messages[1]["role"] == "assistant"
# Cleanup
os.remove(db_path)
return True
def test_message_types() -> bool:
"""Test different message types and content formats."""
print_test_header("Message Types Test")
db_path = "test_conversations.db"
conversation = SQLiteConversation(db_path=db_path)
# Test different content types
console.print("\n[bold]Adding different message types...[/bold]")
conversation.add("user", "Simple text")
conversation.add(
"assistant", {"type": "json", "content": "Complex data"}
)
conversation.add("system", ["list", "of", "items"])
conversation.add(
"function",
"Function result",
message_type=MessageType.FUNCTION,
)
console.print("\n[bold]Retrieved messages:[/bold]")
messages = conversation.get_messages()
print_messages(messages)
assert len(messages) == 4
# Cleanup
os.remove(db_path)
return True
def test_conversation_operations() -> bool:
"""Test various conversation operations."""
print_test_header("Conversation Operations Test")
db_path = "test_conversations.db"
conversation = SQLiteConversation(db_path=db_path)
# Test batch operations
console.print("\n[bold]Adding batch messages...[/bold]")
messages = [
Message(role="user", content="Message 1"),
Message(role="assistant", content="Message 2"),
Message(role="user", content="Message 3"),
]
conversation.batch_add(messages)
console.print("\n[bold]Retrieved messages:[/bold]")
all_messages = conversation.get_messages()
print_messages(all_messages)
# Test statistics
console.print("\n[bold]Conversation Statistics:[/bold]")
stats = conversation.get_statistics()
console.print(json.dumps(stats, indent=2))
# Test role counting
console.print("\n[bold]Role Counts:[/bold]")
role_counts = conversation.count_messages_by_role()
console.print(json.dumps(role_counts, indent=2))
assert stats["total_messages"] == 3
assert role_counts["user"] == 2
assert role_counts["assistant"] == 1
# Cleanup
os.remove(db_path)
return True
def test_file_operations() -> bool:
"""Test file operations (JSON/YAML)."""
print_test_header("File Operations Test")
db_path = "test_conversations.db"
json_path = "test_conversation.json"
yaml_path = "test_conversation.yaml"
conversation = SQLiteConversation(db_path=db_path)
conversation.add("user", "Test message")
# Test JSON operations
console.print("\n[bold]Testing JSON operations...[/bold]")
assert conversation.save_as_json(json_path)
console.print(f"Saved to JSON: {json_path}")
conversation.start_new_conversation()
assert conversation.load_from_json(json_path)
console.print("Loaded from JSON")
# Test YAML operations
console.print("\n[bold]Testing YAML operations...[/bold]")
assert conversation.save_as_yaml(yaml_path)
console.print(f"Saved to YAML: {yaml_path}")
conversation.start_new_conversation()
assert conversation.load_from_yaml(yaml_path)
console.print("Loaded from YAML")
# Cleanup
os.remove(db_path)
os.remove(json_path)
os.remove(yaml_path)
return True
def test_search_and_filter() -> bool:
"""Test search and filter operations."""
print_test_header("Search and Filter Test")
db_path = "test_conversations.db"
conversation = SQLiteConversation(db_path=db_path)
# Add test messages
console.print("\n[bold]Adding test messages...[/bold]")
conversation.add("user", "Hello world")
conversation.add("assistant", "Hello there")
conversation.add("user", "Goodbye world")
# Test search
console.print("\n[bold]Searching for 'world'...[/bold]")
results = conversation.search_messages("world")
print_messages(results, "Search Results")
# Test role filtering
console.print("\n[bold]Filtering user messages...[/bold]")
user_messages = conversation.get_messages_by_role("user")
print_messages(user_messages, "User Messages")
assert len(results) == 2
assert len(user_messages) == 2
# Cleanup
os.remove(db_path)
return True
def test_conversation_management() -> bool:
"""Test conversation management features."""
print_test_header("Conversation Management Test")
db_path = "test_conversations.db"
conversation = SQLiteConversation(db_path=db_path)
# Test conversation ID generation
console.print("\n[bold]Testing conversation IDs...[/bold]")
conv_id1 = conversation.get_conversation_id()
console.print(f"First conversation ID: {conv_id1}")
conversation.start_new_conversation()
conv_id2 = conversation.get_conversation_id()
console.print(f"Second conversation ID: {conv_id2}")
assert conv_id1 != conv_id2
# Test conversation deletion
console.print("\n[bold]Testing conversation deletion...[/bold]")
conversation.add("user", "Test message")
assert conversation.delete_current_conversation()
console.print("Conversation deleted successfully")
# Cleanup
os.remove(db_path)
return True
def generate_test_report(
test_results: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Generate a test report in JSON format.
Args:
test_results: List of test results
Returns:
Dict containing the test report
"""
total_tests = len(test_results)
passed_tests = sum(
1 for result in test_results if result["success"]
)
failed_tests = total_tests - passed_tests
total_time = sum(
result["execution_time"] for result in test_results
)
report = {
"timestamp": datetime.datetime.now().isoformat(),
"summary": {
"total_tests": total_tests,
"passed_tests": passed_tests,
"failed_tests": failed_tests,
"total_execution_time": total_time,
"average_execution_time": (
total_time / total_tests if total_tests > 0 else 0
),
},
"test_results": test_results,
}
return report
def run_all_tests() -> None:
"""Run all tests and generate a report."""
console.print(
Panel(
"[bold blue]Starting Test Suite[/bold blue]", expand=False
)
)
tests = [
("Basic Conversation", test_basic_conversation),
("Message Types", test_message_types),
("Conversation Operations", test_conversation_operations),
("File Operations", test_file_operations),
("Search and Filter", test_search_and_filter),
("Conversation Management", test_conversation_management),
]
test_results = []
for test_name, test_func in tests:
logger.info(f"Running test: {test_name}")
success, message, execution_time = run_test(test_func)
print_test_result(test_name, success, message, execution_time)
result = {
"test_name": test_name,
"success": success,
"message": message,
"execution_time": execution_time,
"timestamp": datetime.datetime.now().isoformat(),
}
if success:
logger.success(f"Test passed: {test_name}")
else:
logger.error(f"Test failed: {test_name} - {message}")
test_results.append(result)
# Generate and save report
report = generate_test_report(test_results)
report_path = "test_report.json"
with open(report_path, "w") as f:
json.dump(report, f, indent=2)
# Print final summary
console.print("\n[bold blue]Test Suite Summary[/bold blue]")
console.print(
Panel(
f"Total tests: {report['summary']['total_tests']}\n"
f"Passed tests: {report['summary']['passed_tests']}\n"
f"Failed tests: {report['summary']['failed_tests']}\n"
f"Total execution time: {report['summary']['total_execution_time']:.2f} seconds",
title="Summary",
expand=False,
)
)
logger.info(f"Test report saved to {report_path}")
if __name__ == "__main__":
run_all_tests()
Loading…
Cancel
Save