You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
83 lines
2.2 KiB
83 lines
2.2 KiB
import pytest
|
|
import torch
|
|
from swarms.models.jina_embeds import JinaEmbeddings
|
|
|
|
|
|
@pytest.fixture
|
|
def model():
|
|
return JinaEmbeddings("bert-base-uncased", verbose=True)
|
|
|
|
|
|
def test_initialization(model):
|
|
assert isinstance(model, JinaEmbeddings)
|
|
assert model.device in ["cuda", "cpu"]
|
|
assert model.max_length == 500
|
|
assert model.verbose is True
|
|
|
|
|
|
def test_run_sync(model):
|
|
task = "Encode this text"
|
|
result = model.run(task)
|
|
assert isinstance(result, torch.Tensor)
|
|
assert result.shape == (model.max_length,)
|
|
|
|
|
|
def test_run_async(model):
|
|
task = "Encode this text"
|
|
result = model.run_async(task)
|
|
assert isinstance(result, torch.Tensor)
|
|
assert result.shape == (model.max_length,)
|
|
|
|
|
|
def test_save_model(tmp_path, model):
|
|
model_path = tmp_path / "model"
|
|
model.save_model(model_path)
|
|
assert (model_path / "config.json").is_file()
|
|
assert (model_path / "pytorch_model.bin").is_file()
|
|
assert (model_path / "vocab.txt").is_file()
|
|
|
|
|
|
def test_gpu_available(model):
|
|
gpu_status = model.gpu_available()
|
|
if torch.cuda.is_available():
|
|
assert gpu_status is True
|
|
else:
|
|
assert gpu_status is False
|
|
|
|
|
|
def test_memory_consumption(model):
|
|
memory_stats = model.memory_consumption()
|
|
if torch.cuda.is_available():
|
|
assert "allocated" in memory_stats
|
|
assert "reserved" in memory_stats
|
|
else:
|
|
assert "error" in memory_stats
|
|
|
|
|
|
def test_cosine_similarity(model):
|
|
task1 = "This is a sample text for testing."
|
|
task2 = "Another sample text for testing."
|
|
embeddings1 = model.run(task1)
|
|
embeddings2 = model.run(task2)
|
|
sim = model.cos_sim(embeddings1, embeddings2)
|
|
assert isinstance(sim, torch.Tensor)
|
|
assert sim.item() >= -1.0 and sim.item() <= 1.0
|
|
|
|
|
|
def test_failed_load_model(caplog):
|
|
with pytest.raises(Exception):
|
|
JinaEmbeddings("invalid_model")
|
|
assert "Failed to load the model or the tokenizer" in caplog.text
|
|
|
|
|
|
def test_failed_generate_text(caplog, model):
|
|
with pytest.raises(Exception):
|
|
model.run("invalid_task")
|
|
assert "Failed to generate the text" in caplog.text
|
|
|
|
|
|
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
|
def test_change_device(model, device):
|
|
model.device = device
|
|
assert model.device == device
|