|
|
|
@ -131,7 +131,7 @@ class HuggingfaceLLM(AbstractLLM):
|
|
|
|
|
temperature: float = 0.7,
|
|
|
|
|
top_k: int = 40,
|
|
|
|
|
top_p: float = 0.8,
|
|
|
|
|
dtype = torch.bfloat16,
|
|
|
|
|
dtype=torch.bfloat16,
|
|
|
|
|
*args,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
@ -189,7 +189,6 @@ class HuggingfaceLLM(AbstractLLM):
|
|
|
|
|
self.model_id, *args, **kwargs
|
|
|
|
|
).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_error(self, error: str):
|
|
|
|
|
"""Print error"""
|
|
|
|
|
print(colored(f"Error: {error}", "red"))
|
|
|
|
@ -264,7 +263,7 @@ class HuggingfaceLLM(AbstractLLM):
|
|
|
|
|
*args,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return self.tokenizer.decode(
|
|
|
|
|
outputs[0], skip_special_tokens=True
|
|
|
|
|
)
|
|
|
|
|