From 75739e6d00e8338d2056715083cd6e3db4197998 Mon Sep 17 00:00:00 2001 From: Kye Date: Fri, 6 Oct 2023 02:06:54 -0400 Subject: [PATCH] mistral ai model --- playground/models/mistral.py | 10 +++++ swarms/models/__init__.py | 3 +- swarms/models/base.py | 3 ++ swarms/models/mistral.py | 82 +++++++++++++++--------------------- 4 files changed, 49 insertions(+), 49 deletions(-) create mode 100644 playground/models/mistral.py diff --git a/playground/models/mistral.py b/playground/models/mistral.py new file mode 100644 index 00000000..8ae3c413 --- /dev/null +++ b/playground/models/mistral.py @@ -0,0 +1,10 @@ +from swarms.models import Mistral + +model = Mistral( + device="cuda", + use_flash_attention=True +) + +prompt = "My favourite condiment is" +result = model.run(prompt) +print(result) \ No newline at end of file diff --git a/swarms/models/__init__.py b/swarms/models/__init__.py index a943844d..575d69a7 100644 --- a/swarms/models/__init__.py +++ b/swarms/models/__init__.py @@ -6,4 +6,5 @@ from swarms.models.anthropic import Anthropic from swarms.models.petals import Petals # from swarms.models.openai import OpenAIChat #prompts -from swarms.models.prompts.debate import * \ No newline at end of file +from swarms.models.prompts.debate import * +from swarms.models.mistral import Mistral \ No newline at end of file diff --git a/swarms/models/base.py b/swarms/models/base.py index cb3ecbf5..53f86c77 100644 --- a/swarms/models/base.py +++ b/swarms/models/base.py @@ -2,6 +2,9 @@ from abc import ABC, abstractmethod class AbstractModel(ABC): #abstract base class for language models + def __init__(): + pass + @abstractmethod def run(self, prompt): #generate text using language model diff --git a/swarms/models/mistral.py b/swarms/models/mistral.py index abfcc422..b0ef5ccc 100644 --- a/swarms/models/mistral.py +++ b/swarms/models/mistral.py @@ -1,47 +1,29 @@ -# from exa import Inference - - -# class Mistral: -# def __init__( -# self, -# temperature: float = 0.4, -# max_length: int = 500, -# quantize: bool = False, -# ): -# self.temperature = temperature -# self.max_length = max_length -# self.quantize = quantize - -# self.model = Inference( -# model_id="from swarms.workers.worker import Worker", -# max_length=self.max_length, -# quantize=self.quantize -# ) - -# def run( -# self, -# task: str -# ): -# try: -# output = self.model.run(task) -# return output -# except Exception as e: -# raise e - - import torch from transformers import AutoModelForCausalLM, AutoTokenizer -class MistralWrapper: +class Mistral: + """ + Mistral + + model = MistralWrapper(device="cuda", use_flash_attention=True, temperature=0.7, max_length=200) + task = "My favourite condiment is" + result = model.run(task) + print(result) + """ def __init__( - self, - model_name="mistralai/Mistral-7B-v0.1", - device="cuda", - use_flash_attention=False - ): + self, + model_name: str ="mistralai/Mistral-7B-v0.1", + device: str ="cuda", + use_flash_attention: bool = False, + temperature: float = 1.0, + max_length: int = 100, + do_sample: bool = True + ): self.model_name = model_name self.device = device self.use_flash_attention = use_flash_attention + self.temperature = temperature + self.max_length = max_length # Check if the specified device is available if not torch.cuda.is_available() and device == "cuda": @@ -60,18 +42,22 @@ class MistralWrapper: except Exception as e: raise ValueError(f"Error loading the Mistral model: {str(e)}") - def run(self, prompt, max_new_tokens=100, do_sample=True): + def run( + self, + task: str + ): + """Run the model on a given task.""" + try: - model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) - generated_ids = self.model.generate(**model_inputs, max_new_tokens=max_new_tokens, do_sample=do_sample) + model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device) + generated_ids = self.model.generate( + **model_inputs, + max_length=self.max_length, + do_sample=self.do_sample, + temperature=self.temperature, + max_new_tokens=self.max_length + ) output_text = self.tokenizer.batch_decode(generated_ids)[0] return output_text except Exception as e: - raise ValueError(f"Error running the model: {str(e)}") - -# Example usage: -if __name__ == "__main__": - wrapper = MistralWrapper(device="cuda", use_flash_attention=True) - prompt = "My favourite condiment is" - result = wrapper.run(prompt) - print(result) + raise ValueError(f"Error running the model: {str(e)}") \ No newline at end of file