migration for openai

Former-commit-id: 7e3e6ca281
grit/923f7c6f-0958-480b-8748-ea6bbf1c2084
Kye 1 year ago
parent 7bf0f75bf1
commit dcd82c0cf6

@ -37,7 +37,7 @@ To use Fuyu, follow these steps:
1. Initialize the Fuyu instance: 1. Initialize the Fuyu instance:
```python ```python
from swarms.models import Fuyu from swarms.models.fuyu import Fuyu
fuyu = Fuyu() fuyu = Fuyu()
``` ```
@ -54,7 +54,7 @@ output_text = fuyu(text, img_path)
### Example 2 - Text Generation ### Example 2 - Text Generation
```python ```python
from swarms.models import Fuyu from swarms.models.fuyu import Fuyu
fuyu = Fuyu() fuyu = Fuyu()

@ -101,7 +101,7 @@ def _create_retry_decorator(
import openai import openai
errors = [ errors = [
openai.error.Timeout, openai.Timeout,
openai.error.APIError, openai.error.APIError,
openai.error.APIConnectionError, openai.error.APIConnectionError,
openai.error.RateLimitError, openai.error.RateLimitError,
@ -547,7 +547,7 @@ class OpenAIChat(BaseChatModel):
if self.openai_proxy: if self.openai_proxy:
import openai import openai
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501 raise Exception("The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy={"http": self.openai_proxy, "https": self.openai_proxy})'") # type: ignore[assignment] # noqa: E501
return {**self._default_params, **openai_creds} return {**self._default_params, **openai_creds}
def _get_invocation_params( def _get_invocation_params(

@ -88,7 +88,7 @@ def _create_retry_decorator(
import openai import openai
errors = [ errors = [
openai.error.Timeout, openai.Timeout,
openai.error.APIError, openai.error.APIError,
openai.error.APIConnectionError, openai.error.APIConnectionError,
openai.error.RateLimitError, openai.error.RateLimitError,
@ -500,10 +500,10 @@ class BaseOpenAI(BaseLLM):
if self.openai_proxy: if self.openai_proxy:
import openai import openai
openai.proxy = { raise Exception("The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy={
"http": self.openai_proxy, "http": self.openai_proxy,
"https": self.openai_proxy, "https": self.openai_proxy,
} # type: ignore[assignment] # noqa: E501 })'") # type: ignore[assignment] # noqa: E501
return {**openai_creds, **self._default_params} return {**openai_creds, **self._default_params}
@property @property
@ -782,16 +782,16 @@ class OpenAIChat(BaseLLM):
try: try:
import openai import openai
openai.api_key = openai_api_key
if openai_api_base: if openai_api_base:
openai.api_base = openai_api_base raise Exception("The 'openai.api_base' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(api_base=openai_api_base)'")
if openai_organization: if openai_organization:
openai.organization = openai_organization raise Exception("The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=openai_organization)'")
if openai_proxy: if openai_proxy:
openai.proxy = { raise Exception("The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy={
"http": openai_proxy, "http": openai_proxy,
"https": openai_proxy, "https": openai_proxy,
} # type: ignore[assignment] # noqa: E501 })'") # type: ignore[assignment] # noqa: E501
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import openai python package. " "Could not import openai python package. "

@ -1,4 +1,6 @@
import openai from openai import OpenAI
client = OpenAI(api_key=getenv("OPENAI_API_KEY"))
from dotenv import load_dotenv from dotenv import load_dotenv
from os import getenv from os import getenv
@ -14,13 +16,11 @@ def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"):
>>> get_ada_embeddings("Hello World", model="text-embedding-ada-001") >>> get_ada_embeddings("Hello World", model="text-embedding-ada-001")
""" """
openai.api_key = getenv("OPENAI_API_KEY")
text = text.replace("\n", " ") text = text.replace("\n", " ")
return openai.Embedding.create( return client.embeddings.create(input=[text],
input=[text], model=model)["data"][
model=model,
)["data"][
0 0
]["embedding"] ]["embedding"]

@ -24,7 +24,7 @@ def test_texts():
# Basic Test # Basic Test
def test_get_ada_embeddings_basic(test_texts): def test_get_ada_embeddings_basic(test_texts):
with patch("openai.Embedding.create") as mock_create: with patch("openai.resources.Embeddings.create") as mock_create:
# Mocking the OpenAI API call # Mocking the OpenAI API call
mock_create.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]} mock_create.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
@ -49,7 +49,7 @@ def test_get_ada_embeddings_basic(test_texts):
], ],
) )
def test_get_ada_embeddings_models(text, model, expected_call_model): def test_get_ada_embeddings_models(text, model, expected_call_model):
with patch("openai.Embedding.create") as mock_create: with patch("openai.resources.Embeddings.create") as mock_create:
mock_create.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]} mock_create.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
_ = get_ada_embeddings(text, model=model) _ = get_ada_embeddings(text, model=model)
@ -58,16 +58,16 @@ def test_get_ada_embeddings_models(text, model, expected_call_model):
# Exception Test # Exception Test
def test_get_ada_embeddings_exception(): def test_get_ada_embeddings_exception():
with patch("openai.Embedding.create") as mock_create: with patch("openai.resources.Embeddings.create") as mock_create:
mock_create.side_effect = openai.error.OpenAIError("Test error") mock_create.side_effect = openai.OpenAIError("Test error")
with pytest.raises(openai.error.OpenAIError): with pytest.raises(openai.OpenAIError):
get_ada_embeddings("Some text") get_ada_embeddings("Some text")
# Tests for environment variable loading # Tests for environment variable loading
def test_env_var_loading(monkeypatch): def test_env_var_loading(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "testkey123") monkeypatch.setenv("OPENAI_API_KEY", "testkey123")
with patch("openai.Embedding.create"): with patch("openai.resources.Embeddings.create"):
assert ( assert (
getenv("OPENAI_API_KEY") == "testkey123" getenv("OPENAI_API_KEY") == "testkey123"
), "Environment variable for API key is not set correctly" ), "Environment variable for API key is not set correctly"

Loading…
Cancel
Save