feat: update config and paths, update data genenration script

main
thinhlpg 3 weeks ago
parent bd1d7ced3b
commit bac5f3b4f7

@ -251,7 +251,7 @@ def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, tem
gr.ChatMessage(
role="assistant",
content=f"Searching for: {search_query}",
metadata={"title": "🔍 Search", "status": "pending"},
metadata={"title": "🔍 ReZero Query", "status": "pending"},
)
)
yield messages
@ -263,7 +263,7 @@ def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, tem
messages[search_msg_idx] = gr.ChatMessage(
role="assistant",
content=f"{search_query}",
metadata={"title": "🔍 Search", "duration": search_duration},
metadata={"title": "🔍 ReZero Query", "duration": search_duration},
)
yield messages
display_results = format_search_results(results)
@ -630,7 +630,7 @@ def create_tavily_tab(model, tokenizer, assistant_marker, system_prompt, tempera
gr.ChatMessage(
role="assistant",
content=f"Searching (Tavily) for: {search_query}",
metadata={"title": "🔍 Tavily Search", "status": "pending"},
metadata={"title": "🔍 ReZero Query", "status": "pending"},
)
)
yield messages
@ -664,7 +664,7 @@ def create_tavily_tab(model, tokenizer, assistant_marker, system_prompt, tempera
messages[search_msg_idx] = gr.ChatMessage(
role="assistant",
content=f"{search_query}",
metadata={"title": "🔍 Tavily Search", "duration": search_duration},
metadata={"title": "🔍 ReZero Query", "duration": search_duration},
)
yield messages

@ -21,18 +21,14 @@ LOG_FOLDER = PROJ_ROOT / "logs"
RETRIEVER_MODEL_REPO_ID = "intfloat/e5-base-v2"
RETRIEVER_MODEL_DIR = MODEL_DIR / "retriever"
RETRIEVER_SERVER_PORT = 8001
GENERATOR_MODEL_REPO_ID = "janhq/250404-llama-3.2-3b-instruct-grpo-03-s250"
GENERATOR_MODEL_REPO_ID = "Menlo/ReZero-v0.1-llama-3.2-3b-it-grpo-250404"
GENERATOR_MODEL_DIR = MODEL_DIR / "generator"
GENERATOR_SERVER_PORT = 8002
# Model configuration
# MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
# MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
# MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
# MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
# MODEL_NAME = "unsloth/Qwen2-1.5B" # Smoke test first
device_id = 1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
@ -40,7 +36,7 @@ OUTPUT_DIR = PROJ_ROOT / f"trainer_output_{MODEL_NAME.replace('/', '_')}_gpu{dev
# Model parameters
MODEL_CONFIG = {
"max_seq_length": 4096 * 6, # 24k tokens -> just try to utiliiz
"max_seq_length": 4096 * 2,
"lora_rank": 64, # Larger rank = smarter, but slower
"gpu_memory_utilization": 0.6, # Reduce if out of memory
"model_name": MODEL_NAME,
@ -68,7 +64,7 @@ TRAINING_CONFIG = {
"per_device_train_batch_size": 8,
"gradient_accumulation_steps": 1, # Increase to 4 for smoother training
"num_generations": 6, # Decrease if out of memory
"max_prompt_length": 4096 * 4 - 2048,
"max_prompt_length": 4096 * 2,
"max_completion_length": 2048,
"max_steps": 1000,
"save_steps": 50,
@ -83,7 +79,7 @@ def get_sampling_params(temperature: float = 0.1) -> SamplingParams:
return SamplingParams(
temperature=temperature,
top_p=0.95,
max_tokens=4096 * 6,
max_tokens=4096 * 2,
)

@ -1,14 +1,11 @@
"""
This script performs two main tasks:
1. It loads a markdown document, splits it into chunks, generates embeddings,
and builds a FAISS index (which is saved locally).
and builds a FAISS index with the original and paraphrased chunks (which is saved locally).
2. It generates QA pairs from the document using llama.
For each chunk (using a sliding window for context), it generates multiple question-answer pairs
with different difficulties. The generation is performed in batch with one retry for failed prompts.
Successfully generated QA pairs are saved to "data/questions.json".
Requirements:
pip install langchain faiss-cpu unsloth vllm
Successfully generated QA pairs are saved to "data/questions.jsonl".
"""
import json
@ -21,12 +18,12 @@ project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))
import pandas as pd
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
# ========= Part 1: Document Processing and Embedding Generation =========
# Load and split the markdown document using LangChain
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.vectorstores import FAISS
from unsloth import FastLanguageModel
from vllm import SamplingParams
from config import DATA_DIR, logger
from src.embeddings import CustomHuggingFaceEmbeddings
@ -57,13 +54,10 @@ vectorstore = FAISS.from_documents(chunks, embeddings)
vectorstore.save_local(str(DATA_DIR))
logger.info(f"Saved FAISS index to {DATA_DIR}")
# TODO: add the paraphrased chunks to the vector store
# ========= Part 2: QA Generation using Llama Backend =========
# Setup Llama backend via unsloth and vLLM
from unsloth import FastLanguageModel
from vllm import SamplingParams
# Load the Llama model (adjust parameters as needed)
model, tokenizer = FastLanguageModel.from_pretrained(
@ -81,6 +75,127 @@ sampling_params = SamplingParams(
max_tokens=4096,
)
# Define paraphrasing styles and parameters
PARAPHRASE_PROMPTS = [
"""Rewrite this text in a formal, scholarly tone. Keep it very concise - summarize in 1-2 short sentences. Only output the paraphrased text:
TEXT: {text}""",
"""Rewrite this text in a clear, simple way that's easy to understand. Provide a medium-length explanation with key details. Only output the paraphrased text:
TEXT: {text}""",
"""Rewrite this text in a vivid, engaging style. Expand on the details and provide a comprehensive, detailed version. Only output the paraphrased text:
TEXT: {text}""",
]
# Sampling parameters for different lengths
sampling_params_short = SamplingParams(
temperature=0.3,
top_p=0.95,
max_tokens=64, # Short responses
)
sampling_params_medium = SamplingParams(
temperature=0.3,
top_p=0.95,
max_tokens=256, # Medium responses
)
sampling_params_long = SamplingParams(
temperature=0.3,
top_p=0.95,
max_tokens=512, # Long responses
)
def generate_paraphrases(text: str) -> list:
"""
Generate three different paraphrased versions with varying lengths.
Args:
text: Text to paraphrase
Returns:
List of three paraphrased versions (short, medium, long)
"""
responses = []
sampling_params_list = [
sampling_params_short,
sampling_params_medium,
sampling_params_long,
]
for prompt_template, sampling_params in zip(PARAPHRASE_PROMPTS, sampling_params_list):
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_template.format(text=text)}],
tokenize=False,
add_generation_prompt=True,
)
output = model.fast_generate([formatted_prompt], sampling_params=sampling_params)
responses.append(output[0].outputs[0].text)
return responses
# Paraphrase all chunks and add to vector store
logger.info("Paraphrasing chunks and adding to vector store...")
all_paraphrased = []
chunk_ids = []
for i, chunk in enumerate(chunks):
# Get paragraphs from chunk
paragraphs = [p.strip() for p in chunk.page_content.split("\n\n") if p.strip()]
for paragraph in paragraphs:
# Generate 3 paraphrased versions
paraphrased_versions = generate_paraphrases(paragraph)
# Save original paragraph ID for reference
for version in paraphrased_versions:
all_paraphrased.append({"chunk_id": i + 1, "original_paragraph": paragraph, "paraphrased_text": version})
# Save paraphrased chunks to CSV for inspection
paraphrased_df = pd.DataFrame(all_paraphrased)
paraphrased_csv_path = DATA_DIR / "paragraphs_noise.csv"
paraphrased_df.to_csv(paraphrased_csv_path, index=False)
logger.info(f"Saved {len(all_paraphrased)} paraphrased paragraphs to {paraphrased_csv_path}")
paraphrased_docs = [
Document(page_content=item["paraphrased_text"], metadata={"chunk_id": item["chunk_id"], "is_paraphrase": True})
for item in all_paraphrased
]
# Process embeddings in smaller batches to avoid OOM
logger.info(f"Creating FAISS index with {len(paraphrased_docs)} documents in batches")
batch_size = 100 # Process 100 documents at a time
paraphrased_vectorstore = None
for i in range(0, len(paraphrased_docs), batch_size):
batch = paraphrased_docs[i : i + batch_size]
logger.info(f"Processing batch {i // batch_size + 1}/{(len(paraphrased_docs) + batch_size - 1) // batch_size}")
# Create a new FAISS index for this batch
batch_vectorstore = FAISS.from_documents(batch, embeddings)
# Merge with existing index or create a new one
if paraphrased_vectorstore is None:
paraphrased_vectorstore = batch_vectorstore
else:
paraphrased_vectorstore.merge_from(batch_vectorstore)
# Merge with main vectorstore
if paraphrased_vectorstore is not None:
vectorstore.merge_from(paraphrased_vectorstore)
logger.info(f"Updated FAISS index with {len(paraphrased_docs)} paraphrased paragraphs")
# Save the updated vector store
vectorstore.save_local(str(DATA_DIR))
logger.info(f"Saved updated FAISS index to {DATA_DIR}")
else:
logger.warning("No paraphrased documents were processed successfully")
def batch_generate(prompts: list) -> list:
"""
@ -148,9 +263,7 @@ def parse_multiple_qa_output(output: str) -> list:
return qa_pairs
def generate_question_batch_for_chunks(
chunks: list, num_questions: int = 2, difficulty=None
) -> list:
def generate_question_batch_for_chunks(chunks: list, num_questions: int = 2, difficulty=None) -> list:
"""
Generates QA pairs for multiple chunks in batch.
@ -233,9 +346,7 @@ def generate_question_batch_for_chunks(
results[idx] = valid_pairs[:num_questions]
else:
results[idx] = None
logger.warning(
f"Retry failed for chunk {idx + 1}: not enough valid QA pairs"
)
logger.warning(f"Retry failed for chunk {idx + 1}: not enough valid QA pairs")
else:
results[idx] = None
logger.warning(f"Retry failed for chunk {idx + 1}: parsing failed")
@ -245,13 +356,15 @@ def generate_question_batch_for_chunks(
for i, qa_list in enumerate(results):
if qa_list is not None:
for qa in qa_list:
# Get supporting paragraphs by splitting chunk content into paragraphs
supporting_paragraphs = [p.strip() for p in chunk_contents[i].split("\n\n") if p.strip()]
final_questions.append(
{
"chunk_id": chunk_ids[i],
"id": str(chunk_ids[i]),
"question": qa[0],
"answer": qa[1],
"difficulty": qa[2],
"chunk_content": chunk_contents[i],
"supporting_paragraphs": supporting_paragraphs,
}
)
@ -260,13 +373,13 @@ def generate_question_batch_for_chunks(
# Generate QA pairs in batch (using a sliding window over the chunks)
all_questions = generate_question_batch_for_chunks(
chunks, num_questions=2, difficulty="medium"
)
logger.info("Generating question-answer pairs...")
all_questions = generate_question_batch_for_chunks(chunks, num_questions=2, difficulty="medium")
logger.info(f"Generated {len(all_questions)} QA pairs.")
# Save the QA pairs to a JSON file
questions_path = DATA_DIR / "questions.json"
# Save the QA pairs to a JSONL file
questions_path = DATA_DIR / "questions.jsonl"
with open(questions_path, "w") as f:
json.dump(all_questions, f, indent=2)
for question in all_questions:
f.write(json.dumps(question) + "\n")
logger.info(f"Saved questions to {questions_path}")

@ -12,7 +12,7 @@ from langchain_community.vectorstores import FAISS
from config import DATA_DIR, logger
from src.embeddings import CustomHuggingFaceEmbeddings
PROCESSED_DATA_DIR = DATA_DIR / "processed"
PROCESSED_DATA_DIR = DATA_DIR
# Load pre-saved vectorstore

@ -2,8 +2,6 @@
Train a model using GRPO (Generative Reward-Penalized Optimization).
"""
import os
from unsloth import FastLanguageModel, is_bfloat16_supported
import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
@ -64,7 +62,7 @@ model = FastLanguageModel.get_peft_model(
# Load datasets
logger.info("Loading datasets")
train_dataset, test_dataset = get_qa_dataset(randomize=True, test_size=0, seed=42)
train_dataset, test_dataset = get_qa_dataset(randomize=False, test_size=0.1, seed=42)
logger.info(f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples")
# Setup training arguments
@ -76,7 +74,7 @@ training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
output_dir=OUTPUT_DIR,
reward_weights=[2.0, 1.0, 1.0, 1.0],
reward_weights=[4.0, 2.0, 1.0, 1.0, 1.0, 1.0],
)
@ -85,6 +83,7 @@ def agentic_generate(
prompts: list,
generate_fn,
max_generations: int = 32,
max_new_tokens: int = 4096 * 2,
):
# Create agent with appropriate adapter based on tokenizer
tokenizer_name = tokenizer.name_or_path.lower()
@ -98,7 +97,7 @@ def agentic_generate(
raise ValueError(f"Unsupported tokenizer: {tokenizer_name}")
agent = Agent(adapter)
return agent.run_agent(generate_fn, tokenizer, prompts, max_generations)
return agent.run_agent(generate_fn, tokenizer, prompts, max_generations, max_new_tokens=max_new_tokens)
model.agentic_generate = agentic_generate
@ -128,8 +127,8 @@ trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
reward_format,
reward_retry,
reward_em_chunk,
# reward_search_strategy,
# reward_search_diversity,
reward_search_strategy,
reward_search_diversity,
],
args=training_args,
train_dataset=train_dataset,

Loading…
Cancel
Save