parent
c02105575f
commit
66fedbd7ae
@ -1,77 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from langchain import PromptTemplate, LLMChain, HuggingFaceHub
|
|
||||||
from langchain.chat_models import ChatOpenAI
|
|
||||||
|
|
||||||
|
|
||||||
# Configure logging-
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
|
||||||
def __init__(self,
|
|
||||||
openai_api_key: Optional[str] = None,
|
|
||||||
hf_repo_id: Optional[str] = None,
|
|
||||||
hf_api_token: Optional[str] = None,
|
|
||||||
temperature: Optional[float] = 0.5,
|
|
||||||
max_length: Optional[int] = 64):
|
|
||||||
|
|
||||||
# Check if keys are in the environment variables
|
|
||||||
openai_api_key = openai_api_key or os.getenv('OPENAI_API_KEY')
|
|
||||||
hf_api_token = hf_api_token or os.getenv('HUGGINGFACEHUB_API_TOKEN')
|
|
||||||
|
|
||||||
self.openai_api_key = openai_api_key
|
|
||||||
self.hf_repo_id = hf_repo_id
|
|
||||||
self.hf_api_token = hf_api_token
|
|
||||||
self.temperature = temperature
|
|
||||||
self.max_length = max_length
|
|
||||||
|
|
||||||
# If the HuggingFace API token is provided, set it in environment variables
|
|
||||||
if self.hf_api_token:
|
|
||||||
os.environ["HUGGINGFACEHUB_API_TOKEN"] = self.hf_api_token
|
|
||||||
|
|
||||||
# Initialize the LLM object
|
|
||||||
self.initialize_llm()
|
|
||||||
|
|
||||||
def initialize_llm(self):
|
|
||||||
model_kwargs = {"temperature": self.temperature, "max_length": self.max_length}
|
|
||||||
try:
|
|
||||||
if self.hf_repo_id and self.hf_api_token:
|
|
||||||
self.llm = HuggingFaceHub(repo_id=self.hf_repo_id, model_kwargs=model_kwargs)
|
|
||||||
elif self.openai_api_key:
|
|
||||||
self.llm = ChatOpenAI(api_key=self.openai_api_key, model_kwargs=model_kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError("Please provide either OpenAI API key or both HuggingFace repository ID and API token.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Failed to initialize LLM: %s", e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def run(self, prompt: str) -> str:
|
|
||||||
template = """Question: {question}
|
|
||||||
Answer: Let's think step by step."""
|
|
||||||
try:
|
|
||||||
prompt_template = PromptTemplate(template=template, input_variables=["question"])
|
|
||||||
llm_chain = LLMChain(prompt=prompt_template, llm=self.llm)
|
|
||||||
return llm_chain.run({"question": prompt})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Failed to generate response: %s", e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
# # example
|
|
||||||
# from swarms.utils.llm import LLM
|
|
||||||
# llm_instance = LLM(openai_api_key="your_openai_key")
|
|
||||||
# result = llm_instance.run("Who won the FIFA World Cup in 1998?")
|
|
||||||
# print(result)
|
|
||||||
|
|
||||||
# # using HuggingFaceHub
|
|
||||||
# llm_instance = LLM(hf_repo_id="google/flan-t5-xl", hf_api_token="your_hf_api_token")
|
|
||||||
# result = llm_instance.run("Who won the FIFA World Cup in 1998?")
|
|
||||||
# print(result)
|
|
||||||
|
|
||||||
|
|
||||||
# make super easy to chaneg parameters, in class, use cpu and
|
|
||||||
#add qlora, 8bit inference
|
|
||||||
# look into adding deepspeed
|
|
@ -0,0 +1,777 @@
|
|||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
from pydantic import Field
|
||||||
|
from swarm.utils.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
######### helpers
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
|
def _import_tiktoken() -> Any:
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import tiktoken python package. "
|
||||||
|
"This is needed in order to calculate get_token_ids. "
|
||||||
|
"Please install it with `pip install tiktoken`."
|
||||||
|
)
|
||||||
|
return tiktoken
|
||||||
|
|
||||||
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
message_dict = {"role": message.role, "content": message.content}
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
if "function_call" in message.additional_kwargs:
|
||||||
|
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, FunctionMessage):
|
||||||
|
message_dict = {
|
||||||
|
"role": "function",
|
||||||
|
"content": message.content,
|
||||||
|
"name": message.name,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
if "name" in message.additional_kwargs:
|
||||||
|
message_dict["name"] = message.additional_kwargs["name"]
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
class BaseMessage:
|
||||||
|
"""Base message class."""
|
||||||
|
|
||||||
|
|
||||||
|
class HumanMessage(BaseMessage):
|
||||||
|
"""Human message class."""
|
||||||
|
|
||||||
|
def __init__(self, content: str):
|
||||||
|
self.role = "user"
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
|
||||||
|
class AIMessage(BaseMessage):
|
||||||
|
"""AI message class."""
|
||||||
|
|
||||||
|
def __init__(self, content: str, additional_kwargs: Optional[Dict[str, Any]] = None):
|
||||||
|
self.role = "assistant"
|
||||||
|
self.content = content
|
||||||
|
self.additional_kwargs = additional_kwargs or {}
|
||||||
|
|
||||||
|
|
||||||
|
class SystemMessage(BaseMessage):
|
||||||
|
"""System message class."""
|
||||||
|
|
||||||
|
def __init__(self, content: str):
|
||||||
|
self.role = "system"
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionMessage(BaseMessage):
|
||||||
|
"""Function message class."""
|
||||||
|
|
||||||
|
def __init__(self, content: str, name: str):
|
||||||
|
self.role = "function"
|
||||||
|
self.content = content
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseMessage):
|
||||||
|
"""Chat message class."""
|
||||||
|
|
||||||
|
def __init__(self, content: str, role: str):
|
||||||
|
self.role = role
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMessageChunk:
|
||||||
|
"""Base message chunk class."""
|
||||||
|
|
||||||
|
|
||||||
|
class HumanMessageChunk(BaseMessageChunk):
|
||||||
|
"""Human message chunk class."""
|
||||||
|
|
||||||
|
def __init__(self, content: str):
|
||||||
|
self.role = "user"
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
|
||||||
|
class AIMessageChunk(BaseMessageChunk):
|
||||||
|
"""AI message chunk class."""
|
||||||
|
|
||||||
|
def __init__(self, content: str, additional_kwargs: Optional[Dict[str, Any]] = None):
|
||||||
|
self.role = "assistant"
|
||||||
|
self.content = content
|
||||||
|
self.additional_kwargs = additional_kwargs or {}
|
||||||
|
|
||||||
|
|
||||||
|
class SystemMessageChunk(BaseMessageChunk):
|
||||||
|
"""System message chunk class."""
|
||||||
|
|
||||||
|
def __init__(self, content: str):
|
||||||
|
self.role = "system"
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionMessageChunk(BaseMessageChunk):
|
||||||
|
"""Function message chunk class."""
|
||||||
|
|
||||||
|
def __init__(self, content: str, name: str):
|
||||||
|
self.role = "function"
|
||||||
|
self.content = content
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
class ChatMessageChunk(BaseMessageChunk):
|
||||||
|
"""Chat message chunk class."""
|
||||||
|
|
||||||
|
def __init__(self, content: str, role: str):
|
||||||
|
self.role = role
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
|
||||||
|
def convert_openai_messages(messages: List[dict]) -> List[BaseMessage]:
|
||||||
|
"""Convert dictionaries representing OpenAI messages to LangChain format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of dictionaries representing OpenAI messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LangChain BaseMessage objects.
|
||||||
|
"""
|
||||||
|
converted_messages = []
|
||||||
|
for m in messages:
|
||||||
|
role = m.get("role")
|
||||||
|
content = m.get("content", "")
|
||||||
|
if m.get("function_call"):
|
||||||
|
additional_kwargs = {"function_call": dict(m["function_call"])}
|
||||||
|
else:
|
||||||
|
additional_kwargs = {}
|
||||||
|
|
||||||
|
if role == "user":
|
||||||
|
converted_messages.append(HumanMessage(content=content))
|
||||||
|
elif role == "assistant":
|
||||||
|
converted_messages.append(AIMessage(content=content, additional_kwargs=additional_kwargs))
|
||||||
|
elif role == "system":
|
||||||
|
converted_messages.append(SystemMessage(content=content))
|
||||||
|
elif role == "function":
|
||||||
|
converted_messages.append(FunctionMessage(content=content, name=m["name"]))
|
||||||
|
else:
|
||||||
|
converted_messages.append(ChatMessage(content=content, role=role))
|
||||||
|
|
||||||
|
return converted_messages
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGenerationChunk:
|
||||||
|
"""Chat generation chunk class."""
|
||||||
|
|
||||||
|
def __init__(self, message: BaseMessageChunk):
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
def __add__(self, other: "ChatGenerationChunk") -> "ChatGenerationChunk":
|
||||||
|
if isinstance(self.message, AIMessageChunk) and isinstance(other.message, AIMessageChunk):
|
||||||
|
combined_kwargs = {
|
||||||
|
**self.message.additional_kwargs,
|
||||||
|
**other.message.additional_kwargs,
|
||||||
|
}
|
||||||
|
return ChatGenerationChunk(
|
||||||
|
AIMessageChunk(content=self.message.content + other.message.content, additional_kwargs=combined_kwargs)
|
||||||
|
)
|
||||||
|
return ChatGenerationChunk(BaseMessageChunk(content=self.message.content + other.message.content))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content(self) -> str:
|
||||||
|
return self.message.content
|
||||||
|
|
||||||
|
@property
|
||||||
|
def additional_kwargs(self) -> Dict[str, Any]:
|
||||||
|
return getattr(self.message, "additional_kwargs", {})
|
||||||
|
|
||||||
|
|
||||||
|
class ChatResult:
|
||||||
|
"""Chat result class."""
|
||||||
|
|
||||||
|
def __init__(self, generations: List[ChatGenerationChunk]):
|
||||||
|
self.generations = generations
|
||||||
|
|
||||||
|
|
||||||
|
class BaseChatModel:
|
||||||
|
"""Base chat model class."""
|
||||||
|
|
||||||
|
def _convert_delta_to_message_chunk(self, _dict: Mapping[str, Any], default_class: type[BaseMessageChunk]) -> BaseMessageChunk:
|
||||||
|
role = _dict.get("role")
|
||||||
|
content = _dict.get("content") or ""
|
||||||
|
if _dict.get("function_call"):
|
||||||
|
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
||||||
|
else:
|
||||||
|
additional_kwargs = {}
|
||||||
|
|
||||||
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
|
return HumanMessageChunk(content=content)
|
||||||
|
elif role == "assistant" or default_class == AIMessageChunk:
|
||||||
|
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||||
|
elif role == "system" or default_class == SystemMessageChunk:
|
||||||
|
return SystemMessageChunk(content=content)
|
||||||
|
elif role == "function" or default_class == FunctionMessageChunk:
|
||||||
|
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||||
|
elif role or default_class == ChatMessageChunk:
|
||||||
|
return ChatMessageChunk(content=content, role=role)
|
||||||
|
else:
|
||||||
|
return default_class(content=content)
|
||||||
|
|
||||||
|
def _convert_dict_to_message(self, _dict: Mapping[str, Any]) -> BaseMessage:
|
||||||
|
role = _dict["role"]
|
||||||
|
if role == "user":
|
||||||
|
return HumanMessage(content=_dict["content"])
|
||||||
|
elif role == "assistant":
|
||||||
|
content = _dict.get("content", "") or ""
|
||||||
|
if _dict.get("function_call"):
|
||||||
|
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
||||||
|
else:
|
||||||
|
additional_kwargs = {}
|
||||||
|
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||||
|
elif role == "system":
|
||||||
|
return SystemMessage(content=_dict["content"])
|
||||||
|
elif role == "function":
|
||||||
|
return FunctionMessage(content=_dict["content"], name=_dict["name"])
|
||||||
|
else:
|
||||||
|
return ChatMessage(content=_dict["content"], role=role)
|
||||||
|
|
||||||
|
def completion_with_retry(self, run_manager: Optional[Callable] = None, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = self._create_retry_decorator(self, run_manager=run_manager)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return self.client.create(**kwargs)
|
||||||
|
|
||||||
|
return _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
def _create_retry_decorator(
|
||||||
|
self,
|
||||||
|
llm: "ChatOpenAI",
|
||||||
|
run_manager: Optional[Callable] = None,
|
||||||
|
) -> Callable[[Any], Any]:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
errors = [
|
||||||
|
openai.error.Timeout,
|
||||||
|
openai.error.APIError,
|
||||||
|
openai.error.APIConnectionError,
|
||||||
|
openai.error.RateLimitError,
|
||||||
|
openai.error.ServiceUnavailableError,
|
||||||
|
]
|
||||||
|
return create_base_retry_decorator(
|
||||||
|
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_message_dicts(
|
||||||
|
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||||
|
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||||
|
params = self._client_params
|
||||||
|
if stop is not None:
|
||||||
|
if "stop" in params:
|
||||||
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
|
params["stop"] = stop
|
||||||
|
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||||
|
return message_dicts, params
|
||||||
|
|
||||||
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||||
|
overall_token_usage: dict = {}
|
||||||
|
for output in llm_outputs:
|
||||||
|
if output is None:
|
||||||
|
# Happens in streaming
|
||||||
|
continue
|
||||||
|
token_usage = output["token_usage"]
|
||||||
|
for k, v in token_usage.items():
|
||||||
|
if k in overall_token_usage:
|
||||||
|
overall_token_usage[k] += v
|
||||||
|
else:
|
||||||
|
overall_token_usage[k] = v
|
||||||
|
return {"token_usage": overall_token_usage, "model_name": self.model_name}
|
||||||
|
|
||||||
|
|
||||||
|
class ChatOpenAI(BaseChatModel):
|
||||||
|
"""Wrapper around OpenAI Chat large language models.
|
||||||
|
|
||||||
|
To use, you should have the ``openai`` python package installed, and the
|
||||||
|
environment variable ``OPENAI_API_KEY`` set with your API key.
|
||||||
|
|
||||||
|
Any parameters that are valid to be passed to the openai.create call can be passed
|
||||||
|
in, even if not explicitly saved on this class.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
openai = ChatOpenAI(model_name="gpt-3.5-turbo")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "gpt-3.5-turbo",
|
||||||
|
temperature: float = 0.7,
|
||||||
|
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
openai_api_key: Optional[str] = None,
|
||||||
|
openai_api_base: Optional[str] = None,
|
||||||
|
openai_organization: Optional[str] = None,
|
||||||
|
openai_proxy: Optional[str] = None,
|
||||||
|
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
|
||||||
|
max_retries: int = 6,
|
||||||
|
streaming: bool = False,
|
||||||
|
n: int = 1,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
tiktoken_model_name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.temperature = temperature
|
||||||
|
self.model_kwargs = model_kwargs or {}
|
||||||
|
self.openai_api_key = openai_api_key
|
||||||
|
self.openai_api_base = openai_api_base
|
||||||
|
self.openai_organization = openai_organization
|
||||||
|
self.openai_proxy = openai_proxy
|
||||||
|
self.request_timeout = request_timeout
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.streaming = streaming
|
||||||
|
self.n = n
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.tiktoken_model_name = tiktoken_model_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the default parameters for calling OpenAI API."""
|
||||||
|
return {
|
||||||
|
"model": self.model_name,
|
||||||
|
"request_timeout": self.request_timeout,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"stream": self.streaming,
|
||||||
|
"n": self.n,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
**self.model_kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _client_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the parameters used for the openai client."""
|
||||||
|
openai_creds: Dict[str, Any] = {
|
||||||
|
"api_key": self.openai_api_key,
|
||||||
|
"api_base": self.openai_api_base,
|
||||||
|
"organization": self.openai_organization,
|
||||||
|
"model": self.model_name,
|
||||||
|
}
|
||||||
|
if self.openai_proxy:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy}
|
||||||
|
return {**self._default_params, **openai_creds}
|
||||||
|
|
||||||
|
def _get_invocation_params(
|
||||||
|
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Get the parameters used to invoke the model."""
|
||||||
|
return {
|
||||||
|
"model": self.model_name,
|
||||||
|
**super()._get_invocation_params(stop=stop),
|
||||||
|
**self._default_params,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "openai-chat"
|
||||||
|
|
||||||
|
def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
|
||||||
|
tiktoken_ = _import_tiktoken()
|
||||||
|
if self.tiktoken_model_name is not None:
|
||||||
|
model = self.tiktoken_model_name
|
||||||
|
else:
|
||||||
|
model = self.model_name
|
||||||
|
if model == "gpt-3.5-turbo":
|
||||||
|
# gpt-3.5-turbo may change over time.
|
||||||
|
# Returning num tokens assuming gpt-3.5-turbo-0301.
|
||||||
|
model = "gpt-3.5-turbo-0301"
|
||||||
|
elif model == "gpt-4":
|
||||||
|
# gpt-4 may change over time.
|
||||||
|
# Returning num tokens assuming gpt-4-0314.
|
||||||
|
model = "gpt-4-0314"
|
||||||
|
# Returns the number of tokens used by a list of messages.
|
||||||
|
try:
|
||||||
|
encoding = tiktoken_.encoding_for_model(model)
|
||||||
|
except KeyError:
|
||||||
|
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||||
|
model = "cl100k_base"
|
||||||
|
encoding = tiktoken_.get_encoding(model)
|
||||||
|
return model, encoding
|
||||||
|
|
||||||
|
def get_token_ids(self, text: str) -> List[int]:
|
||||||
|
"""Get the tokens present in the text with tiktoken package."""
|
||||||
|
# tiktoken NOT supported for Python 3.7 or below
|
||||||
|
if sys.version_info[1] <= 7:
|
||||||
|
return super().get_token_ids(text)
|
||||||
|
_, encoding_model = self._get_encoding_model()
|
||||||
|
return encoding_model.encode(text)
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||||
|
|
||||||
|
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||||
|
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||||
|
if sys.version_info[1] <= 7:
|
||||||
|
return super().get_num_tokens_from_messages(messages)
|
||||||
|
model, encoding = self._get_encoding_model()
|
||||||
|
if model.startswith("gpt-3.5-turbo-0301"):
|
||||||
|
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||||
|
tokens_per_message = 4
|
||||||
|
# if there's a name, the role is omitted
|
||||||
|
tokens_per_name = -1
|
||||||
|
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
|
||||||
|
tokens_per_message = 3
|
||||||
|
tokens_per_name = 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"get_num_tokens_from_messages() is not presently implemented "
|
||||||
|
f"for model {model}."
|
||||||
|
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||||
|
"information on how messages are converted to tokens."
|
||||||
|
)
|
||||||
|
num_tokens = 0
|
||||||
|
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
||||||
|
for message in messages_dict:
|
||||||
|
num_tokens += tokens_per_message
|
||||||
|
for key, value in message.items():
|
||||||
|
# Cast str(value) in case the message value is not a string
|
||||||
|
# This occurs with function messages
|
||||||
|
num_tokens += len(encoding.encode(str(value)))
|
||||||
|
if key == "name":
|
||||||
|
num_tokens += tokens_per_name
|
||||||
|
# every reply is primed with <im_start>assistant
|
||||||
|
num_tokens += 3
|
||||||
|
return num_tokens
|
||||||
|
|
||||||
|
def get_token_usage(self, text: str, prefix_tokens: Optional[List[str]] = None) -> int:
|
||||||
|
"""Get the number of tokens used by the provided text."""
|
||||||
|
token_ids = self.get_token_ids(text)
|
||||||
|
if prefix_tokens:
|
||||||
|
prefix_token_ids = self.get_token_ids(" ".join(prefix_tokens))
|
||||||
|
token_ids = prefix_token_ids + token_ids
|
||||||
|
return len(token_ids)
|
||||||
|
|
||||||
|
def completion_with_retry(
|
||||||
|
self, run_manager: Optional[Callable] = None, **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = self._create_retry_decorator(run_manager=run_manager)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return self.client.create(**kwargs)
|
||||||
|
|
||||||
|
return _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
async def acompletion_with_retry(
|
||||||
|
self, run_manager: Optional[Callable] = None, **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Use tenacity to retry the async completion call."""
|
||||||
|
retry_decorator = self._create_retry_decorator(run_manager=run_manager)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
|
||||||
|
return await self.client.acreate(**kwargs)
|
||||||
|
|
||||||
|
return await _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
def _create_retry_decorator(
|
||||||
|
self,
|
||||||
|
run_manager: Optional[Callable] = None,
|
||||||
|
) -> Callable[[Any], Any]:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
errors = [
|
||||||
|
openai.error.Timeout,
|
||||||
|
openai.error.APIError,
|
||||||
|
openai.error.APIConnectionError,
|
||||||
|
openai.error.RateLimitError,
|
||||||
|
openai.error.ServiceUnavailableError,
|
||||||
|
]
|
||||||
|
return create_base_retry_decorator(
|
||||||
|
error_types=errors, max_retries=self.max_retries, run_manager=run_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[Callable] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs, "stream": True}
|
||||||
|
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
async for chunk in await self.acompletion_with_retry(
|
||||||
|
messages=message_dicts, run_manager=run_manager, **params
|
||||||
|
):
|
||||||
|
if len(chunk["choices"]) == 0:
|
||||||
|
continue
|
||||||
|
delta = chunk["choices"][0]["delta"]
|
||||||
|
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#=================================
|
||||||
|
# from typing import (
|
||||||
|
# Any,
|
||||||
|
# Dict,
|
||||||
|
# List,
|
||||||
|
# Mapping,
|
||||||
|
# Optional,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# import openai
|
||||||
|
|
||||||
|
|
||||||
|
# class ChatResult:
|
||||||
|
# """Wrapper for the result of the chat generation process."""
|
||||||
|
|
||||||
|
# def __init__(
|
||||||
|
# self,
|
||||||
|
# generations: List[ChatGeneration],
|
||||||
|
# llm_output: Optional[Mapping[str, Any]] = None,
|
||||||
|
# ):
|
||||||
|
# self.generations = generations
|
||||||
|
# self.llm_output = llm_output or {}
|
||||||
|
|
||||||
|
# class BaseMessage:
|
||||||
|
# """Base class for different types of messages."""
|
||||||
|
|
||||||
|
# def __init__(self, content: str):
|
||||||
|
# self.content = content
|
||||||
|
|
||||||
|
# class AIMessage(BaseMessage):
|
||||||
|
# """Message from the AI Assistant."""
|
||||||
|
|
||||||
|
# def __init__(self, content: str, additional_kwargs: Optional[Dict[str, Any]] = None):
|
||||||
|
# super().__init__(content)
|
||||||
|
# self.additional_kwargs = additional_kwargs or {}
|
||||||
|
|
||||||
|
# class HumanMessage(BaseMessage):
|
||||||
|
# """Message from the User."""
|
||||||
|
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# class SystemMessage(BaseMessage):
|
||||||
|
# """System message."""
|
||||||
|
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# class FunctionMessage(BaseMessage):
|
||||||
|
# """Function message."""
|
||||||
|
|
||||||
|
# def __init__(self, content: str, name: str):
|
||||||
|
# super().__init__(content)
|
||||||
|
# self.name = name
|
||||||
|
|
||||||
|
# class ChatGeneration:
|
||||||
|
# """Wrapper for the chat generation information."""
|
||||||
|
|
||||||
|
# def __init__(
|
||||||
|
# self, message: BaseMessage, generation_info: Optional[Mapping[str, Any]] = None
|
||||||
|
# ):
|
||||||
|
# self.message = message
|
||||||
|
# self.generation_info = generation_info or {}
|
||||||
|
|
||||||
|
# class ChatGenerationChunk:
|
||||||
|
# """Wrapper for a chunk of chat generation."""
|
||||||
|
|
||||||
|
# def __init__(self, message: BaseMessage):
|
||||||
|
# self.message = message
|
||||||
|
|
||||||
|
# def get_from_env_or_raise(var_name: str) -> str:
|
||||||
|
# value = os.getenv(var_name)
|
||||||
|
# if value is None:
|
||||||
|
# raise ValueError(f"Environment variable {var_name} is not set.")
|
||||||
|
# return value
|
||||||
|
|
||||||
|
|
||||||
|
# class OpenAI:
|
||||||
|
# """Wrapper around OpenAI Chat large language models.
|
||||||
|
|
||||||
|
# To use, you should have the ``openai`` python package installed, and the
|
||||||
|
# environment variable ``OPENAI_API_KEY`` set with your API key.
|
||||||
|
|
||||||
|
# Example:
|
||||||
|
# .. code-block:: python
|
||||||
|
|
||||||
|
# from langchain.chat_models import OpenAI
|
||||||
|
# openai = OpenAI(model_name="gpt-3.5-turbo")
|
||||||
|
# """
|
||||||
|
|
||||||
|
# def __init__(
|
||||||
|
# self,
|
||||||
|
# model_name: str = "gpt-3.5-turbo",
|
||||||
|
# temperature: float = 0.7,
|
||||||
|
# openai_api_key: Optional[str] = None,
|
||||||
|
# request_timeout: Optional[float] = None,
|
||||||
|
# max_retries: int = 6,
|
||||||
|
# ):
|
||||||
|
# self.model_name = model_name
|
||||||
|
# self.temperature = temperature
|
||||||
|
# self.openai_api_key = openai_api_key
|
||||||
|
# self.request_timeout = request_timeout
|
||||||
|
# self.max_retries = max_retries
|
||||||
|
# self._initialize_openai()
|
||||||
|
|
||||||
|
# def _initialize_openai(self):
|
||||||
|
# """Initialize the OpenAI client."""
|
||||||
|
# if self.openai_api_key is None:
|
||||||
|
# raise ValueError("OPENAI_API_KEY environment variable is not set.")
|
||||||
|
|
||||||
|
# openai.api_key = self.openai_api_key
|
||||||
|
|
||||||
|
# def _create_retry_decorator(self):
|
||||||
|
# """Create a decorator to handle API call retries."""
|
||||||
|
# errors = [
|
||||||
|
# openai.error.Timeout,
|
||||||
|
# openai.error.APIError,
|
||||||
|
# openai.error.APIConnectionError,
|
||||||
|
# openai.error.RateLimitError,
|
||||||
|
# openai.error.ServiceUnavailableError,
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# def retry_decorator(func):
|
||||||
|
# @wraps(func)
|
||||||
|
# def wrapper(*args, **kwargs):
|
||||||
|
# for _ in range(self.max_retries):
|
||||||
|
# try:
|
||||||
|
# return func(*args, **kwargs)
|
||||||
|
# except tuple(errors):
|
||||||
|
# continue
|
||||||
|
# raise ValueError("Max retries reached. Unable to complete the API call.")
|
||||||
|
|
||||||
|
# return wrapper
|
||||||
|
|
||||||
|
# return retry_decorator
|
||||||
|
|
||||||
|
# def _create_message_dict(self, message: BaseMessage) -> Dict[str, Any]:
|
||||||
|
# """Convert a LangChain message to an OpenAI message dictionary."""
|
||||||
|
# role = message.role
|
||||||
|
# content = message.content
|
||||||
|
# message_dict = {"role": role, "content": content}
|
||||||
|
|
||||||
|
# if role == "assistant" and isinstance(message, AIMessage):
|
||||||
|
# message_dict["function_call"] = message.additional_kwargs.get("function_call", {})
|
||||||
|
|
||||||
|
# if role == "function" and isinstance(message, FunctionMessage):
|
||||||
|
# message_dict["name"] = message.name
|
||||||
|
|
||||||
|
# return message_dict
|
||||||
|
|
||||||
|
# def _create_message_dicts(self, messages: List[BaseMessage]) -> List[Dict[str, Any]]:
|
||||||
|
# """Convert a list of LangChain messages to a list of OpenAI message dictionaries."""
|
||||||
|
# return [self._create_message_dict(message) for message in messages]
|
||||||
|
|
||||||
|
# @retry_decorator
|
||||||
|
# def _openai_completion(self, messages: List[Dict[str, Any]], params: Dict[str, Any]) -> Any:
|
||||||
|
# """Call the OpenAI Chat Completion API."""
|
||||||
|
# response = openai.ChatCompletion.create(messages=messages, **params)
|
||||||
|
# return response
|
||||||
|
|
||||||
|
# def generate(
|
||||||
|
# self,
|
||||||
|
# messages: List[BaseMessage],
|
||||||
|
# stop: Optional[List[str]] = None,
|
||||||
|
# **kwargs: Any,
|
||||||
|
# ) -> ChatResult:
|
||||||
|
# """Generate a response using the OpenAI Chat model.
|
||||||
|
|
||||||
|
# Args:
|
||||||
|
# messages (List[BaseMessage]): List of messages in the conversation.
|
||||||
|
# stop (Optional[List[str]]): List of stop sequences to stop generation.
|
||||||
|
|
||||||
|
# Returns:
|
||||||
|
# ChatResult: The generated response wrapped in ChatResult object.
|
||||||
|
# """
|
||||||
|
# params = {
|
||||||
|
# "model": self.model_name,
|
||||||
|
# "temperature": self.temperature,
|
||||||
|
# "max_tokens": kwargs.get("max_tokens"),
|
||||||
|
# "stream": kwargs.get("streaming", False),
|
||||||
|
# "n": kwargs.get("n", 1),
|
||||||
|
# "request_timeout": kwargs.get("request_timeout", self.request_timeout),
|
||||||
|
# }
|
||||||
|
|
||||||
|
# messages_dicts = self._create_message_dicts(messages)
|
||||||
|
# response = self._openai_completion(messages_dicts, params)
|
||||||
|
|
||||||
|
# # Process the response and create ChatResult
|
||||||
|
# generations = []
|
||||||
|
# for choice in response["choices"]:
|
||||||
|
# message = self._convert_message(choice["message"])
|
||||||
|
# generation_info = {"finish_reason": choice.get("finish_reason")}
|
||||||
|
# generation = ChatGeneration(message=message, generation_info=generation_info)
|
||||||
|
# generations.append(generation)
|
||||||
|
|
||||||
|
# return ChatResult(generations=generations)
|
||||||
|
|
||||||
|
# def _convert_message(self, message_dict: Dict[str, Any]) -> BaseMessage:
|
||||||
|
# """Convert an OpenAI message dictionary to a LangChain message."""
|
||||||
|
# role = message_dict["role"]
|
||||||
|
# content = message_dict["content"]
|
||||||
|
|
||||||
|
# if role == "user":
|
||||||
|
# return HumanMessage(content=content)
|
||||||
|
# elif role == "assistant":
|
||||||
|
# additional_kwargs = message_dict.get("function_call", {})
|
||||||
|
# return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||||
|
# elif role == "system":
|
||||||
|
# return SystemMessage(content=content)
|
||||||
|
# elif role == "function":
|
||||||
|
# name = message_dict.get("name", "")
|
||||||
|
# return FunctionMessage(content=content, name=name)
|
||||||
|
# else:
|
||||||
|
# raise ValueError(f"Invalid role found in the message: {role}")
|
@ -0,0 +1,177 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from swarms.utils.logger import logger
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, root_validator
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
|
||||||
|
###############helpers
|
||||||
|
|
||||||
|
class ChatGooglePalmError(Exception):
|
||||||
|
"""Error raised when there is an issue with the Google PaLM API."""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate_at_stop_tokens(
|
||||||
|
text: str,
|
||||||
|
stop: Optional[List[str]],
|
||||||
|
) -> str:
|
||||||
|
"""Truncates text at the earliest stop token found."""
|
||||||
|
if stop is None:
|
||||||
|
return text
|
||||||
|
|
||||||
|
for stop_token in stop:
|
||||||
|
stop_token_idx = text.find(stop_token)
|
||||||
|
if stop_token_idx != -1:
|
||||||
|
text = text[:stop_token_idx]
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _response_to_result(response: genai.types.ChatResponse, stop: Optional[List[str]]) -> Dict[str, Any]:
|
||||||
|
"""Convert a PaLM chat response to a result dictionary."""
|
||||||
|
# Code will need to be rewritten due to removal of schema imports.
|
||||||
|
# Instead of creating a ChatResult object, you would directly return a dictionary.
|
||||||
|
result = {
|
||||||
|
"id": response.id,
|
||||||
|
"created": response.created,
|
||||||
|
"model": response.model,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": response.usage.prompt_tokens,
|
||||||
|
"completion_tokens": response.usage.completion_tokens,
|
||||||
|
"total_tokens": response.usage.total_tokens,
|
||||||
|
},
|
||||||
|
"choices": [],
|
||||||
|
}
|
||||||
|
for choice in response.choices:
|
||||||
|
result["choices"].append({
|
||||||
|
"text": _truncate_at_stop_tokens(choice.text, stop),
|
||||||
|
"index": choice.index,
|
||||||
|
"finish_reason": choice.finish_reason,
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _messages_to_prompt_dict(messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
|
"""Convert a list of message dictionaries to a prompt dictionary."""
|
||||||
|
# Code will need to be rewritten due to removal of schema imports.
|
||||||
|
# Instead of using BaseMessage.from_dict, you would directly create a dictionary from each message dict.
|
||||||
|
prompt = {"messages": []}
|
||||||
|
for message in messages:
|
||||||
|
prompt["messages"].append({
|
||||||
|
"role": message["role"],
|
||||||
|
"content": message["content"],
|
||||||
|
})
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def _create_retry_decorator() -> Callable[[Any], Any]:
|
||||||
|
"""Create a retry decorator with exponential backoff."""
|
||||||
|
return retry(
|
||||||
|
retry=retry_if_exception_type(ChatGooglePalmError),
|
||||||
|
stop=stop_after_attempt(5),
|
||||||
|
wait=wait_exponential(multiplier=1, min=2, max=30),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
####################### => main class
|
||||||
|
class ChatGooglePalm(BaseModel):
|
||||||
|
"""Wrapper around Google's PaLM Chat API."""
|
||||||
|
|
||||||
|
client: Any #: :meta private:
|
||||||
|
model_name: str = "models/chat-bison-001"
|
||||||
|
google_api_key: Optional[str] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
n: int = 1
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
# Same as before
|
||||||
|
pass
|
||||||
|
|
||||||
|
def chat_with_retry(self, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator()
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _chat_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return self.client.chat(**kwargs)
|
||||||
|
|
||||||
|
return _chat_with_retry(**kwargs)
|
||||||
|
|
||||||
|
async def achat_with_retry(self, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the async completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator()
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
async def _achat_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return await self.client.chat_async(**kwargs)
|
||||||
|
|
||||||
|
return await _achat_with_retry(**kwargs)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
prompt = _messages_to_prompt_dict(messages)
|
||||||
|
|
||||||
|
response: genai.types.ChatResponse = self.chat_with_retry(
|
||||||
|
model=self.model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
top_k=self.top_k,
|
||||||
|
candidate_count=self.n,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _response_to_result(response, stop)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
prompt = _messages_to_prompt_dict(messages)
|
||||||
|
|
||||||
|
response: genai.types.ChatResponse = await self.achat_with_retry(
|
||||||
|
model=self.model_name,
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
top_k=self.top_k,
|
||||||
|
candidate_count=self.n,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _response_to_result(response, stop)
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {
|
||||||
|
"model_name": self.model_name,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"n": self.n,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "google-palm-chat"
|
Loading…
Reference in new issue