|
|
@ -69,6 +69,23 @@ class Mistral:
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
raise ValueError(f"Error running the model: {str(e)}")
|
|
|
|
raise ValueError(f"Error running the model: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, task: str):
|
|
|
|
|
|
|
|
"""Run the model on a given task."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device)
|
|
|
|
|
|
|
|
generated_ids = self.model.generate(
|
|
|
|
|
|
|
|
**model_inputs,
|
|
|
|
|
|
|
|
max_length=self.max_length,
|
|
|
|
|
|
|
|
do_sample=self.do_sample,
|
|
|
|
|
|
|
|
temperature=self.temperature,
|
|
|
|
|
|
|
|
max_new_tokens=self.max_length,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
output_text = self.tokenizer.batch_decode(generated_ids)[0]
|
|
|
|
|
|
|
|
return output_text
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
raise ValueError(f"Error running the model: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
def chat(self, msg: str = None, streaming: bool = False):
|
|
|
|
def chat(self, msg: str = None, streaming: bool = False):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Run chat
|
|
|
|
Run chat
|
|
|
|