Finish python migration

pull/388/head
Wyatt Stanke 11 months ago
parent 516620708c
commit a8b3adb50e
No known key found for this signature in database
GPG Key ID: CE6BA5FFF135536D

@ -1,6 +1,7 @@
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import openai from openai import OpenAI
import requests import requests
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from tenacity import ( from tenacity import (
@ -147,6 +148,7 @@ class OpenAIFunctionCaller:
self.user = user self.user = user
self.messages = messages if messages is not None else [] self.messages = messages if messages is not None else []
self.timeout_sec = timeout_sec self.timeout_sec = timeout_sec
self.client = OpenAI(api_key=self.openai_api_key)
def add_message(self, role: str, content: str): def add_message(self, role: str, content: str):
self.messages.append({"role": role, "content": content}) self.messages.append({"role": role, "content": content})
@ -163,7 +165,7 @@ class OpenAIFunctionCaller:
): ):
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": "Bearer " + openai.api_key, "Authorization": "Bearer " + self.openai_api_key,
} }
json_data = {"model": self.model, "messages": messages} json_data = {"model": self.model, "messages": messages}
if tools is not None: if tools is not None:
@ -235,7 +237,7 @@ class OpenAIFunctionCaller:
) )
def call(self, task: str, *args, **kwargs) -> Dict: def call(self, task: str, *args, **kwargs) -> Dict:
return openai.Completion.create( return self.client.completions.create(
engine=self.model, engine=self.model,
prompt=task, prompt=task,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,

@ -177,11 +177,11 @@ def _create_retry_decorator(
import openai import openai
errors = [ errors = [
openai.error.Timeout, openai.Timeout,
openai.error.APIError, openai.APIError,
openai.error.APIConnectionError, openai.APIConnectionError,
openai.error.RateLimitError, openai.RateLimitError,
openai.error.ServiceUnavailableError, openai.ServiceUnavailableError,
] ]
return create_base_retry_decorator( return create_base_retry_decorator(
error_types=errors, error_types=errors,
@ -239,9 +239,9 @@ class BaseOpenAI(BaseLLM):
attributes["openai_api_base"] = self.openai_api_base attributes["openai_api_base"] = self.openai_api_base
if self.openai_organization != "": if self.openai_organization != "":
attributes["openai_organization"] = ( attributes[
self.openai_organization "openai_organization"
) ] = self.openai_organization
if self.openai_proxy != "": if self.openai_proxy != "":
attributes["openai_proxy"] = self.openai_proxy attributes["openai_proxy"] = self.openai_proxy
@ -352,7 +352,13 @@ class BaseOpenAI(BaseLLM):
try: try:
import openai import openai
values["client"] = openai.Completion values["client"] = openai.OpenAI(
api_key=values["openai_api_key"],
api_base=values["openai_api_base"] or None,
organization=values["openai_organization"] or None,
# TODO: Reenable this when openai package supports proxy
# proxy=values["openai_proxy"] or None,
)
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import openai python package. " "Could not import openai python package. "
@ -647,7 +653,8 @@ class BaseOpenAI(BaseLLM):
if self.openai_proxy: if self.openai_proxy:
import openai import openai
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501 # TODO: The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy={"http": self.openai_proxy, "https": self.openai_proxy})'
# openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501
return {**openai_creds, **self._default_params} return {**openai_creds, **self._default_params}
@property @property
@ -956,21 +963,13 @@ class OpenAIChat(BaseLLM):
) )
try: try:
import openai import openai
openai.api_key = openai_api_key
if openai_api_base:
openai.api_base = openai_api_base
if openai_organization:
openai.organization = openai_organization
if openai_proxy:
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import openai python package. " "Could not import openai python package. "
"Please install it with `pip install openai`." "Please install it with `pip install openai`."
) )
try: try:
values["client"] = openai.ChatCompletion values["client"] = openai.OpenAI
except AttributeError: except AttributeError:
raise ValueError( raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is" "`openai` has no `ChatCompletion` attribute, this is"

Loading…
Cancel
Save