You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
33 lines
1.3 KiB
33 lines
1.3 KiB
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
class Petals:
|
|
"""Petals Bloom models."""
|
|
|
|
def __init__(self, model_name="bigscience/bloom-petals", temperature=0.7, max_new_tokens=256, top_p=0.9, top_k=None, do_sample=True, max_length=None):
|
|
self.model_name = model_name
|
|
self.temperature = temperature
|
|
self.max_new_tokens = max_new_tokens
|
|
self.top_p = top_p
|
|
self.top_k = top_k
|
|
self.do_sample = do_sample
|
|
self.max_length = max_length
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
|
|
|
def _default_params(self):
|
|
"""Get the default parameters for calling Petals API."""
|
|
return {
|
|
"temperature": self.temperature,
|
|
"max_new_tokens": self.max_new_tokens,
|
|
"top_p": self.top_p,
|
|
"top_k": self.top_k,
|
|
"do_sample": self.do_sample,
|
|
"max_length": self.max_length,
|
|
}
|
|
|
|
def generate(self, prompt):
|
|
"""Generate text using the Petals API."""
|
|
params = self._default_params()
|
|
inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
|
|
outputs = self.model.generate(inputs, **params)
|
|
return self.tokenizer.decode(outputs[0]) |