|
|
|
@ -16,7 +16,7 @@ from tavily import TavilyClient
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
|
|
|
|
|
|
|
|
|
|
# Import from config
|
|
|
|
|
from config import GENERATOR_MODEL_DIR, logger
|
|
|
|
|
from config import GENERATOR_MODEL_DIR, GENERATOR_MODEL_REPO_ID, logger
|
|
|
|
|
from src import (
|
|
|
|
|
apply_chat_template,
|
|
|
|
|
build_user_prompt,
|
|
|
|
@ -64,6 +64,8 @@ def setup_model_and_tokenizer(model_path: str):
|
|
|
|
|
"""Initialize model and tokenizer."""
|
|
|
|
|
logger.info(f"Setting up model from {model_path}...")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Try loading from the provided path first
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
model_path,
|
|
|
|
|
torch_dtype="float16",
|
|
|
|
@ -71,6 +73,19 @@ def setup_model_and_tokenizer(model_path: str):
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
)
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Failed to load from local path: {e}")
|
|
|
|
|
logger.info(f"Attempting to load directly from Hugging Face: {GENERATOR_MODEL_REPO_ID}")
|
|
|
|
|
|
|
|
|
|
# Fallback to the Hugging Face model repository
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
GENERATOR_MODEL_REPO_ID,
|
|
|
|
|
torch_dtype="float16",
|
|
|
|
|
device_map="auto",
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
)
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL_REPO_ID, trust_remote_code=True)
|
|
|
|
|
logger.info(f"Successfully loaded model from Hugging Face repo")
|
|
|
|
|
|
|
|
|
|
# Defaulting to the one used in inference.py, adjust if needed for your specific model
|
|
|
|
|
assistant_marker = "<|start_header_id|>assistant<|end_header_id|>"
|
|
|
|
@ -897,9 +912,7 @@ def main():
|
|
|
|
|
# Display error if model fails to load
|
|
|
|
|
with gr.Blocks() as demo:
|
|
|
|
|
gr.Markdown("# Critical Error")
|
|
|
|
|
gr.Markdown(
|
|
|
|
|
f"Failed to load model or tokenizer from '{model_path}'. Check the path and ensure the model exists.\n\nError: {e}"
|
|
|
|
|
)
|
|
|
|
|
gr.Markdown(f"Failed to load both local model and Hugging Face model. Error: {e}")
|
|
|
|
|
demo.launch(share=True)
|
|
|
|
|
sys.exit(1) # Exit if model loading fails
|
|
|
|
|
|
|
|
|
|