# test_embeddings.py

import pytest
import openai
from unittest.mock import patch
from swarms.models.simple_ada import (
    get_ada_embeddings,
)  # Adjust this import path to your project structure
from os import getenv
from dotenv import load_dotenv

load_dotenv()

# Fixture for test texts
@pytest.fixture
def test_texts():
    return [
        "Hello World",
        "This is a test string with newline\ncharacters",
        "A quick brown fox jumps over the lazy dog",
    ]


# Basic Test
def test_get_ada_embeddings_basic(test_texts):
    with patch("openai.Embedding.create") as mock_create:
        # Mocking the OpenAI API call
        mock_create.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}

        for text in test_texts:
            embedding = get_ada_embeddings(text)
            assert embedding == [
                0.1,
                0.2,
                0.3,
            ], "Embedding does not match expected output"
            mock_create.assert_called_with(
                input=[text.replace("\n", " ")], model="text-embedding-ada-002"
            )


# Parameterized Test
@pytest.mark.parametrize(
    "text, model, expected_call_model",
    [
        ("Hello World", "text-embedding-ada-002", "text-embedding-ada-002"),
        ("Hello World", "text-embedding-ada-001", "text-embedding-ada-001"),
    ],
)
def test_get_ada_embeddings_models(text, model, expected_call_model):
    with patch("openai.Embedding.create") as mock_create:
        mock_create.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}

        _ = get_ada_embeddings(text, model=model)
        mock_create.assert_called_with(input=[text], model=expected_call_model)


# Exception Test
def test_get_ada_embeddings_exception():
    with patch("openai.Embedding.create") as mock_create:
        mock_create.side_effect = openai.error.OpenAIError("Test error")
        with pytest.raises(openai.error.OpenAIError):
            get_ada_embeddings("Some text")


# Tests for environment variable loading
def test_env_var_loading(monkeypatch):
    monkeypatch.setenv("OPENAI_API_KEY", "testkey123")
    with patch("openai.Embedding.create"):
        assert (
            getenv("OPENAI_API_KEY") == "testkey123"
        ), "Environment variable for API key is not set correctly"


# ... more tests to cover other aspects such as different input types, large inputs, invalid inputs, etc.