diff --git a/.env.example b/.env.example index ba87db6..a9fa6d4 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,2 @@ -HF_TOKEN= -OPENROUTER_API_KEY= \ No newline at end of file +HF_TOKEN= +OPENROUTER_API_KEY= \ No newline at end of file diff --git a/.gitignore b/.gitignore index 343324e..1855af7 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,8 @@ unsloth_compiled_cache/ full_local_training/ grpo_trainer_lora_model/ qa_log.txt +trainer_output_* +data/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/generate_data.py b/generate_data.py deleted file mode 100644 index 10d029e..0000000 --- a/generate_data.py +++ /dev/null @@ -1,232 +0,0 @@ -""" -This script performs two main tasks: -1. It loads a markdown document, splits it into chunks, generates embeddings, - and builds a FAISS index (which is saved locally). -2. It generates QA pairs from the document using llama. - For each chunk (using a sliding window for context), it generates multiple question-answer pairs - with different difficulties. The generation is performed in batch with one retry for failed prompts. - Successfully generated QA pairs are saved to "saved_data/questions.json". - -Requirements: - pip install langchain faiss-cpu unsloth vllm -""" - -import json -import os -import pickle -import re -from typing import Dict, List, Optional, Tuple - -from langchain.text_splitter import RecursiveCharacterTextSplitter - -# ========= Part 1: Document Processing and Embedding Generation ========= -# Load and split the markdown document using LangChain -from langchain_community.document_loaders import UnstructuredMarkdownLoader -from langchain_community.vectorstores import FAISS - -from embeddings import CustomHuggingFaceEmbeddings - -# Load your markdown file (adjust the path as needed) -loader = UnstructuredMarkdownLoader("./data/mission_report.md") -docs = loader.load() - -# Split the document into smaller chunks (each 1000 characters, no overlap) -text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) -chunks = text_splitter.split_documents(docs) - -# Save chunks for later use # TODO: change to csv? easier inspect. -os.makedirs("saved_data", exist_ok=True) -with open("saved_data/chunks.pkl", "wb") as f: - pickle.dump(chunks, f) -print(f"Saved {len(chunks)} chunks to saved_data/chunks.pkl") - -embeddings = CustomHuggingFaceEmbeddings() - -# Create a FAISS vector store from the document chunks and save it locally -vectorstore = FAISS.from_documents(chunks, embeddings) -vectorstore.save_local("faiss_index") -print("Saved FAISS index to 'faiss_index'") - -# TODO: add the paraphrased chunks to the vector store - -# ========= Part 2: QA Generation using Llama Backend ========= - -# Setup Llama backend via unsloth and vLLM -from unsloth import FastLanguageModel -from vllm import SamplingParams - -import rl_helpers # Ensure you have this or remove if not used - -# Load the Llama model (adjust parameters as needed) -model, tokenizer = FastLanguageModel.from_pretrained( - model_name="meta-llama/meta-Llama-3.1-8B-Instruct", - max_seq_length=4096, - load_in_4bit=True, # Use 4-bit quantization if desired - fast_inference=True, # Enable fast inference - gpu_memory_utilization=0.6, # Adjust based on your GPU memory -) - -# Define sampling parameters for generation -sampling_params = SamplingParams( - temperature=0.3, - top_p=0.95, - max_tokens=4096, -) - - -def batch_generate(prompts: List[str]) -> List[str]: - """ - Given a list of prompt strings, returns a list of generated outputs. - """ - - def format_input(text: str) -> str: - return tokenizer.apply_chat_template( - [{"role": "user", "content": text}], - tokenize=False, - add_generation_prompt=True, - ) - - formatted = [format_input(p) for p in prompts] - outputs = model.fast_generate(formatted, sampling_params=sampling_params) - return [output.outputs[0].text for output in outputs] - - -def parse_qa_block(block: str) -> Optional[Tuple[str, str, str]]: - """ - Parses a QA block that should contain exactly three non-empty lines: - - A line starting with "Question:" - - A line starting with "Answer:" - - A line starting with "Difficulty:" - - If the markers are not present but the block contains exactly three lines, - those are used in order. - - Returns a tuple (question, answer, difficulty) or None if parsing fails. - """ - lines = [line.strip() for line in block.splitlines() if line.strip()] - if not lines: - return None - - question, answer, difficulty = None, None, None - for line in lines: - lower = line.lower() - if question is None and lower.startswith("question:"): - question = line[len("question:") :].strip() - elif answer is None and lower.startswith("answer:"): - answer = line[len("answer:") :].strip() - elif difficulty is None and lower.startswith("difficulty:"): - difficulty = line[len("difficulty:") :].strip() - - if question and answer and difficulty: - return question, answer, difficulty - if len(lines) == 3: - return lines[0], lines[1], lines[2] - return None - - -def parse_multiple_qa_output(output: str) -> List[Tuple[str, str, str]]: - """ - Splits the output into blocks (separated by one or more blank lines) and - attempts to parse each as a QA pair. - - Returns a list of successfully parsed QA tuples. - """ - blocks = re.split(r"\n\s*\n", output.strip()) - qa_pairs = [] - for block in blocks: - parsed = parse_qa_block(block) - if parsed: - qa_pairs.append(parsed) - return qa_pairs - - -def generate_question_batch_for_chunks( - chunks: List, num_questions: int = 2, difficulty: str = None -) -> List[Dict]: - """ - Generates QA pairs for multiple chunks in batch. - - For each chunk (except the first and last), a sliding window is used for context: - - before: previous chunk's content - - current: current chunk's content - - after: next chunk's content - - Each prompt instructs the model to output exactly three lines per QA pair with markers. - Failed prompts are retried once in batch; if still unsuccessful, they are skipped. - - Returns a list of dicts with keys: "chunk_id", "question", "answer", "difficulty". - """ - prompts = [] - chunk_ids = [] - - # Prepare prompts using a sliding window - for i in range(1, len(chunks) - 1): - before = chunks[i - 1].page_content - current = chunks[i].page_content - after = chunks[i + 1].page_content - prompt = ( - f"From the text within ==BEGIN== and ==END==, generate {num_questions} questions with answers.\n" - "For each QA pair, output exactly three lines with no extra commentary:\n" - "Line 1: Question: \n" - "Line 2: Answer: \n" - "Line 3: Difficulty: \n" - "Do not include any additional text.\n\n" - "==BEGIN==\n" - f"{before}\n{current}\n{after}\n" - "==END==\n" - ) - prompts.append(prompt) - chunk_ids.append(i) - - # First batch generation - outputs = batch_generate(prompts) - results = [None] * len(outputs) - failed_indices = [] - - # Parse each output - for idx, output in enumerate(outputs): - qa_pairs = parse_multiple_qa_output(output) - if qa_pairs is None or len(qa_pairs) < num_questions: - failed_indices.append(idx) - else: - results[idx] = qa_pairs[:num_questions] - - # Retry failed prompts in batch - if failed_indices: - print(f"Retrying {len(failed_indices)} failed prompt(s)...") - retry_prompts = [prompts[i] for i in failed_indices] - retry_outputs = batch_generate(retry_prompts) - for j, idx in enumerate(failed_indices): - qa_pairs = parse_multiple_qa_output(retry_outputs[j]) - if qa_pairs is not None and len(qa_pairs) >= num_questions: - results[idx] = qa_pairs[:num_questions] - else: - results[idx] = None # Mark as failed - - # Build final output, skipping prompts that failed even after retry - final_questions = [] - for i, qa_list in enumerate(results): - if qa_list is not None: - for qa in qa_list: - final_questions.append( - { - "chunk_id": chunk_ids[i], - "question": qa[0], - "answer": qa[1], - "difficulty": qa[2], - } - ) - return final_questions - - -# Generate QA pairs in batch (using a sliding window over the chunks) -all_questions = generate_question_batch_for_chunks( - chunks, num_questions=2, difficulty="medium" -) -print(f"Generated {len(all_questions)} QA pairs.") - -# Save the QA pairs to a JSON file -questions_path = os.path.join("saved_data", "questions.json") -with open(questions_path, "w") as f: - json.dump(all_questions, f, indent=2) -print(f"Saved questions to {questions_path}") diff --git a/notebooks/.gitignore b/notebooks/.gitignore deleted file mode 100644 index 5ac5af4..0000000 --- a/notebooks/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -unsloth_compiled_cache -0_* -faiss_index* \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d42da86..4c616ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,13 +5,13 @@ langchain-community Markdown tokenizers unsloth==2025.3.6 -transformers==4.49.0 unsloth_zoo==2025.3.4 unstructured -vllm -wandb +vllm==0.7.2 +transformers==4.49.0 ipykernel python-dotenv loguru -gradio \ No newline at end of file +gradio +tensorboard diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/generate_data_but_less_dumb.py b/scripts/generate_data.py similarity index 88% rename from generate_data_but_less_dumb.py rename to scripts/generate_data.py index e5842bf..ff67313 100644 --- a/generate_data_but_less_dumb.py +++ b/scripts/generate_data.py @@ -5,35 +5,31 @@ This script performs two main tasks: 2. It generates QA pairs from the document using llama. For each chunk (using a sliding window for context), it generates multiple question-answer pairs with different difficulties. The generation is performed in batch with one retry for failed prompts. - Successfully generated QA pairs are saved to "saved_data/questions.json". + Successfully generated QA pairs are saved to "data/questions.json". Requirements: pip install langchain faiss-cpu unsloth vllm """ import json -import os import re -from typing import Dict, List, Optional, Tuple +import sys +from pathlib import Path + +# Add project root to Python path +project_root = Path(__file__).resolve().parent.parent +sys.path.append(str(project_root)) import pandas as pd from langchain.text_splitter import RecursiveCharacterTextSplitter -from loguru import logger - -# Configure logger -logger.add( - "logs/generate_data_{time}.log", - rotation="500 MB", - retention="10 days", - level="INFO", -) # ========= Part 1: Document Processing and Embedding Generation ========= # Load and split the markdown document using LangChain from langchain_community.document_loaders import UnstructuredMarkdownLoader from langchain_community.vectorstores import FAISS -from embeddings import CustomHuggingFaceEmbeddings +from src.config import DATA_DIR, logger +from src.embeddings import CustomHuggingFaceEmbeddings # Load your markdown file (adjust the path as needed) loader = UnstructuredMarkdownLoader("./data/mission_report.md") @@ -43,9 +39,6 @@ docs = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) chunks = text_splitter.split_documents(docs) -# Create output directory -os.makedirs("saved_data", exist_ok=True) - # Save chunks to CSV for easy inspection chunks_df = pd.DataFrame( { @@ -54,15 +47,15 @@ chunks_df = pd.DataFrame( "metadata": [chunk.metadata for chunk in chunks], } ) -chunks_df.to_csv("saved_data/chunks.csv", index=False) -print(f"Saved {len(chunks)} chunks to saved_data/chunks.csv") +chunks_df.to_csv(DATA_DIR / "chunks.csv", index=False) +logger.info(f"Saved {len(chunks)} chunks to {DATA_DIR}/chunks.csv") embeddings = CustomHuggingFaceEmbeddings() # Create a FAISS vector store from the document chunks and save it locally vectorstore = FAISS.from_documents(chunks, embeddings) -vectorstore.save_local("faiss_index") -print("Saved FAISS index to 'faiss_index'") +vectorstore.save_local(str(DATA_DIR)) +logger.info(f"Saved FAISS index to {DATA_DIR}") # TODO: add the paraphrased chunks to the vector store @@ -72,8 +65,6 @@ print("Saved FAISS index to 'faiss_index'") from unsloth import FastLanguageModel from vllm import SamplingParams -import rl_helpers # Ensure you have this or remove if not used - # Load the Llama model (adjust parameters as needed) model, tokenizer = FastLanguageModel.from_pretrained( model_name="meta-llama/meta-Llama-3.1-8B-Instruct", @@ -91,7 +82,7 @@ sampling_params = SamplingParams( ) -def batch_generate(prompts: List[str]) -> List[str]: +def batch_generate(prompts: list) -> list: """ Given a list of prompt strings, returns a list of generated outputs. """ @@ -108,7 +99,7 @@ def batch_generate(prompts: List[str]) -> List[str]: return [output.outputs[0].text for output in outputs] -def parse_qa_block(block: str) -> Optional[Tuple[str, str, str]]: +def parse_qa_block(block: str): """ Parses a QA block that should contain exactly three non-empty lines: - A line starting with "Question:" @@ -141,7 +132,7 @@ def parse_qa_block(block: str) -> Optional[Tuple[str, str, str]]: return None -def parse_multiple_qa_output(output: str) -> List[Tuple[str, str, str]]: +def parse_multiple_qa_output(output: str) -> list: """ Splits the output into blocks (separated by one or more blank lines) and attempts to parse each as a QA pair. @@ -158,8 +149,8 @@ def parse_multiple_qa_output(output: str) -> List[Tuple[str, str, str]]: def generate_question_batch_for_chunks( - chunks: List, num_questions: int = 2, difficulty: Optional[str] = None -) -> List[Dict]: + chunks: list, num_questions: int = 2, difficulty=None +) -> list: """ Generates QA pairs for multiple chunks in batch. @@ -198,7 +189,9 @@ def generate_question_batch_for_chunks( # First batch generation outputs = batch_generate(prompts) - results: List[Optional[List[Tuple[str, str, str]]]] = [None] * len(outputs) + results = [] + for _ in range(len(outputs)): + results.append(None) failed_indices = [] # Parse each output @@ -270,10 +263,10 @@ def generate_question_batch_for_chunks( all_questions = generate_question_batch_for_chunks( chunks, num_questions=2, difficulty="medium" ) -print(f"Generated {len(all_questions)} QA pairs.") +logger.info(f"Generated {len(all_questions)} QA pairs.") # Save the QA pairs to a JSON file -questions_path = os.path.join("saved_data", "questions.json") +questions_path = DATA_DIR / "questions.json" with open(questions_path, "w") as f: json.dump(all_questions, f, indent=2) -print(f"Saved questions to {questions_path}") +logger.info(f"Saved questions to {questions_path}") diff --git a/simple_qa.py b/scripts/simple_qa.py similarity index 61% rename from simple_qa.py rename to scripts/simple_qa.py index 5c5c4ac..7467526 100644 --- a/simple_qa.py +++ b/scripts/simple_qa.py @@ -8,10 +8,18 @@ import json import random import sys import time -from typing import Any, Dict +from pathlib import Path -# Import our search module (ensure these functions follow the new interfaces) -from search_module import get_question_answer, get_question_count, search +# Add project root to Python path +project_root = Path(__file__).resolve().parent.parent +sys.path.append(str(project_root)) + +# Import our search module and config +from src.config import DATA_DIR, logger +from src.search_module import get_question_answer, get_question_count, search + +# TODO: Import verify function and router from appropriate module +# TODO: Consider moving verify function to search_module.py for better organization class SimpleQAEnvironment: @@ -21,27 +29,30 @@ class SimpleQAEnvironment: self.score = {"correct": 0, "incorrect": 0, "total": 0} self.session_data = [] self.current_question = None + self.session_file = DATA_DIR / "qa_sessions" def display_welcome(self): """Display welcome message and instructions.""" - print("\n===== Search & Answer Environment =====") - print("Answer questions using the search tool to find relevant information.") - print("Type 'q' to quit, 'h' for help.\n") + logger.info("===== Search & Answer Environment =====") + logger.info( + "Answer questions using the search tool to find relevant information." + ) + logger.info("Type 'q' to quit, 'h' for help.\n") def display_help(self): """Display help information.""" - print("\n===== Commands =====") - print("n - Get a new question") - print("s - Search for information (e.g., s program launch date)") - print("a - Submit your answer") - print("h - Display this help message") - print("q - Quit the program\n") + logger.info("\n===== Commands =====") + logger.info("n - Get a new question") + logger.info("s - Search for information (e.g., s program launch date)") + logger.info("a - Submit your answer") + logger.info("h - Display this help message") + logger.info("q - Quit the program\n") def display_question(self, question: str): """Display the current question.""" - print("\n===== QUESTION =====") - print(question) - print("=====================\n") + logger.info("\n===== QUESTION =====") + logger.info(question) + logger.info("=====================\n") def get_new_question(self) -> str: """Get a new random question and set it as current.""" @@ -66,30 +77,30 @@ class SimpleQAEnvironment: def perform_search(self, query: str): """Perform a search with the given query.""" if not query: - print("Please provide a search query.") + logger.warning("Please provide a search query.") return try: - print("\n===== SEARCH RESULTS =====") + logger.info("\n===== SEARCH RESULTS =====") results = search(query) - print(results) - print("==========================\n") + logger.info(results) + logger.info("==========================\n") # Record search in current question data if available. if self.current_question is not None: self.current_question["searches"].append(query) except Exception as e: - print(f"Error searching: {str(e)}") + logger.error(f"Error searching: {str(e)}") async def process_answer(self, user_answer: str): """Process and verify the user's answer.""" if self.current_question is None: - print("Please get a question first.") + logger.warning("Please get a question first.") return if not user_answer: - print("Please provide an answer.") + logger.warning("Please provide an answer.") return # Record answer and calculate time taken. @@ -100,27 +111,29 @@ class SimpleQAEnvironment: ) try: - print("\nVerifying your answer...") - correct = await verify( - user_answer, - self.current_question["question"], - self.current_question["correct_answer"], - router, - ) + logger.info("\nVerifying your answer...") + # TODO: Implement verify function in search_module.py + # correct = await verify( + # user_answer, + # self.current_question["question"], + # self.current_question["correct_answer"], + # router, + # ) + correct = False # Temporary placeholder until verify is implemented # Update score and inform the user. self.score["total"] += 1 if correct: self.score["correct"] += 1 - print("\nāœ“ Your answer is CORRECT!") + logger.success("\nāœ“ Your answer is CORRECT!") else: self.score["incorrect"] += 1 - print("\nāœ— Your answer is INCORRECT.") - print( + logger.error("\nāœ— Your answer is INCORRECT.") + logger.info( f"\nThe correct answer is:\n{self.current_question['correct_answer']}" ) - print(f"\nScore: {self.score['correct']}/{self.score['total']}") + logger.info(f"\nScore: {self.score['correct']}/{self.score['total']}") # Record the result and add the current question to the session data. self.current_question["is_correct"] = correct @@ -130,15 +143,18 @@ class SimpleQAEnvironment: self.current_question = None except Exception as e: - print(f"Error verifying answer: {str(e)}") + logger.error(f"Error verifying answer: {str(e)}") def save_session(self): """Save the session data to a file.""" if not self.session_data: return + # Ensure session directory exists + self.session_file.mkdir(parents=True, exist_ok=True) + timestamp = time.strftime("%Y%m%d_%H%M%S") - filename = f"qa_session_{timestamp}.json" + filename = self.session_file / f"qa_session_{timestamp}.json" session_data = { "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), @@ -149,9 +165,9 @@ class SimpleQAEnvironment: try: with open(filename, "w") as f: json.dump(session_data, f, indent=2) - print(f"\nSession data saved to {filename}") + logger.info(f"\nSession data saved to {filename}") except Exception as e: - print(f"Error saving session data: {str(e)}") + logger.error(f"Error saving session data: {str(e)}") async def run(self): """Run the main command loop.""" @@ -178,11 +194,11 @@ class SimpleQAEnvironment: answer = command[2:].strip() await self.process_answer(answer) else: - print("Unknown command. Type 'h' for help.") + logger.warning("Unknown command. Type 'h' for help.") # Save session data on exit. self.save_session() - print("\nThank you for using the Q&A environment!") + logger.info("\nThank you for using the Q&A environment!") async def main(): @@ -195,6 +211,6 @@ if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: - print("\nProgram terminated by user.") + logger.info("\nProgram terminated by user.") except Exception as e: - print(f"\nError: {str(e)}") + logger.error(f"\nError: {str(e)}") diff --git a/UnslothGRPOTrainerTemp.py b/src/UnslothGRPOTrainerTemp.py similarity index 99% rename from UnslothGRPOTrainerTemp.py rename to src/UnslothGRPOTrainerTemp.py index dcdbd90..8ae417f 100644 --- a/UnslothGRPOTrainerTemp.py +++ b/src/UnslothGRPOTrainerTemp.py @@ -1,18 +1,21 @@ -from torch import Tensor +import os +from contextlib import nullcontext +from dataclasses import dataclass, field +from typing import * + +import numpy as np import torch import torch.nn as nn -from torch.nn import functional as F +from packaging.version import Version from trl.trainer.grpo_trainer import ( Any, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, Dataset, - GRPOConfig, - GRPOTrainer, GenerationConfig, + GRPOConfig, IterableDataset, - LLM, Optional, PeftConfig, PreTrainedModel, @@ -41,7 +44,6 @@ from trl.trainer.grpo_trainer import ( nn, os, pad, - patch, prepare_deepspeed, set_seed, textwrap, @@ -50,42 +52,8 @@ from trl.trainer.grpo_trainer import ( unwrap_model_for_generation, version, wandb, - warnings, - os, - torch, - transformers, - Any, - LLM, - Union, - apply_chat_template, - broadcast_object_list, - gather, - gather_object, - is_conversational, - maybe_apply_chat_template, - nn, - os, - pad, - torch, - unwrap_model_for_generation, - wandb, - GRPOTrainer, - Trainer, - gather, - os, - torch, ) - -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version -import torch -import numpy as np -from contextlib import nullcontext -from torch.nn import functional as F - torch_compile_options = { "epilogue_fusion": True, "max_autotune": False, diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..838b971 --- /dev/null +++ b/src/config.py @@ -0,0 +1,301 @@ +import os +import sys +from datetime import datetime +from pathlib import Path + +import torch +from dotenv import load_dotenv +from loguru import logger +from vllm import SamplingParams + +# Load environment variables from .env file if it exists +load_dotenv(override=True) + +# Project paths +PROJ_ROOT = Path(__file__).resolve().parent.parent +DATA_DIR = PROJ_ROOT / "data" +LOG_FOLDER = PROJ_ROOT / "logs" + +# Model configuration +# MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +device_id = ( + 1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device() +) +timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + +OUTPUT_DIR = ( + PROJ_ROOT + / f"trainer_output_{MODEL_NAME.replace('/', '_')}_gpu{device_id}_{timestamp}" +) + +# Model parameters +MODEL_CONFIG = { + "max_seq_length": 4096 * 2, # Can increase for longer reasoning traces + "lora_rank": 64, # Larger rank = smarter, but slower + "gpu_memory_utilization": 0.6, # Reduce if out of memory + "model_name": MODEL_NAME, + "target_modules": [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], # Remove QKVO if out of memory +} + +# Training parameters +TRAINING_CONFIG = { + "learning_rate": 5e-6, + "adam_beta1": 0.9, + "adam_beta2": 0.99, + "weight_decay": 0.1, + "warmup_ratio": 0.1, + "lr_scheduler_type": "cosine", + "optim": "paged_adamw_8bit", + "logging_steps": 1, + "per_device_train_batch_size": 8, + "gradient_accumulation_steps": 1, # Increase to 4 for smoother training + "num_generations": 8, # Decrease if out of memory + "max_prompt_length": 1024, + "max_completion_length": 1024, + "max_steps": 101, + "save_steps": 50, + "max_grad_norm": 0.1, + "report_to": "tensorboard", +} + + +# Sampling parameters +def get_sampling_params(temperature: float = 0.1) -> SamplingParams: + """Get sampling parameters for text generation""" + return SamplingParams( + temperature=temperature, + top_p=0.95, + max_tokens=4096, + ) + + +# Initialize logging based on environment +def _init_logging(env: str = "development") -> None: + """ + Initialize logging configuration with console logging + and default file logging to ./logs directory. + Additional file logging will be set up later in update_log_path(). + + Args: + env: The environment for logging ('development' or 'production') + """ + # Create default log folder + if not LOG_FOLDER.exists(): + LOG_FOLDER.mkdir(parents=True, exist_ok=True) + + # Remove any existing handlers + logger.remove() + + # Define the logging format + console_format = ( + "{time:YYYY-MM-DD HH:mm:ss} " + "| {level: <8} " + "| {name}:{function}:{line} " + "- {message}" + ) + + file_format = ( + "{time:YYYY-MM-DD at HH:mm:ss} " + "| {level} " + "| {name}:{function}:{line} " + "- {message}" + ) + + # Add console logging + logger.add( + sys.stderr, + format=console_format, + level="DEBUG" if env == "development" else "INFO", + colorize=True, + backtrace=True, + diagnose=env == "development", + ) + + # Add default file logging to ./logs directory + logger.add( + LOG_FOLDER / "app.log", + format=file_format, + level="INFO", + rotation="500 MB", + retention="7 days", + compression="zip", + enqueue=True, # Enables asynchronous logging + ) + + # Add custom level for requests + logger.level("REQUEST", no=25, color="", icon=" ") + + # Configure exception handling + def exception_handler(exc_type, exc_value, exc_traceback): + if issubclass(exc_type, KeyboardInterrupt): + sys.__excepthook__(exc_type, exc_value, exc_traceback) + return + logger.opt(exception=(exc_type, exc_value, exc_traceback)).critical( + "Unhandled exception" + ) + + sys.excepthook = exception_handler + + +# Update the log files to point to the training directory +def update_log_path(log_dir=None): + """ + Add a log file in the training directory while keeping the default ./logs logging. + Should be called after the training directory is created. + + Args: + log_dir: Path to store additional log files (default: uses get_paths()["log_dir"]) + """ + # Use provided log_dir or get from training paths + if log_dir is None: + paths = get_paths(create_dirs=True) + log_dir = paths["log_dir"] + else: + log_dir = Path(log_dir) + log_dir.mkdir(exist_ok=True, parents=True) + + file_format = ( + "{time:YYYY-MM-DD at HH:mm:ss} " + "| {level} " + "| {name}:{function}:{line} " + "- {message}" + ) + + # Add additional file handler pointing to training directory + # No need to remove existing handlers as we want to keep those + logger.add( + log_dir / "app.log", + format=file_format, + level="INFO", + rotation="500 MB", + retention="7 days", + compression="zip", + enqueue=True, # Enables asynchronous logging + ) + + logger.info(f"Additional logs will be stored in: {log_dir}") + + +# Paths configuration without creating directories +def get_paths(create_dirs: bool = False) -> dict: + """ + Get common paths for the project + + Args: + create_dirs: Whether to create the directories + + Returns: + Dictionary with paths + """ + output_dir = Path(OUTPUT_DIR) + log_dir = output_dir / "logs" + tensorboard_dir = output_dir / "runs" + + # Only create directories if explicitly requested + if create_dirs: + output_dir.mkdir(exist_ok=True) + log_dir.mkdir(exist_ok=True) + + # Only create tensorboard directory if it's enabled in config + if TRAINING_CONFIG.get("report_to") == "tensorboard": + tensorboard_dir.mkdir(exist_ok=True) + + return { + "output_dir": output_dir, + "log_dir": log_dir, + "tensorboard_dir": tensorboard_dir, + "proj_root": PROJ_ROOT, + "data_dir": DATA_DIR, + } + + +# Create training directories +def init_training_dirs(): + """Initialize all directories needed for training""" + paths = get_paths(create_dirs=True) + + # Also ensure our standard project directories exist + for directory in [ + DATA_DIR, + LOG_FOLDER, + ]: + directory.mkdir(exist_ok=True, parents=True) + + return paths + + +# For backward compatibility - will be deprecated +def setup_logger(module_name=None, create_dirs: bool = False): + """ + Setup a logger for a specific module with consistent configuration. + + Note: This function is kept for backward compatibility. + Use the global 'logger' instead for new code. + + Args: + module_name: Optional name of module for module-specific log file + create_dirs: Whether to create log directories + + Returns: + Configured logger instance + """ + logger.warning( + "setup_logger is deprecated. Import logger directly from config instead." + ) + return logger + + +# Tensorboard writer singleton +_tensorboard_writer = None + + +# Safe tensorboard logging function +def log_metric(key, value, step=0): + """ + Log a metric safely to tensorboard if writer is available. + + Args: + key: Metric name + value: Metric value + step: Training step + """ + global _tensorboard_writer + + # Skip tensorboard logging if disabled in config + if TRAINING_CONFIG.get("report_to") != "tensorboard": + logger.debug(f"Tensorboard disabled. Metric: {key}={value} (step {step})") + return + + # Get paths and initialize writer if needed + paths = get_paths(create_dirs=False) + if paths["tensorboard_dir"].exists(): + # Only create writer once + if _tensorboard_writer is None: + from torch.utils.tensorboard.writer import SummaryWriter + + _tensorboard_writer = SummaryWriter(paths["tensorboard_dir"]) + logger.debug(f"Created tensorboard writer at {paths['tensorboard_dir']}") + + # Add scalar using existing writer + _tensorboard_writer.add_scalar(key, value, step) + # No need to close the writer - it will be closed at process exit + else: + logger.debug(f"Tensorboard metric: {key}={value} (step {step})") + + +# Initialize logging on module import +env = os.getenv("APP_ENV", "development") +_init_logging(env=env) + +# Log project root on import +logger.info(f"Project root path: {PROJ_ROOT}") +logger.debug(f"Running in {env} environment") diff --git a/embeddings.py b/src/embeddings.py similarity index 90% rename from embeddings.py rename to src/embeddings.py index b9570e5..90af333 100644 --- a/embeddings.py +++ b/src/embeddings.py @@ -1,7 +1,4 @@ -from typing import List, Union - import torch -import torch.nn.functional as F from langchain.embeddings.base import Embeddings from transformers import AutoModel, AutoTokenizer @@ -18,9 +15,7 @@ class CustomHuggingFaceEmbeddings(Embeddings): - "query": uses mean pooling over tokens (weighted by the attention mask) for query embeddings. """ - def __init__( - self, model_name: str = DEFAULT_MODEL_NAME, default_mode: str = "sentence" - ): + def __init__(self, model_name=DEFAULT_MODEL_NAME, default_mode="sentence"): self.model_name = model_name # Set device to GPU if available, else CPU self.device = "cuda" if torch.cuda.is_available() else "cpu" @@ -29,7 +24,8 @@ class CustomHuggingFaceEmbeddings(Embeddings): self.default_mode = default_mode # "sentence" or "query" self.model.eval() # Set model to evaluation mode - def get_embedding(self, text: Union[str, List[str]], mode: str = None): + def get_embedding(self, text, mode=None): + """Get embeddings for text using specified mode""" if mode is None: mode = self.default_mode assert mode in ( @@ -59,14 +55,14 @@ class CustomHuggingFaceEmbeddings(Embeddings): vectors = output.last_hidden_state[:, 0, :] return vectors - def embed_documents(self, texts: List[str]) -> List[List[float]]: + def embed_documents(self, texts): """ Compute embeddings for a list of documents (using sentence mode). """ vectors = self.get_embedding(texts, mode="sentence") return vectors.cpu().numpy().tolist() - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text): """ Compute an embedding for a single query. """ diff --git a/rl_helpers.py b/src/rl_helpers.py similarity index 59% rename from rl_helpers.py rename to src/rl_helpers.py index 44ecee4..d60c887 100644 --- a/rl_helpers.py +++ b/src/rl_helpers.py @@ -4,33 +4,21 @@ This module provides utility functions for handling chat-based tool interactions and calculating rewards based on the quality of responses. """ -import asyncio +import inspect import json import re from dataclasses import dataclass from datetime import datetime -from pathlib import Path import nest_asyncio import numpy as np import torch -from loguru import logger -from search_module import get_qa_dataset, search - -# Setup loguru -log_dir = Path("logs") -log_dir.mkdir(exist_ok=True) -logger.add( - log_dir / "rl_helpers_{time}.log", - rotation="500 MB", - retention="10 days", - level="DEBUG", - format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}", -) +from src.config import log_metric, logger +from src.search_module import get_qa_dataset, search +# Apply nest_asyncio for supporting async operations in notebooks nest_asyncio.apply() -from typing import Callable, List from trl.trainer.grpo_trainer import apply_chat_template @@ -238,6 +226,8 @@ def run_tool_calls(chat_states): Execute tool calls found in chat states. """ logger.debug(f"Running tool calls for {len(chat_states)} chat states") + total_retries = 0 + for chat_state in chat_states: if chat_state.get("finished"): logger.debug("Chat state already finished, skipping tool calls") @@ -256,9 +246,14 @@ def run_tool_calls(chat_states): elif len(function_calls) == 1: function_call = function_calls[0] query = function_call["function"]["parameters"]["query"] - logger.info(f"Executing search with query: {query}") + logger.info(f"šŸ” Search Query: {query}") results = search(query, return_type=str, results=2) chat_state["messages"].append({"role": "ipython", "content": results}) + + # Count retries + retries = len(extract_json_objects(assistant_response)) + total_retries += retries + logger.debug("Added search results to chat state") except Exception as e: logger.error(f"Error during tool call: {str(e)}") @@ -332,7 +327,12 @@ def get_chat_num_tokens(chat_state, tokenizer): def run_agent( - generate_fn, tokenizer, questions, max_generations=5, max_new_tokens=4096 + generate_fn, + tokenizer, + questions, + max_generations=5, + max_new_tokens=4096, + correct_contents=None, ): """ Run the agent to completion for a batch of questions. @@ -343,6 +343,11 @@ def run_agent( ) chat_states = [get_initial_chat(q) for q in questions] + # Add correct content to chat states if provided + if correct_contents: + for chat_state, correct_content in zip(chat_states, correct_contents): + chat_state["correct_content"] = correct_content + # set the initial_prompt length for i, chat_state in enumerate(chat_states): chat_state["initial_length"] = get_chat_num_tokens(chat_state, tokenizer) @@ -350,7 +355,7 @@ def run_agent( # agent loop for i in range(max_generations): - logger.info(f"Starting generation step {i+1}/{max_generations}") + logger.info(f"Starting generation step {i + 1}/{max_generations}") chat_states = run_agent_generations(generate_fn, tokenizer, chat_states) chat_states = check_finished_chats(chat_states) chat_states = run_tool_calls(chat_states) @@ -359,7 +364,7 @@ def run_agent( ) finished_count = sum(1 for state in chat_states if state.get("finished")) logger.info( - f"Finished {finished_count}/{len(chat_states)} chat states after step {i+1}" + f"Finished {finished_count}/{len(chat_states)} chat states after step {i + 1}" ) logger.info("Agent run completed") @@ -440,15 +445,26 @@ async def verify(student_answer: str, question: str, answer: str) -> bool: def check_student_answers( - questions: List[str], - answers: List[str], - student_answers: List[str], - vllm_generate_func: Callable[[List[str]], List[str]], + questions: list[str], + answers: list[str], + student_answers: list, # Can be strings or dicts + vllm_generate_func, tokenizer, - log_file: str = "qa_log.txt", -) -> List[bool]: + log_file=None, +) -> list[bool]: """ Evaluates a list of student answers against the true answers using a vLLM generate function. + + Args: + questions: List of questions + answers: List of correct answers + student_answers: List of student answers to evaluate + vllm_generate_func: Function to generate verification responses + tokenizer: Tokenizer for formatting prompts + log_file: Optional path to write detailed results + + Returns: + List of boolean results (True for correct answers) """ logger.info(f"Checking {len(questions)} student answers") @@ -463,12 +479,15 @@ def check_student_answers( prompts = [] for question, answer, student_ans in zip(questions, answers, student_answers): prompt_text = ( - "You are grading a student's answer. For the following question, " - "compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer is correct, or 'No' if it is completely incorrect.\n\n" + "You are grading a student's answer to a question. For the following question, " + "compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer contains the correct information, " + "even if it's not an exact match. If the student's answer doesn't contain the right information or is completely incorrect, reply with 'No'.\n\n" f"Question: {question}\n" f"Correct Answer: {answer}\n" - f"Student Answer: {student_ans}\n" + f"Student Answer: {student_ans}\n\n" + "Your response should be just 'Yes' or 'No'." ) + formatted_prompt = tokenizer.apply_chat_template( [{"role": "user", "content": prompt_text}], tokenize=False, @@ -481,10 +500,15 @@ def check_student_answers( responses = vllm_generate_func(prompts) responses_text = [] for response in responses: + # Handle different response formats if hasattr(response, "outputs"): - responses_text.append(response.outputs[0].text) + try: + responses_text.append(response.outputs[0].text) + except (AttributeError, IndexError): + # Fallback for simple string responses + responses_text.append(str(response)) else: - responses_text.append(response) + responses_text.append(str(response)) logger.debug(f"Got {len(responses_text)} verification responses") results = [] @@ -495,34 +519,108 @@ def check_student_answers( logger.info(f"Verification complete. {sum(results)}/{len(results)} answers correct") # Append the QA details and verifier's response to the specified log file - with open(log_file, "a") as file: - for question, answer, student_ans, verifier_response in zip( - questions, answers, student_answers, responses_text - ): - file.write("Question: " + question + "\n") - file.write("Correct Answer: " + answer + "\n") - file.write("Student Answer: " + student_ans + "\n") - file.write("Verifier said: " + verifier_response + "\n") - file.write("-" * 40 + "\n") + if log_file: + with open(log_file, "a") as file: + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + file.write(f"\nšŸ“ === QA Evaluation at {timestamp} ===\n") + file.write(f"šŸ“‚ File: {__file__}\n") + + # Get current frame info safely + frame = inspect.currentframe() + if frame: + file.write(f"šŸ“ Line: {frame.f_lineno}\n") + # Don't forget to delete the frame to avoid reference cycles + del frame + + file.write("=" * 80 + "\n") + + for i, (question, answer, student_ans, verifier_response) in enumerate( + zip(questions, answers, student_answers, responses_text) + ): + file.write(f"\nā“ Question {i+1}:\n") + file.write("-" * 40 + "\n") + file.write(f"šŸ“‹ Question: {question}\n") + file.write(f"āœ… Correct Answer: {answer}\n") + file.write(f"šŸ‘Øā€šŸŽ“ Student Answer: {student_ans}\n") + file.write(f"šŸ” Verifier said: {verifier_response}\n") + + # Add search results if available in the chat state + if isinstance(student_ans, dict) and "messages" in student_ans: + # Get messages from dict + messages = student_ans.get("messages", []) + search_results = [ + msg.get("content", "") + for msg in messages + if msg.get("role") == "ipython" + ] + if search_results: + file.write("\nšŸ”Ž Search Results:\n") + for j, result in enumerate(search_results, 1): + file.write(f"\nSearch {j}:\n{result}\n") + + file.write("-" * 40 + "\n") + + file.write( + f"\nšŸ“Š Summary: {sum(results)}/{len(results)} answers correct ({sum(results)/len(results)*100:.2f}%)\n" + ) + file.write("=" * 80 + "\n\n") return results # Reward Functions -def build_reward_correctness_fn(generate_fn, tokenizer): +def build_reward_correctness_fn(generate_fn, tokenizer, log_file=None): def reward_correctness(prompts, completions, **reward_kwargs): teacher_answers = reward_kwargs["answer"] student_answers = [ completion["messages"][-1]["content"] for completion in completions ] + # Log non-exact matches + for i, (student, teacher) in enumerate(zip(student_answers, teacher_answers)): + if student.strip().lower() != teacher.strip().lower(): + logger.warning( + f"Non-exact match at index {i}:\n" + f"Student: {student}\n" + f"Teacher: {teacher}" + ) + correct = check_student_answers( prompts, teacher_answers, student_answers, vllm_generate_func=generate_fn, tokenizer=tokenizer, + log_file=log_file, + ) + + # Log correctness metrics with length info + log_metric( + "rewards/correctness", np.mean(correct), reward_kwargs.get("step", 0) ) + log_metric( + "rewards/correctness_std", np.std(correct), reward_kwargs.get("step", 0) + ) + + # Log length metrics + student_lengths = [len(ans.strip()) for ans in student_answers] + teacher_lengths = [len(ans.strip()) for ans in teacher_answers] + log_metric( + "metrics/avg_student_length", + np.mean(student_lengths), + reward_kwargs.get("step", 0), + ) + log_metric( + "metrics/avg_teacher_length", + np.mean(teacher_lengths), + reward_kwargs.get("step", 0), + ) + log_metric( + "metrics/length_ratio", + np.mean(student_lengths) / np.mean(teacher_lengths), + reward_kwargs.get("step", 0), + ) + return correct return reward_correctness @@ -535,14 +633,23 @@ def reward_formatting(prompts, completions, **reward_kwargs): for message in chat["messages"]: if "Error during" in message["content"]: has_error[i] = True + logger.warning(f"Error in chat {i}: {message['content']}") break - return [0.7 if not e else 0 for e in has_error] + + rewards = [0.7 if not e else 0 for e in has_error] + + # Log formatting metrics + log_metric("rewards/formatting", np.mean(rewards), reward_kwargs.get("step", 0)) + log_metric("rewards/formatting_std", np.std(rewards), reward_kwargs.get("step", 0)) + log_metric("metrics/error_rate", np.mean(has_error), reward_kwargs.get("step", 0)) + + return rewards def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[float]: """ - Reward function that encourages optimal retry behavior by counting total function calls - across all assistant messages in the conversation. + Reward function that encourages optimal retry behavior by only rewarding completions + where every assistant message contains at most 1 JSON object. """ rewards: list[float] = [] @@ -558,21 +665,61 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa rewards.append(0.0) continue - # Count total function calls across all messages - total_retries: int = 0 + # Check if every message has at most 1 JSON object + has_multiple_json = False + total_json_objects = 0 + for msg in assistant_msgs: - total_retries += len(extract_json_objects(msg)) + json_objects = extract_json_objects(msg) + json_count = len(json_objects) + total_json_objects += json_count + + if json_count > 1: + has_multiple_json = True + logger.warning( + f"Message contains {json_count} JSON objects, which exceeds the limit of 1" + ) + break - # Calculate reward using modified sigmoid function - x: float = float(total_retries - 4) # Center peak at 4 retries - base_reward: float = 1.0 / (1.0 + np.exp(-x + abs(x) / 2)) + # Only reward if no message has multiple JSON objects + if has_multiple_json: + rewards.append(0.0) + else: + # Base reward is 1.0 if constraint is met + base_reward = 1.0 + + # Slight penalty for having too many total JSON objects across all messages + if total_json_objects > 4: + penalty = 0.1 * (total_json_objects - 4) + base_reward = max(0.2, base_reward - penalty) + logger.debug( + f"Applied penalty for {total_json_objects} total JSON objects: {penalty}" + ) - # Additional penalty for excessive retries - if total_retries > 6: - penalty: float = 0.2 * (total_retries - 6) - base_reward = max(0.1, base_reward - penalty) + rewards.append(base_reward) - rewards.append(base_reward) + # Log retry behavior metrics + log_metric("rewards/retry_behavior", np.mean(rewards), reward_kwargs.get("step", 0)) + log_metric( + "rewards/retry_behavior_std", np.std(rewards), reward_kwargs.get("step", 0) + ) + log_metric( + "metrics/avg_json_per_msg", + np.mean( + [ + len(extract_json_objects(msg["content"])) + for completion in completions + for msg in completion["messages"] + if msg["role"] == "assistant" + ] + ), + reward_kwargs.get("step", 0), + ) + log_metric( + "metrics/multiple_json_violation_rate", + np.mean([0.0 if rewards[i] > 0.0 else 1.0 for i in range(len(rewards))]), + reward_kwargs.get("step", 0), + ) return rewards @@ -599,6 +746,11 @@ def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs): ] logger.debug(f"Found {len(search_results)} search results for prompt {i}") + # Log ground truth chunk and searched chunks + logger.info(f"šŸ“ Ground Truth Chunk: {correct_content}") + for j, result in enumerate(search_results): + logger.info(f"šŸ” Searched Chunk {j+1}: {result}") + # Check if any search hit the correct chunk content found_correct_chunk = False for result in search_results: @@ -609,30 +761,145 @@ def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs): ) break + if not found_correct_chunk: + logger.warning( + f"Failed to find correct chunk for prompt {i}:\n" + f"Search results: {[r[:100] + '...' for r in search_results]}" + ) + reward = 1.0 if found_correct_chunk else 0.0 rewards.append(reward) logger.debug(f"Reward for prompt {i}: {reward}") - logger.info(f"Average reward: {sum(rewards)/len(rewards):.3f}") + # Log detailed metrics for debugging + log_metric( + f"debug/chunk_match_{i}", + 1 if found_correct_chunk else 0, + reward_kwargs.get("step", 0), + ) + log_metric( + f"debug/search_results_count_{i}", + len(search_results), + reward_kwargs.get("step", 0), + ) + if search_results: + log_metric( + f"debug/result_length_{i}", + np.mean([len(r.split()) for r in search_results]), + reward_kwargs.get("step", 0), + ) + + # Log chunk query metrics + log_metric("rewards/chunk_query", np.mean(rewards), reward_kwargs.get("step", 0)) + log_metric("rewards/chunk_query_std", np.std(rewards), reward_kwargs.get("step", 0)) + log_metric( + "metrics/avg_search_results", + np.mean( + [ + len( + [ + msg["content"] + for msg in chat_state["messages"] + if msg["role"] == "ipython" + ] + ) + for chat_state in completions + ] + ), + reward_kwargs.get("step", 0), + ) + log_metric( + "metrics/chunk_match_rate", np.mean(rewards), reward_kwargs.get("step", 0) + ) + + # Log detailed debugging info + logger.info("Chunk Query Rewards Summary:") + logger.info(f"Total prompts: {len(prompts)}") + logger.info(f"Correct matches: {sum(rewards)}") + logger.info(f"Average reward: {np.mean(rewards):.3f}") + logger.info(f"Reward std: {np.std(rewards):.3f}") + return rewards -def run_eval(generate_fn, verify_fn, tokenizer): - logger.info("Starting evaluation") +def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=None): + """ + Run evaluation on the test dataset and return results. + + Args: + generate_fn: Function to generate completions + verify_fn: Function to verify results + tokenizer: Tokenizer for processing text + output_file: Path to save evaluation results summary + debug_file: Path to save detailed debug information + + Returns: + full_chat_states: The chat states from evaluation + """ train_dataset, test_dataset = get_qa_dataset() questions = test_dataset["prompt"] - logger.info(f"Loaded {len(questions)} test questions") - agentic_outputs = run_agent(generate_fn, tokenizer, questions) full_chat_states = agentic_outputs.full_chat_states final_responses = agentic_outputs.final_response_str - - logger.info("Calculating rewards") rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"]) - avg_reward = sum(rewards) / len(rewards) - logger.info("EVALUATION RESULTS:") - logger.info(f"Percentage of correct answers: {avg_reward:.3f}") + # Calculate results + percent_correct = sum(rewards) / len(rewards) * 100 + + # Log results to console + logger.info("RESULTS:") + logger.info(f"percentage of correct answers: {percent_correct:.2f}%") logger.info("=" * 30) + # Save results to file if specified + if output_file: + try: + with open(output_file, "w") as f: + f.write("EVALUATION RESULTS\n") + f.write("=================\n\n") + f.write(f"Total questions: {len(questions)}\n") + f.write(f"Correct answers: {sum(rewards)}\n") + f.write(f"Percentage correct: {percent_correct:.2f}%\n\n") + + f.write("Individual results:\n") + for i, (q, r, resp) in enumerate( + zip(questions, rewards, final_responses) + ): + f.write(f"\nQ{i+1}: {q[:100]}...\n") + f.write(f"Correct: {'āœ“' if r else 'āœ—'}\n") + f.write(f"Response: {resp[:150]}...\n") + f.write("-" * 40 + "\n") + logger.info(f"Saved evaluation results to {output_file}") + except Exception as e: + logger.error(f"Error saving results file: {e}") + + # Save debug information if specified + if debug_file: + try: + import json + + debug_data = [] + for i, (q, r, resp, chat) in enumerate( + zip(questions, rewards, final_responses, full_chat_states) + ): + debug_data.append( + { + "question_id": i, + "question": q, + "is_correct": bool(r), + "final_response": resp, + "chat_state": { + k: str(v) if isinstance(v, (list, dict)) else v + for k, v in chat.items() + if k != "tokenizer" + }, + } + ) + + with open(debug_file, "w") as f: + json.dump(debug_data, f, indent=2) + logger.info(f"Saved debug information to {debug_file}") + except Exception as e: + logger.error(f"Error saving debug file: {e}") + return full_chat_states diff --git a/search_module.py b/src/search_module.py similarity index 67% rename from search_module.py rename to src/search_module.py index e742c28..50955b7 100644 --- a/search_module.py +++ b/src/search_module.py @@ -3,40 +3,33 @@ Search module for RL training loop. This module provides functions to search through vectorized documents and retrieve question-answer pairs. """ -import pickle import json import random -import asyncio -from typing import List, Tuple, Optional, Union, Dict, Any -from enum import Enum -from pydantic import BaseModel -from langchain.vectorstores import FAISS + from datasets import Dataset -from embeddings import CustomHuggingFaceEmbeddings +from langchain.vectorstores import FAISS + +from src.config import DATA_DIR, logger +from src.embeddings import CustomHuggingFaceEmbeddings # Load pre-saved vectorstore def load_vectorstore(): """Load the pre-saved FAISS index""" try: - import os - embeddings = CustomHuggingFaceEmbeddings() - # Load the FAISS index with absolute path - index_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "faiss_index" - ) - print(f"Loading FAISS index from: {index_path}") + # Load the FAISS index from the data directory + logger.info(f"Loading FAISS index from: {DATA_DIR}") vectorstore = FAISS.load_local( - index_path, embeddings, allow_dangerous_deserialization=True + str(DATA_DIR), embeddings, allow_dangerous_deserialization=True ) - print("Successfully loaded FAISS index") + logger.info("Successfully loaded FAISS index") return vectorstore except Exception as e: - print(f"Error loading vectorstore: {e}") + logger.error(f"Error loading vectorstore: {e}") import traceback - traceback.print_exc() + logger.debug(traceback.format_exc()) return None @@ -44,13 +37,13 @@ def load_vectorstore(): try: vectorstore = load_vectorstore() if vectorstore is None: - print("Warning: FAISS vectorstore could not be loaded.") + logger.warning("FAISS vectorstore could not be loaded.") except Exception as e: - print(f"Error loading vectorstore: {e}") + logger.error(f"Error loading vectorstore: {e}") vectorstore = None -def search(query: str, return_type=str, results: int = 5) -> Union[str, List[str]]: +def search(query: str, return_type=str, results: int = 5): """ Search for relevant chunks using similarity search. @@ -82,51 +75,36 @@ def search(query: str, return_type=str, results: int = 5) -> Union[str, List[str # Load questions from saved data def load_qa_data(): - """Load the pre-generated questions and document chunks""" + """Load the pre-generated questions""" try: - import os - - # Get absolute paths to data files - base_dir = os.path.dirname(os.path.abspath(__file__)) - chunks_path = os.path.join(base_dir, "saved_data", "chunks.pkl") - questions_path = os.path.join(base_dir, "saved_data", "questions.json") - - print(f"Loading chunks from: {chunks_path}") - print(f"Loading questions from: {questions_path}") - - # Load the chunks - with open(chunks_path, "rb") as f: - chunks = pickle.load(f) + questions_path = DATA_DIR / "questions.json" + logger.info(f"Loading questions from: {questions_path}") # Load the questions with open(questions_path, "r") as f: questions = json.load(f) - print( - f"Successfully loaded {len(chunks)} chunks and {len(questions)} questions" - ) - return chunks, questions + logger.info(f"Successfully loaded {len(questions)} questions") + return questions except Exception as e: - print(f"Error loading QA data: {e}") + logger.error(f"Error loading QA data: {e}") import traceback - traceback.print_exc() - return None, None + logger.debug(traceback.format_exc()) + return None -# Load chunks and questions when module is imported +# Load questions when module is imported try: - chunks, questions = load_qa_data() - if chunks is None or questions is None: - print("Warning: Could not load QA data.") + questions = load_qa_data() + if questions is None: + logger.warning("Could not load QA data.") except Exception as e: - print(f"Error initializing QA data: {e}") - chunks, questions = None, None + logger.error(f"Error initializing QA data: {e}") + questions = None -def get_question_answer( - idx: Optional[int] = None, return_both: bool = True -) -> Union[dict, str]: +def get_question_answer(idx=None, return_both: bool = True) -> dict: """ Get a question-answer pair either by index or randomly. @@ -148,7 +126,7 @@ def get_question_answer( qa_pair = questions[idx] else: raise ValueError( - f"Index out of range. Must be between 0 and {len(questions)-1}" + f"Index out of range. Must be between 0 and {len(questions) - 1}" ) question = qa_pair["question"] @@ -168,7 +146,7 @@ def get_question_count() -> int: return len(questions) -def get_qa_dataset(): +def get_qa_dataset() -> tuple: """ Return a HuggingFace Dataset containing question and answer pairs. diff --git a/train.sh b/train.sh new file mode 100755 index 0000000..5835684 --- /dev/null +++ b/train.sh @@ -0,0 +1,6 @@ +export CUDA_VISIBLE_DEVICES=0 + +python train_grpo.py + + + diff --git a/train_autodidact.py b/train_autodidact.py deleted file mode 100644 index 1415a9c..0000000 --- a/train_autodidact.py +++ /dev/null @@ -1,192 +0,0 @@ -# %% -import torch - -# %% -from unsloth import FastLanguageModel, is_bfloat16_supported - -max_seq_length = 4096 * 2 # Can increase for longer reasoning traces -lora_rank = 64 # Larger rank = smarter, but slower - -model, tokenizer = FastLanguageModel.from_pretrained( - model_name="meta-llama/meta-Llama-3.1-8B-Instruct", - max_seq_length=max_seq_length, - load_in_4bit=True, # False for LoRA 16bit - fast_inference=True, # Enable vLLM fast inference - max_lora_rank=lora_rank, - gpu_memory_utilization=0.6, # Reduce if out of memory -) - -model = FastLanguageModel.get_peft_model( - model, - r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 - target_modules=[ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ], # Remove QKVO if out of memory - lora_alpha=lora_rank, - use_gradient_checkpointing="unsloth", # Enable long context finetuning - random_state=3407, -) - -# %% -import re - -from datasets import Dataset, load_dataset - -from rl_helpers import get_qa_dataset -from search_module import get_question_answer, get_question_count, search - -train_dataset, test_dataset = get_qa_dataset() - -# %% [markdown] -# -# ### Train the model -# -# Now set up GRPO Trainer and all configurations! - -# %% -import os - -os.environ["WANDB_PROJECT"] = "bootstrap-search-rl" - -# %% -# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer -import UnslothGRPOTrainerTemp - -training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig( - use_vllm=True, # use vLLM for fast inference! - use_agentic_generate=True, # use agentic generation - learning_rate=5e-6, - adam_beta1=0.9, - adam_beta2=0.99, - weight_decay=0.1, - warmup_ratio=0.1, - lr_scheduler_type="cosine", - optim="paged_adamw_8bit", - logging_steps=1, - bf16=is_bfloat16_supported(), - fp16=not is_bfloat16_supported(), - per_device_train_batch_size=8, - gradient_accumulation_steps=1, # Increase to 4 for smoother training - num_generations=8, # Decrease if out of memory - max_prompt_length=1024, - max_completion_length=1024, - # num_train_epochs = 1, # Set to 1 for a full training run - max_steps=101, - save_steps=50, - max_grad_norm=0.1, - report_to="none", # Can use Weights & Biases - output_dir="full_local_training", -) - -# %% - - -import rl_helpers - -# importlib.reload(rl_helpers) - - -def agentic_generate( - prompts: list[str], - generate_fn, - max_generations: int = 6, -): - return run_agent(generate_fn, tokenizer, prompts, max_generations) - - -model.agentic_generate = agentic_generate - - -from vllm import SamplingParams - -verifier_sampling_params = SamplingParams( - temperature=0.1, - top_p=0.95, - max_tokens=4096, -) - - -def verifier_generate_fn(inputs): - return model.fast_generate( - inputs, - sampling_params=verifier_sampling_params, - ) - - -run_agent = rl_helpers.run_agent -reward_correctness = rl_helpers.build_reward_correctness_fn( - verifier_generate_fn, - tokenizer, -) -reward_formatting = rl_helpers.reward_formatting - -import UnslothGRPOTrainerTemp - -trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer( - model=model, - processing_class=tokenizer, - reward_funcs=[ - reward_correctness, - reward_formatting, - ], - args=training_args, - train_dataset=train_dataset, -) - -# %% -trainer.train() - -# %% [markdown] -# -# ### Inference -# Now let's try benchmark the model we trained! - -# %% -from vllm import SamplingParams - -import rl_helpers - -sampling_params = SamplingParams( - temperature=0.5, - top_p=0.95, - max_tokens=4096, -) - - -def eval_generate_fn(inputs): - return model.fast_generate( - inputs, - sampling_params=sampling_params, - lora_request=model.load_lora( - "full_local_training/checkpoint-101" - ), # load the trained LoRA - ) - - -rl_helpers.run_eval( - generate_fn=eval_generate_fn, - verify_fn=reward_correctness, - tokenizer=tokenizer, -) - - -# %% -# eval w/o lora -def eval_generate_fn(inputs): - return model.fast_generate( - inputs, - sampling_params=sampling_params, - ) - - -rl_helpers.run_eval( - generate_fn=eval_generate_fn, - verify_fn=reward_correctness, - tokenizer=tokenizer, -) diff --git a/train_autodidact_1B.py b/train_autodidact_1B.py deleted file mode 100644 index e7fdd3e..0000000 --- a/train_autodidact_1B.py +++ /dev/null @@ -1,196 +0,0 @@ -# %% -import torch - -# %% -from unsloth import FastLanguageModel, is_bfloat16_supported - -max_seq_length = 4096 * 2 # Can increase for longer reasoning traces -lora_rank = 64 # Larger rank = smarter, but slower - -model, tokenizer = FastLanguageModel.from_pretrained( - model_name="meta-llama/Llama-3.2-1B-Instruct", - max_seq_length=max_seq_length, - load_in_4bit=True, # False for LoRA 16bit - fast_inference=True, # Enable vLLM fast inference - max_lora_rank=lora_rank, - gpu_memory_utilization=0.6, # Reduce if out of memory -) - - -print(tokenizer.chat_template) # See what format Qwen expects - - -model = FastLanguageModel.get_peft_model( - model, - r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 - target_modules=[ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ], # Remove QKVO if out of memory - lora_alpha=lora_rank, - use_gradient_checkpointing="unsloth", # Enable long context finetuning - random_state=3407, -) - -# %% -import re - -from datasets import Dataset, load_dataset - -from rl_helpers import get_qa_dataset -from search_module import get_question_answer, get_question_count, search - -train_dataset, test_dataset = get_qa_dataset() - -# %% [markdown] -# -# ### Train the model -# -# Now set up GRPO Trainer and all configurations! - -# %% -import os - -os.environ["WANDB_PROJECT"] = "bootstrap-search-rl" - -# %% -# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer -import UnslothGRPOTrainerTemp - -training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig( - use_vllm=True, # use vLLM for fast inference! - use_agentic_generate=True, # use agentic generation - learning_rate=5e-6, - adam_beta1=0.9, - adam_beta2=0.99, - weight_decay=0.1, - warmup_ratio=0.1, - lr_scheduler_type="cosine", - optim="paged_adamw_8bit", - logging_steps=1, - bf16=is_bfloat16_supported(), - fp16=not is_bfloat16_supported(), - per_device_train_batch_size=8, - gradient_accumulation_steps=1, # Increase to 4 for smoother training - num_generations=8, # Decrease if out of memory - max_prompt_length=1024, - max_completion_length=1024, - # num_train_epochs = 1, # Set to 1 for a full training run - max_steps=101, - save_steps=50, - max_grad_norm=0.1, - report_to="none", # Can use Weights & Biases - output_dir="full_local_training", -) - -# %% - - -import rl_helpers - -# importlib.reload(rl_helpers) - - -def agentic_generate( - prompts: list[str], - generate_fn, - max_generations: int = 6, -): - return run_agent(generate_fn, tokenizer, prompts, max_generations) - - -model.agentic_generate = agentic_generate - - -from vllm import SamplingParams - -verifier_sampling_params = SamplingParams( - temperature=0.1, - top_p=0.95, - max_tokens=4096, -) - - -def verifier_generate_fn(inputs): - return model.fast_generate( - inputs, - sampling_params=verifier_sampling_params, - ) - - -run_agent = rl_helpers.run_agent -reward_correctness = rl_helpers.build_reward_correctness_fn( - verifier_generate_fn, - tokenizer, -) -reward_formatting = rl_helpers.reward_formatting - -import UnslothGRPOTrainerTemp - -trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer( - model=model, - processing_class=tokenizer, - reward_funcs=[ - reward_correctness, - reward_formatting, - ], - args=training_args, - train_dataset=train_dataset, -) - -# %% -trainer.train() - -# %% [markdown] -# -# ### Inference -# Now let's try benchmark the model we trained! - -# %% -from vllm import SamplingParams - -import rl_helpers - -sampling_params = SamplingParams( - temperature=0.5, - top_p=0.95, - max_tokens=4096, -) - - -def eval_generate_fn(inputs): - return model.fast_generate( - inputs, - sampling_params=sampling_params, - lora_request=model.load_lora( - "full_local_training/checkpoint-101" - ), # load the trained LoRA - ) - - -rl_helpers.run_eval( - generate_fn=eval_generate_fn, - verify_fn=reward_correctness, - tokenizer=tokenizer, -) - - -# %% -# eval w/o lora -def eval_generate_fn(inputs): - return model.fast_generate( - inputs, - sampling_params=sampling_params, - ) - - -rl_helpers.run_eval( - generate_fn=eval_generate_fn, - verify_fn=reward_correctness, - tokenizer=tokenizer, -) diff --git a/train_grpo.py b/train_grpo.py new file mode 100644 index 0000000..ff0699d --- /dev/null +++ b/train_grpo.py @@ -0,0 +1,124 @@ +import os + +from unsloth import FastLanguageModel, is_bfloat16_supported + +import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp + +# Import reward functions +from src.rl_helpers import ( + build_reward_correctness_fn, + get_qa_dataset, + reward_exact_match_chunk_query, + reward_formatting, + reward_retry_behavior, + run_agent, +) +from src.config import ( + MODEL_CONFIG, + MODEL_NAME, + OUTPUT_DIR, + TRAINING_CONFIG, + get_sampling_params, + init_training_dirs, + logger, + update_log_path, +) + +# Initialize training directories +paths = init_training_dirs() + +# Update logger to use the training directory +update_log_path(paths["log_dir"]) +logger.info(f"Training output directory: {paths['output_dir']}") +logger.info(f"Logs are being saved to both ./logs and {paths['log_dir']}") + +# Initialize model and tokenizer +logger.info(f"Initializing model {MODEL_NAME}") +model, tokenizer = FastLanguageModel.from_pretrained( + model_name=MODEL_NAME, + max_seq_length=MODEL_CONFIG["max_seq_length"], + load_in_4bit=True, # False for LoRA 16bit + fast_inference=True, # Enable vLLM fast inference + max_lora_rank=MODEL_CONFIG["lora_rank"], + gpu_memory_utilization=MODEL_CONFIG["gpu_memory_utilization"], +) + +# Setup LoRA +logger.info("Setting up LoRA adapter") +model = FastLanguageModel.get_peft_model( + model, + r=MODEL_CONFIG["lora_rank"], + target_modules=MODEL_CONFIG["target_modules"], + lora_alpha=MODEL_CONFIG["lora_rank"], + use_gradient_checkpointing=True, # Enable long context finetuning + random_state=3407, +) + +# Load datasets +logger.info("Loading datasets") +train_dataset, test_dataset = get_qa_dataset() +logger.info( + f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples" +) + +# Setup training arguments +logger.info("Setting up training arguments") +training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig( + use_vllm=True, # use vLLM for fast inference! + use_agentic_generate=True, # use agentic generation + **TRAINING_CONFIG, + bf16=is_bfloat16_supported(), + fp16=not is_bfloat16_supported(), + output_dir=OUTPUT_DIR, + # report_to="tensorboard", # ā“ Does't have billions of tensorboard files if set report to right here +) + + +# Setup model generation functions +def agentic_generate( + prompts: list, + generate_fn, + max_generations: int = 10, +): + return run_agent(generate_fn, tokenizer, prompts, max_generations) + + +model.agentic_generate = agentic_generate + +# Setup verifier +logger.info("Setting up verifier") +verifier_sampling_params = get_sampling_params(temperature=0.1) + + +def verifier_generate_fn(inputs): + return model.fast_generate( + inputs, + sampling_params=verifier_sampling_params, + ) + + +# Setup trainer +logger.info("Initializing trainer") +trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer( + model=model, + processing_class=tokenizer, + reward_funcs=[ + build_reward_correctness_fn( + verifier_generate_fn, + tokenizer, + log_file=os.path.join(paths["log_dir"], "qa_log.txt"), + ), + reward_formatting, + reward_retry_behavior, + reward_exact_match_chunk_query, + ], + args=training_args, + train_dataset=train_dataset, +) + +# Train the model +if __name__ == "__main__": + logger.info("Starting training") + trainer.train() + logger.info("Training completed") + logger.info(f"Model saved to {OUTPUT_DIR}")