You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
223 lines
6.8 KiB
223 lines
6.8 KiB
import pytest
|
|
|
|
from swarms.models import OpenAIChat
|
|
from swarms.models.anthropic import Anthropic
|
|
from swarms.structs.flow import Flow
|
|
from swarms.swarms.groupchat import GroupChat, GroupChatManager
|
|
|
|
llm = OpenAIChat()
|
|
llm2 = Anthropic()
|
|
|
|
|
|
# Mock the OpenAI class for testing
|
|
class MockOpenAI:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
def generate_reply(self, content):
|
|
return {"role": "mocked_agent", "content": "Mocked Reply"}
|
|
|
|
|
|
# Create fixtures for agents and a sample message
|
|
@pytest.fixture
|
|
def agent1():
|
|
return Flow(name="Agent1", llm=llm)
|
|
|
|
|
|
@pytest.fixture
|
|
def agent2():
|
|
return Flow(name="Agent2", llm=llm2)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_message():
|
|
return {"role": "Agent1", "content": "Hello, World!"}
|
|
|
|
|
|
# Test the initialization of GroupChat
|
|
def test_groupchat_initialization(agent1, agent2):
|
|
groupchat = GroupChat(agents=[agent1, agent2])
|
|
assert len(groupchat.agents) == 2
|
|
assert len(groupchat.messages) == 0
|
|
assert groupchat.max_round == 10
|
|
assert groupchat.admin_name == "Admin"
|
|
|
|
|
|
# Test resetting the GroupChat
|
|
def test_groupchat_reset(agent1, agent2, sample_message):
|
|
groupchat = GroupChat(agents=[agent1, agent2])
|
|
groupchat.messages.append(sample_message)
|
|
groupchat.reset()
|
|
assert len(groupchat.messages) == 0
|
|
|
|
|
|
# Test finding an agent by name
|
|
def test_groupchat_find_agent_by_name(agent1, agent2):
|
|
groupchat = GroupChat(agents=[agent1, agent2])
|
|
found_agent = groupchat.agent_by_name("Agent1")
|
|
assert found_agent == agent1
|
|
|
|
|
|
# Test selecting the next agent
|
|
def test_groupchat_select_next_agent(agent1, agent2):
|
|
groupchat = GroupChat(agents=[agent1, agent2])
|
|
next_agent = groupchat.next_agent(agent1)
|
|
assert next_agent == agent2
|
|
|
|
|
|
# Add more tests for different methods and scenarios as needed
|
|
|
|
|
|
# Test the GroupChatManager
|
|
def test_groupchat_manager(agent1, agent2):
|
|
groupchat = GroupChat(agents=[agent1, agent2])
|
|
selector = agent1 # Assuming agent1 is the selector
|
|
manager = GroupChatManager(groupchat, selector)
|
|
task = "Task for agent2"
|
|
reply = manager(task)
|
|
assert reply["role"] == "Agent2"
|
|
assert reply["content"] == "Reply from Agent2"
|
|
|
|
|
|
# Test selecting the next speaker when there is only one agent
|
|
def test_groupchat_select_speaker_single_agent(agent1):
|
|
groupchat = GroupChat(agents=[agent1])
|
|
selector = agent1
|
|
manager = GroupChatManager(groupchat, selector)
|
|
task = "Task for agent1"
|
|
reply = manager(task)
|
|
assert reply["role"] == "Agent1"
|
|
assert reply["content"] == "Reply from Agent1"
|
|
|
|
|
|
# Test selecting the next speaker when GroupChat is underpopulated
|
|
def test_groupchat_select_speaker_underpopulated(agent1, agent2):
|
|
groupchat = GroupChat(agents=[agent1, agent2])
|
|
selector = agent1
|
|
manager = GroupChatManager(groupchat, selector)
|
|
task = "Task for agent1"
|
|
reply = manager(task)
|
|
assert reply["role"] == "Agent2"
|
|
assert reply["content"] == "Reply from Agent2"
|
|
|
|
|
|
# Test formatting history
|
|
def test_groupchat_format_history(agent1, agent2, sample_message):
|
|
groupchat = GroupChat(agents=[agent1, agent2])
|
|
groupchat.messages.append(sample_message)
|
|
formatted_history = groupchat.format_history(groupchat.messages)
|
|
expected_history = "'Agent1:Hello, World!"
|
|
assert formatted_history == expected_history
|
|
|
|
|
|
# Test agent names property
|
|
def test_groupchat_agent_names(agent1, agent2):
|
|
groupchat = GroupChat(agents=[agent1, agent2])
|
|
names = groupchat.agent_names
|
|
assert len(names) == 2
|
|
assert "Agent1" in names
|
|
assert "Agent2" in names
|
|
|
|
|
|
# Test GroupChatManager initialization
|
|
def test_groupchat_manager_initialization(agent1, agent2):
|
|
groupchat = GroupChat(agents=[agent1, agent2])
|
|
selector = agent1
|
|
manager = GroupChatManager(groupchat, selector)
|
|
assert manager.groupchat == groupchat
|
|
assert manager.selector == selector
|
|
|
|
|
|
# Test case to ensure GroupChatManager generates a reply from an agent
|
|
def test_groupchat_manager_generate_reply():
|
|
# Create a GroupChat with two agents
|
|
agents = [agent1, agent2]
|
|
groupchat = GroupChat(agents=agents, messages=[], max_round=10)
|
|
|
|
# Mock the OpenAI class and GroupChat selector
|
|
mocked_openai = MockOpenAI()
|
|
selector = agent1
|
|
|
|
# Initialize GroupChatManager
|
|
manager = GroupChatManager(
|
|
groupchat=groupchat, selector=selector, openai=mocked_openai
|
|
)
|
|
|
|
# Generate a reply
|
|
task = "Write me a riddle"
|
|
reply = manager(task)
|
|
|
|
# Check if a valid reply is generated
|
|
assert "role" in reply
|
|
assert "content" in reply
|
|
assert reply["role"] in groupchat.agent_names
|
|
|
|
|
|
# Test case to ensure GroupChat selects the next speaker correctly
|
|
def test_groupchat_select_speaker():
|
|
agent3 = Flow(name="agent3", llm=llm)
|
|
agents = [agent1, agent2, agent3]
|
|
groupchat = GroupChat(agents=agents, messages=[], max_round=10)
|
|
|
|
# Initialize GroupChatManager with agent1 as selector
|
|
selector = agent1
|
|
manager = GroupChatManager(groupchat=groupchat, selector=selector)
|
|
|
|
# Simulate selecting the next speaker
|
|
last_speaker = agent1
|
|
next_speaker = manager.select_speaker(
|
|
last_speaker=last_speaker, selector=selector
|
|
)
|
|
|
|
# Ensure the next speaker is agent2
|
|
assert next_speaker == agent2
|
|
|
|
|
|
# Test case to ensure GroupChat handles underpopulated group correctly
|
|
def test_groupchat_underpopulated_group():
|
|
agent1 = Flow(name="agent1", llm=llm)
|
|
agents = [agent1]
|
|
groupchat = GroupChat(agents=agents, messages=[], max_round=10)
|
|
|
|
# Initialize GroupChatManager with agent1 as selector
|
|
selector = agent1
|
|
manager = GroupChatManager(groupchat=groupchat, selector=selector)
|
|
|
|
# Simulate selecting the next speaker in an underpopulated group
|
|
last_speaker = agent1
|
|
next_speaker = manager.select_speaker(
|
|
last_speaker=last_speaker, selector=selector
|
|
)
|
|
|
|
# Ensure the next speaker is the same as the last speaker in an underpopulated group
|
|
assert next_speaker == last_speaker
|
|
|
|
|
|
# Test case to ensure GroupChatManager handles the maximum rounds correctly
|
|
def test_groupchat_max_rounds():
|
|
agents = [agent1, agent2]
|
|
groupchat = GroupChat(agents=agents, messages=[], max_round=2)
|
|
|
|
# Initialize GroupChatManager with agent1 as selector
|
|
selector = agent1
|
|
manager = GroupChatManager(groupchat=groupchat, selector=selector)
|
|
|
|
# Simulate the conversation with max rounds
|
|
last_speaker = agent1
|
|
for _ in range(2):
|
|
next_speaker = manager.select_speaker(
|
|
last_speaker=last_speaker, selector=selector
|
|
)
|
|
last_speaker = next_speaker
|
|
|
|
# Try one more round, should stay with the last speaker
|
|
next_speaker = manager.select_speaker(
|
|
last_speaker=last_speaker, selector=selector
|
|
)
|
|
|
|
# Ensure the next speaker is the same as the last speaker after reaching max rounds
|
|
assert next_speaker == last_speaker
|
|
|
|
|
|
# Continue adding more test cases as needed to cover various scenarios and functionalities of the code.
|