From 74b53e939c4e151af4885a64bec17e6fe5bbb7e2 Mon Sep 17 00:00:00 2001 From: Kye Date: Fri, 6 Oct 2023 01:02:36 -0400 Subject: [PATCH] clean up Former-commit-id: b0a085dc07bb5925c1dad41ec63327a11c780e07 --- apps/omni_ui.py | 4 +-- swarms/models/mistral.py | 77 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 swarms/models/mistral.py diff --git a/apps/omni_ui.py b/apps/omni_ui.py index 7a843938..13c76dda 100644 --- a/apps/omni_ui.py +++ b/apps/omni_ui.py @@ -4,7 +4,7 @@ import threading import os import glob import base64 -from langchain.llms import OpenAIChat # Replace with your actual class +from langchain.llms import ChatOpenAI # from swarms.agents import OmniModalAgent # Replace with your actual class #Function to convert image to base64 @@ -21,7 +21,7 @@ def get_latest_image(): return latest_file #Initialize your OmniModalAgent -llm = OpenAIChat(model_name="gpt-4") # Replace with your actual initialization +llm = ChatOpenAI(model_name="gpt-4") # Replace with your actual initialization agent = OmniModalAgent(llm) # Replace with your actual initialization #Global variable to store chat history diff --git a/swarms/models/mistral.py b/swarms/models/mistral.py new file mode 100644 index 00000000..abfcc422 --- /dev/null +++ b/swarms/models/mistral.py @@ -0,0 +1,77 @@ +# from exa import Inference + + +# class Mistral: +# def __init__( +# self, +# temperature: float = 0.4, +# max_length: int = 500, +# quantize: bool = False, +# ): +# self.temperature = temperature +# self.max_length = max_length +# self.quantize = quantize + +# self.model = Inference( +# model_id="from swarms.workers.worker import Worker", +# max_length=self.max_length, +# quantize=self.quantize +# ) + +# def run( +# self, +# task: str +# ): +# try: +# output = self.model.run(task) +# return output +# except Exception as e: +# raise e + + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +class MistralWrapper: + def __init__( + self, + model_name="mistralai/Mistral-7B-v0.1", + device="cuda", + use_flash_attention=False + ): + self.model_name = model_name + self.device = device + self.use_flash_attention = use_flash_attention + + # Check if the specified device is available + if not torch.cuda.is_available() and device == "cuda": + raise ValueError("CUDA is not available. Please choose a different device.") + + # Load the model and tokenizer + self.model = None + self.tokenizer = None + self.load_model() + + def load_model(self): + try: + self.model = AutoModelForCausalLM.from_pretrained(self.model_name) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self.model.to(self.device) + except Exception as e: + raise ValueError(f"Error loading the Mistral model: {str(e)}") + + def run(self, prompt, max_new_tokens=100, do_sample=True): + try: + model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) + generated_ids = self.model.generate(**model_inputs, max_new_tokens=max_new_tokens, do_sample=do_sample) + output_text = self.tokenizer.batch_decode(generated_ids)[0] + return output_text + except Exception as e: + raise ValueError(f"Error running the model: {str(e)}") + +# Example usage: +if __name__ == "__main__": + wrapper = MistralWrapper(device="cuda", use_flash_attention=True) + prompt = "My favourite condiment is" + result = wrapper.run(prompt) + print(result)