[CLEANUP][ADA]

pull/334/head
Kye 1 year ago
parent 1c4f0d8ad5
commit 8cdb82bd9d

@ -1,22 +0,0 @@
from openai import OpenAI
client = OpenAI()
def get_ada_embeddings(
text: str, model: str = "text-embedding-ada-002"
):
"""
Simple function to get embeddings from ada
Usage:
>>> get_ada_embeddings("Hello World")
>>> get_ada_embeddings("Hello World", model="text-embedding-ada-001")
"""
text = text.replace("\n", " ")
return client.embeddings.create(input=[text], model=model)[
"data"
][0]["embedding"]

@ -1,91 +0,0 @@
# test_embeddings.py
import pytest
import openai
from unittest.mock import patch
from swarms.models.simple_ada import (
get_ada_embeddings,
) # Adjust this import path to your project structure
from os import getenv
from dotenv import load_dotenv
load_dotenv()
# Fixture for test texts
@pytest.fixture
def test_texts():
return [
"Hello World",
"This is a test string with newline\ncharacters",
"A quick brown fox jumps over the lazy dog",
]
# Basic Test
def test_get_ada_embeddings_basic(test_texts):
with patch("openai.resources.Embeddings.create") as mock_create:
# Mocking the OpenAI API call
mock_create.return_value = {
"data": [{"embedding": [0.1, 0.2, 0.3]}]
}
for text in test_texts:
embedding = get_ada_embeddings(text)
assert embedding == [
0.1,
0.2,
0.3,
], "Embedding does not match expected output"
mock_create.assert_called_with(
input=[text.replace("\n", " ")],
model="text-embedding-ada-002",
)
# Parameterized Test
@pytest.mark.parametrize(
"text, model, expected_call_model",
[
(
"Hello World",
"text-embedding-ada-002",
"text-embedding-ada-002",
),
(
"Hello World",
"text-embedding-ada-001",
"text-embedding-ada-001",
),
],
)
def test_get_ada_embeddings_models(text, model, expected_call_model):
with patch("openai.resources.Embeddings.create") as mock_create:
mock_create.return_value = {
"data": [{"embedding": [0.1, 0.2, 0.3]}]
}
_ = get_ada_embeddings(text, model=model)
mock_create.assert_called_with(
input=[text], model=expected_call_model
)
# Exception Test
def test_get_ada_embeddings_exception():
with patch("openai.resources.Embeddings.create") as mock_create:
mock_create.side_effect = openai.OpenAIError("Test error")
with pytest.raises(openai.OpenAIError):
get_ada_embeddings("Some text")
# Tests for environment variable loading
def test_env_var_loading(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "testkey123")
with patch("openai.resources.Embeddings.create"):
assert (
getenv("OPENAI_API_KEY") == "testkey123"
), "Environment variable for API key is not set correctly"
# ... more tests to cover other aspects such as different input types, large inputs, invalid inputs, etc.
Loading…
Cancel
Save