import threading from swarms.memory.short_term_memory import ShortTermMemory def test_init(): memory = ShortTermMemory() assert memory.short_term_memory == [] assert memory.medium_term_memory == [] def test_add(): memory = ShortTermMemory() memory.add("user", "Hello, world!") assert memory.short_term_memory == [ {"role": "user", "message": "Hello, world!"} ] def test_get_short_term(): memory = ShortTermMemory() memory.add("user", "Hello, world!") assert memory.get_short_term() == [ {"role": "user", "message": "Hello, world!"} ] def test_get_medium_term(): memory = ShortTermMemory() memory.add("user", "Hello, world!") memory.move_to_medium_term(0) assert memory.get_medium_term() == [ {"role": "user", "message": "Hello, world!"} ] def test_clear_medium_term(): memory = ShortTermMemory() memory.add("user", "Hello, world!") memory.move_to_medium_term(0) memory.clear_medium_term() assert memory.get_medium_term() == [] def test_get_short_term_memory_str(): memory = ShortTermMemory() memory.add("user", "Hello, world!") assert ( memory.get_short_term_memory_str() == "[{'role': 'user', 'message': 'Hello, world!'}]" ) def test_update_short_term(): memory = ShortTermMemory() memory.add("user", "Hello, world!") memory.update_short_term(0, "user", "Goodbye, world!") assert memory.get_short_term() == [ {"role": "user", "message": "Goodbye, world!"} ] def test_clear(): memory = ShortTermMemory() memory.add("user", "Hello, world!") memory.clear() assert memory.get_short_term() == [] def test_search_memory(): memory = ShortTermMemory() memory.add("user", "Hello, world!") assert memory.search_memory("Hello") == { "short_term": [(0, {"role": "user", "message": "Hello, world!"})], "medium_term": [], } def test_return_shortmemory_as_str(): memory = ShortTermMemory() memory.add("user", "Hello, world!") assert ( memory.return_shortmemory_as_str() == "[{'role': 'user', 'message': 'Hello, world!'}]" ) def test_move_to_medium_term(): memory = ShortTermMemory() memory.add("user", "Hello, world!") memory.move_to_medium_term(0) assert memory.get_medium_term() == [ {"role": "user", "message": "Hello, world!"} ] assert memory.get_short_term() == [] def test_return_medium_memory_as_str(): memory = ShortTermMemory() memory.add("user", "Hello, world!") memory.move_to_medium_term(0) assert ( memory.return_medium_memory_as_str() == "[{'role': 'user', 'message': 'Hello, world!'}]" ) def test_thread_safety(): memory = ShortTermMemory() def add_messages(): for _ in range(1000): memory.add("user", "Hello, world!") threads = [threading.Thread(target=add_messages) for _ in range(10)] for thread in threads: thread.start() for thread in threads: thread.join() assert len(memory.get_short_term()) == 10000 def test_save_and_load(): memory1 = ShortTermMemory() memory1.add("user", "Hello, world!") memory1.save_to_file("memory.json") memory2 = ShortTermMemory() memory2.load_from_file("memory.json") assert memory1.get_short_term() == memory2.get_short_term() assert memory1.get_medium_term() == memory2.get_medium_term()