parent
b22b02ea1d
commit
04d56325bb
@ -0,0 +1,279 @@
|
|||||||
|
"""
|
||||||
|
This script performs two main tasks:
|
||||||
|
1. It loads a markdown document, splits it into chunks, generates embeddings,
|
||||||
|
and builds a FAISS index (which is saved locally).
|
||||||
|
2. It generates QA pairs from the document using llama.
|
||||||
|
For each chunk (using a sliding window for context), it generates multiple question-answer pairs
|
||||||
|
with different difficulties. The generation is performed in batch with one retry for failed prompts.
|
||||||
|
Successfully generated QA pairs are saved to "saved_data/questions.json".
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
pip install langchain faiss-cpu unsloth vllm
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
# Configure logger
|
||||||
|
logger.add(
|
||||||
|
"logs/generate_data_{time}.log",
|
||||||
|
rotation="500 MB",
|
||||||
|
retention="10 days",
|
||||||
|
level="INFO",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ========= Part 1: Document Processing and Embedding Generation =========
|
||||||
|
# Load and split the markdown document using LangChain
|
||||||
|
from langchain_community.document_loaders import UnstructuredMarkdownLoader
|
||||||
|
from langchain_community.vectorstores import FAISS
|
||||||
|
|
||||||
|
from embeddings import CustomHuggingFaceEmbeddings
|
||||||
|
|
||||||
|
# Load your markdown file (adjust the path as needed)
|
||||||
|
loader = UnstructuredMarkdownLoader("./data/mission_report.md")
|
||||||
|
docs = loader.load()
|
||||||
|
|
||||||
|
# Split the document into smaller chunks (each 1000 characters, no overlap)
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||||
|
chunks = text_splitter.split_documents(docs)
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
os.makedirs("saved_data", exist_ok=True)
|
||||||
|
|
||||||
|
# Save chunks to CSV for easy inspection
|
||||||
|
chunks_df = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"chunk_id": range(1, len(chunks) + 1),
|
||||||
|
"content": [chunk.page_content for chunk in chunks],
|
||||||
|
"metadata": [chunk.metadata for chunk in chunks],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
chunks_df.to_csv("saved_data/chunks.csv", index=False)
|
||||||
|
print(f"Saved {len(chunks)} chunks to saved_data/chunks.csv")
|
||||||
|
|
||||||
|
embeddings = CustomHuggingFaceEmbeddings()
|
||||||
|
|
||||||
|
# Create a FAISS vector store from the document chunks and save it locally
|
||||||
|
vectorstore = FAISS.from_documents(chunks, embeddings)
|
||||||
|
vectorstore.save_local("faiss_index")
|
||||||
|
print("Saved FAISS index to 'faiss_index'")
|
||||||
|
|
||||||
|
# TODO: add the paraphrased chunks to the vector store
|
||||||
|
|
||||||
|
# ========= Part 2: QA Generation using Llama Backend =========
|
||||||
|
|
||||||
|
# Setup Llama backend via unsloth and vLLM
|
||||||
|
from unsloth import FastLanguageModel
|
||||||
|
from vllm import SamplingParams
|
||||||
|
|
||||||
|
import rl_helpers # Ensure you have this or remove if not used
|
||||||
|
|
||||||
|
# Load the Llama model (adjust parameters as needed)
|
||||||
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||||
|
model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
|
||||||
|
max_seq_length=4096,
|
||||||
|
load_in_4bit=True, # Use 4-bit quantization if desired
|
||||||
|
fast_inference=True, # Enable fast inference
|
||||||
|
gpu_memory_utilization=0.6, # Adjust based on your GPU memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define sampling parameters for generation
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.3,
|
||||||
|
top_p=0.95,
|
||||||
|
max_tokens=4096,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_generate(prompts: List[str]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Given a list of prompt strings, returns a list of generated outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def format_input(text: str) -> str:
|
||||||
|
return tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": text}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted = [format_input(p) for p in prompts]
|
||||||
|
outputs = model.fast_generate(formatted, sampling_params=sampling_params)
|
||||||
|
return [output.outputs[0].text for output in outputs]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_qa_block(block: str) -> Optional[Tuple[str, str, str]]:
|
||||||
|
"""
|
||||||
|
Parses a QA block that should contain exactly three non-empty lines:
|
||||||
|
- A line starting with "Question:"
|
||||||
|
- A line starting with "Answer:"
|
||||||
|
- A line starting with "Difficulty:"
|
||||||
|
|
||||||
|
If the markers are not present but the block contains exactly three lines,
|
||||||
|
those are used in order.
|
||||||
|
|
||||||
|
Returns a tuple (question, answer, difficulty) or None if parsing fails.
|
||||||
|
"""
|
||||||
|
lines = [line.strip() for line in block.splitlines() if line.strip()]
|
||||||
|
if not lines:
|
||||||
|
return None
|
||||||
|
|
||||||
|
question, answer, difficulty = None, None, None
|
||||||
|
for line in lines:
|
||||||
|
lower = line.lower()
|
||||||
|
if question is None and lower.startswith("question:"):
|
||||||
|
question = line[len("question:") :].strip()
|
||||||
|
elif answer is None and lower.startswith("answer:"):
|
||||||
|
answer = line[len("answer:") :].strip()
|
||||||
|
elif difficulty is None and lower.startswith("difficulty:"):
|
||||||
|
difficulty = line[len("difficulty:") :].strip()
|
||||||
|
|
||||||
|
if question and answer and difficulty:
|
||||||
|
return question, answer, difficulty
|
||||||
|
if len(lines) == 3:
|
||||||
|
return lines[0], lines[1], lines[2]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_multiple_qa_output(output: str) -> List[Tuple[str, str, str]]:
|
||||||
|
"""
|
||||||
|
Splits the output into blocks (separated by one or more blank lines) and
|
||||||
|
attempts to parse each as a QA pair.
|
||||||
|
|
||||||
|
Returns a list of successfully parsed QA tuples.
|
||||||
|
"""
|
||||||
|
blocks = re.split(r"\n\s*\n", output.strip())
|
||||||
|
qa_pairs = []
|
||||||
|
for block in blocks:
|
||||||
|
parsed = parse_qa_block(block)
|
||||||
|
if parsed:
|
||||||
|
qa_pairs.append(parsed)
|
||||||
|
return qa_pairs
|
||||||
|
|
||||||
|
|
||||||
|
def generate_question_batch_for_chunks(
|
||||||
|
chunks: List, num_questions: int = 2, difficulty: Optional[str] = None
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Generates QA pairs for multiple chunks in batch.
|
||||||
|
|
||||||
|
For each chunk, generates questions based on its content only.
|
||||||
|
Each prompt instructs the model to output exactly three lines per QA pair with markers.
|
||||||
|
Failed prompts are retried once in batch; if still unsuccessful, they are skipped.
|
||||||
|
|
||||||
|
Returns a list of dicts with keys: "chunk_id", "question", "answer", "difficulty", "chunk_content".
|
||||||
|
"""
|
||||||
|
prompts = []
|
||||||
|
chunk_ids = []
|
||||||
|
chunk_contents = []
|
||||||
|
|
||||||
|
# Prepare prompts for each chunk
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
current = chunk.page_content
|
||||||
|
prompt = (
|
||||||
|
f"You are a question generator. Generate {num_questions} questions based on the following text.\n"
|
||||||
|
"Rules:\n"
|
||||||
|
"1. Questions must be answerable using ONLY the information in the text\n"
|
||||||
|
"2. Answers must be directly stated in the text\n"
|
||||||
|
"3. Each question should test understanding of a different aspect of the text\n"
|
||||||
|
"4. Questions should be clear and specific\n"
|
||||||
|
"5. Answers should be concise and factual\n\n"
|
||||||
|
"For each QA pair, output exactly three lines with no extra commentary:\n"
|
||||||
|
"Line 1: Question: <your question>\n"
|
||||||
|
"Line 2: Answer: <the answer>\n"
|
||||||
|
"Line 3: Difficulty: <easy, medium, or hard>\n"
|
||||||
|
"Do not include any additional text.\n\n"
|
||||||
|
"Text:\n"
|
||||||
|
f"{current}\n"
|
||||||
|
)
|
||||||
|
prompts.append(prompt)
|
||||||
|
chunk_ids.append(i + 1) # 1-based indexing
|
||||||
|
chunk_contents.append(current)
|
||||||
|
|
||||||
|
# First batch generation
|
||||||
|
outputs = batch_generate(prompts)
|
||||||
|
results: List[Optional[List[Tuple[str, str, str]]]] = [None] * len(outputs)
|
||||||
|
failed_indices = []
|
||||||
|
|
||||||
|
# Parse each output
|
||||||
|
for idx, output in enumerate(outputs):
|
||||||
|
qa_pairs = parse_multiple_qa_output(output)
|
||||||
|
if qa_pairs is None or len(qa_pairs) < num_questions:
|
||||||
|
failed_indices.append(idx)
|
||||||
|
logger.warning(f"Failed to generate enough QA pairs for chunk {idx + 1}")
|
||||||
|
else:
|
||||||
|
# Validate that answers exist in chunk content
|
||||||
|
valid_pairs = []
|
||||||
|
for q, a, d in qa_pairs:
|
||||||
|
if a.lower() in chunk_contents[idx].lower():
|
||||||
|
valid_pairs.append((q, a, d))
|
||||||
|
else:
|
||||||
|
logger.warning(f"Answer not found in chunk content: {a}")
|
||||||
|
|
||||||
|
if len(valid_pairs) >= num_questions:
|
||||||
|
results[idx] = valid_pairs[:num_questions]
|
||||||
|
else:
|
||||||
|
failed_indices.append(idx)
|
||||||
|
logger.warning(f"Not enough valid QA pairs for chunk {idx + 1}")
|
||||||
|
|
||||||
|
# Retry failed prompts in batch
|
||||||
|
if failed_indices:
|
||||||
|
logger.info(f"Retrying {len(failed_indices)} failed prompt(s)...")
|
||||||
|
retry_prompts = [prompts[i] for i in failed_indices]
|
||||||
|
retry_outputs = batch_generate(retry_prompts)
|
||||||
|
for j, idx in enumerate(failed_indices):
|
||||||
|
qa_pairs = parse_multiple_qa_output(retry_outputs[j])
|
||||||
|
if qa_pairs is not None and len(qa_pairs) >= num_questions:
|
||||||
|
# Validate answers again
|
||||||
|
valid_pairs = []
|
||||||
|
for q, a, d in qa_pairs:
|
||||||
|
if a.lower() in chunk_contents[idx].lower():
|
||||||
|
valid_pairs.append((q, a, d))
|
||||||
|
|
||||||
|
if len(valid_pairs) >= num_questions:
|
||||||
|
results[idx] = valid_pairs[:num_questions]
|
||||||
|
else:
|
||||||
|
results[idx] = None
|
||||||
|
logger.warning(
|
||||||
|
f"Retry failed for chunk {idx + 1}: not enough valid QA pairs"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results[idx] = None
|
||||||
|
logger.warning(f"Retry failed for chunk {idx + 1}: parsing failed")
|
||||||
|
|
||||||
|
# Build final output, skipping prompts that failed even after retry
|
||||||
|
final_questions = []
|
||||||
|
for i, qa_list in enumerate(results):
|
||||||
|
if qa_list is not None:
|
||||||
|
for qa in qa_list:
|
||||||
|
final_questions.append(
|
||||||
|
{
|
||||||
|
"chunk_id": chunk_ids[i],
|
||||||
|
"question": qa[0],
|
||||||
|
"answer": qa[1],
|
||||||
|
"difficulty": qa[2],
|
||||||
|
"chunk_content": chunk_contents[i],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Generated {len(final_questions)} valid QA pairs")
|
||||||
|
return final_questions
|
||||||
|
|
||||||
|
|
||||||
|
# Generate QA pairs in batch (using a sliding window over the chunks)
|
||||||
|
all_questions = generate_question_batch_for_chunks(
|
||||||
|
chunks, num_questions=2, difficulty="medium"
|
||||||
|
)
|
||||||
|
print(f"Generated {len(all_questions)} QA pairs.")
|
||||||
|
|
||||||
|
# Save the QA pairs to a JSON file
|
||||||
|
questions_path = os.path.join("saved_data", "questions.json")
|
||||||
|
with open(questions_path, "w") as f:
|
||||||
|
json.dump(all_questions, f, indent=2)
|
||||||
|
print(f"Saved questions to {questions_path}")
|
@ -0,0 +1,113 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!pip install matplotlib -q"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"\n",
|
||||||
|
"def plot_reward_functions():\n",
|
||||||
|
" # Generate retry counts from 0 to 15\n",
|
||||||
|
" retries = np.linspace(0, 15, 100)\n",
|
||||||
|
" \n",
|
||||||
|
" # 1. Basic Sigmoid\n",
|
||||||
|
" basic_sigmoid = 1 / (1 + np.exp(-(retries - 4)))\n",
|
||||||
|
" \n",
|
||||||
|
" # 2. Our Modified Sigmoid\n",
|
||||||
|
" x = retries - 4 # Center at 4 retries\n",
|
||||||
|
" modified_sigmoid = 1 / (1 + np.exp(-x + abs(x)/2))\n",
|
||||||
|
" \n",
|
||||||
|
" # 3. With Penalty\n",
|
||||||
|
" penalized_reward = modified_sigmoid.copy()\n",
|
||||||
|
" for i, r in enumerate(retries):\n",
|
||||||
|
" if r > 6:\n",
|
||||||
|
" penalty = 0.2 * (r - 6)\n",
|
||||||
|
" penalized_reward[i] = max(0.1, modified_sigmoid[i] - penalty)\n",
|
||||||
|
" \n",
|
||||||
|
" # Plotting\n",
|
||||||
|
" plt.figure(figsize=(12, 6))\n",
|
||||||
|
" \n",
|
||||||
|
" plt.plot(retries, basic_sigmoid, 'b--', label='Basic Sigmoid')\n",
|
||||||
|
" plt.plot(retries, modified_sigmoid, 'g--', label='Modified Sigmoid')\n",
|
||||||
|
" plt.plot(retries, penalized_reward, 'r-', label='Final Reward (with penalty)', linewidth=2)\n",
|
||||||
|
" \n",
|
||||||
|
" # Add vertical lines for key points\n",
|
||||||
|
" plt.axvline(x=4, color='gray', linestyle=':', alpha=0.5, label='Peak (4 retries)')\n",
|
||||||
|
" plt.axvline(x=6, color='gray', linestyle=':', alpha=0.5, label='Penalty Start (6 retries)')\n",
|
||||||
|
" \n",
|
||||||
|
" plt.grid(True, alpha=0.3)\n",
|
||||||
|
" plt.xlabel('Number of Retries')\n",
|
||||||
|
" plt.ylabel('Reward')\n",
|
||||||
|
" plt.title('Reward Function Visualization')\n",
|
||||||
|
" plt.legend()\n",
|
||||||
|
" plt.ylim(-0.1, 1.1)\n",
|
||||||
|
" \n",
|
||||||
|
" # Add annotations\n",
|
||||||
|
" plt.annotate('Optimal Zone', xy=(4, 0.8), xytext=(4, 0.9),\n",
|
||||||
|
" ha='center', va='bottom',\n",
|
||||||
|
" bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3),\n",
|
||||||
|
" arrowprops=dict(arrowstyle='->'))\n",
|
||||||
|
" \n",
|
||||||
|
" plt.annotate('Penalty Zone', xy=(8, 0.3), xytext=(8, 0.5),\n",
|
||||||
|
" ha='center', va='bottom',\n",
|
||||||
|
" bbox=dict(boxstyle='round,pad=0.5', fc='red', alpha=0.3),\n",
|
||||||
|
" arrowprops=dict(arrowstyle='->'))\n",
|
||||||
|
" \n",
|
||||||
|
" plt.show()\n",
|
||||||
|
"\n",
|
||||||
|
"# Run the visualization\n",
|
||||||
|
"plot_reward_functions()\n",
|
||||||
|
"\n",
|
||||||
|
"# Print reward values for specific retry counts\n",
|
||||||
|
"def print_reward_examples():\n",
|
||||||
|
" retry_examples = [1, 2, 3, 4, 5, 6, 7, 8, 10, 12]\n",
|
||||||
|
" print(\"\\nReward values for different retry counts:\")\n",
|
||||||
|
" print(\"Retries | Reward\")\n",
|
||||||
|
" print(\"-\" * 20)\n",
|
||||||
|
" \n",
|
||||||
|
" for retries in retry_examples:\n",
|
||||||
|
" x = retries - 4\n",
|
||||||
|
" reward = 1 / (1 + np.exp(-x + abs(x)/2))\n",
|
||||||
|
" if retries > 6:\n",
|
||||||
|
" penalty = 0.2 * (retries - 6)\n",
|
||||||
|
" reward = max(0.1, reward - penalty)\n",
|
||||||
|
" print(f\"{retries:7d} | {reward:.3f}\")\n",
|
||||||
|
"\n",
|
||||||
|
"print_reward_examples()"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.11"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
Loading…
Reference in new issue