migration for openai

Former-commit-id: 7e3e6ca281
grit/923f7c6f-0958-480b-8748-ea6bbf1c2084
Kye 1 year ago
parent 7bf0f75bf1
commit dcd82c0cf6

@ -37,7 +37,7 @@ To use Fuyu, follow these steps:
1. Initialize the Fuyu instance:
```python
from swarms.models import Fuyu
from swarms.models.fuyu import Fuyu
fuyu = Fuyu()
```
@ -54,7 +54,7 @@ output_text = fuyu(text, img_path)
### Example 2 - Text Generation
```python
from swarms.models import Fuyu
from swarms.models.fuyu import Fuyu
fuyu = Fuyu()

@ -101,7 +101,7 @@ def _create_retry_decorator(
import openai
errors = [
openai.error.Timeout,
openai.Timeout,
openai.error.APIError,
openai.error.APIConnectionError,
openai.error.RateLimitError,
@ -547,7 +547,7 @@ class OpenAIChat(BaseChatModel):
if self.openai_proxy:
import openai
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501
raise Exception("The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy={"http": self.openai_proxy, "https": self.openai_proxy})'") # type: ignore[assignment] # noqa: E501
return {**self._default_params, **openai_creds}
def _get_invocation_params(

@ -88,7 +88,7 @@ def _create_retry_decorator(
import openai
errors = [
openai.error.Timeout,
openai.Timeout,
openai.error.APIError,
openai.error.APIConnectionError,
openai.error.RateLimitError,
@ -500,10 +500,10 @@ class BaseOpenAI(BaseLLM):
if self.openai_proxy:
import openai
openai.proxy = {
raise Exception("The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy={
"http": self.openai_proxy,
"https": self.openai_proxy,
} # type: ignore[assignment] # noqa: E501
})'") # type: ignore[assignment] # noqa: E501
return {**openai_creds, **self._default_params}
@property
@ -782,16 +782,16 @@ class OpenAIChat(BaseLLM):
try:
import openai
openai.api_key = openai_api_key
if openai_api_base:
openai.api_base = openai_api_base
raise Exception("The 'openai.api_base' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(api_base=openai_api_base)'")
if openai_organization:
openai.organization = openai_organization
raise Exception("The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=openai_organization)'")
if openai_proxy:
openai.proxy = {
raise Exception("The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy={
"http": openai_proxy,
"https": openai_proxy,
} # type: ignore[assignment] # noqa: E501
})'") # type: ignore[assignment] # noqa: E501
except ImportError:
raise ImportError(
"Could not import openai python package. "

@ -1,4 +1,6 @@
import openai
from openai import OpenAI
client = OpenAI(api_key=getenv("OPENAI_API_KEY"))
from dotenv import load_dotenv
from os import getenv
@ -14,13 +16,11 @@ def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"):
>>> get_ada_embeddings("Hello World", model="text-embedding-ada-001")
"""
openai.api_key = getenv("OPENAI_API_KEY")
text = text.replace("\n", " ")
return openai.Embedding.create(
input=[text],
model=model,
)["data"][
return client.embeddings.create(input=[text],
model=model)["data"][
0
]["embedding"]

@ -24,7 +24,7 @@ def test_texts():
# Basic Test
def test_get_ada_embeddings_basic(test_texts):
with patch("openai.Embedding.create") as mock_create:
with patch("openai.resources.Embeddings.create") as mock_create:
# Mocking the OpenAI API call
mock_create.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
@ -49,7 +49,7 @@ def test_get_ada_embeddings_basic(test_texts):
],
)
def test_get_ada_embeddings_models(text, model, expected_call_model):
with patch("openai.Embedding.create") as mock_create:
with patch("openai.resources.Embeddings.create") as mock_create:
mock_create.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
_ = get_ada_embeddings(text, model=model)
@ -58,16 +58,16 @@ def test_get_ada_embeddings_models(text, model, expected_call_model):
# Exception Test
def test_get_ada_embeddings_exception():
with patch("openai.Embedding.create") as mock_create:
mock_create.side_effect = openai.error.OpenAIError("Test error")
with pytest.raises(openai.error.OpenAIError):
with patch("openai.resources.Embeddings.create") as mock_create:
mock_create.side_effect = openai.OpenAIError("Test error")
with pytest.raises(openai.OpenAIError):
get_ada_embeddings("Some text")
# Tests for environment variable loading
def test_env_var_loading(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "testkey123")
with patch("openai.Embedding.create"):
with patch("openai.resources.Embeddings.create"):
assert (
getenv("OPENAI_API_KEY") == "testkey123"
), "Environment variable for API key is not set correctly"

Loading…
Cancel
Save