You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							356 lines
						
					
					
						
							14 KiB
						
					
					
				
			
		
		
	
	
							356 lines
						
					
					
						
							14 KiB
						
					
					
				| """Contains classes for querying large language models."""
 | |
| from math import ceil
 | |
| import os
 | |
| import time
 | |
| from tqdm import tqdm
 | |
| from abc import ABC, abstractmethod
 | |
| 
 | |
| import openai
 | |
| 
 | |
| gpt_costs_per_thousand = {
 | |
|     'davinci': 0.0200,
 | |
|     'curie': 0.0020,
 | |
|     'babbage': 0.0005,
 | |
|     'ada': 0.0004
 | |
| }
 | |
| 
 | |
| 
 | |
| def model_from_config(config, disable_tqdm=True):
 | |
|     """Returns a model based on the config."""
 | |
|     model_type = config["name"]
 | |
|     if model_type == "GPT_forward":
 | |
|         return GPT_Forward(config, disable_tqdm=disable_tqdm)
 | |
|     elif model_type == "GPT_insert":
 | |
|         return GPT_Insert(config, disable_tqdm=disable_tqdm)
 | |
|     raise ValueError(f"Unknown model type: {model_type}")
 | |
| 
 | |
| 
 | |
| class LLM(ABC):
 | |
|     """Abstract base class for large language models."""
 | |
| 
 | |
|     @abstractmethod
 | |
|     def generate_text(self, prompt):
 | |
|         """Generates text from the model.
 | |
|         Parameters:
 | |
|             prompt: The prompt to use. This can be a string or a list of strings.
 | |
|         Returns:
 | |
|             A list of strings.
 | |
|         """
 | |
|         pass
 | |
| 
 | |
|     @abstractmethod
 | |
|     def log_probs(self, text, log_prob_range):
 | |
|         """Returns the log probs of the text.
 | |
|         Parameters:
 | |
|             text: The text to get the log probs of. This can be a string or a list of strings.
 | |
|             log_prob_range: The range of characters within each string to get the log_probs of. 
 | |
|                 This is a list of tuples of the form (start, end).
 | |
|         Returns:
 | |
|             A list of log probs.
 | |
|         """
 | |
|         pass
 | |
| 
 | |
| 
 | |
| class GPT_Forward(LLM):
 | |
|     """Wrapper for GPT-3."""
 | |
| 
 | |
|     def __init__(self, config, needs_confirmation=False, disable_tqdm=True):
 | |
|         """Initializes the model."""
 | |
|         self.config = config
 | |
|         self.needs_confirmation = needs_confirmation
 | |
|         self.disable_tqdm = disable_tqdm
 | |
| 
 | |
|     def confirm_cost(self, texts, n, max_tokens):
 | |
|         total_estimated_cost = 0
 | |
|         for text in texts:
 | |
|             total_estimated_cost += gpt_get_estimated_cost(
 | |
|                 self.config, text, max_tokens) * n
 | |
|         print(f"Estimated cost: ${total_estimated_cost:.2f}")
 | |
|         # Ask the user to confirm in the command line
 | |
|         if os.getenv("LLM_SKIP_CONFIRM") is None:
 | |
|             confirm = input("Continue? (y/n) ")
 | |
|             if confirm != 'y':
 | |
|                 raise Exception("Aborted.")
 | |
| 
 | |
|     def auto_reduce_n(self, fn, prompt, n):
 | |
|         """Reduces n by half until the function succeeds."""
 | |
|         try:
 | |
|             return fn(prompt, n)
 | |
|         except BatchSizeException as e:
 | |
|             if n == 1:
 | |
|                 raise e
 | |
|             return self.auto_reduce_n(fn, prompt, n // 2) + self.auto_reduce_n(fn, prompt, n // 2)
 | |
| 
 | |
|     def generate_text(self, prompt, n):
 | |
|         if not isinstance(prompt, list):
 | |
|             prompt = [prompt]
 | |
|         if self.needs_confirmation:
 | |
|             self.confirm_cost(
 | |
|                 prompt, n, self.config['gpt_config']['max_tokens'])
 | |
|         batch_size = self.config['batch_size']
 | |
|         prompt_batches = [prompt[i:i + batch_size]
 | |
|                           for i in range(0, len(prompt), batch_size)]
 | |
|         if not self.disable_tqdm:
 | |
|             print(
 | |
|                 f"[{self.config['name']}] Generating {len(prompt) * n} completions, "
 | |
|                 f"split into {len(prompt_batches)} batches of size {batch_size * n}")
 | |
|         text = []
 | |
| 
 | |
|         for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):
 | |
|             text += self.auto_reduce_n(self.__generate_text, prompt_batch, n)
 | |
|         return text
 | |
| 
 | |
|     def complete(self, prompt, n):
 | |
|         """Generates text from the model and returns the log prob data."""
 | |
|         if not isinstance(prompt, list):
 | |
|             prompt = [prompt]
 | |
|         batch_size = self.config['batch_size']
 | |
|         prompt_batches = [prompt[i:i + batch_size]
 | |
|                           for i in range(0, len(prompt), batch_size)]
 | |
|         if not self.disable_tqdm:
 | |
|             print(
 | |
|                 f"[{self.config['name']}] Generating {len(prompt) * n} completions, "
 | |
|                 f"split into {len(prompt_batches)} batches of size {batch_size * n}")
 | |
|         res = []
 | |
|         for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):
 | |
|             res += self.__complete(prompt_batch, n)
 | |
|         return res
 | |
| 
 | |
|     def log_probs(self, text, log_prob_range=None):
 | |
|         """Returns the log probs of the text."""
 | |
|         if not isinstance(text, list):
 | |
|             text = [text]
 | |
|         if self.needs_confirmation:
 | |
|             self.confirm_cost(text, 1, 0)
 | |
|         batch_size = self.config['batch_size']
 | |
|         text_batches = [text[i:i + batch_size]
 | |
|                         for i in range(0, len(text), batch_size)]
 | |
|         if log_prob_range is None:
 | |
|             log_prob_range_batches = [None] * len(text)
 | |
|         else:
 | |
|             assert len(log_prob_range) == len(text)
 | |
|             log_prob_range_batches = [log_prob_range[i:i + batch_size]
 | |
|                                       for i in range(0, len(log_prob_range), batch_size)]
 | |
|         if not self.disable_tqdm:
 | |
|             print(
 | |
|                 f"[{self.config['name']}] Getting log probs for {len(text)} strings, "
 | |
|                 f"split into {len(text_batches)} batches of (maximum) size {batch_size}")
 | |
|         log_probs = []
 | |
|         tokens = []
 | |
|         for text_batch, log_prob_range in tqdm(list(zip(text_batches, log_prob_range_batches)),
 | |
|                                                disable=self.disable_tqdm):
 | |
|             log_probs_batch, tokens_batch = self.__log_probs(
 | |
|                 text_batch, log_prob_range)
 | |
|             log_probs += log_probs_batch
 | |
|             tokens += tokens_batch
 | |
|         return log_probs, tokens
 | |
| 
 | |
|     def __generate_text(self, prompt, n):
 | |
|         """Generates text from the model."""
 | |
|         if not isinstance(prompt, list):
 | |
|             text = [prompt]
 | |
|         config = self.config['gpt_config'].copy()
 | |
|         config['n'] = n
 | |
|         # If there are any [APE] tokens in the prompts, remove them
 | |
|         for i in range(len(prompt)):
 | |
|             prompt[i] = prompt[i].replace('[APE]', '').strip()
 | |
|         response = None
 | |
|         while response is None:
 | |
|             try:
 | |
|                 response = openai.Completion.create(
 | |
|                     **config, prompt=prompt)
 | |
|             except Exception as e:
 | |
|                 if 'is greater than the maximum' in str(e):
 | |
|                     raise BatchSizeException()
 | |
|                 print(e)
 | |
|                 print('Retrying...')
 | |
|                 time.sleep(5)
 | |
| 
 | |
|         return [response['choices'][i]['text'] for i in range(len(response['choices']))]
 | |
| 
 | |
|     def __complete(self, prompt, n):
 | |
|         """Generates text from the model and returns the log prob data."""
 | |
|         if not isinstance(prompt, list):
 | |
|             text = [prompt]
 | |
|         config = self.config['gpt_config'].copy()
 | |
|         config['n'] = n
 | |
|         # If there are any [APE] tokens in the prompts, remove them
 | |
|         for i in range(len(prompt)):
 | |
|             prompt[i] = prompt[i].replace('[APE]', '').strip()
 | |
|         response = None
 | |
|         while response is None:
 | |
|             try:
 | |
|                 response = openai.Completion.create(
 | |
|                     **config, prompt=prompt)
 | |
|             except Exception as e:
 | |
|                 print(e)
 | |
|                 print('Retrying...')
 | |
|                 time.sleep(5)
 | |
|         return response['choices']
 | |
| 
 | |
|     def __log_probs(self, text, log_prob_range=None):
 | |
|         """Returns the log probs of the text."""
 | |
|         if not isinstance(text, list):
 | |
|             text = [text]
 | |
|         if log_prob_range is not None:
 | |
|             for i in range(len(text)):
 | |
|                 lower_index, upper_index = log_prob_range[i]
 | |
|                 assert lower_index < upper_index
 | |
|                 assert lower_index >= 0
 | |
|                 assert upper_index - 1 < len(text[i])
 | |
|         config = self.config['gpt_config'].copy()
 | |
|         config['logprobs'] = 1
 | |
|         config['echo'] = True
 | |
|         config['max_tokens'] = 0
 | |
|         if isinstance(text, list):
 | |
|             text = [f'\n{text[i]}' for i in range(len(text))]
 | |
|         else:
 | |
|             text = f'\n{text}'
 | |
|         response = None
 | |
|         while response is None:
 | |
|             try:
 | |
|                 response = openai.Completion.create(
 | |
|                     **config, prompt=text)
 | |
|             except Exception as e:
 | |
|                 print(e)
 | |
|                 print('Retrying...')
 | |
|                 time.sleep(5)
 | |
|         log_probs = [response['choices'][i]['logprobs']['token_logprobs'][1:]
 | |
|                      for i in range(len(response['choices']))]
 | |
|         tokens = [response['choices'][i]['logprobs']['tokens'][1:]
 | |
|                   for i in range(len(response['choices']))]
 | |
|         offsets = [response['choices'][i]['logprobs']['text_offset'][1:]
 | |
|                    for i in range(len(response['choices']))]
 | |
| 
 | |
|         # Subtract 1 from the offsets to account for the newline
 | |
|         for i in range(len(offsets)):
 | |
|             offsets[i] = [offset - 1 for offset in offsets[i]]
 | |
| 
 | |
|         if log_prob_range is not None:
 | |
|             # First, we need to find the indices of the tokens in the log probs
 | |
|             # that correspond to the tokens in the log_prob_range
 | |
|             for i in range(len(log_probs)):
 | |
|                 lower_index, upper_index = self.get_token_indices(
 | |
|                     offsets[i], log_prob_range[i])
 | |
|                 log_probs[i] = log_probs[i][lower_index:upper_index]
 | |
|                 tokens[i] = tokens[i][lower_index:upper_index]
 | |
| 
 | |
|         return log_probs, tokens
 | |
| 
 | |
|     def get_token_indices(self, offsets, log_prob_range):
 | |
|         """Returns the indices of the tokens in the log probs that correspond to the tokens in the log_prob_range."""
 | |
|         # For the lower index, find the highest index that is less than or equal to the lower index
 | |
|         lower_index = 0
 | |
|         for i in range(len(offsets)):
 | |
|             if offsets[i] <= log_prob_range[0]:
 | |
|                 lower_index = i
 | |
|             else:
 | |
|                 break
 | |
| 
 | |
|         upper_index = len(offsets)
 | |
|         for i in range(len(offsets)):
 | |
|             if offsets[i] >= log_prob_range[1]:
 | |
|                 upper_index = i
 | |
|                 break
 | |
| 
 | |
|         return lower_index, upper_index
 | |
| 
 | |
| 
 | |
| class GPT_Insert(LLM):
 | |
| 
 | |
|     def __init__(self, config, needs_confirmation=False, disable_tqdm=True):
 | |
|         """Initializes the model."""
 | |
|         self.config = config
 | |
|         self.needs_confirmation = needs_confirmation
 | |
|         self.disable_tqdm = disable_tqdm
 | |
| 
 | |
|     def confirm_cost(self, texts, n, max_tokens):
 | |
|         total_estimated_cost = 0
 | |
|         for text in texts:
 | |
|             total_estimated_cost += gpt_get_estimated_cost(
 | |
|                 self.config, text, max_tokens) * n
 | |
|         print(f"Estimated cost: ${total_estimated_cost:.2f}")
 | |
|         # Ask the user to confirm in the command line
 | |
|         if os.getenv("LLM_SKIP_CONFIRM") is None:
 | |
|             confirm = input("Continue? (y/n) ")
 | |
|             if confirm != 'y':
 | |
|                 raise Exception("Aborted.")
 | |
| 
 | |
|     def auto_reduce_n(self, fn, prompt, n):
 | |
|         """Reduces n by half until the function succeeds."""
 | |
|         try:
 | |
|             return fn(prompt, n)
 | |
|         except BatchSizeException as e:
 | |
|             if n == 1:
 | |
|                 raise e
 | |
|             return self.auto_reduce_n(fn, prompt, n // 2) + self.auto_reduce_n(fn, prompt, n // 2)
 | |
| 
 | |
|     def generate_text(self, prompt, n):
 | |
|         if not isinstance(prompt, list):
 | |
|             prompt = [prompt]
 | |
|         if self.needs_confirmation:
 | |
|             self.confirm_cost(
 | |
|                 prompt, n, self.config['gpt_config']['max_tokens'])
 | |
|         batch_size = self.config['batch_size']
 | |
|         assert batch_size == 1
 | |
|         prompt_batches = [prompt[i:i + batch_size]
 | |
|                           for i in range(0, len(prompt), batch_size)]
 | |
|         if not self.disable_tqdm:
 | |
|             print(
 | |
|                 f"[{self.config['name']}] Generating {len(prompt) * n} completions, split into {len(prompt_batches)} batches of (maximum) size {batch_size * n}")
 | |
|         text = []
 | |
|         for prompt_batch in tqdm(prompt_batches, disable=self.disable_tqdm):
 | |
|             text += self.auto_reduce_n(self.__generate_text, prompt_batch, n)
 | |
|         return text
 | |
| 
 | |
|     def log_probs(self, text, log_prob_range=None):
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     def __generate_text(self, prompt, n):
 | |
|         """Generates text from the model."""
 | |
|         config = self.config['gpt_config'].copy()
 | |
|         config['n'] = n
 | |
|         # Split prompts into prefixes and suffixes with the [APE] token (do not include the [APE] token in the suffix)
 | |
|         prefix = prompt[0].split('[APE]')[0]
 | |
|         suffix = prompt[0].split('[APE]')[1]
 | |
|         response = None
 | |
|         while response is None:
 | |
|             try:
 | |
|                 response = openai.Completion.create(
 | |
|                     **config, prompt=prefix, suffix=suffix)
 | |
|             except Exception as e:
 | |
|                 print(e)
 | |
|                 print('Retrying...')
 | |
|                 time.sleep(5)
 | |
| 
 | |
|         # Remove suffix from the generated text
 | |
|         texts = [response['choices'][i]['text'].replace(suffix, '') for i in range(len(response['choices']))]
 | |
|         return texts
 | |
| 
 | |
| 
 | |
| def gpt_get_estimated_cost(config, prompt, max_tokens):
 | |
|     """Uses the current API costs/1000 tokens to estimate the cost of generating text from the model."""
 | |
|     # Get rid of [APE] token
 | |
|     prompt = prompt.replace('[APE]', '')
 | |
|     # Get the number of tokens in the prompt
 | |
|     n_prompt_tokens = len(prompt) // 4
 | |
|     # Get the number of tokens in the generated text
 | |
|     total_tokens = n_prompt_tokens + max_tokens
 | |
|     engine = config['gpt_config']['model'].split('-')[1]
 | |
|     costs_per_thousand = gpt_costs_per_thousand
 | |
|     if engine not in costs_per_thousand:
 | |
|         # Try as if it is a fine-tuned model
 | |
|         engine = config['gpt_config']['model'].split(':')[0]
 | |
|         costs_per_thousand = {
 | |
|             'davinci': 0.1200,
 | |
|             'curie': 0.0120,
 | |
|             'babbage': 0.0024,
 | |
|             'ada': 0.0016
 | |
|         }
 | |
|     price = costs_per_thousand[engine] * total_tokens / 1000
 | |
|     return price
 | |
| 
 | |
| 
 | |
| class BatchSizeException(Exception):
 | |
|     pass
 |