You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
173 lines
5.4 KiB
173 lines
5.4 KiB
"""
|
|
Search module for RL training loop.
|
|
This module provides functions to search through vectorized documents and retrieve question-answer pairs.
|
|
"""
|
|
|
|
import json
|
|
import random
|
|
|
|
from datasets import Dataset
|
|
from langchain.vectorstores import FAISS
|
|
|
|
from src.config import DATA_DIR, logger
|
|
from src.embeddings import CustomHuggingFaceEmbeddings
|
|
|
|
|
|
# Load pre-saved vectorstore
|
|
def load_vectorstore():
|
|
"""Load the pre-saved FAISS index"""
|
|
try:
|
|
embeddings = CustomHuggingFaceEmbeddings()
|
|
# Load the FAISS index from the data directory
|
|
logger.info(f"Loading FAISS index from: {DATA_DIR}")
|
|
vectorstore = FAISS.load_local(
|
|
str(DATA_DIR), embeddings, allow_dangerous_deserialization=True
|
|
)
|
|
logger.info("Successfully loaded FAISS index")
|
|
return vectorstore
|
|
except Exception as e:
|
|
logger.error(f"Error loading vectorstore: {e}")
|
|
import traceback
|
|
|
|
logger.debug(traceback.format_exc())
|
|
return None
|
|
|
|
|
|
# Load the vectorstore when module is imported
|
|
try:
|
|
vectorstore = load_vectorstore()
|
|
if vectorstore is None:
|
|
logger.warning("FAISS vectorstore could not be loaded.")
|
|
except Exception as e:
|
|
logger.error(f"Error loading vectorstore: {e}")
|
|
vectorstore = None
|
|
|
|
|
|
def search(query: str, return_type=str, results: int = 5):
|
|
"""
|
|
Search for relevant chunks using similarity search.
|
|
|
|
Args:
|
|
query: The search query
|
|
return_type: Return as string or list (default: str)
|
|
results: Number of results to return (default: 5)
|
|
|
|
Returns:
|
|
Results as string or list depending on return_type
|
|
"""
|
|
if vectorstore is None:
|
|
raise ValueError("Vectorstore not loaded. Please ensure FAISS index exists.")
|
|
|
|
search_results = vectorstore.similarity_search(query, k=results)
|
|
|
|
if return_type == str:
|
|
str_results = ""
|
|
for idx, result in enumerate(search_results, start=1):
|
|
str_results += f"Result {idx}:\n"
|
|
str_results += result.page_content + "\n"
|
|
str_results += "------\n"
|
|
return str_results
|
|
elif return_type == list:
|
|
return [result.page_content for result in search_results]
|
|
else:
|
|
raise ValueError("Invalid return_type. Use str or list.")
|
|
|
|
|
|
# Load questions from saved data
|
|
def load_qa_data():
|
|
"""Load the pre-generated questions"""
|
|
try:
|
|
questions_path = DATA_DIR / "questions.json"
|
|
logger.info(f"Loading questions from: {questions_path}")
|
|
|
|
# Load the questions
|
|
with open(questions_path, "r") as f:
|
|
questions = json.load(f)
|
|
|
|
logger.info(f"Successfully loaded {len(questions)} questions")
|
|
return questions
|
|
except Exception as e:
|
|
logger.error(f"Error loading QA data: {e}")
|
|
import traceback
|
|
|
|
logger.debug(traceback.format_exc())
|
|
return None
|
|
|
|
|
|
# Load questions when module is imported
|
|
try:
|
|
questions = load_qa_data()
|
|
if questions is None:
|
|
logger.warning("Could not load QA data.")
|
|
except Exception as e:
|
|
logger.error(f"Error initializing QA data: {e}")
|
|
questions = None
|
|
|
|
|
|
def get_question_answer(idx=None, return_both: bool = True) -> dict:
|
|
"""
|
|
Get a question-answer pair either by index or randomly.
|
|
|
|
Args:
|
|
idx: Index of the question to retrieve (if None, selects random question)
|
|
return_both: Whether to return both question and answer (default: True)
|
|
|
|
Returns:
|
|
Question and answer as tuple if return_both=True, otherwise just the question
|
|
"""
|
|
if questions is None:
|
|
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
|
|
|
|
if idx is None:
|
|
# Select a random question
|
|
qa_pair = random.choice(questions)
|
|
elif 0 <= idx < len(questions):
|
|
# Select question by index
|
|
qa_pair = questions[idx]
|
|
else:
|
|
raise ValueError(
|
|
f"Index out of range. Must be between 0 and {len(questions) - 1}"
|
|
)
|
|
|
|
question = qa_pair["question"]
|
|
answer = qa_pair["answer"]
|
|
|
|
if return_both:
|
|
return {"question": question, "answer": answer}
|
|
else:
|
|
return question
|
|
|
|
|
|
# Function to get the total number of questions
|
|
def get_question_count() -> int:
|
|
"""Get the total number of available questions"""
|
|
if questions is None:
|
|
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
|
|
return len(questions)
|
|
|
|
|
|
def get_qa_dataset() -> tuple:
|
|
"""
|
|
Return a HuggingFace Dataset containing question and answer pairs.
|
|
|
|
This dataset is constructed from the loaded questions data (questions.json).
|
|
Each element in the dataset is a dictionary that includes at least:
|
|
- "question": The question text.
|
|
- "answer": The corresponding answer text.
|
|
Additional keys present in the original questions data will also be included.
|
|
|
|
Returns:
|
|
A HuggingFace Dataset object.
|
|
"""
|
|
if questions is None:
|
|
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
|
|
|
|
qa_dataset = Dataset.from_list(questions)
|
|
full_dataset = qa_dataset.shuffle(seed=42)
|
|
train_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["train"]
|
|
test_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["test"]
|
|
# rename the column of the dataset from "question" to "input"
|
|
train_dataset = train_dataset.rename_column("question", "prompt")
|
|
test_dataset = test_dataset.rename_column("question", "prompt")
|
|
return train_dataset, test_dataset
|