mistral ai model

pull/58/head
Kye 1 year ago
parent b204b60f98
commit 75739e6d00

@ -0,0 +1,10 @@
from swarms.models import Mistral
model = Mistral(
device="cuda",
use_flash_attention=True
)
prompt = "My favourite condiment is"
result = model.run(prompt)
print(result)

@ -7,3 +7,4 @@ from swarms.models.petals import Petals
# from swarms.models.openai import OpenAIChat
#prompts
from swarms.models.prompts.debate import *
from swarms.models.mistral import Mistral

@ -2,6 +2,9 @@ from abc import ABC, abstractmethod
class AbstractModel(ABC):
#abstract base class for language models
def __init__():
pass
@abstractmethod
def run(self, prompt):
#generate text using language model

@ -1,47 +1,29 @@
# from exa import Inference
# class Mistral:
# def __init__(
# self,
# temperature: float = 0.4,
# max_length: int = 500,
# quantize: bool = False,
# ):
# self.temperature = temperature
# self.max_length = max_length
# self.quantize = quantize
# self.model = Inference(
# model_id="from swarms.workers.worker import Worker",
# max_length=self.max_length,
# quantize=self.quantize
# )
# def run(
# self,
# task: str
# ):
# try:
# output = self.model.run(task)
# return output
# except Exception as e:
# raise e
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class MistralWrapper:
class Mistral:
"""
Mistral
model = MistralWrapper(device="cuda", use_flash_attention=True, temperature=0.7, max_length=200)
task = "My favourite condiment is"
result = model.run(task)
print(result)
"""
def __init__(
self,
model_name="mistralai/Mistral-7B-v0.1",
device="cuda",
use_flash_attention=False
):
self,
model_name: str ="mistralai/Mistral-7B-v0.1",
device: str ="cuda",
use_flash_attention: bool = False,
temperature: float = 1.0,
max_length: int = 100,
do_sample: bool = True
):
self.model_name = model_name
self.device = device
self.use_flash_attention = use_flash_attention
self.temperature = temperature
self.max_length = max_length
# Check if the specified device is available
if not torch.cuda.is_available() and device == "cuda":
@ -60,18 +42,22 @@ class MistralWrapper:
except Exception as e:
raise ValueError(f"Error loading the Mistral model: {str(e)}")
def run(self, prompt, max_new_tokens=100, do_sample=True):
def run(
self,
task: str
):
"""Run the model on a given task."""
try:
model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
generated_ids = self.model.generate(**model_inputs, max_new_tokens=max_new_tokens, do_sample=do_sample)
model_inputs = self.tokenizer([task], return_tensors="pt").to(self.device)
generated_ids = self.model.generate(
**model_inputs,
max_length=self.max_length,
do_sample=self.do_sample,
temperature=self.temperature,
max_new_tokens=self.max_length
)
output_text = self.tokenizer.batch_decode(generated_ids)[0]
return output_text
except Exception as e:
raise ValueError(f"Error running the model: {str(e)}")
# Example usage:
if __name__ == "__main__":
wrapper = MistralWrapper(device="cuda", use_flash_attention=True)
prompt = "My favourite condiment is"
result = wrapper.run(prompt)
print(result)

Loading…
Cancel
Save