parent
							
								
									9c0c6c06cc
								
							
						
					
					
						commit
						2aa88b893d
					
				| @ -0,0 +1,53 @@ | |||||||
|  | """Zephyr by HF""" | ||||||
|  | import torch  | ||||||
|  | from transformers import pipeline | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Zephyr: | ||||||
|  |     """ | ||||||
|  |     Zehpyr model from HF | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     Usage: | ||||||
|  |     >>> model = Zephyr() | ||||||
|  |     >>> output = model("Generate hello world in python") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     """ | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         max_new_tokens: int = 300, | ||||||
|  |         temperature: float = 0.5, | ||||||
|  |         top_k: float = 50, | ||||||
|  |         top_p: float = 0.95, | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         self.max_new_tokens = max_new_tokens | ||||||
|  |         self.temperature = temperature | ||||||
|  |         self.top_k = top_k | ||||||
|  |         self.top_p = top_p | ||||||
|  | 
 | ||||||
|  |         self.pipe = pipeline( | ||||||
|  |             "text-generation", | ||||||
|  |             model="HuggingFaceH4/zephyr-7b-alpha", | ||||||
|  |             torch_dtype=torch.bfloa16, | ||||||
|  |             device_map="auto" | ||||||
|  |         ) | ||||||
|  |         self.messages = [ | ||||||
|  |             { | ||||||
|  |                 "role": "system", | ||||||
|  |                 "content": "You are a friendly chatbot who always responds in the style of a pirate", | ||||||
|  |             }, | ||||||
|  |             {"role": "user", "content": "How many helicopters can a human eat in one sitting?"}, | ||||||
|  |         ] | ||||||
|  | 
 | ||||||
|  |     def __call__(self, text: str): | ||||||
|  |         """Call the model""" | ||||||
|  |         prompt = self.pipe.tokenizer.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) | ||||||
|  |         outputs = self.pipe(prompt, max_new_token=self.max_new_tokens) | ||||||
|  |         print(outputs[0])["generated_text"] | ||||||
					Loading…
					
					
				
		Reference in new issue