From 4a1d45271d566fd84cdc862c1d2d66ce21093ea8 Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Fri, 11 Apr 2025 17:18:18 +0000 Subject: [PATCH] feat: add scripts for musique data processing --- Makefile | 42 ++- notebooks/250410_cook_better_data.ipynb | 278 ++++++++++++------ pyproject.toml | 1 + scripts/check_data.py | 117 ++++++++ scripts/train_data/build_musique_index.py | 135 +++++++++ scripts/train_data/download_data_musique.sh | 30 ++ .../train_data/extract_musique_paragraphs.py | 101 +++++++ scripts/train_data/prepare_musique_jsonl.py | 172 +++++++++++ src/embeddings.py | 2 +- 9 files changed, 790 insertions(+), 88 deletions(-) create mode 100644 scripts/check_data.py create mode 100644 scripts/train_data/build_musique_index.py create mode 100644 scripts/train_data/download_data_musique.sh create mode 100644 scripts/train_data/extract_musique_paragraphs.py create mode 100644 scripts/train_data/prepare_musique_jsonl.py diff --git a/Makefile b/Makefile index 32e77ba..efd5f74 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: style quality install tensorboard clean fix update-worklog test +.PHONY: style quality install tensorboard clean fix update-worklog test data download-musique prepare-musique-jsonl extract-musique-paragraphs build-musique-index prepare-musique-index prepare-all-musique check-data # make sure to test the local checkout in scripts and not the pre-installed one export PYTHONPATH = src @@ -37,6 +37,43 @@ list-runs: @echo "Available run directories:" @ls -d trainer_output_*_runs 2>/dev/null || echo "No run directories found" +# Data Preparation +data: prepare-musique-jsonl + @echo "Data preparation complete." + +# Index Preparation +prepare-musique-index: build-musique-index + @echo "Musique index preparation complete." + +download-musique: + @echo "Downloading Musique dataset..." + bash scripts/train_data/download_data_musique.sh + @echo "Musique dataset ready in ./data/raw/" + +prepare-musique-jsonl: download-musique + @echo "Preparing Musique data (JSONL)..." + python scripts/train_data/prepare_musique_jsonl.py + @echo "Processed Musique JSONL ready in ./data/processed/questions.jsonl" + +extract-musique-paragraphs: download-musique + @echo "Extracting unique paragraphs from raw Musique data..." + python scripts/train_data/extract_musique_paragraphs.py + @echo "Musique paragraphs extracted to ./data/processed/paragraphs.csv" + +build-musique-index: extract-musique-paragraphs + @echo "Building Musique FAISS index from paragraphs..." + python scripts/train_data/build_musique_index.py + @echo "Musique FAISS index files saved to ./data/processed/" + +# Combined Preparation +prepare-all-musique: data prepare-musique-index + @echo "All Musique data and index preparation complete." + +# Check Data +check-data: prepare-all-musique + @echo "Checking generated data files..." + python scripts/check_data.py + # Clean up clean: find . -type d -name "__pycache__" -exec rm -r {} + @@ -52,6 +89,9 @@ clean: find . -type d -name "htmlcov" -exec rm -r {} + find . -type d -name "build" -exec rm -r {} + find . -type d -name "dist" -exec rm -r {} + + rm -rf ./data/raw ./data/processed # Clean raw and processed data + # Clean up the old faiss_index directory if it exists + rm -rf ./data/processed/faiss_index # Update worklog in GitHub issue update-worklog: diff --git a/notebooks/250410_cook_better_data.ipynb b/notebooks/250410_cook_better_data.ipynb index cf0165b..8d50569 100644 --- a/notebooks/250410_cook_better_data.ipynb +++ b/notebooks/250410_cook_better_data.ipynb @@ -166,7 +166,13 @@ "\n", "Unique characters in answers (lowercased):\n", " \"$%&'()+,-./0123456789:`abcdefghijklmnopqrstuvwxyz€‚‡ˆ’“™¡¢£¤¥§¨©ª¬­°±³¶¸ºáâãäæè\n", - " - bro wtf is \"€‚‡ˆ’“™\"????\n" + " - bro wtf is \"€‚‡ˆ’“™\"????\n", + "\n", + "- 600 system + 50 (think) + 600 (infor) + 50 (think) + 600 (info) ..... \n", + "- ---------------1----------------------------2------------------------- -> 650 for each genrations ->\n", + "- 24k: max_generations from 32 to 36 should be good\n", + "- 16k: max_generations from 20 to 24 should be good\n", + "- 8k: 10 to 12 generations should be good " ] }, { @@ -196,15 +202,15 @@ "prefix_counts = Counter()\n", "\n", "# Read the jsonl file\n", - "with open(file_path, 'r', encoding='utf-8') as f:\n", + "with open(file_path, \"r\", encoding=\"utf-8\") as f:\n", " for line in f:\n", " total_rows += 1\n", " data = json.loads(line)\n", - " unique_ids.add(data['id'])\n", - " \n", + " unique_ids.add(data[\"id\"])\n", + "\n", " # Count prefixes\n", - " for prefix in ['2hop_', '3hop1_', '3hop2_', '4hop1_', '4hop2_', '4hop3_']:\n", - " if data['id'].startswith(prefix):\n", + " for prefix in [\"2hop_\", \"3hop1_\", \"3hop2_\", \"4hop1_\", \"4hop2_\", \"4hop3_\"]:\n", + " if data[\"id\"].startswith(prefix):\n", " prefix_counts[prefix] += 1\n", " break\n", "\n", @@ -252,67 +258,63 @@ "total_samples = 0\n", "\n", "# For hop analysis\n", - "hop_question_lengths = {\n", - " '2hop': [],\n", - " '3hop': [],\n", - " '4hop': []\n", - "}\n", + "hop_question_lengths = {\"2hop\": [], \"3hop\": [], \"4hop\": []}\n", "\n", "# Read the jsonl file\n", - "with open(file_path, 'r', encoding='utf-8') as f:\n", + "with open(file_path, \"r\", encoding=\"utf-8\") as f:\n", " for line in f:\n", " data = json.loads(line)\n", " total_samples += 1\n", - " \n", + "\n", " # Check if answerable\n", - " if data.get('answerable', False):\n", + " if data.get(\"answerable\", False):\n", " answerable_count += 1\n", - " \n", + "\n", " # Check answer aliases\n", - " if data.get('answer_aliases') and len(data['answer_aliases']) > 0:\n", + " if data.get(\"answer_aliases\") and len(data[\"answer_aliases\"]) > 0:\n", " non_empty_aliases_count += 1\n", - " \n", + "\n", " # Collect unique characters in questions and answers\n", - " if 'question' in data:\n", - " question = data['question']\n", + " if \"question\" in data:\n", + " question = data[\"question\"]\n", " question_chars.update(question.lower())\n", - " \n", + "\n", " # Count words and characters in question\n", " question_words = question.split()\n", " question_word_lengths.append(len(question_words))\n", " question_char_lengths.append(len(question))\n", - " \n", + "\n", " # Extract hop count from ID for analysis\n", - " if 'id' in data:\n", - " if data['id'].startswith('2hop'):\n", - " hop_question_lengths['2hop'].append(len(question_words))\n", - " elif data['id'].startswith('3hop'):\n", - " hop_question_lengths['3hop'].append(len(question_words))\n", - " elif data['id'].startswith('4hop'):\n", - " hop_question_lengths['4hop'].append(len(question_words))\n", - " \n", - " if 'answer' in data:\n", + " if \"id\" in data:\n", + " if data[\"id\"].startswith(\"2hop\"):\n", + " hop_question_lengths[\"2hop\"].append(len(question_words))\n", + " elif data[\"id\"].startswith(\"3hop\"):\n", + " hop_question_lengths[\"3hop\"].append(len(question_words))\n", + " elif data[\"id\"].startswith(\"4hop\"):\n", + " hop_question_lengths[\"4hop\"].append(len(question_words))\n", + "\n", + " if \"answer\" in data:\n", " # Handle unicode escape sequences\n", - " answer = data['answer'].encode().decode('unicode_escape')\n", + " answer = data[\"answer\"].encode().decode(\"unicode_escape\")\n", " answer_chars.update(answer.lower())\n", - " \n", + "\n", " # Count words and characters in answer\n", " answer_words = answer.split()\n", " answer_word_lengths.append(len(answer_words))\n", " answer_char_lengths.append(len(answer))\n", - " \n", + "\n", " # Process paragraphs\n", - " for para in data.get('paragraphs', []):\n", - " if 'paragraph_text' in para:\n", + " for para in data.get(\"paragraphs\", []):\n", + " if \"paragraph_text\" in para:\n", " # Handle unicode escape sequences\n", - " text = para['paragraph_text'].encode().decode('unicode_escape')\n", + " text = para[\"paragraph_text\"].encode().decode(\"unicode_escape\")\n", " words = text.split()\n", " para_word_lengths.append(len(words))\n", " para_char_lengths.append(len(text))\n", - " \n", - " if 'title' in para:\n", + "\n", + " if \"title\" in para:\n", " # Handle unicode escape sequences\n", - " title = para['title'].encode().decode('unicode_escape')\n", + " title = para[\"title\"].encode().decode(\"unicode_escape\")\n", " title_words = title.split()\n", " title_word_lengths.append(len(title_words))\n", " title_char_lengths.append(len(title))\n", @@ -322,64 +324,66 @@ "\n", "# Plot paragraph length distributions\n", "axs[0, 0].hist(para_word_lengths, bins=50, alpha=0.7)\n", - "axs[0, 0].set_title('Paragraph Length (Words)')\n", - "axs[0, 0].set_xlabel('Number of Words')\n", - "axs[0, 0].set_ylabel('Frequency')\n", + "axs[0, 0].set_title(\"Paragraph Length (Words)\")\n", + "axs[0, 0].set_xlabel(\"Number of Words\")\n", + "axs[0, 0].set_ylabel(\"Frequency\")\n", "\n", "axs[0, 1].hist(para_char_lengths, bins=50, alpha=0.7)\n", - "axs[0, 1].set_title('Paragraph Length (Characters)')\n", - "axs[0, 1].set_xlabel('Number of Characters')\n", - "axs[0, 1].set_ylabel('Frequency')\n", + "axs[0, 1].set_title(\"Paragraph Length (Characters)\")\n", + "axs[0, 1].set_xlabel(\"Number of Characters\")\n", + "axs[0, 1].set_ylabel(\"Frequency\")\n", "\n", "# Plot title length distributions\n", "axs[1, 0].hist(title_word_lengths, bins=30, alpha=0.7)\n", - "axs[1, 0].set_title('Title Length (Words)')\n", - "axs[1, 0].set_xlabel('Number of Words')\n", - "axs[1, 0].set_ylabel('Frequency')\n", + "axs[1, 0].set_title(\"Title Length (Words)\")\n", + "axs[1, 0].set_xlabel(\"Number of Words\")\n", + "axs[1, 0].set_ylabel(\"Frequency\")\n", "\n", "axs[1, 1].hist(title_char_lengths, bins=30, alpha=0.7)\n", - "axs[1, 1].set_title('Title Length (Characters)')\n", - "axs[1, 1].set_xlabel('Number of Characters')\n", - "axs[1, 1].set_ylabel('Frequency')\n", + "axs[1, 1].set_title(\"Title Length (Characters)\")\n", + "axs[1, 1].set_xlabel(\"Number of Characters\")\n", + "axs[1, 1].set_ylabel(\"Frequency\")\n", "\n", "# Plot answer length distributions\n", "axs[2, 0].hist(answer_word_lengths, bins=30, alpha=0.7)\n", - "axs[2, 0].set_title('Answer Length (Words)')\n", - "axs[2, 0].set_xlabel('Number of Words')\n", - "axs[2, 0].set_ylabel('Frequency')\n", + "axs[2, 0].set_title(\"Answer Length (Words)\")\n", + "axs[2, 0].set_xlabel(\"Number of Words\")\n", + "axs[2, 0].set_ylabel(\"Frequency\")\n", "\n", "axs[2, 1].hist(answer_char_lengths, bins=30, alpha=0.7)\n", - "axs[2, 1].set_title('Answer Length (Characters)')\n", - "axs[2, 1].set_xlabel('Number of Characters')\n", - "axs[2, 1].set_ylabel('Frequency')\n", + "axs[2, 1].set_title(\"Answer Length (Characters)\")\n", + "axs[2, 1].set_xlabel(\"Number of Characters\")\n", + "axs[2, 1].set_ylabel(\"Frequency\")\n", "\n", "# Plot question length distributions\n", "axs[3, 0].hist(question_word_lengths, bins=30, alpha=0.7)\n", - "axs[3, 0].set_title('Question Length (Words)')\n", - "axs[3, 0].set_xlabel('Number of Words')\n", - "axs[3, 0].set_ylabel('Frequency')\n", + "axs[3, 0].set_title(\"Question Length (Words)\")\n", + "axs[3, 0].set_xlabel(\"Number of Words\")\n", + "axs[3, 0].set_ylabel(\"Frequency\")\n", "\n", "# Plot question length by hop count\n", - "hop_labels = ['2hop', '3hop', '4hop']\n", + "hop_labels = [\"2hop\", \"3hop\", \"4hop\"]\n", "hop_means = [np.mean(hop_question_lengths[hop]) for hop in hop_labels]\n", "hop_counts = [len(hop_question_lengths[hop]) for hop in hop_labels]\n", "\n", "axs[3, 1].bar(hop_labels, hop_means, alpha=0.7)\n", - "axs[3, 1].set_title('Average Question Length by Hop Count')\n", - "axs[3, 1].set_xlabel('Hop Count')\n", - "axs[3, 1].set_ylabel('Average Number of Words')\n", + "axs[3, 1].set_title(\"Average Question Length by Hop Count\")\n", + "axs[3, 1].set_xlabel(\"Hop Count\")\n", + "axs[3, 1].set_ylabel(\"Average Number of Words\")\n", "\n", "# Add count labels on top of bars\n", "for i, (count, mean) in enumerate(zip(hop_counts, hop_means)):\n", - " axs[3, 1].text(i, mean + 0.5, f'n={count}\\n{mean:.1f}', ha='center')\n", + " axs[3, 1].text(i, mean + 0.5, f\"n={count}\\n{mean:.1f}\", ha=\"center\")\n", "\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "# Print statistics\n", "print(f\"Total samples: {total_samples}\")\n", - "print(f\"Answerable samples: {answerable_count} ({answerable_count/total_samples*100:.2f}%)\")\n", - "print(f\"Samples with non-empty answer_aliases: {non_empty_aliases_count} ({non_empty_aliases_count/total_samples*100:.2f}%)\")\n", + "print(f\"Answerable samples: {answerable_count} ({answerable_count / total_samples * 100:.2f}%)\")\n", + "print(\n", + " f\"Samples with non-empty answer_aliases: {non_empty_aliases_count} ({non_empty_aliases_count / total_samples * 100:.2f}%)\"\n", + ")\n", "\n", "# Print paragraph length statistics\n", "print(\"\\nParagraph length statistics (words):\")\n", @@ -428,10 +432,10 @@ "\n", "# Print unique characters\n", "print(\"\\nUnique characters in questions (lowercased):\")\n", - "print(''.join(sorted(question_chars)))\n", + "print(\"\".join(sorted(question_chars)))\n", "\n", "print(\"\\nUnique characters in answers (lowercased):\")\n", - "print(''.join(sorted(answer_chars)))\n" + "print(\"\".join(sorted(answer_chars)))\n" ] }, { @@ -447,34 +451,40 @@ "# Define suspicious characters (non-ASCII and special characters)\n", "suspicious_chars = \"€‚‡ˆ’“™\"\n", "\n", + "\n", "# Function to check if text contains any suspicious characters\n", "def contains_suspicious_chars(text, chars_to_check):\n", " return any(char in text for char in chars_to_check)\n", "\n", + "\n", "# Lists to store samples with suspicious characters\n", "question_samples = []\n", "answer_samples = []\n", "\n", "# Read the jsonl file again to find examples\n", - "with open(file_path, 'r', encoding='utf-8') as f:\n", + "with open(file_path, \"r\", encoding=\"utf-8\") as f:\n", " for line in f:\n", " data = json.loads(line)\n", - " \n", + "\n", " # Check question\n", - " if 'question' in data and contains_suspicious_chars(data['question'].lower(), suspicious_chars):\n", - " question_samples.append({\n", - " 'id': data.get('id', 'unknown'),\n", - " 'question': data['question'],\n", - " 'suspicious_chars': [char for char in data['question'] if char.lower() in suspicious_chars]\n", - " })\n", - " \n", + " if \"question\" in data and contains_suspicious_chars(data[\"question\"].lower(), suspicious_chars):\n", + " question_samples.append(\n", + " {\n", + " \"id\": data.get(\"id\", \"unknown\"),\n", + " \"question\": data[\"question\"],\n", + " \"suspicious_chars\": [char for char in data[\"question\"] if char.lower() in suspicious_chars],\n", + " }\n", + " )\n", + "\n", " # Check answer\n", - " if 'answer' in data and contains_suspicious_chars(data['answer'].lower(), suspicious_chars):\n", - " answer_samples.append({\n", - " 'id': data.get('id', 'unknown'),\n", - " 'answer': data['answer'],\n", - " 'suspicious_chars': [char for char in data['answer'] if char.lower() in suspicious_chars]\n", - " })\n", + " if \"answer\" in data and contains_suspicious_chars(data[\"answer\"].lower(), suspicious_chars):\n", + " answer_samples.append(\n", + " {\n", + " \"id\": data.get(\"id\", \"unknown\"),\n", + " \"answer\": data[\"answer\"],\n", + " \"suspicious_chars\": [char for char in data[\"answer\"] if char.lower() in suspicious_chars],\n", + " }\n", + " )\n", "\n", "# Print some samples with suspicious characters in questions\n", "print(f\"Found {len(question_samples)} samples with suspicious characters in questions\")\n", @@ -866,6 +876,103 @@ "}\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# example 4 hop - original\n", + "{\n", + " \"id\": \"4hop2__160585_14670_8987_8974\",\n", + " \"paragraphs\": [\n", + " {\n", + " \"idx\": 0,\n", + " \"title\": \"United States Army\",\n", + " \"paragraph_text\": \"Currently, the army is divided into the Regular Army, the Army Reserve, and the Army National Guard. The army is also divided into major branches such as Air Defense Artillery, Infantry, Aviation, Signal Corps, Corps of Engineers, and Armor. Before 1903 members of the National Guard were considered state soldiers unless federalized (i.e., activated) by the President. Since the Militia Act of 1903 all National Guard soldiers have held dual status: as National Guardsmen under the authority of the governor of their state or territory and, when activated, as a reserve of the U.S. Army under the authority of the President.\",\n", + " \"is_supporting\": true,\n", + " },\n", + " {\n", + " \"idx\": 1,\n", + " \"title\": \"Iron Beam\",\n", + " \"paragraph_text\": 'Iron Beam (, \"\") is an air defense system which is in development by Israeli defense contractor Rafael Advanced Defense Systems. Unveiled at the Singapore Air Show on February 11, 2014. The system is designed to destroy short-range rockets, artillery, and mortars with a range of up to , too small for the Iron Dome system to intercept effectively. In addition, the system could also intercept unmanned aerial vehicles. Iron Beam will use a \"directed high energy laser beam\" to destroy hostile targets with ranges of up to . Iron Beam will constitute the fifth element of Israel\\'s integrated air defense system, in addition to Arrow 2, Arrow 3, David\\'s Sling, and Iron Dome. However, Iron Beam is also a stand-alone system.',\n", + " \"is_supporting\": false,\n", + " },\n", + " ...\n", + " {\n", + " \"idx\": 19,\n", + " \"title\": \"Josip Broz Tito\",\n", + " \"paragraph_text\": \"In 1968, Tito offered Czechoslovak leader Alexander Dub\\u010dek to fly to Prague on three hours notice if Dub\\u010dek needed help in facing down the Soviets. In April 1969, Tito removed generals Ivan Go\\u0161njak and Rade Hamovi\\u0107 in the aftermath of the invasion of Czechoslovakia due to the unpreparedness of the Yugoslav army to respond to a similar invasion of Yugoslavia.\",\n", + " \"is_supporting\": true,\n", + " },\n", + " ],\n", + " \"question\": \"Hana Mandlikova was born in Country A that invaded Country B because the military branch the Air Defense Artillery is part of was unprepared. Country B was the only communist country to have an embassy where?\",\n", + " \"question_decomposition\": [\n", + " {\n", + " \"id\": 160585,\n", + " \"question\": \"Where was Hana Mandlikova born?\",\n", + " \"answer\": \"Czechoslovakia\",\n", + " \"paragraph_support_idx\": 15,\n", + " },\n", + " {\n", + " \"id\": 14670,\n", + " \"question\": \"The Air Defense Artillery is a branch of what?\",\n", + " \"answer\": \"the Army\",\n", + " \"paragraph_support_idx\": 0,\n", + " },\n", + " {\n", + " \"id\": 8987,\n", + " \"question\": \"What #2 was unprepared for the invasion of #1 ?\",\n", + " \"answer\": \"Yugoslavia\",\n", + " \"paragraph_support_idx\": 19,\n", + " },\n", + " {\n", + " \"id\": 8974,\n", + " \"question\": \"#3 was the only communist country to have an embassy where?\",\n", + " \"answer\": \"Alfredo Stroessner's Paraguay\",\n", + " \"paragraph_support_idx\": 2,\n", + " },\n", + " ],\n", + " \"answer\": \"Alfredo Stroessner's Paraguay\",\n", + " \"answer_aliases\": [\"Alfredo Stroessner\"],\n", + " \"answerable\": true,\n", + "}\n", + "\n", + "# example format - desired\n", + "{\n", + " \"id\": \"4hop2__160585_14670_8987_8974\",\n", + " \"question\": \"Hana Mandlikova was born in Country A that invaded Country B because the military branch the Air Defense Artillery is part of was unprepared. Country B was the only communist country to have an embassy where?\",\n", + " \"answer\": \"Alfredo Stroessner's Paraguay\",\n", + " \"supporting_paragraphs\": [ # this is the list of paragraphs that is_supporting=True\n", + " \"Currently, the army is divided into the Regular Army, the Army Reserve, and the Army National Guard. The army is also divided into major branches such as Air Defense Artillery, Infantry, Aviation, Signal Corps, Corps of Engineers, and Armor. Before 1903 members of the National Guard were considered state soldiers unless federalized (i.e., activated) by the President. Since the Militia Act of 1903 all National Guard soldiers have held dual status: as National Guardsmen under the authority of the governor of their state or territory and, when activated, as a reserve of the U.S. Army under the authority of the President.\",\n", + " \"bla bla bla\",\n", + " \"bla bla bla\",\n", + " \"In 1968, Tito offered Czechoslovak leader Alexander Dub\\u010dek to fly to Prague on three hours notice if Dub\\u010dek needed help in facing down the Soviets. In April 1969, Tito removed generals Ivan Go\\u0161njak and Rade Hamovi\\u0107 in the aftermath of the invasion of Czechoslovakia due to the unpreparedness of the Yugoslav army to respond to a similar invasion of Yugoslavia.\",\n", + " ]\n", + "},\n", + "\n", + "# quick check from the script:\n", + "{\"id\": \"4hop1__337568_570923_833580_61459\", \n", + " \"question\": \"One of the actors in a Pound of Flesh is from a European county whose king Albert I lived during a major war. When did Italy enter that war?\", \"answer\": \"1915\", \"supporting_paragraphs\": [\"The Queen Elisabeth Medical Foundation (QEMF) is a Belgian non-profit organization, founded in 1926 by Elisabeth of Bavaria, wife of Albert I. She founded the organization, based on her experience with the wounded from the front-line during the First World War. The foundation wants to encourage laboratory research and contacts between researchers and clinical practitioners, with a particular focus on neurosciences. The QEMF supports seventeen university teams throughout Belgium.\", \"On 3 May 1915 Italy officially revoked the Triple Alliance. In the following days Giolitti and the neutralist majority of the Parliament opposed declaring war, while nationalist crowds demonstrated in public areas for it. (The nationalist poet Gabriele D'Annunzio called this period le radiose giornate di Maggio -- ``the sunny days of May ''). Giolitti had the support of the majority of Italian parliament so on 13 May Salandra offered his resignation to King Victor Emmanuel III, but then Giolitti learned that the London Pact was already signed: fearful of a conflict between the Crown and the Parliament and the consequences on both internal stability and foreign relationships, Giolitti accepted the fait accompli, declined to succeed as prime minister and Salandra's resignation was not accepted. On 23 May, Italy declared war on Austria - Hungary. This was followed by declarations of war on the Ottoman Empire (21 August 1915, following an ultimatum of 3 August), Bulgaria (19 October 1915) and the German Empire (28 August 1916).\", \"JCVD is a 2008 Belgian crime drama film directed by French Tunisian film director Mabrouk el Mechri, and starring Jean-Claude van Damme as a semi-fictionalized version of himself, a down and out action star whose family and career are crumbling around him as he is caught in the middle of a post office heist in his hometown of Brussels, Belgium.\", \"Pound of Flesh is a 2015 Canadian action thriller film directed by Ernie Barbarash, and starring Jean-Claude Van Damme and Darren Shahlavi. It is the third collaboration between Van Damme and Barbarash (following \\\"Assassination Games\\\" in 2011 and \\\"Six Bullets\\\" in 2012).\"]}\n", + "\n", + "\n", + "# also, need a script to extract all paragraphs from the original data (ensure they are unique) \n", + "# FROM ALL TRAIN DEV TEST SPLIT\n", + "# to a simple csv file like this. \n", + "# chunk_id should be incremental (good enough, since there is no chunk_id in the original data)\n", + "# title is the title of the paragraph\n", + "# content is the text of the paragraph\n", + "# metadata is a list of original question_id that this paragraph is supporting (just for backward compatibility with previous data format)\n", + "\"\"\"\n", + "chunk_id,content,metadata\n", + "1, Bla bla bla, bla bla bla, [2hop_xxx_xxx, ....]\n", + "\"\"\"\n", + "\n", + "# then the an faiss index will be generated from this csv file with intfloat/e5-base-v2 embedding model\n", + "\n", + "# remember the distribution of the number of hops?" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -927,8 +1034,7 @@ "question = \"\"\"\n", "Each vertex of a regular octagon is independently colored either red or blue with equal probability. The probability that the octagon can then be rotated so that all of the blue vertices end up at positions where there were originally red vertices is $\\tfrac{m}{n}$, where $m$ and $n$ are relatively prime positive integers. What is $m+n$?\n", "\"\"\"\n", - "print(pipe(question))\n", - "\n" + "print(pipe(question))\n" ] } ], diff --git a/pyproject.toml b/pyproject.toml index bc7d10c..3af8445 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,4 +39,5 @@ dependencies = [ "tqdm>=4.66.1", "tavily-python", "sglang[all]>=0.4.5", + "gdown", ] \ No newline at end of file diff --git a/scripts/check_data.py b/scripts/check_data.py new file mode 100644 index 0000000..606532e --- /dev/null +++ b/scripts/check_data.py @@ -0,0 +1,117 @@ +import json +import sys +from pathlib import Path + +import pandas as pd + +# Add project root to Python path for imports +project_root = Path(__file__).resolve().parent.parent +sys.path.append(str(project_root)) + +# Assuming these are defined in your project structure +from config import DATA_DIR, logger # Adjust import as needed +from src.embeddings import CustomHuggingFaceEmbeddings + +# Import FAISS after potentially adding to sys.path +try: + from langchain_community.vectorstores import FAISS + + faiss_installed = True +except ImportError: + print("Warning: langchain_community or faiss not installed. Cannot check FAISS index.") + faiss_installed = False + + +def check_output_files(processed_dir: Path): + """Prints head and tail of key processed files and FAISS index info. + + Args: + processed_dir: The path to the 'data/processed' directory. + """ + print("--- Checking Processed Files ---") + + # 1. Check paragraphs.csv + csv_path = processed_dir / "paragraphs.csv" + print(f"\n--- Checking {csv_path} ---") + try: + df = pd.read_csv(csv_path) + print("First 3 rows:") + print(df.head(3).to_string()) + print("\nLast 3 rows:") + print(df.tail(3).to_string()) + print(f"Total rows: {len(df)}") + except FileNotFoundError: + print(f"Error: {csv_path} not found.") + except Exception as e: + print(f"Error reading {csv_path}: {e}") + + # 2. Check questions.jsonl + jsonl_path = processed_dir / "questions.jsonl" + print(f"\n--- Checking {jsonl_path} ---") + try: + with open(jsonl_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + num_lines = len(lines) + print(f"Total lines: {num_lines}") + + if num_lines > 0: + print("\nFirst 3 lines (parsed JSON):") + for i in range(min(3, num_lines)): + try: + print(json.loads(lines[i].strip())) + except json.JSONDecodeError: + print(f" (Error parsing line {i + 1})") + + if num_lines > 3: + print("\nLast 3 lines (parsed JSON):") + for i in range(max(0, num_lines - 3), num_lines): + try: + print(json.loads(lines[i].strip())) + except json.JSONDecodeError: + print(f" (Error parsing line {i + 1})") + elif num_lines > 0: + print("\n(Less than 6 lines total, showing all)") + + except FileNotFoundError: + print(f"Error: {jsonl_path} not found.") + except Exception as e: + print(f"Error reading {jsonl_path}: {e}") + + # 3. Check FAISS index + print(f"\n--- Checking FAISS Index in {processed_dir} ---") + if not faiss_installed: + print("Skipping FAISS check as required libraries are not installed.") + return + + # FAISS loads from the directory containing index.faiss and index.pkl + index_dir = processed_dir + index_file = index_dir / "index.faiss" + pkl_file = index_dir / "index.pkl" + + if not index_file.exists() or not pkl_file.exists(): + print(f"Error: FAISS index files (index.faiss, index.pkl) not found in {index_dir}") + return + + try: + print("Initializing embeddings model for loading index...") + embeddings = CustomHuggingFaceEmbeddings() + print("Loading FAISS index...") + # FAISS.load_local requires the folder_path and the embeddings object + vectorstore = FAISS.load_local(str(index_dir), embeddings, allow_dangerous_deserialization=True) + print("FAISS index loaded successfully.") + # Access the underlying FAISS index object to get the total number of vectors + print(f"Total vectors in index: {vectorstore.index.ntotal}") + except Exception as e: + print(f"Error loading or checking FAISS index from {index_dir}: {e}") + import traceback + + traceback.print_exc() + + print("\n--- Check Complete ---") + + +if __name__ == "__main__": + # Assuming the script is run from the project root or paths are relative + PROCESSED_PATH = Path("data/processed") + check_output_files(PROCESSED_PATH) diff --git a/scripts/train_data/build_musique_index.py b/scripts/train_data/build_musique_index.py new file mode 100644 index 0000000..0cb0106 --- /dev/null +++ b/scripts/train_data/build_musique_index.py @@ -0,0 +1,135 @@ +import json +import math # Import math for ceiling division +import sys +import traceback # Import traceback +from pathlib import Path + +import pandas as pd + +# Add project root to Python path if needed (adjust relative path as necessary) +project_root = Path(__file__).resolve().parent.parent +sys.path.append(str(project_root)) + +from src.embeddings import CustomHuggingFaceEmbeddings + +# Import FAISS after potentially adding to sys.path +try: + from langchain_community.vectorstores import FAISS +except ImportError: + print("Error: langchain_community or FAISS not installed. Please install with 'pip install langchain faiss-cpu'") + sys.exit(1) + + +def build_faiss_index_from_csv(csv_path: str, index_save_path: str, batch_size: int = 128) -> None: + """Builds a FAISS index from a CSV containing paragraph content and metadata. + + Reads a CSV file, generates embeddings for the 'content' column in batches, + and saves the FAISS index files (index.faiss, index.pkl) locally. + + Args: + csv_path: Path to the input CSV file (e.g., data/processed/paragraphs.csv). + index_save_path: Path to the directory where the index files should be saved. + batch_size: Number of texts to process in each embedding batch. + """ + print(f"Loading paragraphs from {csv_path}") + try: + df = pd.read_csv(csv_path) + except FileNotFoundError: + print(f"Error: CSV file not found at {csv_path}. Please run the extraction script first.") + return + except Exception as e: + print(f"Error reading CSV file: {e}") + return + + if "content" not in df.columns or "metadata" not in df.columns: + print("Error: CSV file must contain 'content' and 'metadata' columns.") + return + + if df.empty: + print("Warning: Input CSV file is empty. No index will be built.") + return + + # Prepare documents for FAISS + texts = df["content"].astype(str).tolist() + metadatas = [] + try: + metadatas = [json.loads(m) for m in df["metadata"].tolist()] + print(f"Prepared {len(texts)} texts and {len(metadatas)} metadatas.") + except json.JSONDecodeError as e: + print(f"Error parsing metadata JSON: {e}. Check the format in {csv_path}") + traceback.print_exc() # Print traceback for JSON errors + return + except Exception as e: + print(f"Error processing metadata: {e}") + traceback.print_exc() # Print traceback for other metadata errors + return + + if not texts or not metadatas or len(texts) != len(metadatas): + print(f"Error: Mismatch or empty texts/metadatas. Texts: {len(texts)}, Metadatas: {len(metadatas)}") + return + + print("Initializing embeddings model...") + try: + embeddings = CustomHuggingFaceEmbeddings() + except Exception as e: + print(f"Error initializing embeddings model: {e}") + traceback.print_exc() + return + print("Embeddings model initialized successfully.") + + vectorstore = None + num_batches = math.ceil(len(texts) / batch_size) + print(f"Processing {len(texts)} texts in {num_batches} batches of size {batch_size}...") + + for i in range(num_batches): + start_idx = i * batch_size + end_idx = min((i + 1) * batch_size, len(texts)) + batch_texts = texts[start_idx:end_idx] + batch_metadatas = metadatas[start_idx:end_idx] + print(f" Processing batch {i + 1}/{num_batches} (indices {start_idx}-{end_idx - 1})...") + + try: + if i == 0: + # Initialize the vector store with the first batch + print(f" Initializing FAISS index with first batch...") + vectorstore = FAISS.from_texts(texts=batch_texts, embedding=embeddings, metadatas=batch_metadatas) + print(" FAISS index initialized.") + else: + # Add subsequent batches to the existing store + if vectorstore is None: + print("Error: vectorstore is None after first batch, cannot add more texts.") + return # Should not happen if first batch succeeded + print(f" Adding batch {i + 1} to FAISS index...") + vectorstore.add_texts(texts=batch_texts, metadatas=batch_metadatas) + print(f" Batch {i + 1} added.") + + except Exception as e: + print(f"Error processing batch {i + 1} (indices {start_idx}-{end_idx - 1}): {e}") + traceback.print_exc() + print("Stopping index creation due to error in batch processing.") + return # Exit if any batch fails + + if vectorstore is None: + print("Error: Failed to create or add any data to the vectorstore.") + return + + # Save the completed index + try: + print(f"Attempting to save final FAISS index files to directory: {index_save_path}") + # Ensure the target directory exists before saving + Path(index_save_path).mkdir(parents=True, exist_ok=True) + vectorstore.save_local(index_save_path) + print(f"Successfully saved final FAISS index files (index.faiss, index.pkl) to: {index_save_path}") + except Exception as e: + print(f"Error during final vectorstore.save_local to {index_save_path}: {e}") + traceback.print_exc() + + +if __name__ == "__main__": + # Define paths relative to this script or use absolute paths + PROCESSED_DIR = Path("data/processed") + INPUT_CSV = str(PROCESSED_DIR / "paragraphs.csv") + # FAISS save_local will save index.faiss and index.pkl in this directory + INDEX_SAVE_DIR = str(PROCESSED_DIR) # Save directly to processed dir + + build_faiss_index_from_csv(INPUT_CSV, INDEX_SAVE_DIR, batch_size=128) diff --git a/scripts/train_data/download_data_musique.sh b/scripts/train_data/download_data_musique.sh new file mode 100644 index 0000000..9860a68 --- /dev/null +++ b/scripts/train_data/download_data_musique.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# This script is taken from https://github.com/StonyBrookNLP/musique with slight modifications + +set -e +set -x + +# If gdown doesn't work, you can download files from mentioned URLs manually +# and put them at appropriate locations. +pip install gdown + +ZIP_NAME="musique_v1.0.zip" + +# URL: https://drive.google.com/file/d/1tGdADlNjWFaHLeZZGShh2IRcpO6Lv24h/view?usp=sharing +gdown --id 1tGdADlNjWFaHLeZZGShh2IRcpO6Lv24h --output $ZIP_NAME + +TARGET_DIR="./data/raw" +mkdir -p $TARGET_DIR +unzip -o $(basename $ZIP_NAME) -d $TARGET_DIR # Extract directly into target + +# Move contents from the extracted 'data' folder up one level +mv $TARGET_DIR/data/* $TARGET_DIR/ + +# Clean up the empty directory and the zip +rm -rf $TARGET_DIR/data +rm $ZIP_NAME + +# TODO: prevent these from zipping in. +rm -rf __MACOSX +# Clean up potential extracted .DS_Store +rm -f $TARGET_DIR/.DS_Store diff --git a/scripts/train_data/extract_musique_paragraphs.py b/scripts/train_data/extract_musique_paragraphs.py new file mode 100644 index 0000000..8b4ef7c --- /dev/null +++ b/scripts/train_data/extract_musique_paragraphs.py @@ -0,0 +1,101 @@ +import json +import sys +from collections import defaultdict # Use defaultdict for cleaner accumulation +from pathlib import Path + +import pandas as pd + +# Add project root to Python path if needed (adjust relative path as necessary) +# project_root = Path(__file__).resolve().parent.parent +# sys.path.append(str(project_root)) +# from config import logger # Assuming you have a logger setup + + +def extract_unique_paragraphs(input_paths: list[str], output_csv_path: str) -> None: + """Extracts unique paragraphs from specified JSONL files. + + Reads Musique JSONL files (train, dev, test), finds unique paragraphs + (regardless of is_supporting flag), combines title and text, + tracks source question IDs, and saves to CSV. + + Args: + input_paths: A list of paths to the input JSONL files. + output_csv_path: Path to save the output CSV file. + """ + output_dir = Path(output_csv_path).parent + output_dir.mkdir(parents=True, exist_ok=True) + + # Use paragraph content as key, value is the set of source question IDs + paragraphs_data = defaultdict(set) + print("Starting paragraph extraction (including non-supporting)...") + + for file_path in input_paths: + print(f"Processing file: {file_path}") + try: + with open(file_path, "r", encoding="utf-8") as infile: + for line_num, line in enumerate(infile, 1): + try: + data = json.loads(line) + main_question_id = data.get("id") + if not main_question_id: + print(f"Warning: Missing 'id' in line {line_num} of {file_path}") + continue + + for p in data.get("paragraphs", []): + title = p.get("title", "No Title") + text = p.get("paragraph_text", "") + content = f"{title}\n{text}".strip() + + if not content: + continue # Skip empty paragraphs + + paragraphs_data[content].add(main_question_id) + + except json.JSONDecodeError: + print(f"Warning: Skipping invalid JSON in line {line_num} of {file_path}") + except Exception as e: + print(f"Warning: Error processing line {line_num} in {file_path}: {e}") + except FileNotFoundError: + print(f"Error: Input file not found: {file_path}") + except Exception as e: + print(f"Error reading file {file_path}: {e}") + + print(f"Found {len(paragraphs_data)} unique paragraphs (supporting and non-supporting).") + + # Prepare data for DataFrame + output_list = [] + sorted_content = sorted(paragraphs_data.keys()) + for chunk_id, content in enumerate(sorted_content, 1): + question_ids = paragraphs_data[content] + metadata = {"source_question_ids": sorted(list(question_ids))} + output_list.append( + { + "chunk_id": chunk_id, + "content": content, + "metadata": json.dumps(metadata), # Store metadata as JSON string + } + ) + + if not output_list: + print("No paragraphs found to save.") + return + df = pd.DataFrame(output_list) + try: + df.to_csv(output_csv_path, index=False) + print(f"Successfully saved unique paragraphs to {output_csv_path}") + except Exception as e: + print(f"Error saving CSV file: {e}") + + +if __name__ == "__main__": + RAW_DIR = Path("data/raw") + PROCESSED_DIR = Path("data/processed") + + input_files = [ + str(RAW_DIR / "musique_ans_v1.0_train.jsonl"), + str(RAW_DIR / "musique_ans_v1.0_dev.jsonl"), + str(RAW_DIR / "musique_ans_v1.0_test.jsonl"), + ] + output_csv = str(PROCESSED_DIR / "paragraphs.csv") + + extract_unique_paragraphs(input_files, output_csv) diff --git a/scripts/train_data/prepare_musique_jsonl.py b/scripts/train_data/prepare_musique_jsonl.py new file mode 100644 index 0000000..74e41da --- /dev/null +++ b/scripts/train_data/prepare_musique_jsonl.py @@ -0,0 +1,172 @@ +import json +import math # Keep math import +import os +import re # Import re for parsing ID +from collections import defaultdict +from pathlib import Path + +# import random # No longer needed +# SEED = 42 # No longer needed +# random.seed(SEED) # No longer needed + + +def transform_musique_data(input_path: str, output_path: str, sample_config: dict) -> None: + """Transforms Musique data with deterministic stratified sampling using uniform selection from sorted lists. + + Reads data, categorizes by detailed hop type, sorts categories by ID, + selects N samples uniformly spaced from each sorted category, + combines, sorts final list by ID, and writes to output. + + Args: + input_path: Path to the input JSONL file. + output_path: Path to the output JSONL file. + sample_config: Dictionary specifying samples per detailed hop type (e.g., {"2hop": 400, "3hop1": 150, ...}). + """ + output_dir = Path(output_path).parent + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"Reading all data from {input_path} for sampling...") + all_data = [] + try: + with open(input_path, "r", encoding="utf-8") as infile: + for line_num, line in enumerate(infile, 1): + try: + data = json.loads(line) + if "id" in data: + all_data.append(data) + else: + print(f"Warning: Skipping line {line_num} due to missing 'id' field in {input_path}") + except json.JSONDecodeError: + print(f"Warning: Skipping invalid JSON in line {line_num} of {input_path}") + except FileNotFoundError: + print(f"Error: Input file not found at {input_path}") + return + except Exception as e: + print(f"Error reading file {input_path}: {e}") + return + print(f"Read {len(all_data)} total samples with IDs.") + + # Detailed Categorization by hop type + categorized_data = defaultdict(list) + print("Categorizing data by detailed hop type (e.g., 3hop1, 4hop2)...") + for data in all_data: + q_id = data["id"] + match = re.match(r"^(2hop|3hop[12]|4hop[123])__", q_id) + if match: + detailed_hop_type = match.group(1) + categorized_data[detailed_hop_type].append(data) + # else: # Optional: log if an ID doesn't match expected pattern + # print(f"Warning: ID {q_id} does not match expected hop pattern.") + + # Deterministic sampling using sorting and uniform index selection + final_sample_list = [] + total_target = sum(sample_config.values()) + print(f"Sampling deterministically via uniform selection from sorted lists to get {total_target} samples...") + # Check if all requested hop types exist in config + for hop_type in sample_config.keys(): + if hop_type not in categorized_data: + print(f"Warning: Hop type '{hop_type}' requested in config but not found in data.") + + for hop_type, target_count in sample_config.items(): + available_samples = categorized_data.get(hop_type, []) + current_count = len(available_samples) + print(f" {hop_type}: Found {current_count} samples, need {target_count}.") + + if current_count == 0: + continue + + # Sort the list for this category by ID + available_samples.sort(key=lambda x: x["id"]) + + selected_samples_for_hop = [] + if current_count < target_count: + print(f" Warning: Not enough samples for {hop_type}. Taking all {current_count} sorted samples.") + selected_samples_for_hop = available_samples + else: + # Select target_count indices spread uniformly across the available samples + print(f" Selecting {target_count} samples uniformly from {current_count}...") + # Calculate indices using integer interpretation of evenly spaced points + indices_to_take = [int(i * current_count / target_count) for i in range(target_count)] + # Ensure uniqueness in case of rounding issues with small numbers (though unlikely here) + indices_to_take = sorted(list(set(indices_to_take))) + # Adjust if rounding resulted in fewer than target_count unique indices + while len(indices_to_take) < target_count: + # This is a fallback, shouldn't happen if current_count >= target_count + # Add indices from the end if needed, avoiding duplicates + next_idx = indices_to_take[-1] + 1 + if next_idx < current_count and next_idx not in indices_to_take: + indices_to_take.append(next_idx) + else: # Should not be reachable if logic is sound + break + + # Select samples at the calculated indices + selected_samples_for_hop = [ + available_samples[idx] for idx in indices_to_take[:target_count] + ] # Ensure we take exactly target_count + + final_sample_list.extend(selected_samples_for_hop) + + print(f"Selected {len(final_sample_list)} samples in total.") + + # Sort the final combined list by ID for consistent output order + print("Sorting the final combined sample list by ID...") + final_sample_list.sort(key=lambda x: x["id"]) + + # Process and write the selected samples + print(f"Processing and writing {len(final_sample_list)} selected samples to {output_path}...") + count = 0 + try: + with open(output_path, "w", encoding="utf-8") as outfile: + for data in final_sample_list: + try: + supporting_paragraphs = [ + p["paragraph_text"] for p in data.get("paragraphs", []) if p.get("is_supporting", False) + ] + + main_answer = data.get("answer", "") + aliases = data.get("answer_aliases", []) + + all_answers = [main_answer] + (aliases if isinstance(aliases, list) else []) + valid_answers = [str(ans).strip() for ans in all_answers if ans and str(ans).strip()] + unique_valid_answers = list(set(valid_answers)) + + combined_answer_str = " OR ".join(unique_valid_answers) + + output_data = { + "id": data.get("id"), + "question": data.get("question"), + "answer": combined_answer_str, + "supporting_paragraphs": supporting_paragraphs, + } + outfile.write(json.dumps(output_data) + "\n") + count += 1 + except KeyError as e: + print(f"Skipping sample due to missing key {e}: {data.get('id')}") + print(f"Successfully processed and wrote {count} samples.") + except Exception as e: + print(f"An unexpected error occurred during writing: {e}") + + +if __name__ == "__main__": + # Define file paths + RAW_DIR = Path("data/raw") + PROCESSED_DIR = Path("data/processed") + + # Define detailed sampling configuration + SAMPLING_CONFIG = { + "2hop": 400, + "3hop1": 150, + "3hop2": 150, + "4hop1": 100, + "4hop2": 100, + "4hop3": 100, + } # Total = 1000 + + transform_musique_data( + str(RAW_DIR / "musique_ans_v1.0_train.jsonl"), str(PROCESSED_DIR / "questions.jsonl"), SAMPLING_CONFIG + ) + + print( + "\nMusique JSONL transformation and detailed deterministic sampling (uniform selection from sorted) complete." + ) + # Note: Dev/Test files are not processed by default with this sampling logic. diff --git a/src/embeddings.py b/src/embeddings.py index 90af333..ca6e5e2 100644 --- a/src/embeddings.py +++ b/src/embeddings.py @@ -3,7 +3,7 @@ from langchain.embeddings.base import Embeddings from transformers import AutoModel, AutoTokenizer # Set a default model here -DEFAULT_MODEL_NAME = "avsolatorio/NoInstruct-small-Embedding-v0" +DEFAULT_MODEL_NAME = "intfloat/e5-base-v2" # Changed model name class CustomHuggingFaceEmbeddings(Embeddings):