From 406ce98e537b237e940a94cc5922db1648363a50 Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 14 Oct 2023 19:35:57 -0400 Subject: [PATCH] pegasus tests Former-commit-id: f3cd328eee2174ba4a4319eb32eff053b41555ab --- swarms/{memory => embeddings}/embed.py | 0 tests/embeddings/pegasus.py | 33 ++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) rename swarms/{memory => embeddings}/embed.py (100%) create mode 100644 tests/embeddings/pegasus.py diff --git a/swarms/memory/embed.py b/swarms/embeddings/embed.py similarity index 100% rename from swarms/memory/embed.py rename to swarms/embeddings/embed.py diff --git a/tests/embeddings/pegasus.py b/tests/embeddings/pegasus.py new file mode 100644 index 00000000..29227d28 --- /dev/null +++ b/tests/embeddings/pegasus.py @@ -0,0 +1,33 @@ +import pytest +from unittest.mock import Mock, patch +from swarms.embeddings.pegasus import PegasusEmbedding + + +def test_init(): + with patch("your_module.Pegasus") as MockPegasus: + embedder = PegasusEmbedding(modality="text") + MockPegasus.assert_called_once() + assert embedder.pegasus == MockPegasus.return_value + + +def test_init_exception(): + with patch("your_module.Pegasus", side_effect=Exception("Test exception")): + with pytest.raises(Exception) as e: + embedder = PegasusEmbedding(modality="text") + assert str(e.value) == "Test exception" + + +def test_embed(): + with patch("your_module.Pegasus") as MockPegasus: + embedder = PegasusEmbedding(modality="text") + embedder.embed("Hello world") + MockPegasus.return_value.embed.assert_called_once() + + +def test_embed_exception(): + with patch("your_module.Pegasus") as MockPegasus: + MockPegasus.return_value.embed.side_effect = Exception("Test exception") + embedder = PegasusEmbedding(modality="text") + with pytest.raises(Exception) as e: + embedder.embed("Hello world") + assert str(e.value) == "Test exception"