You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
3.6 KiB

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"sys.path.append(\"..\")\n",
"\n",
"from src.search_module import get_qa_dataset\n",
"import random\n",
"\n",
"\n",
"def inspect_qa_dataset():\n",
" \"\"\"Inspect the QA dataset used for evaluation to identify potential issues\"\"\"\n",
"\n",
" # Get the datasets\n",
" train_dataset, test_dataset = get_qa_dataset()\n",
"\n",
" # Print dataset statistics\n",
" print(f\"Train dataset size: {len(train_dataset)}\")\n",
" print(f\"Test dataset size: {len(test_dataset)}\")\n",
"\n",
" # Print column information\n",
" print(f\"\\nTest dataset columns: {test_dataset.column_names}\")\n",
"\n",
" # Print a few random examples\n",
" sample_size = min(5, len(test_dataset))\n",
" sample_indices = random.sample(range(len(test_dataset)), sample_size)\n",
"\n",
" print(f\"\\n--- {sample_size} Random Test Examples ---\")\n",
" for i, idx in enumerate(sample_indices):\n",
" example = test_dataset[idx]\n",
" print(f\"\\nExample {i+1}:\")\n",
" print(f\"Prompt: {example['prompt']}\")\n",
" print(f\"Answer: {example['answer']}\")\n",
" if \"chunk_content\" in example:\n",
" print(f\"Chunk Content: {example['chunk_content'][:200]}... (truncated)\")\n",
"\n",
" # Check for potential issues\n",
" print(\"\\n--- Dataset Analysis ---\")\n",
"\n",
" # Check for duplicate questions\n",
" prompts = test_dataset[\"prompt\"]\n",
" duplicate_count = len(prompts) - len(set(prompts))\n",
" print(f\"Duplicate prompts: {duplicate_count}\")\n",
"\n",
" # Check answer length distribution\n",
" answer_lengths = [len(ans) for ans in test_dataset[\"answer\"]]\n",
" avg_answer_length = sum(answer_lengths) / len(answer_lengths)\n",
" min_answer_length = min(answer_lengths)\n",
" max_answer_length = max(answer_lengths)\n",
" print(\n",
" f\"Answer length stats: min={min_answer_length}, avg={avg_answer_length:.1f}, max={max_answer_length}\"\n",
" )\n",
"\n",
" # Analyze prompt types if possible\n",
" if len(prompts) > 0:\n",
" qa_count = sum(1 for p in prompts if p.endswith(\"?\"))\n",
" print(\n",
" f\"Questions ending with '?': {qa_count} ({qa_count/len(prompts)*100:.1f}%)\"\n",
" )\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" inspect_qa_dataset()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"train_dataset, test_dataset = get_qa_dataset()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Sanity check 32 test cases: -> 31/32 is correct, nothing wrong with the test data here :/\n",
"\n",
"brow wtf is happening 😭"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_dataset"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "deepsearch-py311",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}