pull/64/head
Kye 1 year ago
parent 9733541cc7
commit a9a5645792

@ -64,6 +64,7 @@ def omni_agent(task: str = None):
tools = [
hf_agent,
omni_agent,
]

@ -13,3 +13,7 @@ class AbstractModel(ABC):
def chat(self, prompt, history):
pass
def __call__(self, task):
pass

@ -69,6 +69,23 @@ class Mistral:
except Exception as e:
raise ValueError(f"Error running the model: {str(e)}")
def __call__(self, task: str):
"""Run the model on a given task."""
try:
model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device)
generated_ids = self.model.generate(
**model_inputs,
max_length=self.max_length,
do_sample=self.do_sample,
temperature=self.temperature,
max_new_tokens=self.max_length,
)
output_text = self.tokenizer.batch_decode(generated_ids)[0]
return output_text
except Exception as e:
raise ValueError(f"Error running the model: {str(e)}")
def chat(self, msg: str = None, streaming: bool = False):
"""
Run chat

Loading…
Cancel
Save