From a9a5645792cc0ef709f7ff53a420bdfdb0f67b9c Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 14 Oct 2023 14:13:33 -0400 Subject: [PATCH] clean up --- stacked_worker.py | 1 + swarms/models/base.py | 4 ++++ swarms/models/mistral.py | 17 +++++++++++++++++ 3 files changed, 22 insertions(+) diff --git a/stacked_worker.py b/stacked_worker.py index 9f545f1f..f94c6ae2 100644 --- a/stacked_worker.py +++ b/stacked_worker.py @@ -64,6 +64,7 @@ def omni_agent(task: str = None): tools = [ hf_agent, omni_agent, + ] diff --git a/swarms/models/base.py b/swarms/models/base.py index 63f72671..13b9d433 100644 --- a/swarms/models/base.py +++ b/swarms/models/base.py @@ -13,3 +13,7 @@ class AbstractModel(ABC): def chat(self, prompt, history): pass + + def __call__(self, task): + pass + diff --git a/swarms/models/mistral.py b/swarms/models/mistral.py index 61e4305d..6d99ffd9 100644 --- a/swarms/models/mistral.py +++ b/swarms/models/mistral.py @@ -69,6 +69,23 @@ class Mistral: except Exception as e: raise ValueError(f"Error running the model: {str(e)}") + def __call__(self, task: str): + """Run the model on a given task.""" + + try: + model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device) + generated_ids = self.model.generate( + **model_inputs, + max_length=self.max_length, + do_sample=self.do_sample, + temperature=self.temperature, + max_new_tokens=self.max_length, + ) + output_text = self.tokenizer.batch_decode(generated_ids)[0] + return output_text + except Exception as e: + raise ValueError(f"Error running the model: {str(e)}") + def chat(self, msg: str = None, streaming: bool = False): """ Run chat