|
|
@ -47,9 +47,12 @@ class Mistral:
|
|
|
|
task: str
|
|
|
|
task: str
|
|
|
|
):
|
|
|
|
):
|
|
|
|
"""Run the model on a given task."""
|
|
|
|
"""Run the model on a given task."""
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device)
|
|
|
|
model_inputs = self.tokenizer(
|
|
|
|
|
|
|
|
[task],
|
|
|
|
|
|
|
|
return_tensors="pt"
|
|
|
|
|
|
|
|
).to(self.device)
|
|
|
|
generated_ids = self.model.generate(
|
|
|
|
generated_ids = self.model.generate(
|
|
|
|
**model_inputs,
|
|
|
|
**model_inputs,
|
|
|
|
max_length=self.max_length,
|
|
|
|
max_length=self.max_length,
|
|
|
|