You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
150 lines
4.8 KiB
150 lines
4.8 KiB
from __future__ import annotations
|
|
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from typing import Optional
|
|
|
|
import tiktoken
|
|
from attr import Factory, define, field
|
|
|
|
|
|
@define(frozen=True)
|
|
class BaseTokenizer(ABC):
|
|
DEFAULT_STOP_SEQUENCES = ["Observation:"]
|
|
|
|
stop_sequences: list[str] = field(
|
|
default=Factory(lambda: BaseTokenizer.DEFAULT_STOP_SEQUENCES),
|
|
kw_only=True,
|
|
)
|
|
|
|
@property
|
|
@abstractmethod
|
|
def max_tokens(self) -> int:
|
|
...
|
|
|
|
def count_tokens_left(self, text: str) -> int:
|
|
diff = self.max_tokens - self.count_tokens(text)
|
|
|
|
if diff > 0:
|
|
return diff
|
|
else:
|
|
return 0
|
|
|
|
@abstractmethod
|
|
def count_tokens(self, text: str) -> int:
|
|
...
|
|
|
|
|
|
@define(frozen=True)
|
|
class OpenAITokenizer(BaseTokenizer):
|
|
DEFAULT_OPENAI_GPT_3_COMPLETION_MODEL = "text-davinci-003"
|
|
DEFAULT_OPENAI_GPT_3_CHAT_MODEL = "gpt-3.5-turbo"
|
|
DEFAULT_OPENAI_GPT_4_MODEL = "gpt-4"
|
|
DEFAULT_ENCODING = "cl100k_base"
|
|
DEFAULT_MAX_TOKENS = 2049
|
|
TOKEN_OFFSET = 8
|
|
|
|
MODEL_PREFIXES_TO_MAX_TOKENS = {
|
|
"gpt-4-32k": 32768,
|
|
"gpt-4": 8192,
|
|
"gpt-3.5-turbo-16k": 16384,
|
|
"gpt-3.5-turbo": 4096,
|
|
"gpt-35-turbo-16k": 16384,
|
|
"gpt-35-turbo": 4096,
|
|
"text-davinci-003": 4097,
|
|
"text-davinci-002": 4097,
|
|
"code-davinci-002": 8001,
|
|
"text-embedding-ada-002": 8191,
|
|
"text-embedding-ada-001": 2046,
|
|
}
|
|
|
|
EMBEDDING_MODELS = ["text-embedding-ada-002", "text-embedding-ada-001"]
|
|
|
|
model: str = field(kw_only=True)
|
|
|
|
@property
|
|
def encoding(self) -> tiktoken.Encoding:
|
|
try:
|
|
return tiktoken.encoding_for_model(self.model)
|
|
except KeyError:
|
|
return tiktoken.get_encoding(self.DEFAULT_ENCODING)
|
|
|
|
@property
|
|
def max_tokens(self) -> int:
|
|
tokens = next(
|
|
v
|
|
for k, v in self.MODEL_PREFIXES_TO_MAX_TOKENS.items()
|
|
if self.model.startswith(k)
|
|
)
|
|
offset = 0 if self.model in self.EMBEDDING_MODELS else self.TOKEN_OFFSET
|
|
|
|
return (tokens if tokens else self.DEFAULT_MAX_TOKENS) - offset
|
|
|
|
def count_tokens(
|
|
self, text: str | list, model: Optional[str] = None
|
|
) -> int:
|
|
"""
|
|
Handles the special case of ChatML. Implementation adopted from the official OpenAI notebook:
|
|
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
|
"""
|
|
if isinstance(text, list):
|
|
model = model if model else self.model
|
|
|
|
try:
|
|
encoding = tiktoken.encoding_for_model(model)
|
|
except KeyError:
|
|
logging.warning("model not found. Using cl100k_base encoding.")
|
|
|
|
encoding = tiktoken.get_encoding("cl100k_base")
|
|
|
|
if model in {
|
|
"gpt-3.5-turbo-0613",
|
|
"gpt-3.5-turbo-16k-0613",
|
|
"gpt-4-0314",
|
|
"gpt-4-32k-0314",
|
|
"gpt-4-0613",
|
|
"gpt-4-32k-0613",
|
|
}:
|
|
tokens_per_message = 3
|
|
tokens_per_name = 1
|
|
elif model == "gpt-3.5-turbo-0301":
|
|
# every message follows <|start|>{role/name}\n{content}<|end|>\n
|
|
tokens_per_message = 4
|
|
# if there's a name, the role is omitted
|
|
tokens_per_name = -1
|
|
elif "gpt-3.5-turbo" in model or "gpt-35-turbo" in model:
|
|
logging.info(
|
|
"gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613."
|
|
)
|
|
return self.count_tokens(text, model="gpt-3.5-turbo-0613")
|
|
elif "gpt-4" in model:
|
|
logging.info(
|
|
"gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
|
|
)
|
|
return self.count_tokens(text, model="gpt-4-0613")
|
|
else:
|
|
raise NotImplementedError(
|
|
f"""token_count() is not implemented for model {model}.
|
|
See https://github.com/openai/openai-python/blob/main/chatml.md for
|
|
information on how messages are converted to tokens."""
|
|
)
|
|
|
|
num_tokens = 0
|
|
|
|
for message in text:
|
|
num_tokens += tokens_per_message
|
|
for key, value in message.items():
|
|
num_tokens += len(encoding.encode(value))
|
|
if key == "name":
|
|
num_tokens += tokens_per_name
|
|
|
|
# every reply is primed with <|start|>assistant<|message|>
|
|
num_tokens += 3
|
|
|
|
return num_tokens
|
|
else:
|
|
return len(
|
|
self.encoding.encode(
|
|
text, allowed_special=set(self.stop_sequences)
|
|
)
|
|
) |