From 3f7672fc0b5b7d692e24e24cfcb6cffe95a269e9 Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 14 Oct 2023 19:35:57 -0400 Subject: [PATCH] pegasus tests Former-commit-id: dec684736ad5e1d3058d6b7d5acd447c4744cbd1 --- swarms/{memory => embeddings}/embed.py | 0 tests/embeddings/pegasus.py | 33 ++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) rename swarms/{memory => embeddings}/embed.py (100%) create mode 100644 tests/embeddings/pegasus.py diff --git a/swarms/memory/embed.py b/swarms/embeddings/embed.py similarity index 100% rename from swarms/memory/embed.py rename to swarms/embeddings/embed.py diff --git a/tests/embeddings/pegasus.py b/tests/embeddings/pegasus.py new file mode 100644 index 00000000..29227d28 --- /dev/null +++ b/tests/embeddings/pegasus.py @@ -0,0 +1,33 @@ +import pytest +from unittest.mock import Mock, patch +from swarms.embeddings.pegasus import PegasusEmbedding + + +def test_init(): + with patch("your_module.Pegasus") as MockPegasus: + embedder = PegasusEmbedding(modality="text") + MockPegasus.assert_called_once() + assert embedder.pegasus == MockPegasus.return_value + + +def test_init_exception(): + with patch("your_module.Pegasus", side_effect=Exception("Test exception")): + with pytest.raises(Exception) as e: + embedder = PegasusEmbedding(modality="text") + assert str(e.value) == "Test exception" + + +def test_embed(): + with patch("your_module.Pegasus") as MockPegasus: + embedder = PegasusEmbedding(modality="text") + embedder.embed("Hello world") + MockPegasus.return_value.embed.assert_called_once() + + +def test_embed_exception(): + with patch("your_module.Pegasus") as MockPegasus: + MockPegasus.return_value.embed.side_effect = Exception("Test exception") + embedder = PegasusEmbedding(modality="text") + with pytest.raises(Exception) as e: + embedder.embed("Hello world") + assert str(e.value) == "Test exception"