refactor: restructure code base, better centralize logging logic

main
thinhlpg 1 month ago
parent 04d56325bb
commit 3c2deaced9

@ -1,2 +1,2 @@
HF_TOKEN= HF_TOKEN=<your-huggingface-token>
OPENROUTER_API_KEY= OPENROUTER_API_KEY=<your-openrouter-api-key>

2
.gitignore vendored

@ -8,6 +8,8 @@ unsloth_compiled_cache/
full_local_training/ full_local_training/
grpo_trainer_lora_model/ grpo_trainer_lora_model/
qa_log.txt qa_log.txt
trainer_output_*
data/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/

@ -1,232 +0,0 @@
"""
This script performs two main tasks:
1. It loads a markdown document, splits it into chunks, generates embeddings,
and builds a FAISS index (which is saved locally).
2. It generates QA pairs from the document using llama.
For each chunk (using a sliding window for context), it generates multiple question-answer pairs
with different difficulties. The generation is performed in batch with one retry for failed prompts.
Successfully generated QA pairs are saved to "saved_data/questions.json".
Requirements:
pip install langchain faiss-cpu unsloth vllm
"""
import json
import os
import pickle
import re
from typing import Dict, List, Optional, Tuple
from langchain.text_splitter import RecursiveCharacterTextSplitter
# ========= Part 1: Document Processing and Embedding Generation =========
# Load and split the markdown document using LangChain
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.vectorstores import FAISS
from embeddings import CustomHuggingFaceEmbeddings
# Load your markdown file (adjust the path as needed)
loader = UnstructuredMarkdownLoader("./data/mission_report.md")
docs = loader.load()
# Split the document into smaller chunks (each 1000 characters, no overlap)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
chunks = text_splitter.split_documents(docs)
# Save chunks for later use # TODO: change to csv? easier inspect.
os.makedirs("saved_data", exist_ok=True)
with open("saved_data/chunks.pkl", "wb") as f:
pickle.dump(chunks, f)
print(f"Saved {len(chunks)} chunks to saved_data/chunks.pkl")
embeddings = CustomHuggingFaceEmbeddings()
# Create a FAISS vector store from the document chunks and save it locally
vectorstore = FAISS.from_documents(chunks, embeddings)
vectorstore.save_local("faiss_index")
print("Saved FAISS index to 'faiss_index'")
# TODO: add the paraphrased chunks to the vector store
# ========= Part 2: QA Generation using Llama Backend =========
# Setup Llama backend via unsloth and vLLM
from unsloth import FastLanguageModel
from vllm import SamplingParams
import rl_helpers # Ensure you have this or remove if not used
# Load the Llama model (adjust parameters as needed)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
max_seq_length=4096,
load_in_4bit=True, # Use 4-bit quantization if desired
fast_inference=True, # Enable fast inference
gpu_memory_utilization=0.6, # Adjust based on your GPU memory
)
# Define sampling parameters for generation
sampling_params = SamplingParams(
temperature=0.3,
top_p=0.95,
max_tokens=4096,
)
def batch_generate(prompts: List[str]) -> List[str]:
"""
Given a list of prompt strings, returns a list of generated outputs.
"""
def format_input(text: str) -> str:
return tokenizer.apply_chat_template(
[{"role": "user", "content": text}],
tokenize=False,
add_generation_prompt=True,
)
formatted = [format_input(p) for p in prompts]
outputs = model.fast_generate(formatted, sampling_params=sampling_params)
return [output.outputs[0].text for output in outputs]
def parse_qa_block(block: str) -> Optional[Tuple[str, str, str]]:
"""
Parses a QA block that should contain exactly three non-empty lines:
- A line starting with "Question:"
- A line starting with "Answer:"
- A line starting with "Difficulty:"
If the markers are not present but the block contains exactly three lines,
those are used in order.
Returns a tuple (question, answer, difficulty) or None if parsing fails.
"""
lines = [line.strip() for line in block.splitlines() if line.strip()]
if not lines:
return None
question, answer, difficulty = None, None, None
for line in lines:
lower = line.lower()
if question is None and lower.startswith("question:"):
question = line[len("question:") :].strip()
elif answer is None and lower.startswith("answer:"):
answer = line[len("answer:") :].strip()
elif difficulty is None and lower.startswith("difficulty:"):
difficulty = line[len("difficulty:") :].strip()
if question and answer and difficulty:
return question, answer, difficulty
if len(lines) == 3:
return lines[0], lines[1], lines[2]
return None
def parse_multiple_qa_output(output: str) -> List[Tuple[str, str, str]]:
"""
Splits the output into blocks (separated by one or more blank lines) and
attempts to parse each as a QA pair.
Returns a list of successfully parsed QA tuples.
"""
blocks = re.split(r"\n\s*\n", output.strip())
qa_pairs = []
for block in blocks:
parsed = parse_qa_block(block)
if parsed:
qa_pairs.append(parsed)
return qa_pairs
def generate_question_batch_for_chunks(
chunks: List, num_questions: int = 2, difficulty: str = None
) -> List[Dict]:
"""
Generates QA pairs for multiple chunks in batch.
For each chunk (except the first and last), a sliding window is used for context:
- before: previous chunk's content
- current: current chunk's content
- after: next chunk's content
Each prompt instructs the model to output exactly three lines per QA pair with markers.
Failed prompts are retried once in batch; if still unsuccessful, they are skipped.
Returns a list of dicts with keys: "chunk_id", "question", "answer", "difficulty".
"""
prompts = []
chunk_ids = []
# Prepare prompts using a sliding window
for i in range(1, len(chunks) - 1):
before = chunks[i - 1].page_content
current = chunks[i].page_content
after = chunks[i + 1].page_content
prompt = (
f"From the text within ==BEGIN== and ==END==, generate {num_questions} questions with answers.\n"
"For each QA pair, output exactly three lines with no extra commentary:\n"
"Line 1: Question: <your question>\n"
"Line 2: Answer: <the answer>\n"
"Line 3: Difficulty: <easy, medium, or hard>\n"
"Do not include any additional text.\n\n"
"==BEGIN==\n"
f"{before}\n{current}\n{after}\n"
"==END==\n"
)
prompts.append(prompt)
chunk_ids.append(i)
# First batch generation
outputs = batch_generate(prompts)
results = [None] * len(outputs)
failed_indices = []
# Parse each output
for idx, output in enumerate(outputs):
qa_pairs = parse_multiple_qa_output(output)
if qa_pairs is None or len(qa_pairs) < num_questions:
failed_indices.append(idx)
else:
results[idx] = qa_pairs[:num_questions]
# Retry failed prompts in batch
if failed_indices:
print(f"Retrying {len(failed_indices)} failed prompt(s)...")
retry_prompts = [prompts[i] for i in failed_indices]
retry_outputs = batch_generate(retry_prompts)
for j, idx in enumerate(failed_indices):
qa_pairs = parse_multiple_qa_output(retry_outputs[j])
if qa_pairs is not None and len(qa_pairs) >= num_questions:
results[idx] = qa_pairs[:num_questions]
else:
results[idx] = None # Mark as failed
# Build final output, skipping prompts that failed even after retry
final_questions = []
for i, qa_list in enumerate(results):
if qa_list is not None:
for qa in qa_list:
final_questions.append(
{
"chunk_id": chunk_ids[i],
"question": qa[0],
"answer": qa[1],
"difficulty": qa[2],
}
)
return final_questions
# Generate QA pairs in batch (using a sliding window over the chunks)
all_questions = generate_question_batch_for_chunks(
chunks, num_questions=2, difficulty="medium"
)
print(f"Generated {len(all_questions)} QA pairs.")
# Save the QA pairs to a JSON file
questions_path = os.path.join("saved_data", "questions.json")
with open(questions_path, "w") as f:
json.dump(all_questions, f, indent=2)
print(f"Saved questions to {questions_path}")

@ -1,3 +0,0 @@
unsloth_compiled_cache
0_*
faiss_index*

@ -5,13 +5,13 @@ langchain-community
Markdown Markdown
tokenizers tokenizers
unsloth==2025.3.6 unsloth==2025.3.6
transformers==4.49.0
unsloth_zoo==2025.3.4 unsloth_zoo==2025.3.4
unstructured unstructured
vllm vllm==0.7.2
wandb
transformers==4.49.0
ipykernel ipykernel
python-dotenv python-dotenv
loguru loguru
gradio gradio
tensorboard

@ -5,35 +5,31 @@ This script performs two main tasks:
2. It generates QA pairs from the document using llama. 2. It generates QA pairs from the document using llama.
For each chunk (using a sliding window for context), it generates multiple question-answer pairs For each chunk (using a sliding window for context), it generates multiple question-answer pairs
with different difficulties. The generation is performed in batch with one retry for failed prompts. with different difficulties. The generation is performed in batch with one retry for failed prompts.
Successfully generated QA pairs are saved to "saved_data/questions.json". Successfully generated QA pairs are saved to "data/questions.json".
Requirements: Requirements:
pip install langchain faiss-cpu unsloth vllm pip install langchain faiss-cpu unsloth vllm
""" """
import json import json
import os
import re import re
from typing import Dict, List, Optional, Tuple import sys
from pathlib import Path
# Add project root to Python path
project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))
import pandas as pd import pandas as pd
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from loguru import logger
# Configure logger
logger.add(
"logs/generate_data_{time}.log",
rotation="500 MB",
retention="10 days",
level="INFO",
)
# ========= Part 1: Document Processing and Embedding Generation ========= # ========= Part 1: Document Processing and Embedding Generation =========
# Load and split the markdown document using LangChain # Load and split the markdown document using LangChain
from langchain_community.document_loaders import UnstructuredMarkdownLoader from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.vectorstores import FAISS from langchain_community.vectorstores import FAISS
from embeddings import CustomHuggingFaceEmbeddings from src.config import DATA_DIR, logger
from src.embeddings import CustomHuggingFaceEmbeddings
# Load your markdown file (adjust the path as needed) # Load your markdown file (adjust the path as needed)
loader = UnstructuredMarkdownLoader("./data/mission_report.md") loader = UnstructuredMarkdownLoader("./data/mission_report.md")
@ -43,9 +39,6 @@ docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
chunks = text_splitter.split_documents(docs) chunks = text_splitter.split_documents(docs)
# Create output directory
os.makedirs("saved_data", exist_ok=True)
# Save chunks to CSV for easy inspection # Save chunks to CSV for easy inspection
chunks_df = pd.DataFrame( chunks_df = pd.DataFrame(
{ {
@ -54,15 +47,15 @@ chunks_df = pd.DataFrame(
"metadata": [chunk.metadata for chunk in chunks], "metadata": [chunk.metadata for chunk in chunks],
} }
) )
chunks_df.to_csv("saved_data/chunks.csv", index=False) chunks_df.to_csv(DATA_DIR / "chunks.csv", index=False)
print(f"Saved {len(chunks)} chunks to saved_data/chunks.csv") logger.info(f"Saved {len(chunks)} chunks to {DATA_DIR}/chunks.csv")
embeddings = CustomHuggingFaceEmbeddings() embeddings = CustomHuggingFaceEmbeddings()
# Create a FAISS vector store from the document chunks and save it locally # Create a FAISS vector store from the document chunks and save it locally
vectorstore = FAISS.from_documents(chunks, embeddings) vectorstore = FAISS.from_documents(chunks, embeddings)
vectorstore.save_local("faiss_index") vectorstore.save_local(str(DATA_DIR))
print("Saved FAISS index to 'faiss_index'") logger.info(f"Saved FAISS index to {DATA_DIR}")
# TODO: add the paraphrased chunks to the vector store # TODO: add the paraphrased chunks to the vector store
@ -72,8 +65,6 @@ print("Saved FAISS index to 'faiss_index'")
from unsloth import FastLanguageModel from unsloth import FastLanguageModel
from vllm import SamplingParams from vllm import SamplingParams
import rl_helpers # Ensure you have this or remove if not used
# Load the Llama model (adjust parameters as needed) # Load the Llama model (adjust parameters as needed)
model, tokenizer = FastLanguageModel.from_pretrained( model, tokenizer = FastLanguageModel.from_pretrained(
model_name="meta-llama/meta-Llama-3.1-8B-Instruct", model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
@ -91,7 +82,7 @@ sampling_params = SamplingParams(
) )
def batch_generate(prompts: List[str]) -> List[str]: def batch_generate(prompts: list) -> list:
""" """
Given a list of prompt strings, returns a list of generated outputs. Given a list of prompt strings, returns a list of generated outputs.
""" """
@ -108,7 +99,7 @@ def batch_generate(prompts: List[str]) -> List[str]:
return [output.outputs[0].text for output in outputs] return [output.outputs[0].text for output in outputs]
def parse_qa_block(block: str) -> Optional[Tuple[str, str, str]]: def parse_qa_block(block: str):
""" """
Parses a QA block that should contain exactly three non-empty lines: Parses a QA block that should contain exactly three non-empty lines:
- A line starting with "Question:" - A line starting with "Question:"
@ -141,7 +132,7 @@ def parse_qa_block(block: str) -> Optional[Tuple[str, str, str]]:
return None return None
def parse_multiple_qa_output(output: str) -> List[Tuple[str, str, str]]: def parse_multiple_qa_output(output: str) -> list:
""" """
Splits the output into blocks (separated by one or more blank lines) and Splits the output into blocks (separated by one or more blank lines) and
attempts to parse each as a QA pair. attempts to parse each as a QA pair.
@ -158,8 +149,8 @@ def parse_multiple_qa_output(output: str) -> List[Tuple[str, str, str]]:
def generate_question_batch_for_chunks( def generate_question_batch_for_chunks(
chunks: List, num_questions: int = 2, difficulty: Optional[str] = None chunks: list, num_questions: int = 2, difficulty=None
) -> List[Dict]: ) -> list:
""" """
Generates QA pairs for multiple chunks in batch. Generates QA pairs for multiple chunks in batch.
@ -198,7 +189,9 @@ def generate_question_batch_for_chunks(
# First batch generation # First batch generation
outputs = batch_generate(prompts) outputs = batch_generate(prompts)
results: List[Optional[List[Tuple[str, str, str]]]] = [None] * len(outputs) results = []
for _ in range(len(outputs)):
results.append(None)
failed_indices = [] failed_indices = []
# Parse each output # Parse each output
@ -270,10 +263,10 @@ def generate_question_batch_for_chunks(
all_questions = generate_question_batch_for_chunks( all_questions = generate_question_batch_for_chunks(
chunks, num_questions=2, difficulty="medium" chunks, num_questions=2, difficulty="medium"
) )
print(f"Generated {len(all_questions)} QA pairs.") logger.info(f"Generated {len(all_questions)} QA pairs.")
# Save the QA pairs to a JSON file # Save the QA pairs to a JSON file
questions_path = os.path.join("saved_data", "questions.json") questions_path = DATA_DIR / "questions.json"
with open(questions_path, "w") as f: with open(questions_path, "w") as f:
json.dump(all_questions, f, indent=2) json.dump(all_questions, f, indent=2)
print(f"Saved questions to {questions_path}") logger.info(f"Saved questions to {questions_path}")

@ -8,10 +8,18 @@ import json
import random import random
import sys import sys
import time import time
from typing import Any, Dict from pathlib import Path
# Import our search module (ensure these functions follow the new interfaces) # Add project root to Python path
from search_module import get_question_answer, get_question_count, search project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))
# Import our search module and config
from src.config import DATA_DIR, logger
from src.search_module import get_question_answer, get_question_count, search
# TODO: Import verify function and router from appropriate module
# TODO: Consider moving verify function to search_module.py for better organization
class SimpleQAEnvironment: class SimpleQAEnvironment:
@ -21,27 +29,30 @@ class SimpleQAEnvironment:
self.score = {"correct": 0, "incorrect": 0, "total": 0} self.score = {"correct": 0, "incorrect": 0, "total": 0}
self.session_data = [] self.session_data = []
self.current_question = None self.current_question = None
self.session_file = DATA_DIR / "qa_sessions"
def display_welcome(self): def display_welcome(self):
"""Display welcome message and instructions.""" """Display welcome message and instructions."""
print("\n===== Search & Answer Environment =====") logger.info("===== Search & Answer Environment =====")
print("Answer questions using the search tool to find relevant information.") logger.info(
print("Type 'q' to quit, 'h' for help.\n") "Answer questions using the search tool to find relevant information."
)
logger.info("Type 'q' to quit, 'h' for help.\n")
def display_help(self): def display_help(self):
"""Display help information.""" """Display help information."""
print("\n===== Commands =====") logger.info("\n===== Commands =====")
print("n - Get a new question") logger.info("n - Get a new question")
print("s <query> - Search for information (e.g., s program launch date)") logger.info("s <query> - Search for information (e.g., s program launch date)")
print("a <answer> - Submit your answer") logger.info("a <answer> - Submit your answer")
print("h - Display this help message") logger.info("h - Display this help message")
print("q - Quit the program\n") logger.info("q - Quit the program\n")
def display_question(self, question: str): def display_question(self, question: str):
"""Display the current question.""" """Display the current question."""
print("\n===== QUESTION =====") logger.info("\n===== QUESTION =====")
print(question) logger.info(question)
print("=====================\n") logger.info("=====================\n")
def get_new_question(self) -> str: def get_new_question(self) -> str:
"""Get a new random question and set it as current.""" """Get a new random question and set it as current."""
@ -66,30 +77,30 @@ class SimpleQAEnvironment:
def perform_search(self, query: str): def perform_search(self, query: str):
"""Perform a search with the given query.""" """Perform a search with the given query."""
if not query: if not query:
print("Please provide a search query.") logger.warning("Please provide a search query.")
return return
try: try:
print("\n===== SEARCH RESULTS =====") logger.info("\n===== SEARCH RESULTS =====")
results = search(query) results = search(query)
print(results) logger.info(results)
print("==========================\n") logger.info("==========================\n")
# Record search in current question data if available. # Record search in current question data if available.
if self.current_question is not None: if self.current_question is not None:
self.current_question["searches"].append(query) self.current_question["searches"].append(query)
except Exception as e: except Exception as e:
print(f"Error searching: {str(e)}") logger.error(f"Error searching: {str(e)}")
async def process_answer(self, user_answer: str): async def process_answer(self, user_answer: str):
"""Process and verify the user's answer.""" """Process and verify the user's answer."""
if self.current_question is None: if self.current_question is None:
print("Please get a question first.") logger.warning("Please get a question first.")
return return
if not user_answer: if not user_answer:
print("Please provide an answer.") logger.warning("Please provide an answer.")
return return
# Record answer and calculate time taken. # Record answer and calculate time taken.
@ -100,27 +111,29 @@ class SimpleQAEnvironment:
) )
try: try:
print("\nVerifying your answer...") logger.info("\nVerifying your answer...")
correct = await verify( # TODO: Implement verify function in search_module.py
user_answer, # correct = await verify(
self.current_question["question"], # user_answer,
self.current_question["correct_answer"], # self.current_question["question"],
router, # self.current_question["correct_answer"],
) # router,
# )
correct = False # Temporary placeholder until verify is implemented
# Update score and inform the user. # Update score and inform the user.
self.score["total"] += 1 self.score["total"] += 1
if correct: if correct:
self.score["correct"] += 1 self.score["correct"] += 1
print("\n✓ Your answer is CORRECT!") logger.success("\n✓ Your answer is CORRECT!")
else: else:
self.score["incorrect"] += 1 self.score["incorrect"] += 1
print("\n✗ Your answer is INCORRECT.") logger.error("\n✗ Your answer is INCORRECT.")
print( logger.info(
f"\nThe correct answer is:\n{self.current_question['correct_answer']}" f"\nThe correct answer is:\n{self.current_question['correct_answer']}"
) )
print(f"\nScore: {self.score['correct']}/{self.score['total']}") logger.info(f"\nScore: {self.score['correct']}/{self.score['total']}")
# Record the result and add the current question to the session data. # Record the result and add the current question to the session data.
self.current_question["is_correct"] = correct self.current_question["is_correct"] = correct
@ -130,15 +143,18 @@ class SimpleQAEnvironment:
self.current_question = None self.current_question = None
except Exception as e: except Exception as e:
print(f"Error verifying answer: {str(e)}") logger.error(f"Error verifying answer: {str(e)}")
def save_session(self): def save_session(self):
"""Save the session data to a file.""" """Save the session data to a file."""
if not self.session_data: if not self.session_data:
return return
# Ensure session directory exists
self.session_file.mkdir(parents=True, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S") timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"qa_session_{timestamp}.json" filename = self.session_file / f"qa_session_{timestamp}.json"
session_data = { session_data = {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
@ -149,9 +165,9 @@ class SimpleQAEnvironment:
try: try:
with open(filename, "w") as f: with open(filename, "w") as f:
json.dump(session_data, f, indent=2) json.dump(session_data, f, indent=2)
print(f"\nSession data saved to {filename}") logger.info(f"\nSession data saved to {filename}")
except Exception as e: except Exception as e:
print(f"Error saving session data: {str(e)}") logger.error(f"Error saving session data: {str(e)}")
async def run(self): async def run(self):
"""Run the main command loop.""" """Run the main command loop."""
@ -178,11 +194,11 @@ class SimpleQAEnvironment:
answer = command[2:].strip() answer = command[2:].strip()
await self.process_answer(answer) await self.process_answer(answer)
else: else:
print("Unknown command. Type 'h' for help.") logger.warning("Unknown command. Type 'h' for help.")
# Save session data on exit. # Save session data on exit.
self.save_session() self.save_session()
print("\nThank you for using the Q&A environment!") logger.info("\nThank you for using the Q&A environment!")
async def main(): async def main():
@ -195,6 +211,6 @@ if __name__ == "__main__":
try: try:
asyncio.run(main()) asyncio.run(main())
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nProgram terminated by user.") logger.info("\nProgram terminated by user.")
except Exception as e: except Exception as e:
print(f"\nError: {str(e)}") logger.error(f"\nError: {str(e)}")

@ -1,18 +1,21 @@
from torch import Tensor import os
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import *
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from packaging.version import Version
from trl.trainer.grpo_trainer import ( from trl.trainer.grpo_trainer import (
Any, Any,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoTokenizer, AutoTokenizer,
Dataset, Dataset,
GRPOConfig,
GRPOTrainer,
GenerationConfig, GenerationConfig,
GRPOConfig,
IterableDataset, IterableDataset,
LLM,
Optional, Optional,
PeftConfig, PeftConfig,
PreTrainedModel, PreTrainedModel,
@ -41,7 +44,6 @@ from trl.trainer.grpo_trainer import (
nn, nn,
os, os,
pad, pad,
patch,
prepare_deepspeed, prepare_deepspeed,
set_seed, set_seed,
textwrap, textwrap,
@ -50,42 +52,8 @@ from trl.trainer.grpo_trainer import (
unwrap_model_for_generation, unwrap_model_for_generation,
version, version,
wandb, wandb,
warnings,
os,
torch,
transformers,
Any,
LLM,
Union,
apply_chat_template,
broadcast_object_list,
gather,
gather_object,
is_conversational,
maybe_apply_chat_template,
nn,
os,
pad,
torch,
unwrap_model_for_generation,
wandb,
GRPOTrainer,
Trainer,
gather,
os,
torch,
) )
import os
from typing import *
from dataclasses import dataclass, field
from packaging.version import Version
import torch
import numpy as np
from contextlib import nullcontext
from torch.nn import functional as F
torch_compile_options = { torch_compile_options = {
"epilogue_fusion": True, "epilogue_fusion": True,
"max_autotune": False, "max_autotune": False,

@ -0,0 +1,301 @@
import os
import sys
from datetime import datetime
from pathlib import Path
import torch
from dotenv import load_dotenv
from loguru import logger
from vllm import SamplingParams
# Load environment variables from .env file if it exists
load_dotenv(override=True)
# Project paths
PROJ_ROOT = Path(__file__).resolve().parent.parent
DATA_DIR = PROJ_ROOT / "data"
LOG_FOLDER = PROJ_ROOT / "logs"
# Model configuration
# MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
device_id = (
1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device()
)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
OUTPUT_DIR = (
PROJ_ROOT
/ f"trainer_output_{MODEL_NAME.replace('/', '_')}_gpu{device_id}_{timestamp}"
)
# Model parameters
MODEL_CONFIG = {
"max_seq_length": 4096 * 2, # Can increase for longer reasoning traces
"lora_rank": 64, # Larger rank = smarter, but slower
"gpu_memory_utilization": 0.6, # Reduce if out of memory
"model_name": MODEL_NAME,
"target_modules": [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
], # Remove QKVO if out of memory
}
# Training parameters
TRAINING_CONFIG = {
"learning_rate": 5e-6,
"adam_beta1": 0.9,
"adam_beta2": 0.99,
"weight_decay": 0.1,
"warmup_ratio": 0.1,
"lr_scheduler_type": "cosine",
"optim": "paged_adamw_8bit",
"logging_steps": 1,
"per_device_train_batch_size": 8,
"gradient_accumulation_steps": 1, # Increase to 4 for smoother training
"num_generations": 8, # Decrease if out of memory
"max_prompt_length": 1024,
"max_completion_length": 1024,
"max_steps": 101,
"save_steps": 50,
"max_grad_norm": 0.1,
"report_to": "tensorboard",
}
# Sampling parameters
def get_sampling_params(temperature: float = 0.1) -> SamplingParams:
"""Get sampling parameters for text generation"""
return SamplingParams(
temperature=temperature,
top_p=0.95,
max_tokens=4096,
)
# Initialize logging based on environment
def _init_logging(env: str = "development") -> None:
"""
Initialize logging configuration with console logging
and default file logging to ./logs directory.
Additional file logging will be set up later in update_log_path().
Args:
env: The environment for logging ('development' or 'production')
"""
# Create default log folder
if not LOG_FOLDER.exists():
LOG_FOLDER.mkdir(parents=True, exist_ok=True)
# Remove any existing handlers
logger.remove()
# Define the logging format
console_format = (
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> "
"| <level>{level: <8}</level> "
"| <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> "
"- <level>{message}</level>"
)
file_format = (
"{time:YYYY-MM-DD at HH:mm:ss} "
"| {level} "
"| {name}:{function}:{line} "
"- {message}"
)
# Add console logging
logger.add(
sys.stderr,
format=console_format,
level="DEBUG" if env == "development" else "INFO",
colorize=True,
backtrace=True,
diagnose=env == "development",
)
# Add default file logging to ./logs directory
logger.add(
LOG_FOLDER / "app.log",
format=file_format,
level="INFO",
rotation="500 MB",
retention="7 days",
compression="zip",
enqueue=True, # Enables asynchronous logging
)
# Add custom level for requests
logger.level("REQUEST", no=25, color="<yellow>", icon=" ")
# Configure exception handling
def exception_handler(exc_type, exc_value, exc_traceback):
if issubclass(exc_type, KeyboardInterrupt):
sys.__excepthook__(exc_type, exc_value, exc_traceback)
return
logger.opt(exception=(exc_type, exc_value, exc_traceback)).critical(
"Unhandled exception"
)
sys.excepthook = exception_handler
# Update the log files to point to the training directory
def update_log_path(log_dir=None):
"""
Add a log file in the training directory while keeping the default ./logs logging.
Should be called after the training directory is created.
Args:
log_dir: Path to store additional log files (default: uses get_paths()["log_dir"])
"""
# Use provided log_dir or get from training paths
if log_dir is None:
paths = get_paths(create_dirs=True)
log_dir = paths["log_dir"]
else:
log_dir = Path(log_dir)
log_dir.mkdir(exist_ok=True, parents=True)
file_format = (
"{time:YYYY-MM-DD at HH:mm:ss} "
"| {level} "
"| {name}:{function}:{line} "
"- {message}"
)
# Add additional file handler pointing to training directory
# No need to remove existing handlers as we want to keep those
logger.add(
log_dir / "app.log",
format=file_format,
level="INFO",
rotation="500 MB",
retention="7 days",
compression="zip",
enqueue=True, # Enables asynchronous logging
)
logger.info(f"Additional logs will be stored in: {log_dir}")
# Paths configuration without creating directories
def get_paths(create_dirs: bool = False) -> dict:
"""
Get common paths for the project
Args:
create_dirs: Whether to create the directories
Returns:
Dictionary with paths
"""
output_dir = Path(OUTPUT_DIR)
log_dir = output_dir / "logs"
tensorboard_dir = output_dir / "runs"
# Only create directories if explicitly requested
if create_dirs:
output_dir.mkdir(exist_ok=True)
log_dir.mkdir(exist_ok=True)
# Only create tensorboard directory if it's enabled in config
if TRAINING_CONFIG.get("report_to") == "tensorboard":
tensorboard_dir.mkdir(exist_ok=True)
return {
"output_dir": output_dir,
"log_dir": log_dir,
"tensorboard_dir": tensorboard_dir,
"proj_root": PROJ_ROOT,
"data_dir": DATA_DIR,
}
# Create training directories
def init_training_dirs():
"""Initialize all directories needed for training"""
paths = get_paths(create_dirs=True)
# Also ensure our standard project directories exist
for directory in [
DATA_DIR,
LOG_FOLDER,
]:
directory.mkdir(exist_ok=True, parents=True)
return paths
# For backward compatibility - will be deprecated
def setup_logger(module_name=None, create_dirs: bool = False):
"""
Setup a logger for a specific module with consistent configuration.
Note: This function is kept for backward compatibility.
Use the global 'logger' instead for new code.
Args:
module_name: Optional name of module for module-specific log file
create_dirs: Whether to create log directories
Returns:
Configured logger instance
"""
logger.warning(
"setup_logger is deprecated. Import logger directly from config instead."
)
return logger
# Tensorboard writer singleton
_tensorboard_writer = None
# Safe tensorboard logging function
def log_metric(key, value, step=0):
"""
Log a metric safely to tensorboard if writer is available.
Args:
key: Metric name
value: Metric value
step: Training step
"""
global _tensorboard_writer
# Skip tensorboard logging if disabled in config
if TRAINING_CONFIG.get("report_to") != "tensorboard":
logger.debug(f"Tensorboard disabled. Metric: {key}={value} (step {step})")
return
# Get paths and initialize writer if needed
paths = get_paths(create_dirs=False)
if paths["tensorboard_dir"].exists():
# Only create writer once
if _tensorboard_writer is None:
from torch.utils.tensorboard.writer import SummaryWriter
_tensorboard_writer = SummaryWriter(paths["tensorboard_dir"])
logger.debug(f"Created tensorboard writer at {paths['tensorboard_dir']}")
# Add scalar using existing writer
_tensorboard_writer.add_scalar(key, value, step)
# No need to close the writer - it will be closed at process exit
else:
logger.debug(f"Tensorboard metric: {key}={value} (step {step})")
# Initialize logging on module import
env = os.getenv("APP_ENV", "development")
_init_logging(env=env)
# Log project root on import
logger.info(f"Project root path: {PROJ_ROOT}")
logger.debug(f"Running in {env} environment")

@ -1,7 +1,4 @@
from typing import List, Union
import torch import torch
import torch.nn.functional as F
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
@ -18,9 +15,7 @@ class CustomHuggingFaceEmbeddings(Embeddings):
- "query": uses mean pooling over tokens (weighted by the attention mask) for query embeddings. - "query": uses mean pooling over tokens (weighted by the attention mask) for query embeddings.
""" """
def __init__( def __init__(self, model_name=DEFAULT_MODEL_NAME, default_mode="sentence"):
self, model_name: str = DEFAULT_MODEL_NAME, default_mode: str = "sentence"
):
self.model_name = model_name self.model_name = model_name
# Set device to GPU if available, else CPU # Set device to GPU if available, else CPU
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
@ -29,7 +24,8 @@ class CustomHuggingFaceEmbeddings(Embeddings):
self.default_mode = default_mode # "sentence" or "query" self.default_mode = default_mode # "sentence" or "query"
self.model.eval() # Set model to evaluation mode self.model.eval() # Set model to evaluation mode
def get_embedding(self, text: Union[str, List[str]], mode: str = None): def get_embedding(self, text, mode=None):
"""Get embeddings for text using specified mode"""
if mode is None: if mode is None:
mode = self.default_mode mode = self.default_mode
assert mode in ( assert mode in (
@ -59,14 +55,14 @@ class CustomHuggingFaceEmbeddings(Embeddings):
vectors = output.last_hidden_state[:, 0, :] vectors = output.last_hidden_state[:, 0, :]
return vectors return vectors
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts):
""" """
Compute embeddings for a list of documents (using sentence mode). Compute embeddings for a list of documents (using sentence mode).
""" """
vectors = self.get_embedding(texts, mode="sentence") vectors = self.get_embedding(texts, mode="sentence")
return vectors.cpu().numpy().tolist() return vectors.cpu().numpy().tolist()
def embed_query(self, text: str) -> List[float]: def embed_query(self, text):
""" """
Compute an embedding for a single query. Compute an embedding for a single query.
""" """

@ -4,33 +4,21 @@ This module provides utility functions for handling chat-based tool interactions
and calculating rewards based on the quality of responses. and calculating rewards based on the quality of responses.
""" """
import asyncio import inspect
import json import json
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from pathlib import Path
import nest_asyncio import nest_asyncio
import numpy as np import numpy as np
import torch import torch
from loguru import logger
from search_module import get_qa_dataset, search from src.config import log_metric, logger
from src.search_module import get_qa_dataset, search
# Setup loguru
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
logger.add(
log_dir / "rl_helpers_{time}.log",
rotation="500 MB",
retention="10 days",
level="DEBUG",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
)
# Apply nest_asyncio for supporting async operations in notebooks
nest_asyncio.apply() nest_asyncio.apply()
from typing import Callable, List
from trl.trainer.grpo_trainer import apply_chat_template from trl.trainer.grpo_trainer import apply_chat_template
@ -238,6 +226,8 @@ def run_tool_calls(chat_states):
Execute tool calls found in chat states. Execute tool calls found in chat states.
""" """
logger.debug(f"Running tool calls for {len(chat_states)} chat states") logger.debug(f"Running tool calls for {len(chat_states)} chat states")
total_retries = 0
for chat_state in chat_states: for chat_state in chat_states:
if chat_state.get("finished"): if chat_state.get("finished"):
logger.debug("Chat state already finished, skipping tool calls") logger.debug("Chat state already finished, skipping tool calls")
@ -256,9 +246,14 @@ def run_tool_calls(chat_states):
elif len(function_calls) == 1: elif len(function_calls) == 1:
function_call = function_calls[0] function_call = function_calls[0]
query = function_call["function"]["parameters"]["query"] query = function_call["function"]["parameters"]["query"]
logger.info(f"Executing search with query: {query}") logger.info(f"🔍 Search Query: {query}")
results = search(query, return_type=str, results=2) results = search(query, return_type=str, results=2)
chat_state["messages"].append({"role": "ipython", "content": results}) chat_state["messages"].append({"role": "ipython", "content": results})
# Count retries
retries = len(extract_json_objects(assistant_response))
total_retries += retries
logger.debug("Added search results to chat state") logger.debug("Added search results to chat state")
except Exception as e: except Exception as e:
logger.error(f"Error during tool call: {str(e)}") logger.error(f"Error during tool call: {str(e)}")
@ -332,7 +327,12 @@ def get_chat_num_tokens(chat_state, tokenizer):
def run_agent( def run_agent(
generate_fn, tokenizer, questions, max_generations=5, max_new_tokens=4096 generate_fn,
tokenizer,
questions,
max_generations=5,
max_new_tokens=4096,
correct_contents=None,
): ):
""" """
Run the agent to completion for a batch of questions. Run the agent to completion for a batch of questions.
@ -343,6 +343,11 @@ def run_agent(
) )
chat_states = [get_initial_chat(q) for q in questions] chat_states = [get_initial_chat(q) for q in questions]
# Add correct content to chat states if provided
if correct_contents:
for chat_state, correct_content in zip(chat_states, correct_contents):
chat_state["correct_content"] = correct_content
# set the initial_prompt length # set the initial_prompt length
for i, chat_state in enumerate(chat_states): for i, chat_state in enumerate(chat_states):
chat_state["initial_length"] = get_chat_num_tokens(chat_state, tokenizer) chat_state["initial_length"] = get_chat_num_tokens(chat_state, tokenizer)
@ -350,7 +355,7 @@ def run_agent(
# agent loop # agent loop
for i in range(max_generations): for i in range(max_generations):
logger.info(f"Starting generation step {i+1}/{max_generations}") logger.info(f"Starting generation step {i + 1}/{max_generations}")
chat_states = run_agent_generations(generate_fn, tokenizer, chat_states) chat_states = run_agent_generations(generate_fn, tokenizer, chat_states)
chat_states = check_finished_chats(chat_states) chat_states = check_finished_chats(chat_states)
chat_states = run_tool_calls(chat_states) chat_states = run_tool_calls(chat_states)
@ -359,7 +364,7 @@ def run_agent(
) )
finished_count = sum(1 for state in chat_states if state.get("finished")) finished_count = sum(1 for state in chat_states if state.get("finished"))
logger.info( logger.info(
f"Finished {finished_count}/{len(chat_states)} chat states after step {i+1}" f"Finished {finished_count}/{len(chat_states)} chat states after step {i + 1}"
) )
logger.info("Agent run completed") logger.info("Agent run completed")
@ -440,15 +445,26 @@ async def verify(student_answer: str, question: str, answer: str) -> bool:
def check_student_answers( def check_student_answers(
questions: List[str], questions: list[str],
answers: List[str], answers: list[str],
student_answers: List[str], student_answers: list, # Can be strings or dicts
vllm_generate_func: Callable[[List[str]], List[str]], vllm_generate_func,
tokenizer, tokenizer,
log_file: str = "qa_log.txt", log_file=None,
) -> List[bool]: ) -> list[bool]:
""" """
Evaluates a list of student answers against the true answers using a vLLM generate function. Evaluates a list of student answers against the true answers using a vLLM generate function.
Args:
questions: List of questions
answers: List of correct answers
student_answers: List of student answers to evaluate
vllm_generate_func: Function to generate verification responses
tokenizer: Tokenizer for formatting prompts
log_file: Optional path to write detailed results
Returns:
List of boolean results (True for correct answers)
""" """
logger.info(f"Checking {len(questions)} student answers") logger.info(f"Checking {len(questions)} student answers")
@ -463,12 +479,15 @@ def check_student_answers(
prompts = [] prompts = []
for question, answer, student_ans in zip(questions, answers, student_answers): for question, answer, student_ans in zip(questions, answers, student_answers):
prompt_text = ( prompt_text = (
"You are grading a student's answer. For the following question, " "You are grading a student's answer to a question. For the following question, "
"compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer is correct, or 'No' if it is completely incorrect.\n\n" "compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer contains the correct information, "
"even if it's not an exact match. If the student's answer doesn't contain the right information or is completely incorrect, reply with 'No'.\n\n"
f"Question: {question}\n" f"Question: {question}\n"
f"Correct Answer: {answer}\n" f"Correct Answer: {answer}\n"
f"Student Answer: {student_ans}\n" f"Student Answer: {student_ans}\n\n"
"Your response should be just 'Yes' or 'No'."
) )
formatted_prompt = tokenizer.apply_chat_template( formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_text}], [{"role": "user", "content": prompt_text}],
tokenize=False, tokenize=False,
@ -481,10 +500,15 @@ def check_student_answers(
responses = vllm_generate_func(prompts) responses = vllm_generate_func(prompts)
responses_text = [] responses_text = []
for response in responses: for response in responses:
# Handle different response formats
if hasattr(response, "outputs"): if hasattr(response, "outputs"):
try:
responses_text.append(response.outputs[0].text) responses_text.append(response.outputs[0].text)
except (AttributeError, IndexError):
# Fallback for simple string responses
responses_text.append(str(response))
else: else:
responses_text.append(response) responses_text.append(str(response))
logger.debug(f"Got {len(responses_text)} verification responses") logger.debug(f"Got {len(responses_text)} verification responses")
results = [] results = []
@ -495,34 +519,108 @@ def check_student_answers(
logger.info(f"Verification complete. {sum(results)}/{len(results)} answers correct") logger.info(f"Verification complete. {sum(results)}/{len(results)} answers correct")
# Append the QA details and verifier's response to the specified log file # Append the QA details and verifier's response to the specified log file
if log_file:
with open(log_file, "a") as file: with open(log_file, "a") as file:
for question, answer, student_ans, verifier_response in zip( timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
questions, answers, student_answers, responses_text file.write(f"\n📝 === QA Evaluation at {timestamp} ===\n")
file.write(f"📂 File: {__file__}\n")
# Get current frame info safely
frame = inspect.currentframe()
if frame:
file.write(f"📍 Line: {frame.f_lineno}\n")
# Don't forget to delete the frame to avoid reference cycles
del frame
file.write("=" * 80 + "\n")
for i, (question, answer, student_ans, verifier_response) in enumerate(
zip(questions, answers, student_answers, responses_text)
): ):
file.write("Question: " + question + "\n") file.write(f"\n❓ Question {i+1}:\n")
file.write("Correct Answer: " + answer + "\n")
file.write("Student Answer: " + student_ans + "\n")
file.write("Verifier said: " + verifier_response + "\n")
file.write("-" * 40 + "\n") file.write("-" * 40 + "\n")
file.write(f"📋 Question: {question}\n")
file.write(f"✅ Correct Answer: {answer}\n")
file.write(f"👨‍🎓 Student Answer: {student_ans}\n")
file.write(f"🔍 Verifier said: {verifier_response}\n")
# Add search results if available in the chat state
if isinstance(student_ans, dict) and "messages" in student_ans:
# Get messages from dict
messages = student_ans.get("messages", [])
search_results = [
msg.get("content", "")
for msg in messages
if msg.get("role") == "ipython"
]
if search_results:
file.write("\n🔎 Search Results:\n")
for j, result in enumerate(search_results, 1):
file.write(f"\nSearch {j}:\n{result}\n")
file.write("-" * 40 + "\n")
file.write(
f"\n📊 Summary: {sum(results)}/{len(results)} answers correct ({sum(results)/len(results)*100:.2f}%)\n"
)
file.write("=" * 80 + "\n\n")
return results return results
# Reward Functions # Reward Functions
def build_reward_correctness_fn(generate_fn, tokenizer): def build_reward_correctness_fn(generate_fn, tokenizer, log_file=None):
def reward_correctness(prompts, completions, **reward_kwargs): def reward_correctness(prompts, completions, **reward_kwargs):
teacher_answers = reward_kwargs["answer"] teacher_answers = reward_kwargs["answer"]
student_answers = [ student_answers = [
completion["messages"][-1]["content"] for completion in completions completion["messages"][-1]["content"] for completion in completions
] ]
# Log non-exact matches
for i, (student, teacher) in enumerate(zip(student_answers, teacher_answers)):
if student.strip().lower() != teacher.strip().lower():
logger.warning(
f"Non-exact match at index {i}:\n"
f"Student: {student}\n"
f"Teacher: {teacher}"
)
correct = check_student_answers( correct = check_student_answers(
prompts, prompts,
teacher_answers, teacher_answers,
student_answers, student_answers,
vllm_generate_func=generate_fn, vllm_generate_func=generate_fn,
tokenizer=tokenizer, tokenizer=tokenizer,
log_file=log_file,
)
# Log correctness metrics with length info
log_metric(
"rewards/correctness", np.mean(correct), reward_kwargs.get("step", 0)
)
log_metric(
"rewards/correctness_std", np.std(correct), reward_kwargs.get("step", 0)
)
# Log length metrics
student_lengths = [len(ans.strip()) for ans in student_answers]
teacher_lengths = [len(ans.strip()) for ans in teacher_answers]
log_metric(
"metrics/avg_student_length",
np.mean(student_lengths),
reward_kwargs.get("step", 0),
) )
log_metric(
"metrics/avg_teacher_length",
np.mean(teacher_lengths),
reward_kwargs.get("step", 0),
)
log_metric(
"metrics/length_ratio",
np.mean(student_lengths) / np.mean(teacher_lengths),
reward_kwargs.get("step", 0),
)
return correct return correct
return reward_correctness return reward_correctness
@ -535,14 +633,23 @@ def reward_formatting(prompts, completions, **reward_kwargs):
for message in chat["messages"]: for message in chat["messages"]:
if "Error during" in message["content"]: if "Error during" in message["content"]:
has_error[i] = True has_error[i] = True
logger.warning(f"Error in chat {i}: {message['content']}")
break break
return [0.7 if not e else 0 for e in has_error]
rewards = [0.7 if not e else 0 for e in has_error]
# Log formatting metrics
log_metric("rewards/formatting", np.mean(rewards), reward_kwargs.get("step", 0))
log_metric("rewards/formatting_std", np.std(rewards), reward_kwargs.get("step", 0))
log_metric("metrics/error_rate", np.mean(has_error), reward_kwargs.get("step", 0))
return rewards
def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[float]: def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[float]:
""" """
Reward function that encourages optimal retry behavior by counting total function calls Reward function that encourages optimal retry behavior by only rewarding completions
across all assistant messages in the conversation. where every assistant message contains at most 1 JSON object.
""" """
rewards: list[float] = [] rewards: list[float] = []
@ -558,22 +665,62 @@ def reward_retry_behavior(completions: list[dict], **reward_kwargs) -> list[floa
rewards.append(0.0) rewards.append(0.0)
continue continue
# Count total function calls across all messages # Check if every message has at most 1 JSON object
total_retries: int = 0 has_multiple_json = False
total_json_objects = 0
for msg in assistant_msgs: for msg in assistant_msgs:
total_retries += len(extract_json_objects(msg)) json_objects = extract_json_objects(msg)
json_count = len(json_objects)
total_json_objects += json_count
if json_count > 1:
has_multiple_json = True
logger.warning(
f"Message contains {json_count} JSON objects, which exceeds the limit of 1"
)
break
# Calculate reward using modified sigmoid function # Only reward if no message has multiple JSON objects
x: float = float(total_retries - 4) # Center peak at 4 retries if has_multiple_json:
base_reward: float = 1.0 / (1.0 + np.exp(-x + abs(x) / 2)) rewards.append(0.0)
else:
# Base reward is 1.0 if constraint is met
base_reward = 1.0
# Additional penalty for excessive retries # Slight penalty for having too many total JSON objects across all messages
if total_retries > 6: if total_json_objects > 4:
penalty: float = 0.2 * (total_retries - 6) penalty = 0.1 * (total_json_objects - 4)
base_reward = max(0.1, base_reward - penalty) base_reward = max(0.2, base_reward - penalty)
logger.debug(
f"Applied penalty for {total_json_objects} total JSON objects: {penalty}"
)
rewards.append(base_reward) rewards.append(base_reward)
# Log retry behavior metrics
log_metric("rewards/retry_behavior", np.mean(rewards), reward_kwargs.get("step", 0))
log_metric(
"rewards/retry_behavior_std", np.std(rewards), reward_kwargs.get("step", 0)
)
log_metric(
"metrics/avg_json_per_msg",
np.mean(
[
len(extract_json_objects(msg["content"]))
for completion in completions
for msg in completion["messages"]
if msg["role"] == "assistant"
]
),
reward_kwargs.get("step", 0),
)
log_metric(
"metrics/multiple_json_violation_rate",
np.mean([0.0 if rewards[i] > 0.0 else 1.0 for i in range(len(rewards))]),
reward_kwargs.get("step", 0),
)
return rewards return rewards
@ -599,6 +746,11 @@ def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs):
] ]
logger.debug(f"Found {len(search_results)} search results for prompt {i}") logger.debug(f"Found {len(search_results)} search results for prompt {i}")
# Log ground truth chunk and searched chunks
logger.info(f"📝 Ground Truth Chunk: {correct_content}")
for j, result in enumerate(search_results):
logger.info(f"🔍 Searched Chunk {j+1}: {result}")
# Check if any search hit the correct chunk content # Check if any search hit the correct chunk content
found_correct_chunk = False found_correct_chunk = False
for result in search_results: for result in search_results:
@ -609,30 +761,145 @@ def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs):
) )
break break
if not found_correct_chunk:
logger.warning(
f"Failed to find correct chunk for prompt {i}:\n"
f"Search results: {[r[:100] + '...' for r in search_results]}"
)
reward = 1.0 if found_correct_chunk else 0.0 reward = 1.0 if found_correct_chunk else 0.0
rewards.append(reward) rewards.append(reward)
logger.debug(f"Reward for prompt {i}: {reward}") logger.debug(f"Reward for prompt {i}: {reward}")
logger.info(f"Average reward: {sum(rewards)/len(rewards):.3f}") # Log detailed metrics for debugging
log_metric(
f"debug/chunk_match_{i}",
1 if found_correct_chunk else 0,
reward_kwargs.get("step", 0),
)
log_metric(
f"debug/search_results_count_{i}",
len(search_results),
reward_kwargs.get("step", 0),
)
if search_results:
log_metric(
f"debug/result_length_{i}",
np.mean([len(r.split()) for r in search_results]),
reward_kwargs.get("step", 0),
)
# Log chunk query metrics
log_metric("rewards/chunk_query", np.mean(rewards), reward_kwargs.get("step", 0))
log_metric("rewards/chunk_query_std", np.std(rewards), reward_kwargs.get("step", 0))
log_metric(
"metrics/avg_search_results",
np.mean(
[
len(
[
msg["content"]
for msg in chat_state["messages"]
if msg["role"] == "ipython"
]
)
for chat_state in completions
]
),
reward_kwargs.get("step", 0),
)
log_metric(
"metrics/chunk_match_rate", np.mean(rewards), reward_kwargs.get("step", 0)
)
# Log detailed debugging info
logger.info("Chunk Query Rewards Summary:")
logger.info(f"Total prompts: {len(prompts)}")
logger.info(f"Correct matches: {sum(rewards)}")
logger.info(f"Average reward: {np.mean(rewards):.3f}")
logger.info(f"Reward std: {np.std(rewards):.3f}")
return rewards return rewards
def run_eval(generate_fn, verify_fn, tokenizer): def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=None):
logger.info("Starting evaluation") """
Run evaluation on the test dataset and return results.
Args:
generate_fn: Function to generate completions
verify_fn: Function to verify results
tokenizer: Tokenizer for processing text
output_file: Path to save evaluation results summary
debug_file: Path to save detailed debug information
Returns:
full_chat_states: The chat states from evaluation
"""
train_dataset, test_dataset = get_qa_dataset() train_dataset, test_dataset = get_qa_dataset()
questions = test_dataset["prompt"] questions = test_dataset["prompt"]
logger.info(f"Loaded {len(questions)} test questions")
agentic_outputs = run_agent(generate_fn, tokenizer, questions) agentic_outputs = run_agent(generate_fn, tokenizer, questions)
full_chat_states = agentic_outputs.full_chat_states full_chat_states = agentic_outputs.full_chat_states
final_responses = agentic_outputs.final_response_str final_responses = agentic_outputs.final_response_str
logger.info("Calculating rewards")
rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"]) rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"])
avg_reward = sum(rewards) / len(rewards)
logger.info("EVALUATION RESULTS:") # Calculate results
logger.info(f"Percentage of correct answers: {avg_reward:.3f}") percent_correct = sum(rewards) / len(rewards) * 100
# Log results to console
logger.info("RESULTS:")
logger.info(f"percentage of correct answers: {percent_correct:.2f}%")
logger.info("=" * 30) logger.info("=" * 30)
# Save results to file if specified
if output_file:
try:
with open(output_file, "w") as f:
f.write("EVALUATION RESULTS\n")
f.write("=================\n\n")
f.write(f"Total questions: {len(questions)}\n")
f.write(f"Correct answers: {sum(rewards)}\n")
f.write(f"Percentage correct: {percent_correct:.2f}%\n\n")
f.write("Individual results:\n")
for i, (q, r, resp) in enumerate(
zip(questions, rewards, final_responses)
):
f.write(f"\nQ{i+1}: {q[:100]}...\n")
f.write(f"Correct: {'' if r else ''}\n")
f.write(f"Response: {resp[:150]}...\n")
f.write("-" * 40 + "\n")
logger.info(f"Saved evaluation results to {output_file}")
except Exception as e:
logger.error(f"Error saving results file: {e}")
# Save debug information if specified
if debug_file:
try:
import json
debug_data = []
for i, (q, r, resp, chat) in enumerate(
zip(questions, rewards, final_responses, full_chat_states)
):
debug_data.append(
{
"question_id": i,
"question": q,
"is_correct": bool(r),
"final_response": resp,
"chat_state": {
k: str(v) if isinstance(v, (list, dict)) else v
for k, v in chat.items()
if k != "tokenizer"
},
}
)
with open(debug_file, "w") as f:
json.dump(debug_data, f, indent=2)
logger.info(f"Saved debug information to {debug_file}")
except Exception as e:
logger.error(f"Error saving debug file: {e}")
return full_chat_states return full_chat_states

@ -3,40 +3,33 @@ Search module for RL training loop.
This module provides functions to search through vectorized documents and retrieve question-answer pairs. This module provides functions to search through vectorized documents and retrieve question-answer pairs.
""" """
import pickle
import json import json
import random import random
import asyncio
from typing import List, Tuple, Optional, Union, Dict, Any
from enum import Enum
from pydantic import BaseModel
from langchain.vectorstores import FAISS
from datasets import Dataset from datasets import Dataset
from embeddings import CustomHuggingFaceEmbeddings from langchain.vectorstores import FAISS
from src.config import DATA_DIR, logger
from src.embeddings import CustomHuggingFaceEmbeddings
# Load pre-saved vectorstore # Load pre-saved vectorstore
def load_vectorstore(): def load_vectorstore():
"""Load the pre-saved FAISS index""" """Load the pre-saved FAISS index"""
try: try:
import os
embeddings = CustomHuggingFaceEmbeddings() embeddings = CustomHuggingFaceEmbeddings()
# Load the FAISS index with absolute path # Load the FAISS index from the data directory
index_path = os.path.join( logger.info(f"Loading FAISS index from: {DATA_DIR}")
os.path.dirname(os.path.abspath(__file__)), "faiss_index"
)
print(f"Loading FAISS index from: {index_path}")
vectorstore = FAISS.load_local( vectorstore = FAISS.load_local(
index_path, embeddings, allow_dangerous_deserialization=True str(DATA_DIR), embeddings, allow_dangerous_deserialization=True
) )
print("Successfully loaded FAISS index") logger.info("Successfully loaded FAISS index")
return vectorstore return vectorstore
except Exception as e: except Exception as e:
print(f"Error loading vectorstore: {e}") logger.error(f"Error loading vectorstore: {e}")
import traceback import traceback
traceback.print_exc() logger.debug(traceback.format_exc())
return None return None
@ -44,13 +37,13 @@ def load_vectorstore():
try: try:
vectorstore = load_vectorstore() vectorstore = load_vectorstore()
if vectorstore is None: if vectorstore is None:
print("Warning: FAISS vectorstore could not be loaded.") logger.warning("FAISS vectorstore could not be loaded.")
except Exception as e: except Exception as e:
print(f"Error loading vectorstore: {e}") logger.error(f"Error loading vectorstore: {e}")
vectorstore = None vectorstore = None
def search(query: str, return_type=str, results: int = 5) -> Union[str, List[str]]: def search(query: str, return_type=str, results: int = 5):
""" """
Search for relevant chunks using similarity search. Search for relevant chunks using similarity search.
@ -82,51 +75,36 @@ def search(query: str, return_type=str, results: int = 5) -> Union[str, List[str
# Load questions from saved data # Load questions from saved data
def load_qa_data(): def load_qa_data():
"""Load the pre-generated questions and document chunks""" """Load the pre-generated questions"""
try: try:
import os questions_path = DATA_DIR / "questions.json"
logger.info(f"Loading questions from: {questions_path}")
# Get absolute paths to data files
base_dir = os.path.dirname(os.path.abspath(__file__))
chunks_path = os.path.join(base_dir, "saved_data", "chunks.pkl")
questions_path = os.path.join(base_dir, "saved_data", "questions.json")
print(f"Loading chunks from: {chunks_path}")
print(f"Loading questions from: {questions_path}")
# Load the chunks
with open(chunks_path, "rb") as f:
chunks = pickle.load(f)
# Load the questions # Load the questions
with open(questions_path, "r") as f: with open(questions_path, "r") as f:
questions = json.load(f) questions = json.load(f)
print( logger.info(f"Successfully loaded {len(questions)} questions")
f"Successfully loaded {len(chunks)} chunks and {len(questions)} questions" return questions
)
return chunks, questions
except Exception as e: except Exception as e:
print(f"Error loading QA data: {e}") logger.error(f"Error loading QA data: {e}")
import traceback import traceback
traceback.print_exc() logger.debug(traceback.format_exc())
return None, None return None
# Load chunks and questions when module is imported # Load questions when module is imported
try: try:
chunks, questions = load_qa_data() questions = load_qa_data()
if chunks is None or questions is None: if questions is None:
print("Warning: Could not load QA data.") logger.warning("Could not load QA data.")
except Exception as e: except Exception as e:
print(f"Error initializing QA data: {e}") logger.error(f"Error initializing QA data: {e}")
chunks, questions = None, None questions = None
def get_question_answer( def get_question_answer(idx=None, return_both: bool = True) -> dict:
idx: Optional[int] = None, return_both: bool = True
) -> Union[dict, str]:
""" """
Get a question-answer pair either by index or randomly. Get a question-answer pair either by index or randomly.
@ -148,7 +126,7 @@ def get_question_answer(
qa_pair = questions[idx] qa_pair = questions[idx]
else: else:
raise ValueError( raise ValueError(
f"Index out of range. Must be between 0 and {len(questions)-1}" f"Index out of range. Must be between 0 and {len(questions) - 1}"
) )
question = qa_pair["question"] question = qa_pair["question"]
@ -168,7 +146,7 @@ def get_question_count() -> int:
return len(questions) return len(questions)
def get_qa_dataset(): def get_qa_dataset() -> tuple:
""" """
Return a HuggingFace Dataset containing question and answer pairs. Return a HuggingFace Dataset containing question and answer pairs.

@ -0,0 +1,6 @@
export CUDA_VISIBLE_DEVICES=0
python train_grpo.py

@ -1,192 +0,0 @@
# %%
import torch
# %%
from unsloth import FastLanguageModel, is_bfloat16_supported
max_seq_length = 4096 * 2 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
max_seq_length=max_seq_length,
load_in_4bit=True, # False for LoRA 16bit
fast_inference=True, # Enable vLLM fast inference
max_lora_rank=lora_rank,
gpu_memory_utilization=0.6, # Reduce if out of memory
)
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
], # Remove QKVO if out of memory
lora_alpha=lora_rank,
use_gradient_checkpointing="unsloth", # Enable long context finetuning
random_state=3407,
)
# %%
import re
from datasets import Dataset, load_dataset
from rl_helpers import get_qa_dataset
from search_module import get_question_answer, get_question_count, search
train_dataset, test_dataset = get_qa_dataset()
# %% [markdown]
# <a name="Train"></a>
# ### Train the model
#
# Now set up GRPO Trainer and all configurations!
# %%
import os
os.environ["WANDB_PROJECT"] = "bootstrap-search-rl"
# %%
# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer
import UnslothGRPOTrainerTemp
training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
use_vllm=True, # use vLLM for fast inference!
use_agentic_generate=True, # use agentic generation
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="paged_adamw_8bit",
logging_steps=1,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
per_device_train_batch_size=8,
gradient_accumulation_steps=1, # Increase to 4 for smoother training
num_generations=8, # Decrease if out of memory
max_prompt_length=1024,
max_completion_length=1024,
# num_train_epochs = 1, # Set to 1 for a full training run
max_steps=101,
save_steps=50,
max_grad_norm=0.1,
report_to="none", # Can use Weights & Biases
output_dir="full_local_training",
)
# %%
import rl_helpers
# importlib.reload(rl_helpers)
def agentic_generate(
prompts: list[str],
generate_fn,
max_generations: int = 6,
):
return run_agent(generate_fn, tokenizer, prompts, max_generations)
model.agentic_generate = agentic_generate
from vllm import SamplingParams
verifier_sampling_params = SamplingParams(
temperature=0.1,
top_p=0.95,
max_tokens=4096,
)
def verifier_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=verifier_sampling_params,
)
run_agent = rl_helpers.run_agent
reward_correctness = rl_helpers.build_reward_correctness_fn(
verifier_generate_fn,
tokenizer,
)
reward_formatting = rl_helpers.reward_formatting
import UnslothGRPOTrainerTemp
trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
reward_correctness,
reward_formatting,
],
args=training_args,
train_dataset=train_dataset,
)
# %%
trainer.train()
# %% [markdown]
# <a name="Inference"></a>
# ### Inference
# Now let's try benchmark the model we trained!
# %%
from vllm import SamplingParams
import rl_helpers
sampling_params = SamplingParams(
temperature=0.5,
top_p=0.95,
max_tokens=4096,
)
def eval_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=sampling_params,
lora_request=model.load_lora(
"full_local_training/checkpoint-101"
), # load the trained LoRA
)
rl_helpers.run_eval(
generate_fn=eval_generate_fn,
verify_fn=reward_correctness,
tokenizer=tokenizer,
)
# %%
# eval w/o lora
def eval_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=sampling_params,
)
rl_helpers.run_eval(
generate_fn=eval_generate_fn,
verify_fn=reward_correctness,
tokenizer=tokenizer,
)

@ -1,196 +0,0 @@
# %%
import torch
# %%
from unsloth import FastLanguageModel, is_bfloat16_supported
max_seq_length = 4096 * 2 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="meta-llama/Llama-3.2-1B-Instruct",
max_seq_length=max_seq_length,
load_in_4bit=True, # False for LoRA 16bit
fast_inference=True, # Enable vLLM fast inference
max_lora_rank=lora_rank,
gpu_memory_utilization=0.6, # Reduce if out of memory
)
print(tokenizer.chat_template) # See what format Qwen expects
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
], # Remove QKVO if out of memory
lora_alpha=lora_rank,
use_gradient_checkpointing="unsloth", # Enable long context finetuning
random_state=3407,
)
# %%
import re
from datasets import Dataset, load_dataset
from rl_helpers import get_qa_dataset
from search_module import get_question_answer, get_question_count, search
train_dataset, test_dataset = get_qa_dataset()
# %% [markdown]
# <a name="Train"></a>
# ### Train the model
#
# Now set up GRPO Trainer and all configurations!
# %%
import os
os.environ["WANDB_PROJECT"] = "bootstrap-search-rl"
# %%
# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer
import UnslothGRPOTrainerTemp
training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
use_vllm=True, # use vLLM for fast inference!
use_agentic_generate=True, # use agentic generation
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="paged_adamw_8bit",
logging_steps=1,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
per_device_train_batch_size=8,
gradient_accumulation_steps=1, # Increase to 4 for smoother training
num_generations=8, # Decrease if out of memory
max_prompt_length=1024,
max_completion_length=1024,
# num_train_epochs = 1, # Set to 1 for a full training run
max_steps=101,
save_steps=50,
max_grad_norm=0.1,
report_to="none", # Can use Weights & Biases
output_dir="full_local_training",
)
# %%
import rl_helpers
# importlib.reload(rl_helpers)
def agentic_generate(
prompts: list[str],
generate_fn,
max_generations: int = 6,
):
return run_agent(generate_fn, tokenizer, prompts, max_generations)
model.agentic_generate = agentic_generate
from vllm import SamplingParams
verifier_sampling_params = SamplingParams(
temperature=0.1,
top_p=0.95,
max_tokens=4096,
)
def verifier_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=verifier_sampling_params,
)
run_agent = rl_helpers.run_agent
reward_correctness = rl_helpers.build_reward_correctness_fn(
verifier_generate_fn,
tokenizer,
)
reward_formatting = rl_helpers.reward_formatting
import UnslothGRPOTrainerTemp
trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
reward_correctness,
reward_formatting,
],
args=training_args,
train_dataset=train_dataset,
)
# %%
trainer.train()
# %% [markdown]
# <a name="Inference"></a>
# ### Inference
# Now let's try benchmark the model we trained!
# %%
from vllm import SamplingParams
import rl_helpers
sampling_params = SamplingParams(
temperature=0.5,
top_p=0.95,
max_tokens=4096,
)
def eval_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=sampling_params,
lora_request=model.load_lora(
"full_local_training/checkpoint-101"
), # load the trained LoRA
)
rl_helpers.run_eval(
generate_fn=eval_generate_fn,
verify_fn=reward_correctness,
tokenizer=tokenizer,
)
# %%
# eval w/o lora
def eval_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=sampling_params,
)
rl_helpers.run_eval(
generate_fn=eval_generate_fn,
verify_fn=reward_correctness,
tokenizer=tokenizer,
)

@ -0,0 +1,124 @@
import os
from unsloth import FastLanguageModel, is_bfloat16_supported
import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
# Import reward functions
from src.rl_helpers import (
build_reward_correctness_fn,
get_qa_dataset,
reward_exact_match_chunk_query,
reward_formatting,
reward_retry_behavior,
run_agent,
)
from src.config import (
MODEL_CONFIG,
MODEL_NAME,
OUTPUT_DIR,
TRAINING_CONFIG,
get_sampling_params,
init_training_dirs,
logger,
update_log_path,
)
# Initialize training directories
paths = init_training_dirs()
# Update logger to use the training directory
update_log_path(paths["log_dir"])
logger.info(f"Training output directory: {paths['output_dir']}")
logger.info(f"Logs are being saved to both ./logs and {paths['log_dir']}")
# Initialize model and tokenizer
logger.info(f"Initializing model {MODEL_NAME}")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME,
max_seq_length=MODEL_CONFIG["max_seq_length"],
load_in_4bit=True, # False for LoRA 16bit
fast_inference=True, # Enable vLLM fast inference
max_lora_rank=MODEL_CONFIG["lora_rank"],
gpu_memory_utilization=MODEL_CONFIG["gpu_memory_utilization"],
)
# Setup LoRA
logger.info("Setting up LoRA adapter")
model = FastLanguageModel.get_peft_model(
model,
r=MODEL_CONFIG["lora_rank"],
target_modules=MODEL_CONFIG["target_modules"],
lora_alpha=MODEL_CONFIG["lora_rank"],
use_gradient_checkpointing=True, # Enable long context finetuning
random_state=3407,
)
# Load datasets
logger.info("Loading datasets")
train_dataset, test_dataset = get_qa_dataset()
logger.info(
f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples"
)
# Setup training arguments
logger.info("Setting up training arguments")
training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
use_vllm=True, # use vLLM for fast inference!
use_agentic_generate=True, # use agentic generation
**TRAINING_CONFIG,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
output_dir=OUTPUT_DIR,
# report_to="tensorboard", # ❓ Does't have billions of tensorboard files if set report to right here
)
# Setup model generation functions
def agentic_generate(
prompts: list,
generate_fn,
max_generations: int = 10,
):
return run_agent(generate_fn, tokenizer, prompts, max_generations)
model.agentic_generate = agentic_generate
# Setup verifier
logger.info("Setting up verifier")
verifier_sampling_params = get_sampling_params(temperature=0.1)
def verifier_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=verifier_sampling_params,
)
# Setup trainer
logger.info("Initializing trainer")
trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
build_reward_correctness_fn(
verifier_generate_fn,
tokenizer,
log_file=os.path.join(paths["log_dir"], "qa_log.txt"),
),
reward_formatting,
reward_retry_behavior,
reward_exact_match_chunk_query,
],
args=training_args,
train_dataset=train_dataset,
)
# Train the model
if __name__ == "__main__":
logger.info("Starting training")
trainer.train()
logger.info("Training completed")
logger.info(f"Model saved to {OUTPUT_DIR}")
Loading…
Cancel
Save