|
|
|
@ -142,7 +142,7 @@ def get_question_count() -> int:
|
|
|
|
|
return len(questions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_qa_dataset() -> tuple:
|
|
|
|
|
def get_qa_dataset(randomize: bool = False) -> tuple:
|
|
|
|
|
"""
|
|
|
|
|
Return a HuggingFace Dataset containing question and answer pairs.
|
|
|
|
|
|
|
|
|
@ -159,9 +159,10 @@ def get_qa_dataset() -> tuple:
|
|
|
|
|
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
|
|
|
|
|
|
|
|
|
|
qa_dataset = Dataset.from_list(questions)
|
|
|
|
|
full_dataset = qa_dataset.shuffle(seed=42)
|
|
|
|
|
train_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["train"]
|
|
|
|
|
test_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["test"]
|
|
|
|
|
if randomize:
|
|
|
|
|
qa_dataset = qa_dataset.shuffle(seed=42)
|
|
|
|
|
train_dataset = qa_dataset.train_test_split(test_size=0.1, seed=42)["train"]
|
|
|
|
|
test_dataset = qa_dataset.train_test_split(test_size=0.1, seed=42)["test"]
|
|
|
|
|
# rename the column of the dataset from "question" to "input"
|
|
|
|
|
train_dataset = train_dataset.rename_column("question", "prompt")
|
|
|
|
|
test_dataset = test_dataset.rename_column("question", "prompt")
|
|
|
|
|