|
|
@ -5,7 +5,13 @@ import logging
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
class HuggingFaceLLM:
|
|
|
|
class HuggingFaceLLM:
|
|
|
|
def __init__(self, model_id: str, device: str = None, max_length: int = 20, quantize: bool = False, quantization_config: dict = None):
|
|
|
|
def __init__(self,
|
|
|
|
|
|
|
|
model_id: str,
|
|
|
|
|
|
|
|
device: str = None,
|
|
|
|
|
|
|
|
max_length: int = 20,
|
|
|
|
|
|
|
|
quantize: bool = False,
|
|
|
|
|
|
|
|
quantization_config: dict = None
|
|
|
|
|
|
|
|
):
|
|
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
self.model_id = model_id
|
|
|
|
self.model_id = model_id
|
|
|
|