diff --git a/notebooks/250324_generate_data_anatomy.ipynb b/notebooks/250324_generate_data_anatomy.ipynb
index 0c252ee..008177d 100644
--- a/notebooks/250324_generate_data_anatomy.ipynb
+++ b/notebooks/250324_generate_data_anatomy.ipynb
@@ -31,16 +31,10 @@
"\n",
"sys.path.append(\"..\")\n",
"\n",
- "import json\n",
- "import os\n",
"import pickle\n",
- "import re\n",
- "from typing import Dict, List, Optional, Tuple\n",
"\n",
- "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"\n",
"\n",
- "from langchain_community.document_loaders import UnstructuredMarkdownLoader\n",
"from langchain_community.vectorstores import FAISS\n",
"\n",
"from embeddings import CustomHuggingFaceEmbeddings"
@@ -101,7 +95,7 @@
"): # Ok cool, so this is much simpler than i expected!\n",
" print(f\"\\n--- Chunk {i + 1}/{len(chunks)} ---\")\n",
" print(chunk.page_content)\n",
- " print(\"-\" * 50)\n"
+ " print(\"-\" * 50)"
]
},
{
@@ -323,8 +317,6 @@
"outputs": [],
"source": [
"# Load the existing FAISS index\n",
- "from langchain_community.vectorstores import FAISS\n",
- "from embeddings import CustomHuggingFaceEmbeddings\n",
"\n",
"\n",
"# Load the paraphrased chunks\n",
@@ -365,9 +357,7 @@
"\n",
"# Save the updated vectorstore\n",
"existing_vectorstore.save_local(\"faiss_index_with_paraphrased\")\n",
- "print(\"Saved updated FAISS index to 'faiss_index_with_paraphrased'\")\n",
- "\n",
- "\n"
+ "print(\"Saved updated FAISS index to 'faiss_index_with_paraphrased'\")"
]
},
{
diff --git a/notebooks/250325_fak_you_chattemplate.ipynb b/notebooks/250325_fak_you_chattemplate.ipynb
new file mode 100644
index 0000000..7ba7a2f
--- /dev/null
+++ b/notebooks/250325_fak_you_chattemplate.ipynb
@@ -0,0 +1,105 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# LLama 3.1 \n",
+ "\n",
+ "\"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nCutting Knowledge Date: December 2023\\nToday Date: 26 Jul 2024\\n\\n<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nHello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\nI'm doing great. How can I help you today?<|eot_id|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a friendly chatbot who always responds in the style of a pirate<|eot_id|>\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from transformers import AutoTokenizer\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/meta-Llama-3.1-8B-Instruct\")\n",
+ "chat = [\n",
+ " {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n",
+ " {\"role\": \"assistant\", \"content\": \"I'm doing great. How can I help you today?\"},\n",
+ " {\n",
+ " \"role\": \"system\",\n",
+ " \"content\": \"You are a friendly chatbot who always responds in the style of a pirate\",\n",
+ " },\n",
+ "]\n",
+ "\n",
+ "tokenizer.apply_chat_template(chat, tokenize=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Qwen 1.5B R1 Distill\n",
+ "\n",
+ "\"<|begin▁of▁sentence|>You are a friendly chatbot who always responds in the style of a pirate<|User|>Hello, how are you?<|Assistant|>I'm doing great. How can I help you today?<|end▁of▁sentence|>\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from transformers import AutoTokenizer\n",
+ "\n",
+ "tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\")\n",
+ "chat = [\n",
+ " {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n",
+ " {\"role\": \"assistant\", \"content\": \"I'm doing great. How can I help you today?\"},\n",
+ " {\n",
+ " \"role\": \"system\",\n",
+ " \"content\": \"You are a friendly chatbot who always responds in the style of a pirate\",\n",
+ " },\n",
+ "]\n",
+ "\n",
+ "tokenizer.apply_chat_template(chat, tokenize=False)\n",
+ "tokenizer.apply_chat_template(chat, tokenize=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# ✅ Compare the two\n",
+ "\n",
+ "- \"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nCutting Knowledge Date: December 2023\\nToday Date: 26 Jul 2024\\n\\n<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\n**Hello, how are you?**<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n**I'm doing great. How can I help you today**?<|eot_id|><|start_header_id|>system<|end_header_id|>\\n\\n**You are a friendly chatbot who always responds in the style of a pirate**<|eot_id|>\"\n",
+ "- \"<|begin▁of▁sentence|>**You are a friendly chatbot who always responds in the style of a pirate**<|User|>**Hello, how are you?**<|Assistant|>**I'm doing great. How can I help you today?**<|end▁of▁sentence|>\"\n",
+ "- Ok make sense now!, so the structure of r1-distil doesn't have closing tags for most of the tags\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Just curious, did Alpha Maze touch anything with the chat template?\n",
+ "- Nope, as alpha maze task isn't as complicated as agent tool call stuffs, so it doesn't need to tweak the chat template"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "deepsearch-py311",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/250325_inspect_qa_dataset.ipynb b/notebooks/250325_inspect_qa_dataset.ipynb
new file mode 100644
index 0000000..26c55e8
--- /dev/null
+++ b/notebooks/250325_inspect_qa_dataset.ipynb
@@ -0,0 +1,123 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "\n",
+ "sys.path.append(\"..\")\n",
+ "\n",
+ "from src.search_module import get_qa_dataset\n",
+ "import random\n",
+ "\n",
+ "\n",
+ "def inspect_qa_dataset():\n",
+ " \"\"\"Inspect the QA dataset used for evaluation to identify potential issues\"\"\"\n",
+ "\n",
+ " # Get the datasets\n",
+ " train_dataset, test_dataset = get_qa_dataset()\n",
+ "\n",
+ " # Print dataset statistics\n",
+ " print(f\"Train dataset size: {len(train_dataset)}\")\n",
+ " print(f\"Test dataset size: {len(test_dataset)}\")\n",
+ "\n",
+ " # Print column information\n",
+ " print(f\"\\nTest dataset columns: {test_dataset.column_names}\")\n",
+ "\n",
+ " # Print a few random examples\n",
+ " sample_size = min(5, len(test_dataset))\n",
+ " sample_indices = random.sample(range(len(test_dataset)), sample_size)\n",
+ "\n",
+ " print(f\"\\n--- {sample_size} Random Test Examples ---\")\n",
+ " for i, idx in enumerate(sample_indices):\n",
+ " example = test_dataset[idx]\n",
+ " print(f\"\\nExample {i+1}:\")\n",
+ " print(f\"Prompt: {example['prompt']}\")\n",
+ " print(f\"Answer: {example['answer']}\")\n",
+ " if \"chunk_content\" in example:\n",
+ " print(f\"Chunk Content: {example['chunk_content'][:200]}... (truncated)\")\n",
+ "\n",
+ " # Check for potential issues\n",
+ " print(\"\\n--- Dataset Analysis ---\")\n",
+ "\n",
+ " # Check for duplicate questions\n",
+ " prompts = test_dataset[\"prompt\"]\n",
+ " duplicate_count = len(prompts) - len(set(prompts))\n",
+ " print(f\"Duplicate prompts: {duplicate_count}\")\n",
+ "\n",
+ " # Check answer length distribution\n",
+ " answer_lengths = [len(ans) for ans in test_dataset[\"answer\"]]\n",
+ " avg_answer_length = sum(answer_lengths) / len(answer_lengths)\n",
+ " min_answer_length = min(answer_lengths)\n",
+ " max_answer_length = max(answer_lengths)\n",
+ " print(\n",
+ " f\"Answer length stats: min={min_answer_length}, avg={avg_answer_length:.1f}, max={max_answer_length}\"\n",
+ " )\n",
+ "\n",
+ " # Analyze prompt types if possible\n",
+ " if len(prompts) > 0:\n",
+ " qa_count = sum(1 for p in prompts if p.endswith(\"?\"))\n",
+ " print(\n",
+ " f\"Questions ending with '?': {qa_count} ({qa_count/len(prompts)*100:.1f}%)\"\n",
+ " )\n",
+ "\n",
+ "\n",
+ "if __name__ == \"__main__\":\n",
+ " inspect_qa_dataset()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "\n",
+ "train_dataset, test_dataset = get_qa_dataset()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Sanity check 32 test cases: -> 31/32 is correct, nothing wrong with the test data here :/\n",
+ "\n",
+ "brow wtf is happening 😭"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "test_dataset"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "deepsearch-py311",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/250325_visualize_reward_function.ipynb b/notebooks/250325_visualize_reward_function.ipynb
index 04c691a..80b14ed 100644
--- a/notebooks/250325_visualize_reward_function.ipynb
+++ b/notebooks/250325_visualize_reward_function.ipynb
@@ -18,73 +18,95 @@
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
+ "\n",
"def plot_reward_functions():\n",
" # Generate retry counts from 0 to 15\n",
" retries = np.linspace(0, 15, 100)\n",
- " \n",
+ "\n",
" # 1. Basic Sigmoid\n",
" basic_sigmoid = 1 / (1 + np.exp(-(retries - 4)))\n",
- " \n",
+ "\n",
" # 2. Our Modified Sigmoid\n",
" x = retries - 4 # Center at 4 retries\n",
- " modified_sigmoid = 1 / (1 + np.exp(-x + abs(x)/2))\n",
- " \n",
+ " modified_sigmoid = 1 / (1 + np.exp(-x + abs(x) / 2))\n",
+ "\n",
" # 3. With Penalty\n",
" penalized_reward = modified_sigmoid.copy()\n",
" for i, r in enumerate(retries):\n",
" if r > 6:\n",
" penalty = 0.2 * (r - 6)\n",
" penalized_reward[i] = max(0.1, modified_sigmoid[i] - penalty)\n",
- " \n",
+ "\n",
" # Plotting\n",
" plt.figure(figsize=(12, 6))\n",
- " \n",
- " plt.plot(retries, basic_sigmoid, 'b--', label='Basic Sigmoid')\n",
- " plt.plot(retries, modified_sigmoid, 'g--', label='Modified Sigmoid')\n",
- " plt.plot(retries, penalized_reward, 'r-', label='Final Reward (with penalty)', linewidth=2)\n",
- " \n",
+ "\n",
+ " plt.plot(retries, basic_sigmoid, \"b--\", label=\"Basic Sigmoid\")\n",
+ " plt.plot(retries, modified_sigmoid, \"g--\", label=\"Modified Sigmoid\")\n",
+ " plt.plot(\n",
+ " retries,\n",
+ " penalized_reward,\n",
+ " \"r-\",\n",
+ " label=\"Final Reward (with penalty)\",\n",
+ " linewidth=2,\n",
+ " )\n",
+ "\n",
" # Add vertical lines for key points\n",
- " plt.axvline(x=4, color='gray', linestyle=':', alpha=0.5, label='Peak (4 retries)')\n",
- " plt.axvline(x=6, color='gray', linestyle=':', alpha=0.5, label='Penalty Start (6 retries)')\n",
- " \n",
+ " plt.axvline(x=4, color=\"gray\", linestyle=\":\", alpha=0.5, label=\"Peak (4 retries)\")\n",
+ " plt.axvline(\n",
+ " x=6, color=\"gray\", linestyle=\":\", alpha=0.5, label=\"Penalty Start (6 retries)\"\n",
+ " )\n",
+ "\n",
" plt.grid(True, alpha=0.3)\n",
- " plt.xlabel('Number of Retries')\n",
- " plt.ylabel('Reward')\n",
- " plt.title('Reward Function Visualization')\n",
+ " plt.xlabel(\"Number of Retries\")\n",
+ " plt.ylabel(\"Reward\")\n",
+ " plt.title(\"Reward Function Visualization\")\n",
" plt.legend()\n",
" plt.ylim(-0.1, 1.1)\n",
- " \n",
+ "\n",
" # Add annotations\n",
- " plt.annotate('Optimal Zone', xy=(4, 0.8), xytext=(4, 0.9),\n",
- " ha='center', va='bottom',\n",
- " bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3),\n",
- " arrowprops=dict(arrowstyle='->'))\n",
- " \n",
- " plt.annotate('Penalty Zone', xy=(8, 0.3), xytext=(8, 0.5),\n",
- " ha='center', va='bottom',\n",
- " bbox=dict(boxstyle='round,pad=0.5', fc='red', alpha=0.3),\n",
- " arrowprops=dict(arrowstyle='->'))\n",
- " \n",
+ " plt.annotate(\n",
+ " \"Optimal Zone\",\n",
+ " xy=(4, 0.8),\n",
+ " xytext=(4, 0.9),\n",
+ " ha=\"center\",\n",
+ " va=\"bottom\",\n",
+ " bbox=dict(boxstyle=\"round,pad=0.5\", fc=\"yellow\", alpha=0.3),\n",
+ " arrowprops=dict(arrowstyle=\"->\"),\n",
+ " )\n",
+ "\n",
+ " plt.annotate(\n",
+ " \"Penalty Zone\",\n",
+ " xy=(8, 0.3),\n",
+ " xytext=(8, 0.5),\n",
+ " ha=\"center\",\n",
+ " va=\"bottom\",\n",
+ " bbox=dict(boxstyle=\"round,pad=0.5\", fc=\"red\", alpha=0.3),\n",
+ " arrowprops=dict(arrowstyle=\"->\"),\n",
+ " )\n",
+ "\n",
" plt.show()\n",
"\n",
+ "\n",
"# Run the visualization\n",
"plot_reward_functions()\n",
"\n",
+ "\n",
"# Print reward values for specific retry counts\n",
"def print_reward_examples():\n",
" retry_examples = [1, 2, 3, 4, 5, 6, 7, 8, 10, 12]\n",
" print(\"\\nReward values for different retry counts:\")\n",
" print(\"Retries | Reward\")\n",
" print(\"-\" * 20)\n",
- " \n",
+ "\n",
" for retries in retry_examples:\n",
" x = retries - 4\n",
- " reward = 1 / (1 + np.exp(-x + abs(x)/2))\n",
+ " reward = 1 / (1 + np.exp(-x + abs(x) / 2))\n",
" if retries > 6:\n",
" penalty = 0.2 * (retries - 6)\n",
" reward = max(0.1, reward - penalty)\n",
" print(f\"{retries:7d} | {reward:.3f}\")\n",
"\n",
+ "\n",
"print_reward_examples()"
]
}
diff --git a/notebooks/Llama3_1_(8B)_GRPO.ipynb b/notebooks/Llama3_1_(8B)_GRPO.ipynb
index 5bb6f80..05b0969 100644
--- a/notebooks/Llama3_1_(8B)_GRPO.ipynb
+++ b/notebooks/Llama3_1_(8B)_GRPO.ipynb
@@ -36,7 +36,9 @@
" !pip install -q --no-deps unsloth vllm\n",
" # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
" # Skip restarting message in Colab\n",
- " import sys, re, requests\n",
+ " import sys\n",
+ " import re\n",
+ " import requests\n",
"\n",
" modules = list(sys.modules.keys())\n",
" for x in modules:\n",
@@ -197,7 +199,6 @@
"outputs": [],
"source": [
"from unsloth import FastLanguageModel\n",
- "import torch\n",
"\n",
"max_seq_length = 1024 # Can increase for longer reasoning traces\n",
"lora_rank = 32 # Larger rank = smarter, but slower\n",
diff --git a/notebooks/Qwen2_5_(3B)_GRPO.ipynb b/notebooks/Qwen2_5_(3B)_GRPO.ipynb
index 4735d8e..8bac994 100644
--- a/notebooks/Qwen2_5_(3B)_GRPO.ipynb
+++ b/notebooks/Qwen2_5_(3B)_GRPO.ipynb
@@ -57,6 +57,7 @@
"source": [
"%%capture\n",
"import os\n",
+ "\n",
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
" !pip install unsloth vllm\n",
"else:\n",
@@ -72,22 +73,30 @@
},
"outputs": [],
"source": [
- "#@title Colab Extra Install { display-mode: \"form\" }\n",
+ "# @title Colab Extra Install { display-mode: \"form\" }\n",
"%%capture\n",
"import os\n",
+ "\n",
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
" !pip install unsloth vllm\n",
"else:\n",
" !pip install --no-deps unsloth vllm\n",
" # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
" # Skip restarting message in Colab\n",
- " import sys, re, requests; modules = list(sys.modules.keys())\n",
- " for x in modules: sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n",
+ " import sys\n",
+ " import re\n",
+ " import requests\n",
+ "\n",
+ " modules = list(sys.modules.keys())\n",
+ " for x in modules:\n",
+ " sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n",
" !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n",
" !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer\n",
"\n",
" # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n",
- " f = requests.get(\"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\").content\n",
+ " f = requests.get(\n",
+ " \"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\"\n",
+ " ).content\n",
" with open(\"vllm_requirements.txt\", \"wb\") as file:\n",
" file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n",
" !pip install -r vllm_requirements.txt"
@@ -303,29 +312,34 @@
"outputs": [],
"source": [
"from unsloth import FastLanguageModel, is_bfloat16_supported\n",
- "import torch\n",
- "max_seq_length = 1024 # Can increase for longer reasoning traces\n",
- "lora_rank = 64 # Larger rank = smarter, but slower\n",
+ "\n",
+ "max_seq_length = 1024 # Can increase for longer reasoning traces\n",
+ "lora_rank = 64 # Larger rank = smarter, but slower\n",
"\n",
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
- " model_name = \"Qwen/Qwen2.5-3B-Instruct\",\n",
- " max_seq_length = max_seq_length,\n",
- " load_in_4bit = True, # False for LoRA 16bit\n",
- " fast_inference = True, # Enable vLLM fast inference\n",
- " max_lora_rank = lora_rank,\n",
- " gpu_memory_utilization = 0.5, # Reduce if out of memory\n",
+ " model_name=\"Qwen/Qwen2.5-3B-Instruct\",\n",
+ " max_seq_length=max_seq_length,\n",
+ " load_in_4bit=True, # False for LoRA 16bit\n",
+ " fast_inference=True, # Enable vLLM fast inference\n",
+ " max_lora_rank=lora_rank,\n",
+ " gpu_memory_utilization=0.5, # Reduce if out of memory\n",
")\n",
"\n",
"model = FastLanguageModel.get_peft_model(\n",
" model,\n",
- " r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
- " target_modules = [\n",
- " \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
- " \"gate_proj\", \"up_proj\", \"down_proj\",\n",
- " ], # Remove QKVO if out of memory\n",
- " lora_alpha = lora_rank,\n",
- " use_gradient_checkpointing = \"unsloth\", # Enable long context finetuning\n",
- " random_state = 3407,\n",
+ " r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
+ " target_modules=[\n",
+ " \"q_proj\",\n",
+ " \"k_proj\",\n",
+ " \"v_proj\",\n",
+ " \"o_proj\",\n",
+ " \"gate_proj\",\n",
+ " \"up_proj\",\n",
+ " \"down_proj\",\n",
+ " ], # Remove QKVO if out of memory\n",
+ " lora_alpha=lora_rank,\n",
+ " use_gradient_checkpointing=\"unsloth\", # Enable long context finetuning\n",
+ " random_state=3407,\n",
")"
]
},
@@ -445,43 +459,58 @@
"\n",
"\"\"\"\n",
"\n",
+ "\n",
"def extract_xml_answer(text: str) -> str:\n",
" answer = text.split(\"\")[-1]\n",
" answer = answer.split(\"\")[0]\n",
" return answer.strip()\n",
"\n",
+ "\n",
"def extract_hash_answer(text: str) -> str | None:\n",
" if \"####\" not in text:\n",
" return None\n",
" return text.split(\"####\")[1].strip()\n",
"\n",
+ "\n",
"# uncomment middle messages for 1-shot prompting\n",
- "def get_gsm8k_questions(split = \"train\") -> Dataset:\n",
- " data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore\n",
- " data = data.map(lambda x: { # type: ignore\n",
- " 'prompt': [\n",
- " {'role': 'system', 'content': SYSTEM_PROMPT},\n",
- " {'role': 'user', 'content': x['question']}\n",
- " ],\n",
- " 'answer': extract_hash_answer(x['answer'])\n",
- " }) # type: ignore\n",
- " return data # type: ignore\n",
+ "def get_gsm8k_questions(split=\"train\") -> Dataset:\n",
+ " data = load_dataset(\"openai/gsm8k\", \"main\")[split] # type: ignore\n",
+ " data = data.map(\n",
+ " lambda x: { # type: ignore\n",
+ " \"prompt\": [\n",
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
+ " {\"role\": \"user\", \"content\": x[\"question\"]},\n",
+ " ],\n",
+ " \"answer\": extract_hash_answer(x[\"answer\"]),\n",
+ " }\n",
+ " ) # type: ignore\n",
+ " return data # type: ignore\n",
+ "\n",
"\n",
"dataset = get_gsm8k_questions()\n",
"\n",
+ "\n",
"# Reward functions\n",
"def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n",
- " responses = [completion[0]['content'] for completion in completions]\n",
- " q = prompts[0][-1]['content']\n",
+ " responses = [completion[0][\"content\"] for completion in completions]\n",
+ " q = prompts[0][-1][\"content\"]\n",
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
- " print('-'*20, f\"Question:\\n{q}\", f\"\\nAnswer:\\n{answer[0]}\", f\"\\nResponse:\\n{responses[0]}\", f\"\\nExtracted:\\n{extracted_responses[0]}\")\n",
+ " print(\n",
+ " \"-\" * 20,\n",
+ " f\"Question:\\n{q}\",\n",
+ " f\"\\nAnswer:\\n{answer[0]}\",\n",
+ " f\"\\nResponse:\\n{responses[0]}\",\n",
+ " f\"\\nExtracted:\\n{extracted_responses[0]}\",\n",
+ " )\n",
" return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n",
"\n",
+ "\n",
"def int_reward_func(completions, **kwargs) -> list[float]:\n",
- " responses = [completion[0]['content'] for completion in completions]\n",
+ " responses = [completion[0][\"content\"] for completion in completions]\n",
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
" return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n",
"\n",
+ "\n",
"def strict_format_reward_func(completions, **kwargs) -> list[float]:\n",
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
" pattern = r\"^\\n.*?\\n\\n\\n.*?\\n\\n$\"\n",
@@ -489,6 +518,7 @@
" matches = [re.match(pattern, r) for r in responses]\n",
" return [0.5 if match else 0.0 for match in matches]\n",
"\n",
+ "\n",
"def soft_format_reward_func(completions, **kwargs) -> list[float]:\n",
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
" pattern = r\".*?\\s*.*?\"\n",
@@ -496,6 +526,7 @@
" matches = [re.match(pattern, r) for r in responses]\n",
" return [0.5 if match else 0.0 for match in matches]\n",
"\n",
+ "\n",
"def count_xml(text) -> float:\n",
" count = 0.0\n",
" if text.count(\"\\n\") == 1:\n",
@@ -504,12 +535,13 @@
" count += 0.125\n",
" if text.count(\"\\n\\n\") == 1:\n",
" count += 0.125\n",
- " count -= len(text.split(\"\\n\\n\")[-1])*0.001\n",
+ " count -= len(text.split(\"\\n\\n\")[-1]) * 0.001\n",
" if text.count(\"\\n\") == 1:\n",
" count += 0.125\n",
- " count -= (len(text.split(\"\\n\")[-1]) - 1)*0.001\n",
+ " count -= (len(text.split(\"\\n\")[-1]) - 1) * 0.001\n",
" return count\n",
"\n",
+ "\n",
"def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n",
" contents = [completion[0][\"content\"] for completion in completions]\n",
" return [count_xml(c) for c in contents]"
@@ -540,29 +572,30 @@
"outputs": [],
"source": [
"from trl import GRPOConfig, GRPOTrainer\n",
+ "\n",
"training_args = GRPOConfig(\n",
- " use_vllm = True, # use vLLM for fast inference!\n",
- " learning_rate = 5e-6,\n",
- " adam_beta1 = 0.9,\n",
- " adam_beta2 = 0.99,\n",
- " weight_decay = 0.1,\n",
- " warmup_ratio = 0.1,\n",
- " lr_scheduler_type = \"cosine\",\n",
- " optim = \"adamw_8bit\",\n",
- " logging_steps = 1,\n",
- " bf16 = is_bfloat16_supported(),\n",
- " fp16 = not is_bfloat16_supported(),\n",
- " per_device_train_batch_size = 1,\n",
- " gradient_accumulation_steps = 1, # Increase to 4 for smoother training\n",
- " num_generations = 8, # Decrease if out of memory\n",
- " max_prompt_length = 256,\n",
- " max_completion_length = 200,\n",
+ " use_vllm=True, # use vLLM for fast inference!\n",
+ " learning_rate=5e-6,\n",
+ " adam_beta1=0.9,\n",
+ " adam_beta2=0.99,\n",
+ " weight_decay=0.1,\n",
+ " warmup_ratio=0.1,\n",
+ " lr_scheduler_type=\"cosine\",\n",
+ " optim=\"adamw_8bit\",\n",
+ " logging_steps=1,\n",
+ " bf16=is_bfloat16_supported(),\n",
+ " fp16=not is_bfloat16_supported(),\n",
+ " per_device_train_batch_size=1,\n",
+ " gradient_accumulation_steps=1, # Increase to 4 for smoother training\n",
+ " num_generations=8, # Decrease if out of memory\n",
+ " max_prompt_length=256,\n",
+ " max_completion_length=200,\n",
" # num_train_epochs = 1, # Set to 1 for a full training run\n",
- " max_steps = 250,\n",
- " save_steps = 250,\n",
- " max_grad_norm = 0.1,\n",
- " report_to = \"none\", # Can use Weights & Biases\n",
- " output_dir = \"outputs\",\n",
+ " max_steps=250,\n",
+ " save_steps=250,\n",
+ " max_grad_norm=0.1,\n",
+ " report_to=\"none\", # Can use Weights & Biases\n",
+ " output_dir=\"outputs\",\n",
")"
]
},
@@ -597,17 +630,17 @@
"outputs": [],
"source": [
"trainer = GRPOTrainer(\n",
- " model = model,\n",
- " processing_class = tokenizer,\n",
- " reward_funcs = [\n",
+ " model=model,\n",
+ " processing_class=tokenizer,\n",
+ " reward_funcs=[\n",
" xmlcount_reward_func,\n",
" soft_format_reward_func,\n",
" strict_format_reward_func,\n",
" int_reward_func,\n",
" correctness_reward_func,\n",
" ],\n",
- " args = training_args,\n",
- " train_dataset = dataset,\n",
+ " args=training_args,\n",
+ " train_dataset=dataset,\n",
")\n",
"trainer.train()"
]
@@ -636,21 +669,30 @@
},
"outputs": [],
"source": [
- "text = tokenizer.apply_chat_template([\n",
- " {\"role\" : \"user\", \"content\" : \"How many r's are in strawberry?\"},\n",
- "], tokenize = False, add_generation_prompt = True)\n",
+ "text = tokenizer.apply_chat_template(\n",
+ " [\n",
+ " {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n",
+ " ],\n",
+ " tokenize=False,\n",
+ " add_generation_prompt=True,\n",
+ ")\n",
"\n",
"from vllm import SamplingParams\n",
+ "\n",
"sampling_params = SamplingParams(\n",
- " temperature = 0.8,\n",
- " top_p = 0.95,\n",
- " max_tokens = 1024,\n",
+ " temperature=0.8,\n",
+ " top_p=0.95,\n",
+ " max_tokens=1024,\n",
+ ")\n",
+ "output = (\n",
+ " model.fast_generate(\n",
+ " [text],\n",
+ " sampling_params=sampling_params,\n",
+ " lora_request=None,\n",
+ " )[0]\n",
+ " .outputs[0]\n",
+ " .text\n",
")\n",
- "output = model.fast_generate(\n",
- " [text],\n",
- " sampling_params = sampling_params,\n",
- " lora_request = None,\n",
- ")[0].outputs[0].text\n",
"\n",
"output"
]
@@ -697,22 +739,31 @@
},
"outputs": [],
"source": [
- "text = tokenizer.apply_chat_template([\n",
- " {\"role\" : \"system\", \"content\" : SYSTEM_PROMPT},\n",
- " {\"role\" : \"user\", \"content\" : \"How many r's are in strawberry?\"},\n",
- "], tokenize = False, add_generation_prompt = True)\n",
+ "text = tokenizer.apply_chat_template(\n",
+ " [\n",
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
+ " {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n",
+ " ],\n",
+ " tokenize=False,\n",
+ " add_generation_prompt=True,\n",
+ ")\n",
"\n",
"from vllm import SamplingParams\n",
+ "\n",
"sampling_params = SamplingParams(\n",
- " temperature = 0.8,\n",
- " top_p = 0.95,\n",
- " max_tokens = 1024,\n",
+ " temperature=0.8,\n",
+ " top_p=0.95,\n",
+ " max_tokens=1024,\n",
+ ")\n",
+ "output = (\n",
+ " model.fast_generate(\n",
+ " text,\n",
+ " sampling_params=sampling_params,\n",
+ " lora_request=model.load_lora(\"grpo_saved_lora\"),\n",
+ " )[0]\n",
+ " .outputs[0]\n",
+ " .text\n",
")\n",
- "output = model.fast_generate(\n",
- " text,\n",
- " sampling_params = sampling_params,\n",
- " lora_request = model.load_lora(\"grpo_saved_lora\"),\n",
- ")[0].outputs[0].text\n",
"\n",
"output"
]
@@ -747,16 +798,36 @@
"outputs": [],
"source": [
"# Merge to 16bit\n",
- "if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"merged_16bit\",)\n",
- "if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"merged_16bit\", token = \"\")\n",
+ "if False:\n",
+ " model.save_pretrained_merged(\n",
+ " \"model\",\n",
+ " tokenizer,\n",
+ " save_method=\"merged_16bit\",\n",
+ " )\n",
+ "if False:\n",
+ " model.push_to_hub_merged(\n",
+ " \"hf/model\", tokenizer, save_method=\"merged_16bit\", token=\"\"\n",
+ " )\n",
"\n",
"# Merge to 4bit\n",
- "if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"merged_4bit\",)\n",
- "if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"merged_4bit\", token = \"\")\n",
+ "if False:\n",
+ " model.save_pretrained_merged(\n",
+ " \"model\",\n",
+ " tokenizer,\n",
+ " save_method=\"merged_4bit\",\n",
+ " )\n",
+ "if False:\n",
+ " model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"merged_4bit\", token=\"\")\n",
"\n",
"# Just LoRA adapters\n",
- "if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"lora\",)\n",
- "if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"lora\", token = \"\")"
+ "if False:\n",
+ " model.save_pretrained_merged(\n",
+ " \"model\",\n",
+ " tokenizer,\n",
+ " save_method=\"lora\",\n",
+ " )\n",
+ "if False:\n",
+ " model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"lora\", token=\"\")"
]
},
{
@@ -785,26 +856,41 @@
"outputs": [],
"source": [
"# Save to 8bit Q8_0\n",
- "if False: model.save_pretrained_gguf(\"model\", tokenizer,)\n",
+ "if False:\n",
+ " model.save_pretrained_gguf(\n",
+ " \"model\",\n",
+ " tokenizer,\n",
+ " )\n",
"# Remember to go to https://huggingface.co/settings/tokens for a token!\n",
"# And change hf to your username!\n",
- "if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, token = \"\")\n",
+ "if False:\n",
+ " model.push_to_hub_gguf(\"hf/model\", tokenizer, token=\"\")\n",
"\n",
"# Save to 16bit GGUF\n",
- "if False: model.save_pretrained_gguf(\"model\", tokenizer, quantization_method = \"f16\")\n",
- "if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method = \"f16\", token = \"\")\n",
+ "if False:\n",
+ " model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"f16\")\n",
+ "if False:\n",
+ " model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method=\"f16\", token=\"\")\n",
"\n",
"# Save to q4_k_m GGUF\n",
- "if False: model.save_pretrained_gguf(\"model\", tokenizer, quantization_method = \"q4_k_m\")\n",
- "if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method = \"q4_k_m\", token = \"\")\n",
+ "if False:\n",
+ " model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"q4_k_m\")\n",
+ "if False:\n",
+ " model.push_to_hub_gguf(\n",
+ " \"hf/model\", tokenizer, quantization_method=\"q4_k_m\", token=\"\"\n",
+ " )\n",
"\n",
"# Save to multiple GGUF options - much faster if you want multiple!\n",
"if False:\n",
" model.push_to_hub_gguf(\n",
- " \"hf/model\", # Change hf to your username!\n",
+ " \"hf/model\", # Change hf to your username!\n",
" tokenizer,\n",
- " quantization_method = [\"q4_k_m\", \"q8_0\", \"q5_k_m\",],\n",
- " token = \"\",\n",
+ " quantization_method=[\n",
+ " \"q4_k_m\",\n",
+ " \"q8_0\",\n",
+ " \"q5_k_m\",\n",
+ " ],\n",
+ " token=\"\",\n",
" )"
]
},
diff --git a/notebooks/README.md b/notebooks/README.md
new file mode 100644
index 0000000..f3c905f
--- /dev/null
+++ b/notebooks/README.md
@@ -0,0 +1,7 @@
+# Clean Notebooks Output 101
+
+```
+pip install nbstripout
+nbstripout --status
+nbstripout --install --global
+```
diff --git a/notebooks/train_autodidact.ipynb b/notebooks/train_autodidact.ipynb
index ff33f4d..9221516 100644
--- a/notebooks/train_autodidact.ipynb
+++ b/notebooks/train_autodidact.ipynb
@@ -61,7 +61,6 @@
"outputs": [],
"source": [
"from unsloth import is_bfloat16_supported\n",
- "import torch\n",
"\n",
"max_seq_length = 4096 * 2 # Can increase for longer reasoning traces\n",
"lora_rank = 64 # Larger rank = smarter, but slower\n",
@@ -101,9 +100,6 @@
},
"outputs": [],
"source": [
- "import re\n",
- "from datasets import load_dataset, Dataset\n",
- "from search_module import search, get_question_answer, get_question_count\n",
"from rl_helpers import get_qa_dataset\n",
"\n",
"train_dataset, test_dataset = get_qa_dataset()"
@@ -187,7 +183,7 @@
"def agentic_generate(\n",
" prompts: list[str],\n",
" generate_fn,\n",
- " max_generations: int = 6,\n",
+ " max_generations: int = 10,\n",
"):\n",
" return run_agent(generate_fn, tokenizer, prompts, max_generations)\n",
"\n",