You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
74 lines
2.1 KiB
74 lines
2.1 KiB
from typing import Optional
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from swarms.models.base_llm import AbstractLLM
|
|
|
|
|
|
class Mixtral(AbstractLLM):
|
|
"""Mixtral model.
|
|
|
|
Args:
|
|
model_name (str): The name or path of the pre-trained Mixtral model.
|
|
max_new_tokens (int): The maximum number of new tokens to generate.
|
|
*args: Variable length argument list.
|
|
|
|
|
|
Examples:
|
|
>>> from swarms.models import Mixtral
|
|
>>> mixtral = Mixtral()
|
|
>>> mixtral.run("Test task")
|
|
'Generated text'
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = "mistralai/Mixtral-8x7B-v0.1",
|
|
max_new_tokens: int = 500,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initializes a Mixtral model.
|
|
|
|
Args:
|
|
model_name (str): The name or path of the pre-trained Mixtral model.
|
|
max_new_tokens (int): The maximum number of new tokens to generate.
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
"""
|
|
super().__init__(*args, **kwargs)
|
|
self.model_name = model_name
|
|
self.max_new_tokens = max_new_tokens
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
model_name, *args, **kwargs
|
|
)
|
|
|
|
def run(self, task: Optional[str] = None, **kwargs):
|
|
"""
|
|
Generates text based on the given task.
|
|
|
|
Args:
|
|
task (str, optional): The task or prompt for text generation.
|
|
|
|
Returns:
|
|
str: The generated text.
|
|
"""
|
|
try:
|
|
inputs = self.tokenizer(task, return_tensors="pt")
|
|
|
|
outputs = self.model.generate(
|
|
**inputs,
|
|
max_new_tokens=self.max_new_tokens,
|
|
**kwargs,
|
|
)
|
|
|
|
out = self.tokenizer.decode(
|
|
outputs[0],
|
|
skip_special_tokens=True,
|
|
)
|
|
|
|
return out
|
|
except Exception as error:
|
|
print(f"There is an error: {error} in Mixtral model.")
|
|
raise error
|