pull/53/head
Kye 2 years ago
parent c2f465265e
commit 272cb5fd1c

@ -1,5 +1,7 @@
import torch
import logging
from transformers import BitsAndBytesConfig
from transformers import AutoModelForCausalLM, AutoTokenizer #,# BitsAndBytesConfig
class HuggingFaceLLM:
@ -10,15 +12,15 @@ class HuggingFaceLLM:
self.max_length = max_length
bnb_config = None
# if quantize:
# if not quantization_config:
# quantization_config = {
# 'load_in_4bit': True,
# 'bnb_4bit_use_double_quant': True,
# 'bnb_4bit_quant_type': "nf4",
# 'bnb_4bit_compute_dtype': torch.bfloat16
# }
# bnb_config = BitsAndBytesConfig(**quantization_config)
if quantize:
if not quantization_config:
quantization_config = {
'load_in_4bit': True,
'bnb_4bit_use_double_quant': True,
'bnb_4bit_quant_type': "nf4",
'bnb_4bit_compute_dtype': torch.bfloat16
}
bnb_config = BitsAndBytesConfig(**quantization_config)
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)

Loading…
Cancel
Save