refactor: simplify inference script by removing logger, load 16 bit model intead of raw lora finetuned

main
thinhlpg 1 month ago
parent da79e986b6
commit 58dcf9a99d

@ -1,127 +1,36 @@
""" """
Simple CLI inference script with search functionality. Simple CLI inference script with search functionality.
This script allows interaction with a model (with optional LoRA adapter) This script allows interaction with the merged 16-bit model
and provides search functionality for data retrieval. and provides search functionality for data retrieval.
""" """
import argparse import argparse
import json import json
import os import os
import re
import time import time
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from unsloth import FastLanguageModel from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import SamplingParams from vllm import SamplingParams
from src.config import MODEL_NAME, OUTPUT_DIR, logger
from src.search_module import load_vectorstore, search from src.search_module import load_vectorstore, search
def find_latest_checkpoint(search_dir=None): def setup_model_and_tokenizer(model_path: str):
""" """Initialize model and tokenizer."""
Find the latest checkpoint in the specified directory or OUTPUT_DIR. print(f"Setting up model from {model_path}...")
Args:
search_dir: Directory to search for checkpoints (default: OUTPUT_DIR)
Returns: model = AutoModelForCausalLM.from_pretrained(
Path to the latest checkpoint or None if no checkpoints found model_path,
""" torch_dtype="float16",
if search_dir is None: device_map="auto",
search_dir = "trainer_output_meta-llama_Llama-3.1-8B-Instruct_gpu1_20250326_134236" trust_remote_code=True,
logger.info(f"No search directory provided, using default: {search_dir}")
else:
logger.info(f"Searching for checkpoints in: {search_dir}")
# Check if the directory exists first
if not os.path.exists(search_dir):
logger.warning(f"Search directory {search_dir} does not exist")
return None
# First try to find checkpoints in the format checkpoint-{step}
import glob
checkpoints = glob.glob(os.path.join(search_dir, "checkpoint-*"))
if checkpoints:
# Extract checkpoint numbers and sort
checkpoint_numbers = []
for checkpoint in checkpoints:
match = re.search(r"checkpoint-(\d+)$", checkpoint)
if match:
checkpoint_numbers.append((int(match.group(1)), checkpoint))
if checkpoint_numbers:
# Sort by checkpoint number (descending)
checkpoint_numbers.sort(reverse=True)
latest = checkpoint_numbers[0][1]
logger.info(f"Found latest checkpoint: {latest}")
return latest
# If no checkpoints found, look for saved_adapter_{timestamp}.bin files
adapter_files = glob.glob(os.path.join(search_dir, "saved_adapter_*.bin"))
if adapter_files:
# Sort by modification time (newest first)
adapter_files.sort(key=os.path.getmtime, reverse=True)
latest = adapter_files[0]
logger.info(f"Found latest adapter file: {latest}")
return latest
# If all else fails, look for any .bin files
bin_files = glob.glob(os.path.join(search_dir, "*.bin"))
if bin_files:
# Sort by modification time (newest first)
bin_files.sort(key=os.path.getmtime, reverse=True)
latest = bin_files[0]
logger.info(f"Found latest .bin file: {latest}")
return latest
logger.warning(f"No checkpoints found in {search_dir}")
return None
def setup_model_and_tokenizer():
"""Initialize model and tokenizer with LoRA support."""
config = {
"max_seq_length": 4096 * 2,
"lora_rank": 64,
"gpu_memory_utilization": 0.6,
"model_name": MODEL_NAME,
"target_modules": [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
}
logger.info(f"Setting up model {config['model_name']} with LoRA support...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=config["model_name"],
max_seq_length=config["max_seq_length"],
load_in_4bit=True,
fast_inference=True,
max_lora_rank=config["lora_rank"],
gpu_memory_utilization=config["gpu_memory_utilization"],
) )
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Setup LoRA print("Model and tokenizer setup complete.")
model = FastLanguageModel.get_peft_model(
model,
r=config["lora_rank"],
target_modules=config["target_modules"],
lora_alpha=config["lora_rank"],
use_gradient_checkpointing=True,
random_state=3407,
)
logger.info("Model and tokenizer setup complete.")
return model, tokenizer return model, tokenizer
@ -205,7 +114,7 @@ class DeepSearchCLI:
def __init__( def __init__(
self, self,
lora_path: Optional[str] = None, model_path: str,
temperature: float = 0.7, temperature: float = 0.7,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
): ):
@ -213,15 +122,13 @@ class DeepSearchCLI:
Initialize the CLI. Initialize the CLI.
Args: Args:
lora_path: Path to LoRA weights (None for base model) model_path: Path to the merged 16-bit model
temperature: Sampling temperature temperature: Sampling temperature
system_prompt: Optional system prompt to guide the model's behavior system_prompt: Optional system prompt to guide the model's behavior
""" """
self.model, self.tokenizer = setup_model_and_tokenizer() self.model, self.tokenizer = setup_model_and_tokenizer(model_path)
self.lora_path = lora_path
self.temperature = temperature self.temperature = temperature
self.sampling_params = get_sampling_params(temperature) self.sampling_params = get_sampling_params(temperature)
self.lora_request = None
self.history = [] self.history = []
self.search_history = [] self.search_history = []
self.system_prompt = ( self.system_prompt = (
@ -234,14 +141,34 @@ When you receive a tool call response, use the output to format an answer to the
You are a helpful assistant with tool calling capabilities.""" You are a helpful assistant with tool calling capabilities."""
) )
# Load LoRA if specified def _run_agent_generation(self, chat_state: dict) -> dict:
if self.lora_path: """Run a single generation step for the agent."""
logger.info(f"Loading LoRA adapter from {self.lora_path}...") formatted_prompt = self.tokenizer.apply_chat_template(
self.lora_request = self.model.load_lora(self.lora_path) chat_state["messages"],
if self.lora_request: tokenize=False,
logger.info(f"LoRA adapter loaded successfully: {self.lora_request}") add_generation_prompt=True,
else: )
logger.error("Failed to load LoRA adapter")
start_time = time.time()
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=self.sampling_params.max_tokens,
temperature=self.sampling_params.temperature,
top_p=self.sampling_params.top_p,
do_sample=True,
)
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
gen_time = time.time() - start_time
print(f"Generation completed in {gen_time:.2f} seconds")
# Extract assistant response
assistant_response = response_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
chat_state["messages"].append({"role": "assistant", "content": assistant_response})
return chat_state
def generate(self, prompt: str, max_generations: int = 20) -> str: def generate(self, prompt: str, max_generations: int = 20) -> str:
""" """
@ -285,42 +212,6 @@ You are a helpful assistant with tool calling capabilities."""
return final_response return final_response
def _run_agent_generation(self, chat_state: dict) -> dict:
"""Run a single generation step for the agent."""
formatted_prompt = self.tokenizer.apply_chat_template(
chat_state["messages"],
tokenize=False,
add_generation_prompt=True,
)
start_time = time.time()
if self.lora_request:
response = self.model.fast_generate(
[formatted_prompt],
sampling_params=self.sampling_params,
lora_request=self.lora_request,
)
else:
response = self.model.fast_generate(
[formatted_prompt],
sampling_params=self.sampling_params,
)
gen_time = time.time() - start_time
logger.debug(f"Generation completed in {gen_time:.2f} seconds")
if hasattr(response[0], "outputs"):
response_text = response[0].outputs[0].text
else:
response_text = response[0]
# Extract assistant response
assistant_response = response_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
chat_state["messages"].append({"role": "assistant", "content": assistant_response})
return chat_state
def _check_finished_chat(self, chat_state: dict) -> dict: def _check_finished_chat(self, chat_state: dict) -> dict:
"""Check if the chat is finished (no more function calls).""" """Check if the chat is finished (no more function calls)."""
if chat_state.get("finished"): if chat_state.get("finished"):
@ -329,7 +220,7 @@ You are a helpful assistant with tool calling capabilities."""
assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant" assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant"
assistant_response = chat_state["messages"][-1]["content"] assistant_response = chat_state["messages"][-1]["content"]
function_calls = extract_json_objects(assistant_response) function_calls = extract_function_calls(assistant_response)
if len(function_calls) == 0: if len(function_calls) == 0:
chat_state["finished"] = True chat_state["finished"] = True
@ -343,16 +234,16 @@ You are a helpful assistant with tool calling capabilities."""
try: try:
assistant_response = chat_state["messages"][-1]["content"] assistant_response = chat_state["messages"][-1]["content"]
function_calls = extract_json_objects(assistant_response) function_calls = extract_function_calls(assistant_response)
if len(function_calls) > 1: if len(function_calls) > 1:
logger.warning("Multiple function calls found in assistant response") print("Multiple function calls found in assistant response")
raise ValueError("Expected only one function call in assistant response") raise ValueError("Expected only one function call in assistant response")
elif len(function_calls) == 1: elif len(function_calls) == 1:
function_call = function_calls[0] function_call = function_calls[0]
query = function_call["function"]["parameters"]["query"] query = function_call["function"]["parameters"]["query"]
logger.info(f"🔍 Search Query: {query}") print(f"🔍 Search Query: {query}")
results = search(query, return_type=str, results=2) results = search(query, return_type=str, results=2)
@ -373,7 +264,7 @@ You are a helpful assistant with tool calling capabilities."""
self.search_history.append(search_entry) self.search_history.append(search_entry)
except Exception as e: except Exception as e:
logger.error(f"Error during tool call: {str(e)}") print(f"Error during tool call: {str(e)}")
chat_state["messages"].append({"role": "system", "content": f"Error during post-processing: {str(e)}"}) chat_state["messages"].append({"role": "system", "content": f"Error during post-processing: {str(e)}"})
chat_state["finished"] = True chat_state["finished"] = True
@ -383,7 +274,7 @@ You are a helpful assistant with tool calling capabilities."""
"""Clear the conversation history.""" """Clear the conversation history."""
self.history = [] self.history = []
self.search_history = [] self.search_history = []
logger.info("Conversation history cleared.") print("Conversation history cleared.")
def set_system_prompt(self, prompt: str): def set_system_prompt(self, prompt: str):
""" """
@ -393,35 +284,32 @@ You are a helpful assistant with tool calling capabilities."""
prompt: The new system prompt prompt: The new system prompt
""" """
if not prompt: if not prompt:
logger.warning("System prompt cannot be empty. Using default.") print("System prompt cannot be empty. Using default.")
return return
self.system_prompt = prompt self.system_prompt = prompt
logger.info("System prompt updated.") print("System prompt updated.")
logger.info(f"New system prompt: {self.system_prompt}") print(f"New system prompt: {self.system_prompt}")
def display_welcome(self): def display_welcome(self):
"""Display welcome message.""" """Display welcome message."""
model_type = "LoRA" if self.lora_path else "Base" print(f"\n{'=' * 50}")
logger.info(f"\n{'=' * 50}") print(f"DeepSearch CLI - {self.model.name_or_path}")
logger.info(f"DeepSearch CLI - {model_type} Model") print(f"Model: {self.model.name_or_path}")
logger.info(f"Model: {MODEL_NAME}") print(f"Temperature: {self.temperature}")
logger.info(f"Temperature: {self.temperature}") print(f"System Prompt: {self.system_prompt}")
if self.lora_path: print(f"{'=' * 50}")
logger.info(f"LoRA Path: {self.lora_path}") print("Type 'help' to see available commands.")
logger.info(f"System Prompt: {self.system_prompt}")
logger.info(f"{'=' * 50}")
logger.info("Type 'help' to see available commands.")
def print_pretty_chat_history(self): def print_pretty_chat_history(self):
"""Print the full chat history in a pretty format, including searches.""" """Print the full chat history in a pretty format, including searches."""
if not self.history: if not self.history:
logger.info("No chat history available.") print("No chat history available.")
return return
logger.info("\n" + "=" * 80) print("\n" + "=" * 80)
logger.info("CHAT HISTORY WITH SEARCH DETAILS") print("CHAT HISTORY WITH SEARCH DETAILS")
logger.info("=" * 80) print("=" * 80)
# Group history into conversation turns # Group history into conversation turns
for i in range(0, len(self.history), 2): for i in range(0, len(self.history), 2):
@ -430,26 +318,26 @@ You are a helpful assistant with tool calling capabilities."""
# Print user message # Print user message
if i < len(self.history): if i < len(self.history):
user_msg = self.history[i]["content"] user_msg = self.history[i]["content"]
logger.info(f"\n[Turn {turn_number + 1}] USER: ") print(f"\n[Turn {turn_number + 1}] USER: ")
logger.info("-" * 40) print("-" * 40)
logger.info(user_msg) print(user_msg)
# Print searches associated with this turn if any # Print searches associated with this turn if any
for search_entry in self.search_history: for search_entry in self.search_history:
if search_entry["turn"] == turn_number: if search_entry["turn"] == turn_number:
for idx, search in enumerate(search_entry["searches"]): for idx, search in enumerate(search_entry["searches"]):
logger.info(f'\n🔍 SEARCH {idx + 1}: "{search["query"]}"') print(f'\n🔍 SEARCH {idx + 1}: "{search["query"]}"')
logger.info("-" * 40) print("-" * 40)
logger.info(search["results"]) print(search["results"])
# Print assistant response # Print assistant response
if i + 1 < len(self.history): if i + 1 < len(self.history):
assistant_msg = self.history[i + 1]["content"] assistant_msg = self.history[i + 1]["content"]
logger.info(f"\n[Turn {turn_number + 1}] ASSISTANT: ") print(f"\n[Turn {turn_number + 1}] ASSISTANT: ")
logger.info("-" * 40) print("-" * 40)
logger.info(assistant_msg) print(assistant_msg)
logger.info("\n" + "=" * 80 + "\n") print("\n" + "=" * 80 + "\n")
def save_chat_history(self, filepath=None): def save_chat_history(self, filepath=None):
""" """
@ -462,14 +350,13 @@ You are a helpful assistant with tool calling capabilities."""
Path to the saved file Path to the saved file
""" """
if not self.history: if not self.history:
logger.info("No chat history to save.") print("No chat history to save.")
return None return None
# Generate a default filepath if none provided # Generate a default filepath if none provided
if filepath is None: if filepath is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_type = "lora" if self.lora_path else "base" filepath = os.path.join(os.getcwd(), f"chat_history_{timestamp}.txt")
filepath = os.path.join(OUTPUT_DIR, f"chat_history_{model_type}_{timestamp}.txt")
# Ensure the directory exists # Ensure the directory exists
os.makedirs(os.path.dirname(filepath), exist_ok=True) os.makedirs(os.path.dirname(filepath), exist_ok=True)
@ -499,8 +386,7 @@ You are a helpful assistant with tool calling capabilities."""
with open(filepath, "w", encoding="utf-8") as f: with open(filepath, "w", encoding="utf-8") as f:
f.write(f"{'=' * 80}\n") f.write(f"{'=' * 80}\n")
f.write("DEEPSEARCH CHAT HISTORY\n") f.write("DEEPSEARCH CHAT HISTORY\n")
f.write(f"Model: {MODEL_NAME}\n") f.write(f"Model: {self.model.name_or_path}\n")
f.write(f"LoRA Path: {self.lora_path if self.lora_path else 'None'}\n")
f.write(f"Temperature: {self.temperature}\n") f.write(f"Temperature: {self.temperature}\n")
f.write(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"{'=' * 80}\n\n") f.write(f"{'=' * 80}\n\n")
@ -521,11 +407,11 @@ You are a helpful assistant with tool calling capabilities."""
f.write(f"{turn['assistant']}\n\n") f.write(f"{turn['assistant']}\n\n")
f.write(f"{'=' * 40}\n\n") f.write(f"{'=' * 40}\n\n")
logger.info(f"Chat history saved to: {filepath}") print(f"Chat history saved to: {filepath}")
return filepath return filepath
except Exception as e: except Exception as e:
logger.error(f"Error saving chat history: {e}") print(f"Error saving chat history: {e}")
return None return None
def save_chat_history_json(self, filepath=None): def save_chat_history_json(self, filepath=None):
@ -539,22 +425,20 @@ You are a helpful assistant with tool calling capabilities."""
Path to the saved file Path to the saved file
""" """
if not self.history: if not self.history:
logger.info("No chat history to save.") print("No chat history to save.")
return None return None
# Generate a default filepath if none provided # Generate a default filepath if none provided
if filepath is None: if filepath is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_type = "lora" if self.lora_path else "base" filepath = os.path.join(os.getcwd(), f"chat_history_{timestamp}.json")
filepath = os.path.join(OUTPUT_DIR, f"chat_history_{model_type}_{timestamp}.json")
# Ensure the directory exists # Ensure the directory exists
os.makedirs(os.path.dirname(filepath), exist_ok=True) os.makedirs(os.path.dirname(filepath), exist_ok=True)
# Prepare chat history data # Prepare chat history data
history_data = { history_data = {
"model": MODEL_NAME, "model": self.model.name_or_path,
"lora_path": self.lora_path if self.lora_path else None,
"temperature": self.temperature, "temperature": self.temperature,
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"turns": [], "turns": [],
@ -582,26 +466,26 @@ You are a helpful assistant with tool calling capabilities."""
with open(filepath, "w", encoding="utf-8") as f: with open(filepath, "w", encoding="utf-8") as f:
json.dump(history_data, f, indent=2, ensure_ascii=False) json.dump(history_data, f, indent=2, ensure_ascii=False)
logger.info(f"Chat history saved to JSON: {filepath}") print(f"Chat history saved to JSON: {filepath}")
return filepath return filepath
except Exception as e: except Exception as e:
logger.error(f"Error saving chat history to JSON: {e}") print(f"Error saving chat history to JSON: {e}")
return None return None
def display_help(self): def display_help(self):
"""Display help information.""" """Display help information."""
logger.info("\n===== Commands =====") print("\n===== Commands =====")
logger.info("search <query> - Search for information") print("search <query> - Search for information")
logger.info("system <prompt> - Set a new system prompt") print("system <prompt> - Set a new system prompt")
logger.info("clear - Clear conversation history") print("clear - Clear conversation history")
logger.info("history - Display full chat history with searches") print("history - Display full chat history with searches")
logger.info("save - Save chat history to a text file") print("save - Save chat history to a text file")
logger.info("savejson - Save chat history to a JSON file") print("savejson - Save chat history to a JSON file")
logger.info("help - Display this help message") print("help - Display this help message")
logger.info("exit/quit - Exit the program") print("exit/quit - Exit the program")
logger.info("Any other input will be treated as a prompt to the model.") print("Any other input will be treated as a prompt to the model.")
logger.info("===================\n") print("===================\n")
def run(self): def run(self):
"""Run the CLI.""" """Run the CLI."""
@ -615,7 +499,7 @@ You are a helpful assistant with tool calling capabilities."""
continue continue
if user_input.lower() in ["exit", "quit"]: if user_input.lower() in ["exit", "quit"]:
logger.info("Exiting...") print("Exiting...")
break break
if user_input.lower() == "help": if user_input.lower() == "help":
@ -649,7 +533,7 @@ You are a helpful assistant with tool calling capabilities."""
try: try:
results = search(query, return_type=str) results = search(query, return_type=str)
formatted_results = format_search_results(results) formatted_results = format_search_results(results)
logger.info(formatted_results) print(formatted_results)
# Add to search history # Add to search history
search_entry = { search_entry = {
@ -658,22 +542,22 @@ You are a helpful assistant with tool calling capabilities."""
} }
self.search_history.append(search_entry) self.search_history.append(search_entry)
except Exception as e: except Exception as e:
logger.error(f"Error searching: {e}") print(f"Error searching: {e}")
else: else:
logger.warning("Please provide a search query.") print("Please provide a search query.")
continue continue
# Process as a prompt to the model # Process as a prompt to the model
logger.info("\nGenerating response...") print("\nGenerating response...")
response = self.generate(user_input) response = self.generate(user_input)
logger.info("\n----- Response -----") print("\n----- Response -----")
logger.info(response) print(response)
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("\nExiting...") print("\nExiting...")
break break
except Exception as e: except Exception as e:
logger.error(f"Error: {e}") print(f"Error: {e}")
def extract_json_objects(text): def extract_json_objects(text):
@ -743,10 +627,10 @@ def main():
"""Main function.""" """Main function."""
parser = argparse.ArgumentParser(description="DeepSearch CLI") parser = argparse.ArgumentParser(description="DeepSearch CLI")
parser.add_argument( parser.add_argument(
"--lora_path", "--model_path",
type=str, type=str,
default="auto", default="trainer_output_example/model_merged_16bit",
help="Path to LoRA weights (None for base model, 'auto' for auto-detection)", help="Path to the merged 16-bit model (default: trainer_output_example/model_merged_16bit)",
) )
parser.add_argument( parser.add_argument(
"--temperature", "--temperature",
@ -762,22 +646,9 @@ def main():
) )
args = parser.parse_args() args = parser.parse_args()
# Auto-detect LoRA path if requested
lora_path = None
if args.lora_path and args.lora_path.lower() != "none":
if args.lora_path == "auto":
detected_path = find_latest_checkpoint()
if detected_path:
lora_path = detected_path
logger.info(f"Auto-detected LoRA path: {lora_path}")
else:
logger.warning("No LoRA checkpoint found. Using base model.")
else:
lora_path = args.lora_path
# Initialize and run the CLI # Initialize and run the CLI
cli = DeepSearchCLI( cli = DeepSearchCLI(
lora_path=lora_path, model_path=args.model_path,
temperature=args.temperature, temperature=args.temperature,
system_prompt=args.system_prompt, system_prompt=args.system_prompt,
) )
@ -787,6 +658,6 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
# Ensure the vectorstore is loaded # Ensure the vectorstore is loaded
if load_vectorstore() is None: if load_vectorstore() is None:
logger.warning("FAISS vectorstore could not be loaded. Search functionality may not work.") print("FAISS vectorstore could not be loaded. Search functionality may not work.")
main() main()

Loading…
Cancel
Save