diff --git a/inference.py b/inference.py index a2aab1a..85ec4cd 100644 --- a/inference.py +++ b/inference.py @@ -1,127 +1,36 @@ """ Simple CLI inference script with search functionality. -This script allows interaction with a model (with optional LoRA adapter) +This script allows interaction with the merged 16-bit model and provides search functionality for data retrieval. """ import argparse import json import os -import re import time from datetime import datetime from typing import Any, Dict, List, Optional, Union -from unsloth import FastLanguageModel +from transformers import AutoModelForCausalLM, AutoTokenizer from vllm import SamplingParams -from src.config import MODEL_NAME, OUTPUT_DIR, logger from src.search_module import load_vectorstore, search -def find_latest_checkpoint(search_dir=None): - """ - Find the latest checkpoint in the specified directory or OUTPUT_DIR. - - Args: - search_dir: Directory to search for checkpoints (default: OUTPUT_DIR) +def setup_model_and_tokenizer(model_path: str): + """Initialize model and tokenizer.""" + print(f"Setting up model from {model_path}...") - Returns: - Path to the latest checkpoint or None if no checkpoints found - """ - if search_dir is None: - search_dir = "trainer_output_meta-llama_Llama-3.1-8B-Instruct_gpu1_20250326_134236" - logger.info(f"No search directory provided, using default: {search_dir}") - else: - logger.info(f"Searching for checkpoints in: {search_dir}") - - # Check if the directory exists first - if not os.path.exists(search_dir): - logger.warning(f"Search directory {search_dir} does not exist") - return None - - # First try to find checkpoints in the format checkpoint-{step} - import glob - - checkpoints = glob.glob(os.path.join(search_dir, "checkpoint-*")) - - if checkpoints: - # Extract checkpoint numbers and sort - checkpoint_numbers = [] - for checkpoint in checkpoints: - match = re.search(r"checkpoint-(\d+)$", checkpoint) - if match: - checkpoint_numbers.append((int(match.group(1)), checkpoint)) - - if checkpoint_numbers: - # Sort by checkpoint number (descending) - checkpoint_numbers.sort(reverse=True) - latest = checkpoint_numbers[0][1] - logger.info(f"Found latest checkpoint: {latest}") - return latest - - # If no checkpoints found, look for saved_adapter_{timestamp}.bin files - adapter_files = glob.glob(os.path.join(search_dir, "saved_adapter_*.bin")) - if adapter_files: - # Sort by modification time (newest first) - adapter_files.sort(key=os.path.getmtime, reverse=True) - latest = adapter_files[0] - logger.info(f"Found latest adapter file: {latest}") - return latest - - # If all else fails, look for any .bin files - bin_files = glob.glob(os.path.join(search_dir, "*.bin")) - if bin_files: - # Sort by modification time (newest first) - bin_files.sort(key=os.path.getmtime, reverse=True) - latest = bin_files[0] - logger.info(f"Found latest .bin file: {latest}") - return latest - - logger.warning(f"No checkpoints found in {search_dir}") - return None - - -def setup_model_and_tokenizer(): - """Initialize model and tokenizer with LoRA support.""" - config = { - "max_seq_length": 4096 * 2, - "lora_rank": 64, - "gpu_memory_utilization": 0.6, - "model_name": MODEL_NAME, - "target_modules": [ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ], - } - - logger.info(f"Setting up model {config['model_name']} with LoRA support...") - model, tokenizer = FastLanguageModel.from_pretrained( - model_name=config["model_name"], - max_seq_length=config["max_seq_length"], - load_in_4bit=True, - fast_inference=True, - max_lora_rank=config["lora_rank"], - gpu_memory_utilization=config["gpu_memory_utilization"], + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype="float16", + device_map="auto", + trust_remote_code=True, ) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - # Setup LoRA - model = FastLanguageModel.get_peft_model( - model, - r=config["lora_rank"], - target_modules=config["target_modules"], - lora_alpha=config["lora_rank"], - use_gradient_checkpointing=True, - random_state=3407, - ) - - logger.info("Model and tokenizer setup complete.") + print("Model and tokenizer setup complete.") return model, tokenizer @@ -205,7 +114,7 @@ class DeepSearchCLI: def __init__( self, - lora_path: Optional[str] = None, + model_path: str, temperature: float = 0.7, system_prompt: Optional[str] = None, ): @@ -213,15 +122,13 @@ class DeepSearchCLI: Initialize the CLI. Args: - lora_path: Path to LoRA weights (None for base model) + model_path: Path to the merged 16-bit model temperature: Sampling temperature system_prompt: Optional system prompt to guide the model's behavior """ - self.model, self.tokenizer = setup_model_and_tokenizer() - self.lora_path = lora_path + self.model, self.tokenizer = setup_model_and_tokenizer(model_path) self.temperature = temperature self.sampling_params = get_sampling_params(temperature) - self.lora_request = None self.history = [] self.search_history = [] self.system_prompt = ( @@ -234,14 +141,34 @@ When you receive a tool call response, use the output to format an answer to the You are a helpful assistant with tool calling capabilities.""" ) - # Load LoRA if specified - if self.lora_path: - logger.info(f"Loading LoRA adapter from {self.lora_path}...") - self.lora_request = self.model.load_lora(self.lora_path) - if self.lora_request: - logger.info(f"LoRA adapter loaded successfully: {self.lora_request}") - else: - logger.error("Failed to load LoRA adapter") + def _run_agent_generation(self, chat_state: dict) -> dict: + """Run a single generation step for the agent.""" + formatted_prompt = self.tokenizer.apply_chat_template( + chat_state["messages"], + tokenize=False, + add_generation_prompt=True, + ) + + start_time = time.time() + inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.model.device) + outputs = self.model.generate( + **inputs, + max_new_tokens=self.sampling_params.max_tokens, + temperature=self.sampling_params.temperature, + top_p=self.sampling_params.top_p, + do_sample=True, + ) + response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + + gen_time = time.time() - start_time + print(f"Generation completed in {gen_time:.2f} seconds") + + # Extract assistant response + assistant_response = response_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1] + + chat_state["messages"].append({"role": "assistant", "content": assistant_response}) + + return chat_state def generate(self, prompt: str, max_generations: int = 20) -> str: """ @@ -285,42 +212,6 @@ You are a helpful assistant with tool calling capabilities.""" return final_response - def _run_agent_generation(self, chat_state: dict) -> dict: - """Run a single generation step for the agent.""" - formatted_prompt = self.tokenizer.apply_chat_template( - chat_state["messages"], - tokenize=False, - add_generation_prompt=True, - ) - - start_time = time.time() - if self.lora_request: - response = self.model.fast_generate( - [formatted_prompt], - sampling_params=self.sampling_params, - lora_request=self.lora_request, - ) - else: - response = self.model.fast_generate( - [formatted_prompt], - sampling_params=self.sampling_params, - ) - - gen_time = time.time() - start_time - logger.debug(f"Generation completed in {gen_time:.2f} seconds") - - if hasattr(response[0], "outputs"): - response_text = response[0].outputs[0].text - else: - response_text = response[0] - - # Extract assistant response - assistant_response = response_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1] - - chat_state["messages"].append({"role": "assistant", "content": assistant_response}) - - return chat_state - def _check_finished_chat(self, chat_state: dict) -> dict: """Check if the chat is finished (no more function calls).""" if chat_state.get("finished"): @@ -329,7 +220,7 @@ You are a helpful assistant with tool calling capabilities.""" assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant" assistant_response = chat_state["messages"][-1]["content"] - function_calls = extract_json_objects(assistant_response) + function_calls = extract_function_calls(assistant_response) if len(function_calls) == 0: chat_state["finished"] = True @@ -343,16 +234,16 @@ You are a helpful assistant with tool calling capabilities.""" try: assistant_response = chat_state["messages"][-1]["content"] - function_calls = extract_json_objects(assistant_response) + function_calls = extract_function_calls(assistant_response) if len(function_calls) > 1: - logger.warning("Multiple function calls found in assistant response") + print("Multiple function calls found in assistant response") raise ValueError("Expected only one function call in assistant response") elif len(function_calls) == 1: function_call = function_calls[0] query = function_call["function"]["parameters"]["query"] - logger.info(f"šŸ” Search Query: {query}") + print(f"šŸ” Search Query: {query}") results = search(query, return_type=str, results=2) @@ -373,7 +264,7 @@ You are a helpful assistant with tool calling capabilities.""" self.search_history.append(search_entry) except Exception as e: - logger.error(f"Error during tool call: {str(e)}") + print(f"Error during tool call: {str(e)}") chat_state["messages"].append({"role": "system", "content": f"Error during post-processing: {str(e)}"}) chat_state["finished"] = True @@ -383,7 +274,7 @@ You are a helpful assistant with tool calling capabilities.""" """Clear the conversation history.""" self.history = [] self.search_history = [] - logger.info("Conversation history cleared.") + print("Conversation history cleared.") def set_system_prompt(self, prompt: str): """ @@ -393,35 +284,32 @@ You are a helpful assistant with tool calling capabilities.""" prompt: The new system prompt """ if not prompt: - logger.warning("System prompt cannot be empty. Using default.") + print("System prompt cannot be empty. Using default.") return self.system_prompt = prompt - logger.info("System prompt updated.") - logger.info(f"New system prompt: {self.system_prompt}") + print("System prompt updated.") + print(f"New system prompt: {self.system_prompt}") def display_welcome(self): """Display welcome message.""" - model_type = "LoRA" if self.lora_path else "Base" - logger.info(f"\n{'=' * 50}") - logger.info(f"DeepSearch CLI - {model_type} Model") - logger.info(f"Model: {MODEL_NAME}") - logger.info(f"Temperature: {self.temperature}") - if self.lora_path: - logger.info(f"LoRA Path: {self.lora_path}") - logger.info(f"System Prompt: {self.system_prompt}") - logger.info(f"{'=' * 50}") - logger.info("Type 'help' to see available commands.") + print(f"\n{'=' * 50}") + print(f"DeepSearch CLI - {self.model.name_or_path}") + print(f"Model: {self.model.name_or_path}") + print(f"Temperature: {self.temperature}") + print(f"System Prompt: {self.system_prompt}") + print(f"{'=' * 50}") + print("Type 'help' to see available commands.") def print_pretty_chat_history(self): """Print the full chat history in a pretty format, including searches.""" if not self.history: - logger.info("No chat history available.") + print("No chat history available.") return - logger.info("\n" + "=" * 80) - logger.info("CHAT HISTORY WITH SEARCH DETAILS") - logger.info("=" * 80) + print("\n" + "=" * 80) + print("CHAT HISTORY WITH SEARCH DETAILS") + print("=" * 80) # Group history into conversation turns for i in range(0, len(self.history), 2): @@ -430,26 +318,26 @@ You are a helpful assistant with tool calling capabilities.""" # Print user message if i < len(self.history): user_msg = self.history[i]["content"] - logger.info(f"\n[Turn {turn_number + 1}] USER: ") - logger.info("-" * 40) - logger.info(user_msg) + print(f"\n[Turn {turn_number + 1}] USER: ") + print("-" * 40) + print(user_msg) # Print searches associated with this turn if any for search_entry in self.search_history: if search_entry["turn"] == turn_number: for idx, search in enumerate(search_entry["searches"]): - logger.info(f'\nšŸ” SEARCH {idx + 1}: "{search["query"]}"') - logger.info("-" * 40) - logger.info(search["results"]) + print(f'\nšŸ” SEARCH {idx + 1}: "{search["query"]}"') + print("-" * 40) + print(search["results"]) # Print assistant response if i + 1 < len(self.history): assistant_msg = self.history[i + 1]["content"] - logger.info(f"\n[Turn {turn_number + 1}] ASSISTANT: ") - logger.info("-" * 40) - logger.info(assistant_msg) + print(f"\n[Turn {turn_number + 1}] ASSISTANT: ") + print("-" * 40) + print(assistant_msg) - logger.info("\n" + "=" * 80 + "\n") + print("\n" + "=" * 80 + "\n") def save_chat_history(self, filepath=None): """ @@ -462,14 +350,13 @@ You are a helpful assistant with tool calling capabilities.""" Path to the saved file """ if not self.history: - logger.info("No chat history to save.") + print("No chat history to save.") return None # Generate a default filepath if none provided if filepath is None: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - model_type = "lora" if self.lora_path else "base" - filepath = os.path.join(OUTPUT_DIR, f"chat_history_{model_type}_{timestamp}.txt") + filepath = os.path.join(os.getcwd(), f"chat_history_{timestamp}.txt") # Ensure the directory exists os.makedirs(os.path.dirname(filepath), exist_ok=True) @@ -499,8 +386,7 @@ You are a helpful assistant with tool calling capabilities.""" with open(filepath, "w", encoding="utf-8") as f: f.write(f"{'=' * 80}\n") f.write("DEEPSEARCH CHAT HISTORY\n") - f.write(f"Model: {MODEL_NAME}\n") - f.write(f"LoRA Path: {self.lora_path if self.lora_path else 'None'}\n") + f.write(f"Model: {self.model.name_or_path}\n") f.write(f"Temperature: {self.temperature}\n") f.write(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"{'=' * 80}\n\n") @@ -521,11 +407,11 @@ You are a helpful assistant with tool calling capabilities.""" f.write(f"{turn['assistant']}\n\n") f.write(f"{'=' * 40}\n\n") - logger.info(f"Chat history saved to: {filepath}") + print(f"Chat history saved to: {filepath}") return filepath except Exception as e: - logger.error(f"Error saving chat history: {e}") + print(f"Error saving chat history: {e}") return None def save_chat_history_json(self, filepath=None): @@ -539,22 +425,20 @@ You are a helpful assistant with tool calling capabilities.""" Path to the saved file """ if not self.history: - logger.info("No chat history to save.") + print("No chat history to save.") return None # Generate a default filepath if none provided if filepath is None: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - model_type = "lora" if self.lora_path else "base" - filepath = os.path.join(OUTPUT_DIR, f"chat_history_{model_type}_{timestamp}.json") + filepath = os.path.join(os.getcwd(), f"chat_history_{timestamp}.json") # Ensure the directory exists os.makedirs(os.path.dirname(filepath), exist_ok=True) # Prepare chat history data history_data = { - "model": MODEL_NAME, - "lora_path": self.lora_path if self.lora_path else None, + "model": self.model.name_or_path, "temperature": self.temperature, "timestamp": datetime.now().isoformat(), "turns": [], @@ -582,26 +466,26 @@ You are a helpful assistant with tool calling capabilities.""" with open(filepath, "w", encoding="utf-8") as f: json.dump(history_data, f, indent=2, ensure_ascii=False) - logger.info(f"Chat history saved to JSON: {filepath}") + print(f"Chat history saved to JSON: {filepath}") return filepath except Exception as e: - logger.error(f"Error saving chat history to JSON: {e}") + print(f"Error saving chat history to JSON: {e}") return None def display_help(self): """Display help information.""" - logger.info("\n===== Commands =====") - logger.info("search - Search for information") - logger.info("system - Set a new system prompt") - logger.info("clear - Clear conversation history") - logger.info("history - Display full chat history with searches") - logger.info("save - Save chat history to a text file") - logger.info("savejson - Save chat history to a JSON file") - logger.info("help - Display this help message") - logger.info("exit/quit - Exit the program") - logger.info("Any other input will be treated as a prompt to the model.") - logger.info("===================\n") + print("\n===== Commands =====") + print("search - Search for information") + print("system - Set a new system prompt") + print("clear - Clear conversation history") + print("history - Display full chat history with searches") + print("save - Save chat history to a text file") + print("savejson - Save chat history to a JSON file") + print("help - Display this help message") + print("exit/quit - Exit the program") + print("Any other input will be treated as a prompt to the model.") + print("===================\n") def run(self): """Run the CLI.""" @@ -615,7 +499,7 @@ You are a helpful assistant with tool calling capabilities.""" continue if user_input.lower() in ["exit", "quit"]: - logger.info("Exiting...") + print("Exiting...") break if user_input.lower() == "help": @@ -649,7 +533,7 @@ You are a helpful assistant with tool calling capabilities.""" try: results = search(query, return_type=str) formatted_results = format_search_results(results) - logger.info(formatted_results) + print(formatted_results) # Add to search history search_entry = { @@ -658,22 +542,22 @@ You are a helpful assistant with tool calling capabilities.""" } self.search_history.append(search_entry) except Exception as e: - logger.error(f"Error searching: {e}") + print(f"Error searching: {e}") else: - logger.warning("Please provide a search query.") + print("Please provide a search query.") continue # Process as a prompt to the model - logger.info("\nGenerating response...") + print("\nGenerating response...") response = self.generate(user_input) - logger.info("\n----- Response -----") - logger.info(response) + print("\n----- Response -----") + print(response) except KeyboardInterrupt: - logger.info("\nExiting...") + print("\nExiting...") break except Exception as e: - logger.error(f"Error: {e}") + print(f"Error: {e}") def extract_json_objects(text): @@ -743,10 +627,10 @@ def main(): """Main function.""" parser = argparse.ArgumentParser(description="DeepSearch CLI") parser.add_argument( - "--lora_path", + "--model_path", type=str, - default="auto", - help="Path to LoRA weights (None for base model, 'auto' for auto-detection)", + default="trainer_output_example/model_merged_16bit", + help="Path to the merged 16-bit model (default: trainer_output_example/model_merged_16bit)", ) parser.add_argument( "--temperature", @@ -762,22 +646,9 @@ def main(): ) args = parser.parse_args() - # Auto-detect LoRA path if requested - lora_path = None - if args.lora_path and args.lora_path.lower() != "none": - if args.lora_path == "auto": - detected_path = find_latest_checkpoint() - if detected_path: - lora_path = detected_path - logger.info(f"Auto-detected LoRA path: {lora_path}") - else: - logger.warning("No LoRA checkpoint found. Using base model.") - else: - lora_path = args.lora_path - # Initialize and run the CLI cli = DeepSearchCLI( - lora_path=lora_path, + model_path=args.model_path, temperature=args.temperature, system_prompt=args.system_prompt, ) @@ -787,6 +658,6 @@ def main(): if __name__ == "__main__": # Ensure the vectorstore is loaded if load_vectorstore() is None: - logger.warning("FAISS vectorstore could not be loaded. Search functionality may not work.") + print("FAISS vectorstore could not be loaded. Search functionality may not work.") main()