Former-commit-id: 9402dab489
grit/923f7c6f-0958-480b-8748-ea6bbf1c2084
parent
3279130e95
commit
7d744dc6a0
@ -0,0 +1,214 @@
|
||||
import pytest
|
||||
|
||||
from swarms.models import OpenAIChat
|
||||
from swarms.models.anthropic import Anthropic
|
||||
from swarms.structs.flow import Flow
|
||||
from swarms.swarms.flow import GroupChat, GroupChatManager
|
||||
|
||||
llm = OpenAIChat()
|
||||
llm2 = Anthropic()
|
||||
|
||||
# Mock the OpenAI class for testing
|
||||
class MockOpenAI:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def generate_reply(self, content):
|
||||
return {"role": "mocked_agent", "content": "Mocked Reply"}
|
||||
|
||||
|
||||
# Create fixtures for agents and a sample message
|
||||
@pytest.fixture
|
||||
def agent1():
|
||||
return Flow(name="Agent1", llm=llm)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent2():
|
||||
return Flow(name="Agent2", llm=llm2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_message():
|
||||
return {"role": "Agent1", "content": "Hello, World!"}
|
||||
|
||||
|
||||
# Test the initialization of GroupChat
|
||||
def test_groupchat_initialization(agent1, agent2):
|
||||
groupchat = GroupChat(agents=[agent1, agent2])
|
||||
assert len(groupchat.agents) == 2
|
||||
assert len(groupchat.messages) == 0
|
||||
assert groupchat.max_round == 10
|
||||
assert groupchat.admin_name == "Admin"
|
||||
|
||||
|
||||
# Test resetting the GroupChat
|
||||
def test_groupchat_reset(agent1, agent2, sample_message):
|
||||
groupchat = GroupChat(agents=[agent1, agent2])
|
||||
groupchat.messages.append(sample_message)
|
||||
groupchat.reset()
|
||||
assert len(groupchat.messages) == 0
|
||||
|
||||
|
||||
# Test finding an agent by name
|
||||
def test_groupchat_find_agent_by_name(agent1, agent2):
|
||||
groupchat = GroupChat(agents=[agent1, agent2])
|
||||
found_agent = groupchat.agent_by_name("Agent1")
|
||||
assert found_agent == agent1
|
||||
|
||||
|
||||
# Test selecting the next agent
|
||||
def test_groupchat_select_next_agent(agent1, agent2):
|
||||
groupchat = GroupChat(agents=[agent1, agent2])
|
||||
next_agent = groupchat.next_agent(agent1)
|
||||
assert next_agent == agent2
|
||||
|
||||
|
||||
# Add more tests for different methods and scenarios as needed
|
||||
|
||||
|
||||
# Test the GroupChatManager
|
||||
def test_groupchat_manager(agent1, agent2):
|
||||
groupchat = GroupChat(agents=[agent1, agent2])
|
||||
selector = agent1 # Assuming agent1 is the selector
|
||||
manager = GroupChatManager(groupchat, selector)
|
||||
task = "Task for agent2"
|
||||
reply = manager(task)
|
||||
assert reply["role"] == "Agent2"
|
||||
assert reply["content"] == "Reply from Agent2"
|
||||
|
||||
|
||||
# Test selecting the next speaker when there is only one agent
|
||||
def test_groupchat_select_speaker_single_agent(agent1):
|
||||
groupchat = GroupChat(agents=[agent1])
|
||||
selector = agent1
|
||||
manager = GroupChatManager(groupchat, selector)
|
||||
task = "Task for agent1"
|
||||
reply = manager(task)
|
||||
assert reply["role"] == "Agent1"
|
||||
assert reply["content"] == "Reply from Agent1"
|
||||
|
||||
|
||||
# Test selecting the next speaker when GroupChat is underpopulated
|
||||
def test_groupchat_select_speaker_underpopulated(agent1, agent2):
|
||||
groupchat = GroupChat(agents=[agent1, agent2])
|
||||
selector = agent1
|
||||
manager = GroupChatManager(groupchat, selector)
|
||||
task = "Task for agent1"
|
||||
reply = manager(task)
|
||||
assert reply["role"] == "Agent2"
|
||||
assert reply["content"] == "Reply from Agent2"
|
||||
|
||||
|
||||
# Test formatting history
|
||||
def test_groupchat_format_history(agent1, agent2, sample_message):
|
||||
groupchat = GroupChat(agents=[agent1, agent2])
|
||||
groupchat.messages.append(sample_message)
|
||||
formatted_history = groupchat.format_history(groupchat.messages)
|
||||
expected_history = "'Agent1:Hello, World!"
|
||||
assert formatted_history == expected_history
|
||||
|
||||
|
||||
# Test agent names property
|
||||
def test_groupchat_agent_names(agent1, agent2):
|
||||
groupchat = GroupChat(agents=[agent1, agent2])
|
||||
names = groupchat.agent_names
|
||||
assert len(names) == 2
|
||||
assert "Agent1" in names
|
||||
assert "Agent2" in names
|
||||
|
||||
|
||||
# Test GroupChatManager initialization
|
||||
def test_groupchat_manager_initialization(agent1, agent2):
|
||||
groupchat = GroupChat(agents=[agent1, agent2])
|
||||
selector = agent1
|
||||
manager = GroupChatManager(groupchat, selector)
|
||||
assert manager.groupchat == groupchat
|
||||
assert manager.selector == selector
|
||||
|
||||
# Test case to ensure GroupChatManager generates a reply from an agent
|
||||
def test_groupchat_manager_generate_reply():
|
||||
# Create a GroupChat with two agents
|
||||
agents = [agent1, agent2]
|
||||
groupchat = GroupChat(agents=agents, messages=[], max_round=10)
|
||||
|
||||
# Mock the OpenAI class and GroupChat selector
|
||||
mocked_openai = MockOpenAI()
|
||||
selector = agent1
|
||||
|
||||
# Initialize GroupChatManager
|
||||
manager = GroupChatManager(
|
||||
groupchat=groupchat, selector=selector, openai=mocked_openai
|
||||
)
|
||||
|
||||
# Generate a reply
|
||||
task = "Write me a riddle"
|
||||
reply = manager(task)
|
||||
|
||||
# Check if a valid reply is generated
|
||||
assert "role" in reply
|
||||
assert "content" in reply
|
||||
assert reply["role"] in groupchat.agent_names
|
||||
|
||||
|
||||
# Test case to ensure GroupChat selects the next speaker correctly
|
||||
def test_groupchat_select_speaker():
|
||||
agent3 = Flow(name="agent3", llm=llm)
|
||||
agents = [agent1, agent2, agent3]
|
||||
groupchat = GroupChat(agents=agents, messages=[], max_round=10)
|
||||
|
||||
# Initialize GroupChatManager with agent1 as selector
|
||||
selector = agent1
|
||||
manager = GroupChatManager(groupchat=groupchat, selector=selector)
|
||||
|
||||
# Simulate selecting the next speaker
|
||||
last_speaker = agent1
|
||||
next_speaker = manager.select_speaker(last_speaker=last_speaker, selector=selector)
|
||||
|
||||
# Ensure the next speaker is agent2
|
||||
assert next_speaker == agent2
|
||||
|
||||
|
||||
# Test case to ensure GroupChat handles underpopulated group correctly
|
||||
def test_groupchat_underpopulated_group():
|
||||
agent1 = Flow(name="agent1", llm=llm)
|
||||
agents = [agent1]
|
||||
groupchat = GroupChat(agents=agents, messages=[], max_round=10)
|
||||
|
||||
# Initialize GroupChatManager with agent1 as selector
|
||||
selector = agent1
|
||||
manager = GroupChatManager(groupchat=groupchat, selector=selector)
|
||||
|
||||
# Simulate selecting the next speaker in an underpopulated group
|
||||
last_speaker = agent1
|
||||
next_speaker = manager.select_speaker(last_speaker=last_speaker, selector=selector)
|
||||
|
||||
# Ensure the next speaker is the same as the last speaker in an underpopulated group
|
||||
assert next_speaker == last_speaker
|
||||
|
||||
|
||||
# Test case to ensure GroupChatManager handles the maximum rounds correctly
|
||||
def test_groupchat_max_rounds():
|
||||
agents = [agent1, agent2]
|
||||
groupchat = GroupChat(agents=agents, messages=[], max_round=2)
|
||||
|
||||
# Initialize GroupChatManager with agent1 as selector
|
||||
selector = agent1
|
||||
manager = GroupChatManager(groupchat=groupchat, selector=selector)
|
||||
|
||||
# Simulate the conversation with max rounds
|
||||
last_speaker = agent1
|
||||
for _ in range(2):
|
||||
next_speaker = manager.select_speaker(
|
||||
last_speaker=last_speaker, selector=selector
|
||||
)
|
||||
last_speaker = next_speaker
|
||||
|
||||
# Try one more round, should stay with the last speaker
|
||||
next_speaker = manager.select_speaker(last_speaker=last_speaker, selector=selector)
|
||||
|
||||
# Ensure the next speaker is the same as the last speaker after reaching max rounds
|
||||
assert next_speaker == last_speaker
|
||||
|
||||
|
||||
# Continue adding more test cases as needed to cover various scenarios and functionalities of the code.
|
Loading…
Reference in new issue