petals model

Former-commit-id: be96f7689b
pull/47/head
Kye 1 year ago
parent 36f026344f
commit 35c5dca1d5

@ -1,9 +1,34 @@
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
class PetalsHFLLM:
def __init__(self, model_name: str = None, prompt: str = None, device: str = None, use_fast = False, add_bos_token: str = None, cuda=False):
class Petals:
"""Petals Bloom models."""
def __init__(self, model_name="bigscience/bloom-petals", temperature=0.7, max_new_tokens=256, top_p=0.9, top_k=None, do_sample=True, max_length=None):
self.model_name = model_name
self.prompt = prompt
self.device = device
self.use_fast = use_fast
self.add_bos_token = add_bos_token
self.cuda = cuda
self.temperature = temperature
self.max_new_tokens = max_new_tokens
self.top_p = top_p
self.top_k = top_k
self.do_sample = do_sample
self.max_length = max_length
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
def _default_params(self):
"""Get the default parameters for calling Petals API."""
return {
"temperature": self.temperature,
"max_new_tokens": self.max_new_tokens,
"top_p": self.top_p,
"top_k": self.top_k,
"do_sample": self.do_sample,
"max_length": self.max_length,
}
def generate(self, prompt):
"""Generate text using the Petals API."""
params = self._default_params()
inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
outputs = self.model.generate(inputs, **params)
return self.tokenizer.decode(outputs[0])
Loading…
Cancel
Save