chore: clean up notebooks

main
thinhlpg 1 month ago
parent 3c2deaced9
commit 7f2f43aa46

@ -31,16 +31,10 @@
"\n", "\n",
"sys.path.append(\"..\")\n", "sys.path.append(\"..\")\n",
"\n", "\n",
"import json\n",
"import os\n",
"import pickle\n", "import pickle\n",
"import re\n",
"from typing import Dict, List, Optional, Tuple\n",
"\n", "\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"\n", "\n",
"\n", "\n",
"from langchain_community.document_loaders import UnstructuredMarkdownLoader\n",
"from langchain_community.vectorstores import FAISS\n", "from langchain_community.vectorstores import FAISS\n",
"\n", "\n",
"from embeddings import CustomHuggingFaceEmbeddings" "from embeddings import CustomHuggingFaceEmbeddings"
@ -101,7 +95,7 @@
"): # Ok cool, so this is much simpler than i expected!\n", "): # Ok cool, so this is much simpler than i expected!\n",
" print(f\"\\n--- Chunk {i + 1}/{len(chunks)} ---\")\n", " print(f\"\\n--- Chunk {i + 1}/{len(chunks)} ---\")\n",
" print(chunk.page_content)\n", " print(chunk.page_content)\n",
" print(\"-\" * 50)\n" " print(\"-\" * 50)"
] ]
}, },
{ {
@ -323,8 +317,6 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Load the existing FAISS index\n", "# Load the existing FAISS index\n",
"from langchain_community.vectorstores import FAISS\n",
"from embeddings import CustomHuggingFaceEmbeddings\n",
"\n", "\n",
"\n", "\n",
"# Load the paraphrased chunks\n", "# Load the paraphrased chunks\n",
@ -365,9 +357,7 @@
"\n", "\n",
"# Save the updated vectorstore\n", "# Save the updated vectorstore\n",
"existing_vectorstore.save_local(\"faiss_index_with_paraphrased\")\n", "existing_vectorstore.save_local(\"faiss_index_with_paraphrased\")\n",
"print(\"Saved updated FAISS index to 'faiss_index_with_paraphrased'\")\n", "print(\"Saved updated FAISS index to 'faiss_index_with_paraphrased'\")"
"\n",
"\n"
] ]
}, },
{ {

@ -0,0 +1,105 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# LLama 3.1 \n",
"\n",
"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nCutting Knowledge Date: December 2023\\nToday Date: 26 Jul 2024\\n\\n<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nHello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\nI'm doing great. How can I help you today?<|eot_id|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a friendly chatbot who always responds in the style of a pirate<|eot_id|>\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/meta-Llama-3.1-8B-Instruct\")\n",
"chat = [\n",
" {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n",
" {\"role\": \"assistant\", \"content\": \"I'm doing great. How can I help you today?\"},\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You are a friendly chatbot who always responds in the style of a pirate\",\n",
" },\n",
"]\n",
"\n",
"tokenizer.apply_chat_template(chat, tokenize=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Qwen 1.5B R1 Distill\n",
"\n",
"\"<begin▁of▁sentence>You are a friendly chatbot who always responds in the style of a pirate<User>Hello, how are you?<Assistant>I'm doing great. How can I help you today?<end▁of▁sentence>\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\")\n",
"chat = [\n",
" {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n",
" {\"role\": \"assistant\", \"content\": \"I'm doing great. How can I help you today?\"},\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You are a friendly chatbot who always responds in the style of a pirate\",\n",
" },\n",
"]\n",
"\n",
"tokenizer.apply_chat_template(chat, tokenize=False)\n",
"tokenizer.apply_chat_template(chat, tokenize=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ✅ Compare the two\n",
"\n",
"- \"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nCutting Knowledge Date: December 2023\\nToday Date: 26 Jul 2024\\n\\n<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\n**Hello, how are you?**<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n**I'm doing great. How can I help you today**?<|eot_id|><|start_header_id|>system<|end_header_id|>\\n\\n**You are a friendly chatbot who always responds in the style of a pirate**<|eot_id|>\"\n",
"- \"<begin▁of▁sentence>**You are a friendly chatbot who always responds in the style of a pirate**<User>**Hello, how are you?**<Assistant>**I'm doing great. How can I help you today?**<end▁of▁sentence>\"\n",
"- Ok make sense now!, so the structure of r1-distil doesn't have closing tags for most of the tags\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Just curious, did Alpha Maze touch anything with the chat template?\n",
"- Nope, as alpha maze task isn't as complicated as agent tool call stuffs, so it doesn't need to tweak the chat template"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "deepsearch-py311",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -0,0 +1,123 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"sys.path.append(\"..\")\n",
"\n",
"from src.search_module import get_qa_dataset\n",
"import random\n",
"\n",
"\n",
"def inspect_qa_dataset():\n",
" \"\"\"Inspect the QA dataset used for evaluation to identify potential issues\"\"\"\n",
"\n",
" # Get the datasets\n",
" train_dataset, test_dataset = get_qa_dataset()\n",
"\n",
" # Print dataset statistics\n",
" print(f\"Train dataset size: {len(train_dataset)}\")\n",
" print(f\"Test dataset size: {len(test_dataset)}\")\n",
"\n",
" # Print column information\n",
" print(f\"\\nTest dataset columns: {test_dataset.column_names}\")\n",
"\n",
" # Print a few random examples\n",
" sample_size = min(5, len(test_dataset))\n",
" sample_indices = random.sample(range(len(test_dataset)), sample_size)\n",
"\n",
" print(f\"\\n--- {sample_size} Random Test Examples ---\")\n",
" for i, idx in enumerate(sample_indices):\n",
" example = test_dataset[idx]\n",
" print(f\"\\nExample {i+1}:\")\n",
" print(f\"Prompt: {example['prompt']}\")\n",
" print(f\"Answer: {example['answer']}\")\n",
" if \"chunk_content\" in example:\n",
" print(f\"Chunk Content: {example['chunk_content'][:200]}... (truncated)\")\n",
"\n",
" # Check for potential issues\n",
" print(\"\\n--- Dataset Analysis ---\")\n",
"\n",
" # Check for duplicate questions\n",
" prompts = test_dataset[\"prompt\"]\n",
" duplicate_count = len(prompts) - len(set(prompts))\n",
" print(f\"Duplicate prompts: {duplicate_count}\")\n",
"\n",
" # Check answer length distribution\n",
" answer_lengths = [len(ans) for ans in test_dataset[\"answer\"]]\n",
" avg_answer_length = sum(answer_lengths) / len(answer_lengths)\n",
" min_answer_length = min(answer_lengths)\n",
" max_answer_length = max(answer_lengths)\n",
" print(\n",
" f\"Answer length stats: min={min_answer_length}, avg={avg_answer_length:.1f}, max={max_answer_length}\"\n",
" )\n",
"\n",
" # Analyze prompt types if possible\n",
" if len(prompts) > 0:\n",
" qa_count = sum(1 for p in prompts if p.endswith(\"?\"))\n",
" print(\n",
" f\"Questions ending with '?': {qa_count} ({qa_count/len(prompts)*100:.1f}%)\"\n",
" )\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" inspect_qa_dataset()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"train_dataset, test_dataset = get_qa_dataset()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Sanity check 32 test cases: -> 31/32 is correct, nothing wrong with the test data here :/\n",
"\n",
"brow wtf is happening 😭"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_dataset"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "deepsearch-py311",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -18,6 +18,7 @@
"import numpy as np\n", "import numpy as np\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"\n", "\n",
"\n",
"def plot_reward_functions():\n", "def plot_reward_functions():\n",
" # Generate retry counts from 0 to 15\n", " # Generate retry counts from 0 to 15\n",
" retries = np.linspace(0, 15, 100)\n", " retries = np.linspace(0, 15, 100)\n",
@ -39,37 +40,57 @@
" # Plotting\n", " # Plotting\n",
" plt.figure(figsize=(12, 6))\n", " plt.figure(figsize=(12, 6))\n",
"\n", "\n",
" plt.plot(retries, basic_sigmoid, 'b--', label='Basic Sigmoid')\n", " plt.plot(retries, basic_sigmoid, \"b--\", label=\"Basic Sigmoid\")\n",
" plt.plot(retries, modified_sigmoid, 'g--', label='Modified Sigmoid')\n", " plt.plot(retries, modified_sigmoid, \"g--\", label=\"Modified Sigmoid\")\n",
" plt.plot(retries, penalized_reward, 'r-', label='Final Reward (with penalty)', linewidth=2)\n", " plt.plot(\n",
" retries,\n",
" penalized_reward,\n",
" \"r-\",\n",
" label=\"Final Reward (with penalty)\",\n",
" linewidth=2,\n",
" )\n",
"\n", "\n",
" # Add vertical lines for key points\n", " # Add vertical lines for key points\n",
" plt.axvline(x=4, color='gray', linestyle=':', alpha=0.5, label='Peak (4 retries)')\n", " plt.axvline(x=4, color=\"gray\", linestyle=\":\", alpha=0.5, label=\"Peak (4 retries)\")\n",
" plt.axvline(x=6, color='gray', linestyle=':', alpha=0.5, label='Penalty Start (6 retries)')\n", " plt.axvline(\n",
" x=6, color=\"gray\", linestyle=\":\", alpha=0.5, label=\"Penalty Start (6 retries)\"\n",
" )\n",
"\n", "\n",
" plt.grid(True, alpha=0.3)\n", " plt.grid(True, alpha=0.3)\n",
" plt.xlabel('Number of Retries')\n", " plt.xlabel(\"Number of Retries\")\n",
" plt.ylabel('Reward')\n", " plt.ylabel(\"Reward\")\n",
" plt.title('Reward Function Visualization')\n", " plt.title(\"Reward Function Visualization\")\n",
" plt.legend()\n", " plt.legend()\n",
" plt.ylim(-0.1, 1.1)\n", " plt.ylim(-0.1, 1.1)\n",
"\n", "\n",
" # Add annotations\n", " # Add annotations\n",
" plt.annotate('Optimal Zone', xy=(4, 0.8), xytext=(4, 0.9),\n", " plt.annotate(\n",
" ha='center', va='bottom',\n", " \"Optimal Zone\",\n",
" bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3),\n", " xy=(4, 0.8),\n",
" arrowprops=dict(arrowstyle='->'))\n", " xytext=(4, 0.9),\n",
" ha=\"center\",\n",
" va=\"bottom\",\n",
" bbox=dict(boxstyle=\"round,pad=0.5\", fc=\"yellow\", alpha=0.3),\n",
" arrowprops=dict(arrowstyle=\"->\"),\n",
" )\n",
"\n", "\n",
" plt.annotate('Penalty Zone', xy=(8, 0.3), xytext=(8, 0.5),\n", " plt.annotate(\n",
" ha='center', va='bottom',\n", " \"Penalty Zone\",\n",
" bbox=dict(boxstyle='round,pad=0.5', fc='red', alpha=0.3),\n", " xy=(8, 0.3),\n",
" arrowprops=dict(arrowstyle='->'))\n", " xytext=(8, 0.5),\n",
" ha=\"center\",\n",
" va=\"bottom\",\n",
" bbox=dict(boxstyle=\"round,pad=0.5\", fc=\"red\", alpha=0.3),\n",
" arrowprops=dict(arrowstyle=\"->\"),\n",
" )\n",
"\n", "\n",
" plt.show()\n", " plt.show()\n",
"\n", "\n",
"\n",
"# Run the visualization\n", "# Run the visualization\n",
"plot_reward_functions()\n", "plot_reward_functions()\n",
"\n", "\n",
"\n",
"# Print reward values for specific retry counts\n", "# Print reward values for specific retry counts\n",
"def print_reward_examples():\n", "def print_reward_examples():\n",
" retry_examples = [1, 2, 3, 4, 5, 6, 7, 8, 10, 12]\n", " retry_examples = [1, 2, 3, 4, 5, 6, 7, 8, 10, 12]\n",
@ -85,6 +106,7 @@
" reward = max(0.1, reward - penalty)\n", " reward = max(0.1, reward - penalty)\n",
" print(f\"{retries:7d} | {reward:.3f}\")\n", " print(f\"{retries:7d} | {reward:.3f}\")\n",
"\n", "\n",
"\n",
"print_reward_examples()" "print_reward_examples()"
] ]
} }

@ -36,7 +36,9 @@
" !pip install -q --no-deps unsloth vllm\n", " !pip install -q --no-deps unsloth vllm\n",
" # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n", " # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
" # Skip restarting message in Colab\n", " # Skip restarting message in Colab\n",
" import sys, re, requests\n", " import sys\n",
" import re\n",
" import requests\n",
"\n", "\n",
" modules = list(sys.modules.keys())\n", " modules = list(sys.modules.keys())\n",
" for x in modules:\n", " for x in modules:\n",
@ -197,7 +199,6 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"from unsloth import FastLanguageModel\n", "from unsloth import FastLanguageModel\n",
"import torch\n",
"\n", "\n",
"max_seq_length = 1024 # Can increase for longer reasoning traces\n", "max_seq_length = 1024 # Can increase for longer reasoning traces\n",
"lora_rank = 32 # Larger rank = smarter, but slower\n", "lora_rank = 32 # Larger rank = smarter, but slower\n",

@ -57,6 +57,7 @@
"source": [ "source": [
"%%capture\n", "%%capture\n",
"import os\n", "import os\n",
"\n",
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n", "if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
" !pip install unsloth vllm\n", " !pip install unsloth vllm\n",
"else:\n", "else:\n",
@ -75,19 +76,27 @@
"# @title Colab Extra Install { display-mode: \"form\" }\n", "# @title Colab Extra Install { display-mode: \"form\" }\n",
"%%capture\n", "%%capture\n",
"import os\n", "import os\n",
"\n",
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n", "if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
" !pip install unsloth vllm\n", " !pip install unsloth vllm\n",
"else:\n", "else:\n",
" !pip install --no-deps unsloth vllm\n", " !pip install --no-deps unsloth vllm\n",
" # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n", " # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]\n",
" # Skip restarting message in Colab\n", " # Skip restarting message in Colab\n",
" import sys, re, requests; modules = list(sys.modules.keys())\n", " import sys\n",
" for x in modules: sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n", " import re\n",
" import requests\n",
"\n",
" modules = list(sys.modules.keys())\n",
" for x in modules:\n",
" sys.modules.pop(x) if \"PIL\" in x or \"google\" in x else None\n",
" !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n", " !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n",
" !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer\n", " !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer\n",
"\n", "\n",
" # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n", " # vLLM requirements - vLLM breaks Colab due to reinstalling numpy\n",
" f = requests.get(\"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\").content\n", " f = requests.get(\n",
" \"https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt\"\n",
" ).content\n",
" with open(\"vllm_requirements.txt\", \"wb\") as file:\n", " with open(\"vllm_requirements.txt\", \"wb\") as file:\n",
" file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n", " file.write(re.sub(rb\"(transformers|numpy|xformers)[^\\n]{1,}\\n\", b\"\", f))\n",
" !pip install -r vllm_requirements.txt" " !pip install -r vllm_requirements.txt"
@ -303,7 +312,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"from unsloth import FastLanguageModel, is_bfloat16_supported\n", "from unsloth import FastLanguageModel, is_bfloat16_supported\n",
"import torch\n", "\n",
"max_seq_length = 1024 # Can increase for longer reasoning traces\n", "max_seq_length = 1024 # Can increase for longer reasoning traces\n",
"lora_rank = 64 # Larger rank = smarter, but slower\n", "lora_rank = 64 # Larger rank = smarter, but slower\n",
"\n", "\n",
@ -320,8 +329,13 @@
" model,\n", " model,\n",
" r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n", " r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
" target_modules=[\n", " target_modules=[\n",
" \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n", " \"q_proj\",\n",
" \"gate_proj\", \"up_proj\", \"down_proj\",\n", " \"k_proj\",\n",
" \"v_proj\",\n",
" \"o_proj\",\n",
" \"gate_proj\",\n",
" \"up_proj\",\n",
" \"down_proj\",\n",
" ], # Remove QKVO if out of memory\n", " ], # Remove QKVO if out of memory\n",
" lora_alpha=lora_rank,\n", " lora_alpha=lora_rank,\n",
" use_gradient_checkpointing=\"unsloth\", # Enable long context finetuning\n", " use_gradient_checkpointing=\"unsloth\", # Enable long context finetuning\n",
@ -445,43 +459,58 @@
"</answer>\n", "</answer>\n",
"\"\"\"\n", "\"\"\"\n",
"\n", "\n",
"\n",
"def extract_xml_answer(text: str) -> str:\n", "def extract_xml_answer(text: str) -> str:\n",
" answer = text.split(\"<answer>\")[-1]\n", " answer = text.split(\"<answer>\")[-1]\n",
" answer = answer.split(\"</answer>\")[0]\n", " answer = answer.split(\"</answer>\")[0]\n",
" return answer.strip()\n", " return answer.strip()\n",
"\n", "\n",
"\n",
"def extract_hash_answer(text: str) -> str | None:\n", "def extract_hash_answer(text: str) -> str | None:\n",
" if \"####\" not in text:\n", " if \"####\" not in text:\n",
" return None\n", " return None\n",
" return text.split(\"####\")[1].strip()\n", " return text.split(\"####\")[1].strip()\n",
"\n", "\n",
"\n",
"# uncomment middle messages for 1-shot prompting\n", "# uncomment middle messages for 1-shot prompting\n",
"def get_gsm8k_questions(split=\"train\") -> Dataset:\n", "def get_gsm8k_questions(split=\"train\") -> Dataset:\n",
" data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore\n", " data = load_dataset(\"openai/gsm8k\", \"main\")[split] # type: ignore\n",
" data = data.map(lambda x: { # type: ignore\n", " data = data.map(\n",
" 'prompt': [\n", " lambda x: { # type: ignore\n",
" {'role': 'system', 'content': SYSTEM_PROMPT},\n", " \"prompt\": [\n",
" {'role': 'user', 'content': x['question']}\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": x[\"question\"]},\n",
" ],\n", " ],\n",
" 'answer': extract_hash_answer(x['answer'])\n", " \"answer\": extract_hash_answer(x[\"answer\"]),\n",
" }) # type: ignore\n", " }\n",
" ) # type: ignore\n",
" return data # type: ignore\n", " return data # type: ignore\n",
"\n", "\n",
"\n",
"dataset = get_gsm8k_questions()\n", "dataset = get_gsm8k_questions()\n",
"\n", "\n",
"\n",
"# Reward functions\n", "# Reward functions\n",
"def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n", "def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n",
" responses = [completion[0]['content'] for completion in completions]\n", " responses = [completion[0][\"content\"] for completion in completions]\n",
" q = prompts[0][-1]['content']\n", " q = prompts[0][-1][\"content\"]\n",
" extracted_responses = [extract_xml_answer(r) for r in responses]\n", " extracted_responses = [extract_xml_answer(r) for r in responses]\n",
" print('-'*20, f\"Question:\\n{q}\", f\"\\nAnswer:\\n{answer[0]}\", f\"\\nResponse:\\n{responses[0]}\", f\"\\nExtracted:\\n{extracted_responses[0]}\")\n", " print(\n",
" \"-\" * 20,\n",
" f\"Question:\\n{q}\",\n",
" f\"\\nAnswer:\\n{answer[0]}\",\n",
" f\"\\nResponse:\\n{responses[0]}\",\n",
" f\"\\nExtracted:\\n{extracted_responses[0]}\",\n",
" )\n",
" return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n", " return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n",
"\n", "\n",
"\n",
"def int_reward_func(completions, **kwargs) -> list[float]:\n", "def int_reward_func(completions, **kwargs) -> list[float]:\n",
" responses = [completion[0]['content'] for completion in completions]\n", " responses = [completion[0][\"content\"] for completion in completions]\n",
" extracted_responses = [extract_xml_answer(r) for r in responses]\n", " extracted_responses = [extract_xml_answer(r) for r in responses]\n",
" return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n", " return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n",
"\n", "\n",
"\n",
"def strict_format_reward_func(completions, **kwargs) -> list[float]:\n", "def strict_format_reward_func(completions, **kwargs) -> list[float]:\n",
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n", " \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
" pattern = r\"^<reasoning>\\n.*?\\n</reasoning>\\n<answer>\\n.*?\\n</answer>\\n$\"\n", " pattern = r\"^<reasoning>\\n.*?\\n</reasoning>\\n<answer>\\n.*?\\n</answer>\\n$\"\n",
@ -489,6 +518,7 @@
" matches = [re.match(pattern, r) for r in responses]\n", " matches = [re.match(pattern, r) for r in responses]\n",
" return [0.5 if match else 0.0 for match in matches]\n", " return [0.5 if match else 0.0 for match in matches]\n",
"\n", "\n",
"\n",
"def soft_format_reward_func(completions, **kwargs) -> list[float]:\n", "def soft_format_reward_func(completions, **kwargs) -> list[float]:\n",
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n", " \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
" pattern = r\"<reasoning>.*?</reasoning>\\s*<answer>.*?</answer>\"\n", " pattern = r\"<reasoning>.*?</reasoning>\\s*<answer>.*?</answer>\"\n",
@ -496,6 +526,7 @@
" matches = [re.match(pattern, r) for r in responses]\n", " matches = [re.match(pattern, r) for r in responses]\n",
" return [0.5 if match else 0.0 for match in matches]\n", " return [0.5 if match else 0.0 for match in matches]\n",
"\n", "\n",
"\n",
"def count_xml(text) -> float:\n", "def count_xml(text) -> float:\n",
" count = 0.0\n", " count = 0.0\n",
" if text.count(\"<reasoning>\\n\") == 1:\n", " if text.count(\"<reasoning>\\n\") == 1:\n",
@ -510,6 +541,7 @@
" count -= (len(text.split(\"\\n</answer>\")[-1]) - 1) * 0.001\n", " count -= (len(text.split(\"\\n</answer>\")[-1]) - 1) * 0.001\n",
" return count\n", " return count\n",
"\n", "\n",
"\n",
"def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n", "def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n",
" contents = [completion[0][\"content\"] for completion in completions]\n", " contents = [completion[0][\"content\"] for completion in completions]\n",
" return [count_xml(c) for c in contents]" " return [count_xml(c) for c in contents]"
@ -540,6 +572,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"from trl import GRPOConfig, GRPOTrainer\n", "from trl import GRPOConfig, GRPOTrainer\n",
"\n",
"training_args = GRPOConfig(\n", "training_args = GRPOConfig(\n",
" use_vllm=True, # use vLLM for fast inference!\n", " use_vllm=True, # use vLLM for fast inference!\n",
" learning_rate=5e-6,\n", " learning_rate=5e-6,\n",
@ -636,21 +669,30 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"text = tokenizer.apply_chat_template([\n", "text = tokenizer.apply_chat_template(\n",
" [\n",
" {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n", " {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n",
"], tokenize = False, add_generation_prompt = True)\n", " ],\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
")\n",
"\n", "\n",
"from vllm import SamplingParams\n", "from vllm import SamplingParams\n",
"\n",
"sampling_params = SamplingParams(\n", "sampling_params = SamplingParams(\n",
" temperature=0.8,\n", " temperature=0.8,\n",
" top_p=0.95,\n", " top_p=0.95,\n",
" max_tokens=1024,\n", " max_tokens=1024,\n",
")\n", ")\n",
"output = model.fast_generate(\n", "output = (\n",
" model.fast_generate(\n",
" [text],\n", " [text],\n",
" sampling_params=sampling_params,\n", " sampling_params=sampling_params,\n",
" lora_request=None,\n", " lora_request=None,\n",
")[0].outputs[0].text\n", " )[0]\n",
" .outputs[0]\n",
" .text\n",
")\n",
"\n", "\n",
"output" "output"
] ]
@ -697,22 +739,31 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"text = tokenizer.apply_chat_template([\n", "text = tokenizer.apply_chat_template(\n",
" [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n", " {\"role\": \"user\", \"content\": \"How many r's are in strawberry?\"},\n",
"], tokenize = False, add_generation_prompt = True)\n", " ],\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
")\n",
"\n", "\n",
"from vllm import SamplingParams\n", "from vllm import SamplingParams\n",
"\n",
"sampling_params = SamplingParams(\n", "sampling_params = SamplingParams(\n",
" temperature=0.8,\n", " temperature=0.8,\n",
" top_p=0.95,\n", " top_p=0.95,\n",
" max_tokens=1024,\n", " max_tokens=1024,\n",
")\n", ")\n",
"output = model.fast_generate(\n", "output = (\n",
" model.fast_generate(\n",
" text,\n", " text,\n",
" sampling_params=sampling_params,\n", " sampling_params=sampling_params,\n",
" lora_request=model.load_lora(\"grpo_saved_lora\"),\n", " lora_request=model.load_lora(\"grpo_saved_lora\"),\n",
")[0].outputs[0].text\n", " )[0]\n",
" .outputs[0]\n",
" .text\n",
")\n",
"\n", "\n",
"output" "output"
] ]
@ -747,16 +798,36 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Merge to 16bit\n", "# Merge to 16bit\n",
"if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"merged_16bit\",)\n", "if False:\n",
"if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"merged_16bit\", token = \"\")\n", " model.save_pretrained_merged(\n",
" \"model\",\n",
" tokenizer,\n",
" save_method=\"merged_16bit\",\n",
" )\n",
"if False:\n",
" model.push_to_hub_merged(\n",
" \"hf/model\", tokenizer, save_method=\"merged_16bit\", token=\"\"\n",
" )\n",
"\n", "\n",
"# Merge to 4bit\n", "# Merge to 4bit\n",
"if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"merged_4bit\",)\n", "if False:\n",
"if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"merged_4bit\", token = \"\")\n", " model.save_pretrained_merged(\n",
" \"model\",\n",
" tokenizer,\n",
" save_method=\"merged_4bit\",\n",
" )\n",
"if False:\n",
" model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"merged_4bit\", token=\"\")\n",
"\n", "\n",
"# Just LoRA adapters\n", "# Just LoRA adapters\n",
"if False: model.save_pretrained_merged(\"model\", tokenizer, save_method = \"lora\",)\n", "if False:\n",
"if False: model.push_to_hub_merged(\"hf/model\", tokenizer, save_method = \"lora\", token = \"\")" " model.save_pretrained_merged(\n",
" \"model\",\n",
" tokenizer,\n",
" save_method=\"lora\",\n",
" )\n",
"if False:\n",
" model.push_to_hub_merged(\"hf/model\", tokenizer, save_method=\"lora\", token=\"\")"
] ]
}, },
{ {
@ -785,25 +856,40 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Save to 8bit Q8_0\n", "# Save to 8bit Q8_0\n",
"if False: model.save_pretrained_gguf(\"model\", tokenizer,)\n", "if False:\n",
" model.save_pretrained_gguf(\n",
" \"model\",\n",
" tokenizer,\n",
" )\n",
"# Remember to go to https://huggingface.co/settings/tokens for a token!\n", "# Remember to go to https://huggingface.co/settings/tokens for a token!\n",
"# And change hf to your username!\n", "# And change hf to your username!\n",
"if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, token = \"\")\n", "if False:\n",
" model.push_to_hub_gguf(\"hf/model\", tokenizer, token=\"\")\n",
"\n", "\n",
"# Save to 16bit GGUF\n", "# Save to 16bit GGUF\n",
"if False: model.save_pretrained_gguf(\"model\", tokenizer, quantization_method = \"f16\")\n", "if False:\n",
"if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method = \"f16\", token = \"\")\n", " model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"f16\")\n",
"if False:\n",
" model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method=\"f16\", token=\"\")\n",
"\n", "\n",
"# Save to q4_k_m GGUF\n", "# Save to q4_k_m GGUF\n",
"if False: model.save_pretrained_gguf(\"model\", tokenizer, quantization_method = \"q4_k_m\")\n", "if False:\n",
"if False: model.push_to_hub_gguf(\"hf/model\", tokenizer, quantization_method = \"q4_k_m\", token = \"\")\n", " model.save_pretrained_gguf(\"model\", tokenizer, quantization_method=\"q4_k_m\")\n",
"if False:\n",
" model.push_to_hub_gguf(\n",
" \"hf/model\", tokenizer, quantization_method=\"q4_k_m\", token=\"\"\n",
" )\n",
"\n", "\n",
"# Save to multiple GGUF options - much faster if you want multiple!\n", "# Save to multiple GGUF options - much faster if you want multiple!\n",
"if False:\n", "if False:\n",
" model.push_to_hub_gguf(\n", " model.push_to_hub_gguf(\n",
" \"hf/model\", # Change hf to your username!\n", " \"hf/model\", # Change hf to your username!\n",
" tokenizer,\n", " tokenizer,\n",
" quantization_method = [\"q4_k_m\", \"q8_0\", \"q5_k_m\",],\n", " quantization_method=[\n",
" \"q4_k_m\",\n",
" \"q8_0\",\n",
" \"q5_k_m\",\n",
" ],\n",
" token=\"\",\n", " token=\"\",\n",
" )" " )"
] ]

@ -0,0 +1,7 @@
# Clean Notebooks Output 101
```
pip install nbstripout
nbstripout --status
nbstripout --install --global
```

@ -61,7 +61,6 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"from unsloth import is_bfloat16_supported\n", "from unsloth import is_bfloat16_supported\n",
"import torch\n",
"\n", "\n",
"max_seq_length = 4096 * 2 # Can increase for longer reasoning traces\n", "max_seq_length = 4096 * 2 # Can increase for longer reasoning traces\n",
"lora_rank = 64 # Larger rank = smarter, but slower\n", "lora_rank = 64 # Larger rank = smarter, but slower\n",
@ -101,9 +100,6 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"import re\n",
"from datasets import load_dataset, Dataset\n",
"from search_module import search, get_question_answer, get_question_count\n",
"from rl_helpers import get_qa_dataset\n", "from rl_helpers import get_qa_dataset\n",
"\n", "\n",
"train_dataset, test_dataset = get_qa_dataset()" "train_dataset, test_dataset = get_qa_dataset()"
@ -187,7 +183,7 @@
"def agentic_generate(\n", "def agentic_generate(\n",
" prompts: list[str],\n", " prompts: list[str],\n",
" generate_fn,\n", " generate_fn,\n",
" max_generations: int = 6,\n", " max_generations: int = 10,\n",
"):\n", "):\n",
" return run_agent(generate_fn, tokenizer, prompts, max_generations)\n", " return run_agent(generate_fn, tokenizer, prompts, max_generations)\n",
"\n", "\n",

Loading…
Cancel
Save