import time

from loguru import logger
from swarms import Agent

from experimental.airflow_swarm import (
    AirflowDAGSwarm,
    NodeType,
    Conversation,
)

# Configure logger
logger.remove()
logger.add(lambda msg: print(msg, end=""), level="DEBUG")


def test_swarm_initialization():
    """Test basic swarm initialization and configuration."""
    try:
        swarm = AirflowDAGSwarm(
            dag_id="test_dag",
            name="Test DAG",
            initial_message="Test message",
        )
        assert swarm.dag_id == "test_dag", "DAG ID not set correctly"
        assert swarm.name == "Test DAG", "Name not set correctly"
        assert (
            len(swarm.nodes) == 0
        ), "Nodes should be empty on initialization"
        assert (
            len(swarm.edges) == 0
        ), "Edges should be empty on initialization"

        # Test initial message
        conv_json = swarm.get_conversation_history()
        assert (
            "Test message" in conv_json
        ), "Initial message not set correctly"
        print("✅ Swarm initialization test passed")
        return True
    except AssertionError as e:
        print(f"❌ Swarm initialization test failed: {str(e)}")
        return False


def test_node_addition():
    """Test adding different types of nodes to the swarm."""
    try:
        swarm = AirflowDAGSwarm(dag_id="test_dag")

        # Test adding an agent node
        agent = Agent(
            agent_name="Test-Agent",
            system_prompt="Test prompt",
            model_name="gpt-4o-mini",
            max_loops=1,
        )
        agent_id = swarm.add_node(
            "test_agent",
            agent,
            NodeType.AGENT,
            query="Test query",
            concurrent=True,
        )
        assert (
            agent_id == "test_agent"
        ), "Agent node ID not returned correctly"
        assert (
            "test_agent" in swarm.nodes
        ), "Agent node not added to nodes dict"

        # Test adding a callable node
        def test_callable(x: int, conversation: Conversation) -> str:
            return f"Test output {x}"

        callable_id = swarm.add_node(
            "test_callable",
            test_callable,
            NodeType.CALLABLE,
            args=[42],
            concurrent=False,
        )
        assert (
            callable_id == "test_callable"
        ), "Callable node ID not returned correctly"
        assert (
            "test_callable" in swarm.nodes
        ), "Callable node not added to nodes dict"

        print("✅ Node addition test passed")
        return True
    except AssertionError as e:
        print(f"❌ Node addition test failed: {str(e)}")
        return False
    except Exception as e:
        print(
            f"❌ Node addition test failed with unexpected error: {str(e)}"
        )
        return False


def test_edge_addition():
    """Test adding edges between nodes."""
    try:
        swarm = AirflowDAGSwarm(dag_id="test_dag")

        # Add two nodes
        def node1_fn(conversation: Conversation) -> str:
            return "Node 1 output"

        def node2_fn(conversation: Conversation) -> str:
            return "Node 2 output"

        swarm.add_node("node1", node1_fn, NodeType.CALLABLE)
        swarm.add_node("node2", node2_fn, NodeType.CALLABLE)

        # Add edge between them
        swarm.add_edge("node1", "node2")

        assert (
            "node2" in swarm.edges["node1"]
        ), "Edge not added correctly"
        assert (
            len(swarm.edges["node1"]) == 1
        ), "Incorrect number of edges"

        # Test adding edge with non-existent node
        try:
            swarm.add_edge("node1", "non_existent")
            assert (
                False
            ), "Should raise ValueError for non-existent node"
        except ValueError:
            pass

        print("✅ Edge addition test passed")
        return True
    except AssertionError as e:
        print(f"❌ Edge addition test failed: {str(e)}")
        return False


def test_execution_order():
    """Test that nodes are executed in the correct order based on dependencies."""
    try:
        swarm = AirflowDAGSwarm(dag_id="test_dag")
        execution_order = []

        def node1(conversation: Conversation) -> str:
            execution_order.append("node1")
            return "Node 1 output"

        def node2(conversation: Conversation) -> str:
            execution_order.append("node2")
            return "Node 2 output"

        def node3(conversation: Conversation) -> str:
            execution_order.append("node3")
            return "Node 3 output"

        # Add nodes
        swarm.add_node(
            "node1", node1, NodeType.CALLABLE, concurrent=False
        )
        swarm.add_node(
            "node2", node2, NodeType.CALLABLE, concurrent=False
        )
        swarm.add_node(
            "node3", node3, NodeType.CALLABLE, concurrent=False
        )

        # Add edges to create a chain: node1 -> node2 -> node3
        swarm.add_edge("node1", "node2")
        swarm.add_edge("node2", "node3")

        # Execute
        swarm.run()

        # Check execution order
        assert execution_order == [
            "node1",
            "node2",
            "node3",
        ], "Incorrect execution order"
        print("✅ Execution order test passed")
        return True
    except AssertionError as e:
        print(f"❌ Execution order test failed: {str(e)}")
        return False


def test_concurrent_execution():
    """Test concurrent execution of nodes."""
    try:
        swarm = AirflowDAGSwarm(dag_id="test_dag")

        def slow_node1(conversation: Conversation) -> str:
            time.sleep(0.5)
            return "Slow node 1 output"

        def slow_node2(conversation: Conversation) -> str:
            time.sleep(0.5)
            return "Slow node 2 output"

        # Add nodes with concurrent=True
        swarm.add_node(
            "slow1", slow_node1, NodeType.CALLABLE, concurrent=True
        )
        swarm.add_node(
            "slow2", slow_node2, NodeType.CALLABLE, concurrent=True
        )

        # Measure execution time
        start_time = time.time()
        swarm.run()
        execution_time = time.time() - start_time

        # Should take ~0.5s for concurrent execution, not ~1s
        assert (
            execution_time < 0.8
        ), "Concurrent execution took too long"
        print("✅ Concurrent execution test passed")
        return True
    except AssertionError as e:
        print(f"❌ Concurrent execution test failed: {str(e)}")
        return False


def test_conversation_handling():
    """Test conversation management within the swarm."""
    try:
        swarm = AirflowDAGSwarm(
            dag_id="test_dag", initial_message="Initial test message"
        )

        # Test adding user messages
        swarm.add_user_message("Test message 1")
        swarm.add_user_message("Test message 2")

        history = swarm.get_conversation_history()
        assert (
            "Initial test message" in history
        ), "Initial message not in history"
        assert (
            "Test message 1" in history
        ), "First message not in history"
        assert (
            "Test message 2" in history
        ), "Second message not in history"

        print("✅ Conversation handling test passed")
        return True
    except AssertionError as e:
        print(f"❌ Conversation handling test failed: {str(e)}")
        return False


def test_error_handling():
    """Test error handling in node execution."""
    try:
        swarm = AirflowDAGSwarm(dag_id="test_dag")

        def failing_node(conversation: Conversation) -> str:
            raise ValueError("Test error")

        swarm.add_node("failing", failing_node, NodeType.CALLABLE)

        # Execute should not raise an exception
        result = swarm.run()

        assert (
            "Error" in result
        ), "Error not captured in execution result"
        assert (
            "Test error" in result
        ), "Specific error message not captured"

        print("✅ Error handling test passed")
        return True
    except Exception as e:
        print(f"❌ Error handling test failed: {str(e)}")
        return False


def run_all_tests():
    """Run all test functions and report results."""
    tests = [
        test_swarm_initialization,
        test_node_addition,
        test_edge_addition,
        test_execution_order,
        test_concurrent_execution,
        test_conversation_handling,
        test_error_handling,
    ]

    results = []
    for test in tests:
        print(f"\nRunning {test.__name__}...")
        result = test()
        results.append(result)

    total = len(results)
    passed = sum(results)
    print("\n=== Test Results ===")
    print(f"Total tests: {total}")
    print(f"Passed: {passed}")
    print(f"Failed: {total - passed}")
    print("==================")


if __name__ == "__main__":
    run_all_tests()