diff --git a/swarms/memory/embed.py b/swarms/embeddings/embed.py similarity index 100% rename from swarms/memory/embed.py rename to swarms/embeddings/embed.py diff --git a/tests/embeddings/pegasus.py b/tests/embeddings/pegasus.py new file mode 100644 index 00000000..29227d28 --- /dev/null +++ b/tests/embeddings/pegasus.py @@ -0,0 +1,33 @@ +import pytest +from unittest.mock import Mock, patch +from swarms.embeddings.pegasus import PegasusEmbedding + + +def test_init(): + with patch("your_module.Pegasus") as MockPegasus: + embedder = PegasusEmbedding(modality="text") + MockPegasus.assert_called_once() + assert embedder.pegasus == MockPegasus.return_value + + +def test_init_exception(): + with patch("your_module.Pegasus", side_effect=Exception("Test exception")): + with pytest.raises(Exception) as e: + embedder = PegasusEmbedding(modality="text") + assert str(e.value) == "Test exception" + + +def test_embed(): + with patch("your_module.Pegasus") as MockPegasus: + embedder = PegasusEmbedding(modality="text") + embedder.embed("Hello world") + MockPegasus.return_value.embed.assert_called_once() + + +def test_embed_exception(): + with patch("your_module.Pegasus") as MockPegasus: + MockPegasus.return_value.embed.side_effect = Exception("Test exception") + embedder = PegasusEmbedding(modality="text") + with pytest.raises(Exception) as e: + embedder.embed("Hello world") + assert str(e.value) == "Test exception"