refactor: restructure code base, better centralize logging logic

main
thinhlpg 1 month ago
parent 04d56325bb
commit 3c2deaced9

@ -1,2 +1,2 @@
HF_TOKEN=
OPENROUTER_API_KEY=
HF_TOKEN=<your-huggingface-token>
OPENROUTER_API_KEY=<your-openrouter-api-key>

2
.gitignore vendored

@ -8,6 +8,8 @@ unsloth_compiled_cache/
full_local_training/
grpo_trainer_lora_model/
qa_log.txt
trainer_output_*
data/
# Byte-compiled / optimized / DLL files
__pycache__/

@ -1,232 +0,0 @@
"""
This script performs two main tasks:
1. It loads a markdown document, splits it into chunks, generates embeddings,
and builds a FAISS index (which is saved locally).
2. It generates QA pairs from the document using llama.
For each chunk (using a sliding window for context), it generates multiple question-answer pairs
with different difficulties. The generation is performed in batch with one retry for failed prompts.
Successfully generated QA pairs are saved to "saved_data/questions.json".
Requirements:
pip install langchain faiss-cpu unsloth vllm
"""
import json
import os
import pickle
import re
from typing import Dict, List, Optional, Tuple
from langchain.text_splitter import RecursiveCharacterTextSplitter
# ========= Part 1: Document Processing and Embedding Generation =========
# Load and split the markdown document using LangChain
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.vectorstores import FAISS
from embeddings import CustomHuggingFaceEmbeddings
# Load your markdown file (adjust the path as needed)
loader = UnstructuredMarkdownLoader("./data/mission_report.md")
docs = loader.load()
# Split the document into smaller chunks (each 1000 characters, no overlap)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
chunks = text_splitter.split_documents(docs)
# Save chunks for later use # TODO: change to csv? easier inspect.
os.makedirs("saved_data", exist_ok=True)
with open("saved_data/chunks.pkl", "wb") as f:
pickle.dump(chunks, f)
print(f"Saved {len(chunks)} chunks to saved_data/chunks.pkl")
embeddings = CustomHuggingFaceEmbeddings()
# Create a FAISS vector store from the document chunks and save it locally
vectorstore = FAISS.from_documents(chunks, embeddings)
vectorstore.save_local("faiss_index")
print("Saved FAISS index to 'faiss_index'")
# TODO: add the paraphrased chunks to the vector store
# ========= Part 2: QA Generation using Llama Backend =========
# Setup Llama backend via unsloth and vLLM
from unsloth import FastLanguageModel
from vllm import SamplingParams
import rl_helpers # Ensure you have this or remove if not used
# Load the Llama model (adjust parameters as needed)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
max_seq_length=4096,
load_in_4bit=True, # Use 4-bit quantization if desired
fast_inference=True, # Enable fast inference
gpu_memory_utilization=0.6, # Adjust based on your GPU memory
)
# Define sampling parameters for generation
sampling_params = SamplingParams(
temperature=0.3,
top_p=0.95,
max_tokens=4096,
)
def batch_generate(prompts: List[str]) -> List[str]:
"""
Given a list of prompt strings, returns a list of generated outputs.
"""
def format_input(text: str) -> str:
return tokenizer.apply_chat_template(
[{"role": "user", "content": text}],
tokenize=False,
add_generation_prompt=True,
)
formatted = [format_input(p) for p in prompts]
outputs = model.fast_generate(formatted, sampling_params=sampling_params)
return [output.outputs[0].text for output in outputs]
def parse_qa_block(block: str) -> Optional[Tuple[str, str, str]]:
"""
Parses a QA block that should contain exactly three non-empty lines:
- A line starting with "Question:"
- A line starting with "Answer:"
- A line starting with "Difficulty:"
If the markers are not present but the block contains exactly three lines,
those are used in order.
Returns a tuple (question, answer, difficulty) or None if parsing fails.
"""
lines = [line.strip() for line in block.splitlines() if line.strip()]
if not lines:
return None
question, answer, difficulty = None, None, None
for line in lines:
lower = line.lower()
if question is None and lower.startswith("question:"):
question = line[len("question:") :].strip()
elif answer is None and lower.startswith("answer:"):
answer = line[len("answer:") :].strip()
elif difficulty is None and lower.startswith("difficulty:"):
difficulty = line[len("difficulty:") :].strip()
if question and answer and difficulty:
return question, answer, difficulty
if len(lines) == 3:
return lines[0], lines[1], lines[2]
return None
def parse_multiple_qa_output(output: str) -> List[Tuple[str, str, str]]:
"""
Splits the output into blocks (separated by one or more blank lines) and
attempts to parse each as a QA pair.
Returns a list of successfully parsed QA tuples.
"""
blocks = re.split(r"\n\s*\n", output.strip())
qa_pairs = []
for block in blocks:
parsed = parse_qa_block(block)
if parsed:
qa_pairs.append(parsed)
return qa_pairs
def generate_question_batch_for_chunks(
chunks: List, num_questions: int = 2, difficulty: str = None
) -> List[Dict]:
"""
Generates QA pairs for multiple chunks in batch.
For each chunk (except the first and last), a sliding window is used for context:
- before: previous chunk's content
- current: current chunk's content
- after: next chunk's content
Each prompt instructs the model to output exactly three lines per QA pair with markers.
Failed prompts are retried once in batch; if still unsuccessful, they are skipped.
Returns a list of dicts with keys: "chunk_id", "question", "answer", "difficulty".
"""
prompts = []
chunk_ids = []
# Prepare prompts using a sliding window
for i in range(1, len(chunks) - 1):
before = chunks[i - 1].page_content
current = chunks[i].page_content
after = chunks[i + 1].page_content
prompt = (
f"From the text within ==BEGIN== and ==END==, generate {num_questions} questions with answers.\n"
"For each QA pair, output exactly three lines with no extra commentary:\n"
"Line 1: Question: <your question>\n"
"Line 2: Answer: <the answer>\n"
"Line 3: Difficulty: <easy, medium, or hard>\n"
"Do not include any additional text.\n\n"
"==BEGIN==\n"
f"{before}\n{current}\n{after}\n"
"==END==\n"
)
prompts.append(prompt)
chunk_ids.append(i)
# First batch generation
outputs = batch_generate(prompts)
results = [None] * len(outputs)
failed_indices = []
# Parse each output
for idx, output in enumerate(outputs):
qa_pairs = parse_multiple_qa_output(output)
if qa_pairs is None or len(qa_pairs) < num_questions:
failed_indices.append(idx)
else:
results[idx] = qa_pairs[:num_questions]
# Retry failed prompts in batch
if failed_indices:
print(f"Retrying {len(failed_indices)} failed prompt(s)...")
retry_prompts = [prompts[i] for i in failed_indices]
retry_outputs = batch_generate(retry_prompts)
for j, idx in enumerate(failed_indices):
qa_pairs = parse_multiple_qa_output(retry_outputs[j])
if qa_pairs is not None and len(qa_pairs) >= num_questions:
results[idx] = qa_pairs[:num_questions]
else:
results[idx] = None # Mark as failed
# Build final output, skipping prompts that failed even after retry
final_questions = []
for i, qa_list in enumerate(results):
if qa_list is not None:
for qa in qa_list:
final_questions.append(
{
"chunk_id": chunk_ids[i],
"question": qa[0],
"answer": qa[1],
"difficulty": qa[2],
}
)
return final_questions
# Generate QA pairs in batch (using a sliding window over the chunks)
all_questions = generate_question_batch_for_chunks(
chunks, num_questions=2, difficulty="medium"
)
print(f"Generated {len(all_questions)} QA pairs.")
# Save the QA pairs to a JSON file
questions_path = os.path.join("saved_data", "questions.json")
with open(questions_path, "w") as f:
json.dump(all_questions, f, indent=2)
print(f"Saved questions to {questions_path}")

@ -1,3 +0,0 @@
unsloth_compiled_cache
0_*
faiss_index*

@ -5,13 +5,13 @@ langchain-community
Markdown
tokenizers
unsloth==2025.3.6
transformers==4.49.0
unsloth_zoo==2025.3.4
unstructured
vllm
wandb
vllm==0.7.2
transformers==4.49.0
ipykernel
python-dotenv
loguru
gradio
gradio
tensorboard

@ -5,35 +5,31 @@ This script performs two main tasks:
2. It generates QA pairs from the document using llama.
For each chunk (using a sliding window for context), it generates multiple question-answer pairs
with different difficulties. The generation is performed in batch with one retry for failed prompts.
Successfully generated QA pairs are saved to "saved_data/questions.json".
Successfully generated QA pairs are saved to "data/questions.json".
Requirements:
pip install langchain faiss-cpu unsloth vllm
"""
import json
import os
import re
from typing import Dict, List, Optional, Tuple
import sys
from pathlib import Path
# Add project root to Python path
project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))
import pandas as pd
from langchain.text_splitter import RecursiveCharacterTextSplitter
from loguru import logger
# Configure logger
logger.add(
"logs/generate_data_{time}.log",
rotation="500 MB",
retention="10 days",
level="INFO",
)
# ========= Part 1: Document Processing and Embedding Generation =========
# Load and split the markdown document using LangChain
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.vectorstores import FAISS
from embeddings import CustomHuggingFaceEmbeddings
from src.config import DATA_DIR, logger
from src.embeddings import CustomHuggingFaceEmbeddings
# Load your markdown file (adjust the path as needed)
loader = UnstructuredMarkdownLoader("./data/mission_report.md")
@ -43,9 +39,6 @@ docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
chunks = text_splitter.split_documents(docs)
# Create output directory
os.makedirs("saved_data", exist_ok=True)
# Save chunks to CSV for easy inspection
chunks_df = pd.DataFrame(
{
@ -54,15 +47,15 @@ chunks_df = pd.DataFrame(
"metadata": [chunk.metadata for chunk in chunks],
}
)
chunks_df.to_csv("saved_data/chunks.csv", index=False)
print(f"Saved {len(chunks)} chunks to saved_data/chunks.csv")
chunks_df.to_csv(DATA_DIR / "chunks.csv", index=False)
logger.info(f"Saved {len(chunks)} chunks to {DATA_DIR}/chunks.csv")
embeddings = CustomHuggingFaceEmbeddings()
# Create a FAISS vector store from the document chunks and save it locally
vectorstore = FAISS.from_documents(chunks, embeddings)
vectorstore.save_local("faiss_index")
print("Saved FAISS index to 'faiss_index'")
vectorstore.save_local(str(DATA_DIR))
logger.info(f"Saved FAISS index to {DATA_DIR}")
# TODO: add the paraphrased chunks to the vector store
@ -72,8 +65,6 @@ print("Saved FAISS index to 'faiss_index'")
from unsloth import FastLanguageModel
from vllm import SamplingParams
import rl_helpers # Ensure you have this or remove if not used
# Load the Llama model (adjust parameters as needed)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
@ -91,7 +82,7 @@ sampling_params = SamplingParams(
)
def batch_generate(prompts: List[str]) -> List[str]:
def batch_generate(prompts: list) -> list:
"""
Given a list of prompt strings, returns a list of generated outputs.
"""
@ -108,7 +99,7 @@ def batch_generate(prompts: List[str]) -> List[str]:
return [output.outputs[0].text for output in outputs]
def parse_qa_block(block: str) -> Optional[Tuple[str, str, str]]:
def parse_qa_block(block: str):
"""
Parses a QA block that should contain exactly three non-empty lines:
- A line starting with "Question:"
@ -141,7 +132,7 @@ def parse_qa_block(block: str) -> Optional[Tuple[str, str, str]]:
return None
def parse_multiple_qa_output(output: str) -> List[Tuple[str, str, str]]:
def parse_multiple_qa_output(output: str) -> list:
"""
Splits the output into blocks (separated by one or more blank lines) and
attempts to parse each as a QA pair.
@ -158,8 +149,8 @@ def parse_multiple_qa_output(output: str) -> List[Tuple[str, str, str]]:
def generate_question_batch_for_chunks(
chunks: List, num_questions: int = 2, difficulty: Optional[str] = None
) -> List[Dict]:
chunks: list, num_questions: int = 2, difficulty=None
) -> list:
"""
Generates QA pairs for multiple chunks in batch.
@ -198,7 +189,9 @@ def generate_question_batch_for_chunks(
# First batch generation
outputs = batch_generate(prompts)
results: List[Optional[List[Tuple[str, str, str]]]] = [None] * len(outputs)
results = []
for _ in range(len(outputs)):
results.append(None)
failed_indices = []
# Parse each output
@ -270,10 +263,10 @@ def generate_question_batch_for_chunks(
all_questions = generate_question_batch_for_chunks(
chunks, num_questions=2, difficulty="medium"
)
print(f"Generated {len(all_questions)} QA pairs.")
logger.info(f"Generated {len(all_questions)} QA pairs.")
# Save the QA pairs to a JSON file
questions_path = os.path.join("saved_data", "questions.json")
questions_path = DATA_DIR / "questions.json"
with open(questions_path, "w") as f:
json.dump(all_questions, f, indent=2)
print(f"Saved questions to {questions_path}")
logger.info(f"Saved questions to {questions_path}")

@ -8,10 +8,18 @@ import json
import random
import sys
import time
from typing import Any, Dict
from pathlib import Path
# Import our search module (ensure these functions follow the new interfaces)
from search_module import get_question_answer, get_question_count, search
# Add project root to Python path
project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))
# Import our search module and config
from src.config import DATA_DIR, logger
from src.search_module import get_question_answer, get_question_count, search
# TODO: Import verify function and router from appropriate module
# TODO: Consider moving verify function to search_module.py for better organization
class SimpleQAEnvironment:
@ -21,27 +29,30 @@ class SimpleQAEnvironment:
self.score = {"correct": 0, "incorrect": 0, "total": 0}
self.session_data = []
self.current_question = None
self.session_file = DATA_DIR / "qa_sessions"
def display_welcome(self):
"""Display welcome message and instructions."""
print("\n===== Search & Answer Environment =====")
print("Answer questions using the search tool to find relevant information.")
print("Type 'q' to quit, 'h' for help.\n")
logger.info("===== Search & Answer Environment =====")
logger.info(
"Answer questions using the search tool to find relevant information."
)
logger.info("Type 'q' to quit, 'h' for help.\n")
def display_help(self):
"""Display help information."""
print("\n===== Commands =====")
print("n - Get a new question")
print("s <query> - Search for information (e.g., s program launch date)")
print("a <answer> - Submit your answer")
print("h - Display this help message")
print("q - Quit the program\n")
logger.info("\n===== Commands =====")
logger.info("n - Get a new question")
logger.info("s <query> - Search for information (e.g., s program launch date)")
logger.info("a <answer> - Submit your answer")
logger.info("h - Display this help message")
logger.info("q - Quit the program\n")
def display_question(self, question: str):
"""Display the current question."""
print("\n===== QUESTION =====")
print(question)
print("=====================\n")
logger.info("\n===== QUESTION =====")
logger.info(question)
logger.info("=====================\n")
def get_new_question(self) -> str:
"""Get a new random question and set it as current."""
@ -66,30 +77,30 @@ class SimpleQAEnvironment:
def perform_search(self, query: str):
"""Perform a search with the given query."""
if not query:
print("Please provide a search query.")
logger.warning("Please provide a search query.")
return
try:
print("\n===== SEARCH RESULTS =====")
logger.info("\n===== SEARCH RESULTS =====")
results = search(query)
print(results)
print("==========================\n")
logger.info(results)
logger.info("==========================\n")
# Record search in current question data if available.
if self.current_question is not None:
self.current_question["searches"].append(query)
except Exception as e:
print(f"Error searching: {str(e)}")
logger.error(f"Error searching: {str(e)}")
async def process_answer(self, user_answer: str):
"""Process and verify the user's answer."""
if self.current_question is None:
print("Please get a question first.")
logger.warning("Please get a question first.")
return
if not user_answer:
print("Please provide an answer.")
logger.warning("Please provide an answer.")
return
# Record answer and calculate time taken.
@ -100,27 +111,29 @@ class SimpleQAEnvironment:
)
try:
print("\nVerifying your answer...")
correct = await verify(
user_answer,
self.current_question["question"],
self.current_question["correct_answer"],
router,
)
logger.info("\nVerifying your answer...")
# TODO: Implement verify function in search_module.py
# correct = await verify(
# user_answer,
# self.current_question["question"],
# self.current_question["correct_answer"],
# router,
# )
correct = False # Temporary placeholder until verify is implemented
# Update score and inform the user.
self.score["total"] += 1
if correct:
self.score["correct"] += 1
print("\n✓ Your answer is CORRECT!")
logger.success("\n✓ Your answer is CORRECT!")
else:
self.score["incorrect"] += 1
print("\n✗ Your answer is INCORRECT.")
print(
logger.error("\n✗ Your answer is INCORRECT.")
logger.info(
f"\nThe correct answer is:\n{self.current_question['correct_answer']}"
)
print(f"\nScore: {self.score['correct']}/{self.score['total']}")
logger.info(f"\nScore: {self.score['correct']}/{self.score['total']}")
# Record the result and add the current question to the session data.
self.current_question["is_correct"] = correct
@ -130,15 +143,18 @@ class SimpleQAEnvironment:
self.current_question = None
except Exception as e:
print(f"Error verifying answer: {str(e)}")
logger.error(f"Error verifying answer: {str(e)}")
def save_session(self):
"""Save the session data to a file."""
if not self.session_data:
return
# Ensure session directory exists
self.session_file.mkdir(parents=True, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"qa_session_{timestamp}.json"
filename = self.session_file / f"qa_session_{timestamp}.json"
session_data = {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
@ -149,9 +165,9 @@ class SimpleQAEnvironment:
try:
with open(filename, "w") as f:
json.dump(session_data, f, indent=2)
print(f"\nSession data saved to {filename}")
logger.info(f"\nSession data saved to {filename}")
except Exception as e:
print(f"Error saving session data: {str(e)}")
logger.error(f"Error saving session data: {str(e)}")
async def run(self):
"""Run the main command loop."""
@ -178,11 +194,11 @@ class SimpleQAEnvironment:
answer = command[2:].strip()
await self.process_answer(answer)
else:
print("Unknown command. Type 'h' for help.")
logger.warning("Unknown command. Type 'h' for help.")
# Save session data on exit.
self.save_session()
print("\nThank you for using the Q&A environment!")
logger.info("\nThank you for using the Q&A environment!")
async def main():
@ -195,6 +211,6 @@ if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\nProgram terminated by user.")
logger.info("\nProgram terminated by user.")
except Exception as e:
print(f"\nError: {str(e)}")
logger.error(f"\nError: {str(e)}")

@ -1,18 +1,21 @@
from torch import Tensor
import os
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import *
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from packaging.version import Version
from trl.trainer.grpo_trainer import (
Any,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
Dataset,
GRPOConfig,
GRPOTrainer,
GenerationConfig,
GRPOConfig,
IterableDataset,
LLM,
Optional,
PeftConfig,
PreTrainedModel,
@ -41,7 +44,6 @@ from trl.trainer.grpo_trainer import (
nn,
os,
pad,
patch,
prepare_deepspeed,
set_seed,
textwrap,
@ -50,42 +52,8 @@ from trl.trainer.grpo_trainer import (
unwrap_model_for_generation,
version,
wandb,
warnings,
os,
torch,
transformers,
Any,
LLM,
Union,
apply_chat_template,
broadcast_object_list,
gather,
gather_object,
is_conversational,
maybe_apply_chat_template,
nn,
os,
pad,
torch,
unwrap_model_for_generation,
wandb,
GRPOTrainer,
Trainer,
gather,
os,
torch,
)
import os
from typing import *
from dataclasses import dataclass, field
from packaging.version import Version
import torch
import numpy as np
from contextlib import nullcontext
from torch.nn import functional as F
torch_compile_options = {
"epilogue_fusion": True,
"max_autotune": False,

@ -0,0 +1,301 @@
import os
import sys
from datetime import datetime
from pathlib import Path
import torch
from dotenv import load_dotenv
from loguru import logger
from vllm import SamplingParams
# Load environment variables from .env file if it exists
load_dotenv(override=True)
# Project paths
PROJ_ROOT = Path(__file__).resolve().parent.parent
DATA_DIR = PROJ_ROOT / "data"
LOG_FOLDER = PROJ_ROOT / "logs"
# Model configuration
# MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
device_id = (
1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device()
)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
OUTPUT_DIR = (
PROJ_ROOT
/ f"trainer_output_{MODEL_NAME.replace('/', '_')}_gpu{device_id}_{timestamp}"
)
# Model parameters
MODEL_CONFIG = {
"max_seq_length": 4096 * 2, # Can increase for longer reasoning traces
"lora_rank": 64, # Larger rank = smarter, but slower
"gpu_memory_utilization": 0.6, # Reduce if out of memory
"model_name": MODEL_NAME,
"target_modules": [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
], # Remove QKVO if out of memory
}
# Training parameters
TRAINING_CONFIG = {
"learning_rate": 5e-6,
"adam_beta1": 0.9,
"adam_beta2": 0.99,
"weight_decay": 0.1,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine",
"optim": "paged_adamw_8bit",
"logging_steps": 1,
"per_device_train_batch_size": 8,
"gradient_accumulation_steps": 1, # Increase to 4 for smoother training
"num_generations": 8, # Decrease if out of memory
"max_prompt_length": 1024,
"max_completion_length": 1024,
"max_steps": 101,
"save_steps": 50,
"max_grad_norm": 0.1,
"report_to": "tensorboard",
}
# Sampling parameters
def get_sampling_params(temperature: float = 0.1) -> SamplingParams:
"""Get sampling parameters for text generation"""
return SamplingParams(
temperature=temperature,
top_p=0.95,
max_tokens=4096,
)
# Initialize logging based on environment
def _init_logging(env: str = "development") -> None:
"""
Initialize logging configuration with console logging
and default file logging to ./logs directory.
Additional file logging will be set up later in update_log_path().
Args:
env: The environment for logging ('development' or 'production')
"""
# Create default log folder
if not LOG_FOLDER.exists():
LOG_FOLDER.mkdir(parents=True, exist_ok=True)
# Remove any existing handlers
logger.remove()
# Define the logging format
console_format = (
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> "
"| <level>{level: <8}</level> "
"| <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> "
"- <level>{message}</level>"
)
file_format = (
"{time:YYYY-MM-DD at HH:mm:ss} "
"| {level} "
"| {name}:{function}:{line} "
"- {message}"
)
# Add console logging
logger.add(
sys.stderr,
format=console_format,
level="DEBUG" if env == "development" else "INFO",
colorize=True,
backtrace=True,
diagnose=env == "development",
)
# Add default file logging to ./logs directory
logger.add(
LOG_FOLDER / "app.log",
format=file_format,
level="INFO",
rotation="500 MB",
retention="7 days",
compression="zip",
enqueue=True, # Enables asynchronous logging
)
# Add custom level for requests
logger.level("REQUEST", no=25, color="<yellow>", icon=" ")
# Configure exception handling
def exception_handler(exc_type, exc_value, exc_traceback):
if issubclass(exc_type, KeyboardInterrupt):
sys.__excepthook__(exc_type, exc_value, exc_traceback)
return
logger.opt(exception=(exc_type, exc_value, exc_traceback)).critical(
"Unhandled exception"
)
sys.excepthook = exception_handler
# Update the log files to point to the training directory
def update_log_path(log_dir=None):
"""
Add a log file in the training directory while keeping the default ./logs logging.
Should be called after the training directory is created.
Args:
log_dir: Path to store additional log files (default: uses get_paths()["log_dir"])
"""
# Use provided log_dir or get from training paths
if log_dir is None:
paths = get_paths(create_dirs=True)
log_dir = paths["log_dir"]
else:
log_dir = Path(log_dir)
log_dir.mkdir(exist_ok=True, parents=True)
file_format = (
"{time:YYYY-MM-DD at HH:mm:ss} "
"| {level} "
"| {name}:{function}:{line} "
"- {message}"
)
# Add additional file handler pointing to training directory
# No need to remove existing handlers as we want to keep those
logger.add(
log_dir / "app.log",
format=file_format,
level="INFO",
rotation="500 MB",
retention="7 days",
compression="zip",
enqueue=True, # Enables asynchronous logging
)
logger.info(f"Additional logs will be stored in: {log_dir}")
# Paths configuration without creating directories
def get_paths(create_dirs: bool = False) -> dict:
"""
Get common paths for the project
Args:
create_dirs: Whether to create the directories
Returns:
Dictionary with paths
"""
output_dir = Path(OUTPUT_DIR)
log_dir = output_dir / "logs"
tensorboard_dir = output_dir / "runs"
# Only create directories if explicitly requested
if create_dirs:
output_dir.mkdir(exist_ok=True)
log_dir.mkdir(exist_ok=True)
# Only create tensorboard directory if it's enabled in config
if TRAINING_CONFIG.get("report_to") == "tensorboard":
tensorboard_dir.mkdir(exist_ok=True)
return {
"output_dir": output_dir,
"log_dir": log_dir,
"tensorboard_dir": tensorboard_dir,
"proj_root": PROJ_ROOT,
"data_dir": DATA_DIR,
}
# Create training directories
def init_training_dirs():
"""Initialize all directories needed for training"""
paths = get_paths(create_dirs=True)
# Also ensure our standard project directories exist
for directory in [
DATA_DIR,
LOG_FOLDER,
]:
directory.mkdir(exist_ok=True, parents=True)
return paths
# For backward compatibility - will be deprecated
def setup_logger(module_name=None, create_dirs: bool = False):
"""
Setup a logger for a specific module with consistent configuration.
Note: This function is kept for backward compatibility.
Use the global 'logger' instead for new code.
Args:
module_name: Optional name of module for module-specific log file
create_dirs: Whether to create log directories
Returns:
Configured logger instance
"""
logger.warning(
"setup_logger is deprecated. Import logger directly from config instead."
)
return logger
# Tensorboard writer singleton
_tensorboard_writer = None
# Safe tensorboard logging function
def log_metric(key, value, step=0):
"""
Log a metric safely to tensorboard if writer is available.
Args:
key: Metric name
value: Metric value
step: Training step
"""
global _tensorboard_writer
# Skip tensorboard logging if disabled in config
if TRAINING_CONFIG.get("report_to") != "tensorboard":
logger.debug(f"Tensorboard disabled. Metric: {key}={value} (step {step})")
return
# Get paths and initialize writer if needed
paths = get_paths(create_dirs=False)
if paths["tensorboard_dir"].exists():
# Only create writer once
if _tensorboard_writer is None:
from torch.utils.tensorboard.writer import SummaryWriter
_tensorboard_writer = SummaryWriter(paths["tensorboard_dir"])
logger.debug(f"Created tensorboard writer at {paths['tensorboard_dir']}")
# Add scalar using existing writer
_tensorboard_writer.add_scalar(key, value, step)
# No need to close the writer - it will be closed at process exit
else:
logger.debug(f"Tensorboard metric: {key}={value} (step {step})")
# Initialize logging on module import
env = os.getenv("APP_ENV", "development")
_init_logging(env=env)
# Log project root on import
logger.info(f"Project root path: {PROJ_ROOT}")
logger.debug(f"Running in {env} environment")

@ -1,7 +1,4 @@
from typing import List, Union
import torch
import torch.nn.functional as F
from langchain.embeddings.base import Embeddings
from transformers import AutoModel, AutoTokenizer
@ -18,9 +15,7 @@ class CustomHuggingFaceEmbeddings(Embeddings):
- "query": uses mean pooling over tokens (weighted by the attention mask) for query embeddings.
"""
def __init__(
self, model_name: str = DEFAULT_MODEL_NAME, default_mode: str = "sentence"
):
def __init__(self, model_name=DEFAULT_MODEL_NAME, default_mode="sentence"):
self.model_name = model_name
# Set device to GPU if available, else CPU
self.device = "cuda" if torch.cuda.is_available() else "cpu"
@ -29,7 +24,8 @@ class CustomHuggingFaceEmbeddings(Embeddings):
self.default_mode = default_mode # "sentence" or "query"
self.model.eval() # Set model to evaluation mode
def get_embedding(self, text: Union[str, List[str]], mode: str = None):
def get_embedding(self, text, mode=None):
"""Get embeddings for text using specified mode"""
if mode is None:
mode = self.default_mode
assert mode in (
@ -59,14 +55,14 @@ class CustomHuggingFaceEmbeddings(Embeddings):
vectors = output.last_hidden_state[:, 0, :]
return vectors
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts):
"""
Compute embeddings for a list of documents (using sentence mode).
"""
vectors = self.get_embedding(texts, mode="sentence")
return vectors.cpu().numpy().tolist()
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text):
"""
Compute an embedding for a single query.
"""

@ -4,33 +4,21 @@ This module provides utility functions for handling chat-based tool interactions
and calculating rewards based on the quality of responses.
"""
import asyncio
import inspect
import json
import re
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import nest_asyncio
import numpy as np
import torch
from loguru import logger
from search_module import get_qa_dataset, search
# Setup loguru
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
logger.add(
log_dir / "rl_helpers_{time}.log",
rotation="500 MB",
retention="10 days",
level="DEBUG",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
)
from src.config import log_metric, logger
from src.search_module import get_qa_dataset, search
# Apply nest_asyncio for supporting async operations in notebooks
nest_asyncio.apply()
from typing import Callable, List
from trl.trainer.grpo_trainer import apply_chat_template
@ -238,6 +226,8 @@ def run_tool_calls(chat_states):
Execute tool calls found in chat states.
"""
logger.debug(f"Running tool calls for {len(chat_states)} chat states")
total_retries = 0
for chat_state in chat_states:
if chat_state.get("finished"):
logger.debug("Chat state already finished, skipping tool calls")
@ -256,9 +246,14 @@ def run_tool_calls(chat_states):
elif len(function_calls) == 1:
function_call = function_calls[0]
query = function_call["function"]["parameters"]["query"]
logger.info(f"Executing search with query: {query}")
logger.info(f"🔍 Search Query: {query}")
results = search(query, return_type=str, results=2)
chat_state["messages"].append({"role": "ipython", "content": results})
# Count retries
retries = len(extract_json_objects(assistant_response))
total_retries += retries
logger.debug("Added search results to chat state")
except Exception as e:
logger.error(f"Error during tool call: {str(e)}")
@ -332,7 +327,12 @@ def get_chat_num_tokens(chat_state, tokenizer):
def run_agent(
generate_fn, tokenizer, questions, max_generations=5, max_new_tokens=4096
generate_fn,
tokenizer,
questions,
max_generations=5,
max_new_tokens=4096,
correct_contents=None,
):
"""
Run the agent to completion for a batch of questions.
@ -343,6 +343,11 @@ def run_agent(
)
chat_states = [get_initial_chat(q) for q in questions]
# Add correct content to chat states if provided
if correct_contents:
for chat_state, correct_content in zip(chat_states, correct_contents):
chat_state["correct_content"] = correct_content
# set the initial_prompt length
for i, chat_state in enumerate(chat_states):
chat_state["initial_length"] = get_chat_num_tokens(chat_state, tokenizer)
@ -350,7 +355,7 @@ def run_agent(
# agent loop
for i in range(max_generations):
logger.info(f"Starting generation step {i+1}/{max_generations}")
logger.info(f"Starting generation step {i + 1}/{max_generations}")
chat_states = run_agent_generations(generate_fn, tokenizer, chat_states)
chat_states = check_finished_chats(chat_states)
chat_states = run_tool_calls(chat_states)
@ -359,7 +364,7 @@ def run_agent(
)
finished_count = sum(1 for state in chat_states if state.get("finished"))
logger.info(
f"Finished {finished_count}/{len(chat_states)} chat states after step {i+1}"
f"Finished {finished_count}/{len(chat_states)} chat states after step {i + 1}"
)
logger.info("Agent run completed")
@ -440,15 +445,26 @@ async def verify(student_answer: str, question: str, answer: str) -> bool:
def check_student_answers(
questions: List[str],
answers: List[str],
student_answers: List[str],
vllm_generate_func: Callable[[List[str]], List[str]],
questions: list[str],
answers: list[str],
student_answers: list, # Can be strings or dicts
vllm_generate_func,
tokenizer,
log_file: str = "qa_log.txt",
) -> List[bool]:
log_file=None,
) -> list[bool]:
"""
Evaluates a list of student answers against the true answers using a vLLM generate function.
Args:
questions: List of questions
answers: List of correct answers
student_answers: List of student answers to evaluate
vllm_generate_func: Function to generate verification responses
tokenizer: Tokenizer for formatting prompts
log_file: Optional path to write detailed results
Returns:
List of boolean results (True for correct answers)
"""
logger.info(f"Checking {len(questions)} student answers")
@ -463,12 +479,15 @@ def check_student_answers(
prompts = []
for question, answer, student_ans in zip(questions, answers, student_answers):
prompt_text = (
"You are grading a student's answer. For the following question, "
"compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer is correct, or 'No' if it is completely incorrect.\n\n"
"You are grading a student's answer to a question. For the following question, "
"compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer contains the correct information, "
"even if it's not an exact match. If the student's answer doesn't contain the right information or is completely incorrect, reply with 'No'.\n\n"
f"Question: {question}\n"
f"Correct Answer: {answer}\n"
f"Student Answer: {student_ans}\n"
f"Student Answer: {student_ans}\n\n"
"Your response should be just 'Yes' or 'No'."
)
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_text}],
tokenize=False,
@ -481,10 +500,15 @@ def check_student_answers(
responses = vllm_generate_func(prompts)
responses_text = []
for response in responses:
# Handle different response formats
if hasattr(response, "outputs"):
responses_text.append(response.outputs[0].text)
try:
responses_text.append(response.outputs[0].text)
except (AttributeError, IndexError):
# Fallback for simple string responses
responses_text.append(str(response))
else:
responses_text.append(response)
responses_text.append(str(response))
logger.debug(f"Got {len(responses_text)} verification responses")
results = []
@ -495,34 +519,108 @@ def check_student_answers(
logger.info(f"Verification complete. {sum(results)}/{len(results)} answers correct")
# Append the QA details and verifier's response to the specified log file
with open(log_file, "a") as file:
for question, answer, student_ans, verifier_response in zip(
questions, answers, student_answers, responses_text
):
file.write("Question: " + question + "\n")
file.write("Correct Answer: " + answer + "\n")
file.write("Student Answer: " + student_ans + "\n")
file.write("Verifier said: " + verifier_response + "\n")
file.write("-" * 40 + "\n")
if log_file:
with open(log_file, "a") as file:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
file.write(f"\n📝 === QA Evaluation at {timestamp} ===\n")
file.write(f"📂 File: {__file__}\n")
# Get current frame info safely
frame = inspect.currentframe()
if frame:
file.write(f"📍 Line: {frame.f_lineno}\n")
# Don't forget to delete the frame to avoid reference cycles
del frame
file.write("=" * 80 + "\n")
for i, (question, answer, student_ans, verifier_response) in enumerate(
zip(questions, answers, student_answers, responses_text)
):
file.write(f"\n❓ Question {i+1}:\n")
file.write("-" * 40 + "\n")
file.write(f"📋 Question: {question}\n")
file.write(f"✅ Correct Answer: {answer}\n")
file.write(f"👨‍🎓 Student Answer: {student_ans}\n")
file.write(f"🔍 Verifier said: {verifier_response}\n")
# Add search results if available in the chat state
if isinstance(student_ans, dict) and "messages" in student_ans:
# Get messages from dict
messages = student_ans.get("messages", [])
search_results = [
msg.get("content", "")
for msg in messages
if msg.get("role") == "ipython"
]
if search_results:
file.write("\n🔎 Search Results:\n")
for j, result in enumerate(search_results, 1):
file.write(f"\nSearch {j}:\n{result}\n")
file.write("-" * 40 + "\n")
file.write(
f"\n📊 Summary: {sum(results)}/{len(results)} answers correct ({sum(results)/len(results)*100:.2f}%)\n"
)
file.write("=" * 80 + "\n\n")
return results
# Reward Functions
def build_reward_correctness_fn(generate_fn, tokenizer):
def build_reward_correctness_fn(generate_fn, tokenizer, log_file=None):
def reward_correctness(prompts, completions, **reward_kwargs):
teacher_answers = reward_kwargs["answer"]
student_answers = [
completion["messages"][-1]["content"] for completion in completions
]
# Log non-exact matches
for i, (student, teacher) in enumerate(zip(student_answers, teacher_answers)):
if student.strip().lower() != teacher.strip().lower():
logger.warning(
f"Non-exact match at index {i}:\n"
f"Student: {student}\n"
f"Teacher: {teacher}"
)
correct = check_student_answers(
prompts,
teacher_answers,
student_answers,
vllm_generate_func=generate_fn,
tokenizer=tokenizer,
log_file=log_file,
)
# Log correctness metrics with length info
log_metric(
"rewards/correctness", np.mean(correct), reward_kwargs.get("step", 0)
)
log_metric(
"rewards/correctness_std", np.std(correct), reward_kwargs.get("step", 0)
)
# Log length metrics
student_lengths = [len(ans.strip()) for ans in student_answers]
teacher_lengths = [len(ans.strip()) for ans in teacher_answers]
log_metric(
"metrics/avg_student_length",
np.mean(student_lengths),
reward_kwargs.get("step", 0),
)
log_metric(
"metrics/avg_teacher_length",
np.mean(teacher_lengths),
reward_kwargs.get("step", 0),
)
log_metric(
"metrics/length_ratio",
np.mean(student_lengths) / np.mean(teacher_lengths),
reward_kwargs.get("step", 0),
)
return correct
return reward_correctness
@ -535,14 +633,23 @@ def reward_formatting(prompts, completions, **reward_kwargs):
for message in chat["messages"]:
if "Error during" in message["content"]:
has_error[i] = True
logger.warning(f"Error in chat {i}: {message['content']}")
break
return [0.7 if not e else 0 for e in has_error]
rewards = [0.7 if not e else 0 for e in has_error]
# Log formatting metrics
log_metric("rewards/formatting", np.mean(rewards), reward_kwargs.get("step", 0))
log_metric("rewards/formatting_std", np.std(rewards), reward_kwargs.get("step", 0))
log_metric("metrics/error_rate", np.mean(has_error), reward_kwargs.get("step", 0))
return rewards
def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[float]:
"""
Reward function that encourages optimal retry behavior by counting total function calls
across all assistant messages in the conversation.
Reward function that encourages optimal retry behavior by only rewarding completions
where every assistant message contains at most 1 JSON object.
"""
rewards: list[float] = []
@ -558,21 +665,61 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa
rewards.append(0.0)
continue
# Count total function calls across all messages
total_retries: int = 0
# Check if every message has at most 1 JSON object
has_multiple_json = False
total_json_objects = 0
for msg in assistant_msgs:
total_retries += len(extract_json_objects(msg))
json_objects = extract_json_objects(msg)
json_count = len(json_objects)
total_json_objects += json_count
if json_count > 1:
has_multiple_json = True
logger.warning(
f"Message contains {json_count} JSON objects, which exceeds the limit of 1"
)
break
# Calculate reward using modified sigmoid function
x: float = float(total_retries - 4) # Center peak at 4 retries
base_reward: float = 1.0 / (1.0 + np.exp(-x + abs(x) / 2))
# Only reward if no message has multiple JSON objects
if has_multiple_json:
rewards.append(0.0)
else:
# Base reward is 1.0 if constraint is met
base_reward = 1.0
# Slight penalty for having too many total JSON objects across all messages
if total_json_objects > 4:
penalty = 0.1 * (total_json_objects - 4)
base_reward = max(0.2, base_reward - penalty)
logger.debug(
f"Applied penalty for {total_json_objects} total JSON objects: {penalty}"
)
# Additional penalty for excessive retries
if total_retries > 6:
penalty: float = 0.2 * (total_retries - 6)
base_reward = max(0.1, base_reward - penalty)
rewards.append(base_reward)
rewards.append(base_reward)
# Log retry behavior metrics
log_metric("rewards/retry_behavior", np.mean(rewards), reward_kwargs.get("step", 0))
log_metric(
"rewards/retry_behavior_std", np.std(rewards), reward_kwargs.get("step", 0)
)
log_metric(
"metrics/avg_json_per_msg",
np.mean(
[
len(extract_json_objects(msg["content"]))
for completion in completions
for msg in completion["messages"]
if msg["role"] == "assistant"
]
),
reward_kwargs.get("step", 0),
)
log_metric(
"metrics/multiple_json_violation_rate",
np.mean([0.0 if rewards[i] > 0.0 else 1.0 for i in range(len(rewards))]),
reward_kwargs.get("step", 0),
)
return rewards
@ -599,6 +746,11 @@ def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs):
]
logger.debug(f"Found {len(search_results)} search results for prompt {i}")
# Log ground truth chunk and searched chunks
logger.info(f"📝 Ground Truth Chunk: {correct_content}")
for j, result in enumerate(search_results):
logger.info(f"🔍 Searched Chunk {j+1}: {result}")
# Check if any search hit the correct chunk content
found_correct_chunk = False
for result in search_results:
@ -609,30 +761,145 @@ def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs):
)
break
if not found_correct_chunk:
logger.warning(
f"Failed to find correct chunk for prompt {i}:\n"
f"Search results: {[r[:100] + '...' for r in search_results]}"
)
reward = 1.0 if found_correct_chunk else 0.0
rewards.append(reward)
logger.debug(f"Reward for prompt {i}: {reward}")
logger.info(f"Average reward: {sum(rewards)/len(rewards):.3f}")
# Log detailed metrics for debugging
log_metric(
f"debug/chunk_match_{i}",
1 if found_correct_chunk else 0,
reward_kwargs.get("step", 0),
)
log_metric(
f"debug/search_results_count_{i}",
len(search_results),
reward_kwargs.get("step", 0),
)
if search_results:
log_metric(
f"debug/result_length_{i}",
np.mean([len(r.split()) for r in search_results]),
reward_kwargs.get("step", 0),
)
# Log chunk query metrics
log_metric("rewards/chunk_query", np.mean(rewards), reward_kwargs.get("step", 0))
log_metric("rewards/chunk_query_std", np.std(rewards), reward_kwargs.get("step", 0))
log_metric(
"metrics/avg_search_results",
np.mean(
[
len(
[
msg["content"]
for msg in chat_state["messages"]
if msg["role"] == "ipython"
]
)
for chat_state in completions
]
),
reward_kwargs.get("step", 0),
)
log_metric(
"metrics/chunk_match_rate", np.mean(rewards), reward_kwargs.get("step", 0)
)
# Log detailed debugging info
logger.info("Chunk Query Rewards Summary:")
logger.info(f"Total prompts: {len(prompts)}")
logger.info(f"Correct matches: {sum(rewards)}")
logger.info(f"Average reward: {np.mean(rewards):.3f}")
logger.info(f"Reward std: {np.std(rewards):.3f}")
return rewards
def run_eval(generate_fn, verify_fn, tokenizer):
logger.info("Starting evaluation")
def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=None):
"""
Run evaluation on the test dataset and return results.
Args:
generate_fn: Function to generate completions
verify_fn: Function to verify results
tokenizer: Tokenizer for processing text
output_file: Path to save evaluation results summary
debug_file: Path to save detailed debug information
Returns:
full_chat_states: The chat states from evaluation
"""
train_dataset, test_dataset = get_qa_dataset()
questions = test_dataset["prompt"]
logger.info(f"Loaded {len(questions)} test questions")
agentic_outputs = run_agent(generate_fn, tokenizer, questions)
full_chat_states = agentic_outputs.full_chat_states
final_responses = agentic_outputs.final_response_str
logger.info("Calculating rewards")
rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"])
avg_reward = sum(rewards) / len(rewards)
logger.info("EVALUATION RESULTS:")
logger.info(f"Percentage of correct answers: {avg_reward:.3f}")
# Calculate results
percent_correct = sum(rewards) / len(rewards) * 100
# Log results to console
logger.info("RESULTS:")
logger.info(f"percentage of correct answers: {percent_correct:.2f}%")
logger.info("=" * 30)
# Save results to file if specified
if output_file:
try:
with open(output_file, "w") as f:
f.write("EVALUATION RESULTS\n")
f.write("=================\n\n")
f.write(f"Total questions: {len(questions)}\n")
f.write(f"Correct answers: {sum(rewards)}\n")
f.write(f"Percentage correct: {percent_correct:.2f}%\n\n")
f.write("Individual results:\n")
for i, (q, r, resp) in enumerate(
zip(questions, rewards, final_responses)
):
f.write(f"\nQ{i+1}: {q[:100]}...\n")
f.write(f"Correct: {'' if r else ''}\n")
f.write(f"Response: {resp[:150]}...\n")
f.write("-" * 40 + "\n")
logger.info(f"Saved evaluation results to {output_file}")
except Exception as e:
logger.error(f"Error saving results file: {e}")
# Save debug information if specified
if debug_file:
try:
import json
debug_data = []
for i, (q, r, resp, chat) in enumerate(
zip(questions, rewards, final_responses, full_chat_states)
):
debug_data.append(
{
"question_id": i,
"question": q,
"is_correct": bool(r),
"final_response": resp,
"chat_state": {
k: str(v) if isinstance(v, (list, dict)) else v
for k, v in chat.items()
if k != "tokenizer"
},
}
)
with open(debug_file, "w") as f:
json.dump(debug_data, f, indent=2)
logger.info(f"Saved debug information to {debug_file}")
except Exception as e:
logger.error(f"Error saving debug file: {e}")
return full_chat_states

@ -3,40 +3,33 @@ Search module for RL training loop.
This module provides functions to search through vectorized documents and retrieve question-answer pairs.
"""
import pickle
import json
import random
import asyncio
from typing import List, Tuple, Optional, Union, Dict, Any
from enum import Enum
from pydantic import BaseModel
from langchain.vectorstores import FAISS
from datasets import Dataset
from embeddings import CustomHuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from src.config import DATA_DIR, logger
from src.embeddings import CustomHuggingFaceEmbeddings
# Load pre-saved vectorstore
def load_vectorstore():
"""Load the pre-saved FAISS index"""
try:
import os
embeddings = CustomHuggingFaceEmbeddings()
# Load the FAISS index with absolute path
index_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "faiss_index"
)
print(f"Loading FAISS index from: {index_path}")
# Load the FAISS index from the data directory
logger.info(f"Loading FAISS index from: {DATA_DIR}")
vectorstore = FAISS.load_local(
index_path, embeddings, allow_dangerous_deserialization=True
str(DATA_DIR), embeddings, allow_dangerous_deserialization=True
)
print("Successfully loaded FAISS index")
logger.info("Successfully loaded FAISS index")
return vectorstore
except Exception as e:
print(f"Error loading vectorstore: {e}")
logger.error(f"Error loading vectorstore: {e}")
import traceback
traceback.print_exc()
logger.debug(traceback.format_exc())
return None
@ -44,13 +37,13 @@ def load_vectorstore():
try:
vectorstore = load_vectorstore()
if vectorstore is None:
print("Warning: FAISS vectorstore could not be loaded.")
logger.warning("FAISS vectorstore could not be loaded.")
except Exception as e:
print(f"Error loading vectorstore: {e}")
logger.error(f"Error loading vectorstore: {e}")
vectorstore = None
def search(query: str, return_type=str, results: int = 5) -> Union[str, List[str]]:
def search(query: str, return_type=str, results: int = 5):
"""
Search for relevant chunks using similarity search.
@ -82,51 +75,36 @@ def search(query: str, return_type=str, results: int = 5) -> Union[str, List[str
# Load questions from saved data
def load_qa_data():
"""Load the pre-generated questions and document chunks"""
"""Load the pre-generated questions"""
try:
import os
# Get absolute paths to data files
base_dir = os.path.dirname(os.path.abspath(__file__))
chunks_path = os.path.join(base_dir, "saved_data", "chunks.pkl")
questions_path = os.path.join(base_dir, "saved_data", "questions.json")
print(f"Loading chunks from: {chunks_path}")
print(f"Loading questions from: {questions_path}")
# Load the chunks
with open(chunks_path, "rb") as f:
chunks = pickle.load(f)
questions_path = DATA_DIR / "questions.json"
logger.info(f"Loading questions from: {questions_path}")
# Load the questions
with open(questions_path, "r") as f:
questions = json.load(f)
print(
f"Successfully loaded {len(chunks)} chunks and {len(questions)} questions"
)
return chunks, questions
logger.info(f"Successfully loaded {len(questions)} questions")
return questions
except Exception as e:
print(f"Error loading QA data: {e}")
logger.error(f"Error loading QA data: {e}")
import traceback
traceback.print_exc()
return None, None
logger.debug(traceback.format_exc())
return None
# Load chunks and questions when module is imported
# Load questions when module is imported
try:
chunks, questions = load_qa_data()
if chunks is None or questions is None:
print("Warning: Could not load QA data.")
questions = load_qa_data()
if questions is None:
logger.warning("Could not load QA data.")
except Exception as e:
print(f"Error initializing QA data: {e}")
chunks, questions = None, None
logger.error(f"Error initializing QA data: {e}")
questions = None
def get_question_answer(
idx: Optional[int] = None, return_both: bool = True
) -> Union[dict, str]:
def get_question_answer(idx=None, return_both: bool = True) -> dict:
"""
Get a question-answer pair either by index or randomly.
@ -148,7 +126,7 @@ def get_question_answer(
qa_pair = questions[idx]
else:
raise ValueError(
f"Index out of range. Must be between 0 and {len(questions)-1}"
f"Index out of range. Must be between 0 and {len(questions) - 1}"
)
question = qa_pair["question"]
@ -168,7 +146,7 @@ def get_question_count() -> int:
return len(questions)
def get_qa_dataset():
def get_qa_dataset() -> tuple:
"""
Return a HuggingFace Dataset containing question and answer pairs.

@ -0,0 +1,6 @@
export CUDA_VISIBLE_DEVICES=0
python train_grpo.py

@ -1,192 +0,0 @@
# %%
import torch
# %%
from unsloth import FastLanguageModel, is_bfloat16_supported
max_seq_length = 4096 * 2 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
max_seq_length=max_seq_length,
load_in_4bit=True, # False for LoRA 16bit
fast_inference=True, # Enable vLLM fast inference
max_lora_rank=lora_rank,
gpu_memory_utilization=0.6, # Reduce if out of memory
)
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
], # Remove QKVO if out of memory
lora_alpha=lora_rank,
use_gradient_checkpointing="unsloth", # Enable long context finetuning
random_state=3407,
)
# %%
import re
from datasets import Dataset, load_dataset
from rl_helpers import get_qa_dataset
from search_module import get_question_answer, get_question_count, search
train_dataset, test_dataset = get_qa_dataset()
# %% [markdown]
# <a name="Train"></a>
# ### Train the model
#
# Now set up GRPO Trainer and all configurations!
# %%
import os
os.environ["WANDB_PROJECT"] = "bootstrap-search-rl"
# %%
# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer
import UnslothGRPOTrainerTemp
training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
use_vllm=True, # use vLLM for fast inference!
use_agentic_generate=True, # use agentic generation
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="paged_adamw_8bit",
logging_steps=1,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
per_device_train_batch_size=8,
gradient_accumulation_steps=1, # Increase to 4 for smoother training
num_generations=8, # Decrease if out of memory
max_prompt_length=1024,
max_completion_length=1024,
# num_train_epochs = 1, # Set to 1 for a full training run
max_steps=101,
save_steps=50,
max_grad_norm=0.1,
report_to="none", # Can use Weights & Biases
output_dir="full_local_training",
)
# %%
import rl_helpers
# importlib.reload(rl_helpers)
def agentic_generate(
prompts: list[str],
generate_fn,
max_generations: int = 6,
):
return run_agent(generate_fn, tokenizer, prompts, max_generations)
model.agentic_generate = agentic_generate
from vllm import SamplingParams
verifier_sampling_params = SamplingParams(
temperature=0.1,
top_p=0.95,
max_tokens=4096,
)
def verifier_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=verifier_sampling_params,
)
run_agent = rl_helpers.run_agent
reward_correctness = rl_helpers.build_reward_correctness_fn(
verifier_generate_fn,
tokenizer,
)
reward_formatting = rl_helpers.reward_formatting
import UnslothGRPOTrainerTemp
trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
reward_correctness,
reward_formatting,
],
args=training_args,
train_dataset=train_dataset,
)
# %%
trainer.train()
# %% [markdown]
# <a name="Inference"></a>
# ### Inference
# Now let's try benchmark the model we trained!
# %%
from vllm import SamplingParams
import rl_helpers
sampling_params = SamplingParams(
temperature=0.5,
top_p=0.95,
max_tokens=4096,
)
def eval_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=sampling_params,
lora_request=model.load_lora(
"full_local_training/checkpoint-101"
), # load the trained LoRA
)
rl_helpers.run_eval(
generate_fn=eval_generate_fn,
verify_fn=reward_correctness,
tokenizer=tokenizer,
)
# %%
# eval w/o lora
def eval_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=sampling_params,
)
rl_helpers.run_eval(
generate_fn=eval_generate_fn,
verify_fn=reward_correctness,
tokenizer=tokenizer,
)

@ -1,196 +0,0 @@
# %%
import torch
# %%
from unsloth import FastLanguageModel, is_bfloat16_supported
max_seq_length = 4096 * 2 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="meta-llama/Llama-3.2-1B-Instruct",
max_seq_length=max_seq_length,
load_in_4bit=True, # False for LoRA 16bit
fast_inference=True, # Enable vLLM fast inference
max_lora_rank=lora_rank,
gpu_memory_utilization=0.6, # Reduce if out of memory
)
print(tokenizer.chat_template) # See what format Qwen expects
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
], # Remove QKVO if out of memory
lora_alpha=lora_rank,
use_gradient_checkpointing="unsloth", # Enable long context finetuning
random_state=3407,
)
# %%
import re
from datasets import Dataset, load_dataset
from rl_helpers import get_qa_dataset
from search_module import get_question_answer, get_question_count, search
train_dataset, test_dataset = get_qa_dataset()
# %% [markdown]
# <a name="Train"></a>
# ### Train the model
#
# Now set up GRPO Trainer and all configurations!
# %%
import os
os.environ["WANDB_PROJECT"] = "bootstrap-search-rl"
# %%
# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer
import UnslothGRPOTrainerTemp
training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
use_vllm=True, # use vLLM for fast inference!
use_agentic_generate=True, # use agentic generation
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="paged_adamw_8bit",
logging_steps=1,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
per_device_train_batch_size=8,
gradient_accumulation_steps=1, # Increase to 4 for smoother training
num_generations=8, # Decrease if out of memory
max_prompt_length=1024,
max_completion_length=1024,
# num_train_epochs = 1, # Set to 1 for a full training run
max_steps=101,
save_steps=50,
max_grad_norm=0.1,
report_to="none", # Can use Weights & Biases
output_dir="full_local_training",
)
# %%
import rl_helpers
# importlib.reload(rl_helpers)
def agentic_generate(
prompts: list[str],
generate_fn,
max_generations: int = 6,
):
return run_agent(generate_fn, tokenizer, prompts, max_generations)
model.agentic_generate = agentic_generate
from vllm import SamplingParams
verifier_sampling_params = SamplingParams(
temperature=0.1,
top_p=0.95,
max_tokens=4096,
)
def verifier_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=verifier_sampling_params,
)
run_agent = rl_helpers.run_agent
reward_correctness = rl_helpers.build_reward_correctness_fn(
verifier_generate_fn,
tokenizer,
)
reward_formatting = rl_helpers.reward_formatting
import UnslothGRPOTrainerTemp
trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
reward_correctness,
reward_formatting,
],
args=training_args,
train_dataset=train_dataset,
)
# %%
trainer.train()
# %% [markdown]
# <a name="Inference"></a>
# ### Inference
# Now let's try benchmark the model we trained!
# %%
from vllm import SamplingParams
import rl_helpers
sampling_params = SamplingParams(
temperature=0.5,
top_p=0.95,
max_tokens=4096,
)
def eval_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=sampling_params,
lora_request=model.load_lora(
"full_local_training/checkpoint-101"
), # load the trained LoRA
)
rl_helpers.run_eval(
generate_fn=eval_generate_fn,
verify_fn=reward_correctness,
tokenizer=tokenizer,
)
# %%
# eval w/o lora
def eval_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=sampling_params,
)
rl_helpers.run_eval(
generate_fn=eval_generate_fn,
verify_fn=reward_correctness,
tokenizer=tokenizer,
)

@ -0,0 +1,124 @@
import os
from unsloth import FastLanguageModel, is_bfloat16_supported
import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
# Import reward functions
from src.rl_helpers import (
build_reward_correctness_fn,
get_qa_dataset,
reward_exact_match_chunk_query,
reward_formatting,
reward_retry_behavior,
run_agent,
)
from src.config import (
MODEL_CONFIG,
MODEL_NAME,
OUTPUT_DIR,
TRAINING_CONFIG,
get_sampling_params,
init_training_dirs,
logger,
update_log_path,
)
# Initialize training directories
paths = init_training_dirs()
# Update logger to use the training directory
update_log_path(paths["log_dir"])
logger.info(f"Training output directory: {paths['output_dir']}")
logger.info(f"Logs are being saved to both ./logs and {paths['log_dir']}")
# Initialize model and tokenizer
logger.info(f"Initializing model {MODEL_NAME}")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME,
max_seq_length=MODEL_CONFIG["max_seq_length"],
load_in_4bit=True, # False for LoRA 16bit
fast_inference=True, # Enable vLLM fast inference
max_lora_rank=MODEL_CONFIG["lora_rank"],
gpu_memory_utilization=MODEL_CONFIG["gpu_memory_utilization"],
)
# Setup LoRA
logger.info("Setting up LoRA adapter")
model = FastLanguageModel.get_peft_model(
model,
r=MODEL_CONFIG["lora_rank"],
target_modules=MODEL_CONFIG["target_modules"],
lora_alpha=MODEL_CONFIG["lora_rank"],
use_gradient_checkpointing=True, # Enable long context finetuning
random_state=3407,
)
# Load datasets
logger.info("Loading datasets")
train_dataset, test_dataset = get_qa_dataset()
logger.info(
f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples"
)
# Setup training arguments
logger.info("Setting up training arguments")
training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
use_vllm=True, # use vLLM for fast inference!
use_agentic_generate=True, # use agentic generation
**TRAINING_CONFIG,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
output_dir=OUTPUT_DIR,
# report_to="tensorboard", # ❓ Does't have billions of tensorboard files if set report to right here
)
# Setup model generation functions
def agentic_generate(
prompts: list,
generate_fn,
max_generations: int = 10,
):
return run_agent(generate_fn, tokenizer, prompts, max_generations)
model.agentic_generate = agentic_generate
# Setup verifier
logger.info("Setting up verifier")
verifier_sampling_params = get_sampling_params(temperature=0.1)
def verifier_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=verifier_sampling_params,
)
# Setup trainer
logger.info("Initializing trainer")
trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
build_reward_correctness_fn(
verifier_generate_fn,
tokenizer,
log_file=os.path.join(paths["log_dir"], "qa_log.txt"),
),
reward_formatting,
reward_retry_behavior,
reward_exact_match_chunk_query,
],
args=training_args,
train_dataset=train_dataset,
)
# Train the model
if __name__ == "__main__":
logger.info("Starting training")
trainer.train()
logger.info("Training completed")
logger.info(f"Model saved to {OUTPUT_DIR}")
Loading…
Cancel
Save