From abb18b10d854c05a9204bdd3b9ee3d58673da01c Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Tue, 1 Apr 2025 04:03:38 +0700 Subject: [PATCH] feat: add CLI inference script with search functionality This script is a bit dumb, but it worked. I'll update it later XD --- inference.py | 792 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 792 insertions(+) create mode 100644 inference.py diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..6f97746 --- /dev/null +++ b/inference.py @@ -0,0 +1,792 @@ +""" +Simple CLI inference script with search functionality. + +This script allows interaction with a model (with optional LoRA adapter) +and provides search functionality for data retrieval. +""" + +import argparse +import json +import os +import re +import time +from datetime import datetime +from typing import Any, Dict, List, Optional, Union + +from unsloth import FastLanguageModel +from vllm import SamplingParams + +from src.config import MODEL_NAME, OUTPUT_DIR, logger +from src.search_module import load_vectorstore, search + + +def find_latest_checkpoint(search_dir=None): + """ + Find the latest checkpoint in the specified directory or OUTPUT_DIR. + + Args: + search_dir: Directory to search for checkpoints (default: OUTPUT_DIR) + + Returns: + Path to the latest checkpoint or None if no checkpoints found + """ + if search_dir is None: + search_dir = "trainer_output_meta-llama_Llama-3.1-8B-Instruct_gpu1_20250326_134236" + logger.info(f"No search directory provided, using default: {search_dir}") + else: + logger.info(f"Searching for checkpoints in: {search_dir}") + + # Check if the directory exists first + if not os.path.exists(search_dir): + logger.warning(f"Search directory {search_dir} does not exist") + return None + + # First try to find checkpoints in the format checkpoint-{step} + import glob + + checkpoints = glob.glob(os.path.join(search_dir, "checkpoint-*")) + + if checkpoints: + # Extract checkpoint numbers and sort + checkpoint_numbers = [] + for checkpoint in checkpoints: + match = re.search(r"checkpoint-(\d+)$", checkpoint) + if match: + checkpoint_numbers.append((int(match.group(1)), checkpoint)) + + if checkpoint_numbers: + # Sort by checkpoint number (descending) + checkpoint_numbers.sort(reverse=True) + latest = checkpoint_numbers[0][1] + logger.info(f"Found latest checkpoint: {latest}") + return latest + + # If no checkpoints found, look for saved_adapter_{timestamp}.bin files + adapter_files = glob.glob(os.path.join(search_dir, "saved_adapter_*.bin")) + if adapter_files: + # Sort by modification time (newest first) + adapter_files.sort(key=os.path.getmtime, reverse=True) + latest = adapter_files[0] + logger.info(f"Found latest adapter file: {latest}") + return latest + + # If all else fails, look for any .bin files + bin_files = glob.glob(os.path.join(search_dir, "*.bin")) + if bin_files: + # Sort by modification time (newest first) + bin_files.sort(key=os.path.getmtime, reverse=True) + latest = bin_files[0] + logger.info(f"Found latest .bin file: {latest}") + return latest + + logger.warning(f"No checkpoints found in {search_dir}") + return None + + +def setup_model_and_tokenizer(): + """Initialize model and tokenizer with LoRA support.""" + config = { + "max_seq_length": 4096 * 2, + "lora_rank": 64, + "gpu_memory_utilization": 0.6, + "model_name": MODEL_NAME, + "target_modules": [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + } + + logger.info(f"Setting up model {config['model_name']} with LoRA support...") + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=config["model_name"], + max_seq_length=config["max_seq_length"], + load_in_4bit=True, + fast_inference=True, + max_lora_rank=config["lora_rank"], + gpu_memory_utilization=config["gpu_memory_utilization"], + ) + + # Setup LoRA + model = FastLanguageModel.get_peft_model( + model, + r=config["lora_rank"], + target_modules=config["target_modules"], + lora_alpha=config["lora_rank"], + use_gradient_checkpointing=True, + random_state=3407, + ) + + logger.info("Model and tokenizer setup complete.") + return model, tokenizer + + +def get_sampling_params(temperature: float = 0.7, max_tokens: int = 4096) -> SamplingParams: + """Get sampling parameters for generation.""" + return SamplingParams( + temperature=temperature, + top_p=0.95, + max_tokens=max_tokens, + ) + + +def extract_function_calls(text: str) -> List[Dict[str, Any]]: + """Extract function calls from a text.""" + import json + import re + + # Pattern to match JSON objects + pattern = r"\{(?:[^{}]|(?:\{(?:[^{}]|(?:\{[^{}]*\}))*\}))*\}" + json_matches = re.findall(pattern, text) + + function_calls = [] + for json_str in json_matches: + try: + obj = json.loads(json_str) + if "function" in obj: + function_calls.append(obj) + except json.JSONDecodeError: + continue + + return function_calls + + +def build_user_prompt(q): + """ + Build a user prompt with the question and search tool definition. + + Args: + q (str): The question to ask + + Returns: + str: Formatted user prompt + """ + user_prompt = f"""You are a research assistant, and you use the search_corpus tool to find answers to questions. +Given a question, answer it using by doing searches using the search_corpus tool. +To use the search_corpus tool, respond with a JSON for a function call with its proper arguments. + +PLEASE CONSIDER CHAT HISTORY WHEN ANSWERING THE QUESTION. +ONLY ANSWER WHEN YOU HAVE 100% CONFIDENCE IN THE SEARCH RESULTS, ELSE CONTINUE SEARCHING. +PLEASE SEARCH MULTIPLE TIMES WITH DIFFERENT QUERIES. + +You may also reason in any message, think step by step about how to answer the question. Wrap your reasoning in and tags. + +{json.dumps(SEARCH_TOOL_DEFINITION, indent=2)} + +Question: {q} +""" + return user_prompt + + +def format_search_results(results: Union[str, List[str]]) -> str: + """ + Format search results for display. + + Args: + results: Search results as string or list of strings + + Returns: + Formatted search results + """ + if isinstance(results, list): + content = "\n".join([f"Result {i + 1}:\n{r}\n------" for i, r in enumerate(results)]) + else: + content = results + + return f"\n===== SEARCH RESULTS =====\n{content}\n===========================\n" + + +class DeepSearchCLI: + """CLI for interacting with the model and search functionality.""" + + def __init__( + self, + lora_path: Optional[str] = None, + temperature: float = 0.7, + system_prompt: Optional[str] = None, + ): + """ + Initialize the CLI. + + Args: + lora_path: Path to LoRA weights (None for base model) + temperature: Sampling temperature + system_prompt: Optional system prompt to guide the model's behavior + """ + self.model, self.tokenizer = setup_model_and_tokenizer() + self.lora_path = lora_path + self.temperature = temperature + self.sampling_params = get_sampling_params(temperature) + self.lora_request = None + self.history = [] + self.search_history = [] + self.system_prompt = ( + system_prompt + or f"""Cutting Knowledge Date: December 2023 +Today Date: {datetime.now().strftime("%d %b %Y")} + +When you receive a tool call response, use the output to format an answer to the original user question. + +You are a helpful assistant with tool calling capabilities.""" + ) + + # Load LoRA if specified + if self.lora_path: + logger.info(f"Loading LoRA adapter from {self.lora_path}...") + self.lora_request = self.model.load_lora(self.lora_path) + if self.lora_request: + logger.info(f"LoRA adapter loaded successfully: {self.lora_request}") + else: + logger.error("Failed to load LoRA adapter") + + def generate(self, prompt: str, max_generations: int = 20) -> str: + """ + Generate a response to the prompt using agentic mechanism. + + Args: + prompt: The prompt to generate a response to + max_generations: Maximum number of turns in the conversation + + Returns: + The generated response after completing the conversation + """ + # Initialize chat state + chat_state = { + "messages": [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": build_user_prompt(prompt)}, + ], + "finished": False, + } + + # Agent loop + for i in range(max_generations): + # Generate response + chat_state = self._run_agent_generation(chat_state) + + # Check if conversation is finished + chat_state = self._check_finished_chat(chat_state) + if chat_state.get("finished"): + break + + # Process tool calls if any + chat_state = self._run_tool_calls(chat_state) + + # Get final response + final_response = chat_state["messages"][-1]["content"] + + # Update history + self.history.append({"role": "user", "content": prompt}) + self.history.append({"role": "assistant", "content": final_response}) + + return final_response + + def _run_agent_generation(self, chat_state: dict) -> dict: + """Run a single generation step for the agent.""" + formatted_prompt = self.tokenizer.apply_chat_template( + chat_state["messages"], + tokenize=False, + add_generation_prompt=True, + ) + + start_time = time.time() + if self.lora_request: + response = self.model.fast_generate( + [formatted_prompt], + sampling_params=self.sampling_params, + lora_request=self.lora_request, + ) + else: + response = self.model.fast_generate( + [formatted_prompt], + sampling_params=self.sampling_params, + ) + + gen_time = time.time() - start_time + logger.debug(f"Generation completed in {gen_time:.2f} seconds") + + if hasattr(response[0], "outputs"): + response_text = response[0].outputs[0].text + else: + response_text = response[0] + + # Extract assistant response + assistant_response = response_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1] + + chat_state["messages"].append({"role": "assistant", "content": assistant_response}) + + return chat_state + + def _check_finished_chat(self, chat_state: dict) -> dict: + """Check if the chat is finished (no more function calls).""" + if chat_state.get("finished"): + return chat_state + + assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant" + + assistant_response = chat_state["messages"][-1]["content"] + function_calls = extract_json_objects(assistant_response) + + if len(function_calls) == 0: + chat_state["finished"] = True + + return chat_state + + def _run_tool_calls(self, chat_state: dict) -> dict: + """Execute tool calls found in chat state.""" + if chat_state.get("finished"): + return chat_state + + try: + assistant_response = chat_state["messages"][-1]["content"] + function_calls = extract_json_objects(assistant_response) + + if len(function_calls) > 1: + logger.warning("Multiple function calls found in assistant response") + raise ValueError("Expected only one function call in assistant response") + + elif len(function_calls) == 1: + function_call = function_calls[0] + query = function_call["function"]["parameters"]["query"] + logger.info(f"šŸ” Search Query: {query}") + + results = search(query, return_type=str, results=2) + + # Print search results to terminal + # logger.info("\n===== SEARCH RESULTS =====") + # logger.info( + # results + # ) # The results are already formatted with Result 1:, Result 2:, etc. + # logger.info("===========================\n") + + chat_state["messages"].append({"role": "ipython", "content": results}) + + # Record search in history + search_entry = { + "turn": len(self.history) // 2, + "searches": [{"query": query, "results": results}], + } + self.search_history.append(search_entry) + + except Exception as e: + logger.error(f"Error during tool call: {str(e)}") + chat_state["messages"].append({"role": "system", "content": f"Error during post-processing: {str(e)}"}) + chat_state["finished"] = True + + return chat_state + + def clear_history(self): + """Clear the conversation history.""" + self.history = [] + self.search_history = [] + logger.info("Conversation history cleared.") + + def set_system_prompt(self, prompt: str): + """ + Set a new system prompt. + + Args: + prompt: The new system prompt + """ + if not prompt: + logger.warning("System prompt cannot be empty. Using default.") + return + + self.system_prompt = prompt + logger.info("System prompt updated.") + logger.info(f"New system prompt: {self.system_prompt}") + + def display_welcome(self): + """Display welcome message.""" + model_type = "LoRA" if self.lora_path else "Base" + logger.info(f"\n{'=' * 50}") + logger.info(f"DeepSearch CLI - {model_type} Model") + logger.info(f"Model: {MODEL_NAME}") + logger.info(f"Temperature: {self.temperature}") + if self.lora_path: + logger.info(f"LoRA Path: {self.lora_path}") + logger.info(f"System Prompt: {self.system_prompt}") + logger.info(f"{'=' * 50}") + logger.info("Type 'help' to see available commands.") + + def print_pretty_chat_history(self): + """Print the full chat history in a pretty format, including searches.""" + if not self.history: + logger.info("No chat history available.") + return + + logger.info("\n" + "=" * 80) + logger.info("CHAT HISTORY WITH SEARCH DETAILS") + logger.info("=" * 80) + + # Group history into conversation turns + for i in range(0, len(self.history), 2): + turn_number = i // 2 + + # Print user message + if i < len(self.history): + user_msg = self.history[i]["content"] + logger.info(f"\n[Turn {turn_number + 1}] USER: ") + logger.info("-" * 40) + logger.info(user_msg) + + # Print searches associated with this turn if any + for search_entry in self.search_history: + if search_entry["turn"] == turn_number: + for idx, search in enumerate(search_entry["searches"]): + logger.info(f'\nšŸ” SEARCH {idx + 1}: "{search["query"]}"') + logger.info("-" * 40) + logger.info(search["results"]) + + # Print assistant response + if i + 1 < len(self.history): + assistant_msg = self.history[i + 1]["content"] + logger.info(f"\n[Turn {turn_number + 1}] ASSISTANT: ") + logger.info("-" * 40) + logger.info(assistant_msg) + + logger.info("\n" + "=" * 80 + "\n") + + def save_chat_history(self, filepath=None): + """ + Save chat history to a file. + + Args: + filepath: Path to save file (if None, auto-generate based on timestamp) + + Returns: + Path to the saved file + """ + if not self.history: + logger.info("No chat history to save.") + return None + + # Generate a default filepath if none provided + if filepath is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + model_type = "lora" if self.lora_path else "base" + filepath = os.path.join(OUTPUT_DIR, f"chat_history_{model_type}_{timestamp}.txt") + + # Ensure the directory exists + os.makedirs(os.path.dirname(filepath), exist_ok=True) + + # Prepare chat history data + pretty_history = [] + + # Group history into conversation turns + for i in range(0, len(self.history), 2): + turn_number = i // 2 + turn_data = { + "turn": turn_number + 1, + "user": self.history[i]["content"] if i < len(self.history) else "", + "searches": [], + "assistant": self.history[i + 1]["content"] if i + 1 < len(self.history) else "", + } + + # Add searches for this turn + for search_entry in self.search_history: + if search_entry["turn"] == turn_number: + turn_data["searches"].extend(search_entry["searches"]) + + pretty_history.append(turn_data) + + # Write to file + try: + with open(filepath, "w", encoding="utf-8") as f: + f.write(f"{'=' * 80}\n") + f.write(f"DEEPSEARCH CHAT HISTORY\n") + f.write(f"Model: {MODEL_NAME}\n") + f.write(f"LoRA Path: {self.lora_path if self.lora_path else 'None'}\n") + f.write(f"Temperature: {self.temperature}\n") + f.write(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"{'=' * 80}\n\n") + + for turn in pretty_history: + f.write(f"[Turn {turn['turn']}] USER:\n") + f.write(f"{'-' * 40}\n") + f.write(f"{turn['user']}\n\n") + + # Write searches + for i, search in enumerate(turn["searches"]): + f.write(f'šŸ” SEARCH {i + 1}: "{search["query"]}"\n') + f.write(f"{'-' * 40}\n") + f.write(f"{search['results']}\n\n") + + f.write(f"[Turn {turn['turn']}] ASSISTANT:\n") + f.write(f"{'-' * 40}\n") + f.write(f"{turn['assistant']}\n\n") + f.write(f"{'=' * 40}\n\n") + + logger.info(f"Chat history saved to: {filepath}") + return filepath + + except Exception as e: + logger.error(f"Error saving chat history: {e}") + return None + + def save_chat_history_json(self, filepath=None): + """ + Save chat history to a JSON file. + + Args: + filepath: Path to save file (if None, auto-generate based on timestamp) + + Returns: + Path to the saved file + """ + if not self.history: + logger.info("No chat history to save.") + return None + + # Generate a default filepath if none provided + if filepath is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + model_type = "lora" if self.lora_path else "base" + filepath = os.path.join(OUTPUT_DIR, f"chat_history_{model_type}_{timestamp}.json") + + # Ensure the directory exists + os.makedirs(os.path.dirname(filepath), exist_ok=True) + + # Prepare chat history data + history_data = { + "model": MODEL_NAME, + "lora_path": self.lora_path if self.lora_path else None, + "temperature": self.temperature, + "timestamp": datetime.now().isoformat(), + "turns": [], + } + + # Group history into conversation turns + for i in range(0, len(self.history), 2): + turn_number = i // 2 + turn_data = { + "turn": turn_number + 1, + "user": self.history[i]["content"] if i < len(self.history) else "", + "searches": [], + "assistant": self.history[i + 1]["content"] if i + 1 < len(self.history) else "", + } + + # Add searches for this turn + for search_entry in self.search_history: + if search_entry["turn"] == turn_number: + turn_data["searches"].extend(search_entry["searches"]) + + history_data["turns"].append(turn_data) + + # Write to file + try: + with open(filepath, "w", encoding="utf-8") as f: + json.dump(history_data, f, indent=2, ensure_ascii=False) + + logger.info(f"Chat history saved to JSON: {filepath}") + return filepath + + except Exception as e: + logger.error(f"Error saving chat history to JSON: {e}") + return None + + def display_help(self): + """Display help information.""" + logger.info("\n===== Commands =====") + logger.info("search - Search for information") + logger.info("system - Set a new system prompt") + logger.info("clear - Clear conversation history") + logger.info("history - Display full chat history with searches") + logger.info("save - Save chat history to a text file") + logger.info("savejson - Save chat history to a JSON file") + logger.info("help - Display this help message") + logger.info("exit/quit - Exit the program") + logger.info("Any other input will be treated as a prompt to the model.") + logger.info("===================\n") + + def run(self): + """Run the CLI.""" + self.display_welcome() + + while True: + try: + user_input = input("\n> ").strip() + + if not user_input: + continue + + if user_input.lower() in ["exit", "quit"]: + logger.info("Exiting...") + break + + if user_input.lower() == "help": + self.display_help() + continue + + if user_input.lower() == "clear": + self.clear_history() + continue + + if user_input.lower() == "history": + self.print_pretty_chat_history() + continue + + if user_input.lower() == "save": + self.save_chat_history() + continue + + if user_input.lower() == "savejson": + self.save_chat_history_json() + continue + + if user_input.lower().startswith("system "): + new_prompt = user_input[7:].strip() + self.set_system_prompt(new_prompt) + continue + + if user_input.lower().startswith("search "): + query = user_input[7:].strip() + if query: + try: + results = search(query, return_type=str) + formatted_results = format_search_results(results) + logger.info(formatted_results) + + # Add to search history + search_entry = { + "turn": len(self.history) // 2, + "searches": [{"query": query, "results": results}], + } + self.search_history.append(search_entry) + except Exception as e: + logger.error(f"Error searching: {e}") + else: + logger.warning("Please provide a search query.") + continue + + # Process as a prompt to the model + logger.info("\nGenerating response...") + response = self.generate(user_input) + logger.info("\n----- Response -----") + logger.info(response) + + except KeyboardInterrupt: + logger.info("\nExiting...") + break + except Exception as e: + logger.error(f"Error: {e}") + + +def extract_json_objects(text): + """ + Extracts JSON objects (dictionaries) from a text that may contain multiple JSON objects. + + Args: + text (str): The input text possibly containing JSON objects. + + Returns: + list: A list of parsed JSON objects (dictionaries) extracted from the text. + """ + results = [] + length = len(text) + i = 0 + + while i < length: + # Look for the start of a JSON object + if text[i] == "{": + start = i + stack = 1 + i += 1 + # Continue until we find the matching closing brace + while i < length and stack > 0: + if text[i] == "{": + stack += 1 + elif text[i] == "}": + stack -= 1 + i += 1 + # Only attempt to decode if the braces are balanced + if stack == 0: + candidate = text[start:i] + try: + obj = json.loads(candidate) + # Optionally, ensure it's a dictionary if that's what you expect + if isinstance(obj, dict): + results.append(obj) + except json.JSONDecodeError: + # If it's not valid JSON, skip it. + pass + else: + i += 1 + return results + + +# Tool definition for search corpus +SEARCH_TOOL_DEFINITION = { + "type": "function", + "function": { + "name": "search_corpus", + "description": "Search over the knowledge corpus with a given query", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The query to search the knowledge corpus with", + }, + }, + "required": ["query"], + }, + }, +} + + +def main(): + """Main function.""" + parser = argparse.ArgumentParser(description="DeepSearch CLI") + parser.add_argument( + "--lora_path", + type=str, + default="auto", + help="Path to LoRA weights (None for base model, 'auto' for auto-detection)", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.7, + help="Sampling temperature (default: 0.7)", + ) + parser.add_argument( + "--system_prompt", + type=str, + default=None, + help="System prompt to guide model behavior", + ) + args = parser.parse_args() + + # Auto-detect LoRA path if requested + lora_path = None + if args.lora_path and args.lora_path.lower() != "none": + if args.lora_path == "auto": + detected_path = find_latest_checkpoint() + if detected_path: + lora_path = detected_path + logger.info(f"Auto-detected LoRA path: {lora_path}") + else: + logger.warning("No LoRA checkpoint found. Using base model.") + else: + lora_path = args.lora_path + + # Initialize and run the CLI + cli = DeepSearchCLI( + lora_path=lora_path, + temperature=args.temperature, + system_prompt=args.system_prompt, + ) + cli.run() + + +if __name__ == "__main__": + # Ensure the vectorstore is loaded + if load_vectorstore() is None: + logger.warning("FAISS vectorstore could not be loaded. Search functionality may not work.") + + main()