|
|
@ -2,9 +2,9 @@ import torch
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
from swarms.structs.message import Message
|
|
|
|
from swarms.structs.message import Message
|
|
|
|
|
|
|
|
from swarms.models.base_llm import AbstractLLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Mistral(AbstractLLM):
|
|
|
|
class Mistral:
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Mistral is an all-new llm
|
|
|
|
Mistral is an all-new llm
|
|
|
|
|
|
|
|
|
|
|
@ -38,7 +38,10 @@ class Mistral:
|
|
|
|
temperature: float = 1.0,
|
|
|
|
temperature: float = 1.0,
|
|
|
|
max_length: int = 100,
|
|
|
|
max_length: int = 100,
|
|
|
|
do_sample: bool = True,
|
|
|
|
do_sample: bool = True,
|
|
|
|
|
|
|
|
*args,
|
|
|
|
|
|
|
|
**kwargs
|
|
|
|
):
|
|
|
|
):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
self.ai_name = ai_name
|
|
|
|
self.ai_name = ai_name
|
|
|
|
self.system_prompt = system_prompt
|
|
|
|
self.system_prompt = system_prompt
|
|
|
|
self.model_name = model_name
|
|
|
|
self.model_name = model_name
|
|
|
@ -46,6 +49,7 @@ class Mistral:
|
|
|
|
self.use_flash_attention = use_flash_attention
|
|
|
|
self.use_flash_attention = use_flash_attention
|
|
|
|
self.temperature = temperature
|
|
|
|
self.temperature = temperature
|
|
|
|
self.max_length = max_length
|
|
|
|
self.max_length = max_length
|
|
|
|
|
|
|
|
self.do_sample = do_sample
|
|
|
|
|
|
|
|
|
|
|
|
# Check if the specified device is available
|
|
|
|
# Check if the specified device is available
|
|
|
|
if not torch.cuda.is_available() and device == "cuda":
|
|
|
|
if not torch.cuda.is_available() and device == "cuda":
|
|
|
@ -54,49 +58,18 @@ class Mistral:
|
|
|
|
" device."
|
|
|
|
" device."
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Load the model and tokenizer
|
|
|
|
|
|
|
|
self.model = None
|
|
|
|
|
|
|
|
self.tokenizer = None
|
|
|
|
|
|
|
|
self.load_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.history = []
|
|
|
|
self.history = []
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(self):
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
try:
|
|
|
|
self.model_name, *args, **kwargs
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
)
|
|
|
|
self.model_name
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
)
|
|
|
|
self.model_name, *args, **kwargs
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
)
|
|
|
|
self.model_name
|
|
|
|
|
|
|
|
)
|
|
|
|
self.model.to(self.device)
|
|
|
|
self.model.to(self.device)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
|
|
f"Error loading the Mistral model: {str(e)}"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run(self, task: str):
|
|
|
|
def run(self, task: str, *args, **kwargs):
|
|
|
|
"""Run the model on a given task."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
model_inputs = self.tokenizer(
|
|
|
|
|
|
|
|
[task], return_tensors="pt"
|
|
|
|
|
|
|
|
).to(self.device)
|
|
|
|
|
|
|
|
generated_ids = self.model.generate(
|
|
|
|
|
|
|
|
**model_inputs,
|
|
|
|
|
|
|
|
max_length=self.max_length,
|
|
|
|
|
|
|
|
do_sample=self.do_sample,
|
|
|
|
|
|
|
|
temperature=self.temperature,
|
|
|
|
|
|
|
|
max_new_tokens=self.max_length,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
output_text = self.tokenizer.batch_decode(generated_ids)[
|
|
|
|
|
|
|
|
0
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
return output_text
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
raise ValueError(f"Error running the model: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, task: str):
|
|
|
|
|
|
|
|
"""Run the model on a given task."""
|
|
|
|
"""Run the model on a given task."""
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
@ -109,6 +82,7 @@ class Mistral:
|
|
|
|
do_sample=self.do_sample,
|
|
|
|
do_sample=self.do_sample,
|
|
|
|
temperature=self.temperature,
|
|
|
|
temperature=self.temperature,
|
|
|
|
max_new_tokens=self.max_length,
|
|
|
|
max_new_tokens=self.max_length,
|
|
|
|
|
|
|
|
**kwargs
|
|
|
|
)
|
|
|
|
)
|
|
|
|
output_text = self.tokenizer.batch_decode(generated_ids)[
|
|
|
|
output_text = self.tokenizer.batch_decode(generated_ids)[
|
|
|
|
0
|
|
|
|
0
|
|
|
@ -158,17 +132,4 @@ class Mistral:
|
|
|
|
# add error to history
|
|
|
|
# add error to history
|
|
|
|
self.history.append(Message("Agent", error_message))
|
|
|
|
self.history.append(Message("Agent", error_message))
|
|
|
|
|
|
|
|
|
|
|
|
return error_message
|
|
|
|
return error_message
|
|
|
|
|
|
|
|
|
|
|
|
def _stream_response(self, response: str = None):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Yield the response token by token (word by word)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
|
|
|
|
--------------
|
|
|
|
|
|
|
|
for token in _stream_response(response):
|
|
|
|
|
|
|
|
print(token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
for token in response.split():
|
|
|
|
|
|
|
|
yield token
|
|
|
|
|