feat: update config and paths, update data genenration script

main
thinhlpg 3 weeks ago
parent bd1d7ced3b
commit bac5f3b4f7

@ -251,7 +251,7 @@ def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, tem
gr.ChatMessage( gr.ChatMessage(
role="assistant", role="assistant",
content=f"Searching for: {search_query}", content=f"Searching for: {search_query}",
metadata={"title": "🔍 Search", "status": "pending"}, metadata={"title": "🔍 ReZero Query", "status": "pending"},
) )
) )
yield messages yield messages
@ -263,7 +263,7 @@ def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, tem
messages[search_msg_idx] = gr.ChatMessage( messages[search_msg_idx] = gr.ChatMessage(
role="assistant", role="assistant",
content=f"{search_query}", content=f"{search_query}",
metadata={"title": "🔍 Search", "duration": search_duration}, metadata={"title": "🔍 ReZero Query", "duration": search_duration},
) )
yield messages yield messages
display_results = format_search_results(results) display_results = format_search_results(results)
@ -630,7 +630,7 @@ def create_tavily_tab(model, tokenizer, assistant_marker, system_prompt, tempera
gr.ChatMessage( gr.ChatMessage(
role="assistant", role="assistant",
content=f"Searching (Tavily) for: {search_query}", content=f"Searching (Tavily) for: {search_query}",
metadata={"title": "🔍 Tavily Search", "status": "pending"}, metadata={"title": "🔍 ReZero Query", "status": "pending"},
) )
) )
yield messages yield messages
@ -664,7 +664,7 @@ def create_tavily_tab(model, tokenizer, assistant_marker, system_prompt, tempera
messages[search_msg_idx] = gr.ChatMessage( messages[search_msg_idx] = gr.ChatMessage(
role="assistant", role="assistant",
content=f"{search_query}", content=f"{search_query}",
metadata={"title": "🔍 Tavily Search", "duration": search_duration}, metadata={"title": "🔍 ReZero Query", "duration": search_duration},
) )
yield messages yield messages

@ -21,18 +21,14 @@ LOG_FOLDER = PROJ_ROOT / "logs"
RETRIEVER_MODEL_REPO_ID = "intfloat/e5-base-v2" RETRIEVER_MODEL_REPO_ID = "intfloat/e5-base-v2"
RETRIEVER_MODEL_DIR = MODEL_DIR / "retriever" RETRIEVER_MODEL_DIR = MODEL_DIR / "retriever"
RETRIEVER_SERVER_PORT = 8001 RETRIEVER_SERVER_PORT = 8001
GENERATOR_MODEL_REPO_ID = "janhq/250404-llama-3.2-3b-instruct-grpo-03-s250" GENERATOR_MODEL_REPO_ID = "Menlo/ReZero-v0.1-llama-3.2-3b-it-grpo-250404"
GENERATOR_MODEL_DIR = MODEL_DIR / "generator" GENERATOR_MODEL_DIR = MODEL_DIR / "generator"
GENERATOR_SERVER_PORT = 8002 GENERATOR_SERVER_PORT = 8002
# Model configuration # Model configuration
# MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
# MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct" MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
# MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
# MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
# MODEL_NAME = "unsloth/Qwen2-1.5B" # Smoke test first
device_id = 1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device() device_id = 1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
@ -40,7 +36,7 @@ OUTPUT_DIR = PROJ_ROOT / f"trainer_output_{MODEL_NAME.replace('/', '_')}_gpu{dev
# Model parameters # Model parameters
MODEL_CONFIG = { MODEL_CONFIG = {
"max_seq_length": 4096 * 6, # 24k tokens -> just try to utiliiz "max_seq_length": 4096 * 2,
"lora_rank": 64, # Larger rank = smarter, but slower "lora_rank": 64, # Larger rank = smarter, but slower
"gpu_memory_utilization": 0.6, # Reduce if out of memory "gpu_memory_utilization": 0.6, # Reduce if out of memory
"model_name": MODEL_NAME, "model_name": MODEL_NAME,
@ -68,7 +64,7 @@ TRAINING_CONFIG = {
"per_device_train_batch_size": 8, "per_device_train_batch_size": 8,
"gradient_accumulation_steps": 1, # Increase to 4 for smoother training "gradient_accumulation_steps": 1, # Increase to 4 for smoother training
"num_generations": 6, # Decrease if out of memory "num_generations": 6, # Decrease if out of memory
"max_prompt_length": 4096 * 4 - 2048, "max_prompt_length": 4096 * 2,
"max_completion_length": 2048, "max_completion_length": 2048,
"max_steps": 1000, "max_steps": 1000,
"save_steps": 50, "save_steps": 50,
@ -83,7 +79,7 @@ def get_sampling_params(temperature: float = 0.1) -> SamplingParams:
return SamplingParams( return SamplingParams(
temperature=temperature, temperature=temperature,
top_p=0.95, top_p=0.95,
max_tokens=4096 * 6, max_tokens=4096 * 2,
) )

@ -1,14 +1,11 @@
""" """
This script performs two main tasks: This script performs two main tasks:
1. It loads a markdown document, splits it into chunks, generates embeddings, 1. It loads a markdown document, splits it into chunks, generates embeddings,
and builds a FAISS index (which is saved locally). and builds a FAISS index with the original and paraphrased chunks (which is saved locally).
2. It generates QA pairs from the document using llama. 2. It generates QA pairs from the document using llama.
For each chunk (using a sliding window for context), it generates multiple question-answer pairs For each chunk (using a sliding window for context), it generates multiple question-answer pairs
with different difficulties. The generation is performed in batch with one retry for failed prompts. with different difficulties. The generation is performed in batch with one retry for failed prompts.
Successfully generated QA pairs are saved to "data/questions.json". Successfully generated QA pairs are saved to "data/questions.jsonl".
Requirements:
pip install langchain faiss-cpu unsloth vllm
""" """
import json import json
@ -21,12 +18,12 @@ project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root)) sys.path.append(str(project_root))
import pandas as pd import pandas as pd
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
# ========= Part 1: Document Processing and Embedding Generation =========
# Load and split the markdown document using LangChain
from langchain_community.document_loaders import UnstructuredMarkdownLoader from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.vectorstores import FAISS from langchain_community.vectorstores import FAISS
from unsloth import FastLanguageModel
from vllm import SamplingParams
from config import DATA_DIR, logger from config import DATA_DIR, logger
from src.embeddings import CustomHuggingFaceEmbeddings from src.embeddings import CustomHuggingFaceEmbeddings
@ -57,13 +54,10 @@ vectorstore = FAISS.from_documents(chunks, embeddings)
vectorstore.save_local(str(DATA_DIR)) vectorstore.save_local(str(DATA_DIR))
logger.info(f"Saved FAISS index to {DATA_DIR}") logger.info(f"Saved FAISS index to {DATA_DIR}")
# TODO: add the paraphrased chunks to the vector store
# ========= Part 2: QA Generation using Llama Backend ========= # ========= Part 2: QA Generation using Llama Backend =========
# Setup Llama backend via unsloth and vLLM # Setup Llama backend via unsloth and vLLM
from unsloth import FastLanguageModel
from vllm import SamplingParams
# Load the Llama model (adjust parameters as needed) # Load the Llama model (adjust parameters as needed)
model, tokenizer = FastLanguageModel.from_pretrained( model, tokenizer = FastLanguageModel.from_pretrained(
@ -81,6 +75,127 @@ sampling_params = SamplingParams(
max_tokens=4096, max_tokens=4096,
) )
# Define paraphrasing styles and parameters
PARAPHRASE_PROMPTS = [
"""Rewrite this text in a formal, scholarly tone. Keep it very concise - summarize in 1-2 short sentences. Only output the paraphrased text:
TEXT: {text}""",
"""Rewrite this text in a clear, simple way that's easy to understand. Provide a medium-length explanation with key details. Only output the paraphrased text:
TEXT: {text}""",
"""Rewrite this text in a vivid, engaging style. Expand on the details and provide a comprehensive, detailed version. Only output the paraphrased text:
TEXT: {text}""",
]
# Sampling parameters for different lengths
sampling_params_short = SamplingParams(
temperature=0.3,
top_p=0.95,
max_tokens=64, # Short responses
)
sampling_params_medium = SamplingParams(
temperature=0.3,
top_p=0.95,
max_tokens=256, # Medium responses
)
sampling_params_long = SamplingParams(
temperature=0.3,
top_p=0.95,
max_tokens=512, # Long responses
)
def generate_paraphrases(text: str) -> list:
"""
Generate three different paraphrased versions with varying lengths.
Args:
text: Text to paraphrase
Returns:
List of three paraphrased versions (short, medium, long)
"""
responses = []
sampling_params_list = [
sampling_params_short,
sampling_params_medium,
sampling_params_long,
]
for prompt_template, sampling_params in zip(PARAPHRASE_PROMPTS, sampling_params_list):
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_template.format(text=text)}],
tokenize=False,
add_generation_prompt=True,
)
output = model.fast_generate([formatted_prompt], sampling_params=sampling_params)
responses.append(output[0].outputs[0].text)
return responses
# Paraphrase all chunks and add to vector store
logger.info("Paraphrasing chunks and adding to vector store...")
all_paraphrased = []
chunk_ids = []
for i, chunk in enumerate(chunks):
# Get paragraphs from chunk
paragraphs = [p.strip() for p in chunk.page_content.split("\n\n") if p.strip()]
for paragraph in paragraphs:
# Generate 3 paraphrased versions
paraphrased_versions = generate_paraphrases(paragraph)
# Save original paragraph ID for reference
for version in paraphrased_versions:
all_paraphrased.append({"chunk_id": i + 1, "original_paragraph": paragraph, "paraphrased_text": version})
# Save paraphrased chunks to CSV for inspection
paraphrased_df = pd.DataFrame(all_paraphrased)
paraphrased_csv_path = DATA_DIR / "paragraphs_noise.csv"
paraphrased_df.to_csv(paraphrased_csv_path, index=False)
logger.info(f"Saved {len(all_paraphrased)} paraphrased paragraphs to {paraphrased_csv_path}")
paraphrased_docs = [
Document(page_content=item["paraphrased_text"], metadata={"chunk_id": item["chunk_id"], "is_paraphrase": True})
for item in all_paraphrased
]
# Process embeddings in smaller batches to avoid OOM
logger.info(f"Creating FAISS index with {len(paraphrased_docs)} documents in batches")
batch_size = 100 # Process 100 documents at a time
paraphrased_vectorstore = None
for i in range(0, len(paraphrased_docs), batch_size):
batch = paraphrased_docs[i : i + batch_size]
logger.info(f"Processing batch {i // batch_size + 1}/{(len(paraphrased_docs) + batch_size - 1) // batch_size}")
# Create a new FAISS index for this batch
batch_vectorstore = FAISS.from_documents(batch, embeddings)
# Merge with existing index or create a new one
if paraphrased_vectorstore is None:
paraphrased_vectorstore = batch_vectorstore
else:
paraphrased_vectorstore.merge_from(batch_vectorstore)
# Merge with main vectorstore
if paraphrased_vectorstore is not None:
vectorstore.merge_from(paraphrased_vectorstore)
logger.info(f"Updated FAISS index with {len(paraphrased_docs)} paraphrased paragraphs")
# Save the updated vector store
vectorstore.save_local(str(DATA_DIR))
logger.info(f"Saved updated FAISS index to {DATA_DIR}")
else:
logger.warning("No paraphrased documents were processed successfully")
def batch_generate(prompts: list) -> list: def batch_generate(prompts: list) -> list:
""" """
@ -148,9 +263,7 @@ def parse_multiple_qa_output(output: str) -> list:
return qa_pairs return qa_pairs
def generate_question_batch_for_chunks( def generate_question_batch_for_chunks(chunks: list, num_questions: int = 2, difficulty=None) -> list:
chunks: list, num_questions: int = 2, difficulty=None
) -> list:
""" """
Generates QA pairs for multiple chunks in batch. Generates QA pairs for multiple chunks in batch.
@ -233,9 +346,7 @@ def generate_question_batch_for_chunks(
results[idx] = valid_pairs[:num_questions] results[idx] = valid_pairs[:num_questions]
else: else:
results[idx] = None results[idx] = None
logger.warning( logger.warning(f"Retry failed for chunk {idx + 1}: not enough valid QA pairs")
f"Retry failed for chunk {idx + 1}: not enough valid QA pairs"
)
else: else:
results[idx] = None results[idx] = None
logger.warning(f"Retry failed for chunk {idx + 1}: parsing failed") logger.warning(f"Retry failed for chunk {idx + 1}: parsing failed")
@ -245,13 +356,15 @@ def generate_question_batch_for_chunks(
for i, qa_list in enumerate(results): for i, qa_list in enumerate(results):
if qa_list is not None: if qa_list is not None:
for qa in qa_list: for qa in qa_list:
# Get supporting paragraphs by splitting chunk content into paragraphs
supporting_paragraphs = [p.strip() for p in chunk_contents[i].split("\n\n") if p.strip()]
final_questions.append( final_questions.append(
{ {
"chunk_id": chunk_ids[i], "id": str(chunk_ids[i]),
"question": qa[0], "question": qa[0],
"answer": qa[1], "answer": qa[1],
"difficulty": qa[2], "supporting_paragraphs": supporting_paragraphs,
"chunk_content": chunk_contents[i],
} }
) )
@ -260,13 +373,13 @@ def generate_question_batch_for_chunks(
# Generate QA pairs in batch (using a sliding window over the chunks) # Generate QA pairs in batch (using a sliding window over the chunks)
all_questions = generate_question_batch_for_chunks( logger.info("Generating question-answer pairs...")
chunks, num_questions=2, difficulty="medium" all_questions = generate_question_batch_for_chunks(chunks, num_questions=2, difficulty="medium")
)
logger.info(f"Generated {len(all_questions)} QA pairs.") logger.info(f"Generated {len(all_questions)} QA pairs.")
# Save the QA pairs to a JSON file # Save the QA pairs to a JSONL file
questions_path = DATA_DIR / "questions.json" questions_path = DATA_DIR / "questions.jsonl"
with open(questions_path, "w") as f: with open(questions_path, "w") as f:
json.dump(all_questions, f, indent=2) for question in all_questions:
f.write(json.dumps(question) + "\n")
logger.info(f"Saved questions to {questions_path}") logger.info(f"Saved questions to {questions_path}")

@ -12,7 +12,7 @@ from langchain_community.vectorstores import FAISS
from config import DATA_DIR, logger from config import DATA_DIR, logger
from src.embeddings import CustomHuggingFaceEmbeddings from src.embeddings import CustomHuggingFaceEmbeddings
PROCESSED_DATA_DIR = DATA_DIR / "processed" PROCESSED_DATA_DIR = DATA_DIR
# Load pre-saved vectorstore # Load pre-saved vectorstore

@ -2,8 +2,6 @@
Train a model using GRPO (Generative Reward-Penalized Optimization). Train a model using GRPO (Generative Reward-Penalized Optimization).
""" """
import os
from unsloth import FastLanguageModel, is_bfloat16_supported from unsloth import FastLanguageModel, is_bfloat16_supported
import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
@ -64,7 +62,7 @@ model = FastLanguageModel.get_peft_model(
# Load datasets # Load datasets
logger.info("Loading datasets") logger.info("Loading datasets")
train_dataset, test_dataset = get_qa_dataset(randomize=True, test_size=0, seed=42) train_dataset, test_dataset = get_qa_dataset(randomize=False, test_size=0.1, seed=42)
logger.info(f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples") logger.info(f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples")
# Setup training arguments # Setup training arguments
@ -76,7 +74,7 @@ training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
bf16=is_bfloat16_supported(), bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(), fp16=not is_bfloat16_supported(),
output_dir=OUTPUT_DIR, output_dir=OUTPUT_DIR,
reward_weights=[2.0, 1.0, 1.0, 1.0], reward_weights=[4.0, 2.0, 1.0, 1.0, 1.0, 1.0],
) )
@ -85,6 +83,7 @@ def agentic_generate(
prompts: list, prompts: list,
generate_fn, generate_fn,
max_generations: int = 32, max_generations: int = 32,
max_new_tokens: int = 4096 * 2,
): ):
# Create agent with appropriate adapter based on tokenizer # Create agent with appropriate adapter based on tokenizer
tokenizer_name = tokenizer.name_or_path.lower() tokenizer_name = tokenizer.name_or_path.lower()
@ -98,7 +97,7 @@ def agentic_generate(
raise ValueError(f"Unsupported tokenizer: {tokenizer_name}") raise ValueError(f"Unsupported tokenizer: {tokenizer_name}")
agent = Agent(adapter) agent = Agent(adapter)
return agent.run_agent(generate_fn, tokenizer, prompts, max_generations) return agent.run_agent(generate_fn, tokenizer, prompts, max_generations, max_new_tokens=max_new_tokens)
model.agentic_generate = agentic_generate model.agentic_generate = agentic_generate
@ -128,8 +127,8 @@ trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
reward_format, reward_format,
reward_retry, reward_retry,
reward_em_chunk, reward_em_chunk,
# reward_search_strategy, reward_search_strategy,
# reward_search_diversity, reward_search_diversity,
], ],
args=training_args, args=training_args,
train_dataset=train_dataset, train_dataset=train_dataset,

Loading…
Cancel
Save