You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
69 lines
2.3 KiB
69 lines
2.3 KiB
1 year ago
|
# test_embeddings.py
|
||
|
|
||
|
import pytest
|
||
|
import openai
|
||
|
from unittest.mock import patch
|
||
|
from swarms.models.simple_ada import get_ada_embeddings # Adjust this import path to your project structure
|
||
|
from os import getenv
|
||
|
from dotenv import load_dotenv
|
||
|
|
||
|
load_dotenv()
|
||
|
|
||
|
# Fixture for test texts
|
||
|
@pytest.fixture
|
||
|
def test_texts():
|
||
|
return [
|
||
|
"Hello World",
|
||
|
"This is a test string with newline\ncharacters",
|
||
|
"A quick brown fox jumps over the lazy dog",
|
||
|
]
|
||
|
|
||
|
# Basic Test
|
||
|
def test_get_ada_embeddings_basic(test_texts):
|
||
|
with patch('openai.Embedding.create') as mock_create:
|
||
|
# Mocking the OpenAI API call
|
||
|
mock_create.return_value = {
|
||
|
"data": [
|
||
|
{"embedding": [0.1, 0.2, 0.3]}
|
||
|
]
|
||
|
}
|
||
|
|
||
|
for text in test_texts:
|
||
|
embedding = get_ada_embeddings(text)
|
||
|
assert embedding == [0.1, 0.2, 0.3], "Embedding does not match expected output"
|
||
|
mock_create.assert_called_with(input=[text.replace("\n", " ")], model="text-embedding-ada-002")
|
||
|
|
||
|
# Parameterized Test
|
||
|
@pytest.mark.parametrize(
|
||
|
"text, model, expected_call_model",
|
||
|
[
|
||
|
("Hello World", "text-embedding-ada-002", "text-embedding-ada-002"),
|
||
|
("Hello World", "text-embedding-ada-001", "text-embedding-ada-001"),
|
||
|
],
|
||
|
)
|
||
|
def test_get_ada_embeddings_models(text, model, expected_call_model):
|
||
|
with patch('openai.Embedding.create') as mock_create:
|
||
|
mock_create.return_value = {
|
||
|
"data": [
|
||
|
{"embedding": [0.1, 0.2, 0.3]}
|
||
|
]
|
||
|
}
|
||
|
|
||
|
_ = get_ada_embeddings(text, model=model)
|
||
|
mock_create.assert_called_with(input=[text], model=expected_call_model)
|
||
|
|
||
|
# Exception Test
|
||
|
def test_get_ada_embeddings_exception():
|
||
|
with patch('openai.Embedding.create') as mock_create:
|
||
|
mock_create.side_effect = openai.error.OpenAIError("Test error")
|
||
|
with pytest.raises(openai.error.OpenAIError):
|
||
|
get_ada_embeddings("Some text")
|
||
|
|
||
|
# Tests for environment variable loading
|
||
|
def test_env_var_loading(monkeypatch):
|
||
|
monkeypatch.setenv("OPENAI_API_KEY", "testkey123")
|
||
|
with patch('openai.Embedding.create'):
|
||
|
assert getenv("OPENAI_API_KEY") == "testkey123", "Environment variable for API key is not set correctly"
|
||
|
|
||
|
# ... more tests to cover other aspects such as different input types, large inputs, invalid inputs, etc.
|