Former-commit-id: dc8f3d2b19a450169e395309d75e641716eb0375
pull/160/head
Kye 2 years ago
parent b1f3614dc9
commit 7aa61f32b6

@ -1,5 +1,7 @@
import torch import torch
import logging import logging
from transformers import BitsAndBytesConfig
from transformers import AutoModelForCausalLM, AutoTokenizer #,# BitsAndBytesConfig from transformers import AutoModelForCausalLM, AutoTokenizer #,# BitsAndBytesConfig
class HuggingFaceLLM: class HuggingFaceLLM:
@ -10,15 +12,15 @@ class HuggingFaceLLM:
self.max_length = max_length self.max_length = max_length
bnb_config = None bnb_config = None
# if quantize: if quantize:
# if not quantization_config: if not quantization_config:
# quantization_config = { quantization_config = {
# 'load_in_4bit': True, 'load_in_4bit': True,
# 'bnb_4bit_use_double_quant': True, 'bnb_4bit_use_double_quant': True,
# 'bnb_4bit_quant_type': "nf4", 'bnb_4bit_quant_type': "nf4",
# 'bnb_4bit_compute_dtype': torch.bfloat16 'bnb_4bit_compute_dtype': torch.bfloat16
# } }
# bnb_config = BitsAndBytesConfig(**quantization_config) bnb_config = BitsAndBytesConfig(**quantization_config)
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)

Loading…
Cancel
Save