From bac5f3b4f7a9ba02840d9372f91757fe0a15350c Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Wed, 16 Apr 2025 03:06:55 +0000 Subject: [PATCH] feat: update config and paths, update data genenration script --- app.py | 8 +- config.py | 14 ++-- scripts/generate_data.py | 165 +++++++++++++++++++++++++++++++++------ src/search_module.py | 2 +- train_grpo.py | 13 ++- 5 files changed, 155 insertions(+), 47 deletions(-) diff --git a/app.py b/app.py index c66b674..2f2a898 100644 --- a/app.py +++ b/app.py @@ -251,7 +251,7 @@ def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, tem gr.ChatMessage( role="assistant", content=f"Searching for: {search_query}", - metadata={"title": "🔍 Search", "status": "pending"}, + metadata={"title": "🔍 ReZero Query", "status": "pending"}, ) ) yield messages @@ -263,7 +263,7 @@ def create_deepsearch_tab(model, tokenizer, assistant_marker, system_prompt, tem messages[search_msg_idx] = gr.ChatMessage( role="assistant", content=f"{search_query}", - metadata={"title": "🔍 Search", "duration": search_duration}, + metadata={"title": "🔍 ReZero Query", "duration": search_duration}, ) yield messages display_results = format_search_results(results) @@ -630,7 +630,7 @@ def create_tavily_tab(model, tokenizer, assistant_marker, system_prompt, tempera gr.ChatMessage( role="assistant", content=f"Searching (Tavily) for: {search_query}", - metadata={"title": "🔍 Tavily Search", "status": "pending"}, + metadata={"title": "🔍 ReZero Query", "status": "pending"}, ) ) yield messages @@ -664,7 +664,7 @@ def create_tavily_tab(model, tokenizer, assistant_marker, system_prompt, tempera messages[search_msg_idx] = gr.ChatMessage( role="assistant", content=f"{search_query}", - metadata={"title": "🔍 Tavily Search", "duration": search_duration}, + metadata={"title": "🔍 ReZero Query", "duration": search_duration}, ) yield messages diff --git a/config.py b/config.py index 20430b9..dff52dd 100644 --- a/config.py +++ b/config.py @@ -21,18 +21,14 @@ LOG_FOLDER = PROJ_ROOT / "logs" RETRIEVER_MODEL_REPO_ID = "intfloat/e5-base-v2" RETRIEVER_MODEL_DIR = MODEL_DIR / "retriever" RETRIEVER_SERVER_PORT = 8001 -GENERATOR_MODEL_REPO_ID = "janhq/250404-llama-3.2-3b-instruct-grpo-03-s250" +GENERATOR_MODEL_REPO_ID = "Menlo/ReZero-v0.1-llama-3.2-3b-it-grpo-250404" GENERATOR_MODEL_DIR = MODEL_DIR / "generator" GENERATOR_SERVER_PORT = 8002 # Model configuration -# MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" -# MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" + MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct" -# MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" -# MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" -# MODEL_NAME = "unsloth/Qwen2-1.5B" # Smoke test first device_id = 1 if os.environ.get("CUDA_VISIBLE_DEVICES") == "1" else torch.cuda.current_device() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") @@ -40,7 +36,7 @@ OUTPUT_DIR = PROJ_ROOT / f"trainer_output_{MODEL_NAME.replace('/', '_')}_gpu{dev # Model parameters MODEL_CONFIG = { - "max_seq_length": 4096 * 6, # 24k tokens -> just try to utiliiz + "max_seq_length": 4096 * 2, "lora_rank": 64, # Larger rank = smarter, but slower "gpu_memory_utilization": 0.6, # Reduce if out of memory "model_name": MODEL_NAME, @@ -68,7 +64,7 @@ TRAINING_CONFIG = { "per_device_train_batch_size": 8, "gradient_accumulation_steps": 1, # Increase to 4 for smoother training "num_generations": 6, # Decrease if out of memory - "max_prompt_length": 4096 * 4 - 2048, + "max_prompt_length": 4096 * 2, "max_completion_length": 2048, "max_steps": 1000, "save_steps": 50, @@ -83,7 +79,7 @@ def get_sampling_params(temperature: float = 0.1) -> SamplingParams: return SamplingParams( temperature=temperature, top_p=0.95, - max_tokens=4096 * 6, + max_tokens=4096 * 2, ) diff --git a/scripts/generate_data.py b/scripts/generate_data.py index 2c8a4b7..64bad3d 100644 --- a/scripts/generate_data.py +++ b/scripts/generate_data.py @@ -1,14 +1,11 @@ """ This script performs two main tasks: 1. It loads a markdown document, splits it into chunks, generates embeddings, - and builds a FAISS index (which is saved locally). + and builds a FAISS index with the original and paraphrased chunks (which is saved locally). 2. It generates QA pairs from the document using llama. For each chunk (using a sliding window for context), it generates multiple question-answer pairs with different difficulties. The generation is performed in batch with one retry for failed prompts. - Successfully generated QA pairs are saved to "data/questions.json". - -Requirements: - pip install langchain faiss-cpu unsloth vllm + Successfully generated QA pairs are saved to "data/questions.jsonl". """ import json @@ -21,12 +18,12 @@ project_root = Path(__file__).resolve().parent.parent sys.path.append(str(project_root)) import pandas as pd +from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter - -# ========= Part 1: Document Processing and Embedding Generation ========= -# Load and split the markdown document using LangChain from langchain_community.document_loaders import UnstructuredMarkdownLoader from langchain_community.vectorstores import FAISS +from unsloth import FastLanguageModel +from vllm import SamplingParams from config import DATA_DIR, logger from src.embeddings import CustomHuggingFaceEmbeddings @@ -57,13 +54,10 @@ vectorstore = FAISS.from_documents(chunks, embeddings) vectorstore.save_local(str(DATA_DIR)) logger.info(f"Saved FAISS index to {DATA_DIR}") -# TODO: add the paraphrased chunks to the vector store # ========= Part 2: QA Generation using Llama Backend ========= # Setup Llama backend via unsloth and vLLM -from unsloth import FastLanguageModel -from vllm import SamplingParams # Load the Llama model (adjust parameters as needed) model, tokenizer = FastLanguageModel.from_pretrained( @@ -81,6 +75,127 @@ sampling_params = SamplingParams( max_tokens=4096, ) +# Define paraphrasing styles and parameters +PARAPHRASE_PROMPTS = [ + """Rewrite this text in a formal, scholarly tone. Keep it very concise - summarize in 1-2 short sentences. Only output the paraphrased text: + + TEXT: {text}""", + """Rewrite this text in a clear, simple way that's easy to understand. Provide a medium-length explanation with key details. Only output the paraphrased text: + + TEXT: {text}""", + """Rewrite this text in a vivid, engaging style. Expand on the details and provide a comprehensive, detailed version. Only output the paraphrased text: + + TEXT: {text}""", +] + +# Sampling parameters for different lengths +sampling_params_short = SamplingParams( + temperature=0.3, + top_p=0.95, + max_tokens=64, # Short responses +) + +sampling_params_medium = SamplingParams( + temperature=0.3, + top_p=0.95, + max_tokens=256, # Medium responses +) + +sampling_params_long = SamplingParams( + temperature=0.3, + top_p=0.95, + max_tokens=512, # Long responses +) + + +def generate_paraphrases(text: str) -> list: + """ + Generate three different paraphrased versions with varying lengths. + + Args: + text: Text to paraphrase + + Returns: + List of three paraphrased versions (short, medium, long) + """ + responses = [] + sampling_params_list = [ + sampling_params_short, + sampling_params_medium, + sampling_params_long, + ] + + for prompt_template, sampling_params in zip(PARAPHRASE_PROMPTS, sampling_params_list): + formatted_prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt_template.format(text=text)}], + tokenize=False, + add_generation_prompt=True, + ) + + output = model.fast_generate([formatted_prompt], sampling_params=sampling_params) + responses.append(output[0].outputs[0].text) + + return responses + + +# Paraphrase all chunks and add to vector store +logger.info("Paraphrasing chunks and adding to vector store...") +all_paraphrased = [] +chunk_ids = [] + +for i, chunk in enumerate(chunks): + # Get paragraphs from chunk + paragraphs = [p.strip() for p in chunk.page_content.split("\n\n") if p.strip()] + + for paragraph in paragraphs: + # Generate 3 paraphrased versions + paraphrased_versions = generate_paraphrases(paragraph) + + # Save original paragraph ID for reference + for version in paraphrased_versions: + all_paraphrased.append({"chunk_id": i + 1, "original_paragraph": paragraph, "paraphrased_text": version}) + +# Save paraphrased chunks to CSV for inspection +paraphrased_df = pd.DataFrame(all_paraphrased) +paraphrased_csv_path = DATA_DIR / "paragraphs_noise.csv" +paraphrased_df.to_csv(paraphrased_csv_path, index=False) +logger.info(f"Saved {len(all_paraphrased)} paraphrased paragraphs to {paraphrased_csv_path}") + + +paraphrased_docs = [ + Document(page_content=item["paraphrased_text"], metadata={"chunk_id": item["chunk_id"], "is_paraphrase": True}) + for item in all_paraphrased +] + +# Process embeddings in smaller batches to avoid OOM +logger.info(f"Creating FAISS index with {len(paraphrased_docs)} documents in batches") +batch_size = 100 # Process 100 documents at a time +paraphrased_vectorstore = None + +for i in range(0, len(paraphrased_docs), batch_size): + batch = paraphrased_docs[i : i + batch_size] + logger.info(f"Processing batch {i // batch_size + 1}/{(len(paraphrased_docs) + batch_size - 1) // batch_size}") + + # Create a new FAISS index for this batch + batch_vectorstore = FAISS.from_documents(batch, embeddings) + + # Merge with existing index or create a new one + if paraphrased_vectorstore is None: + paraphrased_vectorstore = batch_vectorstore + else: + paraphrased_vectorstore.merge_from(batch_vectorstore) + +# Merge with main vectorstore +if paraphrased_vectorstore is not None: + vectorstore.merge_from(paraphrased_vectorstore) + logger.info(f"Updated FAISS index with {len(paraphrased_docs)} paraphrased paragraphs") + + # Save the updated vector store + vectorstore.save_local(str(DATA_DIR)) + logger.info(f"Saved updated FAISS index to {DATA_DIR}") +else: + logger.warning("No paraphrased documents were processed successfully") + def batch_generate(prompts: list) -> list: """ @@ -148,9 +263,7 @@ def parse_multiple_qa_output(output: str) -> list: return qa_pairs -def generate_question_batch_for_chunks( - chunks: list, num_questions: int = 2, difficulty=None -) -> list: +def generate_question_batch_for_chunks(chunks: list, num_questions: int = 2, difficulty=None) -> list: """ Generates QA pairs for multiple chunks in batch. @@ -233,9 +346,7 @@ def generate_question_batch_for_chunks( results[idx] = valid_pairs[:num_questions] else: results[idx] = None - logger.warning( - f"Retry failed for chunk {idx + 1}: not enough valid QA pairs" - ) + logger.warning(f"Retry failed for chunk {idx + 1}: not enough valid QA pairs") else: results[idx] = None logger.warning(f"Retry failed for chunk {idx + 1}: parsing failed") @@ -245,13 +356,15 @@ def generate_question_batch_for_chunks( for i, qa_list in enumerate(results): if qa_list is not None: for qa in qa_list: + # Get supporting paragraphs by splitting chunk content into paragraphs + supporting_paragraphs = [p.strip() for p in chunk_contents[i].split("\n\n") if p.strip()] + final_questions.append( { - "chunk_id": chunk_ids[i], + "id": str(chunk_ids[i]), "question": qa[0], "answer": qa[1], - "difficulty": qa[2], - "chunk_content": chunk_contents[i], + "supporting_paragraphs": supporting_paragraphs, } ) @@ -260,13 +373,13 @@ def generate_question_batch_for_chunks( # Generate QA pairs in batch (using a sliding window over the chunks) -all_questions = generate_question_batch_for_chunks( - chunks, num_questions=2, difficulty="medium" -) +logger.info("Generating question-answer pairs...") +all_questions = generate_question_batch_for_chunks(chunks, num_questions=2, difficulty="medium") logger.info(f"Generated {len(all_questions)} QA pairs.") -# Save the QA pairs to a JSON file -questions_path = DATA_DIR / "questions.json" +# Save the QA pairs to a JSONL file +questions_path = DATA_DIR / "questions.jsonl" with open(questions_path, "w") as f: - json.dump(all_questions, f, indent=2) + for question in all_questions: + f.write(json.dumps(question) + "\n") logger.info(f"Saved questions to {questions_path}") diff --git a/src/search_module.py b/src/search_module.py index a8eba7b..0187ec5 100644 --- a/src/search_module.py +++ b/src/search_module.py @@ -12,7 +12,7 @@ from langchain_community.vectorstores import FAISS from config import DATA_DIR, logger from src.embeddings import CustomHuggingFaceEmbeddings -PROCESSED_DATA_DIR = DATA_DIR / "processed" +PROCESSED_DATA_DIR = DATA_DIR # Load pre-saved vectorstore diff --git a/train_grpo.py b/train_grpo.py index 1460638..42c4f21 100644 --- a/train_grpo.py +++ b/train_grpo.py @@ -2,8 +2,6 @@ Train a model using GRPO (Generative Reward-Penalized Optimization). """ -import os - from unsloth import FastLanguageModel, is_bfloat16_supported import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp @@ -64,7 +62,7 @@ model = FastLanguageModel.get_peft_model( # Load datasets logger.info("Loading datasets") -train_dataset, test_dataset = get_qa_dataset(randomize=True, test_size=0, seed=42) +train_dataset, test_dataset = get_qa_dataset(randomize=False, test_size=0.1, seed=42) logger.info(f"Loaded {len(train_dataset)} training examples and {len(test_dataset)} test examples") # Setup training arguments @@ -76,7 +74,7 @@ training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig( bf16=is_bfloat16_supported(), fp16=not is_bfloat16_supported(), output_dir=OUTPUT_DIR, - reward_weights=[2.0, 1.0, 1.0, 1.0], + reward_weights=[4.0, 2.0, 1.0, 1.0, 1.0, 1.0], ) @@ -85,6 +83,7 @@ def agentic_generate( prompts: list, generate_fn, max_generations: int = 32, + max_new_tokens: int = 4096 * 2, ): # Create agent with appropriate adapter based on tokenizer tokenizer_name = tokenizer.name_or_path.lower() @@ -98,7 +97,7 @@ def agentic_generate( raise ValueError(f"Unsupported tokenizer: {tokenizer_name}") agent = Agent(adapter) - return agent.run_agent(generate_fn, tokenizer, prompts, max_generations) + return agent.run_agent(generate_fn, tokenizer, prompts, max_generations, max_new_tokens=max_new_tokens) model.agentic_generate = agentic_generate @@ -128,8 +127,8 @@ trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer( reward_format, reward_retry, reward_em_chunk, - # reward_search_strategy, - # reward_search_diversity, + reward_search_strategy, + reward_search_diversity, ], args=training_args, train_dataset=train_dataset,