import pytest
from unittest.mock import Mock, patch
from swarms.structs.mixture_of_agents import MixtureOfAgents
from swarms.structs.agent import Agent
from swarms_memory import BaseVectorDatabase


def test_init():
    with patch.object(
        MixtureOfAgents, "agent_check"
    ) as mock_agent_check, patch.object(
        MixtureOfAgents, "final_agent_check"
    ) as mock_final_agent_check, patch.object(
        MixtureOfAgents, "swarm_initialization"
    ) as mock_swarm_initialization, patch.object(
        MixtureOfAgents, "communication_protocol"
    ) as mock_communication_protocol:
        agents = [Mock(spec=Agent)]
        final_agent = Mock(spec=Agent)
        scp = Mock(spec=BaseVectorDatabase)
        MixtureOfAgents(
            agents=agents, final_agent=final_agent, scp=scp
        )
        mock_agent_check.assert_called_once()
        mock_final_agent_check.assert_called_once()
        mock_swarm_initialization.assert_called_once()
        mock_communication_protocol.assert_called_once()


def test_communication_protocol():
    agents = [Mock(spec=Agent)]
    final_agent = Mock(spec=Agent)
    scp = Mock(spec=BaseVectorDatabase)
    swarm = MixtureOfAgents(
        agents=agents, final_agent=final_agent, scp=scp
    )
    swarm.communication_protocol()
    for agent in agents:
        agent.long_term_memory.assert_called_once_with(scp)


def test_agent_check():
    final_agent = Mock(spec=Agent)
    with pytest.raises(TypeError):
        MixtureOfAgents(agents="not a list", final_agent=final_agent)
    with pytest.raises(TypeError):
        MixtureOfAgents(
            agents=["not an agent"], final_agent=final_agent
        )


def test_final_agent_check():
    agents = [Mock(spec=Agent)]
    with pytest.raises(TypeError):
        MixtureOfAgents(agents=agents, final_agent="not an agent")


def test_swarm_initialization():
    with patch(
        "swarms.structs.mixture_of_agents.logger"
    ) as mock_logger:
        agents = [Mock(spec=Agent)]
        final_agent = Mock(spec=Agent)
        swarm = MixtureOfAgents(
            agents=agents, final_agent=final_agent
        )
        swarm.swarm_initialization()
        assert mock_logger.info.call_count == 3


def test_run():
    with patch("swarms.structs.mixture_of_agents.logger"), patch(
        "builtins.open", new_callable=Mock
    ) as mock_open:
        agents = [Mock(spec=Agent)]
        final_agent = Mock(spec=Agent)
        swarm = MixtureOfAgents(
            agents=agents, final_agent=final_agent
        )
        swarm.run("task")
        for agent in agents:
            agent.run.assert_called_once()
        final_agent.run.assert_called_once()
        mock_open.assert_called_once_with(swarm.saved_file_name, "w")