- Added initial files from AutoDiact as starting point - Enhanced `README.md` with project overview and setup instructions. . - Removed `ugly_code_file.py` as part of cleanup. - Added various documentation files and assets for project clarity. - Included Jupyter notebooks for training and experimentation.main
parent
91c2476c28
commit
a58722e16f
@ -0,0 +1,2 @@
|
|||||||
|
HF_TOKEN=
|
||||||
|
OPENROUTER_API_KEY=
|
File diff suppressed because it is too large
Load Diff
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,63 @@
|
|||||||
|
# Worklog
|
||||||
|
|
||||||
|
## Backlog
|
||||||
|
|
||||||
|
- [ ] Modify `generate_dataset.py` (**ONLY AFTER** the simple training and benchmark works):
|
||||||
|
- [ ] As a data dataset maker, I want to change from LLama 3.1 8B to API call, like claude, gemini or openai. Originally they use 3.1 8B for `Self-Bootstrapping` demonstration, but the dataset quality is low, for sure.
|
||||||
|
- [ ] Experimenting with different chunking strategies
|
||||||
|
- [ ] [search-backends.md](search-backends.md) design (for more dataset noise (**ONLY AFTER** the simple training dataset works))
|
||||||
|
|
||||||
|
- [ ] Research a little bit on Agentic Reward Modeling (for designing better reward function maybe?)
|
||||||
|
- <https://medium.com/@techsachin/agentic-reward-modeling-combine-human-preferences-with-verifiable-correctness-signals-for-reliable-76c408b3491c>
|
||||||
|
- <https://arxiv.org/pdf/2502.19328>
|
||||||
|
- <https://github.com/THU-KEG/Agentic-Reward-Modeling>
|
||||||
|
- <https://www.themoonlight.io/en/review/agentic-reward-modeling-integrating-human-preferences-with-verifiable-correctness-signals-for-reliable-reward-systems>
|
||||||
|
|
||||||
|
## yymmdd
|
||||||
|
|
||||||
|
- [ ] task description
|
||||||
|
|
||||||
|
## 250324
|
||||||
|
|
||||||
|
- [ ] @thinhlpg transfers the project to @bachvudinh
|
||||||
|
|
||||||
|
## 250323
|
||||||
|
|
||||||
|
- [ ] Train the model
|
||||||
|
- [ ] Make the dataset
|
||||||
|
- [ ] Upload datasets to HF Hub
|
||||||
|
- Initial dataset from AutoDidact
|
||||||
|
- Paraphrased sdataset
|
||||||
|
- [ ] Make a simple gradio demo app
|
||||||
|
|
||||||
|
## 250322
|
||||||
|
|
||||||
|
- [x] Moving all the scattered and disorganized stuffs that've been working on for the past week into this repo.
|
||||||
|
- [x] Write proposal for DeepSearch
|
||||||
|
- [x] [evaluation.md](evaluation.md) design (list out the metrics and why)
|
||||||
|
- [x] [dataset.md](dataset.md) design (pipeline, data structure,...)
|
||||||
|
- [x] [reward-functions.md](reward-functions.md) design (list out the functions and why)
|
||||||
|
- [x] As a new member of research team, i'm curious on how did we do GRPO with Alphamaze?, so that I can inherit the good stuff and improve the workflow!!!
|
||||||
|
- [Alphamaze](https://github.com/menloresearch/visual-thinker)?
|
||||||
|
- <https://www.menlo.ai/blog/alpha-maze>
|
||||||
|
- <https://arxiv.org/pdf/2502.14669>
|
||||||
|
- > Our training process involved two key stages: creating a specialized dataset and then using a combination of supervised fine-tuning (SFT) and reinforcement learning (RL) to train the model.
|
||||||
|
- LLaMA-Factory for SFT **(1.5B 6xA6000 1.5 hour)** and Unsloth for GRPO
|
||||||
|
- 💡 Hmm so for SFT we have 50% successful data and 50% retry data, and full successful data for GRPO. Can I also apply this to DeepSearch as well? #HACK
|
||||||
|
|
||||||
|
## 250321
|
||||||
|
|
||||||
|
- [x] Inspect the code of AutoDidact in a more detailed way <https://github.com/menloresearch/DeepSearch/issues/4>
|
||||||
|
|
||||||
|
## 250320
|
||||||
|
|
||||||
|
- Research on GRPO <https://github.com/menloresearch/DeepSearch/issues/2>
|
||||||
|
|
||||||
|
## 250319
|
||||||
|
|
||||||
|
- Research on GRPO <https://github.com/menloresearch/DeepSearch/issues/2>
|
||||||
|
- Run the training script of AutoDidact
|
||||||
|
|
||||||
|
## 250318
|
||||||
|
|
||||||
|
- Idea received <https://github.com/menloresearch/DeepSearch/issues/1>
|
@ -0,0 +1,31 @@
|
|||||||
|
# Adaptive Search Behavior
|
||||||
|
|
||||||
|
- [Agent Action](agent-action.md) -> mostly recognize missing something -> perform "refined query"
|
||||||
|
- [ ] As a model trainer, I want to inspect the full chat state of the agent to know what's going on so I can improve it -> implement a simple cli inspect tool after training, just print out full chat state.
|
||||||
|
- Example from AutoDidact:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
Example Question
|
||||||
|
What was the reason for substituting the backup Command Module Pilot 3 days prior to the Apollo 13 flight?
|
||||||
|
|
||||||
|
Step-by-Step Search Process
|
||||||
|
Query : "Apollo 13 Command Module Pilot substitution"
|
||||||
|
|
||||||
|
Outcome: Retrieved operational support details, but no explanation for the substitution.
|
||||||
|
Agent's Action: Recognized missing information → **Refined query**.
|
||||||
|
Query : "Apollo 13 Command Module Pilot substitution reason"
|
||||||
|
|
||||||
|
Outcome: Retrieved general mission anomaly details, but still no direct answer.
|
||||||
|
Agent's Action: **Increased query specificity**.
|
||||||
|
Query : "Apollo 13 John 'Jack' Swigert substitution"
|
||||||
|
|
||||||
|
Outcome: Found general mission reports, but still lacked a clear reason for substitution.
|
||||||
|
Agent's Action: Hypothesized illness might be a factor → **Refined query** accordingly.
|
||||||
|
Query : "Apollo 13 Jack Swigert illness substitution"
|
||||||
|
|
||||||
|
Outcome: Retrieved the exact explanation: "Several days prior to launch, the backup Lunar Module Pilot became sick with measles. Examinations of the prime crew indicated that the Command Module Pilot was not immune to the disease; therefore, the backup Command Module Pilot was substituted."
|
||||||
|
Final Answer
|
||||||
|
The original Command Module Pilot lacked immunity to measles, necessitating his replacement by Jack Swigert.
|
||||||
|
|
||||||
|
This example shows how llama learns to do multiple searches to find answers to its questions.
|
||||||
|
```
|
@ -0,0 +1,5 @@
|
|||||||
|
# Agent Action
|
||||||
|
|
||||||
|
- [ ] Research a bit more on this because I'm a bit outdated on the training side
|
||||||
|
- [ ] How does the dataset look like?
|
||||||
|
- [ ] How to evaluate the performance?
|
After Width: | Height: | Size: 37 KiB |
@ -0,0 +1,52 @@
|
|||||||
|
# Evaluation
|
||||||
|
|
||||||
|
- **Goal**:
|
||||||
|
- 1. Better performance than the original one (by auto eval script)
|
||||||
|
- 2. Better performance by real human eval/preference
|
||||||
|
|
||||||
|
## Implementation Phases
|
||||||
|
|
||||||
|
- [x] 1. Just take the eval function from the original repo (it simply uses accuracy (exact match)) and simple quick glance on the output quality.
|
||||||
|
- [ ] 2. Find some more common and conventional dataset and benchmarks (still auto script)
|
||||||
|
- [ ] 3. Setup human eval
|
||||||
|
|
||||||
|
## Baseline
|
||||||
|
|
||||||
|
- Info from autodidact
|
||||||
|
- After just 100 steps of GRPO training (1 hour on a single RTX 4090 GPU), Llama-8B significantly improved its ability to research and answer questions from the Apollo 13 mission report
|
||||||
|
- On a validation set of 68 questions, accuracy more than doubled from 23% to 59%.
|
||||||
|
|
||||||
|
- Training log: idk why but the result that I got from acutally running the training is a bit lower.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
|
||||||
|
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
|
||||||
|
Processed prompts: 100%|████████████████| 16/16 [00:00<00:00, 39.27it/s, est. speed input: 6827.13 toks/s, output: 81.01 toks/s]
|
||||||
|
rewards_per_func: tensor([0.6875, 0.7000], device='cuda:0'):05, 2.55it/s, est. speed input: 385.80 toks/s, output: 5.11 toks/s]
|
||||||
|
{'loss': 0.0003, 'grad_norm': 0.5810762047767639, 'learning_rate': 0.0, 'rewards/reward_correctness': 0.6875, 'rewards/reward_formatting': 0.699999988079071, 'reward': 1.3875000476837158, 'reward_std': 0.44403791427612305, 'completion_length': 224.125, 'kl': 0.00834659393876791, 'epoch': 0.34}
|
||||||
|
{'train_runtime': 7992.2854, 'train_samples_per_second': 0.202, 'train_steps_per_second': 0.013, 'train_loss': 0.0005197484556535774, 'epoch': 0.34}
|
||||||
|
100%|███████████████████████████████████████████████████████████████████████████████████████| 101/101 [2:13:12<00:00, 79.13s/it]
|
||||||
|
Processed prompts: 100%|████████████████| 67/67 [00:20<00:00, 3.28it/s, est. speed input: 950.44 toks/s, output: 394.51 toks/s]
|
||||||
|
Processed prompts: 100%|███████████████| 66/66 [00:20<00:00, 3.15it/s, est. speed input: 2383.55 toks/s, output: 323.82 toks/s]
|
||||||
|
Processed prompts: 100%|███████████████| 20/20 [00:17<00:00, 1.13it/s, est. speed input: 1320.49 toks/s, output: 146.76 toks/s]
|
||||||
|
Processed prompts: 100%|████████████████| 17/17 [00:16<00:00, 1.04it/s, est. speed input: 1620.28 toks/s, output: 98.35 toks/s]
|
||||||
|
Processed prompts: 100%|██████████████████| 9/9 [00:15<00:00, 1.73s/it, est. speed input: 1165.77 toks/s, output: 71.38 toks/s]
|
||||||
|
Processed prompts: 100%|████████████████| 67/67 [00:04<00:00, 16.31it/s, est. speed input: 3617.28 toks/s, output: 61.11 toks/s]
|
||||||
|
RESULTS:
|
||||||
|
percentage of correct answers: 0.5074626865671642
|
||||||
|
==============================
|
||||||
|
Processed prompts: 100%|███████████████| 67/67 [00:15<00:00, 4.46it/s, est. speed input: 1292.29 toks/s, output: 561.32 toks/s]
|
||||||
|
Processed prompts: 100%|███████████████| 44/44 [00:18<00:00, 2.44it/s, est. speed input: 1800.84 toks/s, output: 244.13 toks/s]
|
||||||
|
Processed prompts: 100%|███████████████| 13/13 [00:12<00:00, 1.05it/s, est. speed input: 1209.04 toks/s, output: 126.32 toks/s]
|
||||||
|
Processed prompts: 100%|███████████████| 10/10 [00:13<00:00, 1.32s/it, est. speed input: 1225.46 toks/s, output: 109.78 toks/s]
|
||||||
|
Processed prompts: 100%|██████████████████| 7/7 [00:12<00:00, 1.86s/it, est. speed input: 1149.18 toks/s, output: 76.05 toks/s]
|
||||||
|
Processed prompts: 100%|████████████████| 67/67 [00:02<00:00, 31.53it/s, est. speed input: 6047.70 toks/s, output: 83.31 toks/s]
|
||||||
|
RESULTS:
|
||||||
|
percentage of correct answers: 0.19402985074626866
|
||||||
|
==============================
|
||||||
|
[rank0]:[W320 07:13:50.651270455 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())
|
||||||
|
```
|
||||||
|
|
||||||
|
## Getting some sense of the eval data or benchmark
|
||||||
|
|
||||||
|
- > For example, benchmarks like ARC-AGI, which involve visual reasoning, remain challenging for these models, even though they might seem straightforward to a human. (ichigo)
|
@ -0,0 +1,15 @@
|
|||||||
|
# GRPO idea
|
||||||
|
|
||||||
|
- The training flow of R1 is really simple (thanks my friend professional yapper @vTuanpham) for initially clarifing my dumbness 🤣
|
||||||
|
|
||||||
|
```python
|
||||||
|
1. Train một con base biết dùng tool bằng sft thuần để boost
|
||||||
|
Tuan
|
||||||
|
2. Sau đó thả rông bằng gpro, syntax gần đúng 0.5, syntax đúng params lệch quá thì 0.65, cả hai đều được thì 0.85,...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Unsloth's guide
|
||||||
|
|
||||||
|
- <https://unsloth.ai/blog/r1-reasoning>
|
||||||
|
- Heheboi let's steal this notebook <https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb>
|
||||||
|
- <https://docs.unsloth.ai/basics/reasoning-grpo-and-rl> - This is like the most simple
|
@ -0,0 +1,10 @@
|
|||||||
|
```mermaid
|
||||||
|
graph TD
|
||||||
|
A[User Query] -->|Random Search Engine Assigned| B{Synthetic Search Engine}
|
||||||
|
B -->|Retrieves Initial Results| C[Model Analyzes Results]
|
||||||
|
C -->|Refines Query if Needed| D[Iterative Search Process]
|
||||||
|
D -->|Final Answer Found| E[Return Best Match]
|
||||||
|
E -->|Rewards/Penalties Applied| F[Reinforcement Learning Update]
|
||||||
|
F -->|Optimized Search Strategy| B
|
||||||
|
|
||||||
|
```
|
@ -0,0 +1,8 @@
|
|||||||
|
# Search backends
|
||||||
|
|
||||||
|
- Purpose: adding more noise to the training process. (already did this in the initial dataset)
|
||||||
|
- Different search strategy? - Semantic search, keyword search, BM25, actually api call
|
||||||
|
- Embedding models, Retrieval mechanisms (BM25, dense, hybrid), Query expansion methods, Reranking strategies
|
||||||
|
- Random search engine assignment per query
|
||||||
|
- Noise and inconsistency injection to prevent shortcut learning
|
||||||
|
|
@ -0,0 +1,5 @@
|
|||||||
|
# Self Verification
|
||||||
|
|
||||||
|
- [x] Investigate this term: it's word is mentioned in the autodiact's about section and also in the deepseek R1 paper (not so detailed), but not in blogs or code base. I think this word is important and should be investigated
|
||||||
|
- Lol a "Verifier" is just a synonym of **reward function**
|
||||||
|
- <https://docs.unsloth.ai/basics/reasoning-grpo-and-rl#reward-functions-verifier>
|
@ -0,0 +1,92 @@
|
|||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
|
||||||
|
# Set a default model here
|
||||||
|
DEFAULT_MODEL_NAME = "avsolatorio/NoInstruct-small-Embedding-v0"
|
||||||
|
|
||||||
|
|
||||||
|
class CustomHuggingFaceEmbeddings(Embeddings):
|
||||||
|
"""
|
||||||
|
A custom embeddings class that wraps a Hugging Face model for generating embeddings.
|
||||||
|
|
||||||
|
Supports two modes:
|
||||||
|
- "sentence": uses the [CLS] token representation for sentence/document embeddings.
|
||||||
|
- "query": uses mean pooling over tokens (weighted by the attention mask) for query embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, model_name: str = DEFAULT_MODEL_NAME, default_mode: str = "sentence"
|
||||||
|
):
|
||||||
|
self.model_name = model_name
|
||||||
|
# Set device to GPU if available, else CPU
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.model = AutoModel.from_pretrained(model_name).to(self.device)
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
self.default_mode = default_mode # "sentence" or "query"
|
||||||
|
self.model.eval() # Set model to evaluation mode
|
||||||
|
|
||||||
|
def get_embedding(self, text: Union[str, List[str]], mode: str = None):
|
||||||
|
if mode is None:
|
||||||
|
mode = self.default_mode
|
||||||
|
assert mode in (
|
||||||
|
"query",
|
||||||
|
"sentence",
|
||||||
|
), f"Unsupported mode: {mode}. Only 'query' and 'sentence' are supported."
|
||||||
|
|
||||||
|
# Ensure we are working with a list of texts
|
||||||
|
if isinstance(text, str):
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
# Tokenize the input texts
|
||||||
|
inp = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
||||||
|
# Move the input tensors to the same device as the model
|
||||||
|
inp = {key: value.to(self.device) for key, value in inp.items()}
|
||||||
|
|
||||||
|
# Forward pass (no gradients needed)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = self.model(**inp)
|
||||||
|
|
||||||
|
if mode == "query":
|
||||||
|
# Mean pooling: weight by attention mask and average across tokens
|
||||||
|
vectors = output.last_hidden_state * inp["attention_mask"].unsqueeze(2)
|
||||||
|
vectors = vectors.sum(dim=1) / inp["attention_mask"].sum(dim=-1).view(-1, 1)
|
||||||
|
else:
|
||||||
|
# Sentence/document embedding: use the [CLS] token (first token) representation
|
||||||
|
vectors = output.last_hidden_state[:, 0, :]
|
||||||
|
return vectors
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
Compute embeddings for a list of documents (using sentence mode).
|
||||||
|
"""
|
||||||
|
vectors = self.get_embedding(texts, mode="sentence")
|
||||||
|
return vectors.cpu().numpy().tolist()
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""
|
||||||
|
Compute an embedding for a single query.
|
||||||
|
"""
|
||||||
|
vector = self.get_embedding(text, mode="query")
|
||||||
|
return vector.cpu().numpy()[0].tolist()
|
||||||
|
|
||||||
|
|
||||||
|
# For quick testing
|
||||||
|
if __name__ == "__main__":
|
||||||
|
embeddings = CustomHuggingFaceEmbeddings()
|
||||||
|
|
||||||
|
# Example texts for document embeddings
|
||||||
|
texts = [
|
||||||
|
"Illustration of the REaLTabFormer model. The left block shows the non-relational tabular data model using GPT-2.",
|
||||||
|
"Predicting human mobility holds significant practical value, with applications in disaster planning and epidemic simulation.",
|
||||||
|
"As economies adopt digital technologies, policy makers are asking how to prepare the workforce for emerging labor demands.",
|
||||||
|
]
|
||||||
|
doc_embeddings = embeddings.embed_documents(texts)
|
||||||
|
print("Document embeddings:", doc_embeddings)
|
||||||
|
|
||||||
|
# Example query embedding
|
||||||
|
query_embedding = embeddings.embed_query("Which sentence talks about jobs?")
|
||||||
|
print("Query embedding:", query_embedding)
|
@ -0,0 +1,230 @@
|
|||||||
|
"""
|
||||||
|
This script performs two main tasks:
|
||||||
|
1. It loads a markdown document, splits it into chunks, generates embeddings,
|
||||||
|
and builds a FAISS index (which is saved locally).
|
||||||
|
2. It generates QA pairs from the document using llama.
|
||||||
|
For each chunk (using a sliding window for context), it generates multiple question-answer pairs
|
||||||
|
with different difficulties. The generation is performed in batch with one retry for failed prompts.
|
||||||
|
Successfully generated QA pairs are saved to "saved_data/questions.json".
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
pip install langchain faiss-cpu unsloth vllm
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import re
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
# ========= Part 1: Document Processing and Embedding Generation =========
|
||||||
|
# Load and split the markdown document using LangChain
|
||||||
|
from langchain_community.document_loaders import UnstructuredMarkdownLoader
|
||||||
|
from langchain_community.vectorstores import FAISS
|
||||||
|
|
||||||
|
from embeddings import CustomHuggingFaceEmbeddings
|
||||||
|
|
||||||
|
# Load your markdown file (adjust the path as needed)
|
||||||
|
loader = UnstructuredMarkdownLoader("./data/mission_report.md")
|
||||||
|
docs = loader.load()
|
||||||
|
|
||||||
|
# Split the document into smaller chunks (each 1000 characters, no overlap)
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||||
|
chunks = text_splitter.split_documents(docs)
|
||||||
|
|
||||||
|
# Save chunks for later use
|
||||||
|
os.makedirs("saved_data", exist_ok=True)
|
||||||
|
with open("saved_data/chunks.pkl", "wb") as f:
|
||||||
|
pickle.dump(chunks, f)
|
||||||
|
print(f"Saved {len(chunks)} chunks to saved_data/chunks.pkl")
|
||||||
|
|
||||||
|
embeddings = CustomHuggingFaceEmbeddings()
|
||||||
|
|
||||||
|
# Create a FAISS vector store from the document chunks and save it locally
|
||||||
|
vectorstore = FAISS.from_documents(chunks, embeddings)
|
||||||
|
vectorstore.save_local("faiss_index")
|
||||||
|
print("Saved FAISS index to 'faiss_index'")
|
||||||
|
|
||||||
|
# ========= Part 2: QA Generation using Llama Backend =========
|
||||||
|
|
||||||
|
# Setup Llama backend via unsloth and vLLM
|
||||||
|
from unsloth import FastLanguageModel
|
||||||
|
from vllm import SamplingParams
|
||||||
|
|
||||||
|
import rl_helpers # Ensure you have this or remove if not used
|
||||||
|
|
||||||
|
# Load the Llama model (adjust parameters as needed)
|
||||||
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||||
|
model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
|
||||||
|
max_seq_length=4096,
|
||||||
|
load_in_4bit=True, # Use 4-bit quantization if desired
|
||||||
|
fast_inference=True, # Enable fast inference
|
||||||
|
gpu_memory_utilization=0.6, # Adjust based on your GPU memory
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define sampling parameters for generation
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.3,
|
||||||
|
top_p=0.95,
|
||||||
|
max_tokens=4096,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_generate(prompts: List[str]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Given a list of prompt strings, returns a list of generated outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def format_input(text: str) -> str:
|
||||||
|
return tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": text}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted = [format_input(p) for p in prompts]
|
||||||
|
outputs = model.fast_generate(formatted, sampling_params=sampling_params)
|
||||||
|
return [output.outputs[0].text for output in outputs]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_qa_block(block: str) -> Optional[Tuple[str, str, str]]:
|
||||||
|
"""
|
||||||
|
Parses a QA block that should contain exactly three non-empty lines:
|
||||||
|
- A line starting with "Question:"
|
||||||
|
- A line starting with "Answer:"
|
||||||
|
- A line starting with "Difficulty:"
|
||||||
|
|
||||||
|
If the markers are not present but the block contains exactly three lines,
|
||||||
|
those are used in order.
|
||||||
|
|
||||||
|
Returns a tuple (question, answer, difficulty) or None if parsing fails.
|
||||||
|
"""
|
||||||
|
lines = [line.strip() for line in block.splitlines() if line.strip()]
|
||||||
|
if not lines:
|
||||||
|
return None
|
||||||
|
|
||||||
|
question, answer, difficulty = None, None, None
|
||||||
|
for line in lines:
|
||||||
|
lower = line.lower()
|
||||||
|
if question is None and lower.startswith("question:"):
|
||||||
|
question = line[len("question:") :].strip()
|
||||||
|
elif answer is None and lower.startswith("answer:"):
|
||||||
|
answer = line[len("answer:") :].strip()
|
||||||
|
elif difficulty is None and lower.startswith("difficulty:"):
|
||||||
|
difficulty = line[len("difficulty:") :].strip()
|
||||||
|
|
||||||
|
if question and answer and difficulty:
|
||||||
|
return question, answer, difficulty
|
||||||
|
if len(lines) == 3:
|
||||||
|
return lines[0], lines[1], lines[2]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_multiple_qa_output(output: str) -> List[Tuple[str, str, str]]:
|
||||||
|
"""
|
||||||
|
Splits the output into blocks (separated by one or more blank lines) and
|
||||||
|
attempts to parse each as a QA pair.
|
||||||
|
|
||||||
|
Returns a list of successfully parsed QA tuples.
|
||||||
|
"""
|
||||||
|
blocks = re.split(r"\n\s*\n", output.strip())
|
||||||
|
qa_pairs = []
|
||||||
|
for block in blocks:
|
||||||
|
parsed = parse_qa_block(block)
|
||||||
|
if parsed:
|
||||||
|
qa_pairs.append(parsed)
|
||||||
|
return qa_pairs
|
||||||
|
|
||||||
|
|
||||||
|
def generate_question_batch_for_chunks(
|
||||||
|
chunks: List, num_questions: int = 2, difficulty: str = None
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Generates QA pairs for multiple chunks in batch.
|
||||||
|
|
||||||
|
For each chunk (except the first and last), a sliding window is used for context:
|
||||||
|
- before: previous chunk's content
|
||||||
|
- current: current chunk's content
|
||||||
|
- after: next chunk's content
|
||||||
|
|
||||||
|
Each prompt instructs the model to output exactly three lines per QA pair with markers.
|
||||||
|
Failed prompts are retried once in batch; if still unsuccessful, they are skipped.
|
||||||
|
|
||||||
|
Returns a list of dicts with keys: "chunk_id", "question", "answer", "difficulty".
|
||||||
|
"""
|
||||||
|
prompts = []
|
||||||
|
chunk_ids = []
|
||||||
|
|
||||||
|
# Prepare prompts using a sliding window
|
||||||
|
for i in range(1, len(chunks) - 1):
|
||||||
|
before = chunks[i - 1].page_content
|
||||||
|
current = chunks[i].page_content
|
||||||
|
after = chunks[i + 1].page_content
|
||||||
|
prompt = (
|
||||||
|
f"From the text within ==BEGIN== and ==END==, generate {num_questions} questions with answers.\n"
|
||||||
|
"For each QA pair, output exactly three lines with no extra commentary:\n"
|
||||||
|
"Line 1: Question: <your question>\n"
|
||||||
|
"Line 2: Answer: <the answer>\n"
|
||||||
|
"Line 3: Difficulty: <easy, medium, or hard>\n"
|
||||||
|
"Do not include any additional text.\n\n"
|
||||||
|
"==BEGIN==\n"
|
||||||
|
f"{before}\n{current}\n{after}\n"
|
||||||
|
"==END==\n"
|
||||||
|
)
|
||||||
|
prompts.append(prompt)
|
||||||
|
chunk_ids.append(i)
|
||||||
|
|
||||||
|
# First batch generation
|
||||||
|
outputs = batch_generate(prompts)
|
||||||
|
results = [None] * len(outputs)
|
||||||
|
failed_indices = []
|
||||||
|
|
||||||
|
# Parse each output
|
||||||
|
for idx, output in enumerate(outputs):
|
||||||
|
qa_pairs = parse_multiple_qa_output(output)
|
||||||
|
if qa_pairs is None or len(qa_pairs) < num_questions:
|
||||||
|
failed_indices.append(idx)
|
||||||
|
else:
|
||||||
|
results[idx] = qa_pairs[:num_questions]
|
||||||
|
|
||||||
|
# Retry failed prompts in batch
|
||||||
|
if failed_indices:
|
||||||
|
print(f"Retrying {len(failed_indices)} failed prompt(s)...")
|
||||||
|
retry_prompts = [prompts[i] for i in failed_indices]
|
||||||
|
retry_outputs = batch_generate(retry_prompts)
|
||||||
|
for j, idx in enumerate(failed_indices):
|
||||||
|
qa_pairs = parse_multiple_qa_output(retry_outputs[j])
|
||||||
|
if qa_pairs is not None and len(qa_pairs) >= num_questions:
|
||||||
|
results[idx] = qa_pairs[:num_questions]
|
||||||
|
else:
|
||||||
|
results[idx] = None # Mark as failed
|
||||||
|
|
||||||
|
# Build final output, skipping prompts that failed even after retry
|
||||||
|
final_questions = []
|
||||||
|
for i, qa_list in enumerate(results):
|
||||||
|
if qa_list is not None:
|
||||||
|
for qa in qa_list:
|
||||||
|
final_questions.append(
|
||||||
|
{
|
||||||
|
"chunk_id": chunk_ids[i],
|
||||||
|
"question": qa[0],
|
||||||
|
"answer": qa[1],
|
||||||
|
"difficulty": qa[2],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return final_questions
|
||||||
|
|
||||||
|
|
||||||
|
# Generate QA pairs in batch (using a sliding window over the chunks)
|
||||||
|
all_questions = generate_question_batch_for_chunks(
|
||||||
|
chunks, num_questions=2, difficulty="medium"
|
||||||
|
)
|
||||||
|
print(f"Generated {len(all_questions)} QA pairs.")
|
||||||
|
|
||||||
|
# Save the QA pairs to a JSON file
|
||||||
|
questions_path = os.path.join("saved_data", "questions.json")
|
||||||
|
with open(questions_path, "w") as f:
|
||||||
|
json.dump(all_questions, f, indent=2)
|
||||||
|
print(f"Saved questions to {questions_path}")
|
@ -0,0 +1,2 @@
|
|||||||
|
unsloth_compiled_cache
|
||||||
|
0_*
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,17 @@
|
|||||||
|
datasets
|
||||||
|
faiss-cpu
|
||||||
|
langchain
|
||||||
|
langchain-community
|
||||||
|
Markdown
|
||||||
|
tokenizers
|
||||||
|
transformers
|
||||||
|
unsloth==2025.3.6
|
||||||
|
unsloth_zoo==2025.3.4
|
||||||
|
unstructured
|
||||||
|
vllm
|
||||||
|
wandb
|
||||||
|
|
||||||
|
ipykernel
|
||||||
|
python-dotenv
|
||||||
|
loguru
|
||||||
|
gradio
|
@ -0,0 +1,540 @@
|
|||||||
|
"""
|
||||||
|
RL helpers module for handling tool-based conversations.
|
||||||
|
This module provides utility functions for handling chat-based tool interactions
|
||||||
|
and calculating rewards based on the quality of responses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import nest_asyncio
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from search_module import get_qa_dataset, search
|
||||||
|
|
||||||
|
nest_asyncio.apply()
|
||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
|
from trl.trainer.grpo_trainer import apply_chat_template
|
||||||
|
|
||||||
|
|
||||||
|
# Constants for prompts and tool definitions
|
||||||
|
def get_system_prompt():
|
||||||
|
"""Get the system prompt with current date."""
|
||||||
|
current_date = datetime.now().strftime("%d %b %Y")
|
||||||
|
return f"""Cutting Knowledge Date: December 2023
|
||||||
|
Today Date: {current_date}
|
||||||
|
|
||||||
|
When you receive a tool call response, use the output to format an answer to the original user question.
|
||||||
|
|
||||||
|
You are a helpful assistant with tool calling capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Tool definition for search corpus
|
||||||
|
SEARCH_TOOL_DEFINITION = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "search_corpus",
|
||||||
|
"description": "Search over the knowledge corpus with a given query",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The query to search the knowledge corpus with",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_user_prompt(q):
|
||||||
|
"""
|
||||||
|
Build a user prompt with the question and search tool definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (str): The question to ask
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Formatted user prompt
|
||||||
|
"""
|
||||||
|
user_prompt = f"""You are a research assistant, and you use the search_corpus tool to find answers to questions.
|
||||||
|
Given a question, answer it using by doing searches using the search_corpus tool.
|
||||||
|
To use the search_corpus tool, respond with a JSON for a function call with its proper arguments.
|
||||||
|
|
||||||
|
You may also reason in any message, thinking step by step about how to answer the question. Wrap your reasoning in <reasoning> and </reasoning> tags.
|
||||||
|
|
||||||
|
{json.dumps(SEARCH_TOOL_DEFINITION, indent=2)}
|
||||||
|
|
||||||
|
Question: {q}
|
||||||
|
"""
|
||||||
|
return user_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def get_initial_chat(question):
|
||||||
|
"""
|
||||||
|
Initialize a chat state with the question.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question (str): The question to ask
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Initial chat state with system and user messages
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": get_system_prompt()},
|
||||||
|
{"role": "user", "content": build_user_prompt(question)},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def extract_json_objects(text):
|
||||||
|
"""
|
||||||
|
Extracts JSON objects (dictionaries) from a text that may contain multiple JSON objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The input text possibly containing JSON objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of parsed JSON objects (dictionaries) extracted from the text.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
length = len(text)
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
while i < length:
|
||||||
|
# Look for the start of a JSON object
|
||||||
|
if text[i] == "{":
|
||||||
|
start = i
|
||||||
|
stack = 1
|
||||||
|
i += 1
|
||||||
|
# Continue until we find the matching closing brace
|
||||||
|
while i < length and stack > 0:
|
||||||
|
if text[i] == "{":
|
||||||
|
stack += 1
|
||||||
|
elif text[i] == "}":
|
||||||
|
stack -= 1
|
||||||
|
i += 1
|
||||||
|
# Only attempt to decode if the braces are balanced
|
||||||
|
if stack == 0:
|
||||||
|
candidate = text[start:i]
|
||||||
|
try:
|
||||||
|
obj = json.loads(candidate)
|
||||||
|
# Optionally, ensure it's a dictionary if that's what you expect
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
results.append(obj)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If it's not valid JSON, skip it.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def remove_reasoning(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Removes all content between <reasoning> and </reasoning> tags,
|
||||||
|
including the tags themselves.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
text (str): The input text that may contain <reasoning>...</reasoning> tags.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The text with the tags and their content removed.
|
||||||
|
"""
|
||||||
|
# The regex pattern matches from <reasoning> to </reasoning> non-greedily.
|
||||||
|
pattern = r"<reasoning>.*?</reasoning>"
|
||||||
|
cleaned_text = re.sub(pattern, "", text, flags=re.DOTALL)
|
||||||
|
return cleaned_text
|
||||||
|
|
||||||
|
|
||||||
|
def run_agent_generations(generate_fn, tokenizer, chat_states):
|
||||||
|
"""
|
||||||
|
Run generation for chat states requiring assistant responses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generate_fn: Function to generate responses
|
||||||
|
tokenizer: Tokenizer for processing text
|
||||||
|
chat_states: List of chat states
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Updated chat states
|
||||||
|
"""
|
||||||
|
prompts = []
|
||||||
|
batch_indices = []
|
||||||
|
# Prepare prompts for chat states needing an assistant response.
|
||||||
|
for idx, chat_state in enumerate(chat_states):
|
||||||
|
if chat_state.get("finished"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if chat_state["messages"][-1]["role"] in ["ipython", "user"]:
|
||||||
|
prompt = apply_chat_template(chat_state, tokenizer=tokenizer)["text"]
|
||||||
|
prompts.append(prompt)
|
||||||
|
batch_indices.append(idx)
|
||||||
|
|
||||||
|
if prompts:
|
||||||
|
responses = generate_fn(prompts)
|
||||||
|
for i, idx in enumerate(batch_indices):
|
||||||
|
chat_state = chat_states[idx]
|
||||||
|
full_response = responses[i].outputs[0].text
|
||||||
|
assistant_response = full_response.split(
|
||||||
|
"<|start_header_id|>assistant<|end_header_id|>"
|
||||||
|
)[-1]
|
||||||
|
chat_state["messages"].append(
|
||||||
|
{"role": "assistant", "content": assistant_response}
|
||||||
|
)
|
||||||
|
return chat_states
|
||||||
|
|
||||||
|
|
||||||
|
def check_finished_chats(chat_states):
|
||||||
|
"""
|
||||||
|
Check which chat states are finished (no more function calls).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_states: List of chat states
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Updated chat states with finished flag
|
||||||
|
"""
|
||||||
|
for chat_state in chat_states:
|
||||||
|
if chat_state.get("finished"):
|
||||||
|
continue
|
||||||
|
assert (
|
||||||
|
chat_state["messages"][-1]["role"] == "assistant"
|
||||||
|
), "Expected the last role to be assistant"
|
||||||
|
assistant_response = chat_state["messages"][-1]["content"]
|
||||||
|
function_calls = extract_json_objects(assistant_response)
|
||||||
|
if len(function_calls) == 0:
|
||||||
|
chat_state["finished"] = True
|
||||||
|
return chat_states
|
||||||
|
|
||||||
|
|
||||||
|
def run_tool_calls(chat_states):
|
||||||
|
"""
|
||||||
|
Execute tool calls found in chat states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_states: List of chat states
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Updated chat states with tool call results
|
||||||
|
"""
|
||||||
|
for chat_state in chat_states:
|
||||||
|
if chat_state.get("finished"):
|
||||||
|
continue
|
||||||
|
assert (
|
||||||
|
chat_state["messages"][-1]["role"] == "assistant"
|
||||||
|
), "Expected the last role to be assistant to run tool calls"
|
||||||
|
try:
|
||||||
|
assistant_response = chat_state["messages"][-1]["content"]
|
||||||
|
function_calls = extract_json_objects(assistant_response)
|
||||||
|
if len(function_calls) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Expected only one function call in assistant response"
|
||||||
|
)
|
||||||
|
elif len(function_calls) == 1:
|
||||||
|
function_call = function_calls[0]
|
||||||
|
query = function_call["function"]["parameters"]["query"]
|
||||||
|
results = search(query, return_type=str, results=2)
|
||||||
|
chat_state["messages"].append({"role": "ipython", "content": results})
|
||||||
|
except Exception as e:
|
||||||
|
chat_state["messages"].append(
|
||||||
|
{"role": "system", "content": f"Error during post-processing: {str(e)}"}
|
||||||
|
)
|
||||||
|
chat_state["finished"] = True
|
||||||
|
return chat_states
|
||||||
|
|
||||||
|
|
||||||
|
def get_mask(text, tokenizer):
|
||||||
|
encoding = tokenizer(text, add_special_tokens=False)
|
||||||
|
start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
|
||||||
|
assistant_token = tokenizer.convert_tokens_to_ids("assistant")
|
||||||
|
end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
|
||||||
|
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
||||||
|
assistant_ranges = []
|
||||||
|
i = 0
|
||||||
|
while i < len(encoding.input_ids) - 1:
|
||||||
|
if (
|
||||||
|
encoding.input_ids[i] == start_header_id
|
||||||
|
and encoding.input_ids[i + 1] == assistant_token
|
||||||
|
):
|
||||||
|
i += 2
|
||||||
|
while (
|
||||||
|
i < len(encoding.input_ids) and encoding.input_ids[i] != end_header_id
|
||||||
|
):
|
||||||
|
i += 1
|
||||||
|
i += 2
|
||||||
|
start_idx = i
|
||||||
|
while i < len(encoding.input_ids) and encoding.input_ids[i] != eot_id:
|
||||||
|
i += 1
|
||||||
|
end_idx = i
|
||||||
|
assistant_ranges.append((start_idx, end_idx))
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
mask = [0] * len(encoding.input_ids)
|
||||||
|
for start_idx, end_idx in assistant_ranges:
|
||||||
|
for idx in range(start_idx, end_idx):
|
||||||
|
mask[idx] = 1
|
||||||
|
return torch.tensor(mask, dtype=torch.int)
|
||||||
|
|
||||||
|
|
||||||
|
def check_exceeded_max_new_tokens(chat_states, max_new_tokens, tokenizer):
|
||||||
|
for chat_state in chat_states:
|
||||||
|
if chat_state.get("finished"):
|
||||||
|
continue
|
||||||
|
initial_length = chat_state["initial_length"]
|
||||||
|
new_length = get_chat_num_tokens(chat_state, tokenizer)
|
||||||
|
if new_length - initial_length > max_new_tokens:
|
||||||
|
chat_state["finished"] = True
|
||||||
|
return chat_states
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgenticOutputs:
|
||||||
|
prompt_tokens: list[torch.Tensor]
|
||||||
|
response_tokens: list[torch.Tensor]
|
||||||
|
response_masks: list[torch.Tensor]
|
||||||
|
final_response_str: list[str]
|
||||||
|
full_chat_states: list[dict]
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_num_tokens(chat_state, tokenizer):
|
||||||
|
chat_text = apply_chat_template(chat_state, tokenizer=tokenizer)["text"]
|
||||||
|
return (
|
||||||
|
tokenizer(chat_text, add_special_tokens=False, return_tensors="pt")["input_ids"]
|
||||||
|
.squeeze()
|
||||||
|
.shape[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_agent(
|
||||||
|
generate_fn, tokenizer, questions, max_generations=5, max_new_tokens=4096
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run the agent to completion for a batch of questions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generate_fn: Function to generate model responses
|
||||||
|
tokenizer: Tokenizer for processing text
|
||||||
|
batch: Batch of data containing questions
|
||||||
|
max_generations: Maximum number of generation steps
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: Final answers for each question
|
||||||
|
"""
|
||||||
|
chat_states = [get_initial_chat(q) for q in questions]
|
||||||
|
# set the initial_prompt length
|
||||||
|
for chat_state in chat_states:
|
||||||
|
chat_state["initial_length"] = get_chat_num_tokens(chat_state, tokenizer)
|
||||||
|
|
||||||
|
# agent loop
|
||||||
|
for i in range(max_generations):
|
||||||
|
chat_states = run_agent_generations(generate_fn, tokenizer, chat_states)
|
||||||
|
chat_states = check_finished_chats(chat_states)
|
||||||
|
chat_states = run_tool_calls(chat_states)
|
||||||
|
chat_states = check_exceeded_max_new_tokens(
|
||||||
|
chat_states, max_new_tokens, tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
answers = []
|
||||||
|
for chat in chat_states:
|
||||||
|
answers.append(chat["messages"][-1]["content"])
|
||||||
|
|
||||||
|
def split_prompt_assistant(convo_text):
|
||||||
|
marker = "<|start_header_id|>assistant<|end_header_id|>"
|
||||||
|
idx = convo_text.find(marker)
|
||||||
|
if idx == -1:
|
||||||
|
raise ValueError("Could not find assistant marker in conversation text.")
|
||||||
|
return convo_text, ""
|
||||||
|
# Include the marker in the prompt by slicing up to the end of the marker.
|
||||||
|
prompt = convo_text[: idx + len(marker)]
|
||||||
|
# The assistant response is everything after the marker.
|
||||||
|
assistant_response = convo_text[idx + len(marker) :]
|
||||||
|
return prompt, assistant_response
|
||||||
|
|
||||||
|
str_chats = [
|
||||||
|
apply_chat_template(chat, tokenizer=tokenizer)["text"] for chat in chat_states
|
||||||
|
]
|
||||||
|
prompt_toks, response_toks, response_masks = [], [], []
|
||||||
|
for str_chat in str_chats:
|
||||||
|
prompt, response = split_prompt_assistant(str_chat)
|
||||||
|
prompt_toks.append(
|
||||||
|
tokenizer(prompt, add_special_tokens=False, return_tensors="pt")[
|
||||||
|
"input_ids"
|
||||||
|
].squeeze()
|
||||||
|
)
|
||||||
|
response_toks.append(
|
||||||
|
tokenizer(response, add_special_tokens=False, return_tensors="pt")[
|
||||||
|
"input_ids"
|
||||||
|
].squeeze()[:max_new_tokens]
|
||||||
|
)
|
||||||
|
mask = get_mask(str_chat, tokenizer)[len(prompt_toks[-1]) :][:max_new_tokens]
|
||||||
|
|
||||||
|
response_masks.append(mask)
|
||||||
|
|
||||||
|
final_response_str = [chat["messages"][-1]["content"] for chat in chat_states]
|
||||||
|
full_chat_states = chat_states
|
||||||
|
agentic_outputs = AgenticOutputs(
|
||||||
|
prompt_tokens=prompt_toks,
|
||||||
|
response_tokens=response_toks,
|
||||||
|
response_masks=response_masks,
|
||||||
|
final_response_str=final_response_str,
|
||||||
|
full_chat_states=full_chat_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
return agentic_outputs
|
||||||
|
|
||||||
|
|
||||||
|
# Verification
|
||||||
|
async def check_correctness(question, student_answer, answer):
|
||||||
|
"""
|
||||||
|
Calculate reward for a given student answer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question (str): The original question
|
||||||
|
student_answer (str): The model's answer
|
||||||
|
answer (str): The ground truth answer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Reward value (1 for correct, 0 for incorrect)
|
||||||
|
"""
|
||||||
|
# log to "./reward_func.log"
|
||||||
|
with open("reward_func.log", "a") as f:
|
||||||
|
f.write("\n" + "==" * 40 + "\n\n")
|
||||||
|
f.write(f"Question: {question}\n")
|
||||||
|
f.write(f"Student Answer: {student_answer}\n")
|
||||||
|
f.write(f"Answer: {answer}\n")
|
||||||
|
if student_answer.startswith("Error during"):
|
||||||
|
f.write(f"failed function call")
|
||||||
|
return 0
|
||||||
|
if len(student_answer) < 5:
|
||||||
|
f.write(f"failed Too short answer\n")
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
f.write(f"last message didn't fail\n")
|
||||||
|
student_answer_clean = remove_reasoning(student_answer)
|
||||||
|
is_correct = await verify(student_answer_clean, question, answer)
|
||||||
|
f.write(f"Is Correct: {is_correct}, so reward is {int(is_correct)}\n")
|
||||||
|
return 1 if is_correct else 0
|
||||||
|
|
||||||
|
|
||||||
|
def check_student_answers(
|
||||||
|
questions: List[str],
|
||||||
|
answers: List[str],
|
||||||
|
student_answers: List[str],
|
||||||
|
vllm_generate_func: Callable[[List[str]], List[str]],
|
||||||
|
tokenizer,
|
||||||
|
log_file: str = "qa_log.txt",
|
||||||
|
) -> List[bool]:
|
||||||
|
"""
|
||||||
|
Evaluates a list of student answers against the true answers using a vLLM generate function.
|
||||||
|
The function applies the chat template to each prompt before passing it to the generate function.
|
||||||
|
It also appends the details of each QA pair and the verifier's response to a log file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
questions: A list of strings representing the questions.
|
||||||
|
answers: A list of strings representing the correct answers.
|
||||||
|
student_answers: A list of strings containing the student's answers.
|
||||||
|
vllm_generate_func: A function that takes a list of chat-formatted prompt strings and returns a list of generated outputs.
|
||||||
|
tokenizer: The tokenizer used to apply the chat template.
|
||||||
|
log_file: Optional; path to the file where the QA pairs and verification responses will be appended.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of booleans indicating whether each student's answer is correct.
|
||||||
|
"""
|
||||||
|
if not (len(questions) == len(answers) == len(student_answers)):
|
||||||
|
raise ValueError(
|
||||||
|
"The number of questions, answers, and student answers must be equal."
|
||||||
|
)
|
||||||
|
|
||||||
|
prompts = []
|
||||||
|
for question, answer, student_ans in zip(questions, answers, student_answers):
|
||||||
|
# Construct the plain text prompt for each QA pair.
|
||||||
|
prompt_text = (
|
||||||
|
"You are grading a student's answer. For the following question, "
|
||||||
|
"compare the student's answer to the correct answer. Reply with 'Yes' if the student's answer is correct, or 'No' if it is completely incorrect.\n\n"
|
||||||
|
f"Question: {question}\n"
|
||||||
|
f"Correct Answer: {answer}\n"
|
||||||
|
f"Student Answer: {student_ans}\n"
|
||||||
|
)
|
||||||
|
# Apply the chat template to the prompt.
|
||||||
|
formatted_prompt = tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": prompt_text}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
prompts.append(formatted_prompt)
|
||||||
|
|
||||||
|
# Get the model responses in batch (each response should ideally be "Yes" or "No")
|
||||||
|
responses = vllm_generate_func(prompts)
|
||||||
|
responses_text = [response.outputs[0].text for response in responses]
|
||||||
|
|
||||||
|
# Evaluate each response and mark as correct if "yes" appears in the answer (case-insensitive)
|
||||||
|
results = []
|
||||||
|
for response in responses_text:
|
||||||
|
results.append("yes" in response.lower())
|
||||||
|
|
||||||
|
# Append the QA details and verifier's response to the specified log file
|
||||||
|
with open(log_file, "a") as file:
|
||||||
|
for question, answer, student_ans, verifier_response in zip(
|
||||||
|
questions, answers, student_answers, responses_text
|
||||||
|
):
|
||||||
|
file.write("Question: " + question + "\n")
|
||||||
|
file.write("Correct Answer: " + answer + "\n")
|
||||||
|
file.write("Student Answer: " + student_ans + "\n")
|
||||||
|
file.write("Verifier said: " + verifier_response + "\n")
|
||||||
|
file.write("-" * 40 + "\n")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def build_reward_correctness_fn(generate_fn, tokenizer):
|
||||||
|
def reward_correctness(prompts, completions, **reward_kwargs):
|
||||||
|
teacher_answers = reward_kwargs["answer"]
|
||||||
|
student_answers = [
|
||||||
|
completion["messages"][-1]["content"] for completion in completions
|
||||||
|
]
|
||||||
|
|
||||||
|
correct = check_student_answers(
|
||||||
|
prompts,
|
||||||
|
teacher_answers,
|
||||||
|
student_answers,
|
||||||
|
vllm_generate_func=generate_fn,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
return correct
|
||||||
|
|
||||||
|
return reward_correctness
|
||||||
|
|
||||||
|
|
||||||
|
def reward_formatting(prompts, completions, **reward_kwargs):
|
||||||
|
# make sure full chats doesn't have any error function calls
|
||||||
|
has_error = [False] * len(completions)
|
||||||
|
for i, chat in enumerate(completions):
|
||||||
|
for message in chat["messages"]:
|
||||||
|
if "Error during" in message["content"]:
|
||||||
|
has_error[i] = True
|
||||||
|
break
|
||||||
|
return [0.7 if not e else 0 for e in has_error]
|
||||||
|
|
||||||
|
|
||||||
|
def run_eval(generate_fn, verify_fn, tokenizer):
|
||||||
|
train_dataset, test_dataset = get_qa_dataset()
|
||||||
|
questions = test_dataset["prompt"]
|
||||||
|
agentic_outputs = run_agent(generate_fn, tokenizer, questions)
|
||||||
|
full_chat_states = agentic_outputs.full_chat_states
|
||||||
|
final_responses = agentic_outputs.final_response_str
|
||||||
|
rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"])
|
||||||
|
|
||||||
|
print("RESULTS:")
|
||||||
|
print("percentage of correct answers:", sum(rewards) / len(rewards))
|
||||||
|
print("=" * 30)
|
||||||
|
|
||||||
|
return full_chat_states
|
@ -0,0 +1,194 @@
|
|||||||
|
"""
|
||||||
|
Search module for RL training loop.
|
||||||
|
This module provides functions to search through vectorized documents and retrieve question-answer pairs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import asyncio
|
||||||
|
from typing import List, Tuple, Optional, Union, Dict, Any
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from langchain.vectorstores import FAISS
|
||||||
|
from datasets import Dataset
|
||||||
|
from embeddings import CustomHuggingFaceEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
# Load pre-saved vectorstore
|
||||||
|
def load_vectorstore():
|
||||||
|
"""Load the pre-saved FAISS index"""
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
|
||||||
|
embeddings = CustomHuggingFaceEmbeddings()
|
||||||
|
# Load the FAISS index with absolute path
|
||||||
|
index_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)), "faiss_index"
|
||||||
|
)
|
||||||
|
print(f"Loading FAISS index from: {index_path}")
|
||||||
|
vectorstore = FAISS.load_local(
|
||||||
|
index_path, embeddings, allow_dangerous_deserialization=True
|
||||||
|
)
|
||||||
|
print("Successfully loaded FAISS index")
|
||||||
|
return vectorstore
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading vectorstore: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Load the vectorstore when module is imported
|
||||||
|
try:
|
||||||
|
vectorstore = load_vectorstore()
|
||||||
|
if vectorstore is None:
|
||||||
|
print("Warning: FAISS vectorstore could not be loaded.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading vectorstore: {e}")
|
||||||
|
vectorstore = None
|
||||||
|
|
||||||
|
|
||||||
|
def search(query: str, return_type=str, results: int = 5) -> Union[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Search for relevant chunks using similarity search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query
|
||||||
|
return_type: Return as string or list (default: str)
|
||||||
|
results: Number of results to return (default: 5)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Results as string or list depending on return_type
|
||||||
|
"""
|
||||||
|
if vectorstore is None:
|
||||||
|
raise ValueError("Vectorstore not loaded. Please ensure FAISS index exists.")
|
||||||
|
|
||||||
|
search_results = vectorstore.similarity_search(query, k=results)
|
||||||
|
|
||||||
|
if return_type == str:
|
||||||
|
str_results = ""
|
||||||
|
for idx, result in enumerate(search_results, start=1):
|
||||||
|
str_results += f"Result {idx}:\n"
|
||||||
|
str_results += result.page_content + "\n"
|
||||||
|
str_results += "------\n"
|
||||||
|
return str_results
|
||||||
|
elif return_type == list:
|
||||||
|
return [result.page_content for result in search_results]
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid return_type. Use str or list.")
|
||||||
|
|
||||||
|
|
||||||
|
# Load questions from saved data
|
||||||
|
def load_qa_data():
|
||||||
|
"""Load the pre-generated questions and document chunks"""
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Get absolute paths to data files
|
||||||
|
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
chunks_path = os.path.join(base_dir, "saved_data", "chunks.pkl")
|
||||||
|
questions_path = os.path.join(base_dir, "saved_data", "questions.json")
|
||||||
|
|
||||||
|
print(f"Loading chunks from: {chunks_path}")
|
||||||
|
print(f"Loading questions from: {questions_path}")
|
||||||
|
|
||||||
|
# Load the chunks
|
||||||
|
with open(chunks_path, "rb") as f:
|
||||||
|
chunks = pickle.load(f)
|
||||||
|
|
||||||
|
# Load the questions
|
||||||
|
with open(questions_path, "r") as f:
|
||||||
|
questions = json.load(f)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Successfully loaded {len(chunks)} chunks and {len(questions)} questions"
|
||||||
|
)
|
||||||
|
return chunks, questions
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading QA data: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
# Load chunks and questions when module is imported
|
||||||
|
try:
|
||||||
|
chunks, questions = load_qa_data()
|
||||||
|
if chunks is None or questions is None:
|
||||||
|
print("Warning: Could not load QA data.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error initializing QA data: {e}")
|
||||||
|
chunks, questions = None, None
|
||||||
|
|
||||||
|
|
||||||
|
def get_question_answer(
|
||||||
|
idx: Optional[int] = None, return_both: bool = True
|
||||||
|
) -> Union[dict, str]:
|
||||||
|
"""
|
||||||
|
Get a question-answer pair either by index or randomly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx: Index of the question to retrieve (if None, selects random question)
|
||||||
|
return_both: Whether to return both question and answer (default: True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Question and answer as tuple if return_both=True, otherwise just the question
|
||||||
|
"""
|
||||||
|
if questions is None:
|
||||||
|
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
|
||||||
|
|
||||||
|
if idx is None:
|
||||||
|
# Select a random question
|
||||||
|
qa_pair = random.choice(questions)
|
||||||
|
elif 0 <= idx < len(questions):
|
||||||
|
# Select question by index
|
||||||
|
qa_pair = questions[idx]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Index out of range. Must be between 0 and {len(questions)-1}"
|
||||||
|
)
|
||||||
|
|
||||||
|
question = qa_pair["question"]
|
||||||
|
answer = qa_pair["answer"]
|
||||||
|
|
||||||
|
if return_both:
|
||||||
|
return {"question": question, "answer": answer}
|
||||||
|
else:
|
||||||
|
return question
|
||||||
|
|
||||||
|
|
||||||
|
# Function to get the total number of questions
|
||||||
|
def get_question_count() -> int:
|
||||||
|
"""Get the total number of available questions"""
|
||||||
|
if questions is None:
|
||||||
|
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
|
||||||
|
return len(questions)
|
||||||
|
|
||||||
|
|
||||||
|
def get_qa_dataset():
|
||||||
|
"""
|
||||||
|
Return a HuggingFace Dataset containing question and answer pairs.
|
||||||
|
|
||||||
|
This dataset is constructed from the loaded questions data (questions.json).
|
||||||
|
Each element in the dataset is a dictionary that includes at least:
|
||||||
|
- "question": The question text.
|
||||||
|
- "answer": The corresponding answer text.
|
||||||
|
Additional keys present in the original questions data will also be included.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A HuggingFace Dataset object.
|
||||||
|
"""
|
||||||
|
if questions is None:
|
||||||
|
raise ValueError("Questions not loaded. Please ensure questions.json exists.")
|
||||||
|
|
||||||
|
qa_dataset = Dataset.from_list(questions)
|
||||||
|
full_dataset = qa_dataset.shuffle(seed=42)
|
||||||
|
train_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["train"]
|
||||||
|
test_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)["test"]
|
||||||
|
# rename the column of the dataset from "question" to "input"
|
||||||
|
train_dataset = train_dataset.rename_column("question", "prompt")
|
||||||
|
test_dataset = test_dataset.rename_column("question", "prompt")
|
||||||
|
return train_dataset, test_dataset
|
@ -0,0 +1,200 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Simple command-line Q&A environment for testing with search functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
# Import our search module (ensure these functions follow the new interfaces)
|
||||||
|
from search_module import get_question_answer, get_question_count, search
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleQAEnvironment:
|
||||||
|
"""Simple command-line environment for Q&A with search capability."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.score = {"correct": 0, "incorrect": 0, "total": 0}
|
||||||
|
self.session_data = []
|
||||||
|
self.current_question = None
|
||||||
|
|
||||||
|
def display_welcome(self):
|
||||||
|
"""Display welcome message and instructions."""
|
||||||
|
print("\n===== Search & Answer Environment =====")
|
||||||
|
print("Answer questions using the search tool to find relevant information.")
|
||||||
|
print("Type 'q' to quit, 'h' for help.\n")
|
||||||
|
|
||||||
|
def display_help(self):
|
||||||
|
"""Display help information."""
|
||||||
|
print("\n===== Commands =====")
|
||||||
|
print("n - Get a new question")
|
||||||
|
print("s <query> - Search for information (e.g., s program launch date)")
|
||||||
|
print("a <answer> - Submit your answer")
|
||||||
|
print("h - Display this help message")
|
||||||
|
print("q - Quit the program\n")
|
||||||
|
|
||||||
|
def display_question(self, question: str):
|
||||||
|
"""Display the current question."""
|
||||||
|
print("\n===== QUESTION =====")
|
||||||
|
print(question)
|
||||||
|
print("=====================\n")
|
||||||
|
|
||||||
|
def get_new_question(self) -> str:
|
||||||
|
"""Get a new random question and set it as current."""
|
||||||
|
total_questions = get_question_count()
|
||||||
|
question_id = random.randint(0, total_questions - 1)
|
||||||
|
|
||||||
|
# Updated to match new interface: get_question_answer now returns a dict.
|
||||||
|
qa = get_question_answer(question_id)
|
||||||
|
question = qa["question"]
|
||||||
|
correct_answer = qa["answer"]
|
||||||
|
|
||||||
|
question_data = {
|
||||||
|
"id": question_id,
|
||||||
|
"question": question,
|
||||||
|
"correct_answer": correct_answer,
|
||||||
|
"start_time": time.time(),
|
||||||
|
"searches": [],
|
||||||
|
}
|
||||||
|
self.current_question = question_data
|
||||||
|
return question
|
||||||
|
|
||||||
|
def perform_search(self, query: str):
|
||||||
|
"""Perform a search with the given query."""
|
||||||
|
if not query:
|
||||||
|
print("Please provide a search query.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("\n===== SEARCH RESULTS =====")
|
||||||
|
results = search(query)
|
||||||
|
print(results)
|
||||||
|
print("==========================\n")
|
||||||
|
|
||||||
|
# Record search in current question data if available.
|
||||||
|
if self.current_question is not None:
|
||||||
|
self.current_question["searches"].append(query)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error searching: {str(e)}")
|
||||||
|
|
||||||
|
async def process_answer(self, user_answer: str):
|
||||||
|
"""Process and verify the user's answer."""
|
||||||
|
if self.current_question is None:
|
||||||
|
print("Please get a question first.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not user_answer:
|
||||||
|
print("Please provide an answer.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Record answer and calculate time taken.
|
||||||
|
self.current_question["user_answer"] = user_answer
|
||||||
|
self.current_question["end_time"] = time.time()
|
||||||
|
self.current_question["time_taken"] = (
|
||||||
|
self.current_question["end_time"] - self.current_question["start_time"]
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("\nVerifying your answer...")
|
||||||
|
correct = await verify(
|
||||||
|
user_answer,
|
||||||
|
self.current_question["question"],
|
||||||
|
self.current_question["correct_answer"],
|
||||||
|
router,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update score and inform the user.
|
||||||
|
self.score["total"] += 1
|
||||||
|
if correct:
|
||||||
|
self.score["correct"] += 1
|
||||||
|
print("\n✓ Your answer is CORRECT!")
|
||||||
|
else:
|
||||||
|
self.score["incorrect"] += 1
|
||||||
|
print("\n✗ Your answer is INCORRECT.")
|
||||||
|
print(
|
||||||
|
f"\nThe correct answer is:\n{self.current_question['correct_answer']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nScore: {self.score['correct']}/{self.score['total']}")
|
||||||
|
|
||||||
|
# Record the result and add the current question to the session data.
|
||||||
|
self.current_question["is_correct"] = correct
|
||||||
|
self.session_data.append(self.current_question)
|
||||||
|
|
||||||
|
# Clear the current question.
|
||||||
|
self.current_question = None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error verifying answer: {str(e)}")
|
||||||
|
|
||||||
|
def save_session(self):
|
||||||
|
"""Save the session data to a file."""
|
||||||
|
if not self.session_data:
|
||||||
|
return
|
||||||
|
|
||||||
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"qa_session_{timestamp}.json"
|
||||||
|
|
||||||
|
session_data = {
|
||||||
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
"score": self.score,
|
||||||
|
"questions": self.session_data,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
json.dump(session_data, f, indent=2)
|
||||||
|
print(f"\nSession data saved to {filename}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error saving session data: {str(e)}")
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""Run the main command loop."""
|
||||||
|
self.display_welcome()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
command = input("\n> ").strip()
|
||||||
|
|
||||||
|
if not command:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Process commands.
|
||||||
|
if command.lower() == "q":
|
||||||
|
break
|
||||||
|
elif command.lower() == "h":
|
||||||
|
self.display_help()
|
||||||
|
elif command.lower() == "n":
|
||||||
|
question = self.get_new_question()
|
||||||
|
self.display_question(question)
|
||||||
|
elif command.lower().startswith("s "):
|
||||||
|
query = command[2:].strip()
|
||||||
|
self.perform_search(query)
|
||||||
|
elif command.lower().startswith("a "):
|
||||||
|
answer = command[2:].strip()
|
||||||
|
await self.process_answer(answer)
|
||||||
|
else:
|
||||||
|
print("Unknown command. Type 'h' for help.")
|
||||||
|
|
||||||
|
# Save session data on exit.
|
||||||
|
self.save_session()
|
||||||
|
print("\nThank you for using the Q&A environment!")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main function to start the application."""
|
||||||
|
env = SimpleQAEnvironment()
|
||||||
|
await env.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nProgram terminated by user.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nError: {str(e)}")
|
@ -0,0 +1,189 @@
|
|||||||
|
# %%
|
||||||
|
from unsloth import FastLanguageModel
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from unsloth import is_bfloat16_supported
|
||||||
|
import torch
|
||||||
|
|
||||||
|
max_seq_length = 4096 * 2 # Can increase for longer reasoning traces
|
||||||
|
lora_rank = 64 # Larger rank = smarter, but slower
|
||||||
|
|
||||||
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||||
|
model_name="meta-llama/meta-Llama-3.1-8B-Instruct",
|
||||||
|
max_seq_length=max_seq_length,
|
||||||
|
load_in_4bit=True, # False for LoRA 16bit
|
||||||
|
fast_inference=True, # Enable vLLM fast inference
|
||||||
|
max_lora_rank=lora_rank,
|
||||||
|
gpu_memory_utilization=0.6, # Reduce if out of memory
|
||||||
|
)
|
||||||
|
|
||||||
|
model = FastLanguageModel.get_peft_model(
|
||||||
|
model,
|
||||||
|
r=lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
|
||||||
|
target_modules=[
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
"down_proj",
|
||||||
|
], # Remove QKVO if out of memory
|
||||||
|
lora_alpha=lora_rank,
|
||||||
|
use_gradient_checkpointing="unsloth", # Enable long context finetuning
|
||||||
|
random_state=3407,
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
import re
|
||||||
|
from datasets import load_dataset, Dataset
|
||||||
|
from search_module import search, get_question_answer, get_question_count
|
||||||
|
from rl_helpers import get_qa_dataset
|
||||||
|
|
||||||
|
train_dataset, test_dataset = get_qa_dataset()
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# <a name="Train"></a>
|
||||||
|
# ### Train the model
|
||||||
|
#
|
||||||
|
# Now set up GRPO Trainer and all configurations!
|
||||||
|
|
||||||
|
# %%
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["WANDB_PROJECT"] = "bootstrap-search-rl"
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# from UnslothGRPOTrainerTemp import UnslothGRPOConfig, _UnslothGRPOTrainer
|
||||||
|
import UnslothGRPOTrainerTemp
|
||||||
|
|
||||||
|
training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
|
||||||
|
use_vllm=True, # use vLLM for fast inference!
|
||||||
|
use_agentic_generate=True, # use agentic generation
|
||||||
|
learning_rate=5e-6,
|
||||||
|
adam_beta1=0.9,
|
||||||
|
adam_beta2=0.99,
|
||||||
|
weight_decay=0.1,
|
||||||
|
warmup_ratio=0.1,
|
||||||
|
lr_scheduler_type="cosine",
|
||||||
|
optim="paged_adamw_8bit",
|
||||||
|
logging_steps=1,
|
||||||
|
bf16=is_bfloat16_supported(),
|
||||||
|
fp16=not is_bfloat16_supported(),
|
||||||
|
per_device_train_batch_size=8,
|
||||||
|
gradient_accumulation_steps=1, # Increase to 4 for smoother training
|
||||||
|
num_generations=8, # Decrease if out of memory
|
||||||
|
max_prompt_length=1024,
|
||||||
|
max_completion_length=1024,
|
||||||
|
# num_train_epochs = 1, # Set to 1 for a full training run
|
||||||
|
max_steps=101,
|
||||||
|
save_steps=50,
|
||||||
|
max_grad_norm=0.1,
|
||||||
|
report_to="none", # Can use Weights & Biases
|
||||||
|
output_dir="full_local_training",
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
|
||||||
|
import rl_helpers
|
||||||
|
# importlib.reload(rl_helpers)
|
||||||
|
|
||||||
|
|
||||||
|
def agentic_generate(
|
||||||
|
prompts: list[str],
|
||||||
|
generate_fn,
|
||||||
|
max_generations: int = 6,
|
||||||
|
):
|
||||||
|
return run_agent(generate_fn, tokenizer, prompts, max_generations)
|
||||||
|
|
||||||
|
|
||||||
|
model.agentic_generate = agentic_generate
|
||||||
|
|
||||||
|
|
||||||
|
from vllm import SamplingParams
|
||||||
|
|
||||||
|
verifier_sampling_params = SamplingParams(
|
||||||
|
temperature=0.1,
|
||||||
|
top_p=0.95,
|
||||||
|
max_tokens=4096,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def verifier_generate_fn(inputs):
|
||||||
|
return model.fast_generate(
|
||||||
|
inputs,
|
||||||
|
sampling_params=verifier_sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
run_agent = rl_helpers.run_agent
|
||||||
|
reward_correctness = rl_helpers.build_reward_correctness_fn(
|
||||||
|
verifier_generate_fn,
|
||||||
|
tokenizer,
|
||||||
|
)
|
||||||
|
reward_formatting = rl_helpers.reward_formatting
|
||||||
|
|
||||||
|
import UnslothGRPOTrainerTemp
|
||||||
|
|
||||||
|
trainer = UnslothGRPOTrainerTemp.UnslothGRPOTrainer(
|
||||||
|
model=model,
|
||||||
|
processing_class=tokenizer,
|
||||||
|
reward_funcs=[
|
||||||
|
reward_correctness,
|
||||||
|
reward_formatting,
|
||||||
|
],
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# <a name="Inference"></a>
|
||||||
|
# ### Inference
|
||||||
|
# Now let's try benchmark the model we trained!
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from vllm import SamplingParams
|
||||||
|
import rl_helpers
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.5,
|
||||||
|
top_p=0.95,
|
||||||
|
max_tokens=4096,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def eval_generate_fn(inputs):
|
||||||
|
return model.fast_generate(
|
||||||
|
inputs,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
lora_request=model.load_lora(
|
||||||
|
"full_local_training/checkpoint-101"
|
||||||
|
), # load the trained LoRA
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
rl_helpers.run_eval(
|
||||||
|
generate_fn=eval_generate_fn,
|
||||||
|
verify_fn=reward_correctness,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# eval w/o lora
|
||||||
|
def eval_generate_fn(inputs):
|
||||||
|
return model.fast_generate(
|
||||||
|
inputs,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
rl_helpers.run_eval(
|
||||||
|
generate_fn=eval_generate_fn,
|
||||||
|
verify_fn=reward_correctness,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
@ -1,26 +0,0 @@
|
|||||||
import sys,os,random,time
|
|
||||||
def f(x):return x*2 if x%2==0 else x+1
|
|
||||||
class C:
|
|
||||||
def __init__(self,v):self.v=v
|
|
||||||
def p(self):print("Value:",self.v)
|
|
||||||
def m(l):return [f(x) for x in l]
|
|
||||||
x=[random.randint(1,100) for _ in range(10)]
|
|
||||||
print("Original:",x)
|
|
||||||
print("Processed:",m(x))
|
|
||||||
for i in range(len(x)):
|
|
||||||
if i%2==0:
|
|
||||||
x[i]*=2
|
|
||||||
elif i%3==0:
|
|
||||||
x[i]+=3
|
|
||||||
else:
|
|
||||||
x[i]-=1
|
|
||||||
c=C(sum(x))
|
|
||||||
c.p()
|
|
||||||
try:
|
|
||||||
for i in range(5):print(i,x[i],f(x[i]))
|
|
||||||
except:pass
|
|
||||||
with open("temp.txt","w") as f:f.write("Hello, world!")
|
|
||||||
while True:
|
|
||||||
time.sleep(0.1)
|
|
||||||
if random.random()>0.9:break
|
|
||||||
print("Done!")
|
|
Loading…
Reference in new issue