parent
aa5f4477f7
commit
92a7adc1a2
@ -0,0 +1,76 @@
|
||||
# Module Name: Mixtral
|
||||
|
||||
## Introduction
|
||||
The Mixtral module is a powerful language model designed for text generation tasks. It leverages the MistralAI Mixtral-8x7B pre-trained model to generate high-quality text based on user-defined tasks or prompts. In this documentation, we will provide a comprehensive overview of the Mixtral module, including its architecture, purpose, arguments, and detailed usage examples.
|
||||
|
||||
## Purpose
|
||||
The Mixtral module is designed to facilitate text generation tasks using state-of-the-art language models. Whether you need to generate creative content, draft text for various applications, or simply explore the capabilities of Mixtral, this module serves as a versatile and efficient solution. With its easy-to-use interface, you can quickly generate text for a wide range of applications.
|
||||
|
||||
## Architecture
|
||||
The Mixtral module is built on top of the MistralAI Mixtral-8x7B pre-trained model. It utilizes a deep neural network architecture with 8 layers and 7 attention heads to generate coherent and contextually relevant text. The model is capable of handling a variety of text generation tasks, from simple prompts to more complex content generation.
|
||||
|
||||
## Class Definition
|
||||
### `Mixtral(model_name: str = "mistralai/Mixtral-8x7B-v0.1", max_new_tokens: int = 500)`
|
||||
|
||||
#### Parameters
|
||||
- `model_name` (str, optional): The name or path of the pre-trained Mixtral model. Default is "mistralai/Mixtral-8x7B-v0.1".
|
||||
- `max_new_tokens` (int, optional): The maximum number of new tokens to generate. Default is 500.
|
||||
|
||||
## Functionality and Usage
|
||||
The Mixtral module offers a straightforward interface for text generation. It accepts a task or prompt as input and returns generated text based on the provided input.
|
||||
|
||||
### `run(task: Optional[str] = None, **kwargs) -> str`
|
||||
|
||||
#### Parameters
|
||||
- `task` (str, optional): The task or prompt for text generation.
|
||||
|
||||
#### Returns
|
||||
- `str`: The generated text.
|
||||
|
||||
## Usage Examples
|
||||
### Example 1: Basic Usage
|
||||
|
||||
```python
|
||||
from swarms.models import Mixtral
|
||||
|
||||
# Initialize the Mixtral model
|
||||
mixtral = Mixtral()
|
||||
|
||||
# Generate text for a simple task
|
||||
generated_text = mixtral.run("Generate a creative story.")
|
||||
print(generated_text)
|
||||
```
|
||||
|
||||
### Example 2: Custom Model
|
||||
|
||||
You can specify a custom pre-trained model by providing the `model_name` parameter.
|
||||
|
||||
```python
|
||||
custom_model_name = "model_name"
|
||||
mixtral_custom = Mixtral(model_name=custom_model_name)
|
||||
|
||||
generated_text = mixtral_custom.run("Generate text with a custom model.")
|
||||
print(generated_text)
|
||||
```
|
||||
|
||||
### Example 3: Controlling Output Length
|
||||
|
||||
You can control the length of the generated text by adjusting the `max_new_tokens` parameter.
|
||||
|
||||
```python
|
||||
mixtral_length = Mixtral(max_new_tokens=100)
|
||||
|
||||
generated_text = mixtral_length.run("Generate a short text.")
|
||||
print(generated_text)
|
||||
```
|
||||
|
||||
## Additional Information and Tips
|
||||
- It's recommended to use a descriptive task or prompt to guide the text generation process.
|
||||
- Experiment with different prompt styles and lengths to achieve the desired output.
|
||||
- You can fine-tune Mixtral on specific tasks if needed, although pre-trained models often work well out of the box.
|
||||
- Monitor the `max_new_tokens` parameter to control the length of the generated text.
|
||||
|
||||
## Conclusion
|
||||
The Mixtral module is a versatile tool for text generation tasks, powered by the MistralAI Mixtral-8x7B pre-trained model. Whether you need creative writing, content generation, or assistance with text-based tasks, Mixtral can help you achieve your goals. With a simple interface and flexible parameters, it's a valuable addition to your text generation toolkit.
|
||||
|
||||
If you encounter any issues or have questions about using Mixtral, please refer to the MistralAI documentation or reach out to their support team for further assistance. Happy text generation with Mixtral!
|
@ -0,0 +1,73 @@
|
||||
from typing import Optional
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from swarms.models.base_llm import AbstractLLM
|
||||
|
||||
|
||||
class Mixtral(AbstractLLM):
|
||||
"""Mixtral model.
|
||||
|
||||
Args:
|
||||
model_name (str): The name or path of the pre-trained Mixtral model.
|
||||
max_new_tokens (int): The maximum number of new tokens to generate.
|
||||
*args: Variable length argument list.
|
||||
|
||||
|
||||
Examples:
|
||||
>>> from swarms.models import Mixtral
|
||||
>>> mixtral = Mixtral()
|
||||
>>> mixtral.run("Test task")
|
||||
'Generated text'
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "mistralai/Mixtral-8x7B-v0.1",
|
||||
max_new_tokens: int = 500,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes a Mixtral model.
|
||||
|
||||
Args:
|
||||
model_name (str): The name or path of the pre-trained Mixtral model.
|
||||
max_new_tokens (int): The maximum number of new tokens to generate.
|
||||
*args: Variable length argument list.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_name = model_name
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, *args, **kwargs
|
||||
)
|
||||
|
||||
def run(self, task: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Generates text based on the given task.
|
||||
|
||||
Args:
|
||||
task (str, optional): The task or prompt for text generation.
|
||||
|
||||
Returns:
|
||||
str: The generated text.
|
||||
"""
|
||||
try:
|
||||
inputs = self.tokenizer(task, return_tensors="pt")
|
||||
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
out = self.tokenizer.decode(
|
||||
outputs[0],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
return out
|
||||
except Exception as error:
|
||||
print(f"There is an error: {error} in Mixtral model.")
|
||||
raise error
|
@ -0,0 +1,53 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from swarms.models.mixtral import Mixtral
|
||||
|
||||
|
||||
@patch("swarms.models.mixtral.AutoTokenizer")
|
||||
@patch("swarms.models.mixtral.AutoModelForCausalLM")
|
||||
def test_mixtral_init(mock_model, mock_tokenizer):
|
||||
mixtral = Mixtral()
|
||||
mock_tokenizer.from_pretrained.assert_called_once()
|
||||
mock_model.from_pretrained.assert_called_once()
|
||||
assert mixtral.model_name == "mistralai/Mixtral-8x7B-v0.1"
|
||||
assert mixtral.max_new_tokens == 20
|
||||
|
||||
|
||||
@patch("swarms.models.mixtral.AutoTokenizer")
|
||||
@patch("swarms.models.mixtral.AutoModelForCausalLM")
|
||||
def test_mixtral_run(mock_model, mock_tokenizer):
|
||||
mixtral = Mixtral()
|
||||
mock_tokenizer_instance = MagicMock()
|
||||
mock_model_instance = MagicMock()
|
||||
mock_tokenizer.from_pretrained.return_value = (
|
||||
mock_tokenizer_instance
|
||||
)
|
||||
mock_model.from_pretrained.return_value = mock_model_instance
|
||||
mock_tokenizer_instance.return_tensors = "pt"
|
||||
mock_model_instance.generate.return_value = [101, 102, 103]
|
||||
mock_tokenizer_instance.decode.return_value = "Generated text"
|
||||
result = mixtral.run("Test task")
|
||||
assert result == "Generated text"
|
||||
mock_tokenizer_instance.assert_called_once_with(
|
||||
"Test task", return_tensors="pt"
|
||||
)
|
||||
mock_model_instance.generate.assert_called_once()
|
||||
mock_tokenizer_instance.decode.assert_called_once_with(
|
||||
[101, 102, 103], skip_special_tokens=True
|
||||
)
|
||||
|
||||
|
||||
@patch("swarms.models.mixtral.AutoTokenizer")
|
||||
@patch("swarms.models.mixtral.AutoModelForCausalLM")
|
||||
def test_mixtral_run_error(mock_model, mock_tokenizer):
|
||||
mixtral = Mixtral()
|
||||
mock_tokenizer_instance = MagicMock()
|
||||
mock_model_instance = MagicMock()
|
||||
mock_tokenizer.from_pretrained.return_value = (
|
||||
mock_tokenizer_instance
|
||||
)
|
||||
mock_model.from_pretrained.return_value = mock_model_instance
|
||||
mock_tokenizer_instance.return_tensors = "pt"
|
||||
mock_model_instance.generate.side_effect = Exception("Test error")
|
||||
with pytest.raises(Exception, match="Test error"):
|
||||
mixtral.run("Test task")
|
Loading…
Reference in new issue