feat: enhance model loading with fallback to Hugging Face repo and improved error handling

main
thinhlpg 3 weeks ago
parent bac5f3b4f7
commit 7cd4d18ee6

@ -16,7 +16,7 @@ from tavily import TavilyClient
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
# Import from config # Import from config
from config import GENERATOR_MODEL_DIR, logger from config import GENERATOR_MODEL_DIR, GENERATOR_MODEL_REPO_ID, logger
from src import ( from src import (
apply_chat_template, apply_chat_template,
build_user_prompt, build_user_prompt,
@ -64,13 +64,28 @@ def setup_model_and_tokenizer(model_path: str):
"""Initialize model and tokenizer.""" """Initialize model and tokenizer."""
logger.info(f"Setting up model from {model_path}...") logger.info(f"Setting up model from {model_path}...")
model = AutoModelForCausalLM.from_pretrained( try:
model_path, # Try loading from the provided path first
torch_dtype="float16", model = AutoModelForCausalLM.from_pretrained(
device_map="auto", model_path,
trust_remote_code=True, torch_dtype="float16",
) device_map="auto",
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
except Exception as e:
logger.warning(f"Failed to load from local path: {e}")
logger.info(f"Attempting to load directly from Hugging Face: {GENERATOR_MODEL_REPO_ID}")
# Fallback to the Hugging Face model repository
model = AutoModelForCausalLM.from_pretrained(
GENERATOR_MODEL_REPO_ID,
torch_dtype="float16",
device_map="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL_REPO_ID, trust_remote_code=True)
logger.info(f"Successfully loaded model from Hugging Face repo")
# Defaulting to the one used in inference.py, adjust if needed for your specific model # Defaulting to the one used in inference.py, adjust if needed for your specific model
assistant_marker = "<|start_header_id|>assistant<|end_header_id|>" assistant_marker = "<|start_header_id|>assistant<|end_header_id|>"
@ -897,9 +912,7 @@ def main():
# Display error if model fails to load # Display error if model fails to load
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("# Critical Error") gr.Markdown("# Critical Error")
gr.Markdown( gr.Markdown(f"Failed to load both local model and Hugging Face model. Error: {e}")
f"Failed to load model or tokenizer from '{model_path}'. Check the path and ensure the model exists.\n\nError: {e}"
)
demo.launch(share=True) demo.launch(share=True)
sys.exit(1) # Exit if model loading fails sys.exit(1) # Exit if model loading fails

Loading…
Cancel
Save