diff --git a/notebooks/250324_generate_data_anatomy.ipynb b/notebooks/250324_generate_data_anatomy.ipynb index 0c252ee..008177d 100644 --- a/notebooks/250324_generate_data_anatomy.ipynb +++ b/notebooks/250324_generate_data_anatomy.ipynb @@ -31,16 +31,10 @@ "\n", "sys.path.append(\"..\")\n", "\n", - "import json\n", - "import os\n", "import pickle\n", - "import re\n", - "from typing import Dict, List, Optional, Tuple\n", "\n", - "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", "\n", "\n", - "from langchain_community.document_loaders import UnstructuredMarkdownLoader\n", "from langchain_community.vectorstores import FAISS\n", "\n", "from embeddings import CustomHuggingFaceEmbeddings" @@ -101,7 +95,7 @@ "): # Ok cool, so this is much simpler than i expected!\n", " print(f\"\\n--- Chunk {i + 1}/{len(chunks)} ---\")\n", " print(chunk.page_content)\n", - " print(\"-\" * 50)\n" + " print(\"-\" * 50)" ] }, { @@ -323,8 +317,6 @@ "outputs": [], "source": [ "# Load the existing FAISS index\n", - "from langchain_community.vectorstores import FAISS\n", - "from embeddings import CustomHuggingFaceEmbeddings\n", "\n", "\n", "# Load the paraphrased chunks\n", @@ -365,9 +357,7 @@ "\n", "# Save the updated vectorstore\n", "existing_vectorstore.save_local(\"faiss_index_with_paraphrased\")\n", - "print(\"Saved updated FAISS index to 'faiss_index_with_paraphrased'\")\n", - "\n", - "\n" + "print(\"Saved updated FAISS index to 'faiss_index_with_paraphrased'\")" ] }, { diff --git a/notebooks/250325_fak_you_chattemplate.ipynb b/notebooks/250325_fak_you_chattemplate.ipynb new file mode 100644 index 0000000..7ba7a2f --- /dev/null +++ b/notebooks/250325_fak_you_chattemplate.ipynb @@ -0,0 +1,105 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LLama 3.1 \n", + "\n", + "\"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nCutting Knowledge Date: December 2023\\nToday Date: 26 Jul 2024\\n\\n<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nHello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\nI'm doing great. How can I help you today?<|eot_id|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a friendly chatbot who always responds in the style of a pirate<|eot_id|>\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/meta-Llama-3.1-8B-Instruct\")\n", + "chat = [\n", + " {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n", + " {\"role\": \"assistant\", \"content\": \"I'm doing great. How can I help you today?\"},\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"You are a friendly chatbot who always responds in the style of a pirate\",\n", + " },\n", + "]\n", + "\n", + "tokenizer.apply_chat_template(chat, tokenize=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Qwen 1.5B R1 Distill\n", + "\n", + "\"<|begin▁of▁sentence|>You are a friendly chatbot who always responds in the style of a pirate<|User|>Hello, how are you?<|Assistant|>I'm doing great. How can I help you today?<|end▁of▁sentence|>\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\")\n", + "chat = [\n", + " {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n", + " {\"role\": \"assistant\", \"content\": \"I'm doing great. How can I help you today?\"},\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"You are a friendly chatbot who always responds in the style of a pirate\",\n", + " },\n", + "]\n", + "\n", + "tokenizer.apply_chat_template(chat, tokenize=False)\n", + "tokenizer.apply_chat_template(chat, tokenize=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ✅ Compare the two\n", + "\n", + "- \"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nCutting Knowledge Date: December 2023\\nToday Date: 26 Jul 2024\\n\\n<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\n**Hello, how are you?**<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n**I'm doing great. How can I help you today**?<|eot_id|><|start_header_id|>system<|end_header_id|>\\n\\n**You are a friendly chatbot who always responds in the style of a pirate**<|eot_id|>\"\n", + "- \"<|begin▁of▁sentence|>**You are a friendly chatbot who always responds in the style of a pirate**<|User|>**Hello, how are you?**<|Assistant|>**I'm doing great. How can I help you today?**<|end▁of▁sentence|>\"\n", + "- Ok make sense now!, so the structure of r1-distil doesn't have closing tags for most of the tags\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Just curious, did Alpha Maze touch anything with the chat template?\n", + "- Nope, as alpha maze task isn't as complicated as agent tool call stuffs, so it doesn't need to tweak the chat template" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "deepsearch-py311", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/250325_inspect_qa_dataset.ipynb b/notebooks/250325_inspect_qa_dataset.ipynb new file mode 100644 index 0000000..26c55e8 --- /dev/null +++ b/notebooks/250325_inspect_qa_dataset.ipynb @@ -0,0 +1,123 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "sys.path.append(\"..\")\n", + "\n", + "from src.search_module import get_qa_dataset\n", + "import random\n", + "\n", + "\n", + "def inspect_qa_dataset():\n", + " \"\"\"Inspect the QA dataset used for evaluation to identify potential issues\"\"\"\n", + "\n", + " # Get the datasets\n", + " train_dataset, test_dataset = get_qa_dataset()\n", + "\n", + " # Print dataset statistics\n", + " print(f\"Train dataset size: {len(train_dataset)}\")\n", + " print(f\"Test dataset size: {len(test_dataset)}\")\n", + "\n", + " # Print column information\n", + " print(f\"\\nTest dataset columns: {test_dataset.column_names}\")\n", + "\n", + " # Print a few random examples\n", + " sample_size = min(5, len(test_dataset))\n", + " sample_indices = random.sample(range(len(test_dataset)), sample_size)\n", + "\n", + " print(f\"\\n--- {sample_size} Random Test Examples ---\")\n", + " for i, idx in enumerate(sample_indices):\n", + " example = test_dataset[idx]\n", + " print(f\"\\nExample {i+1}:\")\n", + " print(f\"Prompt: {example['prompt']}\")\n", + " print(f\"Answer: {example['answer']}\")\n", + " if \"chunk_content\" in example:\n", + " print(f\"Chunk Content: {example['chunk_content'][:200]}... (truncated)\")\n", + "\n", + " # Check for potential issues\n", + " print(\"\\n--- Dataset Analysis ---\")\n", + "\n", + " # Check for duplicate questions\n", + " prompts = test_dataset[\"prompt\"]\n", + " duplicate_count = len(prompts) - len(set(prompts))\n", + " print(f\"Duplicate prompts: {duplicate_count}\")\n", + "\n", + " # Check answer length distribution\n", + " answer_lengths = [len(ans) for ans in test_dataset[\"answer\"]]\n", + " avg_answer_length = sum(answer_lengths) / len(answer_lengths)\n", + " min_answer_length = min(answer_lengths)\n", + " max_answer_length = max(answer_lengths)\n", + " print(\n", + " f\"Answer length stats: min={min_answer_length}, avg={avg_answer_length:.1f}, max={max_answer_length}\"\n", + " )\n", + "\n", + " # Analyze prompt types if possible\n", + " if len(prompts) > 0:\n", + " qa_count = sum(1 for p in prompts if p.endswith(\"?\"))\n", + " print(\n", + " f\"Questions ending with '?': {qa_count} ({qa_count/len(prompts)*100:.1f}%)\"\n", + " )\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " inspect_qa_dataset()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "train_dataset, test_dataset = get_qa_dataset()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sanity check 32 test cases: -> 31/32 is correct, nothing wrong with the test data here :/\n", + "\n", + "brow wtf is happening 😭" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_dataset" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "deepsearch-py311", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/250325_visualize_reward_function.ipynb b/notebooks/250325_visualize_reward_function.ipynb index 04c691a..80b14ed 100644 --- a/notebooks/250325_visualize_reward_function.ipynb +++ b/notebooks/250325_visualize_reward_function.ipynb @@ -18,73 +18,95 @@ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", + "\n", "def plot_reward_functions():\n", " # Generate retry counts from 0 to 15\n", " retries = np.linspace(0, 15, 100)\n", - " \n", + "\n", " # 1. Basic Sigmoid\n", " basic_sigmoid = 1 / (1 + np.exp(-(retries - 4)))\n", - " \n", + "\n", " # 2. Our Modified Sigmoid\n", " x = retries - 4 # Center at 4 retries\n", - " modified_sigmoid = 1 / (1 + np.exp(-x + abs(x)/2))\n", - " \n", + " modified_sigmoid = 1 / (1 + np.exp(-x + abs(x) / 2))\n", + "\n", " # 3. With Penalty\n", " penalized_reward = modified_sigmoid.copy()\n", " for i, r in enumerate(retries):\n", " if r > 6:\n", " penalty = 0.2 * (r - 6)\n", " penalized_reward[i] = max(0.1, modified_sigmoid[i] - penalty)\n", - " \n", + "\n", " # Plotting\n", " plt.figure(figsize=(12, 6))\n", - " \n", - " plt.plot(retries, basic_sigmoid, 'b--', label='Basic Sigmoid')\n", - " plt.plot(retries, modified_sigmoid, 'g--', label='Modified Sigmoid')\n", - " plt.plot(retries, penalized_reward, 'r-', label='Final Reward (with penalty)', linewidth=2)\n", - " \n", + "\n", + " plt.plot(retries, basic_sigmoid, \"b--\", label=\"Basic Sigmoid\")\n", + " plt.plot(retries, modified_sigmoid, \"g--\", label=\"Modified Sigmoid\")\n", + " plt.plot(\n", + " retries,\n", + " penalized_reward,\n", + " \"r-\",\n", + " label=\"Final Reward (with penalty)\",\n", + " linewidth=2,\n", + " )\n", + "\n", " # Add vertical lines for key points\n", - " plt.axvline(x=4, color='gray', linestyle=':', alpha=0.5, label='Peak (4 retries)')\n", - " plt.axvline(x=6, color='gray', linestyle=':', alpha=0.5, label='Penalty Start (6 retries)')\n", - " \n", + " plt.axvline(x=4, color=\"gray\", linestyle=\":\", alpha=0.5, label=\"Peak (4 retries)\")\n", + " plt.axvline(\n", + " x=6, color=\"gray\", linestyle=\":\", alpha=0.5, label=\"Penalty Start (6 retries)\"\n", + " )\n", + "\n", " plt.grid(True, alpha=0.3)\n", - " plt.xlabel('Number of Retries')\n", - " plt.ylabel('Reward')\n", - " plt.title('Reward Function Visualization')\n", + " plt.xlabel(\"Number of Retries\")\n", + " plt.ylabel(\"Reward\")\n", + " plt.title(\"Reward Function Visualization\")\n", " plt.legend()\n", " plt.ylim(-0.1, 1.1)\n", - " \n", + "\n", " # Add annotations\n", - " plt.annotate('Optimal Zone', xy=(4, 0.8), xytext=(4, 0.9),\n", - " ha='center', va='bottom',\n", - " bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3),\n", - " arrowprops=dict(arrowstyle='->'))\n", - " \n", - " plt.annotate('Penalty Zone', xy=(8, 0.3), xytext=(8, 0.5),\n", - " ha='center', va='bottom',\n", - " bbox=dict(boxstyle='round,pad=0.5', fc='red', alpha=0.3),\n", - " arrowprops=dict(arrowstyle='->'))\n", - " \n", + " plt.annotate(\n", + " \"Optimal Zone\",\n", + " xy=(4, 0.8),\n", + " xytext=(4, 0.9),\n", + " ha=\"center\",\n", + " va=\"bottom\",\n", + " bbox=dict(boxstyle=\"round,pad=0.5\", fc=\"yellow\", alpha=0.3),\n", + " arrowprops=dict(arrowstyle=\"->\"),\n", + " )\n", + "\n", + " plt.annotate(\n", + " \"Penalty Zone\",\n", + " xy=(8, 0.3),\n", + " xytext=(8, 0.5),\n", + " ha=\"center\",\n", + " va=\"bottom\",\n", + " bbox=dict(boxstyle=\"round,pad=0.5\", fc=\"red\", alpha=0.3),\n", + " arrowprops=dict(arrowstyle=\"->\"),\n", + " )\n", + "\n", " plt.show()\n", "\n", + "\n", "# Run the visualization\n", "plot_reward_functions()\n", "\n", + "\n", "# Print reward values for specific retry counts\n", "def print_reward_examples():\n", " retry_examples = [1, 2, 3, 4, 5, 6, 7, 8, 10, 12]\n", " print(\"\\nReward values for different retry counts:\")\n", " print(\"Retries | Reward\")\n", " print(\"-\" * 20)\n", - " \n", + "\n", " for retries in retry_examples:\n", " x = retries - 4\n", - " reward = 1 / (1 + np.exp(-x + abs(x)/2))\n", + " reward = 1 / (1 + np.exp(-x + abs(x) / 2))\n", " if retries > 6:\n", " penalty = 0.2 * (retries - 6)\n", " reward = max(0.1, reward - penalty)\n", " print(f\"{retries:7d} | {reward:.3f}\")\n", "\n", + "\n", "print_reward_examples()" ] } diff --git a/notebooks/Llama3_1_(8B)_GRPO.ipynb b/notebooks/Llama3_1_(8B)_GRPO.ipynb index 5bb6f80..05b0969 100644 --- a/notebooks/Llama3_1_(8B)_GRPO.ipynb +++ b/notebooks/Llama3_1_(8B)_GRPO.ipynb @@ -36,7 +36,9 @@ " !pip install -q --no-deps unsloth vllm\n", " # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n", " # Skip restarting message in Colab\n", - " import sys, re, requests\n", + " import sys\n", + " import re\n", + " import requests\n", "\n", " modules = list(sys.modules.keys())\n", " for x in modules:\n", @@ -197,7 +199,6 @@ "outputs": [], "source": [ "from unsloth import FastLanguageModel\n", - "import torch\n", "\n", "max_seq_length = 1024 # Can increase for longer reasoning traces\n", "lora_rank = 32 # Larger rank = smarter, but slower\n", diff --git a/notebooks/Qwen2_5_(3B)_GRPO.ipynb b/notebooks/Qwen2_5_(3B)_GRPO.ipynb index 4735d8e..8bac994 100644 --- a/notebooks/Qwen2_5_(3B)_GRPO.ipynb +++ b/notebooks/Qwen2_5_(3B)_GRPO.ipynb @@ -57,6 +57,7 @@ "source": [ "%%capture\n", "import os\n", + "\n", "if \"COLAB_\" not in \"\".join(os.environ.keys()):\n", " !pip install unsloth vllm\n", "else:\n", @@ -72,22 +73,30 @@ }, "outputs": [], "source": [ - "#@title Colab Extra Install { display-mode: \"form\" }\n", + "# @title Colab Extra Install { display-mode: \"form\" }\n", "%%capture\n", "import os\n", + "\n", "if \"COLAB_\" not in \"\".join(os.environ.keys()):\n", " !pip install unsloth vllm\n", "else:\n", " !pip install --no-deps unsloth vllm\n", " # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n", " # Skip restarting message in Colab\n", - " import sys, re, requests; modules = list(sys.modules.keys())\n", - " for x in modules: sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n", + " import sys\n", + " import re\n", + " import requests\n", + "\n", + " modules = list(sys.modules.keys())\n", + " for x in modules:\n", + " sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n", " !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n", " !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer\n", "\n", " # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n", - " f = requests.get(\"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\").content\n", + " f = requests.get(\n", + " \"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\"\n", + " ).content\n", " with open(\"vllm_requirements.txt\", \"wb\") as file:\n", " file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n", " !pip install -r vllm_requirements.txt" @@ -303,29 +312,34 @@ "outputs": [], "source": [ "from unsloth import FastLanguageModel, is_bfloat16_supported\n", - "import torch\n", - "max_seq_length = 1024 # Can increase for longer reasoning traces\n", - "lora_rank = 64 # Larger rank = smarter, but slower\n", + "\n", + "max_seq_length = 1024 # Can increase for longer reasoning traces\n", + "lora_rank = 64 # Larger rank = smarter, but slower\n", "\n", "model, tokenizer = FastLanguageModel.from_pretrained(\n", - " model_name = \"Qwen/Qwen2.5-3B-Instruct\",\n", - " max_seq_length = max_seq_length,\n", - " load_in_4bit = True, # False for LoRA 16bit\n", - " fast_inference = True, # Enable vLLM fast inference\n", - " max_lora_rank = lora_rank,\n", - " gpu_memory_utilization = 0.5, # Reduce if out of memory\n", + " model_name=\"Qwen/Qwen2.5-3B-Instruct\",\n", + " max_seq_length=max_seq_length,\n", + " load_in_4bit=True, # False for LoRA 16bit\n", + " fast_inference=True, # Enable vLLM fast inference\n", + " max_lora_rank=lora_rank,\n", + " gpu_memory_utilization=0.5, # Reduce if out of memory\n", ")\n", "\n", "model = FastLanguageModel.get_peft_model(\n", " model,\n", - " r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n", - " target_modules = [\n", - " \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n", - " \"gate_proj\", \"up_proj\", \"down_proj\",\n", - " ], # Remove QKVO if out of memory\n", - " lora_alpha = lora_rank,\n", - " use_gradient_checkpointing = \"unsloth\", # Enable long context finetuning\n", - " random_state = 3407,\n", + " r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n", + " target_modules=[\n", + " \"q_proj\",\n", + " \"k_proj\",\n", + " \"v_proj\",\n", + " \"o_proj\",\n", + " \"gate_proj\",\n", + " \"up_proj\",\n", + " \"down_proj\",\n", + " ], # Remove QKVO if out of memory\n", + " lora_alpha=lora_rank,\n", + " use_gradient_checkpointing=\"unsloth\", # Enable long context finetuning\n", + " random_state=3407,\n", ")" ] }, @@ -445,43 +459,58 @@ "\n", "\"\"\"\n", "\n", + "\n", "def extract_xml_answer(text: str) -> str:\n", " answer = text.split(\"\")[-1]\n", " answer = answer.split(\"\")[0]\n", " return answer.strip()\n", "\n", + "\n", "def extract_hash_answer(text: str) -> str | None:\n", " if \"####\" not in text:\n", " return None\n", " return text.split(\"####\")[1].strip()\n", "\n", + "\n", "# uncomment middle messages for 1-shot prompting\n", - "def get_gsm8k_questions(split = \"train\") -> Dataset:\n", - " data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore\n", - " data = data.map(lambda x: { # type: ignore\n", - " 'prompt': [\n", - " {'role': 'system', 'content': SYSTEM_PROMPT},\n", - " {'role': 'user', 'content': x['question']}\n", - " ],\n", - " 'answer': extract_hash_answer(x['answer'])\n", - " }) # type: ignore\n", - " return data # type: ignore\n", + "def get_gsm8k_questions(split=\"train\") -> Dataset:\n", + " data = load_dataset(\"openai/gsm8k\", \"main\")[split] # type: ignore\n", + " data = data.map(\n", + " lambda x: { # type: ignore\n", + " \"prompt\": [\n", + " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", + " {\"role\": \"user\", \"content\": x[\"question\"]},\n", + " ],\n", + " \"answer\": extract_hash_answer(x[\"answer\"]),\n", + " }\n", + " ) # type: ignore\n", + " return data # type: ignore\n", + "\n", "\n", "dataset = get_gsm8k_questions()\n", "\n", + "\n", "# Reward functions\n", "def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n", - " responses = [completion[0]['content'] for completion in completions]\n", - " q = prompts[0][-1]['content']\n", + " responses = [completion[0][\"content\"] for completion in completions]\n", + " q = prompts[0][-1][\"content\"]\n", " extracted_responses = [extract_xml_answer(r) for r in responses]\n", - " print('-'*20, f\"Question:\\n{q}\", f\"\\nAnswer:\\n{answer[0]}\", f\"\\nResponse:\\n{responses[0]}\", f\"\\nExtracted:\\n{extracted_responses[0]}\")\n", + " print(\n", + " \"-\" * 20,\n", + " f\"Question:\\n{q}\",\n", + " f\"\\nAnswer:\\n{answer[0]}\",\n", + " f\"\\nResponse:\\n{responses[0]}\",\n", + " f\"\\nExtracted:\\n{extracted_responses[0]}\",\n", + " )\n", " return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n", "\n", + "\n", "def int_reward_func(completions, **kwargs) -> list[float]:\n", - " responses = [completion[0]['content'] for completion in completions]\n", + " responses = [completion[0][\"content\"] for completion in completions]\n", " extracted_responses = [extract_xml_answer(r) for r in responses]\n", " return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n", "\n", + "\n", "def strict_format_reward_func(completions, **kwargs) -> list[float]:\n", " \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n", " pattern = r\"^\\n.*?\\n\\n\\n.*?\\n\\n$\"\n", @@ -489,6 +518,7 @@ " matches = [re.match(pattern, r) for r in responses]\n", " return [0.5 if match else 0.0 for match in matches]\n", "\n", + "\n", "def soft_format_reward_func(completions, **kwargs) -> list[float]:\n", " \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n", " pattern = r\".*?\\s*.*?\"\n", @@ -496,6 +526,7 @@ " matches = [re.match(pattern, r) for r in responses]\n", " return [0.5 if match else 0.0 for match in matches]\n", "\n", + "\n", "def count_xml(text) -> float:\n", " count = 0.0\n", " if text.count(\"\\n\") == 1:\n", @@ -504,12 +535,13 @@ " count += 0.125\n", " if text.count(\"\\n\\n\") == 1:\n", " count += 0.125\n", - " count -= len(text.split(\"\\n\\n\")[-1])*0.001\n", + " count -= len(text.split(\"\\n\\n\")[-1]) * 0.001\n", " if text.count(\"\\n\") == 1:\n", " count += 0.125\n", - " count -= (len(text.split(\"\\n\")[-1]) - 1)*0.001\n", + " count -= (len(text.split(\"\\n\")[-1]) - 1) * 0.001\n", " return count\n", "\n", + "\n", "def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n", " contents = [completion[0][\"content\"] for completion in completions]\n", " return [count_xml(c) for c in contents]" @@ -540,29 +572,30 @@ "outputs": [], "source": [ "from trl import GRPOConfig, GRPOTrainer\n", + "\n", "training_args = GRPOConfig(\n", - " use_vllm = True, # use vLLM for fast inference!\n", - " learning_rate = 5e-6,\n", - " adam_beta1 = 0.9,\n", - " adam_beta2 = 0.99,\n", - " weight_decay = 0.1,\n", - " warmup_ratio = 0.1,\n", - " lr_scheduler_type = \"cosine\",\n", - " optim = \"adamw_8bit\",\n", - " logging_steps = 1,\n", - " bf16 = is_bfloat16_supported(),\n", - " fp16 = not is_bfloat16_supported(),\n", - " per_device_train_batch_size = 1,\n", - " gradient_accumulation_steps = 1, # Increase to 4 for smoother training\n", - " num_generations = 8, # Decrease if out of memory\n", - " max_prompt_length = 256,\n", - " max_completion_length = 200,\n", + " use_vllm=True, # use vLLM for fast inference!\n", + " learning_rate=5e-6,\n", + " adam_beta1=0.9,\n", + " adam_beta2=0.99,\n", + " weight_decay=0.1,\n", + " warmup_ratio=0.1,\n", + " lr_scheduler_type=\"cosine\",\n", + " optim=\"adamw_8bit\",\n", + " logging_steps=1,\n", + " bf16=is_bfloat16_supported(),\n", + " fp16=not is_bfloat16_supported(),\n", + " per_device_train_batch_size=1,\n", + " gradient_accumulation_steps=1, # Increase to 4 for smoother training\n", + " num_generations=8, # Decrease if out of memory\n", + " max_prompt_length=256,\n", + " max_completion_length=200,\n", " # num_train_epochs = 1, # Set to 1 for a full training run\n", - " max_steps = 250,\n", - " save_steps = 250,\n", - " max_grad_norm = 0.1,\n", - " report_to = \"none\", # Can use Weights & Biases\n", - " output_dir = \"outputs\",\n", + " max_steps=250,\n", + " save_steps=250,\n", + " max_grad_norm=0.1,\n", + " report_to=\"none\", # Can use Weights & Biases\n", + " output_dir=\"outputs\",\n", ")" ] }, @@ -597,17 +630,17 @@ "outputs": [], "source": [ "trainer = GRPOTrainer(\n", - " model = model,\n", - " processing_class = tokenizer,\n", - " reward_funcs = [\n", + " model=model,\n", + " processing_class=tokenizer,\n", + " reward_funcs=[\n", " xmlcount_reward_func,\n", " soft_format_reward_func,\n", " strict_format_reward_func,\n", " int_reward_func,\n", " correctness_reward_func,\n", " ],\n", - " args = training_args,\n", - " train_dataset = dataset,\n", + " args=training_args,\n", + " train_dataset=dataset,\n", ")\n", "trainer.train()" ] @@ -636,21 +669,30 @@ }, "outputs": [], "source": [ - "text = tokenizer.apply_chat_template([\n", - " {\"role\" : \"user\", \"content\" : \"How many r's are in strawberry?\"},\n", - "], tokenize = False, add_generation_prompt = True)\n", + "text = tokenizer.apply_chat_template(\n", + " [\n", + " {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n", + " ],\n", + " tokenize=False,\n", + " add_generation_prompt=True,\n", + ")\n", "\n", "from vllm import SamplingParams\n", + "\n", "sampling_params = SamplingParams(\n", - " temperature = 0.8,\n", - " top_p = 0.95,\n", - " max_tokens = 1024,\n", + " temperature=0.8,\n", + " top_p=0.95,\n", + " max_tokens=1024,\n", + ")\n", + "output = (\n", + " model.fast_generate(\n", + " [text],\n", + " sampling_params=sampling_params,\n", + " lora_request=None,\n", + " )[0]\n", + " .outputs[0]\n", + " .text\n", ")\n", - "output = model.fast_generate(\n", - " [text],\n", - " sampling_params = sampling_params,\n", - " lora_request = None,\n", - ")[0].outputs[0].text\n", "\n", "output" ] @@ -697,22 +739,31 @@ }, "outputs": [], "source": [ - "text = tokenizer.apply_chat_template([\n", - " {\"role\" : \"system\", \"content\" : SYSTEM_PROMPT},\n", - " {\"role\" : \"user\", \"content\" : \"How many r's are in strawberry?\"},\n", - "], tokenize = False, add_generation_prompt = True)\n", + "text = tokenizer.apply_chat_template(\n", + " [\n", + " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", + " {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n", + " ],\n", + " tokenize=False,\n", + " add_generation_prompt=True,\n", + ")\n", "\n", "from vllm import SamplingParams\n", + "\n", "sampling_params = SamplingParams(\n", - " temperature = 0.8,\n", - " top_p = 0.95,\n", - " max_tokens = 1024,\n", + " temperature=0.8,\n", + " top_p=0.95,\n", + " max_tokens=1024,\n", + ")\n", + "output = (\n", + " model.fast_generate(\n", + " text,\n", + " sampling_params=sampling_params,\n", + " lora_request=model.load_lora(\"grpo_saved_lora\"),\n", + " )[0]\n", + " .outputs[0]\n", + " .text\n", ")\n", - "output = model.fast_generate(\n", - " text,\n", - " sampling_params = sampling_params,\n", - " lora_request = model.load_lora(\"grpo_saved_lora\"),\n", - ")[0].outputs[0].text\n", "\n", "output" ] @@ -747,16 +798,36 @@ "outputs": [], "source": [ "# Merge to 16bit\n", - "if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"merged_16bit\",)\n", - "if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"merged_16bit\", token = \"\")\n", + "if False:\n", + " model.save_pretrained_merged(\n", + " \"model\",\n", + " tokenizer,\n", + " save_method=\"merged_16bit\",\n", + " )\n", + "if False:\n", + " model.push_to_hub_merged(\n", + " \"hf/model\", tokenizer, save_method=\"merged_16bit\", token=\"\"\n", + " )\n", "\n", "# Merge to 4bit\n", - "if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"merged_4bit\",)\n", - "if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"merged_4bit\", token = \"\")\n", + "if False:\n", + " model.save_pretrained_merged(\n", + " \"model\",\n", + " tokenizer,\n", + " save_method=\"merged_4bit\",\n", + " )\n", + "if False:\n", + " model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"merged_4bit\", token=\"\")\n", "\n", "# Just LoRA adapters\n", - "if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"lora\",)\n", - "if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"lora\", token = \"\")" + "if False:\n", + " model.save_pretrained_merged(\n", + " \"model\",\n", + " tokenizer,\n", + " save_method=\"lora\",\n", + " )\n", + "if False:\n", + " model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"lora\", token=\"\")" ] }, { @@ -785,26 +856,41 @@ "outputs": [], "source": [ "# Save to 8bit Q8_0\n", - "if False: model.save_pretrained_gguf(\"model\", tokenizer,)\n", + "if False:\n", + " model.save_pretrained_gguf(\n", + " \"model\",\n", + " tokenizer,\n", + " )\n", "# Remember to go to https://huggingface.co/settings/tokens for a token!\n", "# And change hf to your username!\n", - "if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, token = \"\")\n", + "if False:\n", + " model.push_to_hub_gguf(\"hf/model\", tokenizer, token=\"\")\n", "\n", "# Save to 16bit GGUF\n", - "if False: model.save_pretrained_gguf(\"model\", tokenizer, quantization_method = \"f16\")\n", - "if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method = \"f16\", token = \"\")\n", + "if False:\n", + " model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"f16\")\n", + "if False:\n", + " model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method=\"f16\", token=\"\")\n", "\n", "# Save to q4_k_m GGUF\n", - "if False: model.save_pretrained_gguf(\"model\", tokenizer, quantization_method = \"q4_k_m\")\n", - "if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method = \"q4_k_m\", token = \"\")\n", + "if False:\n", + " model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"q4_k_m\")\n", + "if False:\n", + " model.push_to_hub_gguf(\n", + " \"hf/model\", tokenizer, quantization_method=\"q4_k_m\", token=\"\"\n", + " )\n", "\n", "# Save to multiple GGUF options - much faster if you want multiple!\n", "if False:\n", " model.push_to_hub_gguf(\n", - " \"hf/model\", # Change hf to your username!\n", + " \"hf/model\", # Change hf to your username!\n", " tokenizer,\n", - " quantization_method = [\"q4_k_m\", \"q8_0\", \"q5_k_m\",],\n", - " token = \"\",\n", + " quantization_method=[\n", + " \"q4_k_m\",\n", + " \"q8_0\",\n", + " \"q5_k_m\",\n", + " ],\n", + " token=\"\",\n", " )" ] }, diff --git a/notebooks/README.md b/notebooks/README.md new file mode 100644 index 0000000..f3c905f --- /dev/null +++ b/notebooks/README.md @@ -0,0 +1,7 @@ +# Clean Notebooks Output 101 + +``` +pip install nbstripout +nbstripout --status +nbstripout --install --global +``` diff --git a/notebooks/train_autodidact.ipynb b/notebooks/train_autodidact.ipynb index ff33f4d..9221516 100644 --- a/notebooks/train_autodidact.ipynb +++ b/notebooks/train_autodidact.ipynb @@ -61,7 +61,6 @@ "outputs": [], "source": [ "from unsloth import is_bfloat16_supported\n", - "import torch\n", "\n", "max_seq_length = 4096 * 2 # Can increase for longer reasoning traces\n", "lora_rank = 64 # Larger rank = smarter, but slower\n", @@ -101,9 +100,6 @@ }, "outputs": [], "source": [ - "import re\n", - "from datasets import load_dataset, Dataset\n", - "from search_module import search, get_question_answer, get_question_count\n", "from rl_helpers import get_qa_dataset\n", "\n", "train_dataset, test_dataset = get_qa_dataset()" @@ -187,7 +183,7 @@ "def agentic_generate(\n", " prompts: list[str],\n", " generate_fn,\n", - " max_generations: int = 6,\n", + " max_generations: int = 10,\n", "):\n", " return run_agent(generate_fn, tokenizer, prompts, max_generations)\n", "\n",