You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
356 lines
14 KiB
356 lines
14 KiB
"""Contains classes for querying large language models."""
|
|
from math import ceil
|
|
import os
|
|
import time
|
|
from tqdm import tqdm
|
|
from abc import ABC, abstractmethod
|
|
|
|
import openai
|
|
|
|
gpt_costs_per_thousand = {
|
|
'davinci': 0.0200,
|
|
'curie': 0.0020,
|
|
'babbage': 0.0005,
|
|
'ada': 0.0004
|
|
}
|
|
|
|
|
|
def model_from_config(config, disable_tqdm=True):
|
|
"""Returns a model based on the config."""
|
|
model_type = config["name"]
|
|
if model_type == "GPT_forward":
|
|
return GPT_Forward(config, disable_tqdm=disable_tqdm)
|
|
elif model_type == "GPT_insert":
|
|
return GPT_Insert(config, disable_tqdm=disable_tqdm)
|
|
raise ValueError(f"Unknown model type: {model_type}")
|
|
|
|
|
|
class LLM(ABC):
|
|
"""Abstract base class for large language models."""
|
|
|
|
@abstractmethod
|
|
def generate_text(self, prompt):
|
|
"""Generates text from the model.
|
|
Parameters:
|
|
prompt: The prompt to use. This can be a string or a list of strings.
|
|
Returns:
|
|
A list of strings.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def log_probs(self, text, log_prob_range):
|
|
"""Returns the log probs of the text.
|
|
Parameters:
|
|
text: The text to get the log probs of. This can be a string or a list of strings.
|
|
log_prob_range: The range of characters within each string to get the log_probs of.
|
|
This is a list of tuples of the form (start, end).
|
|
Returns:
|
|
A list of log probs.
|
|
"""
|
|
pass
|
|
|
|
|
|
class GPT_Forward(LLM):
|
|
"""Wrapper for GPT-3."""
|
|
|
|
def __init__(self, config, needs_confirmation=False, disable_tqdm=True):
|
|
"""Initializes the model."""
|
|
self.config = config
|
|
self.needs_confirmation = needs_confirmation
|
|
self.disable_tqdm = disable_tqdm
|
|
|
|
def confirm_cost(self, texts, n, max_tokens):
|
|
total_estimated_cost = 0
|
|
for text in texts:
|
|
total_estimated_cost += gpt_get_estimated_cost(
|
|
self.config, text, max_tokens) * n
|
|
print(f"Estimated cost: ${total_estimated_cost:.2f}")
|
|
# Ask the user to confirm in the command line
|
|
if os.getenv("LLM_SKIP_CONFIRM") is None:
|
|
confirm = input("Continue? (y/n) ")
|
|
if confirm != 'y':
|
|
raise Exception("Aborted.")
|
|
|
|
def auto_reduce_n(self, fn, prompt, n):
|
|
"""Reduces n by half until the function succeeds."""
|
|
try:
|
|
return fn(prompt, n)
|
|
except BatchSizeException as e:
|
|
if n == 1:
|
|
raise e
|
|
return self.auto_reduce_n(fn, prompt, n // 2) + self.auto_reduce_n(fn, prompt, n // 2)
|
|
|
|
def generate_text(self, prompt, n):
|
|
if not isinstance(prompt, list):
|
|
prompt = [prompt]
|
|
if self.needs_confirmation:
|
|
self.confirm_cost(
|
|
prompt, n, self.config['gpt_config']['max_tokens'])
|
|
batch_size = self.config['batch_size']
|
|
prompt_batches = [prompt[i:i + batch_size]
|
|
for i in range(0, len(prompt), batch_size)]
|
|
if not self.disable_tqdm:
|
|
print(
|
|
f"[{self.config['name']}] Generating {len(prompt) * n} completions, "
|
|
f"split into {len(prompt_batches)} batches of size {batch_size * n}")
|
|
text = []
|
|
|
|
for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):
|
|
text += self.auto_reduce_n(self.__generate_text, prompt_batch, n)
|
|
return text
|
|
|
|
def complete(self, prompt, n):
|
|
"""Generates text from the model and returns the log prob data."""
|
|
if not isinstance(prompt, list):
|
|
prompt = [prompt]
|
|
batch_size = self.config['batch_size']
|
|
prompt_batches = [prompt[i:i + batch_size]
|
|
for i in range(0, len(prompt), batch_size)]
|
|
if not self.disable_tqdm:
|
|
print(
|
|
f"[{self.config['name']}] Generating {len(prompt) * n} completions, "
|
|
f"split into {len(prompt_batches)} batches of size {batch_size * n}")
|
|
res = []
|
|
for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):
|
|
res += self.__complete(prompt_batch, n)
|
|
return res
|
|
|
|
def log_probs(self, text, log_prob_range=None):
|
|
"""Returns the log probs of the text."""
|
|
if not isinstance(text, list):
|
|
text = [text]
|
|
if self.needs_confirmation:
|
|
self.confirm_cost(text, 1, 0)
|
|
batch_size = self.config['batch_size']
|
|
text_batches = [text[i:i + batch_size]
|
|
for i in range(0, len(text), batch_size)]
|
|
if log_prob_range is None:
|
|
log_prob_range_batches = [None] * len(text)
|
|
else:
|
|
assert len(log_prob_range) == len(text)
|
|
log_prob_range_batches = [log_prob_range[i:i + batch_size]
|
|
for i in range(0, len(log_prob_range), batch_size)]
|
|
if not self.disable_tqdm:
|
|
print(
|
|
f"[{self.config['name']}] Getting log probs for {len(text)} strings, "
|
|
f"split into {len(text_batches)} batches of (maximum) size {batch_size}")
|
|
log_probs = []
|
|
tokens = []
|
|
for text_batch, log_prob_range in tqdm(list(zip(text_batches, log_prob_range_batches)),
|
|
disable=self.disable_tqdm):
|
|
log_probs_batch, tokens_batch = self.__log_probs(
|
|
text_batch, log_prob_range)
|
|
log_probs += log_probs_batch
|
|
tokens += tokens_batch
|
|
return log_probs, tokens
|
|
|
|
def __generate_text(self, prompt, n):
|
|
"""Generates text from the model."""
|
|
if not isinstance(prompt, list):
|
|
text = [prompt]
|
|
config = self.config['gpt_config'].copy()
|
|
config['n'] = n
|
|
# If there are any [APE] tokens in the prompts, remove them
|
|
for i in range(len(prompt)):
|
|
prompt[i] = prompt[i].replace('[APE]', '').strip()
|
|
response = None
|
|
while response is None:
|
|
try:
|
|
response = openai.Completion.create(
|
|
**config, prompt=prompt)
|
|
except Exception as e:
|
|
if 'is greater than the maximum' in str(e):
|
|
raise BatchSizeException()
|
|
print(e)
|
|
print('Retrying...')
|
|
time.sleep(5)
|
|
|
|
return [response['choices'][i]['text'] for i in range(len(response['choices']))]
|
|
|
|
def __complete(self, prompt, n):
|
|
"""Generates text from the model and returns the log prob data."""
|
|
if not isinstance(prompt, list):
|
|
text = [prompt]
|
|
config = self.config['gpt_config'].copy()
|
|
config['n'] = n
|
|
# If there are any [APE] tokens in the prompts, remove them
|
|
for i in range(len(prompt)):
|
|
prompt[i] = prompt[i].replace('[APE]', '').strip()
|
|
response = None
|
|
while response is None:
|
|
try:
|
|
response = openai.Completion.create(
|
|
**config, prompt=prompt)
|
|
except Exception as e:
|
|
print(e)
|
|
print('Retrying...')
|
|
time.sleep(5)
|
|
return response['choices']
|
|
|
|
def __log_probs(self, text, log_prob_range=None):
|
|
"""Returns the log probs of the text."""
|
|
if not isinstance(text, list):
|
|
text = [text]
|
|
if log_prob_range is not None:
|
|
for i in range(len(text)):
|
|
lower_index, upper_index = log_prob_range[i]
|
|
assert lower_index < upper_index
|
|
assert lower_index >= 0
|
|
assert upper_index - 1 < len(text[i])
|
|
config = self.config['gpt_config'].copy()
|
|
config['logprobs'] = 1
|
|
config['echo'] = True
|
|
config['max_tokens'] = 0
|
|
if isinstance(text, list):
|
|
text = [f'\n{text[i]}' for i in range(len(text))]
|
|
else:
|
|
text = f'\n{text}'
|
|
response = None
|
|
while response is None:
|
|
try:
|
|
response = openai.Completion.create(
|
|
**config, prompt=text)
|
|
except Exception as e:
|
|
print(e)
|
|
print('Retrying...')
|
|
time.sleep(5)
|
|
log_probs = [response['choices'][i]['logprobs']['token_logprobs'][1:]
|
|
for i in range(len(response['choices']))]
|
|
tokens = [response['choices'][i]['logprobs']['tokens'][1:]
|
|
for i in range(len(response['choices']))]
|
|
offsets = [response['choices'][i]['logprobs']['text_offset'][1:]
|
|
for i in range(len(response['choices']))]
|
|
|
|
# Subtract 1 from the offsets to account for the newline
|
|
for i in range(len(offsets)):
|
|
offsets[i] = [offset - 1 for offset in offsets[i]]
|
|
|
|
if log_prob_range is not None:
|
|
# First, we need to find the indices of the tokens in the log probs
|
|
# that correspond to the tokens in the log_prob_range
|
|
for i in range(len(log_probs)):
|
|
lower_index, upper_index = self.get_token_indices(
|
|
offsets[i], log_prob_range[i])
|
|
log_probs[i] = log_probs[i][lower_index:upper_index]
|
|
tokens[i] = tokens[i][lower_index:upper_index]
|
|
|
|
return log_probs, tokens
|
|
|
|
def get_token_indices(self, offsets, log_prob_range):
|
|
"""Returns the indices of the tokens in the log probs that correspond to the tokens in the log_prob_range."""
|
|
# For the lower index, find the highest index that is less than or equal to the lower index
|
|
lower_index = 0
|
|
for i in range(len(offsets)):
|
|
if offsets[i] <= log_prob_range[0]:
|
|
lower_index = i
|
|
else:
|
|
break
|
|
|
|
upper_index = len(offsets)
|
|
for i in range(len(offsets)):
|
|
if offsets[i] >= log_prob_range[1]:
|
|
upper_index = i
|
|
break
|
|
|
|
return lower_index, upper_index
|
|
|
|
|
|
class GPT_Insert(LLM):
|
|
|
|
def __init__(self, config, needs_confirmation=False, disable_tqdm=True):
|
|
"""Initializes the model."""
|
|
self.config = config
|
|
self.needs_confirmation = needs_confirmation
|
|
self.disable_tqdm = disable_tqdm
|
|
|
|
def confirm_cost(self, texts, n, max_tokens):
|
|
total_estimated_cost = 0
|
|
for text in texts:
|
|
total_estimated_cost += gpt_get_estimated_cost(
|
|
self.config, text, max_tokens) * n
|
|
print(f"Estimated cost: ${total_estimated_cost:.2f}")
|
|
# Ask the user to confirm in the command line
|
|
if os.getenv("LLM_SKIP_CONFIRM") is None:
|
|
confirm = input("Continue? (y/n) ")
|
|
if confirm != 'y':
|
|
raise Exception("Aborted.")
|
|
|
|
def auto_reduce_n(self, fn, prompt, n):
|
|
"""Reduces n by half until the function succeeds."""
|
|
try:
|
|
return fn(prompt, n)
|
|
except BatchSizeException as e:
|
|
if n == 1:
|
|
raise e
|
|
return self.auto_reduce_n(fn, prompt, n // 2) + self.auto_reduce_n(fn, prompt, n // 2)
|
|
|
|
def generate_text(self, prompt, n):
|
|
if not isinstance(prompt, list):
|
|
prompt = [prompt]
|
|
if self.needs_confirmation:
|
|
self.confirm_cost(
|
|
prompt, n, self.config['gpt_config']['max_tokens'])
|
|
batch_size = self.config['batch_size']
|
|
assert batch_size == 1
|
|
prompt_batches = [prompt[i:i + batch_size]
|
|
for i in range(0, len(prompt), batch_size)]
|
|
if not self.disable_tqdm:
|
|
print(
|
|
f"[{self.config['name']}] Generating {len(prompt) * n} completions, split into {len(prompt_batches)} batches of (maximum) size {batch_size * n}")
|
|
text = []
|
|
for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):
|
|
text += self.auto_reduce_n(self.__generate_text, prompt_batch, n)
|
|
return text
|
|
|
|
def log_probs(self, text, log_prob_range=None):
|
|
raise NotImplementedError
|
|
|
|
def __generate_text(self, prompt, n):
|
|
"""Generates text from the model."""
|
|
config = self.config['gpt_config'].copy()
|
|
config['n'] = n
|
|
# Split prompts into prefixes and suffixes with the [APE] token (do not include the [APE] token in the suffix)
|
|
prefix = prompt[0].split('[APE]')[0]
|
|
suffix = prompt[0].split('[APE]')[1]
|
|
response = None
|
|
while response is None:
|
|
try:
|
|
response = openai.Completion.create(
|
|
**config, prompt=prefix, suffix=suffix)
|
|
except Exception as e:
|
|
print(e)
|
|
print('Retrying...')
|
|
time.sleep(5)
|
|
|
|
# Remove suffix from the generated text
|
|
texts = [response['choices'][i]['text'].replace(suffix, '') for i in range(len(response['choices']))]
|
|
return texts
|
|
|
|
|
|
def gpt_get_estimated_cost(config, prompt, max_tokens):
|
|
"""Uses the current API costs/1000 tokens to estimate the cost of generating text from the model."""
|
|
# Get rid of [APE] token
|
|
prompt = prompt.replace('[APE]', '')
|
|
# Get the number of tokens in the prompt
|
|
n_prompt_tokens = len(prompt) // 4
|
|
# Get the number of tokens in the generated text
|
|
total_tokens = n_prompt_tokens + max_tokens
|
|
engine = config['gpt_config']['model'].split('-')[1]
|
|
costs_per_thousand = gpt_costs_per_thousand
|
|
if engine not in costs_per_thousand:
|
|
# Try as if it is a fine-tuned model
|
|
engine = config['gpt_config']['model'].split(':')[0]
|
|
costs_per_thousand = {
|
|
'davinci': 0.1200,
|
|
'curie': 0.0120,
|
|
'babbage': 0.0024,
|
|
'ada': 0.0016
|
|
}
|
|
price = costs_per_thousand[engine] * total_tokens / 1000
|
|
return price
|
|
|
|
|
|
class BatchSizeException(Exception):
|
|
pass
|