diff --git a/swarms/agents/models/huggingface.py b/swarms/agents/models/huggingface.py index 1eb36d79..4c658d80 100644 --- a/swarms/agents/models/huggingface.py +++ b/swarms/agents/models/huggingface.py @@ -1,5 +1,7 @@ import torch import logging +from transformers import BitsAndBytesConfig + from transformers import AutoModelForCausalLM, AutoTokenizer #,# BitsAndBytesConfig class HuggingFaceLLM: @@ -10,15 +12,15 @@ class HuggingFaceLLM: self.max_length = max_length bnb_config = None - # if quantize: - # if not quantization_config: - # quantization_config = { - # 'load_in_4bit': True, - # 'bnb_4bit_use_double_quant': True, - # 'bnb_4bit_quant_type': "nf4", - # 'bnb_4bit_compute_dtype': torch.bfloat16 - # } - # bnb_config = BitsAndBytesConfig(**quantization_config) + if quantize: + if not quantization_config: + quantization_config = { + 'load_in_4bit': True, + 'bnb_4bit_use_double_quant': True, + 'bnb_4bit_quant_type': "nf4", + 'bnb_4bit_compute_dtype': torch.bfloat16 + } + bnb_config = BitsAndBytesConfig(**quantization_config) try: self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)