parent
b22b02ea1d
commit
04d56325bb
@ -0,0 +1,279 @@
|
||||
"""
|
||||
This script performs two main tasks:
|
||||
1. It loads a markdown document, splits it into chunks, generates embeddings,
|
||||
and builds a FAISS index (which is saved locally).
|
||||
2. It generates QA pairs from the document using llama.
|
||||
For each chunk (using a sliding window for context), it generates multiple question-answer pairs
|
||||
with different difficulties. The generation is performed in batch with one retry for failed prompts.
|
||||
Successfully generated QA pairs are saved to "saved_data/questions.json".
|
||||
|
||||
Requirements:
|
||||
pip install langchain faiss-cpu unsloth vllm
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import pandas as pd
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from loguru import logger
|
||||
|
||||
# Configure logger
|
||||
logger.add(
|
||||
"logs/generate_data_{time}.log",
|
||||
rotation="500 MB",
|
||||
retention="10 days",
|
||||
level="INFO",
|
||||
)
|
||||
|
||||
# ========= Part 1: Document Processing and Embedding Generation =========
|
||||
# Load and split the markdown document using LangChain
|
||||
from langchain_community.document_loaders import UnstructuredMarkdownLoader
|
||||
from langchain_community.vectorstores import FAISS
|
||||
|
||||
from embeddings import CustomHuggingFaceEmbeddings
|
||||
|
||||
# Load your markdown file (adjust the path as needed)
|
||||
loader = UnstructuredMarkdownLoader("./data/mission_report.md")
|
||||
docs = loader.load()
|
||||
|
||||
# Split the document into smaller chunks (each 1000 characters, no overlap)
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||
chunks = text_splitter.split_documents(docs)
|
||||
|
||||
# Create output directory
|
||||
os.makedirs("saved_data", exist_ok=True)
|
||||
|
||||
# Save chunks to CSV for easy inspection
|
||||
chunks_df = pd.DataFrame(
|
||||
{
|
||||
"chunk_id": range(1, len(chunks) + 1),
|
||||
"content": [chunk.page_content for chunk in chunks],
|
||||
"metadata": [chunk.metadata for chunk in chunks],
|
||||
}
|
||||
)
|
||||
chunks_df.to_csv("saved_data/chunks.csv", index=False)
|
||||
print(f"Saved {len(chunks)} chunks to saved_data/chunks.csv")
|
||||
|
||||
embeddings = CustomHuggingFaceEmbeddings()
|
||||
|
||||
# Create a FAISS vector store from the document chunks and save it locally
|
||||
vectorstore = FAISS.from_documents(chunks, embeddings)
|
||||
vectorstore.save_local("faiss_index")
|
||||
print("Saved FAISS index to 'faiss_index'")
|
||||
|
||||
# TODO: add the paraphrased chunks to the vector store
|
||||
|
||||
# ========= Part 2: QA Generation using Llama Backend =========
|
||||
|
||||
# Setup Llama backend via unsloth and vLLM
|
||||
from unsloth import FastLanguageModel
|
||||
from vllm import SamplingParams
|
||||
|
||||
import rl_helpers # Ensure you have this or remove if not used
|
||||
|
||||
# Load the Llama model (adjust parameters as needed)
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
|
||||
max_seq_length=4096,
|
||||
load_in_4bit=True, # Use 4-bit quantization if desired
|
||||
fast_inference=True, # Enable fast inference
|
||||
gpu_memory_utilization=0.6, # Adjust based on your GPU memory
|
||||
)
|
||||
|
||||
# Define sampling parameters for generation
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.3,
|
||||
top_p=0.95,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
|
||||
def batch_generate(prompts: List[str]) -> List[str]:
|
||||
"""
|
||||
Given a list of prompt strings, returns a list of generated outputs.
|
||||
"""
|
||||
|
||||
def format_input(text: str) -> str:
|
||||
return tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": text}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
formatted = [format_input(p) for p in prompts]
|
||||
outputs = model.fast_generate(formatted, sampling_params=sampling_params)
|
||||
return [output.outputs[0].text for output in outputs]
|
||||
|
||||
|
||||
def parse_qa_block(block: str) -> Optional[Tuple[str, str, str]]:
|
||||
"""
|
||||
Parses a QA block that should contain exactly three non-empty lines:
|
||||
- A line starting with "Question:"
|
||||
- A line starting with "Answer:"
|
||||
- A line starting with "Difficulty:"
|
||||
|
||||
If the markers are not present but the block contains exactly three lines,
|
||||
those are used in order.
|
||||
|
||||
Returns a tuple (question, answer, difficulty) or None if parsing fails.
|
||||
"""
|
||||
lines = [line.strip() for line in block.splitlines() if line.strip()]
|
||||
if not lines:
|
||||
return None
|
||||
|
||||
question, answer, difficulty = None, None, None
|
||||
for line in lines:
|
||||
lower = line.lower()
|
||||
if question is None and lower.startswith("question:"):
|
||||
question = line[len("question:") :].strip()
|
||||
elif answer is None and lower.startswith("answer:"):
|
||||
answer = line[len("answer:") :].strip()
|
||||
elif difficulty is None and lower.startswith("difficulty:"):
|
||||
difficulty = line[len("difficulty:") :].strip()
|
||||
|
||||
if question and answer and difficulty:
|
||||
return question, answer, difficulty
|
||||
if len(lines) == 3:
|
||||
return lines[0], lines[1], lines[2]
|
||||
return None
|
||||
|
||||
|
||||
def parse_multiple_qa_output(output: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
Splits the output into blocks (separated by one or more blank lines) and
|
||||
attempts to parse each as a QA pair.
|
||||
|
||||
Returns a list of successfully parsed QA tuples.
|
||||
"""
|
||||
blocks = re.split(r"\n\s*\n", output.strip())
|
||||
qa_pairs = []
|
||||
for block in blocks:
|
||||
parsed = parse_qa_block(block)
|
||||
if parsed:
|
||||
qa_pairs.append(parsed)
|
||||
return qa_pairs
|
||||
|
||||
|
||||
def generate_question_batch_for_chunks(
|
||||
chunks: List, num_questions: int = 2, difficulty: Optional[str] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Generates QA pairs for multiple chunks in batch.
|
||||
|
||||
For each chunk, generates questions based on its content only.
|
||||
Each prompt instructs the model to output exactly three lines per QA pair with markers.
|
||||
Failed prompts are retried once in batch; if still unsuccessful, they are skipped.
|
||||
|
||||
Returns a list of dicts with keys: "chunk_id", "question", "answer", "difficulty", "chunk_content".
|
||||
"""
|
||||
prompts = []
|
||||
chunk_ids = []
|
||||
chunk_contents = []
|
||||
|
||||
# Prepare prompts for each chunk
|
||||
for i, chunk in enumerate(chunks):
|
||||
current = chunk.page_content
|
||||
prompt = (
|
||||
f"You are a question generator. Generate {num_questions} questions based on the following text.\n"
|
||||
"Rules:\n"
|
||||
"1. Questions must be answerable using ONLY the information in the text\n"
|
||||
"2. Answers must be directly stated in the text\n"
|
||||
"3. Each question should test understanding of a different aspect of the text\n"
|
||||
"4. Questions should be clear and specific\n"
|
||||
"5. Answers should be concise and factual\n\n"
|
||||
"For each QA pair, output exactly three lines with no extra commentary:\n"
|
||||
"Line 1: Question: <your question>\n"
|
||||
"Line 2: Answer: <the answer>\n"
|
||||
"Line 3: Difficulty: <easy, medium, or hard>\n"
|
||||
"Do not include any additional text.\n\n"
|
||||
"Text:\n"
|
||||
f"{current}\n"
|
||||
)
|
||||
prompts.append(prompt)
|
||||
chunk_ids.append(i + 1) # 1-based indexing
|
||||
chunk_contents.append(current)
|
||||
|
||||
# First batch generation
|
||||
outputs = batch_generate(prompts)
|
||||
results: List[Optional[List[Tuple[str, str, str]]]] = [None] * len(outputs)
|
||||
failed_indices = []
|
||||
|
||||
# Parse each output
|
||||
for idx, output in enumerate(outputs):
|
||||
qa_pairs = parse_multiple_qa_output(output)
|
||||
if qa_pairs is None or len(qa_pairs) < num_questions:
|
||||
failed_indices.append(idx)
|
||||
logger.warning(f"Failed to generate enough QA pairs for chunk {idx + 1}")
|
||||
else:
|
||||
# Validate that answers exist in chunk content
|
||||
valid_pairs = []
|
||||
for q, a, d in qa_pairs:
|
||||
if a.lower() in chunk_contents[idx].lower():
|
||||
valid_pairs.append((q, a, d))
|
||||
else:
|
||||
logger.warning(f"Answer not found in chunk content: {a}")
|
||||
|
||||
if len(valid_pairs) >= num_questions:
|
||||
results[idx] = valid_pairs[:num_questions]
|
||||
else:
|
||||
failed_indices.append(idx)
|
||||
logger.warning(f"Not enough valid QA pairs for chunk {idx + 1}")
|
||||
|
||||
# Retry failed prompts in batch
|
||||
if failed_indices:
|
||||
logger.info(f"Retrying {len(failed_indices)} failed prompt(s)...")
|
||||
retry_prompts = [prompts[i] for i in failed_indices]
|
||||
retry_outputs = batch_generate(retry_prompts)
|
||||
for j, idx in enumerate(failed_indices):
|
||||
qa_pairs = parse_multiple_qa_output(retry_outputs[j])
|
||||
if qa_pairs is not None and len(qa_pairs) >= num_questions:
|
||||
# Validate answers again
|
||||
valid_pairs = []
|
||||
for q, a, d in qa_pairs:
|
||||
if a.lower() in chunk_contents[idx].lower():
|
||||
valid_pairs.append((q, a, d))
|
||||
|
||||
if len(valid_pairs) >= num_questions:
|
||||
results[idx] = valid_pairs[:num_questions]
|
||||
else:
|
||||
results[idx] = None
|
||||
logger.warning(
|
||||
f"Retry failed for chunk {idx + 1}: not enough valid QA pairs"
|
||||
)
|
||||
else:
|
||||
results[idx] = None
|
||||
logger.warning(f"Retry failed for chunk {idx + 1}: parsing failed")
|
||||
|
||||
# Build final output, skipping prompts that failed even after retry
|
||||
final_questions = []
|
||||
for i, qa_list in enumerate(results):
|
||||
if qa_list is not None:
|
||||
for qa in qa_list:
|
||||
final_questions.append(
|
||||
{
|
||||
"chunk_id": chunk_ids[i],
|
||||
"question": qa[0],
|
||||
"answer": qa[1],
|
||||
"difficulty": qa[2],
|
||||
"chunk_content": chunk_contents[i],
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Generated {len(final_questions)} valid QA pairs")
|
||||
return final_questions
|
||||
|
||||
|
||||
# Generate QA pairs in batch (using a sliding window over the chunks)
|
||||
all_questions = generate_question_batch_for_chunks(
|
||||
chunks, num_questions=2, difficulty="medium"
|
||||
)
|
||||
print(f"Generated {len(all_questions)} QA pairs.")
|
||||
|
||||
# Save the QA pairs to a JSON file
|
||||
questions_path = os.path.join("saved_data", "questions.json")
|
||||
with open(questions_path, "w") as f:
|
||||
json.dump(all_questions, f, indent=2)
|
||||
print(f"Saved questions to {questions_path}")
|
@ -0,0 +1,113 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install matplotlib -q"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"def plot_reward_functions():\n",
|
||||
" # Generate retry counts from 0 to 15\n",
|
||||
" retries = np.linspace(0, 15, 100)\n",
|
||||
" \n",
|
||||
" # 1. Basic Sigmoid\n",
|
||||
" basic_sigmoid = 1 / (1 + np.exp(-(retries - 4)))\n",
|
||||
" \n",
|
||||
" # 2. Our Modified Sigmoid\n",
|
||||
" x = retries - 4 # Center at 4 retries\n",
|
||||
" modified_sigmoid = 1 / (1 + np.exp(-x + abs(x)/2))\n",
|
||||
" \n",
|
||||
" # 3. With Penalty\n",
|
||||
" penalized_reward = modified_sigmoid.copy()\n",
|
||||
" for i, r in enumerate(retries):\n",
|
||||
" if r > 6:\n",
|
||||
" penalty = 0.2 * (r - 6)\n",
|
||||
" penalized_reward[i] = max(0.1, modified_sigmoid[i] - penalty)\n",
|
||||
" \n",
|
||||
" # Plotting\n",
|
||||
" plt.figure(figsize=(12, 6))\n",
|
||||
" \n",
|
||||
" plt.plot(retries, basic_sigmoid, 'b--', label='Basic Sigmoid')\n",
|
||||
" plt.plot(retries, modified_sigmoid, 'g--', label='Modified Sigmoid')\n",
|
||||
" plt.plot(retries, penalized_reward, 'r-', label='Final Reward (with penalty)', linewidth=2)\n",
|
||||
" \n",
|
||||
" # Add vertical lines for key points\n",
|
||||
" plt.axvline(x=4, color='gray', linestyle=':', alpha=0.5, label='Peak (4 retries)')\n",
|
||||
" plt.axvline(x=6, color='gray', linestyle=':', alpha=0.5, label='Penalty Start (6 retries)')\n",
|
||||
" \n",
|
||||
" plt.grid(True, alpha=0.3)\n",
|
||||
" plt.xlabel('Number of Retries')\n",
|
||||
" plt.ylabel('Reward')\n",
|
||||
" plt.title('Reward Function Visualization')\n",
|
||||
" plt.legend()\n",
|
||||
" plt.ylim(-0.1, 1.1)\n",
|
||||
" \n",
|
||||
" # Add annotations\n",
|
||||
" plt.annotate('Optimal Zone', xy=(4, 0.8), xytext=(4, 0.9),\n",
|
||||
" ha='center', va='bottom',\n",
|
||||
" bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3),\n",
|
||||
" arrowprops=dict(arrowstyle='->'))\n",
|
||||
" \n",
|
||||
" plt.annotate('Penalty Zone', xy=(8, 0.3), xytext=(8, 0.5),\n",
|
||||
" ha='center', va='bottom',\n",
|
||||
" bbox=dict(boxstyle='round,pad=0.5', fc='red', alpha=0.3),\n",
|
||||
" arrowprops=dict(arrowstyle='->'))\n",
|
||||
" \n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
"# Run the visualization\n",
|
||||
"plot_reward_functions()\n",
|
||||
"\n",
|
||||
"# Print reward values for specific retry counts\n",
|
||||
"def print_reward_examples():\n",
|
||||
" retry_examples = [1, 2, 3, 4, 5, 6, 7, 8, 10, 12]\n",
|
||||
" print(\"\\nReward values for different retry counts:\")\n",
|
||||
" print(\"Retries | Reward\")\n",
|
||||
" print(\"-\" * 20)\n",
|
||||
" \n",
|
||||
" for retries in retry_examples:\n",
|
||||
" x = retries - 4\n",
|
||||
" reward = 1 / (1 + np.exp(-x + abs(x)/2))\n",
|
||||
" if retries > 6:\n",
|
||||
" penalty = 0.2 * (retries - 6)\n",
|
||||
" reward = max(0.1, reward - penalty)\n",
|
||||
" print(f\"{retries:7d} | {reward:.3f}\")\n",
|
||||
"\n",
|
||||
"print_reward_examples()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Loading…
Reference in new issue