You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
195 lines
6.1 KiB
195 lines
6.1 KiB
"""
|
|
Search module for RL training loop.
|
|
This module provides functions to search through vectorized documents and retrieve question-answer pairs.
|
|
"""
|
|
|
|
import pickle
|
|
import json
|
|
import random
|
|
import asyncio
|
|
from typing import List, Tuple, Optional, Union, Dict, Any
|
|
from enum import Enum
|
|
from pydantic import BaseModel
|
|
from langchain.vectorstores import FAISS
|
|
from datasets import Dataset
|
|
from embeddings import CustomHuggingFaceEmbeddings
|
|
|
|
|
|
# Load pre-saved vectorstore
|
|
def load_vectorstore():
|
|
"""Load the pre-saved FAISS index"""
|
|
try:
|
|
import os
|
|
|
|
embeddings = CustomHuggingFaceEmbeddings()
|
|
# Load the FAISS index with absolute path
|
|
index_path = os.path.join(
|
|
os.path.dirname(os.path.abspath(__file__)), "faiss_index"
|
|
)
|
|
print(f"Loading FAISS index from: {index_path}")
|
|
vectorstore = FAISS.load_local(
|
|
index_path, embeddings, allow_dangerous_deserialization=True
|
|
)
|
|
print("Successfully loaded FAISS index")
|
|
return vectorstore
|
|
except Exception as e:
|
|
print(f"Error loading vectorstore: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
return None
|
|
|
|
|
|
# Load the vectorstore when module is imported
|
|
try:
|
|
vectorstore = load_vectorstore()
|
|
if vectorstore is None:
|
|
print("Warning: FAISS vectorstore could not be loaded.")
|
|
except Exception as e:
|
|
print(f"Error loading vectorstore: {e}")
|
|
vectorstore = None
|
|
|
|
|
|
def search(query: str, return_type=str, results: int = 5) -> Union[str, List[str]]:
|
|
"""
|
|
Search for relevant chunks using similarity search.
|
|
|
|
Args:
|
|
query: The search query
|
|
return_type: Return as string or list (default: str)
|
|
results: Number of results to return (default: 5)
|
|
|
|
Returns:
|
|
Results as string or list depending on return_type
|
|
"""
|
|
if vectorstore is None:
|
|
raise ValueError("Vectorstore not loaded. Please ensure FAISS index exists.")
|
|
|
|
search_results = vectorstore.similarity_search(query, k=results)
|
|
|
|
if return_type == str:
|
|
str_results = ""
|
|
for idx, result in enumerate(search_results, start=1):
|
|
str_results += f"Result {idx}:\n"
|
|
str_results += result.page_content + "\n"
|
|
str_results += "------\n"
|
|
return str_results
|
|
elif return_type == list:
|
|
return [result.page_content for result in search_results]
|
|
else:
|
|
raise ValueError("Invalid return_type. Use str or list.")
|
|
|
|
|
|
# Load questions from saved data
|
|
def load_qa_data():
|
|
"""Load the pre-generated questions and document chunks"""
|
|
try:
|
|
import os
|
|
|
|
# Get absolute paths to data files
|
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
|
chunks_path = os.path.join(base_dir, "saved_data", "chunks.pkl")
|
|
questions_path = os.path.join(base_dir, "saved_data", "questions.json")
|
|
|
|
print(f"Loading chunks from: {chunks_path}")
|
|
print(f"Loading questions from: {questions_path}")
|
|
|
|
# Load the chunks
|
|
with open(chunks_path, "rb") as f:
|
|
chunks = pickle.load(f)
|
|
|
|
# Load the questions
|
|
with open(questions_path, "r") as f:
|
|
questions = json.load(f)
|
|
|
|
print(
|
|
f"Successfully loaded {len(chunks)} chunks and {len(questions)} questions"
|
|
)
|
|
return chunks, questions
|
|
except Exception as e:
|
|
print(f"Error loading QA data: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
return None, None
|
|
|
|
|
|
# Load chunks and questions when module is imported
|
|
try:
|
|
chunks, questions = load_qa_data()
|
|
if chunks is None or questions is None:
|
|
print("Warning: Could not load QA data.")
|
|
except Exception as e:
|
|
print(f"Error initializing QA data: {e}")
|
|
chunks, questions = None, None
|
|
|
|
|
|
def get_question_answer(
|
|
idx: Optional[int] = None, return_both: bool = True
|
|
) -> Union[dict, str]:
|
|
"""
|
|
Get a question-answer pair either by index or randomly.
|
|
|
|
Args:
|
|
idx: Index of the question to retrieve (if None, selects random question)
|
|
return_both: Whether to return both question and answer (default: True)
|
|
|
|
Returns:
|
|
Question and answer as tuple if return_both=True, otherwise just the question
|
|
"""
|
|
if questions is None:
|
|
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
|
|
|
|
if idx is None:
|
|
# Select a random question
|
|
qa_pair = random.choice(questions)
|
|
elif 0 <= idx < len(questions):
|
|
# Select question by index
|
|
qa_pair = questions[idx]
|
|
else:
|
|
raise ValueError(
|
|
f"Index out of range. Must be between 0 and {len(questions)-1}"
|
|
)
|
|
|
|
question = qa_pair["question"]
|
|
answer = qa_pair["answer"]
|
|
|
|
if return_both:
|
|
return {"question": question, "answer": answer}
|
|
else:
|
|
return question
|
|
|
|
|
|
# Function to get the total number of questions
|
|
def get_question_count() -> int:
|
|
"""Get the total number of available questions"""
|
|
if questions is None:
|
|
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
|
|
return len(questions)
|
|
|
|
|
|
def get_qa_dataset():
|
|
"""
|
|
Return a HuggingFace Dataset containing question and answer pairs.
|
|
|
|
This dataset is constructed from the loaded questions data (questions.json).
|
|
Each element in the dataset is a dictionary that includes at least:
|
|
- "question": The question text.
|
|
- "answer": The corresponding answer text.
|
|
Additional keys present in the original questions data will also be included.
|
|
|
|
Returns:
|
|
A HuggingFace Dataset object.
|
|
"""
|
|
if questions is None:
|
|
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
|
|
|
|
qa_dataset = Dataset.from_list(questions)
|
|
full_dataset = qa_dataset.shuffle(seed=42)
|
|
train_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["train"]
|
|
test_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["test"]
|
|
# rename the column of the dataset from "question" to "input"
|
|
train_dataset = train_dataset.rename_column("question", "prompt")
|
|
test_dataset = test_dataset.rename_column("question", "prompt")
|
|
return train_dataset, test_dataset
|