pull/64/head
Kye 1 year ago
parent 9733541cc7
commit a9a5645792

@ -64,6 +64,7 @@ def omni_agent(task: str = None):
tools = [ tools = [
hf_agent, hf_agent,
omni_agent, omni_agent,
] ]

@ -13,3 +13,7 @@ class AbstractModel(ABC):
def chat(self, prompt, history): def chat(self, prompt, history):
pass pass
def __call__(self, task):
pass

@ -69,6 +69,23 @@ class Mistral:
except Exception as e: except Exception as e:
raise ValueError(f"Error running the model: {str(e)}") raise ValueError(f"Error running the model: {str(e)}")
def __call__(self, task: str):
"""Run the model on a given task."""
try:
model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device)
generated_ids = self.model.generate(
**model_inputs,
max_length=self.max_length,
do_sample=self.do_sample,
temperature=self.temperature,
max_new_tokens=self.max_length,
)
output_text = self.tokenizer.batch_decode(generated_ids)[0]
return output_text
except Exception as e:
raise ValueError(f"Error running the model: {str(e)}")
def chat(self, msg: str = None, streaming: bool = False): def chat(self, msg: str = None, streaming: bool = False):
""" """
Run chat Run chat

Loading…
Cancel
Save