You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
118 lines
3.2 KiB
118 lines
3.2 KiB
11 months ago
|
from swarms.structs.agent import Agent
|
||
|
from swarms.structs.message_pool import MessagePool
|
||
|
from swarms import OpenAIChat
|
||
|
|
||
|
|
||
|
def test_message_pool_initialization():
|
||
|
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
|
||
|
agent2 = Agent(llm=OpenAIChat(), agent_name="agent1")
|
||
|
moderator = Agent(llm=OpenAIChat(), agent_name="agent1")
|
||
|
agents = [agent1, agent2]
|
||
|
message_pool = MessagePool(
|
||
|
agents=agents, moderator=moderator, turns=5
|
||
|
)
|
||
|
|
||
|
assert message_pool.agent == agents
|
||
|
assert message_pool.moderator == moderator
|
||
|
assert message_pool.turns == 5
|
||
|
assert message_pool.messages == []
|
||
|
|
||
|
|
||
|
def test_message_pool_add():
|
||
|
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
|
||
|
message_pool = MessagePool(
|
||
|
agents=[agent1], moderator=agent1, turns=5
|
||
|
)
|
||
|
message_pool.add(agent=agent1, content="Hello, world!", turn=1)
|
||
|
|
||
|
assert message_pool.messages == [
|
||
|
{
|
||
|
"agent": agent1,
|
||
|
"content": "Hello, world!",
|
||
|
"turn": 1,
|
||
|
"visible_to": "all",
|
||
|
"logged": True,
|
||
|
}
|
||
|
]
|
||
|
|
||
|
|
||
|
def test_message_pool_reset():
|
||
|
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
|
||
|
message_pool = MessagePool(
|
||
|
agents=[agent1], moderator=agent1, turns=5
|
||
|
)
|
||
|
message_pool.add(agent=agent1, content="Hello, world!", turn=1)
|
||
|
message_pool.reset()
|
||
|
|
||
|
assert message_pool.messages == []
|
||
|
|
||
|
|
||
|
def test_message_pool_last_turn():
|
||
|
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
|
||
|
message_pool = MessagePool(
|
||
|
agents=[agent1], moderator=agent1, turns=5
|
||
|
)
|
||
|
message_pool.add(agent=agent1, content="Hello, world!", turn=1)
|
||
|
|
||
|
assert message_pool.last_turn() == 1
|
||
|
|
||
|
|
||
|
def test_message_pool_last_message():
|
||
|
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
|
||
|
message_pool = MessagePool(
|
||
|
agents=[agent1], moderator=agent1, turns=5
|
||
|
)
|
||
|
message_pool.add(agent=agent1, content="Hello, world!", turn=1)
|
||
|
|
||
|
assert message_pool.last_message == {
|
||
|
"agent": agent1,
|
||
|
"content": "Hello, world!",
|
||
|
"turn": 1,
|
||
|
"visible_to": "all",
|
||
|
"logged": True,
|
||
|
}
|
||
|
|
||
|
|
||
|
def test_message_pool_get_all_messages():
|
||
|
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
|
||
|
message_pool = MessagePool(
|
||
|
agents=[agent1], moderator=agent1, turns=5
|
||
|
)
|
||
|
message_pool.add(agent=agent1, content="Hello, world!", turn=1)
|
||
|
|
||
|
assert message_pool.get_all_messages() == [
|
||
|
{
|
||
|
"agent": agent1,
|
||
|
"content": "Hello, world!",
|
||
|
"turn": 1,
|
||
|
"visible_to": "all",
|
||
|
"logged": True,
|
||
|
}
|
||
|
]
|
||
|
|
||
|
|
||
|
def test_message_pool_get_visible_messages():
|
||
|
agent1 = Agent(llm=OpenAIChat(), agent_name="agent1")
|
||
|
agent2 = Agent(agent_name="agent2")
|
||
|
message_pool = MessagePool(
|
||
|
agents=[agent1, agent2], moderator=agent1, turns=5
|
||
|
)
|
||
|
message_pool.add(
|
||
|
agent=agent1,
|
||
|
content="Hello, agent2!",
|
||
|
turn=1,
|
||
|
visible_to=[agent2.agent_name],
|
||
|
)
|
||
|
|
||
|
assert message_pool.get_visible_messages(
|
||
|
agent=agent2, turn=2
|
||
|
) == [
|
||
|
{
|
||
|
"agent": agent1,
|
||
|
"content": "Hello, agent2!",
|
||
|
"turn": 1,
|
||
|
"visible_to": [agent2.agent_name],
|
||
|
"logged": True,
|
||
|
}
|
||
|
]
|