From 92194104bf611dbd635fdea9a452ea05f272bf6a Mon Sep 17 00:00:00 2001
From: Kye <kye@apacmediasolutions.com>
Date: Fri, 6 Oct 2023 01:02:36 -0400
Subject: [PATCH] clean up

Former-commit-id: 6fcc3c457dec52d80a822ff3f0d6a42379cf7daf
---
 apps/omni_ui.py          |  4 +--
 swarms/models/mistral.py | 77 ++++++++++++++++++++++++++++++++++++++++
 2 files changed, 79 insertions(+), 2 deletions(-)
 create mode 100644 swarms/models/mistral.py

diff --git a/apps/omni_ui.py b/apps/omni_ui.py
index 7a843938..13c76dda 100644
--- a/apps/omni_ui.py
+++ b/apps/omni_ui.py
@@ -4,7 +4,7 @@ import threading
 import os
 import glob
 import base64
-from langchain.llms import OpenAIChat  # Replace with your actual class
+from langchain.llms import ChatOpenAI  #
 from swarms.agents import OmniModalAgent  # Replace with your actual class
 
 #Function to convert image to base64
@@ -21,7 +21,7 @@ def get_latest_image():
     return latest_file
 
 #Initialize your OmniModalAgent
-llm = OpenAIChat(model_name="gpt-4")  # Replace with your actual initialization
+llm = ChatOpenAI(model_name="gpt-4")  # Replace with your actual initialization
 agent = OmniModalAgent(llm)  # Replace with your actual initialization
 
 #Global variable to store chat history
diff --git a/swarms/models/mistral.py b/swarms/models/mistral.py
new file mode 100644
index 00000000..abfcc422
--- /dev/null
+++ b/swarms/models/mistral.py
@@ -0,0 +1,77 @@
+# from exa import Inference
+
+
+# class Mistral:
+#     def __init__(
+#         self,
+#         temperature: float = 0.4,
+#         max_length: int = 500,
+#         quantize: bool = False,
+#     ):
+#         self.temperature = temperature
+#         self.max_length = max_length
+#         self.quantize = quantize
+
+#         self.model = Inference(
+#             model_id="from swarms.workers.worker import Worker",
+#             max_length=self.max_length,
+#             quantize=self.quantize
+#         )
+    
+#     def run(
+#         self,
+#         task: str
+#     ):
+#         try:
+#             output = self.model.run(task)
+#             return output
+#         except Exception as e:
+#             raise e
+        
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+class MistralWrapper:
+    def __init__(
+            self, 
+            model_name="mistralai/Mistral-7B-v0.1", 
+            device="cuda", 
+            use_flash_attention=False
+        ):
+        self.model_name = model_name
+        self.device = device
+        self.use_flash_attention = use_flash_attention
+
+        # Check if the specified device is available
+        if not torch.cuda.is_available() and device == "cuda":
+            raise ValueError("CUDA is not available. Please choose a different device.")
+
+        # Load the model and tokenizer
+        self.model = None
+        self.tokenizer = None
+        self.load_model()
+
+    def load_model(self):
+        try:
+            self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
+            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
+            self.model.to(self.device)
+        except Exception as e:
+            raise ValueError(f"Error loading the Mistral model: {str(e)}")
+
+    def run(self, prompt, max_new_tokens=100, do_sample=True):
+        try:
+            model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
+            generated_ids = self.model.generate(**model_inputs, max_new_tokens=max_new_tokens, do_sample=do_sample)
+            output_text = self.tokenizer.batch_decode(generated_ids)[0]
+            return output_text
+        except Exception as e:
+            raise ValueError(f"Error running the model: {str(e)}")
+
+# Example usage:
+if __name__ == "__main__":
+    wrapper = MistralWrapper(device="cuda", use_flash_attention=True)
+    prompt = "My favourite condiment is"
+    result = wrapper.run(prompt)
+    print(result)