From 35c5dca1d5c30146138c5d22babf110a1b7e7a1a Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 29 Jul 2023 13:50:56 -0400 Subject: [PATCH] petals model Former-commit-id: be96f7689b3698de6a37ed52ce41d29c70b54037 --- swarms/agents/models/petals.py | 39 ++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/swarms/agents/models/petals.py b/swarms/agents/models/petals.py index c62ac413..c5c25803 100644 --- a/swarms/agents/models/petals.py +++ b/swarms/agents/models/petals.py @@ -1,9 +1,34 @@ +import os +from transformers import AutoTokenizer, AutoModelForCausalLM -class PetalsHFLLM: - def __init__(self, model_name: str = None, prompt: str = None, device: str = None, use_fast = False, add_bos_token: str = None, cuda=False): +class Petals: + """Petals Bloom models.""" + + def __init__(self, model_name="bigscience/bloom-petals", temperature=0.7, max_new_tokens=256, top_p=0.9, top_k=None, do_sample=True, max_length=None): self.model_name = model_name - self.prompt = prompt - self.device = device - self.use_fast = use_fast - self.add_bos_token = add_bos_token - self.cuda = cuda \ No newline at end of file + self.temperature = temperature + self.max_new_tokens = max_new_tokens + self.top_p = top_p + self.top_k = top_k + self.do_sample = do_sample + self.max_length = max_length + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForCausalLM.from_pretrained(model_name) + + def _default_params(self): + """Get the default parameters for calling Petals API.""" + return { + "temperature": self.temperature, + "max_new_tokens": self.max_new_tokens, + "top_p": self.top_p, + "top_k": self.top_k, + "do_sample": self.do_sample, + "max_length": self.max_length, + } + + def generate(self, prompt): + """Generate text using the Petals API.""" + params = self._default_params() + inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"] + outputs = self.model.generate(inputs, **params) + return self.tokenizer.decode(outputs[0]) \ No newline at end of file