From 010957cd995bf3d2651e212be6747b74582d14f7 Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Fri, 4 Apr 2025 14:56:31 +0700 Subject: [PATCH] feat: disable randomization option to get_qa_dataset function by default --- src/search_module.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/search_module.py b/src/search_module.py index 48aff90..f0fad85 100644 --- a/src/search_module.py +++ b/src/search_module.py @@ -142,7 +142,7 @@ def get_question_count() -> int: return len(questions) -def get_qa_dataset() -> tuple: +def get_qa_dataset(randomize: bool = False) -> tuple: """ Return a HuggingFace Dataset containing question and answer pairs. @@ -159,9 +159,10 @@ def get_qa_dataset() -> tuple: raise ValueError("Questions not loaded. Please ensure questions.json exists.") qa_dataset = Dataset.from_list(questions) - full_dataset = qa_dataset.shuffle(seed=42) - train_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["train"] - test_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["test"] + if randomize: + qa_dataset = qa_dataset.shuffle(seed=42) + train_dataset = qa_dataset.train_test_split(test_size=0.1, seed=42)["train"] + test_dataset = qa_dataset.train_test_split(test_size=0.1, seed=42)["test"] # rename the column of the dataset from "question" to "input" train_dataset = train_dataset.rename_column("question", "prompt") test_dataset = test_dataset.rename_column("question", "prompt")