diff --git a/eval.py b/eval.py deleted file mode 100644 index 25d02f5..0000000 --- a/eval.py +++ /dev/null @@ -1,380 +0,0 @@ -""" -Evaluate model performance using vLLM and unsloth. - -This script evaluates the performance of a model using vLLM for fast inference -and unsloth for LoRA support. -""" - -import argparse -import os -import time -from datetime import datetime - -from unsloth import FastLanguageModel -from vllm import SamplingParams - -from src import ( - apply_chat_template, - build_reward_correctness_fn, - build_user_prompt, - get_qa_dataset, - get_system_prompt, - run_eval, -) -from config import MODEL_NAME, logger - - -def get_model_config(): - """Get model configuration.""" - return { - "max_seq_length": 4096 * 2, - "lora_rank": 64, - "gpu_memory_utilization": 0.6, - "model_name": MODEL_NAME, - "target_modules": [ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ], - } - - -def get_sampling_params(temperature=0.5): - """Get sampling parameters for generation.""" - return SamplingParams( - temperature=temperature, - top_p=0.95, - max_tokens=4096, - ) - - -def setup_model_and_tokenizer(): - """Initialize model and tokenizer with LoRA support.""" - config = get_model_config() - logger.info(f"Setting up model {config['model_name']} with LoRA support...") - model, tokenizer = FastLanguageModel.from_pretrained( - model_name=config["model_name"], - max_seq_length=config["max_seq_length"], - load_in_4bit=True, - fast_inference=True, - max_lora_rank=config["lora_rank"], - gpu_memory_utilization=config["gpu_memory_utilization"], - ) - - # Setup LoRA - model = FastLanguageModel.get_peft_model( - model, - r=config["lora_rank"], - target_modules=config["target_modules"], - lora_alpha=config["lora_rank"], - use_gradient_checkpointing=True, - random_state=3407, - ) - - logger.info("Model and tokenizer setup complete.") - return model, tokenizer - - -def evaluate_model( - model, - tokenizer, - lora_path=None, - temperature=0.5, - output_file="eval_results.txt", - trainer_dir=None, -): - """ - Evaluate model with or without LoRA weights. - - Args: - model: The model to evaluate - tokenizer: The tokenizer - lora_path: Path to LoRA weights (None for base model) - temperature: Sampling temperature - output_file: File to write results to - trainer_dir: Directory containing the checkpoints - """ - sampling_params = get_sampling_params(temperature=temperature) - - # Set up output directory - if trainer_dir: - eval_log_dir = os.path.join(trainer_dir, "eval_logs") - else: - eval_log_dir = "eval_logs" - os.makedirs(eval_log_dir, exist_ok=True) - - # Create file names based on model type - model_prefix = "lora" if lora_path else "base" - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - - # Define all output file paths - eval_log_file = os.path.join(eval_log_dir, f"{model_prefix}_model_eval_{timestamp}.log") - output_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results.txt") - debug_file = os.path.join(eval_log_dir, f"{model_prefix}_model_results_debug.json") - - logger.info(f"Writing evaluation log to: {eval_log_file}") - logger.info(f"Results will be saved to: {output_file}") - - # Function to generate completions using agentic approach - def eval_generate_fn(inputs): - start_time = time.time() - - # Format inputs as chat messages with system prompt - messages = [ - { - "messages": [ - {"role": "system", "content": get_system_prompt()}, - {"role": "user", "content": build_user_prompt(input_text)}, - ] - } - for input_text in inputs - ] - - if lora_path: - lora_request = model.load_lora(lora_path) - load_time = time.time() - start_time - logger.info(f"LoRA adapter loaded in {load_time:.2f} seconds: {lora_request}") - responses = model.fast_generate( - [apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages], - sampling_params=sampling_params, - lora_request=lora_request, - ) - else: - responses = model.fast_generate( - [apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages], - sampling_params=sampling_params, - ) - - gen_time = time.time() - start_time - logger.debug(f"Generation completed in {gen_time:.2f} seconds") - return responses - - def verifier_generate_fn(inputs): - # Use a lower temperature for verification to get more consistent results - verifier_params = get_sampling_params(temperature=0.1) - - # Format inputs as chat messages with system prompt - messages = [ - { - "messages": [ - {"role": "system", "content": get_system_prompt()}, - {"role": "user", "content": build_user_prompt(input_text)}, - ] - } - for input_text in inputs - ] - - return model.fast_generate( - [apply_chat_template(msg, tokenizer=tokenizer)["text"] for msg in messages], - sampling_params=verifier_params, - ) - - # Prepare the verification function - verify_fn = build_reward_correctness_fn(verifier_generate_fn, tokenizer) - - # Get the dataset and prepare questions and answers - train_dataset, test_dataset = get_qa_dataset() - questions = test_dataset["prompt"] - inputs = questions - - logger.info(f"Verifying {len(inputs)} answers...") - - # Run the evaluation - start_time = time.time() - model_type = "LoRA" if lora_path else "Base" - logger.info(f"Starting {model_type} model evaluation...") - - # Run evaluation using the agentic approach - full_chat_states = run_eval( - generate_fn=eval_generate_fn, - verify_fn=verify_fn, - tokenizer=tokenizer, - output_file=output_file, - debug_file=debug_file, - ) - - # Calculate rewards - logger.info(f"Calculating rewards for {model_type} model...") - rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"]) - avg_reward = sum(rewards) / len(rewards) - total_time = time.time() - start_time - - # Record the results - results = { - "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - "model_type": model_type, - "model_name": MODEL_NAME, - "lora_path": lora_path if lora_path else "None", - "accuracy": avg_reward, - "correct_count": sum(rewards), - "total_count": len(rewards), - "temperature": temperature, - "time_taken": total_time, - } - - # Add more detailed output to log file - logger.info(f"\n{'=' * 50}") - logger.info(f"{model_type.upper()} MODEL EVALUATION RESULTS:") - logger.info(f"{'=' * 50}") - logger.info(f"Accuracy: {avg_reward:.4f} ({sum(rewards)}/{len(rewards)} correct)") - logger.info(f"Temperature: {temperature}") - logger.info(f"Time taken: {total_time:.2f} seconds") - logger.info(f"Results file: {output_file}") - logger.info(f"Debug file: {debug_file}") - logger.info(f"Log file: {eval_log_file}") - - # Write a summary to the log file too - with open(eval_log_file, "a") as f: - f.write(f"\n{'=' * 50}\n") - f.write(f"{model_type.upper()} MODEL EVALUATION SUMMARY\n") - f.write(f"{'=' * 50}\n") - f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") - f.write(f"Accuracy: {avg_reward:.4f} ({sum(rewards)}/{len(rewards)} correct)\n") - f.write(f"Temperature: {temperature}\n") - f.write(f"Time taken: {total_time:.2f} seconds\n") - f.write(f"Results saved to: {output_file}\n") - f.write(f"Debug data saved to: {debug_file}\n\n") - - logger.info(f"Evaluation completed. Results saved to {output_file} and {debug_file}") - - return results - - -def compare_models(lora_path, temperature=0.5, output_file=None, trainer_dir=None): - """ - Compare base model with LoRA model. - - Args: - lora_path: Path to LoRA weights - temperature: Sampling temperature - output_file: File to write results to (optional) - trainer_dir: Directory containing the trainer output - """ - # Set up output directory - if trainer_dir: - eval_log_dir = os.path.join(trainer_dir, "eval_logs") - else: - eval_log_dir = "eval_logs" - os.makedirs(eval_log_dir, exist_ok=True) - - # Define the comparison file path if not provided - if output_file is None: - output_file = os.path.join(eval_log_dir, "model_comparison_results.txt") - - # Define file paths for individual model results - base_output = os.path.join(eval_log_dir, "base_model_results.txt") - lora_output = os.path.join(eval_log_dir, "lora_model_results.txt") - - model, tokenizer = setup_model_and_tokenizer() - - # Evaluate both models - base_results = evaluate_model( - model, - tokenizer, - lora_path=None, - temperature=temperature, - output_file=base_output, - trainer_dir=trainer_dir, - ) - - lora_results = evaluate_model( - model, - tokenizer, - lora_path=lora_path, - temperature=temperature, - output_file=lora_output, - trainer_dir=trainer_dir, - ) - - # Calculate improvement - improvement = lora_results["accuracy"] - base_results["accuracy"] - - # Write comparison results - with open(output_file, "w") as f: - f.write("MODEL COMPARISON RESULTS\n") - f.write("======================\n\n") - f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") - f.write(f"Base Model: {MODEL_NAME}\n") - f.write(f"LoRA Path: {lora_path}\n\n") - f.write(f"Base Model Accuracy: {base_results['accuracy']:.4f}\n") - f.write(f"LoRA Model Accuracy: {lora_results['accuracy']:.4f}\n") - f.write(f"Improvement: {improvement:.4f}\n") - f.write(f"Temperature: {temperature}\n") - f.write(f"Base Model Time: {base_results['time_taken']:.2f}s\n") - f.write(f"LoRA Model Time: {lora_results['time_taken']:.2f}s\n\n") - f.write(f"Base Model Results File: {base_output}\n") - f.write(f"LoRA Model Results File: {lora_output}\n") - - logger.info("\nModel comparison completed.") - logger.info(f"\n{'=' * 50}") - logger.info("MODEL COMPARISON RESULTS:") - logger.info(f"{'=' * 50}") - logger.info(f"Base Model Accuracy: {base_results['accuracy']:.4f}") - logger.info(f"LoRA Model Accuracy: {lora_results['accuracy']:.4f}") - logger.info(f"Improvement: {improvement:.4f}") - logger.info(f"Temperature: {temperature}") - logger.info(f"Results written to: {output_file}") - logger.info(f"Base Model Results: {base_output}") - logger.info(f"LoRA Model Results: {lora_output}") - logger.info(f"{'=' * 50}") - - return { - "base_accuracy": base_results["accuracy"], - "lora_accuracy": lora_results["accuracy"], - "improvement": improvement, - "output_file": output_file, - "base_output": base_output, - "lora_output": lora_output, - } - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Evaluate and compare models") - parser.add_argument( - "--lora_path", - type=str, - default="trainer_output_example/checkpoint-101", - help="Path to LoRA weights", - ) - parser.add_argument("--temperature", type=float, default=0.5, help="Sampling temperature") - parser.add_argument( - "--output_file", - type=str, - default=None, - help="File to write results to (optional)", - ) - parser.add_argument( - "--trainer_dir", - type=str, - default=None, - help="Directory containing the trainer output", - ) - args = parser.parse_args() - - logger.info(f"Starting model evaluation with temperature {args.temperature}") - results = compare_models(args.lora_path, args.temperature, args.output_file, trainer_dir=args.trainer_dir) - if results: - logger.info("Evaluation completed successfully") - logger.info(f"Final improvement: {results['improvement']:.4f}") - logger.info(f"Results saved to: {results['output_file']}") - - # Print all output files for clarity - logger.info("\nSUMMARY OF OUTPUT FILES:") - logger.info(f"Comparison results: {results['output_file']}") - logger.info(f"Base model results: {results['base_output']}") - logger.info(f"LoRA model results: {results['lora_output']}") - - # Find and print all log files in the eval_logs directory - eval_log_dir = os.path.join(args.trainer_dir, "eval_logs") if args.trainer_dir else "eval_logs" - if os.path.exists(eval_log_dir): - log_files = [f for f in os.listdir(eval_log_dir) if f.endswith(".log")] - if log_files: - logger.info("\nEVALUATION LOG FILES:") - for log_file in log_files: - logger.info(f"- {os.path.join(eval_log_dir, log_file)}") - else: - logger.warning("Evaluation failed or was skipped") diff --git a/inference.py b/inference.py deleted file mode 100644 index d9b7ec1..0000000 --- a/inference.py +++ /dev/null @@ -1,458 +0,0 @@ -""" -Simple CLI inference script with search functionality. - -This script allows interaction with the merged 16-bit model -and provides search functionality for data retrieval. -""" - -import argparse -import os -import time -from datetime import datetime - -from transformers import AutoModelForCausalLM, AutoTokenizer -from vllm import SamplingParams - -from src import ( - apply_chat_template, - build_user_prompt, - extract_search_query, - format_search_results, - get_system_prompt, -) -from src.search_module import load_vectorstore, search - - -def setup_model_and_tokenizer(model_path: str): - """Initialize model and tokenizer.""" - print(f"Setting up model from {model_path}...") - - model = AutoModelForCausalLM.from_pretrained( - model_path, - torch_dtype="float16", - device_map="auto", - trust_remote_code=True, - ) - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - - print("Model and tokenizer setup complete.") - return model, tokenizer - - -def get_sampling_params(temperature: float = 0.7, max_tokens: int = 4096) -> SamplingParams: - """Get sampling parameters for generation.""" - return SamplingParams( - temperature=temperature, - top_p=0.95, - max_tokens=max_tokens, - ) - - -class DeepSearchCLI: - """CLI for interacting with the model and search functionality.""" - - def __init__( - self, - model_path: str, - temperature: float = 0.7, - system_prompt: str | None = None, - ): - """ - Initialize the CLI. - - Args: - model_path: Path to the merged 16-bit model - temperature: Sampling temperature - system_prompt: Optional system prompt to guide the model's behavior - """ - self.model, self.tokenizer = setup_model_and_tokenizer(model_path) - self.temperature = temperature - self.sampling_params = get_sampling_params(temperature) - self.history = [] - self.search_history = [] - self.system_prompt = system_prompt or get_system_prompt() - - def _run_agent_generation(self, chat_state: dict) -> dict: - """Run a single generation step for the agent.""" - # Format the chat state using the same template as training - formatted_prompt = apply_chat_template(chat_state, tokenizer=self.tokenizer)["text"] - - start_time = time.time() - inputs = self.tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).to(self.model.device) - outputs = self.model.generate( - **inputs, - max_new_tokens=self.sampling_params.max_tokens, - temperature=self.sampling_params.temperature, - top_p=self.sampling_params.top_p, - do_sample=True, - ) - response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) - - gen_time = time.time() - start_time - print(f"Generation completed in {gen_time:.2f} seconds") - - # Extract assistant response - assistant_response = response_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1] - - chat_state["messages"].append({"role": "assistant", "content": assistant_response}) - - return chat_state - - def generate(self, prompt: str, max_generations: int = 20) -> str: - """ - Generate a response to the prompt using agentic mechanism. - - Args: - prompt: The prompt to generate a response to - max_generations: Maximum number of turns in the conversation - - Returns: - The generated response after completing the conversation - """ - # Initialize chat state with the same structure as training - chat_state = { - "messages": [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": build_user_prompt(prompt)}, - ], - "finished": False, - } - - # Agent loop - for i in range(max_generations): - # Generate response - chat_state = self._run_agent_generation(chat_state) - - # Check if conversation is finished - chat_state = self._check_finished_chat(chat_state) - if chat_state.get("finished"): - break - - # Process tool calls if any - chat_state = self._run_tool_calls(chat_state) - - # Get final response - final_response = chat_state["messages"][-1]["content"] - - # Update history - self.history.append({"role": "user", "content": prompt}) - self.history.append({"role": "assistant", "content": final_response}) - - return final_response - - def _check_finished_chat(self, chat_state: dict) -> dict: - """Check if the chat is finished (no more search queries).""" - if chat_state.get("finished"): - return chat_state - - assert chat_state["messages"][-1]["role"] == "assistant", "Expected the last role to be assistant" - - assistant_response = chat_state["messages"][-1]["content"] - search_query = extract_search_query(assistant_response) - - if not search_query: - chat_state["finished"] = True - - return chat_state - - def _run_tool_calls(self, chat_state: dict) -> dict: - """Execute tool calls found in chat state.""" - if chat_state.get("finished"): - return chat_state - - try: - assistant_response = chat_state["messages"][-1]["content"] - search_query = extract_search_query(assistant_response) - - if search_query: - print(f"šŸ” Search Query: {search_query}") - - results = search(search_query, return_type=str, results=2) - # Wrap results in tags - formatted_results = f"{results}" - - # Print search results to terminal - print("\n===== SEARCH RESULTS =====") - print(results) - print("===========================\n") - - chat_state["messages"].append({"role": "ipython", "content": formatted_results}) - - # Record search in history - search_entry = { - "turn": len(self.history) // 2, - "searches": [{"query": search_query, "results": results}], - } - self.search_history.append(search_entry) - - except Exception as e: - print(f"Error during tool call: {str(e)}") - chat_state["messages"].append({"role": "system", "content": f"Error during post-processing: {str(e)}"}) - chat_state["finished"] = True - - return chat_state - - def clear_history(self): - """Clear the conversation history.""" - self.history = [] - self.search_history = [] - print("Conversation history cleared.") - - def set_system_prompt(self, prompt: str): - """ - Set a new system prompt. - - Args: - prompt: The new system prompt - """ - if not prompt: - print("System prompt cannot be empty. Using default.") - return - - self.system_prompt = prompt - print("System prompt updated.") - print(f"New system prompt: {self.system_prompt}") - - def display_welcome(self): - """Display welcome message.""" - print(f"\n{'=' * 50}") - print(f"DeepSearch CLI - {self.model.name_or_path}") - print(f"Model: {self.model.name_or_path}") - print(f"Temperature: {self.temperature}") - print(f"System Prompt: {self.system_prompt}") - print(f"{'=' * 50}") - print("Type 'help' to see available commands.") - - def print_pretty_chat_history(self): - """Print the full chat history in a pretty format, including searches.""" - if not self.history: - print("No chat history available.") - return - - print("\n" + "=" * 80) - print("CHAT HISTORY WITH SEARCH DETAILS") - print("=" * 80) - - # Group history into conversation turns - for i in range(0, len(self.history), 2): - turn_number = i // 2 - - # Print user message - if i < len(self.history): - user_msg = self.history[i]["content"] - print(f"\n[Turn {turn_number + 1}] USER: ") - print("-" * 40) - print(user_msg) - - # Print searches associated with this turn if any - for search_entry in self.search_history: - if search_entry["turn"] == turn_number: - for idx, search in enumerate(search_entry["searches"]): - print(f'\nšŸ” SEARCH {idx + 1}: "{search["query"]}"') - print("-" * 40) - print(search["results"]) - - # Print assistant response - if i + 1 < len(self.history): - assistant_msg = self.history[i + 1]["content"] - print(f"\n[Turn {turn_number + 1}] ASSISTANT: ") - print("-" * 40) - print(assistant_msg) - - print("\n" + "=" * 80 + "\n") - - def save_chat_history(self, filepath=None): - """ - Save chat history to a file. - - Args: - filepath: Path to save file (if None, auto-generate based on timestamp) - - Returns: - Path to the saved file - """ - if not self.history: - print("No chat history to save.") - return None - - # Generate a default filepath if none provided - if filepath is None: - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - filepath = os.path.join(os.getcwd(), f"chat_history_{timestamp}.txt") - - # Ensure the directory exists - os.makedirs(os.path.dirname(filepath), exist_ok=True) - - # Prepare chat history data - pretty_history = [] - - # Group history into conversation turns - for i in range(0, len(self.history), 2): - turn_number = i // 2 - turn_data = { - "turn": turn_number + 1, - "user": self.history[i]["content"] if i < len(self.history) else "", - "searches": [], - "assistant": self.history[i + 1]["content"] if i + 1 < len(self.history) else "", - } - - # Add searches for this turn - for search_entry in self.search_history: - if search_entry["turn"] == turn_number: - turn_data["searches"].extend(search_entry["searches"]) - - pretty_history.append(turn_data) - - # Write to file - try: - with open(filepath, "w", encoding="utf-8") as f: - f.write(f"{'=' * 80}\n") - f.write("DEEPSEARCH CHAT HISTORY\n") - f.write(f"Model: {self.model.name_or_path}\n") - f.write(f"Temperature: {self.temperature}\n") - f.write(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") - f.write(f"{'=' * 80}\n\n") - - for turn in pretty_history: - f.write(f"[Turn {turn['turn']}] USER:\n") - f.write(f"{'-' * 40}\n") - f.write(f"{turn['user']}\n\n") - - # Write searches - for i, search in enumerate(turn["searches"]): - f.write(f'šŸ” SEARCH {i + 1}: "{search["query"]}"\n') - f.write(f"{'-' * 40}\n") - f.write(f"{search['results']}\n\n") - - f.write(f"[Turn {turn['turn']}] ASSISTANT:\n") - f.write(f"{'-' * 40}\n") - f.write(f"{turn['assistant']}\n\n") - f.write(f"{'=' * 40}\n\n") - - print(f"Chat history saved to: {filepath}") - return filepath - - except Exception as e: - print(f"Error saving chat history: {e}") - return None - - def display_help(self): - """Display help information.""" - print("\n===== Commands =====") - print("search - Search for information") - print("system - Set a new system prompt") - print("clear - Clear conversation history") - print("history - Display full chat history with searches") - print("save - Save chat history to a text file") - print("help - Display this help message") - print("exit/quit - Exit the program") - print("Any other input will be treated as a prompt to the model.") - print("===================\n") - - def run(self): - """Run the CLI.""" - self.display_welcome() - - while True: - try: - user_input = input("\n> ").strip() - - if not user_input: - continue - - if user_input.lower() in ["exit", "quit"]: - print("Exiting...") - break - - if user_input.lower() == "help": - self.display_help() - continue - - if user_input.lower() == "clear": - self.clear_history() - continue - - if user_input.lower() == "history": - self.print_pretty_chat_history() - continue - - if user_input.lower() == "save": - self.save_chat_history() - continue - - if user_input.lower().startswith("system "): - new_prompt = user_input[7:].strip() - self.set_system_prompt(new_prompt) - continue - - if user_input.lower().startswith("search "): - query = user_input[7:].strip() - if query: - try: - results = search(query, return_type=str) - formatted_results = format_search_results(results) - print(formatted_results) - - # Add to search history - search_entry = { - "turn": len(self.history) // 2, - "searches": [{"query": query, "results": results}], - } - self.search_history.append(search_entry) - except Exception as e: - print(f"Error searching: {e}") - else: - print("Please provide a search query.") - continue - - # Process as a prompt to the model - print("\nGenerating response...") - response = self.generate(user_input) - print("\n----- Response -----") - print(response) - - except KeyboardInterrupt: - print("\nExiting...") - break - except Exception as e: - print(f"Error: {e}") - - -def main(): - """Main function.""" - parser = argparse.ArgumentParser(description="DeepSearch CLI") - parser.add_argument( - "--model_path", - type=str, - default="trainer_output_example/model_merged_16bit", - help="Path to the merged 16-bit model (default: trainer_output_example/model_merged_16bit)", - ) - parser.add_argument( - "--temperature", - type=float, - default=0.7, - help="Sampling temperature (default: 0.7)", - ) - parser.add_argument( - "--system_prompt", - type=str, - default=None, - help="System prompt to guide model behavior", - ) - args = parser.parse_args() - - # Initialize and run the CLI - cli = DeepSearchCLI( - model_path=args.model_path, - temperature=args.temperature, - system_prompt=args.system_prompt, - ) - cli.run() - - -if __name__ == "__main__": - # Ensure the vectorstore is loaded - if load_vectorstore() is None: - print("FAISS vectorstore could not be loaded. Search functionality may not work.") - - main() diff --git a/train.sh b/train.sh deleted file mode 100755 index 5835684..0000000 --- a/train.sh +++ /dev/null @@ -1,6 +0,0 @@ -export CUDA_VISIBLE_DEVICES=0 - -python train_grpo.py - - -