You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
131 lines
3.4 KiB
131 lines
3.4 KiB
6 months ago
|
import threading
|
||
|
|
||
|
from swarms.memory.short_term_memory import ShortTermMemory
|
||
|
|
||
|
|
||
|
def test_init():
|
||
|
memory = ShortTermMemory()
|
||
|
assert memory.short_term_memory == []
|
||
|
assert memory.medium_term_memory == []
|
||
|
|
||
|
|
||
|
def test_add():
|
||
|
memory = ShortTermMemory()
|
||
|
memory.add("user", "Hello, world!")
|
||
|
assert memory.short_term_memory == [
|
||
|
{"role": "user", "message": "Hello, world!"}
|
||
|
]
|
||
|
|
||
|
|
||
|
def test_get_short_term():
|
||
|
memory = ShortTermMemory()
|
||
|
memory.add("user", "Hello, world!")
|
||
|
assert memory.get_short_term() == [
|
||
|
{"role": "user", "message": "Hello, world!"}
|
||
|
]
|
||
|
|
||
|
|
||
|
def test_get_medium_term():
|
||
|
memory = ShortTermMemory()
|
||
|
memory.add("user", "Hello, world!")
|
||
|
memory.move_to_medium_term(0)
|
||
|
assert memory.get_medium_term() == [
|
||
|
{"role": "user", "message": "Hello, world!"}
|
||
|
]
|
||
|
|
||
|
|
||
|
def test_clear_medium_term():
|
||
|
memory = ShortTermMemory()
|
||
|
memory.add("user", "Hello, world!")
|
||
|
memory.move_to_medium_term(0)
|
||
|
memory.clear_medium_term()
|
||
|
assert memory.get_medium_term() == []
|
||
|
|
||
|
|
||
|
def test_get_short_term_memory_str():
|
||
|
memory = ShortTermMemory()
|
||
|
memory.add("user", "Hello, world!")
|
||
|
assert (
|
||
|
memory.get_short_term_memory_str()
|
||
|
== "[{'role': 'user', 'message': 'Hello, world!'}]"
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_update_short_term():
|
||
|
memory = ShortTermMemory()
|
||
|
memory.add("user", "Hello, world!")
|
||
|
memory.update_short_term(0, "user", "Goodbye, world!")
|
||
|
assert memory.get_short_term() == [
|
||
|
{"role": "user", "message": "Goodbye, world!"}
|
||
|
]
|
||
|
|
||
|
|
||
|
def test_clear():
|
||
|
memory = ShortTermMemory()
|
||
|
memory.add("user", "Hello, world!")
|
||
|
memory.clear()
|
||
|
assert memory.get_short_term() == []
|
||
|
|
||
|
|
||
|
def test_search_memory():
|
||
|
memory = ShortTermMemory()
|
||
|
memory.add("user", "Hello, world!")
|
||
|
assert memory.search_memory("Hello") == {
|
||
|
"short_term": [(0, {"role": "user", "message": "Hello, world!"})],
|
||
|
"medium_term": [],
|
||
|
}
|
||
|
|
||
|
|
||
|
def test_return_shortmemory_as_str():
|
||
|
memory = ShortTermMemory()
|
||
|
memory.add("user", "Hello, world!")
|
||
|
assert (
|
||
|
memory.return_shortmemory_as_str()
|
||
|
== "[{'role': 'user', 'message': 'Hello, world!'}]"
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_move_to_medium_term():
|
||
|
memory = ShortTermMemory()
|
||
|
memory.add("user", "Hello, world!")
|
||
|
memory.move_to_medium_term(0)
|
||
|
assert memory.get_medium_term() == [
|
||
|
{"role": "user", "message": "Hello, world!"}
|
||
|
]
|
||
|
assert memory.get_short_term() == []
|
||
|
|
||
|
|
||
|
def test_return_medium_memory_as_str():
|
||
|
memory = ShortTermMemory()
|
||
|
memory.add("user", "Hello, world!")
|
||
|
memory.move_to_medium_term(0)
|
||
|
assert (
|
||
|
memory.return_medium_memory_as_str()
|
||
|
== "[{'role': 'user', 'message': 'Hello, world!'}]"
|
||
|
)
|
||
|
|
||
|
|
||
|
def test_thread_safety():
|
||
|
memory = ShortTermMemory()
|
||
|
|
||
|
def add_messages():
|
||
|
for _ in range(1000):
|
||
|
memory.add("user", "Hello, world!")
|
||
|
|
||
|
threads = [threading.Thread(target=add_messages) for _ in range(10)]
|
||
|
for thread in threads:
|
||
|
thread.start()
|
||
|
for thread in threads:
|
||
|
thread.join()
|
||
|
assert len(memory.get_short_term()) == 10000
|
||
|
|
||
|
|
||
|
def test_save_and_load():
|
||
|
memory1 = ShortTermMemory()
|
||
|
memory1.add("user", "Hello, world!")
|
||
|
memory1.save_to_file("memory.json")
|
||
|
memory2 = ShortTermMemory()
|
||
|
memory2.load_from_file("memory.json")
|
||
|
assert memory1.get_short_term() == memory2.get_short_term()
|
||
|
assert memory1.get_medium_term() == memory2.get_medium_term()
|