feat: disable randomization option to get_qa_dataset function by default

main
thinhlpg 1 month ago
parent 56911a73f9
commit 010957cd99

@ -142,7 +142,7 @@ def get_question_count() -> int:
return len(questions)
def get_qa_dataset() -> tuple:
def get_qa_dataset(randomize: bool = False) -> tuple:
"""
Return a HuggingFace Dataset containing question and answer pairs.
@ -159,9 +159,10 @@ def get_qa_dataset() -> tuple:
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
qa_dataset = Dataset.from_list(questions)
full_dataset = qa_dataset.shuffle(seed=42)
train_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["train"]
test_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["test"]
if randomize:
qa_dataset = qa_dataset.shuffle(seed=42)
train_dataset = qa_dataset.train_test_split(test_size=0.1, seed=42)["train"]
test_dataset = qa_dataset.train_test_split(test_size=0.1, seed=42)["test"]
# rename the column of the dataset from "question" to "input"
train_dataset = train_dataset.rename_column("question", "prompt")
test_dataset = test_dataset.rename_column("question", "prompt")

Loading…
Cancel
Save