no openai model class

Former-commit-id: a085a1e233
group-chat
Kye 1 year ago
parent 6a1bca74b7
commit 6a89e9165d

@ -13,6 +13,6 @@ from swarms.models.anthropic import Anthropic
from swarms.models.huggingface import HuggingFaceLLM from swarms.models.huggingface import HuggingFaceLLM
# from swarms.models.palm import GooglePalm # from swarms.models.palm import GooglePalm
from swarms.models.petals import Petals from swarms.models.petals import Petals
from swarms.models.openai import OpenAIChat #from swarms.models.openai import OpenAIChat

@ -2,4 +2,4 @@ from swarms.models.anthropic import Anthropic
from swarms.models.huggingface import HuggingFaceLLM from swarms.models.huggingface import HuggingFaceLLM
# from swarms.models.palm import GooglePalm # from swarms.models.palm import GooglePalm
from swarms.models.petals import Petals from swarms.models.petals import Petals
from swarms.models.openai import OpenAIChat #from swarms.models.openai import OpenAIChat

@ -1,318 +1,318 @@
from __future__ import annotations # from __future__ import annotations
import logging # import logging
import sys # import sys
import warnings # import warnings
from typing import ( # from typing import (
AbstractSet, # AbstractSet,
Any, # Any,
AsyncIterator, # AsyncIterator,
Collection, # Collection,
Dict, # Dict,
Iterator, # Iterator,
List, # List,
Literal, # Literal,
Mapping, # Mapping,
Optional, # Optional,
Tuple, # Tuple,
Union, # Union,
) # )
from langchain.callbacks.manager import ( # from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, # AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, # CallbackManagerForLLMRun,
) # )
from langchain.pydantic_v1 import Field, root_validator # from langchain.pydantic_v1 import Field, root_validator
from langchain.schema import Generation, LLMResult # from langchain.schema import Generation, LLMResult
from langchain.schema.output import GenerationChunk # from langchain.schema.output import GenerationChunk
from langchain.utils import get_from_dict_or_env # from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__) # logger = logging.getLogger(__name__)
import os # import os
def get_from_dict_or_env( # def get_from_dict_or_env(
data: Dict[str, Any], # data: Dict[str, Any],
key: str, # key: str,
env_key: str, # env_key: str,
default: Optional[str] = None # default: Optional[str] = None
) -> str: # ) -> str:
"""Get a value from a dictionary or an environment variable.""" # """Get a value from a dictionary or an environment variable."""
if key in data and data[key]: # if key in data and data[key]:
return data[key] # return data[key]
else: # else:
return get_from_env(key, env_key, default=default) # return get_from_env(key, env_key, default=default)
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str: # def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
"""Get a value from a dictionary or an environment variable.""" # """Get a value from a dictionary or an environment variable."""
if env_key in os.environ and os.environ[env_key]: # if env_key in os.environ and os.environ[env_key]:
return os.environ[env_key] # return os.environ[env_key]
elif default is not None: # elif default is not None:
return default # return default
else: # else:
raise ValueError( # raise ValueError(
f"Did not find {key}, please add an environment variable" # f"Did not find {key}, please add an environment variable"
f" `{env_key}` which contains it, or pass" # f" `{env_key}` which contains it, or pass"
f" `{key}` as a named parameter." # f" `{key}` as a named parameter."
) # )
class OpenAIChat: # class OpenAIChat:
"""OpenAI Chat large language models. # """OpenAI Chat large language models.
To use, you should have the ``openai`` python package installed, and the # To use, you should have the ``openai`` python package installed, and the
environment variable ``OPENAI_API_KEY`` set with your API key. # environment variable ``OPENAI_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the openai.create call can be passed # Any parameters that are valid to be passed to the openai.create call can be passed
in, even if not explicitly saved on this class. # in, even if not explicitly saved on this class.
Example: # Example:
.. code-block:: python # .. code-block:: python
from langchain.llms import OpenAIChat # from langchain.llms import OpenAIChat
openaichat = OpenAIChat(model_name="gpt-3.5-turbo") # openaichat = OpenAIChat(model_name="gpt-3.5-turbo")
""" # """
client: Any #: :meta private: # client: Any #: :meta private:
model_name: str = "gpt-3.5-turbo" # model_name: str = "gpt-3.5-turbo"
"""Model name to use.""" # """Model name to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) # model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" # """Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[str] = None # openai_api_key: Optional[str] = None
openai_api_base: Optional[str] = None # openai_api_base: Optional[str] = None
# to support explicit proxy for OpenAI # # to support explicit proxy for OpenAI
openai_proxy: Optional[str] = None # openai_proxy: Optional[str] = None
max_retries: int = 6 # max_retries: int = 6
"""Maximum number of retries to make when generating.""" # """Maximum number of retries to make when generating."""
prefix_messages: List = Field(default_factory=list) # prefix_messages: List = Field(default_factory=list)
"""Series of messages for Chat input.""" # """Series of messages for Chat input."""
streaming: bool = False # streaming: bool = False
"""Whether to stream the results or not.""" # """Whether to stream the results or not."""
allowed_special: Union[Literal["all"], AbstractSet[str]] = set() # allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
"""Set of special tokens that are allowed。""" # """Set of special tokens that are allowed。"""
disallowed_special: Union[Literal["all"], Collection[str]] = "all" # disallowed_special: Union[Literal["all"], Collection[str]] = "all"
"""Set of special tokens that are not allowed。""" # """Set of special tokens that are not allowed。"""
@root_validator(pre=True) # @root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: # def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" # """Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()} # all_required_field_names = {field.alias for field in cls.__fields__.values()}
extra = values.get("model_kwargs", {}) # extra = values.get("model_kwargs", {})
for field_name in list(values): # for field_name in list(values):
if field_name not in all_required_field_names: # if field_name not in all_required_field_names:
if field_name in extra: # if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.") # raise ValueError(f"Found {field_name} supplied twice.")
extra[field_name] = values.pop(field_name) # extra[field_name] = values.pop(field_name)
values["model_kwargs"] = extra # values["model_kwargs"] = extra
return values # return values
@root_validator() # @root_validator()
def validate_environment(cls, values: Dict) -> Dict: # def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" # """Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env( # openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY" # values, "openai_api_key", "OPENAI_API_KEY"
) # )
openai_api_base = get_from_dict_or_env( # openai_api_base = get_from_dict_or_env(
values, # values,
"openai_api_base", # "openai_api_base",
"OPENAI_API_BASE", # "OPENAI_API_BASE",
default="", # default="",
) # )
openai_proxy = get_from_dict_or_env( # openai_proxy = get_from_dict_or_env(
values, # values,
"openai_proxy", # "openai_proxy",
"OPENAI_PROXY", # "OPENAI_PROXY",
default="", # default="",
) # )
openai_organization = get_from_dict_or_env( # openai_organization = get_from_dict_or_env(
values, "openai_organization", "OPENAI_ORGANIZATION", default="" # values, "openai_organization", "OPENAI_ORGANIZATION", default=""
) # )
try: # try:
import openai # import openai
openai.api_key = openai_api_key # openai.api_key = openai_api_key
if openai_api_base: # if openai_api_base:
openai.api_base = openai_api_base # openai.api_base = openai_api_base
if openai_organization: # if openai_organization:
openai.organization = openai_organization # openai.organization = openai_organization
if openai_proxy: # if openai_proxy:
openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501 # openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501
except ImportError: # except ImportError:
raise ImportError( # raise ImportError(
"Could not import openai python package. " # "Could not import openai python package. "
"Please install it with `pip install openai`." # "Please install it with `pip install openai`."
) # )
try: # try:
values["client"] = openai.ChatCompletion # values["client"] = openai.ChatCompletion
except AttributeError: # except AttributeError:
raise ValueError( # raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely " # "`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it " # "due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`." # "with `pip install --upgrade openai`."
) # )
warnings.warn( # warnings.warn(
"You are trying to use a chat model. This way of initializing it is " # "You are trying to use a chat model. This way of initializing it is "
"no longer supported. Instead, please use: " # "no longer supported. Instead, please use: "
"`from langchain.chat_models import ChatOpenAI`" # "`from langchain.chat_models import ChatOpenAI`"
) # )
return values # return values
@property # @property
def _default_params(self) -> Dict[str, Any]: # def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API.""" # """Get the default parameters for calling OpenAI API."""
return self.model_kwargs # return self.model_kwargs
def _get_chat_params( # def _get_chat_params(
self, prompts: List[str], stop: Optional[List[str]] = None # self, prompts: List[str], stop: Optional[List[str]] = None
) -> Tuple: # ) -> Tuple:
if len(prompts) > 1: # if len(prompts) > 1:
raise ValueError( # raise ValueError(
f"OpenAIChat currently only supports single prompt, got {prompts}" # f"OpenAIChat currently only supports single prompt, got {prompts}"
) # )
messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}] # messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} # params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
if stop is not None: # if stop is not None:
if "stop" in params: # if "stop" in params:
raise ValueError("`stop` found in both the input and default params.") # raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop # params["stop"] = stop
if params.get("max_tokens") == -1: # if params.get("max_tokens") == -1:
# for ChatGPT api, omitting max_tokens is equivalent to having no limit # # for ChatGPT api, omitting max_tokens is equivalent to having no limit
del params["max_tokens"] # del params["max_tokens"]
return messages, params # return messages, params
def _stream( # def _stream(
self, # self,
prompt: str, # prompt: str,
stop: Optional[List[str]] = None, # stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, # run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, # **kwargs: Any,
) -> Iterator[GenerationChunk]: # ) -> Iterator[GenerationChunk]:
messages, params = self._get_chat_params([prompt], stop) # messages, params = self._get_chat_params([prompt], stop)
params = {**params, **kwargs, "stream": True} # params = {**params, **kwargs, "stream": True}
for stream_resp in completion_with_retry( # for stream_resp in completion_with_retry(
self, messages=messages, run_manager=run_manager, **params # self, messages=messages, run_manager=run_manager, **params
): # ):
token = stream_resp["choices"][0]["delta"].get("content", "") # token = stream_resp["choices"][0]["delta"].get("content", "")
chunk = GenerationChunk(text=token) # chunk = GenerationChunk(text=token)
yield chunk # yield chunk
if run_manager: # if run_manager:
run_manager.on_llm_new_token(token, chunk=chunk) # run_manager.on_llm_new_token(token, chunk=chunk)
async def _astream( # async def _astream(
self, # self,
prompt: str, # prompt: str,
stop: Optional[List[str]] = None, # stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, # run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, # **kwargs: Any,
) -> AsyncIterator[GenerationChunk]: # ) -> AsyncIterator[GenerationChunk]:
messages, params = self._get_chat_params([prompt], stop) # messages, params = self._get_chat_params([prompt], stop)
params = {**params, **kwargs, "stream": True} # params = {**params, **kwargs, "stream": True}
async for stream_resp in await acompletion_with_retry( # async for stream_resp in await acompletion_with_retry(
self, messages=messages, run_manager=run_manager, **params # self, messages=messages, run_manager=run_manager, **params
): # ):
token = stream_resp["choices"][0]["delta"].get("content", "") # token = stream_resp["choices"][0]["delta"].get("content", "")
chunk = GenerationChunk(text=token) # chunk = GenerationChunk(text=token)
yield chunk # yield chunk
if run_manager: # if run_manager:
await run_manager.on_llm_new_token(token, chunk=chunk) # await run_manager.on_llm_new_token(token, chunk=chunk)
def _generate( # def _generate(
self, # self,
prompts: List[str], # prompts: List[str],
stop: Optional[List[str]] = None, # stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, # run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, # **kwargs: Any,
) -> LLMResult: # ) -> LLMResult:
if self.streaming: # if self.streaming:
generation: Optional[GenerationChunk] = None # generation: Optional[GenerationChunk] = None
for chunk in self._stream(prompts[0], stop, run_manager, **kwargs): # for chunk in self._stream(prompts[0], stop, run_manager, **kwargs):
if generation is None: # if generation is None:
generation = chunk # generation = chunk
else: # else:
generation += chunk # generation += chunk
assert generation is not None # assert generation is not None
return LLMResult(generations=[[generation]]) # return LLMResult(generations=[[generation]])
messages, params = self._get_chat_params(prompts, stop) # messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs} # params = {**params, **kwargs}
full_response = completion_with_retry( # full_response = completion_with_retry(
self, messages=messages, run_manager=run_manager, **params # self, messages=messages, run_manager=run_manager, **params
) # )
llm_output = { # llm_output = {
"token_usage": full_response["usage"], # "token_usage": full_response["usage"],
"model_name": self.model_name, # "model_name": self.model_name,
} # }
return LLMResult( # return LLMResult(
generations=[ # generations=[
[Generation(text=full_response["choices"][0]["message"]["content"])] # [Generation(text=full_response["choices"][0]["message"]["content"])]
], # ],
llm_output=llm_output, # llm_output=llm_output,
) # )
async def _agenerate( # async def _agenerate(
self, # self,
prompts: List[str], # prompts: List[str],
stop: Optional[List[str]] = None, # stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, # run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, # **kwargs: Any,
) -> LLMResult: # ) -> LLMResult:
if self.streaming: # if self.streaming:
generation: Optional[GenerationChunk] = None # generation: Optional[GenerationChunk] = None
async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs): # async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs):
if generation is None: # if generation is None:
generation = chunk # generation = chunk
else: # else:
generation += chunk # generation += chunk
assert generation is not None # assert generation is not None
return LLMResult(generations=[[generation]]) # return LLMResult(generations=[[generation]])
messages, params = self._get_chat_params(prompts, stop) # messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs} # params = {**params, **kwargs}
full_response = await acompletion_with_retry( # full_response = await acompletion_with_retry(
self, messages=messages, run_manager=run_manager, **params # self, messages=messages, run_manager=run_manager, **params
) # )
llm_output = { # llm_output = {
"token_usage": full_response["usage"], # "token_usage": full_response["usage"],
"model_name": self.model_name, # "model_name": self.model_name,
} # }
return LLMResult( # return LLMResult(
generations=[ # generations=[
[Generation(text=full_response["choices"][0]["message"]["content"])] # [Generation(text=full_response["choices"][0]["message"]["content"])]
], # ],
llm_output=llm_output, # llm_output=llm_output,
) # )
@property # @property
def _identifying_params(self) -> Mapping[str, Any]: # def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters.""" # """Get the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params} # return {**{"model_name": self.model_name}, **self._default_params}
@property # @property
def _llm_type(self) -> str: # def _llm_type(self) -> str:
"""Return type of llm.""" # """Return type of llm."""
return "openai-chat" # return "openai-chat"
def get_token_ids(self, text: str) -> List[int]: # def get_token_ids(self, text: str) -> List[int]:
"""Get the token IDs using the tiktoken package.""" # """Get the token IDs using the tiktoken package."""
# tiktoken NOT supported for Python < 3.8 # # tiktoken NOT supported for Python < 3.8
if sys.version_info[1] < 8: # if sys.version_info[1] < 8:
return super().get_token_ids(text) # return super().get_token_ids(text)
try: # try:
import tiktoken # import tiktoken
except ImportError: # except ImportError:
raise ImportError( # raise ImportError(
"Could not import tiktoken python package. " # "Could not import tiktoken python package. "
"This is needed in order to calculate get_num_tokens. " # "This is needed in order to calculate get_num_tokens. "
"Please install it with `pip install tiktoken`." # "Please install it with `pip install tiktoken`."
) # )
enc = tiktoken.encoding_for_model(self.model_name) # enc = tiktoken.encoding_for_model(self.model_name)
return enc.encode( # return enc.encode(
text, # text,
allowed_special=self.allowed_special, # allowed_special=self.allowed_special,
disallowed_special=self.disallowed_special, # disallowed_special=self.disallowed_special,
) # )
Loading…
Cancel
Save