parent
36f026344f
commit
35c5dca1d5
@ -1,9 +1,34 @@
|
||||
import os
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
class PetalsHFLLM:
|
||||
def __init__(self, model_name: str = None, prompt: str = None, device: str = None, use_fast = False, add_bos_token: str = None, cuda=False):
|
||||
class Petals:
|
||||
"""Petals Bloom models."""
|
||||
|
||||
def __init__(self, model_name="bigscience/bloom-petals", temperature=0.7, max_new_tokens=256, top_p=0.9, top_k=None, do_sample=True, max_length=None):
|
||||
self.model_name = model_name
|
||||
self.prompt = prompt
|
||||
self.device = device
|
||||
self.use_fast = use_fast
|
||||
self.add_bos_token = add_bos_token
|
||||
self.cuda = cuda
|
||||
self.temperature = temperature
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.do_sample = do_sample
|
||||
self.max_length = max_length
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
|
||||
def _default_params(self):
|
||||
"""Get the default parameters for calling Petals API."""
|
||||
return {
|
||||
"temperature": self.temperature,
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"do_sample": self.do_sample,
|
||||
"max_length": self.max_length,
|
||||
}
|
||||
|
||||
def generate(self, prompt):
|
||||
"""Generate text using the Petals API."""
|
||||
params = self._default_params()
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
|
||||
outputs = self.model.generate(inputs, **params)
|
||||
return self.tokenizer.decode(outputs[0])
|
Loading…
Reference in new issue