From 7e3e6ca2816961347ad533b838a8bff25367b138 Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 11 Nov 2023 18:12:09 -0500 Subject: [PATCH] migration for openai --- docs/swarms/models/fuyu.md | 4 ++-- swarms/models/openai_chat.py | 4 ++-- swarms/models/openai_models.py | 16 ++++++++-------- swarms/models/simple_ada.py | 12 ++++++------ tests/models/ada.py | 12 ++++++------ 5 files changed, 24 insertions(+), 24 deletions(-) diff --git a/docs/swarms/models/fuyu.md b/docs/swarms/models/fuyu.md index 021469e8..e54a4a22 100644 --- a/docs/swarms/models/fuyu.md +++ b/docs/swarms/models/fuyu.md @@ -37,7 +37,7 @@ To use Fuyu, follow these steps: 1. Initialize the Fuyu instance: ```python -from swarms.models import Fuyu +from swarms.models.fuyu import Fuyu fuyu = Fuyu() ``` @@ -54,7 +54,7 @@ output_text = fuyu(text, img_path) ### Example 2 - Text Generation ```python -from swarms.models import Fuyu +from swarms.models.fuyu import Fuyu fuyu = Fuyu() diff --git a/swarms/models/openai_chat.py b/swarms/models/openai_chat.py index 546f3509..aaf2eb19 100644 --- a/swarms/models/openai_chat.py +++ b/swarms/models/openai_chat.py @@ -101,7 +101,7 @@ def _create_retry_decorator( import openai errors = [ - openai.error.Timeout, + openai.Timeout, openai.error.APIError, openai.error.APIConnectionError, openai.error.RateLimitError, @@ -547,7 +547,7 @@ class OpenAIChat(BaseChatModel): if self.openai_proxy: import openai - openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501 + raise Exception("The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy={"http": self.openai_proxy, "https": self.openai_proxy})'") # type: ignore[assignment] # noqa: E501 return {**self._default_params, **openai_creds} def _get_invocation_params( diff --git a/swarms/models/openai_models.py b/swarms/models/openai_models.py index 4b0cc91d..c1fdd2b1 100644 --- a/swarms/models/openai_models.py +++ b/swarms/models/openai_models.py @@ -88,7 +88,7 @@ def _create_retry_decorator( import openai errors = [ - openai.error.Timeout, + openai.Timeout, openai.error.APIError, openai.error.APIConnectionError, openai.error.RateLimitError, @@ -500,10 +500,10 @@ class BaseOpenAI(BaseLLM): if self.openai_proxy: import openai - openai.proxy = { + raise Exception("The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy={ "http": self.openai_proxy, "https": self.openai_proxy, - } # type: ignore[assignment] # noqa: E501 + })'") # type: ignore[assignment] # noqa: E501 return {**openai_creds, **self._default_params} @property @@ -782,16 +782,16 @@ class OpenAIChat(BaseLLM): try: import openai - openai.api_key = openai_api_key + if openai_api_base: - openai.api_base = openai_api_base + raise Exception("The 'openai.api_base' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(api_base=openai_api_base)'") if openai_organization: - openai.organization = openai_organization + raise Exception("The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=openai_organization)'") if openai_proxy: - openai.proxy = { + raise Exception("The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy={ "http": openai_proxy, "https": openai_proxy, - } # type: ignore[assignment] # noqa: E501 + })'") # type: ignore[assignment] # noqa: E501 except ImportError: raise ImportError( "Could not import openai python package. " diff --git a/swarms/models/simple_ada.py b/swarms/models/simple_ada.py index 7eb923b4..7aa3e6bd 100644 --- a/swarms/models/simple_ada.py +++ b/swarms/models/simple_ada.py @@ -1,4 +1,6 @@ -import openai +from openai import OpenAI + +client = OpenAI(api_key=getenv("OPENAI_API_KEY")) from dotenv import load_dotenv from os import getenv @@ -14,13 +16,11 @@ def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"): >>> get_ada_embeddings("Hello World", model="text-embedding-ada-001") """ - openai.api_key = getenv("OPENAI_API_KEY") + text = text.replace("\n", " ") - return openai.Embedding.create( - input=[text], - model=model, - )["data"][ + return client.embeddings.create(input=[text], + model=model)["data"][ 0 ]["embedding"] diff --git a/tests/models/ada.py b/tests/models/ada.py index 08f1a687..e65e1470 100644 --- a/tests/models/ada.py +++ b/tests/models/ada.py @@ -24,7 +24,7 @@ def test_texts(): # Basic Test def test_get_ada_embeddings_basic(test_texts): - with patch("openai.Embedding.create") as mock_create: + with patch("openai.resources.Embeddings.create") as mock_create: # Mocking the OpenAI API call mock_create.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]} @@ -49,7 +49,7 @@ def test_get_ada_embeddings_basic(test_texts): ], ) def test_get_ada_embeddings_models(text, model, expected_call_model): - with patch("openai.Embedding.create") as mock_create: + with patch("openai.resources.Embeddings.create") as mock_create: mock_create.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]} _ = get_ada_embeddings(text, model=model) @@ -58,16 +58,16 @@ def test_get_ada_embeddings_models(text, model, expected_call_model): # Exception Test def test_get_ada_embeddings_exception(): - with patch("openai.Embedding.create") as mock_create: - mock_create.side_effect = openai.error.OpenAIError("Test error") - with pytest.raises(openai.error.OpenAIError): + with patch("openai.resources.Embeddings.create") as mock_create: + mock_create.side_effect = openai.OpenAIError("Test error") + with pytest.raises(openai.OpenAIError): get_ada_embeddings("Some text") # Tests for environment variable loading def test_env_var_loading(monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "testkey123") - with patch("openai.Embedding.create"): + with patch("openai.resources.Embeddings.create"): assert ( getenv("OPENAI_API_KEY") == "testkey123" ), "Environment variable for API key is not set correctly"