You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/tests/models/test_jina_embeds.py

85 lines
2.2 KiB

import pytest
import torch
11 months ago
from swarms.models.jina_embeds import JinaEmbeddings
@pytest.fixture
def model():
return JinaEmbeddings("bert-base-uncased", verbose=True)
def test_initialization(model):
assert isinstance(model, JinaEmbeddings)
assert model.device in ["cuda", "cpu"]
assert model.max_length == 500
assert model.verbose is True
def test_run_sync(model):
task = "Encode this text"
result = model.run(task)
assert isinstance(result, torch.Tensor)
assert result.shape == (model.max_length,)
def test_run_async(model):
task = "Encode this text"
result = model.run_async(task)
assert isinstance(result, torch.Tensor)
assert result.shape == (model.max_length,)
def test_save_model(tmp_path, model):
model_path = tmp_path / "model"
model.save_model(model_path)
assert (model_path / "config.json").is_file()
assert (model_path / "pytorch_model.bin").is_file()
assert (model_path / "vocab.txt").is_file()
def test_gpu_available(model):
gpu_status = model.gpu_available()
if torch.cuda.is_available():
assert gpu_status is True
else:
assert gpu_status is False
def test_memory_consumption(model):
memory_stats = model.memory_consumption()
if torch.cuda.is_available():
assert "allocated" in memory_stats
assert "reserved" in memory_stats
else:
assert "error" in memory_stats
def test_cosine_similarity(model):
task1 = "This is a sample text for testing."
task2 = "Another sample text for testing."
embeddings1 = model.run(task1)
embeddings2 = model.run(task2)
sim = model.cos_sim(embeddings1, embeddings2)
assert isinstance(sim, torch.Tensor)
11 months ago
assert sim.item() >= -1.0
assert sim.item() <= 1.0
def test_failed_load_model(caplog):
with pytest.raises(Exception):
JinaEmbeddings("invalid_model")
assert "Failed to load the model or the tokenizer" in caplog.text
def test_failed_generate_text(caplog, model):
with pytest.raises(Exception):
model.run("invalid_task")
assert "Failed to generate the text" in caplog.text
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_change_device(model, device):
model.device = device
assert model.device == device