parent
68a91fb7bb
commit
ea62d98887
@ -0,0 +1,96 @@
|
|||||||
|
## Llava3
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
import torch
|
||||||
|
from swarms.models.base_llm import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
|
class Llama3(BaseLLM):
|
||||||
|
"""
|
||||||
|
Llama3 class represents a Llama model for natural language generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id (str): The ID of the Llama model to use.
|
||||||
|
system_prompt (str): The system prompt to use for generating responses.
|
||||||
|
temperature (float): The temperature value for controlling the randomness of the generated responses.
|
||||||
|
top_p (float): The top-p value for controlling the diversity of the generated responses.
|
||||||
|
max_tokens (int): The maximum number of tokens to generate in the response.
|
||||||
|
**kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model_id (str): The ID of the Llama model being used.
|
||||||
|
system_prompt (str): The system prompt for generating responses.
|
||||||
|
temperature (float): The temperature value for generating responses.
|
||||||
|
top_p (float): The top-p value for generating responses.
|
||||||
|
max_tokens (int): The maximum number of tokens to generate in the response.
|
||||||
|
tokenizer (AutoTokenizer): The tokenizer for the Llama model.
|
||||||
|
model (AutoModelForCausalLM): The Llama model for generating responses.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
run(task, *args, **kwargs): Generates a response for the given task.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
|
system_prompt: str = None,
|
||||||
|
temperature: float = 0.6,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
max_tokens: int = 4000,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.model_id = model_id
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
self.temperature = temperature
|
||||||
|
self.top_p = top_p
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self, task: str, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Generates a response for the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (str): The user's task or input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated response.
|
||||||
|
|
||||||
|
"""
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": self.system_prompt},
|
||||||
|
{"role": "user", "content": task},
|
||||||
|
]
|
||||||
|
|
||||||
|
input_ids = self.tokenizer.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True, return_tensors="pt"
|
||||||
|
).to(self.model.device)
|
||||||
|
|
||||||
|
terminators = [
|
||||||
|
self.tokenizer.eos_token_id,
|
||||||
|
self.tokenizer.convert_tokens_to_ids("<|eot_id|>"),
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = self.model.generate(
|
||||||
|
input_ids,
|
||||||
|
max_new_tokens=self.max_tokens,
|
||||||
|
eos_token_id=terminators,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
response = outputs[0][input_ids.shape[-1] :]
|
||||||
|
return self.tokenizer.decode(
|
||||||
|
response, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
```
|
Loading…
Reference in new issue