diff --git a/app.py b/app.py index 2f2a898..c158e0a 100644 --- a/app.py +++ b/app.py @@ -16,7 +16,7 @@ from tavily import TavilyClient from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer # Import from config -from config import GENERATOR_MODEL_DIR, logger +from config import GENERATOR_MODEL_DIR, GENERATOR_MODEL_REPO_ID, logger from src import ( apply_chat_template, build_user_prompt, @@ -64,13 +64,28 @@ def setup_model_and_tokenizer(model_path: str): """Initialize model and tokenizer.""" logger.info(f"Setting up model from {model_path}...") - model = AutoModelForCausalLM.from_pretrained( - model_path, - torch_dtype="float16", - device_map="auto", - trust_remote_code=True, - ) - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + try: + # Try loading from the provided path first + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype="float16", + device_map="auto", + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + except Exception as e: + logger.warning(f"Failed to load from local path: {e}") + logger.info(f"Attempting to load directly from Hugging Face: {GENERATOR_MODEL_REPO_ID}") + + # Fallback to the Hugging Face model repository + model = AutoModelForCausalLM.from_pretrained( + GENERATOR_MODEL_REPO_ID, + torch_dtype="float16", + device_map="auto", + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL_REPO_ID, trust_remote_code=True) + logger.info(f"Successfully loaded model from Hugging Face repo") # Defaulting to the one used in inference.py, adjust if needed for your specific model assistant_marker = "<|start_header_id|>assistant<|end_header_id|>" @@ -897,9 +912,7 @@ def main(): # Display error if model fails to load with gr.Blocks() as demo: gr.Markdown("# Critical Error") - gr.Markdown( - f"Failed to load model or tokenizer from '{model_path}'. Check the path and ensure the model exists.\n\nError: {e}" - ) + gr.Markdown(f"Failed to load both local model and Hugging Face model. Error: {e}") demo.launch(share=True) sys.exit(1) # Exit if model loading fails