import os
from pathlib import Path
import tempfile
import threading
from swarms.communication.duckdb_wrap import (
    DuckDBConversation,
    Message,
    MessageType,
)


def setup_test():
    """Set up test environment."""
    temp_dir = tempfile.TemporaryDirectory()
    db_path = Path(temp_dir.name) / "test_conversations.duckdb"
    conversation = DuckDBConversation(
        db_path=str(db_path),
        enable_timestamps=True,
        enable_logging=True,
    )
    return temp_dir, db_path, conversation


def cleanup_test(temp_dir, db_path):
    """Clean up test environment."""
    if os.path.exists(db_path):
        os.remove(db_path)
    temp_dir.cleanup()


def test_initialization():
    """Test conversation initialization."""
    temp_dir, db_path, _ = setup_test()
    try:
        conv = DuckDBConversation(db_path=str(db_path))
        assert conv.db_path == db_path, "Database path mismatch"
        assert (
            conv.table_name == "conversations"
        ), "Table name mismatch"
        assert (
            conv.enable_timestamps is True
        ), "Timestamps should be enabled"
        assert (
            conv.current_conversation_id is not None
        ), "Conversation ID should not be None"
        print("✓ Initialization test passed")
    finally:
        cleanup_test(temp_dir, db_path)


def test_add_message():
    """Test adding a single message."""
    temp_dir, db_path, conversation = setup_test()
    try:
        msg_id = conversation.add(
            role="user",
            content="Hello, world!",
            message_type=MessageType.USER,
        )
        assert msg_id is not None, "Message ID should not be None"
        assert isinstance(
            msg_id, int
        ), "Message ID should be an integer"
        print("✓ Add message test passed")
    finally:
        cleanup_test(temp_dir, db_path)


def test_add_complex_message():
    """Test adding a message with complex content."""
    temp_dir, db_path, conversation = setup_test()
    try:
        complex_content = {
            "text": "Hello",
            "data": [1, 2, 3],
            "nested": {"key": "value"},
        }
        msg_id = conversation.add(
            role="assistant",
            content=complex_content,
            message_type=MessageType.ASSISTANT,
            metadata={"source": "test"},
            token_count=10,
        )
        assert msg_id is not None, "Message ID should not be None"
        print("✓ Add complex message test passed")
    finally:
        cleanup_test(temp_dir, db_path)


def test_batch_add():
    """Test batch adding messages."""
    temp_dir, db_path, conversation = setup_test()
    try:
        messages = [
            Message(
                role="user",
                content="First message",
                message_type=MessageType.USER,
            ),
            Message(
                role="assistant",
                content="Second message",
                message_type=MessageType.ASSISTANT,
            ),
        ]
        msg_ids = conversation.batch_add(messages)
        assert len(msg_ids) == 2, "Should have 2 message IDs"
        assert all(
            isinstance(id, int) for id in msg_ids
        ), "All IDs should be integers"
        print("✓ Batch add test passed")
    finally:
        cleanup_test(temp_dir, db_path)


def test_get_str():
    """Test getting conversation as string."""
    temp_dir, db_path, conversation = setup_test()
    try:
        conversation.add("user", "Hello")
        conversation.add("assistant", "Hi there!")
        conv_str = conversation.get_str()
        assert "user: Hello" in conv_str, "User message not found"
        assert (
            "assistant: Hi there!" in conv_str
        ), "Assistant message not found"
        print("✓ Get string test passed")
    finally:
        cleanup_test(temp_dir, db_path)


def test_get_messages():
    """Test getting messages with pagination."""
    temp_dir, db_path, conversation = setup_test()
    try:
        for i in range(5):
            conversation.add("user", f"Message {i}")

        all_messages = conversation.get_messages()
        assert len(all_messages) == 5, "Should have 5 messages"

        limited_messages = conversation.get_messages(limit=2)
        assert (
            len(limited_messages) == 2
        ), "Should have 2 limited messages"

        offset_messages = conversation.get_messages(offset=2)
        assert (
            len(offset_messages) == 3
        ), "Should have 3 offset messages"
        print("✓ Get messages test passed")
    finally:
        cleanup_test(temp_dir, db_path)


def test_search_messages():
    """Test searching messages."""
    temp_dir, db_path, conversation = setup_test()
    try:
        conversation.add("user", "Hello world")
        conversation.add("assistant", "Hello there")
        conversation.add("user", "Goodbye world")

        results = conversation.search_messages("world")
        assert (
            len(results) == 2
        ), "Should find 2 messages with 'world'"
        assert all(
            "world" in msg["content"] for msg in results
        ), "All results should contain 'world'"
        print("✓ Search messages test passed")
    finally:
        cleanup_test(temp_dir, db_path)


def test_get_statistics():
    """Test getting conversation statistics."""
    temp_dir, db_path, conversation = setup_test()
    try:
        conversation.add("user", "Hello", token_count=2)
        conversation.add("assistant", "Hi", token_count=1)

        stats = conversation.get_statistics()
        assert (
            stats["total_messages"] == 2
        ), "Should have 2 total messages"
        assert (
            stats["unique_roles"] == 2
        ), "Should have 2 unique roles"
        assert (
            stats["total_tokens"] == 3
        ), "Should have 3 total tokens"
        print("✓ Get statistics test passed")
    finally:
        cleanup_test(temp_dir, db_path)


def test_json_operations():
    """Test JSON save and load operations."""
    temp_dir, db_path, conversation = setup_test()
    try:
        conversation.add("user", "Hello")
        conversation.add("assistant", "Hi")

        json_path = Path(temp_dir.name) / "test_conversation.json"
        conversation.save_as_json(str(json_path))
        assert json_path.exists(), "JSON file should exist"

        new_conversation = DuckDBConversation(
            db_path=str(Path(temp_dir.name) / "new.duckdb")
        )
        assert new_conversation.load_from_json(
            str(json_path)
        ), "Should load from JSON"
        assert (
            len(new_conversation.get_messages()) == 2
        ), "Should have 2 messages after load"
        print("✓ JSON operations test passed")
    finally:
        cleanup_test(temp_dir, db_path)


def test_yaml_operations():
    """Test YAML save and load operations."""
    temp_dir, db_path, conversation = setup_test()
    try:
        conversation.add("user", "Hello")
        conversation.add("assistant", "Hi")

        yaml_path = Path(temp_dir.name) / "test_conversation.yaml"
        conversation.save_as_yaml(str(yaml_path))
        assert yaml_path.exists(), "YAML file should exist"

        new_conversation = DuckDBConversation(
            db_path=str(Path(temp_dir.name) / "new.duckdb")
        )
        assert new_conversation.load_from_yaml(
            str(yaml_path)
        ), "Should load from YAML"
        assert (
            len(new_conversation.get_messages()) == 2
        ), "Should have 2 messages after load"
        print("✓ YAML operations test passed")
    finally:
        cleanup_test(temp_dir, db_path)


def test_message_types():
    """Test different message types."""
    temp_dir, db_path, conversation = setup_test()
    try:
        conversation.add(
            "system",
            "System message",
            message_type=MessageType.SYSTEM,
        )
        conversation.add(
            "user", "User message", message_type=MessageType.USER
        )
        conversation.add(
            "assistant",
            "Assistant message",
            message_type=MessageType.ASSISTANT,
        )
        conversation.add(
            "function",
            "Function message",
            message_type=MessageType.FUNCTION,
        )
        conversation.add(
            "tool", "Tool message", message_type=MessageType.TOOL
        )

        messages = conversation.get_messages()
        assert len(messages) == 5, "Should have 5 messages"
        assert all(
            "message_type" in msg for msg in messages
        ), "All messages should have type"
        print("✓ Message types test passed")
    finally:
        cleanup_test(temp_dir, db_path)


def test_delete_operations():
    """Test deletion operations."""
    temp_dir, db_path, conversation = setup_test()
    try:
        conversation.add("user", "Hello")
        conversation.add("assistant", "Hi")

        assert (
            conversation.delete_current_conversation()
        ), "Should delete conversation"
        assert (
            len(conversation.get_messages()) == 0
        ), "Should have no messages after delete"

        conversation.add("user", "New message")
        assert conversation.clear_all(), "Should clear all messages"
        assert (
            len(conversation.get_messages()) == 0
        ), "Should have no messages after clear"
        print("✓ Delete operations test passed")
    finally:
        cleanup_test(temp_dir, db_path)


def test_concurrent_operations():
    """Test concurrent operations."""
    temp_dir, db_path, conversation = setup_test()
    try:

        def add_messages():
            for i in range(10):
                conversation.add("user", f"Message {i}")

        threads = [
            threading.Thread(target=add_messages) for _ in range(5)
        ]
        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()

        messages = conversation.get_messages()
        assert (
            len(messages) == 50
        ), "Should have 50 messages (10 * 5 threads)"
        print("✓ Concurrent operations test passed")
    finally:
        cleanup_test(temp_dir, db_path)


def test_error_handling():
    """Test error handling."""
    temp_dir, db_path, conversation = setup_test()
    try:
        # Test invalid message type
        try:
            conversation.add(
                "user", "Message", message_type="invalid"
            )
            assert (
                False
            ), "Should raise exception for invalid message type"
        except Exception:
            pass

        # Test invalid JSON content
        try:
            conversation.add("user", {"invalid": object()})
            assert (
                False
            ), "Should raise exception for invalid JSON content"
        except Exception:
            pass

        # Test invalid file operations
        try:
            conversation.load_from_json("/nonexistent/path.json")
            assert (
                False
            ), "Should raise exception for invalid file path"
        except Exception:
            pass

        print("✓ Error handling test passed")
    finally:
        cleanup_test(temp_dir, db_path)


def run_all_tests():
    """Run all tests."""
    print("Running DuckDB Conversation tests...")
    tests = [
        test_initialization,
        test_add_message,
        test_add_complex_message,
        test_batch_add,
        test_get_str,
        test_get_messages,
        test_search_messages,
        test_get_statistics,
        test_json_operations,
        test_yaml_operations,
        test_message_types,
        test_delete_operations,
        test_concurrent_operations,
        test_error_handling,
    ]

    for test in tests:
        try:
            test()
        except Exception as e:
            print(f"✗ {test.__name__} failed: {str(e)}")
            raise

    print("\nAll tests completed successfully!")


if __name__ == "__main__":
    run_all_tests()