You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
421 lines
13 KiB
421 lines
13 KiB
import asyncio
|
|
import concurrent.futures
|
|
import logging
|
|
from typing import List, Tuple
|
|
|
|
import torch
|
|
from termcolor import colored
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
AutoTokenizer,
|
|
BitsAndBytesConfig,
|
|
)
|
|
|
|
from swarms.models.base_llm import AbstractLLM
|
|
|
|
|
|
class HuggingfaceLLM(AbstractLLM):
|
|
"""
|
|
A class for running inference on a given model.
|
|
|
|
Attributes:
|
|
model_id (str): The ID of the model.
|
|
device (str): The device to run the model on (either 'cuda' or 'cpu').
|
|
max_length (int): The maximum length of the output sequence.
|
|
quantize (bool, optional): Whether to use quantization. Defaults to False.
|
|
quantization_config (dict, optional): The configuration for quantization.
|
|
verbose (bool, optional): Whether to print verbose logs. Defaults to False.
|
|
logger (logging.Logger, optional): The logger to use. Defaults to a basic logger.
|
|
|
|
Methods:
|
|
run(task: str, max_length: int = 500) -> str:
|
|
Generate a response based on the prompt text.
|
|
|
|
__call__(task: str, max_length: int = 500) -> str:
|
|
Generate a response based on the prompt text.
|
|
|
|
save_model(path: str):
|
|
Save the model to a given path.
|
|
|
|
gpu_available() -> bool:
|
|
Check if GPU is available.
|
|
|
|
memory_consumption() -> dict:
|
|
Get the memory consumption of the GPU.
|
|
|
|
print_dashboard(task: str):
|
|
Print dashboard.
|
|
|
|
set_device(device: str):
|
|
Changes the device used for inference.
|
|
|
|
set_max_length(max_length: int):
|
|
Set max_length.
|
|
|
|
set_verbose(verbose: bool):
|
|
Set verbose.
|
|
|
|
set_distributed(distributed: bool):
|
|
Set distributed.
|
|
|
|
set_decoding(decoding: bool):
|
|
Set decoding.
|
|
|
|
set_max_workers(max_workers: int):
|
|
Set max_workers.
|
|
|
|
set_repitition_penalty(repitition_penalty: float):
|
|
Set repitition_penalty.
|
|
|
|
set_no_repeat_ngram_size(no_repeat_ngram_size: int):
|
|
Set no_repeat_ngram_size.
|
|
|
|
set_temperature(temperature: float):
|
|
Set temperature.
|
|
|
|
set_top_k(top_k: int):
|
|
Set top_k.
|
|
|
|
set_top_p(top_p: float):
|
|
Set top_p.
|
|
|
|
set_quantize(quantize: bool):
|
|
Set quantize.
|
|
|
|
set_quantization_config(quantization_config: dict):
|
|
Set quantization_config.
|
|
|
|
set_model_id(model_id: str):
|
|
Set model_id.
|
|
|
|
set_model(model):
|
|
Set model.
|
|
|
|
set_tokenizer(tokenizer):
|
|
Set tokenizer.
|
|
|
|
set_logger(logger):
|
|
Set logger.
|
|
|
|
|
|
Examples:
|
|
>>> llm = HuggingfaceLLM(
|
|
... model_id="EleutherAI/gpt-neo-2.7B",
|
|
... device="cuda",
|
|
... max_length=500,
|
|
... quantize=True,
|
|
... quantization_config={
|
|
... "load_in_4bit": True,
|
|
... "bnb_4bit_use_double_quant": True,
|
|
... "bnb_4bit_quant_type": "nf4",
|
|
... "bnb_4bit_compute_dtype": torch.bfloat16,
|
|
... },
|
|
... )
|
|
>>> llm("Generate a 10,000 word blog on mental clarity and the benefits of meditation.")
|
|
'Generate a 10,000 word
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_id: str,
|
|
device: str = None,
|
|
max_length: int = 500,
|
|
quantize: bool = False,
|
|
quantization_config: dict = None,
|
|
verbose=False,
|
|
distributed=False,
|
|
decoding=False,
|
|
max_workers: int = 5,
|
|
repitition_penalty: float = 1.3,
|
|
no_repeat_ngram_size: int = 5,
|
|
temperature: float = 0.7,
|
|
top_k: int = 40,
|
|
top_p: float = 0.8,
|
|
dtype=torch.bfloat16,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
super().__init__(*args, **kwargs)
|
|
self.logger = logging.getLogger(__name__)
|
|
self.device = (
|
|
device
|
|
if device
|
|
else ("cuda" if torch.cuda.is_available() else "cpu")
|
|
)
|
|
self.model_id = model_id
|
|
self.max_length = max_length
|
|
self.verbose = verbose
|
|
self.distributed = distributed
|
|
self.decoding = decoding
|
|
self.quantize = quantize
|
|
self.quantization_config = quantization_config
|
|
self.max_workers = max_workers
|
|
self.repitition_penalty = repitition_penalty
|
|
self.no_repeat_ngram_size = no_repeat_ngram_size
|
|
self.temperature = temperature
|
|
self.top_k = top_k
|
|
self.top_p = top_p
|
|
self.dtype = dtype
|
|
|
|
if self.distributed:
|
|
assert (
|
|
torch.cuda.device_count() > 1
|
|
), "You need more than 1 gpu for distributed processing"
|
|
|
|
bnb_config = None
|
|
if quantize:
|
|
if not quantization_config:
|
|
quantization_config = {
|
|
"load_in_4bit": True,
|
|
"bnb_4bit_use_double_quant": True,
|
|
"bnb_4bit_quant_type": "nf4",
|
|
"bnb_4bit_compute_dtype": dtype,
|
|
}
|
|
bnb_config = BitsAndBytesConfig(**quantization_config)
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
|
|
|
if quantize:
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
self.model_id,
|
|
quantization_config=bnb_config,
|
|
*args,
|
|
**kwargs,
|
|
).to(self.device)
|
|
else:
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
self.model_id, *args, **kwargs
|
|
).to(self.device)
|
|
|
|
def print_error(self, error: str):
|
|
"""Print error"""
|
|
print(colored(f"Error: {error}", "red"))
|
|
|
|
async def async_run(self, task: str):
|
|
"""Ashcnronous generate text for a given prompt"""
|
|
return await asyncio.to_thread(self.run, task)
|
|
|
|
def concurrent_run(self, tasks: List[str], max_workers: int = 5):
|
|
"""Concurrently generate text for a list of prompts."""
|
|
with concurrent.futures.ThreadPoolExecutor(
|
|
max_workers=max_workers
|
|
) as executor:
|
|
results = list(executor.map(self.run, tasks))
|
|
return results
|
|
|
|
def run_batch(
|
|
self, tasks_images: List[Tuple[str, str]]
|
|
) -> List[str]:
|
|
"""Process a batch of tasks and images"""
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
futures = [
|
|
executor.submit(self.run, task, img)
|
|
for task, img in tasks_images
|
|
]
|
|
results = [future.result() for future in futures]
|
|
return results
|
|
|
|
def run(self, task: str, *args, **kwargs):
|
|
"""
|
|
Generate a response based on the prompt text.
|
|
|
|
Args:
|
|
- task (str): Text to prompt the model.
|
|
- max_length (int): Maximum length of the response.
|
|
|
|
Returns:
|
|
- Generated text (str).
|
|
"""
|
|
try:
|
|
inputs = self.tokenizer.encode(task, return_tensors="pt")
|
|
|
|
if self.decoding:
|
|
with torch.no_grad():
|
|
for _ in range(self.max_length):
|
|
output_sequence = []
|
|
|
|
outputs = self.model.generate(
|
|
inputs,
|
|
max_length=len(inputs) + 1,
|
|
do_sample=True,
|
|
)
|
|
output_tokens = outputs[0][-1]
|
|
output_sequence.append(output_tokens.item())
|
|
|
|
# print token in real-time
|
|
print(
|
|
self.tokenizer.decode(
|
|
[output_tokens],
|
|
skip_special_tokens=True,
|
|
),
|
|
end="",
|
|
flush=True,
|
|
)
|
|
inputs = outputs
|
|
else:
|
|
with torch.no_grad():
|
|
outputs = self.model.generate(
|
|
inputs,
|
|
max_length=self.max_length,
|
|
do_sample=True,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
|
|
return self.tokenizer.decode(
|
|
outputs[0], skip_special_tokens=True
|
|
)
|
|
except Exception as e:
|
|
print(
|
|
colored(
|
|
(
|
|
"HuggingfaceLLM could not generate text"
|
|
f" because of error: {e}, try optimizing your"
|
|
" arguments"
|
|
),
|
|
"red",
|
|
)
|
|
)
|
|
raise
|
|
|
|
def __call__(self, task: str, *args, **kwargs):
|
|
return self.run(task, *args, **kwargs)
|
|
|
|
async def __call_async__(self, task: str, *args, **kwargs) -> str:
|
|
"""Call the model asynchronously""" ""
|
|
return await self.run_async(task, *args, **kwargs)
|
|
|
|
def save_model(self, path: str):
|
|
"""Save the model to a given path"""
|
|
self.model.save_pretrained(path)
|
|
self.tokenizer.save_pretrained(path)
|
|
|
|
def gpu_available(self) -> bool:
|
|
"""Check if GPU is available"""
|
|
return torch.cuda.is_available()
|
|
|
|
def memory_consumption(self) -> dict:
|
|
"""Get the memory consumption of the GPU"""
|
|
if self.gpu_available():
|
|
torch.cuda.synchronize()
|
|
allocated = torch.cuda.memory_allocated()
|
|
reserved = torch.cuda.memory_reserved()
|
|
return {"allocated": allocated, "reserved": reserved}
|
|
else:
|
|
return {"error": "GPU not available"}
|
|
|
|
def print_dashboard(self, task: str):
|
|
"""Print dashboard"""
|
|
|
|
dashboard = print(
|
|
colored(
|
|
f"""
|
|
HuggingfaceLLM Dashboard
|
|
--------------------------------------------
|
|
Model Name: {self.model_id}
|
|
Tokenizer: {self.tokenizer}
|
|
Model MaxLength: {self.max_length}
|
|
Model Device: {self.device}
|
|
Model Quantization: {self.quantize}
|
|
Model Quantization Config: {self.quantization_config}
|
|
Model Verbose: {self.verbose}
|
|
Model Distributed: {self.distributed}
|
|
Model Decoding: {self.decoding}
|
|
|
|
----------------------------------------
|
|
Metadata:
|
|
Task Memory Consumption: {self.memory_consumption()}
|
|
GPU Available: {self.gpu_available()}
|
|
----------------------------------------
|
|
|
|
Task Environment:
|
|
Task: {task}
|
|
|
|
""",
|
|
"red",
|
|
)
|
|
)
|
|
|
|
print(dashboard)
|
|
|
|
def set_device(self, device):
|
|
"""
|
|
Changes the device used for inference.
|
|
|
|
Parameters
|
|
----------
|
|
device : str
|
|
The new device to use for inference.
|
|
"""
|
|
self.device = device
|
|
if self.model is not None:
|
|
self.model.to(self.device)
|
|
|
|
def set_max_length(self, max_length):
|
|
"""Set max_length"""
|
|
self.max_length = max_length
|
|
|
|
def clear_chat_history(self):
|
|
"""Clear chat history"""
|
|
self.chat_history = []
|
|
|
|
def set_verbose(self, verbose):
|
|
"""Set verbose"""
|
|
self.verbose = verbose
|
|
|
|
def set_distributed(self, distributed):
|
|
"""Set distributed"""
|
|
self.distributed = distributed
|
|
|
|
def set_decoding(self, decoding):
|
|
"""Set decoding"""
|
|
self.decoding = decoding
|
|
|
|
def set_max_workers(self, max_workers):
|
|
"""Set max_workers"""
|
|
self.max_workers = max_workers
|
|
|
|
def set_repitition_penalty(self, repitition_penalty):
|
|
"""Set repitition_penalty"""
|
|
self.repitition_penalty = repitition_penalty
|
|
|
|
def set_no_repeat_ngram_size(self, no_repeat_ngram_size):
|
|
"""Set no_repeat_ngram_size"""
|
|
self.no_repeat_ngram_size = no_repeat_ngram_size
|
|
|
|
def set_temperature(self, temperature):
|
|
"""Set temperature"""
|
|
self.temperature = temperature
|
|
|
|
def set_top_k(self, top_k):
|
|
"""Set top_k"""
|
|
self.top_k = top_k
|
|
|
|
def set_top_p(self, top_p):
|
|
"""Set top_p"""
|
|
self.top_p = top_p
|
|
|
|
def set_quantize(self, quantize):
|
|
"""Set quantize"""
|
|
self.quantize = quantize
|
|
|
|
def set_quantization_config(self, quantization_config):
|
|
"""Set quantization_config"""
|
|
self.quantization_config = quantization_config
|
|
|
|
def set_model_id(self, model_id):
|
|
"""Set model_id"""
|
|
self.model_id = model_id
|
|
|
|
def set_model(self, model):
|
|
"""Set model"""
|
|
self.model = model
|
|
|
|
def set_tokenizer(self, tokenizer):
|
|
"""Set tokenizer"""
|
|
self.tokenizer = tokenizer
|
|
|
|
def set_logger(self, logger):
|
|
"""Set logger"""
|
|
self.logger = logger
|