{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "from src.search_module import get_qa_dataset\n",
    "import random\n",
    "\n",
    "\n",
    "def inspect_qa_dataset():\n",
    "    \"\"\"Inspect the QA dataset used for evaluation to identify potential issues\"\"\"\n",
    "\n",
    "    # Get the datasets\n",
    "    train_dataset, test_dataset = get_qa_dataset()\n",
    "\n",
    "    # Print dataset statistics\n",
    "    print(f\"Train dataset size: {len(train_dataset)}\")\n",
    "    print(f\"Test dataset size: {len(test_dataset)}\")\n",
    "\n",
    "    # Print column information\n",
    "    print(f\"\\nTest dataset columns: {test_dataset.column_names}\")\n",
    "\n",
    "    # Print a few random examples\n",
    "    sample_size = min(5, len(test_dataset))\n",
    "    sample_indices = random.sample(range(len(test_dataset)), sample_size)\n",
    "\n",
    "    print(f\"\\n--- {sample_size} Random Test Examples ---\")\n",
    "    for i, idx in enumerate(sample_indices):\n",
    "        example = test_dataset[idx]\n",
    "        print(f\"\\nExample {i+1}:\")\n",
    "        print(f\"Prompt: {example['prompt']}\")\n",
    "        print(f\"Answer: {example['answer']}\")\n",
    "        if \"chunk_content\" in example:\n",
    "            print(f\"Chunk Content: {example['chunk_content'][:200]}... (truncated)\")\n",
    "\n",
    "    # Check for potential issues\n",
    "    print(\"\\n--- Dataset Analysis ---\")\n",
    "\n",
    "    # Check for duplicate questions\n",
    "    prompts = test_dataset[\"prompt\"]\n",
    "    duplicate_count = len(prompts) - len(set(prompts))\n",
    "    print(f\"Duplicate prompts: {duplicate_count}\")\n",
    "\n",
    "    # Check answer length distribution\n",
    "    answer_lengths = [len(ans) for ans in test_dataset[\"answer\"]]\n",
    "    avg_answer_length = sum(answer_lengths) / len(answer_lengths)\n",
    "    min_answer_length = min(answer_lengths)\n",
    "    max_answer_length = max(answer_lengths)\n",
    "    print(\n",
    "        f\"Answer length stats: min={min_answer_length}, avg={avg_answer_length:.1f}, max={max_answer_length}\"\n",
    "    )\n",
    "\n",
    "    # Analyze prompt types if possible\n",
    "    if len(prompts) > 0:\n",
    "        qa_count = sum(1 for p in prompts if p.endswith(\"?\"))\n",
    "        print(\n",
    "            f\"Questions ending with '?': {qa_count} ({qa_count/len(prompts)*100:.1f}%)\"\n",
    "        )\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    inspect_qa_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "train_dataset, test_dataset = get_qa_dataset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sanity check 32 test cases: -> 31/32 is correct, nothing wrong with the test data here :/\n",
    "\n",
    "brow wtf is happening 😭"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "deepsearch-py311",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}