From 4abc8232153d9998c05260563defce3a0ba30269 Mon Sep 17 00:00:00 2001 From: Kye Date: Wed, 30 Aug 2023 11:27:00 -0400 Subject: [PATCH] scaffold of openai class --- swarms/models/openai.py | 379 +++++++++++++++++++++++++++++++--------- 1 file changed, 292 insertions(+), 87 deletions(-) diff --git a/swarms/models/openai.py b/swarms/models/openai.py index 89a3b72e..627a2eb4 100644 --- a/swarms/models/openai.py +++ b/swarms/models/openai.py @@ -1,99 +1,304 @@ #kye #aug 8, 11:51 +import warnings +import logging +import sys +from typing import ( + Any, + Collection, + Dict, + Field, + List, + Literal, + Optional, + Tuple, + Union, + AbstractSet +) -from simpleaichat import AIChat, AsyncAIChat -import asyncio +import openai +import tiktoken +import os +def get_from_dict_or_env( + data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None +) -> str: + """Get a value from a dictionary or an environment variable.""" + if key in data and data[key]: + return data[key] + else: + return get_from_env(key, env_key, default=default) -class OpenAI: - def __init__(self, - api_key=None, - system=None, - console=True, - model=None, - params=None, - save_messages=True): - self.api_key = api_key or self.fetch_api_key() - self.system = system or "You are a helpful assistant" - try: - - self.ai = AIChat(api_key=self.api_key, - system=self.system, - console=self.console, - model=self.model, - params=self.params, - save_messages=self.save_messages) - - self.async_ai = AsyncAIChat( - api_key=self.api_key, - system=self.system, - console=self.console, - model=self.model, - params=self.params, - save_messages=self.save_messages - ) - - except Exception as error: - raise ValueError(f"Failed to initialize the chat with error: {error}, check inputs and input types") +def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str: + """Get a value from a dictionary or an environment variable.""" + if env_key in os.environ and os.environ[env_key]: + return os.environ[env_key] + elif default is not None: + return default + else: + raise ValueError( + f"Did not find {key}, please add an environment variable" + f" `{env_key}` which contains it, or pass" + f" `{key}` as a named parameter." + ) - def __call__(self, message, **kwargs): - try: - return self.ai(message, **kwargs) - except Exception as error: - print(f"Error in OpenAI, {error}") - - def generate(self, message, **kwargs): - try: - return self.ai(message, **kwargs) - except Exception as error: - print(f"Error in OpenAI, {error}") - - async def generate_async(self, message, **kwargs): - try: - return await self.async_ai(message, **kwargs) - except Exception as error: - raise Exception(f"Error in asynchronous OpenAI Call, {error}") - - def initialize_chat(self, ids): - for id in ids: - try: - self.async_ai.new_session(api_key=self.api_key, id=id) - except Exception as error: - raise ValueError(f"Failed to initialize session for ID {id} with error: {error}") - - async def ask_multiple(self, ids, question_template): + + + +class OpenAIChat(BaseLLM): + """OpenAI Chat large language models. + + To use, you should have the ``openai`` python package installed, and the + environment variable ``OPENAI_API_KEY`` set with your API key. + + Any parameters that are valid to be passed to the openai.create call can be passed + in, even if not explicitly saved on this class. + + Example: + .. code-block:: python + + from langchain.llms import OpenAIChat + openaichat = OpenAIChat(model_name="gpt-3.5-turbo") + """ + + client: Any #: :meta private: + model_name: str = "gpt-3.5-turbo" + """Model name to use.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" + openai_api_key: Optional[str] = None + openai_api_base: Optional[str] = None + # to support explicit proxy for OpenAI + openai_proxy: Optional[str] = None + max_retries: int = 6 + """Maximum number of retries to make when generating.""" + prefix_messages: List = Field(default_factory=list) + """Series of messages for Chat input.""" + streaming: bool = False + """Whether to stream the results or not.""" + allowed_special: Union[Literal["all"], AbstractSet[str]] = set() + """Set of special tokens that are allowed。""" + disallowed_special: Union[Literal["all"], Collection[str]] = "all" + """Set of special tokens that are not allowed。""" + + @root_validator(pre=True) + def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = {field.alias for field in cls.__fields__.values()} + + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name not in all_required_field_names: + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") + extra[field_name] = values.pop(field_name) + values["model_kwargs"] = extra + return values + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + openai_api_key = get_from_dict_or_env( + values, "openai_api_key", "OPENAI_API_KEY" + ) + openai_api_base = get_from_dict_or_env( + values, + "openai_api_base", + "OPENAI_API_BASE", + default="", + ) + openai_proxy = get_from_dict_or_env( + values, + "openai_proxy", + "OPENAI_PROXY", + default="", + ) + openai_organization = get_from_dict_or_env( + values, "openai_organization", "OPENAI_ORGANIZATION", default="" + ) try: - self.initialize_chat(ids) - tasks = [self.async_ai(question_template.format(id=id), id=id) for id in ids] - return await asyncio.gather(*tasks) - except Exception as error: - raise Exception(f"Error in ask_multiple: method: {error}") - - async def stream_multiple(self, ids, question_template): + import openai + + openai.api_key = openai_api_key + if openai_api_base: + openai.api_base = openai_api_base + if openai_organization: + openai.organization = openai_organization + if openai_proxy: + openai.proxy = {"http": openai_proxy, "https": openai_proxy} # type: ignore[assignment] # noqa: E501 + except ImportError: + raise ImportError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) try: - self.initialize_chat(ids) - - async def stream_id(id): - async for chunk in await self.async_ai.stream(question_template.format(id=id), id=id): - response = chunk["response"] - return response - - tasks = [stream_id(id) for id in ids] - return await asyncio.gather(*tasks) - except Exception as error: - raise Exception(f"Error in stream_multiple method: {error}") - - def fetch_api_key(self): - pass + values["client"] = openai.ChatCompletion + except AttributeError: + raise ValueError( + "`openai` has no `ChatCompletion` attribute, this is likely " + "due to an old version of the openai package. Try upgrading it " + "with `pip install --upgrade openai`." + ) + warnings.warn( + "You are trying to use a chat model. This way of initializing it is " + "no longer supported. Instead, please use: " + "`from langchain.chat_models import ChatOpenAI`" + ) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling OpenAI API.""" + return self.model_kwargs + + def _get_chat_params( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> Tuple: + if len(prompts) > 1: + raise ValueError( + f"OpenAIChat currently only supports single prompt, got {prompts}" + ) + messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}] + params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + if params.get("max_tokens") == -1: + # for ChatGPT api, omitting max_tokens is equivalent to having no limit + del params["max_tokens"] + return messages, params + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + messages, params = self._get_chat_params([prompt], stop) + params = {**params, **kwargs, "stream": True} + for stream_resp in completion_with_retry( + self, messages=messages, run_manager=run_manager, **params + ): + token = stream_resp["choices"][0]["delta"].get("content", "") + chunk = GenerationChunk(text=token) + yield chunk + if run_manager: + run_manager.on_llm_new_token(token, chunk=chunk) -#usage -#from swarms import OpenAI() -#chat = OpenAI() -#response = chat.generate("Hello World") -#print(response) + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + messages, params = self._get_chat_params([prompt], stop) + params = {**params, **kwargs, "stream": True} + async for stream_resp in await acompletion_with_retry( + self, messages=messages, run_manager=run_manager, **params + ): + token = stream_resp["choices"][0]["delta"].get("content", "") + chunk = GenerationChunk(text=token) + yield chunk + if run_manager: + await run_manager.on_llm_new_token(token, chunk=chunk) + + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + if self.streaming: + generation: Optional[GenerationChunk] = None + for chunk in self._stream(prompts[0], stop, run_manager, **kwargs): + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return LLMResult(generations=[[generation]]) + + messages, params = self._get_chat_params(prompts, stop) + params = {**params, **kwargs} + full_response = completion_with_retry( + self, messages=messages, run_manager=run_manager, **params + ) + llm_output = { + "token_usage": full_response["usage"], + "model_name": self.model_name, + } + return LLMResult( + generations=[ + [Generation(text=full_response["choices"][0]["message"]["content"])] + ], + llm_output=llm_output, + ) + + async def _agenerate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + if self.streaming: + generation: Optional[GenerationChunk] = None + async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs): + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return LLMResult(generations=[[generation]]) + + messages, params = self._get_chat_params(prompts, stop) + params = {**params, **kwargs} + full_response = await acompletion_with_retry( + self, messages=messages, run_manager=run_manager, **params + ) + llm_output = { + "token_usage": full_response["usage"], + "model_name": self.model_name, + } + return LLMResult( + generations=[ + [Generation(text=full_response["choices"][0]["message"]["content"])] + ], + llm_output=llm_output, + ) + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{"model_name": self.model_name}, **self._default_params} + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "openai-chat" + + def get_token_ids(self, text: str) -> List[int]: + """Get the token IDs using the tiktoken package.""" + # tiktoken NOT supported for Python < 3.8 + if sys.version_info[1] < 8: + return super().get_token_ids(text) + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "This is needed in order to calculate get_num_tokens. " + "Please install it with `pip install tiktoken`." + ) -#async -# async_responses = asyncio.run(chat.ask_multiple(['id1', 'id2'], "How is {id}")) -# print(async_responses) + enc = tiktoken.encoding_for_model(self.model_name) + return enc.encode( + text, + allowed_special=self.allowed_special, + disallowed_special=self.disallowed_special, + ) \ No newline at end of file