diff --git a/src/search_module.py b/src/search_module.py index 48aff90..f0fad85 100644 --- a/src/search_module.py +++ b/src/search_module.py @@ -142,7 +142,7 @@ def get_question_count() -> int: return len(questions) -def get_qa_dataset() -> tuple: +def get_qa_dataset(randomize: bool = False) -> tuple: """ Return a HuggingFace Dataset containing question and answer pairs. @@ -159,9 +159,10 @@ def get_qa_dataset() -> tuple: raise ValueError("Questions not loaded. Please ensure questions.json exists.") qa_dataset = Dataset.from_list(questions) - full_dataset = qa_dataset.shuffle(seed=42) - train_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["train"] - test_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["test"] + if randomize: + qa_dataset = qa_dataset.shuffle(seed=42) + train_dataset = qa_dataset.train_test_split(test_size=0.1, seed=42)["train"] + test_dataset = qa_dataset.train_test_split(test_size=0.1, seed=42)["test"] # rename the column of the dataset from "question" to "input" train_dataset = train_dataset.rename_column("question", "prompt") test_dataset = test_dataset.rename_column("question", "prompt")