diff --git a/.gitignore b/.gitignore index 18aa648..343324e 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,10 @@ saved_data/ saved_models/ faiss_index/ .vscode/ +unsloth_compiled_cache/ +full_local_training/ +grpo_trainer_lora_model/ +qa_log.txt # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/docs/00_worklog.md b/docs/00_worklog.md index 20efe9a..9d2406d 100644 --- a/docs/00_worklog.md +++ b/docs/00_worklog.md @@ -14,19 +14,24 @@ - - - +- [ ] Upload datasets to HF Hub +- [ ] Make a simple gradio demo app ## yymmdd - [ ] task description +## 250325 + +- [ ] update new reward functions in [reward-functions.md](reward-functions.md) +- [ ] Train the model v0 (with new data and reward functions) (might be another 2 hours) +- [ ] Convert this notebook to script [250324_generate_data_anatomy.ipynb](../notebooks/250324_generate_data_anatomy.ipynb) + ## 250324 -- [ ] Train the model v0 -- [ ] Make the dataset v0 -- [ ] Upload dataset v0 to HF Hub - - Initial dataset from AutoDidact - - Paraphrased sdataset -- [ ] Make a simple gradio demo app +- [x] Make the dataset v0 +- [x] Train with new data and default reward functions (it took 2 hours on 1xA6000 😭) + - Got poor result (50% Accuracy down to 35%) 📉 ## 250323 diff --git a/docs/assets/reward-function-anatomy.excalidraw.png b/docs/assets/reward-function-anatomy.excalidraw.png new file mode 100644 index 0000000..3fdd340 Binary files /dev/null and b/docs/assets/reward-function-anatomy.excalidraw.png differ diff --git a/docs/dataset.md b/docs/dataset.md index 4cb2d3a..8577917 100644 --- a/docs/dataset.md +++ b/docs/dataset.md @@ -4,11 +4,22 @@ This document describes the creation of a data pipeline to generate a dataset. ## Implementation Phases -- [ ] 1.Simple chunk paraphrasing logic that's just work - - After splitting, feed the splitted chunks into LLM to paraphrase - - Rebuil the FAISS index with the paraphrased chunks - - Don't touch `question.json` -- [ ] 2.Enhance the dataset quality with API (check backlog) +- [x] V0 Initial dataset from AutoDidact (V -1) + - saved_data/chunks.pkl (need to keep this to create later dataset) + - saved_data/questions.json + - faiss_index/ +- [x] V1 Paraphrased dataset + - ~~paraphrased_chunks.pkl (no need, this sucks)~~ + - saved_data/chunks.pkl (this is for the ground truth chunks) + - saved_data/questions.json + - faiss_index/ (already contained all the documents ✅) (this include 3 new paraphrased chunks) + +- [ ] V2 Paraphrased dataset with API + - API (for better quality) + - questions.json + - faiss_index/ (already contained all the documents ✅) (this include 3 new paraphrased chunks) +- [ ] V3 + - IDK, let's survive V1 first. ## Inital idea from @tikikun diff --git a/docs/ds-pipeline-v0.md b/docs/ds-pipeline-v0.md index e16cf0f..0e5e46d 100644 --- a/docs/ds-pipeline-v0.md +++ b/docs/ds-pipeline-v0.md @@ -28,7 +28,9 @@ - iterate over the file - paraphrase the chunk [paraphrase-prompt.md](paraphrase-prompt.md) - add the paraphrased chunks to the vector store (how? will it affect the original chunk id?) - - Can just append the new chunks to the existing file? + - Can just append the new chunks to the existing file - Yes, but: + - The original vectors (first 10 in your example) KEEP their IDs (0-9) + - New vectors (last 10) get new IDs (10-19) - save the vector store - save the question json file - [ ] Should I ass wrong information or not? How correct should the paraphrased chunk be? How many paraphased chunks should I add for each original chunk? - **V0.1? for now just use simple paraphrasing with correct information.** diff --git a/docs/evaluation.md b/docs/evaluation.md index 46f6517..589bae6 100644 --- a/docs/evaluation.md +++ b/docs/evaluation.md @@ -47,6 +47,37 @@ percentage of correct answers: 0.19402985074626866 [rank0]:[W320 07:13:50.651270455 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator()) ``` +- Training log with paraphrased dataset (no new reward function yet!) - Disappointing results + + ```bash + +.2587745785713196, 'completion_length': 374.3125, 'kl': 0.004571444820612669, 'epoch': 0.34} +{'train_runtime': 7419.1437, 'train_samples_per_second': 0.218, 'train_steps_per_second': 0.014, 'train_loss': 0.00037626780881639505, 'epoch': 0.34} +100%|████████████████████████████████████████████████████████| 101/101 [2:03:39<00:00, 73.46s/it] +Processed prompts: 100%|█| 67/67 [00:19<00:00, 3.51it/s, est. speed input: 1016.34 toks/s, outpu +Processed prompts: 100%|█| 66/66 [00:21<00:00, 3.03it/s, est. speed input: 2086.78 toks/s, outpu +Processed prompts: 100%|█| 19/19 [00:14<00:00, 1.28it/s, est. speed input: 1326.10 toks/s, outpu +Processed prompts: 100%|█| 14/14 [00:14<00:00, 1.03s/it, est. speed input: 1363.04 toks/s, outpu +Processed prompts: 100%|█| 9/9 [00:13<00:00, 1.55s/it, est. speed input: 1153.10 toks/s, output: +Processed prompts: 100%|█| 67/67 [00:02<00:00, 28.46it/s, est. speed input: 5843.91 toks/s, outpu +RESULTS: +percentage of correct answers: 0.3582089552238806 +============================== + +Processed prompts: 100%|█| 67/67 [00:20<00:00, 3.20it/s, est. speed input: 925.56 toks/s, output +Processed prompts: 100%|█| 36/36 [00:13<00:00, 2.63it/s, est. speed input: 1755.08 toks/s, outpu +Processed prompts: 100%|█| 11/11 [00:09<00:00, 1.19it/s, est. speed input: 1254.10 toks/s, outpu +Processed prompts: 100%|█| 8/8 [00:09<00:00, 1.15s/it, est. speed input: 1192.77 toks/s, output: +Processed prompts: 100%|█| 4/4 [00:06<00:00, 1.67s/it, est. speed input: 1063.38 toks/s, output: +Processed prompts: 100%|█| 67/67 [00:02<00:00, 29.78it/s, est. speed input: 5244.11 toks/s, outpu +RESULTS: +percentage of correct answers: 0.2835820895522388 +============================== + +[rank0]:[W324 11:21:27.955684565 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator()) + + ``` + ## Getting some sense of the eval data or benchmark - > For example, benchmarks like ARC-AGI, which involve visual reasoning, remain challenging for these models, even though they might seem straightforward to a human. (ichigo) diff --git a/docs/reward-functions.md b/docs/reward-functions.md index fc2d66b..d045bd0 100644 --- a/docs/reward-functions.md +++ b/docs/reward-functions.md @@ -8,22 +8,81 @@ This note is a collection of stolen reward functions and tips from other project - Label studio suggest consult domain experts -> ask the LLM to be search engine expert?? - Starting from the default of AutoDiact should be good enough, then figure out big brain moves from there +- [ ] Reward exact matches only, don't increase gradually. For example, 4 or 5 attempts would get 1 point or half a point, don't scale up (e.g., 10 attempts doesn't scale up further) (don't reward retry behavior) + - Insight from Alphamaze: don't plan for too many cases, scope down to just 1-2 things to generalize rather than being too detailed + ## Implementation Phases -- [ ] 1.Just keep the default ones from AutoDidact and add the Exact Match Idea +- [x] V0. Just keep the default ones from AutoDidact and add the Exact Match Idea - Oh they only use 2 reward functions "reward_correctness" and "reward_formatting" -- [ ] 2. Add more if needed. +- [ ] V1. Add more reward functions + - Retrying + - Need mechanism to count number of retrying attempts + - Exact match + - Hold up, Do I also need LLM for those two? - NO, we are doing exact match, just write the rules, then if else ## Psuedo code ```python +def reward_exact_match(completions, expected_result, **kwargs) -> list[float]: + """Reward exact matches with search results + Returns 1.0 for exact match, 0.0 otherwise""" + responses = [completion[0]["content"] for completion in completions] + return [1.0 if r == expected_result else 0.0 for r in responses] +def reward_retry_behavior(completions, **kwargs) -> list[float]: + """Reward retrying search behavior but cap it + Returns: + - 0.5 for 2-5 search attempts + - 0.0 for <2 or >5 attempts to avoid reward hacking + """ + def count_search_attempts(response): + # Adjust this pattern based on how your search attempts are formatted + search_pattern = r"Searching for:.*?" + attempts = len(re.findall(search_pattern, response)) + if 2 <= attempts <= 5: + return 0.5 + return 0.0 + responses = [completion[0]["content"] for completion in completions] + return [count_search_attempts(r) for r in responses] +run_agent = rl_helpers.run_agent +reward_correctness = rl_helpers.build_reward_correctness_fn( + verifier_generate_fn, + tokenizer, +) +reward_formatting = rl_helpers.reward_formatting + +import UnslothGRPOTrainerTemp + +trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer( + model=model, + processing_class=tokenizer, + reward_funcs=[ + reward_correctness, + reward_formatting, + ], + args=training_args, + train_dataset=train_dataset, +) ``` +## Anatomy of reward_correctness and reward_formatting + +The `reward_correctness` and `reward_formatting` functions are key components in our reinforcement learning setup. Let's break down how they work: + +- `reward_correctness` + - Student LLM generate the answer + - Generated answer is compared with the correct answer, scoring by another LLM +- `reward_formatting` + - Student LLM generate the answer + - Generated answer is compared with the correct answer, scoring by another LLM + +![Reward Function Anatomy](assets/reward-function-anatomy.excalidraw.png) + ## Get a sense of Reward functions - diff --git a/generate_data.py b/generate_data.py index 02d4824..10d029e 100644 --- a/generate_data.py +++ b/generate_data.py @@ -34,7 +34,7 @@ docs = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0) chunks = text_splitter.split_documents(docs) -# Save chunks for later use +# Save chunks for later use # TODO: change to csv? easier inspect. os.makedirs("saved_data", exist_ok=True) with open("saved_data/chunks.pkl", "wb") as f: pickle.dump(chunks, f) diff --git a/notebooks/.gitignore b/notebooks/.gitignore index 192097d..5ac5af4 100644 --- a/notebooks/.gitignore +++ b/notebooks/.gitignore @@ -1,2 +1,3 @@ unsloth_compiled_cache -0_* \ No newline at end of file +0_* +faiss_index* \ No newline at end of file diff --git a/notebooks/250324_generate_data_anatomy.ipynb b/notebooks/250324_generate_data_anatomy.ipynb index edb626c..0c252ee 100644 --- a/notebooks/250324_generate_data_anatomy.ipynb +++ b/notebooks/250324_generate_data_anatomy.ipynb @@ -39,33 +39,41 @@ "\n", "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", "\n", - "# ========= Part 1: Document Processing and Embedding Generation =========\n", - "# Load and split the markdown document using LangChain\n", + "\n", "from langchain_community.document_loaders import UnstructuredMarkdownLoader\n", "from langchain_community.vectorstores import FAISS\n", "\n", - "from embeddings import CustomHuggingFaceEmbeddings\n", - "\n", - "# Load your markdown file (adjust the path as needed)\n", - "loader = UnstructuredMarkdownLoader(\"../data/mission_report.md\")\n", - "docs = loader.load()\n", - "\n", - "# Split the document into smaller chunks (each 1000 characters, no overlap)\n", - "text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", - "chunks = text_splitter.split_documents(docs)\n", - "\n", - "# Save chunks for later use\n", - "os.makedirs(\"saved_data\", exist_ok=True)\n", - "with open(\"saved_data/chunks.pkl\", \"wb\") as f:\n", - " pickle.dump(chunks, f)\n", - "print(f\"Saved {len(chunks)} chunks to saved_data/chunks.pkl\")\n", - "\n", - "embeddings = CustomHuggingFaceEmbeddings()\n", - "\n", - "# Create a FAISS vector store from the document chunks and save it locally\n", - "vectorstore = FAISS.from_documents(chunks, embeddings)\n", - "vectorstore.save_local(\"faiss_index\")\n", - "print(\"Saved FAISS index to 'faiss_index'\")" + "from embeddings import CustomHuggingFaceEmbeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# # ========= Part 1: Document Processing and Embedding Generation =========\n", + "# # Load and split the markdown document using LangChain\n", + "# # Load your markdown file (adjust the path as needed)\n", + "# loader = UnstructuredMarkdownLoader(\"../data/mission_report.md\")\n", + "# docs = loader.load()\n", + "\n", + "# # Split the document into smaller chunks (each 1000 characters, no overlap)\n", + "# text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", + "# chunks = text_splitter.split_documents(docs)\n", + "\n", + "# # Save chunks for later use\n", + "# os.makedirs(\"saved_data\", exist_ok=True)\n", + "# with open(\"saved_data/chunks.pkl\", \"wb\") as f:\n", + "# pickle.dump(chunks, f)\n", + "# print(f\"Saved {len(chunks)} chunks to saved_data/chunks.pkl\")\n", + "\n", + "# embeddings = CustomHuggingFaceEmbeddings()\n", + "\n", + "# # Create a FAISS vector store from the document chunks and save it locally\n", + "# vectorstore = FAISS.from_documents(chunks, embeddings)\n", + "# vectorstore.save_local(\"faiss_index\")\n", + "# print(\"Saved FAISS index to 'faiss_index'\")" ] }, { @@ -164,14 +172,12 @@ " \"\"\"Rewrite this text in a formal, scholarly tone. Keep it very concise - summarize in 1-2 short sentences. Only output the paraphrased text:\n", "\n", " TEXT: {text}\"\"\",\n", - " \n", " \"\"\"Rewrite this text in a clear, simple way that's easy to understand. Provide a medium-length explanation with key details. Only output the paraphrased text:\n", " \n", " TEXT: {text}\"\"\",\n", - " \n", " \"\"\"Rewrite this text in a vivid, engaging style. Expand on the details and provide a comprehensive, detailed version. Only output the paraphrased text:\n", " \n", - " TEXT: {text}\"\"\"\n", + " TEXT: {text}\"\"\",\n", "]\n", "\n", "# Update sampling parameters for each style\n", @@ -193,6 +199,7 @@ " max_tokens=512, # Long responses\n", ")\n", "\n", + "\n", "def generate_response(text: str) -> list[str]:\n", " \"\"\"\n", " Generate three different paraphrased versions with varying lengths.\n", @@ -204,9 +211,15 @@ " List of three paraphrased versions (short, medium, long)\n", " \"\"\"\n", " responses = []\n", - " sampling_params_list = [sampling_params_short, sampling_params_medium, sampling_params_long]\n", - "\n", - " for prompt_template, sampling_params in zip(PARAPHRASE_PROMPTS, sampling_params_list):\n", + " sampling_params_list = [\n", + " sampling_params_short,\n", + " sampling_params_medium,\n", + " sampling_params_long,\n", + " ]\n", + "\n", + " for prompt_template, sampling_params in zip(\n", + " PARAPHRASE_PROMPTS, sampling_params_list\n", + " ):\n", " formatted_prompt = tokenizer.apply_chat_template(\n", " [{\"role\": \"user\", \"content\": prompt_template.format(text=text)}],\n", " tokenize=False,\n", @@ -239,18 +252,177 @@ "paraphrased_chunks = []\n", "for chunk in chunks[:3]:\n", " styles = generate_response(chunk.page_content) # Now returns list of 3 styles\n", - " paraphrased_chunks.append(styles)\n", + " paraphrased_chunks.extend(styles)\n", "\n", - "# print the first 3 chunks and their paraphrased versions\n", - "for i, chunk in enumerate(chunks[:3]):\n", - " print(f\"\\n--- Original Chunk {i + 1}/3 ---\")\n", - " print(chunk.page_content)\n", - " print(\"-\" * 50)\n", - " \n", - " for j, style in enumerate(paraphrased_chunks[i], 1):\n", - " print(f\"\\n--- Style {j} Paraphrase ---\")\n", - " print(style)\n", - " print(\"-\" * 50)" + "from pprint import pprint\n", + "\n", + "pprint(paraphrased_chunks) # single list of 3*len(chunks) items" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: add checkpoint (save to file and resume from there) to this process, as it's long running and resource intensive\n", + "paraphrased_chunks = []\n", + "for chunk in chunks: # all chunks\n", + " styles = generate_response(chunk.page_content) # Now returns list of 3 styles\n", + " paraphrased_chunks.extend(styles) # should be single list of 3*len(chunks) items" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "df = pd.DataFrame(paraphrased_chunks)\n", + "# add column names\n", + "df.columns = [\"paraphrased_text\"]\n", + "df.to_csv(\"saved_data/paraphrased_chunks.csv\", index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# print number of rows\n", + "print(f\"Number of rows: {len(df)}\")\n", + "# wtf, ah 341 * 3 = 1023, make sense" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "df = pd.read_csv(\"saved_data/paraphrased_chunks.csv\")\n", + "df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ✅ Append final vectorstore here" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the existing FAISS index\n", + "from langchain_community.vectorstores import FAISS\n", + "from embeddings import CustomHuggingFaceEmbeddings\n", + "\n", + "\n", + "# Load the paraphrased chunks\n", + "df = pd.read_csv(\"saved_data/paraphrased_chunks.csv\")\n", + "print(f\"Loaded {len(df)} paraphrased chunks\")\n", + "\n", + "# Convert DataFrame to Document objects\n", + "from langchain.schema import Document\n", + "\n", + "paraphrased_documents = [\n", + " Document(page_content=row[\"paraphrased_text\"], metadata={})\n", + " for _, row in df.iterrows()\n", + "]\n", + "\n", + "# Initialize the embeddings model\n", + "embeddings = CustomHuggingFaceEmbeddings()\n", + "\n", + "# Create embeddings for the paraphrased chunks\n", + "paraphrased_vectorstore = FAISS.from_documents(paraphrased_documents, embeddings)\n", + "print(\n", + " f\"Created FAISS index for paraphrased chunks with {paraphrased_vectorstore.index.ntotal} vectors\"\n", + ")\n", + "\n", + "# NOTE: so this load the already generated vectorstore first\n", + "# Load the existing vectorstore - add allow_dangerous_deserialization=True to fix the error\n", + "existing_vectorstore = FAISS.load_local(\n", + " \"faiss_index\", embeddings, allow_dangerous_deserialization=True\n", + ")\n", + "print(f\"Loaded existing FAISS index with {existing_vectorstore.index.ntotal} vectors\")\n", + "\n", + "# Merge the two vectorstores\n", + "# Side effects:\n", + "# Original IDs are not preserved - new IDs are assigned sequentially #TODO: does the final dataset need this information?\n", + "# If vectors are duplicates/very similar, they'll still be added (no deduplication) (don't care for now)\n", + "\n", + "existing_vectorstore.merge_from(paraphrased_vectorstore)\n", + "print(f\"Merged vectorstores, now contains {existing_vectorstore.index.ntotal} vectors\")\n", + "\n", + "# Save the updated vectorstore\n", + "existing_vectorstore.save_local(\"faiss_index_with_paraphrased\")\n", + "print(\"Saved updated FAISS index to 'faiss_index_with_paraphrased'\")\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: Try loading the new vectorstore and see if it works\n", + "# Expected output size: 341 * 4 = 1364\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 🔍 Inspect FAISS vector store\n", + "Ok so basically the faiss vector store contains: \n", + "- Document ID\n", + "- **Full content** -> no need to save the original chunks.pkl file anymore \n", + "- Metadata\n", + "- **Full vector embedding**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load your index\n", + "from langchain_community.vectorstores import FAISS\n", + "from embeddings import CustomHuggingFaceEmbeddings\n", + "\n", + "embeddings = CustomHuggingFaceEmbeddings()\n", + "vectorstore = FAISS.load_local(\n", + " \"faiss_index\",\n", + " embeddings,\n", + " allow_dangerous_deserialization=True,\n", + ")\n", + "\n", + "# View contents\n", + "docs = vectorstore.docstore._dict\n", + "print(f\"Total documents: {len(docs)}\\n\")\n", + "\n", + "# Print first 5 docs as sample\n", + "for doc_id, doc in list(docs.items())[:5]:\n", + " print(f\"ID: {doc_id}\")\n", + " print(f\"Content: {doc.page_content[:200]}...\") # First 200 chars\n", + " print(f\"Metadata: {doc.metadata}\\n\")\n", + " print(\"-\" * 80 + \"\\n\")\n", + "\n", + "# Print total vectors for verification\n", + "print(f\"Total vectors in index: {vectorstore.index.ntotal}\")\n", + "print(f\"Vector dimension: {vectorstore.index.d}\")" ] }, { @@ -315,7 +487,7 @@ " print(\"-\" * 50)\n", " print(f\"\\n--- Paraphrased Chunk {i + 1}/3 ---\")\n", " print(paraphrased_chunks[i])\n", - " print(\"-\" * 50)\n" + " print(\"-\" * 50)" ] } ], diff --git a/requirements.txt b/requirements.txt index 8bd9e02..d42da86 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,8 +4,8 @@ langchain langchain-community Markdown tokenizers -transformers unsloth==2025.3.6 +transformers==4.49.0 unsloth_zoo==2025.3.4 unstructured vllm diff --git a/train_autodidact.py b/train_autodidact.py index a7ce045..1415a9c 100644 --- a/train_autodidact.py +++ b/train_autodidact.py @@ -1,9 +1,8 @@ # %% -from unsloth import FastLanguageModel +import torch # %% -from unsloth import is_bfloat16_supported -import torch +from unsloth import FastLanguageModel, is_bfloat16_supported max_seq_length = 4096 * 2 # Can increase for longer reasoning traces lora_rank = 64 # Larger rank = smarter, but slower @@ -36,9 +35,11 @@ model = FastLanguageModel.get_peft_model( # %% import re -from datasets import load_dataset, Dataset -from search_module import search, get_question_answer, get_question_count + +from datasets import Dataset, load_dataset + from rl_helpers import get_qa_dataset +from search_module import get_question_answer, get_question_count, search train_dataset, test_dataset = get_qa_dataset() @@ -87,6 +88,7 @@ training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig( import rl_helpers + # importlib.reload(rl_helpers) @@ -147,6 +149,7 @@ trainer.train() # %% from vllm import SamplingParams + import rl_helpers sampling_params = SamplingParams( diff --git a/train_autodidact_1B.py b/train_autodidact_1B.py new file mode 100644 index 0000000..e7fdd3e --- /dev/null +++ b/train_autodidact_1B.py @@ -0,0 +1,196 @@ +# %% +import torch + +# %% +from unsloth import FastLanguageModel, is_bfloat16_supported + +max_seq_length = 4096 * 2 # Can increase for longer reasoning traces +lora_rank = 64 # Larger rank = smarter, but slower + +model, tokenizer = FastLanguageModel.from_pretrained( + model_name="meta-llama/Llama-3.2-1B-Instruct", + max_seq_length=max_seq_length, + load_in_4bit=True, # False for LoRA 16bit + fast_inference=True, # Enable vLLM fast inference + max_lora_rank=lora_rank, + gpu_memory_utilization=0.6, # Reduce if out of memory +) + + +print(tokenizer.chat_template) # See what format Qwen expects + + +model = FastLanguageModel.get_peft_model( + model, + r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], # Remove QKVO if out of memory + lora_alpha=lora_rank, + use_gradient_checkpointing="unsloth", # Enable long context finetuning + random_state=3407, +) + +# %% +import re + +from datasets import Dataset, load_dataset + +from rl_helpers import get_qa_dataset +from search_module import get_question_answer, get_question_count, search + +train_dataset, test_dataset = get_qa_dataset() + +# %% [markdown] +# +# ### Train the model +# +# Now set up GRPO Trainer and all configurations! + +# %% +import os + +os.environ["WANDB_PROJECT"] = "bootstrap-search-rl" + +# %% +# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer +import UnslothGRPOTrainerTemp + +training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig( + use_vllm=True, # use vLLM for fast inference! + use_agentic_generate=True, # use agentic generation + learning_rate=5e-6, + adam_beta1=0.9, + adam_beta2=0.99, + weight_decay=0.1, + warmup_ratio=0.1, + lr_scheduler_type="cosine", + optim="paged_adamw_8bit", + logging_steps=1, + bf16=is_bfloat16_supported(), + fp16=not is_bfloat16_supported(), + per_device_train_batch_size=8, + gradient_accumulation_steps=1, # Increase to 4 for smoother training + num_generations=8, # Decrease if out of memory + max_prompt_length=1024, + max_completion_length=1024, + # num_train_epochs = 1, # Set to 1 for a full training run + max_steps=101, + save_steps=50, + max_grad_norm=0.1, + report_to="none", # Can use Weights & Biases + output_dir="full_local_training", +) + +# %% + + +import rl_helpers + +# importlib.reload(rl_helpers) + + +def agentic_generate( + prompts: list[str], + generate_fn, + max_generations: int = 6, +): + return run_agent(generate_fn, tokenizer, prompts, max_generations) + + +model.agentic_generate = agentic_generate + + +from vllm import SamplingParams + +verifier_sampling_params = SamplingParams( + temperature=0.1, + top_p=0.95, + max_tokens=4096, +) + + +def verifier_generate_fn(inputs): + return model.fast_generate( + inputs, + sampling_params=verifier_sampling_params, + ) + + +run_agent = rl_helpers.run_agent +reward_correctness = rl_helpers.build_reward_correctness_fn( + verifier_generate_fn, + tokenizer, +) +reward_formatting = rl_helpers.reward_formatting + +import UnslothGRPOTrainerTemp + +trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer( + model=model, + processing_class=tokenizer, + reward_funcs=[ + reward_correctness, + reward_formatting, + ], + args=training_args, + train_dataset=train_dataset, +) + +# %% +trainer.train() + +# %% [markdown] +# +# ### Inference +# Now let's try benchmark the model we trained! + +# %% +from vllm import SamplingParams + +import rl_helpers + +sampling_params = SamplingParams( + temperature=0.5, + top_p=0.95, + max_tokens=4096, +) + + +def eval_generate_fn(inputs): + return model.fast_generate( + inputs, + sampling_params=sampling_params, + lora_request=model.load_lora( + "full_local_training/checkpoint-101" + ), # load the trained LoRA + ) + + +rl_helpers.run_eval( + generate_fn=eval_generate_fn, + verify_fn=reward_correctness, + tokenizer=tokenizer, +) + + +# %% +# eval w/o lora +def eval_generate_fn(inputs): + return model.fast_generate( + inputs, + sampling_params=sampling_params, + ) + + +rl_helpers.run_eval( + generate_fn=eval_generate_fn, + verify_fn=reward_correctness, + tokenizer=tokenizer, +)