{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "sys.path.append(\"..\")\n", "\n", "from src.search_module import get_qa_dataset\n", "import random\n", "\n", "\n", "def inspect_qa_dataset():\n", " \"\"\"Inspect the QA dataset used for evaluation to identify potential issues\"\"\"\n", "\n", " # Get the datasets\n", " train_dataset, test_dataset = get_qa_dataset()\n", "\n", " # Print dataset statistics\n", " print(f\"Train dataset size: {len(train_dataset)}\")\n", " print(f\"Test dataset size: {len(test_dataset)}\")\n", "\n", " # Print column information\n", " print(f\"\\nTest dataset columns: {test_dataset.column_names}\")\n", "\n", " # Print a few random examples\n", " sample_size = min(5, len(test_dataset))\n", " sample_indices = random.sample(range(len(test_dataset)), sample_size)\n", "\n", " print(f\"\\n--- {sample_size} Random Test Examples ---\")\n", " for i, idx in enumerate(sample_indices):\n", " example = test_dataset[idx]\n", " print(f\"\\nExample {i+1}:\")\n", " print(f\"Prompt: {example['prompt']}\")\n", " print(f\"Answer: {example['answer']}\")\n", " if \"chunk_content\" in example:\n", " print(f\"Chunk Content: {example['chunk_content'][:200]}... (truncated)\")\n", "\n", " # Check for potential issues\n", " print(\"\\n--- Dataset Analysis ---\")\n", "\n", " # Check for duplicate questions\n", " prompts = test_dataset[\"prompt\"]\n", " duplicate_count = len(prompts) - len(set(prompts))\n", " print(f\"Duplicate prompts: {duplicate_count}\")\n", "\n", " # Check answer length distribution\n", " answer_lengths = [len(ans) for ans in test_dataset[\"answer\"]]\n", " avg_answer_length = sum(answer_lengths) / len(answer_lengths)\n", " min_answer_length = min(answer_lengths)\n", " max_answer_length = max(answer_lengths)\n", " print(\n", " f\"Answer length stats: min={min_answer_length}, avg={avg_answer_length:.1f}, max={max_answer_length}\"\n", " )\n", "\n", " # Analyze prompt types if possible\n", " if len(prompts) > 0:\n", " qa_count = sum(1 for p in prompts if p.endswith(\"?\"))\n", " print(\n", " f\"Questions ending with '?': {qa_count} ({qa_count/len(prompts)*100:.1f}%)\"\n", " )\n", "\n", "\n", "if __name__ == \"__main__\":\n", " inspect_qa_dataset()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "train_dataset, test_dataset = get_qa_dataset()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Sanity check 32 test cases: -> 31/32 is correct, nothing wrong with the test data here :/\n", "\n", "brow wtf is happening 😭" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_dataset" ] } ], "metadata": { "kernelspec": { "display_name": "deepsearch-py311", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 2 }