You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
213 lines
7.0 KiB
213 lines
7.0 KiB
"""
|
|
Search module for RL training loop.
|
|
This module provides functions to search through vectorized documents and retrieve question-answer pairs.
|
|
"""
|
|
|
|
import json
|
|
import random
|
|
|
|
from datasets import Dataset
|
|
from langchain_community.vectorstores import FAISS
|
|
|
|
from config import DATA_DIR, logger
|
|
from src.embeddings import CustomHuggingFaceEmbeddings
|
|
|
|
PROCESSED_DATA_DIR = DATA_DIR
|
|
|
|
|
|
# Load pre-saved vectorstore
|
|
def load_vectorstore():
|
|
"""Load the pre-saved FAISS index"""
|
|
try:
|
|
embeddings = CustomHuggingFaceEmbeddings()
|
|
# Load the FAISS index from the data directory
|
|
logger.info(f"Loading FAISS index from: {PROCESSED_DATA_DIR}")
|
|
vectorstore = FAISS.load_local(str(PROCESSED_DATA_DIR), embeddings, allow_dangerous_deserialization=True)
|
|
logger.info("Successfully loaded FAISS index")
|
|
return vectorstore
|
|
except Exception as e:
|
|
logger.error(f"Error loading vectorstore: {e}")
|
|
import traceback
|
|
|
|
logger.debug(traceback.format_exc())
|
|
return None
|
|
|
|
|
|
# Load the vectorstore when module is imported
|
|
try:
|
|
vectorstore = load_vectorstore()
|
|
if vectorstore is None:
|
|
logger.warning("FAISS vectorstore could not be loaded.")
|
|
except Exception as e:
|
|
logger.error(f"Error loading vectorstore: {e}")
|
|
vectorstore = None
|
|
|
|
|
|
def search(query: str, return_type=str, results: int = 5):
|
|
"""
|
|
Search for relevant chunks using similarity search.
|
|
|
|
Args:
|
|
query: The search query
|
|
return_type: Return as string or list (default: str)
|
|
results: Number of results to return (default: 5)
|
|
|
|
Returns:
|
|
Results as string or list depending on return_type
|
|
"""
|
|
if vectorstore is None:
|
|
raise ValueError("Vectorstore not loaded. Please ensure FAISS index exists.")
|
|
|
|
search_results = vectorstore.similarity_search(query, k=results)
|
|
|
|
if return_type == str:
|
|
str_results = ""
|
|
for idx, result in enumerate(search_results, start=1):
|
|
str_results += f"Result {idx}:\n"
|
|
str_results += result.page_content + "\n"
|
|
str_results += "------\n"
|
|
return str_results
|
|
elif return_type == list:
|
|
return [result.page_content for result in search_results]
|
|
else:
|
|
raise ValueError("Invalid return_type. Use str or list.")
|
|
|
|
|
|
# Load questions from saved data
|
|
def load_qa_data(questions_path=None):
|
|
"""
|
|
Load the pre-generated questions
|
|
|
|
Args:
|
|
questions_path: Path to questions file (default: PROCESSED_DATA_DIR / "questions.jsonl")
|
|
|
|
Returns:
|
|
List of question-answer pairs
|
|
"""
|
|
try:
|
|
if questions_path is None:
|
|
questions_path = PROCESSED_DATA_DIR / "questions.jsonl"
|
|
|
|
logger.info(f"Loading questions from: {questions_path}")
|
|
|
|
# Load the questions
|
|
with open(questions_path, "r") as f:
|
|
questions = [json.loads(line) for line in f]
|
|
|
|
logger.info(f"Successfully loaded {len(questions)} questions")
|
|
return questions
|
|
except Exception as e:
|
|
logger.error(f"Error loading QA data: {e}")
|
|
import traceback
|
|
|
|
logger.debug(traceback.format_exc())
|
|
return None
|
|
|
|
|
|
# Load questions when module is imported
|
|
try:
|
|
questions = load_qa_data()
|
|
if questions is None:
|
|
logger.warning("Could not load QA data.")
|
|
except Exception as e:
|
|
logger.error(f"Error initializing QA data: {e}")
|
|
questions = None
|
|
|
|
|
|
def get_question_answer(idx=None, return_both: bool = True) -> dict:
|
|
"""
|
|
Get a question-answer pair either by index or randomly.
|
|
|
|
Args:
|
|
idx: Index of the question to retrieve (if None, selects random question)
|
|
return_both: Whether to return both question and answer (default: True)
|
|
|
|
Returns:
|
|
Question and answer as tuple if return_both=True, otherwise just the question
|
|
"""
|
|
if questions is None:
|
|
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
|
|
|
|
if idx is None:
|
|
# Select a random question
|
|
qa_pair = random.choice(questions)
|
|
elif 0 <= idx < len(questions):
|
|
# Select question by index
|
|
qa_pair = questions[idx]
|
|
else:
|
|
raise ValueError(f"Index out of range. Must be between 0 and {len(questions) - 1}")
|
|
|
|
question = qa_pair["question"]
|
|
answer = qa_pair["answer"]
|
|
|
|
if return_both:
|
|
return {"question": question, "answer": answer}
|
|
else:
|
|
return question
|
|
|
|
|
|
# Function to get the total number of questions
|
|
def get_question_count() -> int:
|
|
"""Get the total number of available questions"""
|
|
if questions is None:
|
|
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
|
|
return len(questions)
|
|
|
|
|
|
def get_qa_dataset(randomize: bool = False, test_size: float = 0.1, seed: int = 42, questions_path=None) -> tuple:
|
|
"""
|
|
Return a HuggingFace Dataset containing question and answer pairs.
|
|
|
|
This dataset is constructed from the loaded questions data.
|
|
Each element in the dataset is a dictionary that includes at least:
|
|
- "question": The question text.
|
|
- "answer": The corresponding answer text.
|
|
- "supporting_paragraphs": The supporting paragraphs for the question.
|
|
Additional keys present in the original questions data will also be included.
|
|
|
|
Args:
|
|
randomize: Whether to shuffle the dataset
|
|
test_size: Proportion of the dataset to include in the test split (0 for train-only)
|
|
seed: Random seed for reproducibility
|
|
questions_path: Path to questions.jsonl file (if None, uses globally loaded questions)
|
|
|
|
Returns:
|
|
A tuple of (train_dataset, test_dataset) HuggingFace Dataset objects.
|
|
If test_size=0, test_dataset will be empty. If test_size=1, train_dataset will be empty.
|
|
"""
|
|
qa_data = questions
|
|
|
|
if questions_path is not None:
|
|
qa_data = load_qa_data(questions_path)
|
|
|
|
if qa_data is None:
|
|
raise ValueError("Questions not loaded. Please ensure questions.jsonl exists.")
|
|
|
|
qa_dataset = Dataset.from_list(qa_data)
|
|
if randomize:
|
|
qa_dataset = qa_dataset.shuffle(seed=seed)
|
|
|
|
# Create empty dataset for when train or test size is 0
|
|
empty_dataset = Dataset.from_list([])
|
|
|
|
if test_size <= 0:
|
|
# Only train dataset, empty test dataset
|
|
train_dataset = qa_dataset
|
|
train_dataset = train_dataset.rename_column("question", "prompt")
|
|
return train_dataset, empty_dataset
|
|
elif test_size >= 1:
|
|
# Only test dataset, empty train dataset
|
|
test_dataset = qa_dataset
|
|
test_dataset = test_dataset.rename_column("question", "prompt")
|
|
return empty_dataset, test_dataset
|
|
else:
|
|
# Both train and test datasets
|
|
split = qa_dataset.train_test_split(test_size=test_size, seed=seed)
|
|
train_dataset = split["train"]
|
|
test_dataset = split["test"]
|
|
|
|
# rename the column of the dataset from "question" to "input"
|
|
train_dataset = train_dataset.rename_column("question", "prompt")
|
|
test_dataset = test_dataset.rename_column("question", "prompt")
|
|
return train_dataset, test_dataset
|