feat: add initial project structure and core functionality

- Added initial files from AutoDiact as starting point
- Enhanced `README.md` with project overview and setup instructions.   .
- Removed `ugly_code_file.py` as part of cleanup.
- Added various documentation files and assets for project clarity.
- Included Jupyter notebooks for training and experimentation.
main
thinhlpg 2 months ago
parent 91c2476c28
commit a58722e16f

@ -0,0 +1,2 @@
HF_TOKEN=
OPENROUTER_API_KEY=

7
.gitignore vendored

@ -1,3 +1,10 @@
# DeepSearch
.ruff_cache/
saved_data/
saved_models/
faiss_index/
.vscode/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

@ -1 +1,39 @@
# DeepSearch
# DeepSearch - A Hard Working Search Engine 🔍
DeepSearch trains a small language model to develop effective search behaviors instead of memorizing static data. It interacts with multiple synthetic search engines, each with unique retrieval mechanisms, to refine queries and persist in searching until it finds exact answers. The project focuses on reinforcement learning, preventing overfitting, and optimizing for efficiency in real-world search applications.
![Project Whiteboard](docs/assets/whiteboard.drawio.png)
## Setup
```bash
python -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
```
## Models
You can find our models on Hugging Face 🤗! We're committed to open-source and easy access for the research community.
| Model | Backbone | Size | Link |
|-------|----------|------|------|
| - | - | - | - |
## Datasets
We've released our datasets on Hugging Face 🤗 to support reproducibility and further research.
| Dataset | Description | Size | Link |
|--------------------------------------|-----------------------------------------------------|-------|-----------------------------------------------------------------------------------------|
| - | - | - | - |
| - | - | - | - |
| - | - | - | - |
## References
- This project is kickstarted from [AutoDidact](https://github.com/dCaples/AutoDidact)
## Personal Notes
- **This is research code**, so I'm prioritizing speed over code quality for now. Expect things to be messy—both the code and commit history. Roasting is welcome, but don't judge me too hard; I'll clean it up later. **I dont know what I dont know**, but Im eager (and desperate) to learn and improve, so any constructive feedback is highly appreciated! 💖

File diff suppressed because it is too large Load Diff

Binary file not shown.

File diff suppressed because it is too large Load Diff

@ -0,0 +1,63 @@
# Worklog
## Backlog
- [ ] Modify `generate_dataset.py` (**ONLY AFTER** the simple training and benchmark works):
- [ ] As a data dataset maker, I want to change from LLama 3.1 8B to API call, like claude, gemini or openai. Originally they use 3.1 8B for `Self-Bootstrapping` demonstration, but the dataset quality is low, for sure.
- [ ] Experimenting with different chunking strategies
- [ ] [search-backends.md](search-backends.md) design (for more dataset noise (**ONLY AFTER** the simple training dataset works))
- [ ] Research a little bit on Agentic Reward Modeling (for designing better reward function maybe?)
- <https://medium.com/@techsachin/agentic-reward-modeling-combine-human-preferences-with-verifiable-correctness-signals-for-reliable-76c408b3491c>
- <https://arxiv.org/pdf/2502.19328>
- <https://github.com/THU-KEG/Agentic-Reward-Modeling>
- <https://www.themoonlight.io/en/review/agentic-reward-modeling-integrating-human-preferences-with-verifiable-correctness-signals-for-reliable-reward-systems>
## yymmdd
- [ ] task description
## 250324
- [ ] @thinhlpg transfers the project to @bachvudinh
## 250323
- [ ] Train the model
- [ ] Make the dataset
- [ ] Upload datasets to HF Hub
- Initial dataset from AutoDidact
- Paraphrased sdataset
- [ ] Make a simple gradio demo app
## 250322
- [x] Moving all the scattered and disorganized stuffs that've been working on for the past week into this repo.
- [x] Write proposal for DeepSearch
- [x] [evaluation.md](evaluation.md) design (list out the metrics and why)
- [x] [dataset.md](dataset.md) design (pipeline, data structure,...)
- [x] [reward-functions.md](reward-functions.md) design (list out the functions and why)
- [x] As a new member of research team, i'm curious on how did we do GRPO with Alphamaze?, so that I can inherit the good stuff and improve the workflow!!!
- [Alphamaze](https://github.com/menloresearch/visual-thinker)?
- <https://www.menlo.ai/blog/alpha-maze>
- <https://arxiv.org/pdf/2502.14669>
- > Our training process involved two key stages: creating a specialized dataset and then using a combination of supervised fine-tuning (SFT) and reinforcement learning (RL) to train the model.
- LLaMA-Factory for SFT **(1.5B 6xA6000 1.5 hour)** and Unsloth for GRPO
- 💡 Hmm so for SFT we have 50% successful data and 50% retry data, and full successful data for GRPO. Can I also apply this to DeepSearch as well? #HACK
## 250321
- [x] Inspect the code of AutoDidact in a more detailed way <https://github.com/menloresearch/DeepSearch/issues/4>
## 250320
- Research on GRPO <https://github.com/menloresearch/DeepSearch/issues/2>
## 250319
- Research on GRPO <https://github.com/menloresearch/DeepSearch/issues/2>
- Run the training script of AutoDidact
## 250318
- Idea received <https://github.com/menloresearch/DeepSearch/issues/1>

@ -0,0 +1,31 @@
# Adaptive Search Behavior
- [Agent Action](agent-action.md) -> mostly recognize missing something -> perform "refined query"
- [ ] As a model trainer, I want to inspect the full chat state of the agent to know what's going on so I can improve it -> implement a simple cli inspect tool after training, just print out full chat state.
- Example from AutoDidact:
```markdown
Example Question
What was the reason for substituting the backup Command Module Pilot 3 days prior to the Apollo 13 flight?
Step-by-Step Search Process
Query : "Apollo 13 Command Module Pilot substitution"
Outcome: Retrieved operational support details, but no explanation for the substitution.
Agent's Action: Recognized missing information → **Refined query**.
Query : "Apollo 13 Command Module Pilot substitution reason"
Outcome: Retrieved general mission anomaly details, but still no direct answer.
Agent's Action: **Increased query specificity**.
Query : "Apollo 13 John 'Jack' Swigert substitution"
Outcome: Found general mission reports, but still lacked a clear reason for substitution.
Agent's Action: Hypothesized illness might be a factor → **Refined query** accordingly.
Query : "Apollo 13 Jack Swigert illness substitution"
Outcome: Retrieved the exact explanation: "Several days prior to launch, the backup Lunar Module Pilot became sick with measles. Examinations of the prime crew indicated that the Command Module Pilot was not immune to the disease; therefore, the backup Command Module Pilot was substituted."
Final Answer
The original Command Module Pilot lacked immunity to measles, necessitating his replacement by Jack Swigert.
This example shows how llama learns to do multiple searches to find answers to its questions.
```

@ -0,0 +1,5 @@
# Agent Action
- [ ] Research a bit more on this because I'm a bit outdated on the training side
- [ ] How does the dataset look like?
- [ ] How to evaluate the performance?

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

@ -0,0 +1,91 @@
# Dataset
This document describes the creation of a data pipeline to generate a dataset.
## Implementation Phases
- [ ] 1.Simple chunk paraphrasing logic that's just work
- After splitting, feed the splitted chunks into LLM to paraphrase
- Rebuil the FAISS index with the paraphrased chunks
- Don't touch `question.json`
- [ ] 2.Enhance the dataset quality with API (check backlog)
## Inital idea from @tikikun
- Take a dataset and break it into chunks.
- Use the **ground truth chunks** (the original, correct ones).
- Use an AI model to **paraphrase** those chunks—rewrite them in a different way while keeping the same meaning.
- During training, give the model these **paraphrased chunks** and ask it to **search for the original chunk**.
- If the model finds the correct original chunk, it gets a **reward**.
- This way, the model learns to **retrieve the most accurate chunk** even when given noisy or reworded input.
### Why Does This Work?
- **Paraphrasing adds noise**, making the training more realistic.
- The model learns to **recover the true information** from different ways of saying it.
- It ensures the model **only stops searching when it finds the exact right answer**.
- This makes retrieval stronger because it trains the model to handle **real-world variations in wording**.
### Derived from flow matcing
- Flow Matching is a generative modeling technique that trains models to transform simple distributions **(like noise)** into complex data distributions by learning continuous transformations, or "flows"
- Paraphrase as Noise Introduction: By paraphrasing original data chunks, we introduce controlled noise, creating variations that maintain the original meaning but differ in wording.
- Model Training with Paraphrased Data: The model is trained to map these paraphrased (noisy) chunks back to their original form, learning to navigate from noise to truth.
## Dataset Format
- Should start from AutoDidact `generate_dataset.py`
- Output 3 things:
- Document chunks.pkl
- questions.json
- faiss_index
- `question.json`:
```json
{
"chunk_id": "chunk_1",
"question": "What is the capital of France?",
"answer": "Paris",
"difficulty": "easy"
}
```
- Original Flow: load markdown -> split into chunks -> generate embeddings -> build FAISS index -> generate questions
```mermaid
graph TD
%% === Document Processing and Embedding Generation ===
A1[Load Markdown Document] -->|mission_report.md| A2[Split Document into Chunks]
A2 -->|Save as Pickle| A3[💾 Chunks saved_data/chunks.pkl]
A3 -->|Load Chunks| B1[Generate Embeddings]
B1 -->|Save Embeddings| B2[Build FAISS Vector Store]
B2 -->|Save FAISS Index| B3[💾 FAISS Index faiss_index]
%% === QA Pair Generation ===
C1[Load Llama Model] -->|meta-Llama-3.1-8B-Instruct| C2[Configure Sampling Params]
A3 -->|Load Chunks| D1[Prepare Sliding Window Prompts]
D1 -->|Batch Generation| D2[Generate QA Pairs]
D2 -->|Parse & Validate QA Pairs| D3[Filter Valid Pairs]
D3 -->|Save Questions| D4[💾 QA Pairs saved_data/questions.json]
%% === Error Handling ===
D2 -->|Retry Failed Prompts| D5[Retry Batch Generation]
D5 -->|Parse & Validate QA Pairs| D3
%% Dependencies
A1 -.->|Required| B1
A3 -.->|Required| D1
C1 -.->|Required| D2
C2 -.->|Required| D2
B3 -.->|Required| D2
```
## Get a sense of how to prepare the dataset for GRPO
- <https://docs.unsloth.ai/basics/reasoning-grpo-and-rl/tutorial-train-your-own-reasoning-model-with-grpo#data-preparation>
- > Your dataset should still have at least **2 columns for question and answer pairs**. However the **answer must not reveal the reasoning behind** how it derived the answer from the question. See below for an example:
- Cool basic stuff <https://docs.unsloth.ai/basics/datasets-101>

@ -0,0 +1,52 @@
# Evaluation
- **Goal**:
- 1. Better performance than the original one (by auto eval script)
- 2. Better performance by real human eval/preference
## Implementation Phases
- [x] 1. Just take the eval function from the original repo (it simply uses accuracy (exact match)) and simple quick glance on the output quality.
- [ ] 2. Find some more common and conventional dataset and benchmarks (still auto script)
- [ ] 3. Setup human eval
## Baseline
- Info from autodidact
- After just 100 steps of GRPO training (1 hour on a single RTX 4090 GPU), Llama-8B significantly improved its ability to research and answer questions from the Apollo 13 mission report
- On a validation set of 68 questions, accuracy more than doubled from 23% to 59%.
- Training log: idk why but the result that I got from acutally running the training is a bit lower.
```bash
ceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
Processed prompts: 100%|████████████████| 16/16 [00:00<00:00, 39.27it/s, est. speed input: 6827.13 toks/s, output: 81.01 toks/s]
rewards_per_func: tensor([0.6875, 0.7000], device='cuda:0'):05, 2.55it/s, est. speed input: 385.80 toks/s, output: 5.11 toks/s]
{'loss': 0.0003, 'grad_norm': 0.5810762047767639, 'learning_rate': 0.0, 'rewards/reward_correctness': 0.6875, 'rewards/reward_formatting': 0.699999988079071, 'reward': 1.3875000476837158, 'reward_std': 0.44403791427612305, 'completion_length': 224.125, 'kl': 0.00834659393876791, 'epoch': 0.34}
{'train_runtime': 7992.2854, 'train_samples_per_second': 0.202, 'train_steps_per_second': 0.013, 'train_loss': 0.0005197484556535774, 'epoch': 0.34}
100%|███████████████████████████████████████████████████████████████████████████████████████| 101/101 [2:13:12<00:00, 79.13s/it]
Processed prompts: 100%|████████████████| 67/67 [00:20<00:00, 3.28it/s, est. speed input: 950.44 toks/s, output: 394.51 toks/s]
Processed prompts: 100%|███████████████| 66/66 [00:20<00:00, 3.15it/s, est. speed input: 2383.55 toks/s, output: 323.82 toks/s]
Processed prompts: 100%|███████████████| 20/20 [00:17<00:00, 1.13it/s, est. speed input: 1320.49 toks/s, output: 146.76 toks/s]
Processed prompts: 100%|████████████████| 17/17 [00:16<00:00, 1.04it/s, est. speed input: 1620.28 toks/s, output: 98.35 toks/s]
Processed prompts: 100%|██████████████████| 9/9 [00:15<00:00, 1.73s/it, est. speed input: 1165.77 toks/s, output: 71.38 toks/s]
Processed prompts: 100%|████████████████| 67/67 [00:04<00:00, 16.31it/s, est. speed input: 3617.28 toks/s, output: 61.11 toks/s]
RESULTS:
percentage of correct answers: 0.5074626865671642
==============================
Processed prompts: 100%|███████████████| 67/67 [00:15<00:00, 4.46it/s, est. speed input: 1292.29 toks/s, output: 561.32 toks/s]
Processed prompts: 100%|███████████████| 44/44 [00:18<00:00, 2.44it/s, est. speed input: 1800.84 toks/s, output: 244.13 toks/s]
Processed prompts: 100%|███████████████| 13/13 [00:12<00:00, 1.05it/s, est. speed input: 1209.04 toks/s, output: 126.32 toks/s]
Processed prompts: 100%|███████████████| 10/10 [00:13<00:00, 1.32s/it, est. speed input: 1225.46 toks/s, output: 109.78 toks/s]
Processed prompts: 100%|██████████████████| 7/7 [00:12<00:00, 1.86s/it, est. speed input: 1149.18 toks/s, output: 76.05 toks/s]
Processed prompts: 100%|████████████████| 67/67 [00:02<00:00, 31.53it/s, est. speed input: 6047.70 toks/s, output: 83.31 toks/s]
RESULTS:
percentage of correct answers: 0.19402985074626866
==============================
[rank0]:[W320 07:13:50.651270455 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())
```
## Getting some sense of the eval data or benchmark
- > For example, benchmarks like ARC-AGI, which involve visual reasoning, remain challenging for these models, even though they might seem straightforward to a human. (ichigo)

@ -0,0 +1,15 @@
# GRPO idea
- The training flow of R1 is really simple (thanks my friend professional yapper @vTuanpham) for initially clarifing my dumbness 🤣
```python
1. Train một con base biết dùng tool bằng sft thuần để boost
Tuan
2. Sau đó thả rông bằng gpro, syntax gần đúng 0.5, syntax đúng params lệch quá thì 0.65, cả hai đều được thì 0.85,...
```
## Unsloth's guide
- <https://unsloth.ai/blog/r1-reasoning>
- Heheboi let's steal this notebook <https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb>
- <https://docs.unsloth.ai/basics/reasoning-grpo-and-rl> - This is like the most simple

@ -0,0 +1,10 @@
```mermaid
graph TD
A[User Query] -->|Random Search Engine Assigned| B{Synthetic Search Engine}
B -->|Retrieves Initial Results| C[Model Analyzes Results]
C -->|Refines Query if Needed| D[Iterative Search Process]
D -->|Final Answer Found| E[Return Best Match]
E -->|Rewards/Penalties Applied| F[Reinforcement Learning Update]
F -->|Optimized Search Strategy| B
```

@ -0,0 +1,282 @@
# Reward functions
This note is a collection of stolen reward functions and tips from other projects.
- NEED SOMETHING THAT MAKE THE MODEL WORK HARDER!!!
- [x] Goal: design reward functions (Search Task!) for DeepSearch's GRPO trainings (likely to be exact match) (**Try the suggestion by unsloth below, lol**)
- > You can refer to the examples below. You can input your generations into an LLM like ChatGPT 4o or Llama 3.1 (8B) and design a reward function and verifier to evaluate it. **For example, feed your generations into a LLM of your choice and set a rule: "If the answer sounds too robotic, deduct 3 points." This helps refine outputs based on quality criteria**
- Label studio suggest consult domain experts -> ask the LLM to be search engine expert??
- Starting from the default of AutoDiact should be good enough, then figure out big brain moves from there
## Implementation Phases
- [ ] 1.Just keep the default ones from AutoDidact and add the Exact Match Idea
- Oh they only use 2 reward functions "reward_correctness" and "reward_formatting"
- [ ] 2. Add more if needed.
## Psuedo code
```python
```
## Get a sense of Reward functions
- <https://github.com/kubernetes-bad/reward-composer>
- Reward Composer is a collection of simple building blocks for making your perfect reward function for Reinforcement Learning training of language models... It's like Lego for GRPO.
- <https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb>
- Really minimalist and simple grpo training script (only 171 lines :O)
- Example form unsloth's blog <https://docs.unsloth.ai/basics/reasoning-grpo-and-rl#reward-function-examples>
- > You can **reuse data** across multiple epochs. - What does this mean 👀?
- From <https://labelstud.io/blog/reinforcement-learning-from-verifiable-rewards/#how-to-design-a-verifiable-reward-function>
- Factual Accuracy: Checking whether the output contains verifiable facts.
- Logical Consistency: Ensuring that arguments or narratives are internally consistent. Ensure solving propositional logic reasoning problems
- Exact Match and Heuristics: Use deterministic rules to check correctness (e.g., exact match in math answers, passing test cases in code, **matching the predefined categories or taxonomy** etc.)
- > Designing a verifiable reward function **requires expert knowledge, domain expertise**, and structured data interfaces - Can I just LLM Roleplaying search engine expert? 👀
- Multi-Level Scoring: Implement tiered scoring mechanisms to reward partial correctness where applicable. (cool, might try this)
- > 3. Validate the Reward Model Based on Generated Examples
Run Controlled Tests: Generate model outputs and measure how well the reward function distinguishes correct from incorrect responses.
Evaluate for Robustness: Ensure the function avoids penalizing correct responses due to formatting issues or minor variations.
A/B Testing with RL Agents: Compare performance between models trained with and without the verifiable reward function.
## Reward Function vs Verifier
Stolen note from unsloth's docs:
| Component | Purpose | Characteristics | Examples |
|-----------|---------|-----------------|----------|
| **Verifier** | Determines correctness | - No numerical scoring<br>- Binary correct/incorrect judgment | - Checks if "2+2=5" is wrong<br>- Executes code to validate syntax/logic |
| **Reward Function** | Assigns numerical scores | - Converts verification to numbers<br>- Can include multiple criteria | - Wrong answer: -1 or -2<br>- Correct answer: +1 or +2<br>- Penalties for length/readability |
| **Key Differences** | | - Verifier: checks correctness without scoring<br>- Reward Function: assigns scores without necessarily verifying<br>- Reward Function can use a Verifier, but they're distinct components | |
## Idea examples
Note taken from unsloth's docs.
Example #1: Simple Arithmetic Task
- Question: "2 + 2"
- Answer: "4"
- Reward Function 1:
- If a number is detected → +1
- If no number is detected → -1
Example #2: Email Automation Task
- Question: Inbound email
- Answer: Outbound email
- Reward Functions:
- If the answer contains a required keyword → +1
- If the answer exactly matches the ideal response → +1
- If the response is too long → -1
- If the recipient's name is included → +1
- If a signature block (phone, email, address) is present → +1
## Code Examples
- Below is a code snippet from @unslothai sample notebook, which is taken from @willccbb's gist
```python
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]["content"] for completion in completions]
q = prompts[0][-1]["content"]
extracted_responses = [extract_xml_answer(r) for r in responses]
print(
"-" * 20,
f"Question:\n{q}",
f"\nAnswer:\n{answer[0]}",
f"\nResponse:\n{responses[0]}",
f"\nExtracted:\n{extracted_responses[0]}",
)
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]["content"] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def count_xml(text) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1]) * 0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
...
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[ # Personal note: didn't expect this be so simple to implement @@
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func,
],
args=training_args,
train_dataset=dataset,
)
trainer.train()
```
- [x] Just curious, how did the team implemented the reward functions for [Alphamaze](https://github.com/menloresearch/visual-thinker)?
- Below is from Alphamaze's repo
- > We designed a reward function 3 components. Correctness Reward (+0.2 per solution step): This reward is scaled according to the number of steps in the maze solution. Each valid movement step adds 0.2 points to the total score. For example, a solution requiring 4 steps earns a reward of 0.2 × 4 = 0.8 points, incentivizing both accuracy and efficiency in navigation. Integrity Reward (+0.5): This reward is given for each valid movement token (<|up|>, <|down|>, <|left|>, <|right|>) in the predicted sequence, encouraging the generation of meaningful and valid movement steps.
- > Thinking Reward (+0.25): This reward is given for correctly using the <think> tag in the output, ensuring completeness and consistency in the reasoning format. These reward components were weighted to prioritize correctness while also encouraging valid movement sequences and proper reasoning formatting with <think> tag. We adapted the Group Relative Policy Optimization (GRPO) algorithm, as employed in DeepSeek-R1 [Guo et al., 2025], to perform reinforcement learning. GRPO estimates advantages based on relative group scores, offering computational efficiency compared to critic-based methods.
```python
def xmlcount_reward_func(completions, **kwargs) -> List[float]:
"""
Reward function based on proper XML tag usage.
Args:
completions: Model completions
Returns:
List of reward scores
"""
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
def int_reward_func(completions, **kwargs) -> List[float]:
"""
Reward function that checks if responses contain valid direction tokens.
Args:
completions: Model completions
Returns:
List of reward scores
"""
allowed_tokens = {"<|up|>", "<|down|>", "<|right|>", "<|left|>"}
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
def correctness_reward_func(prompts, completions, answer, **kwargs) -> List[float]:
"""
Reward function that checks correctness of answers.
Args:
prompts: Input prompts
completions: Model completions
answer: Ground truth answers
Returns:
List of reward scores
"""
rewards = []
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
logger.debug('-'*20)
logger.debug(f"Question:\n{q}")
logger.debug(f"\nAnswer:\n{answer[0]}")
logger.debug(f"\nResponse:\n{responses[0]}")
logger.debug(f"\nExtracted:\n{extracted_responses[0]}")
for r, a in zip(extracted_responses, answer):
if r == a:
direction = r.split("|><|")
rewards.append(len(direction)*0.2)
else:
rewards.append(0.0)
return rewards
# def strict_format_reward_func(completions, **kwargs) -> List[float]:
# """
# Reward function that checks if completions strictly follow the required format.
# Args:
# completions: Model completions
# Returns:
# List of reward scores
# """
# pattern = r"^<think>\n.*?\n</think>\n\n.*?\n$"
# responses = [completion[0]["content"] for completion in completions]
# matches = [re.match(pattern, r, re.DOTALL) for r in responses]
# return [0.5 if match else 0.0 for match in matches]
# def soft_format_reward_func(completions, **kwargs) -> List[float]:
# """
# Reward function that checks if completions loosely follow the required format.
# Args:
# completions: Model completions
# Returns:
# List of reward scores
# """
# pattern = r"<think>.*?</think>\s*.*?"
# responses = [completion[0]["content"] for completion in completions]
# matches = [re.match(pattern, r, re.DOTALL) for r in responses]
# return [0.5 if match else 0.0 for match in matches]
...
reward_funcs=[
xmlcount_reward_func,
# soft_format_reward_func,
# strict_format_reward_func,
int_reward_func,
correctness_reward_func,
],
```
## Comparison of Alphamaze's reward functions and unsloth's
| Feature | Unsloth Example | AlphaMaze | Similarities | Differences |
| :-------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **Overall Purpose** | To evaluate and score the quality of model-generated text based on various criteria (format, correctness, content). | Same as Unsloth. | Both aim to provide numerical rewards for model outputs based on defined criteria. | AlphaMaze appears more focused on a specific maze-solving task (directions in the answer), while Unsloth's examples are more general, including evaluating whether a number prediction can be cast to integer . |
| **Function Structure** | Functions generally take `completions` (and sometimes `prompts`, `answer`) as input. Return a list of floats (rewards). | Same as Unsloth. | Both use functions that take model outputs (and sometimes inputs) and return lists of reward scores. | AlphaMaze's `correctness_reward_func` calculates a reward based on the *length* of the correct answer (number of directions), while Unsloth's gives a fixed reward (2.0) for a correct answer, and 0 otherwise. |
| **Reward Types** | - `correctness_reward_func`: Checks if the extracted answer matches the ground truth. Binary reward (2.0 or 0.0).<br> - `int_reward_func`: Checks if extracted answer is a digit. Binary reward (0.5 or 0.0).<br> - `strict_format_reward_func`, `soft_format_reward_func`: Check for specific XML-like formatting using regular expressions. Binary reward (0.5 or 0.0).<br> - `xmlcount_reward_func`: Counts XML tags, providing a fractional reward based on tag presence and penalizing trailing text. | - `correctness_reward_func`: Checks if extracted answer matches ground truth. Reward is proportional to answer length (0.2 per direction).<br> - `int_reward_func`: Checks if the answer consists of allowed tokens. The implementation in this code is not complete. <br> - `xmlcount_reward_func`: Same as Unsloth's.<br> - `strict_format_reward_func` (commented out): Checks for a specific format using regex.<br> - `soft_format_reward_func` (commented out): Checks for a looser format. | - Both have `correctness_reward_func`, `int_reward_func`, `xmlcount_reward_func` (though implemented slightly differently).<br>- Both use regular expressions for format checking. | - Unsloth uses a simpler binary reward for correctness. AlphaMaze uses a length-based reward.<br>- Unsloth's `int_reward_func` check if castable to integer, AlphaMaze's intends to check for allowed direction tokens (but the implementation is not finished).<br>- AlphaMaze's formatting functions are commented out. |
| **`correctness_reward_func`** | Compares extracted answer to ground truth. Prints debugging information. Returns 2.0 for correct, 0.0 otherwise. | Compares extracted answer to ground truth, calculates reward based on the *length* of the correct answer (number of direction steps, 0.2 per step). Logs debugging information. | Both compare the extracted answer to the ground truth answer and print/log debugging information. | - Unsloth returns a fixed reward (2.0) for a correct answer.<br> - AlphaMaze's reward is proportional to the length of the correct answer (0.2 per direction). |
| **`int_reward_func`** | Checks if the extracted response `isdigit()`. Returns 0.5 if true, 0.0 otherwise. | Intended to check if the response contains allowed direction tokens (`<|up|>`,`<|down|>`, etc.). The provided code *does not* actually implement this check. The lines where the response is processes are incomplete and non-functional. | Both are intended to evaluate specific characteristics of the extracted response. | - Unsloth's checks for digits.<br>- AlphaMaze's *intended* functionality is to check for specific tokens, but the code, as shown, does not implement this, and the reward return is not defined. |
| **`xmlcount_reward_func`** | Same implementation in both. Counts opening/closing tags, penalizes extra text. | Same implementation in both. | Identical implementation. | None. |
| **Format Checking** | Uses `strict_format_reward_func` and `soft_format_reward_func` with different regular expressions. | Has `strict_format_reward_func` and `soft_format_reward_func` (commented out) with different regular expressions. | Both use regular expressions to check for specific formatting patterns. | - Unsloth's format checks look for `<reasoning>` and `<answer>` tags.<br>- AlphaMaze's (commented out) checks look for `<think>` tags and a general structure.<br>- Unsloth's are active; AlphaMaze's are commented out. |
| **Extracted Answer** | Both use an `extract_xml_answer` function (not shown in the provided snippets, but assumed to be defined elsewhere). | Same as Unsloth. | Both rely on an external function to extract the relevant part of the response for evaluation. | We don't know the exact implementation of `extract_xml_answer`, so there might be subtle differences. However, the *use* is the same. |

@ -0,0 +1,8 @@
# Search backends
- Purpose: adding more noise to the training process. (already did this in the initial dataset)
- Different search strategy? - Semantic search, keyword search, BM25, actually api call
- Embedding models, Retrieval mechanisms (BM25, dense, hybrid), Query expansion methods, Reranking strategies
- Random search engine assignment per query
- Noise and inconsistency injection to prevent shortcut learning

@ -0,0 +1,5 @@
# Self Verification
- [x] Investigate this term: it's word is mentioned in the autodiact's about section and also in the deepseek R1 paper (not so detailed), but not in blogs or code base. I think this word is important and should be investigated
- Lol a "Verifier" is just a synonym of **reward function**
- <https://docs.unsloth.ai/basics/reasoning-grpo-and-rl#reward-functions-verifier>

@ -0,0 +1,92 @@
from typing import List, Union
import torch
import torch.nn.functional as F
from langchain.embeddings.base import Embeddings
from transformers import AutoModel, AutoTokenizer
# Set a default model here
DEFAULT_MODEL_NAME = "avsolatorio/NoInstruct-small-Embedding-v0"
class CustomHuggingFaceEmbeddings(Embeddings):
"""
A custom embeddings class that wraps a Hugging Face model for generating embeddings.
Supports two modes:
- "sentence": uses the [CLS] token representation for sentence/document embeddings.
- "query": uses mean pooling over tokens (weighted by the attention mask) for query embeddings.
"""
def __init__(
self, model_name: str = DEFAULT_MODEL_NAME, default_mode: str = "sentence"
):
self.model_name = model_name
# Set device to GPU if available, else CPU
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.default_mode = default_mode # "sentence" or "query"
self.model.eval() # Set model to evaluation mode
def get_embedding(self, text: Union[str, List[str]], mode: str = None):
if mode is None:
mode = self.default_mode
assert mode in (
"query",
"sentence",
), f"Unsupported mode: {mode}. Only 'query' and 'sentence' are supported."
# Ensure we are working with a list of texts
if isinstance(text, str):
text = [text]
# Tokenize the input texts
inp = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
# Move the input tensors to the same device as the model
inp = {key: value.to(self.device) for key, value in inp.items()}
# Forward pass (no gradients needed)
with torch.no_grad():
output = self.model(**inp)
if mode == "query":
# Mean pooling: weight by attention mask and average across tokens
vectors = output.last_hidden_state * inp["attention_mask"].unsqueeze(2)
vectors = vectors.sum(dim=1) / inp["attention_mask"].sum(dim=-1).view(-1, 1)
else:
# Sentence/document embedding: use the [CLS] token (first token) representation
vectors = output.last_hidden_state[:, 0, :]
return vectors
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Compute embeddings for a list of documents (using sentence mode).
"""
vectors = self.get_embedding(texts, mode="sentence")
return vectors.cpu().numpy().tolist()
def embed_query(self, text: str) -> List[float]:
"""
Compute an embedding for a single query.
"""
vector = self.get_embedding(text, mode="query")
return vector.cpu().numpy()[0].tolist()
# For quick testing
if __name__ == "__main__":
embeddings = CustomHuggingFaceEmbeddings()
# Example texts for document embeddings
texts = [
"Illustration of the REaLTabFormer model. The left block shows the non-relational tabular data model using GPT-2.",
"Predicting human mobility holds significant practical value, with applications in disaster planning and epidemic simulation.",
"As economies adopt digital technologies, policy makers are asking how to prepare the workforce for emerging labor demands.",
]
doc_embeddings = embeddings.embed_documents(texts)
print("Document embeddings:", doc_embeddings)
# Example query embedding
query_embedding = embeddings.embed_query("Which sentence talks about jobs?")
print("Query embedding:", query_embedding)

@ -0,0 +1,230 @@
"""
This script performs two main tasks:
1. It loads a markdown document, splits it into chunks, generates embeddings,
and builds a FAISS index (which is saved locally).
2. It generates QA pairs from the document using llama.
For each chunk (using a sliding window for context), it generates multiple question-answer pairs
with different difficulties. The generation is performed in batch with one retry for failed prompts.
Successfully generated QA pairs are saved to "saved_data/questions.json".
Requirements:
pip install langchain faiss-cpu unsloth vllm
"""
import json
import os
import pickle
import re
from typing import Dict, List, Optional, Tuple
from langchain.text_splitter import RecursiveCharacterTextSplitter
# ========= Part 1: Document Processing and Embedding Generation =========
# Load and split the markdown document using LangChain
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.vectorstores import FAISS
from embeddings import CustomHuggingFaceEmbeddings
# Load your markdown file (adjust the path as needed)
loader = UnstructuredMarkdownLoader("./data/mission_report.md")
docs = loader.load()
# Split the document into smaller chunks (each 1000 characters, no overlap)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
chunks = text_splitter.split_documents(docs)
# Save chunks for later use
os.makedirs("saved_data", exist_ok=True)
with open("saved_data/chunks.pkl", "wb") as f:
pickle.dump(chunks, f)
print(f"Saved {len(chunks)} chunks to saved_data/chunks.pkl")
embeddings = CustomHuggingFaceEmbeddings()
# Create a FAISS vector store from the document chunks and save it locally
vectorstore = FAISS.from_documents(chunks, embeddings)
vectorstore.save_local("faiss_index")
print("Saved FAISS index to 'faiss_index'")
# ========= Part 2: QA Generation using Llama Backend =========
# Setup Llama backend via unsloth and vLLM
from unsloth import FastLanguageModel
from vllm import SamplingParams
import rl_helpers # Ensure you have this or remove if not used
# Load the Llama model (adjust parameters as needed)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
max_seq_length=4096,
load_in_4bit=True, # Use 4-bit quantization if desired
fast_inference=True, # Enable fast inference
gpu_memory_utilization=0.6, # Adjust based on your GPU memory
)
# Define sampling parameters for generation
sampling_params = SamplingParams(
temperature=0.3,
top_p=0.95,
max_tokens=4096,
)
def batch_generate(prompts: List[str]) -> List[str]:
"""
Given a list of prompt strings, returns a list of generated outputs.
"""
def format_input(text: str) -> str:
return tokenizer.apply_chat_template(
[{"role": "user", "content": text}],
tokenize=False,
add_generation_prompt=True,
)
formatted = [format_input(p) for p in prompts]
outputs = model.fast_generate(formatted, sampling_params=sampling_params)
return [output.outputs[0].text for output in outputs]
def parse_qa_block(block: str) -> Optional[Tuple[str, str, str]]:
"""
Parses a QA block that should contain exactly three non-empty lines:
- A line starting with "Question:"
- A line starting with "Answer:"
- A line starting with "Difficulty:"
If the markers are not present but the block contains exactly three lines,
those are used in order.
Returns a tuple (question, answer, difficulty) or None if parsing fails.
"""
lines = [line.strip() for line in block.splitlines() if line.strip()]
if not lines:
return None
question, answer, difficulty = None, None, None
for line in lines:
lower = line.lower()
if question is None and lower.startswith("question:"):
question = line[len("question:") :].strip()
elif answer is None and lower.startswith("answer:"):
answer = line[len("answer:") :].strip()
elif difficulty is None and lower.startswith("difficulty:"):
difficulty = line[len("difficulty:") :].strip()
if question and answer and difficulty:
return question, answer, difficulty
if len(lines) == 3:
return lines[0], lines[1], lines[2]
return None
def parse_multiple_qa_output(output: str) -> List[Tuple[str, str, str]]:
"""
Splits the output into blocks (separated by one or more blank lines) and
attempts to parse each as a QA pair.
Returns a list of successfully parsed QA tuples.
"""
blocks = re.split(r"\n\s*\n", output.strip())
qa_pairs = []
for block in blocks:
parsed = parse_qa_block(block)
if parsed:
qa_pairs.append(parsed)
return qa_pairs
def generate_question_batch_for_chunks(
chunks: List, num_questions: int = 2, difficulty: str = None
) -> List[Dict]:
"""
Generates QA pairs for multiple chunks in batch.
For each chunk (except the first and last), a sliding window is used for context:
- before: previous chunk's content
- current: current chunk's content
- after: next chunk's content
Each prompt instructs the model to output exactly three lines per QA pair with markers.
Failed prompts are retried once in batch; if still unsuccessful, they are skipped.
Returns a list of dicts with keys: "chunk_id", "question", "answer", "difficulty".
"""
prompts = []
chunk_ids = []
# Prepare prompts using a sliding window
for i in range(1, len(chunks) - 1):
before = chunks[i - 1].page_content
current = chunks[i].page_content
after = chunks[i + 1].page_content
prompt = (
f"From the text within ==BEGIN== and ==END==, generate {num_questions} questions with answers.\n"
"For each QA pair, output exactly three lines with no extra commentary:\n"
"Line 1: Question: <your question>\n"
"Line 2: Answer: <the answer>\n"
"Line 3: Difficulty: <easy, medium, or hard>\n"
"Do not include any additional text.\n\n"
"==BEGIN==\n"
f"{before}\n{current}\n{after}\n"
"==END==\n"
)
prompts.append(prompt)
chunk_ids.append(i)
# First batch generation
outputs = batch_generate(prompts)
results = [None] * len(outputs)
failed_indices = []
# Parse each output
for idx, output in enumerate(outputs):
qa_pairs = parse_multiple_qa_output(output)
if qa_pairs is None or len(qa_pairs) < num_questions:
failed_indices.append(idx)
else:
results[idx] = qa_pairs[:num_questions]
# Retry failed prompts in batch
if failed_indices:
print(f"Retrying {len(failed_indices)} failed prompt(s)...")
retry_prompts = [prompts[i] for i in failed_indices]
retry_outputs = batch_generate(retry_prompts)
for j, idx in enumerate(failed_indices):
qa_pairs = parse_multiple_qa_output(retry_outputs[j])
if qa_pairs is not None and len(qa_pairs) >= num_questions:
results[idx] = qa_pairs[:num_questions]
else:
results[idx] = None # Mark as failed
# Build final output, skipping prompts that failed even after retry
final_questions = []
for i, qa_list in enumerate(results):
if qa_list is not None:
for qa in qa_list:
final_questions.append(
{
"chunk_id": chunk_ids[i],
"question": qa[0],
"answer": qa[1],
"difficulty": qa[2],
}
)
return final_questions
# Generate QA pairs in batch (using a sliding window over the chunks)
all_questions = generate_question_batch_for_chunks(
chunks, num_questions=2, difficulty="medium"
)
print(f"Generated {len(all_questions)} QA pairs.")
# Save the QA pairs to a JSON file
questions_path = os.path.join("saved_data", "questions.json")
with open(questions_path, "w") as f:
json.dump(all_questions, f, indent=2)
print(f"Saved questions to {questions_path}")

@ -0,0 +1,2 @@
unsloth_compiled_cache
0_*

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,17 @@
datasets
faiss-cpu
langchain
langchain-community
Markdown
tokenizers
transformers
unsloth==2025.3.6
unsloth_zoo==2025.3.4
unstructured
vllm
wandb
ipykernel
python-dotenv
loguru
gradio

@ -0,0 +1,540 @@
"""
RL helpers module for handling tool-based conversations.
This module provides utility functions for handling chat-based tool interactions
and calculating rewards based on the quality of responses.
"""
import asyncio
import json
import re
from dataclasses import dataclass
from datetime import datetime
import nest_asyncio
import torch
from search_module import get_qa_dataset, search
nest_asyncio.apply()
from typing import Callable, List
from trl.trainer.grpo_trainer import apply_chat_template
# Constants for prompts and tool definitions
def get_system_prompt():
"""Get the system prompt with current date."""
current_date = datetime.now().strftime("%d %b %Y")
return f"""Cutting Knowledge Date: December 2023
Today Date: {current_date}
When you receive a tool call response, use the output to format an answer to the original user question.
You are a helpful assistant with tool calling capabilities.
"""
# Tool definition for search corpus
SEARCH_TOOL_DEFINITION = {
"type": "function",
"function": {
"name": "search_corpus",
"description": "Search over the knowledge corpus with a given query",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The query to search the knowledge corpus with",
},
},
"required": ["query"],
},
},
}
def build_user_prompt(q):
"""
Build a user prompt with the question and search tool definition.
Args:
q (str): The question to ask
Returns:
str: Formatted user prompt
"""
user_prompt = f"""You are a research assistant, and you use the search_corpus tool to find answers to questions.
Given a question, answer it using by doing searches using the search_corpus tool.
To use the search_corpus tool, respond with a JSON for a function call with its proper arguments.
You may also reason in any message, thinking step by step about how to answer the question. Wrap your reasoning in <reasoning> and </reasoning> tags.
{json.dumps(SEARCH_TOOL_DEFINITION, indent=2)}
Question: {q}
"""
return user_prompt
def get_initial_chat(question):
"""
Initialize a chat state with the question.
Args:
question (str): The question to ask
Returns:
dict: Initial chat state with system and user messages
"""
return {
"messages": [
{"role": "system", "content": get_system_prompt()},
{"role": "user", "content": build_user_prompt(question)},
]
}
def extract_json_objects(text):
"""
Extracts JSON objects (dictionaries) from a text that may contain multiple JSON objects.
Args:
text (str): The input text possibly containing JSON objects.
Returns:
list: A list of parsed JSON objects (dictionaries) extracted from the text.
"""
results = []
length = len(text)
i = 0
while i < length:
# Look for the start of a JSON object
if text[i] == "{":
start = i
stack = 1
i += 1
# Continue until we find the matching closing brace
while i < length and stack > 0:
if text[i] == "{":
stack += 1
elif text[i] == "}":
stack -= 1
i += 1
# Only attempt to decode if the braces are balanced
if stack == 0:
candidate = text[start:i]
try:
obj = json.loads(candidate)
# Optionally, ensure it's a dictionary if that's what you expect
if isinstance(obj, dict):
results.append(obj)
except json.JSONDecodeError:
# If it's not valid JSON, skip it.
pass
else:
i += 1
return results
def remove_reasoning(text: str) -> str:
"""
Removes all content between <reasoning> and </reasoning> tags,
including the tags themselves.
Parameters:
text (str): The input text that may contain <reasoning>...</reasoning> tags.
Returns:
str: The text with the tags and their content removed.
"""
# The regex pattern matches from <reasoning> to </reasoning> non-greedily.
pattern = r"<reasoning>.*?</reasoning>"
cleaned_text = re.sub(pattern, "", text, flags=re.DOTALL)
return cleaned_text
def run_agent_generations(generate_fn, tokenizer, chat_states):
"""
Run generation for chat states requiring assistant responses.
Args:
generate_fn: Function to generate responses
tokenizer: Tokenizer for processing text
chat_states: List of chat states
Returns:
list: Updated chat states
"""
prompts = []
batch_indices = []
# Prepare prompts for chat states needing an assistant response.
for idx, chat_state in enumerate(chat_states):
if chat_state.get("finished"):
continue
if chat_state["messages"][-1]["role"] in ["ipython", "user"]:
prompt = apply_chat_template(chat_state, tokenizer=tokenizer)["text"]
prompts.append(prompt)
batch_indices.append(idx)
if prompts:
responses = generate_fn(prompts)
for i, idx in enumerate(batch_indices):
chat_state = chat_states[idx]
full_response = responses[i].outputs[0].text
assistant_response = full_response.split(
"<|start_header_id|>assistant<|end_header_id|>"
)[-1]
chat_state["messages"].append(
{"role": "assistant", "content": assistant_response}
)
return chat_states
def check_finished_chats(chat_states):
"""
Check which chat states are finished (no more function calls).
Args:
chat_states: List of chat states
Returns:
list: Updated chat states with finished flag
"""
for chat_state in chat_states:
if chat_state.get("finished"):
continue
assert (
chat_state["messages"][-1]["role"] == "assistant"
), "Expected the last role to be assistant"
assistant_response = chat_state["messages"][-1]["content"]
function_calls = extract_json_objects(assistant_response)
if len(function_calls) == 0:
chat_state["finished"] = True
return chat_states
def run_tool_calls(chat_states):
"""
Execute tool calls found in chat states.
Args:
chat_states: List of chat states
Returns:
list: Updated chat states with tool call results
"""
for chat_state in chat_states:
if chat_state.get("finished"):
continue
assert (
chat_state["messages"][-1]["role"] == "assistant"
), "Expected the last role to be assistant to run tool calls"
try:
assistant_response = chat_state["messages"][-1]["content"]
function_calls = extract_json_objects(assistant_response)
if len(function_calls) > 1:
raise ValueError(
"Expected only one function call in assistant response"
)
elif len(function_calls) == 1:
function_call = function_calls[0]
query = function_call["function"]["parameters"]["query"]
results = search(query, return_type=str, results=2)
chat_state["messages"].append({"role": "ipython", "content": results})
except Exception as e:
chat_state["messages"].append(
{"role": "system", "content": f"Error during post-processing: {str(e)}"}
)
chat_state["finished"] = True
return chat_states
def get_mask(text, tokenizer):
encoding = tokenizer(text, add_special_tokens=False)
start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
assistant_token = tokenizer.convert_tokens_to_ids("assistant")
end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
assistant_ranges = []
i = 0
while i < len(encoding.input_ids) - 1:
if (
encoding.input_ids[i] == start_header_id
and encoding.input_ids[i + 1] == assistant_token
):
i += 2
while (
i < len(encoding.input_ids) and encoding.input_ids[i] != end_header_id
):
i += 1
i += 2
start_idx = i
while i < len(encoding.input_ids) and encoding.input_ids[i] != eot_id:
i += 1
end_idx = i
assistant_ranges.append((start_idx, end_idx))
else:
i += 1
mask = [0] * len(encoding.input_ids)
for start_idx, end_idx in assistant_ranges:
for idx in range(start_idx, end_idx):
mask[idx] = 1
return torch.tensor(mask, dtype=torch.int)
def check_exceeded_max_new_tokens(chat_states, max_new_tokens, tokenizer):
for chat_state in chat_states:
if chat_state.get("finished"):
continue
initial_length = chat_state["initial_length"]
new_length = get_chat_num_tokens(chat_state, tokenizer)
if new_length - initial_length > max_new_tokens:
chat_state["finished"] = True
return chat_states
@dataclass
class AgenticOutputs:
prompt_tokens: list[torch.Tensor]
response_tokens: list[torch.Tensor]
response_masks: list[torch.Tensor]
final_response_str: list[str]
full_chat_states: list[dict]
def get_chat_num_tokens(chat_state, tokenizer):
chat_text = apply_chat_template(chat_state, tokenizer=tokenizer)["text"]
return (
tokenizer(chat_text, add_special_tokens=False, return_tensors="pt")["input_ids"]
.squeeze()
.shape[0]
)
def run_agent(
generate_fn, tokenizer, questions, max_generations=5, max_new_tokens=4096
):
"""
Run the agent to completion for a batch of questions.
Args:
generate_fn: Function to generate model responses
tokenizer: Tokenizer for processing text
batch: Batch of data containing questions
max_generations: Maximum number of generation steps
Returns:
list: Final answers for each question
"""
chat_states = [get_initial_chat(q) for q in questions]
# set the initial_prompt length
for chat_state in chat_states:
chat_state["initial_length"] = get_chat_num_tokens(chat_state, tokenizer)
# agent loop
for i in range(max_generations):
chat_states = run_agent_generations(generate_fn, tokenizer, chat_states)
chat_states = check_finished_chats(chat_states)
chat_states = run_tool_calls(chat_states)
chat_states = check_exceeded_max_new_tokens(
chat_states, max_new_tokens, tokenizer
)
answers = []
for chat in chat_states:
answers.append(chat["messages"][-1]["content"])
def split_prompt_assistant(convo_text):
marker = "<|start_header_id|>assistant<|end_header_id|>"
idx = convo_text.find(marker)
if idx == -1:
raise ValueError("Could not find assistant marker in conversation text.")
return convo_text, ""
# Include the marker in the prompt by slicing up to the end of the marker.
prompt = convo_text[: idx + len(marker)]
# The assistant response is everything after the marker.
assistant_response = convo_text[idx + len(marker) :]
return prompt, assistant_response
str_chats = [
apply_chat_template(chat, tokenizer=tokenizer)["text"] for chat in chat_states
]
prompt_toks, response_toks, response_masks = [], [], []
for str_chat in str_chats:
prompt, response = split_prompt_assistant(str_chat)
prompt_toks.append(
tokenizer(prompt, add_special_tokens=False, return_tensors="pt")[
"input_ids"
].squeeze()
)
response_toks.append(
tokenizer(response, add_special_tokens=False, return_tensors="pt")[
"input_ids"
].squeeze()[:max_new_tokens]
)
mask = get_mask(str_chat, tokenizer)[len(prompt_toks[-1]) :][:max_new_tokens]
response_masks.append(mask)
final_response_str = [chat["messages"][-1]["content"] for chat in chat_states]
full_chat_states = chat_states
agentic_outputs = AgenticOutputs(
prompt_tokens=prompt_toks,
response_tokens=response_toks,
response_masks=response_masks,
final_response_str=final_response_str,
full_chat_states=full_chat_states,
)
return agentic_outputs
# Verification
async def check_correctness(question, student_answer, answer):
"""
Calculate reward for a given student answer.
Args:
question (str): The original question
student_answer (str): The model's answer
answer (str): The ground truth answer
Returns:
float: Reward value (1 for correct, 0 for incorrect)
"""
# log to "./reward_func.log"
with open("reward_func.log", "a") as f:
f.write("\n" + "==" * 40 + "\n\n")
f.write(f"Question: {question}\n")
f.write(f"Student Answer: {student_answer}\n")
f.write(f"Answer: {answer}\n")
if student_answer.startswith("Error during"):
f.write(f"failed function call")
return 0
if len(student_answer) < 5:
f.write(f"failed Too short answer\n")
return 0
else:
f.write(f"last message didn't fail\n")
student_answer_clean = remove_reasoning(student_answer)
is_correct = await verify(student_answer_clean, question, answer)
f.write(f"Is Correct: {is_correct}, so reward is {int(is_correct)}\n")
return 1 if is_correct else 0
def check_student_answers(
questions: List[str],
answers: List[str],
student_answers: List[str],
vllm_generate_func: Callable[[List[str]], List[str]],
tokenizer,
log_file: str = "qa_log.txt",
) -> List[bool]:
"""
Evaluates a list of student answers against the true answers using a vLLM generate function.
The function applies the chat template to each prompt before passing it to the generate function.
It also appends the details of each QA pair and the verifier's response to a log file.
Args:
questions: A list of strings representing the questions.
answers: A list of strings representing the correct answers.
student_answers: A list of strings containing the student's answers.
vllm_generate_func: A function that takes a list of chat-formatted prompt strings and returns a list of generated outputs.
tokenizer: The tokenizer used to apply the chat template.
log_file: Optional; path to the file where the QA pairs and verification responses will be appended.
Returns:
A list of booleans indicating whether each student's answer is correct.
"""
if not (len(questions) == len(answers) == len(student_answers)):
raise ValueError(
"The number of questions, answers, and student answers must be equal."
)
prompts = []
for question, answer, student_ans in zip(questions, answers, student_answers):
# Construct the plain text prompt for each QA pair.
prompt_text = (
"You are grading a student's answer. For the following question, "
"compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer is correct, or 'No' if it is completely incorrect.\n\n"
f"Question: {question}\n"
f"Correct Answer: {answer}\n"
f"Student Answer: {student_ans}\n"
)
# Apply the chat template to the prompt.
formatted_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_text}],
tokenize=False,
add_generation_prompt=True,
)
prompts.append(formatted_prompt)
# Get the model responses in batch (each response should ideally be "Yes" or "No")
responses = vllm_generate_func(prompts)
responses_text = [response.outputs[0].text for response in responses]
# Evaluate each response and mark as correct if "yes" appears in the answer (case-insensitive)
results = []
for response in responses_text:
results.append("yes" in response.lower())
# Append the QA details and verifier's response to the specified log file
with open(log_file, "a") as file:
for question, answer, student_ans, verifier_response in zip(
questions, answers, student_answers, responses_text
):
file.write("Question: " + question + "\n")
file.write("Correct Answer: " + answer + "\n")
file.write("Student Answer: " + student_ans + "\n")
file.write("Verifier said: " + verifier_response + "\n")
file.write("-" * 40 + "\n")
return results
def build_reward_correctness_fn(generate_fn, tokenizer):
def reward_correctness(prompts, completions, **reward_kwargs):
teacher_answers = reward_kwargs["answer"]
student_answers = [
completion["messages"][-1]["content"] for completion in completions
]
correct = check_student_answers(
prompts,
teacher_answers,
student_answers,
vllm_generate_func=generate_fn,
tokenizer=tokenizer,
)
return correct
return reward_correctness
def reward_formatting(prompts, completions, **reward_kwargs):
# make sure full chats doesn't have any error function calls
has_error = [False] * len(completions)
for i, chat in enumerate(completions):
for message in chat["messages"]:
if "Error during" in message["content"]:
has_error[i] = True
break
return [0.7 if not e else 0 for e in has_error]
def run_eval(generate_fn, verify_fn, tokenizer):
train_dataset, test_dataset = get_qa_dataset()
questions = test_dataset["prompt"]
agentic_outputs = run_agent(generate_fn, tokenizer, questions)
full_chat_states = agentic_outputs.full_chat_states
final_responses = agentic_outputs.final_response_str
rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"])
print("RESULTS:")
print("percentage of correct answers:", sum(rewards) / len(rewards))
print("=" * 30)
return full_chat_states

@ -0,0 +1,194 @@
"""
Search module for RL training loop.
This module provides functions to search through vectorized documents and retrieve question-answer pairs.
"""
import pickle
import json
import random
import asyncio
from typing import List, Tuple, Optional, Union, Dict, Any
from enum import Enum
from pydantic import BaseModel
from langchain.vectorstores import FAISS
from datasets import Dataset
from embeddings import CustomHuggingFaceEmbeddings
# Load pre-saved vectorstore
def load_vectorstore():
"""Load the pre-saved FAISS index"""
try:
import os
embeddings = CustomHuggingFaceEmbeddings()
# Load the FAISS index with absolute path
index_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "faiss_index"
)
print(f"Loading FAISS index from: {index_path}")
vectorstore = FAISS.load_local(
index_path, embeddings, allow_dangerous_deserialization=True
)
print("Successfully loaded FAISS index")
return vectorstore
except Exception as e:
print(f"Error loading vectorstore: {e}")
import traceback
traceback.print_exc()
return None
# Load the vectorstore when module is imported
try:
vectorstore = load_vectorstore()
if vectorstore is None:
print("Warning: FAISS vectorstore could not be loaded.")
except Exception as e:
print(f"Error loading vectorstore: {e}")
vectorstore = None
def search(query: str, return_type=str, results: int = 5) -> Union[str, List[str]]:
"""
Search for relevant chunks using similarity search.
Args:
query: The search query
return_type: Return as string or list (default: str)
results: Number of results to return (default: 5)
Returns:
Results as string or list depending on return_type
"""
if vectorstore is None:
raise ValueError("Vectorstore not loaded. Please ensure FAISS index exists.")
search_results = vectorstore.similarity_search(query, k=results)
if return_type == str:
str_results = ""
for idx, result in enumerate(search_results, start=1):
str_results += f"Result {idx}:\n"
str_results += result.page_content + "\n"
str_results += "------\n"
return str_results
elif return_type == list:
return [result.page_content for result in search_results]
else:
raise ValueError("Invalid return_type. Use str or list.")
# Load questions from saved data
def load_qa_data():
"""Load the pre-generated questions and document chunks"""
try:
import os
# Get absolute paths to data files
base_dir = os.path.dirname(os.path.abspath(__file__))
chunks_path = os.path.join(base_dir, "saved_data", "chunks.pkl")
questions_path = os.path.join(base_dir, "saved_data", "questions.json")
print(f"Loading chunks from: {chunks_path}")
print(f"Loading questions from: {questions_path}")
# Load the chunks
with open(chunks_path, "rb") as f:
chunks = pickle.load(f)
# Load the questions
with open(questions_path, "r") as f:
questions = json.load(f)
print(
f"Successfully loaded {len(chunks)} chunks and {len(questions)} questions"
)
return chunks, questions
except Exception as e:
print(f"Error loading QA data: {e}")
import traceback
traceback.print_exc()
return None, None
# Load chunks and questions when module is imported
try:
chunks, questions = load_qa_data()
if chunks is None or questions is None:
print("Warning: Could not load QA data.")
except Exception as e:
print(f"Error initializing QA data: {e}")
chunks, questions = None, None
def get_question_answer(
idx: Optional[int] = None, return_both: bool = True
) -> Union[dict, str]:
"""
Get a question-answer pair either by index or randomly.
Args:
idx: Index of the question to retrieve (if None, selects random question)
return_both: Whether to return both question and answer (default: True)
Returns:
Question and answer as tuple if return_both=True, otherwise just the question
"""
if questions is None:
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
if idx is None:
# Select a random question
qa_pair = random.choice(questions)
elif 0 <= idx < len(questions):
# Select question by index
qa_pair = questions[idx]
else:
raise ValueError(
f"Index out of range. Must be between 0 and {len(questions)-1}"
)
question = qa_pair["question"]
answer = qa_pair["answer"]
if return_both:
return {"question": question, "answer": answer}
else:
return question
# Function to get the total number of questions
def get_question_count() -> int:
"""Get the total number of available questions"""
if questions is None:
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
return len(questions)
def get_qa_dataset():
"""
Return a HuggingFace Dataset containing question and answer pairs.
This dataset is constructed from the loaded questions data (questions.json).
Each element in the dataset is a dictionary that includes at least:
- "question": The question text.
- "answer": The corresponding answer text.
Additional keys present in the original questions data will also be included.
Returns:
A HuggingFace Dataset object.
"""
if questions is None:
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
qa_dataset = Dataset.from_list(questions)
full_dataset = qa_dataset.shuffle(seed=42)
train_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["train"]
test_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["test"]
# rename the column of the dataset from "question" to "input"
train_dataset = train_dataset.rename_column("question", "prompt")
test_dataset = test_dataset.rename_column("question", "prompt")
return train_dataset, test_dataset

@ -0,0 +1,200 @@
#!/usr/bin/env python3
"""
Simple command-line Q&A environment for testing with search functionality.
"""
import asyncio
import json
import random
import sys
import time
from typing import Any, Dict
# Import our search module (ensure these functions follow the new interfaces)
from search_module import get_question_answer, get_question_count, search
class SimpleQAEnvironment:
"""Simple command-line environment for Q&A with search capability."""
def __init__(self):
self.score = {"correct": 0, "incorrect": 0, "total": 0}
self.session_data = []
self.current_question = None
def display_welcome(self):
"""Display welcome message and instructions."""
print("\n===== Search & Answer Environment =====")
print("Answer questions using the search tool to find relevant information.")
print("Type 'q' to quit, 'h' for help.\n")
def display_help(self):
"""Display help information."""
print("\n===== Commands =====")
print("n - Get a new question")
print("s <query> - Search for information (e.g., s program launch date)")
print("a <answer> - Submit your answer")
print("h - Display this help message")
print("q - Quit the program\n")
def display_question(self, question: str):
"""Display the current question."""
print("\n===== QUESTION =====")
print(question)
print("=====================\n")
def get_new_question(self) -> str:
"""Get a new random question and set it as current."""
total_questions = get_question_count()
question_id = random.randint(0, total_questions - 1)
# Updated to match new interface: get_question_answer now returns a dict.
qa = get_question_answer(question_id)
question = qa["question"]
correct_answer = qa["answer"]
question_data = {
"id": question_id,
"question": question,
"correct_answer": correct_answer,
"start_time": time.time(),
"searches": [],
}
self.current_question = question_data
return question
def perform_search(self, query: str):
"""Perform a search with the given query."""
if not query:
print("Please provide a search query.")
return
try:
print("\n===== SEARCH RESULTS =====")
results = search(query)
print(results)
print("==========================\n")
# Record search in current question data if available.
if self.current_question is not None:
self.current_question["searches"].append(query)
except Exception as e:
print(f"Error searching: {str(e)}")
async def process_answer(self, user_answer: str):
"""Process and verify the user's answer."""
if self.current_question is None:
print("Please get a question first.")
return
if not user_answer:
print("Please provide an answer.")
return
# Record answer and calculate time taken.
self.current_question["user_answer"] = user_answer
self.current_question["end_time"] = time.time()
self.current_question["time_taken"] = (
self.current_question["end_time"] - self.current_question["start_time"]
)
try:
print("\nVerifying your answer...")
correct = await verify(
user_answer,
self.current_question["question"],
self.current_question["correct_answer"],
router,
)
# Update score and inform the user.
self.score["total"] += 1
if correct:
self.score["correct"] += 1
print("\n✓ Your answer is CORRECT!")
else:
self.score["incorrect"] += 1
print("\n✗ Your answer is INCORRECT.")
print(
f"\nThe correct answer is:\n{self.current_question['correct_answer']}"
)
print(f"\nScore: {self.score['correct']}/{self.score['total']}")
# Record the result and add the current question to the session data.
self.current_question["is_correct"] = correct
self.session_data.append(self.current_question)
# Clear the current question.
self.current_question = None
except Exception as e:
print(f"Error verifying answer: {str(e)}")
def save_session(self):
"""Save the session data to a file."""
if not self.session_data:
return
timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"qa_session_{timestamp}.json"
session_data = {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"score": self.score,
"questions": self.session_data,
}
try:
with open(filename, "w") as f:
json.dump(session_data, f, indent=2)
print(f"\nSession data saved to {filename}")
except Exception as e:
print(f"Error saving session data: {str(e)}")
async def run(self):
"""Run the main command loop."""
self.display_welcome()
while True:
command = input("\n> ").strip()
if not command:
continue
# Process commands.
if command.lower() == "q":
break
elif command.lower() == "h":
self.display_help()
elif command.lower() == "n":
question = self.get_new_question()
self.display_question(question)
elif command.lower().startswith("s "):
query = command[2:].strip()
self.perform_search(query)
elif command.lower().startswith("a "):
answer = command[2:].strip()
await self.process_answer(answer)
else:
print("Unknown command. Type 'h' for help.")
# Save session data on exit.
self.save_session()
print("\nThank you for using the Q&A environment!")
async def main():
"""Main function to start the application."""
env = SimpleQAEnvironment()
await env.run()
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\nProgram terminated by user.")
except Exception as e:
print(f"\nError: {str(e)}")

@ -0,0 +1,189 @@
# %%
from unsloth import FastLanguageModel
# %%
from unsloth import is_bfloat16_supported
import torch
max_seq_length = 4096 * 2 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
max_seq_length=max_seq_length,
load_in_4bit=True, # False for LoRA 16bit
fast_inference=True, # Enable vLLM fast inference
max_lora_rank=lora_rank,
gpu_memory_utilization=0.6, # Reduce if out of memory
)
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
], # Remove QKVO if out of memory
lora_alpha=lora_rank,
use_gradient_checkpointing="unsloth", # Enable long context finetuning
random_state=3407,
)
# %%
import re
from datasets import load_dataset, Dataset
from search_module import search, get_question_answer, get_question_count
from rl_helpers import get_qa_dataset
train_dataset, test_dataset = get_qa_dataset()
# %% [markdown]
# <a name="Train"></a>
# ### Train the model
#
# Now set up GRPO Trainer and all configurations!
# %%
import os
os.environ["WANDB_PROJECT"] = "bootstrap-search-rl"
# %%
# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer
import UnslothGRPOTrainerTemp
training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
use_vllm=True, # use vLLM for fast inference!
use_agentic_generate=True, # use agentic generation
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
optim="paged_adamw_8bit",
logging_steps=1,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
per_device_train_batch_size=8,
gradient_accumulation_steps=1, # Increase to 4 for smoother training
num_generations=8, # Decrease if out of memory
max_prompt_length=1024,
max_completion_length=1024,
# num_train_epochs = 1, # Set to 1 for a full training run
max_steps=101,
save_steps=50,
max_grad_norm=0.1,
report_to="none", # Can use Weights & Biases
output_dir="full_local_training",
)
# %%
import rl_helpers
# importlib.reload(rl_helpers)
def agentic_generate(
prompts: list[str],
generate_fn,
max_generations: int = 6,
):
return run_agent(generate_fn, tokenizer, prompts, max_generations)
model.agentic_generate = agentic_generate
from vllm import SamplingParams
verifier_sampling_params = SamplingParams(
temperature=0.1,
top_p=0.95,
max_tokens=4096,
)
def verifier_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=verifier_sampling_params,
)
run_agent = rl_helpers.run_agent
reward_correctness = rl_helpers.build_reward_correctness_fn(
verifier_generate_fn,
tokenizer,
)
reward_formatting = rl_helpers.reward_formatting
import UnslothGRPOTrainerTemp
trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
reward_correctness,
reward_formatting,
],
args=training_args,
train_dataset=train_dataset,
)
# %%
trainer.train()
# %% [markdown]
# <a name="Inference"></a>
# ### Inference
# Now let's try benchmark the model we trained!
# %%
from vllm import SamplingParams
import rl_helpers
sampling_params = SamplingParams(
temperature=0.5,
top_p=0.95,
max_tokens=4096,
)
def eval_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=sampling_params,
lora_request=model.load_lora(
"full_local_training/checkpoint-101"
), # load the trained LoRA
)
rl_helpers.run_eval(
generate_fn=eval_generate_fn,
verify_fn=reward_correctness,
tokenizer=tokenizer,
)
# %%
# eval w/o lora
def eval_generate_fn(inputs):
return model.fast_generate(
inputs,
sampling_params=sampling_params,
)
rl_helpers.run_eval(
generate_fn=eval_generate_fn,
verify_fn=reward_correctness,
tokenizer=tokenizer,
)

@ -1,26 +0,0 @@
import sys,os,random,time
def f(x):return x*2 if x%2==0 else x+1
class C:
def __init__(self,v):self.v=v
def p(self):print("Value:",self.v)
def m(l):return [f(x) for x in l]
x=[random.randint(1,100) for _ in range(10)]
print("Original:",x)
print("Processed:",m(x))
for i in range(len(x)):
if i%2==0:
x[i]*=2
elif i%3==0:
x[i]+=3
else:
x[i]-=1
c=C(sum(x))
c.p()
try:
for i in range(5):print(i,x[i],f(x[i]))
except:pass
with open("temp.txt","w") as f:f.write("Hello, world!")
while True:
time.sleep(0.1)
if random.random()>0.9:break
print("Done!")
Loading…
Cancel
Save