refactor: change repo stucture (move code from src/ to src/deepsearch)

main
thinhlpg 1 month ago
parent e3163081a0
commit 2fec4f2f42

@ -12,7 +12,7 @@ from vllm import SamplingParams
load_dotenv(override=True) load_dotenv(override=True)
# Project paths # Project paths
PROJ_ROOT = Path(__file__).resolve().parent.parent PROJ_ROOT = Path(__file__).resolve().parent
DATA_DIR = PROJ_ROOT / "data" DATA_DIR = PROJ_ROOT / "data"
LOG_FOLDER = PROJ_ROOT / "logs" LOG_FOLDER = PROJ_ROOT / "logs"
@ -55,7 +55,7 @@ TRAINING_CONFIG = {
"logging_steps": 1, "logging_steps": 1,
"per_device_train_batch_size": 8, "per_device_train_batch_size": 8,
"gradient_accumulation_steps": 1, # Increase to 4 for smoother training "gradient_accumulation_steps": 1, # Increase to 4 for smoother training
"num_generations": 8, # Decrease if out of memory "num_generations": 6, # Decrease if out of memory
"max_prompt_length": 1024, "max_prompt_length": 1024,
"max_completion_length": 1024, "max_completion_length": 1024,
"max_steps": 101, "max_steps": 101,
@ -244,3 +244,7 @@ _init_logging(env=env)
# Log project root on import # Log project root on import
logger.info(f"Project root path: {PROJ_ROOT}") logger.info(f"Project root path: {PROJ_ROOT}")
logger.debug(f"Running in {env} environment") logger.debug(f"Running in {env} environment")
if __name__ == "__main__":
print(PROJ_ROOT)

@ -21,7 +21,7 @@ from src import (
get_system_prompt, get_system_prompt,
run_eval, run_eval,
) )
from src.config import MODEL_NAME, logger from config import MODEL_NAME, logger
def get_model_config(): def get_model_config():

@ -20,7 +20,7 @@ from src import (
format_search_results, format_search_results,
get_system_prompt, get_system_prompt,
) )
from src.search_module import load_vectorstore, search from src.deepsearch.search_module import load_vectorstore, search
def setup_model_and_tokenizer(model_path: str): def setup_model_and_tokenizer(model_path: str):

@ -20,7 +20,7 @@ from src import (
get_system_prompt, get_system_prompt,
run_eval, run_eval,
) )
from src.config import logger from config import logger
def main(): def main():

@ -21,7 +21,7 @@ from src import (
get_system_prompt, get_system_prompt,
run_eval, run_eval,
) )
from src.config import logger from config import logger
def main(): def main():

@ -28,8 +28,8 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import UnstructuredMarkdownLoader from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.vectorstores import FAISS from langchain_community.vectorstores import FAISS
from src.config import DATA_DIR, logger from config import DATA_DIR, logger
from src.embeddings import CustomHuggingFaceEmbeddings from src.deepsearch.embeddings import CustomHuggingFaceEmbeddings
# Load your markdown file (adjust the path as needed) # Load your markdown file (adjust the path as needed)
loader = UnstructuredMarkdownLoader("./data/mission_report.md") loader = UnstructuredMarkdownLoader("./data/mission_report.md")

@ -15,8 +15,8 @@ project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root)) sys.path.append(str(project_root))
# Import our search module and config # Import our search module and config
from src.config import DATA_DIR, logger from config import DATA_DIR, logger
from src.search_module import get_question_answer, get_question_count, search from src.deepsearch.search_module import get_question_answer, get_question_count, search
# TODO: Import verify function and router from appropriate module # TODO: Import verify function and router from appropriate module
# TODO: Consider moving verify function to search_module.py for better organization # TODO: Consider moving verify function to search_module.py for better organization

@ -4,18 +4,18 @@ Main package exports for RL helpers.
from trl.trainer.grpo_trainer import apply_chat_template from trl.trainer.grpo_trainer import apply_chat_template
from src.agent import Agent, extract_search_query from config import logger
from src.config import logger from src.deepsearch.agent import Agent, extract_search_query
from src.evaluation import check_student_answers, run_eval, verify from src.deepsearch.evaluation import check_student_answers, run_eval, verify
from src.prompts import build_user_prompt, format_search_results, get_system_prompt from src.deepsearch.prompts import build_user_prompt, format_search_results, get_system_prompt
from src.rewards import ( from src.deepsearch.rewards import (
build_reward_correctness_fn, build_reward_correctness_fn,
reward_em_chunk, reward_em_chunk,
reward_format, reward_format,
reward_retry, reward_retry,
) )
from src.search_module import get_qa_dataset, search from src.deepsearch.search_module import get_qa_dataset, search
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
__all__ = [ __all__ = [
# Prompts # Prompts

@ -54,7 +54,7 @@ from trl.trainer.grpo_trainer import (
wandb, wandb,
) )
from src.config import logger from config import logger
torch_compile_options = { torch_compile_options = {
"epilogue_fusion": True, "epilogue_fusion": True,

@ -9,10 +9,10 @@ from dataclasses import dataclass
import torch import torch
from trl.trainer.grpo_trainer import apply_chat_template from trl.trainer.grpo_trainer import apply_chat_template
from src.config import logger from config import logger
from src.prompts import build_user_prompt, get_system_prompt from src.deepsearch.prompts import build_user_prompt, get_system_prompt
from src.search_module import search from src.deepsearch.search_module import search
from src.tokenizer_adapter import TokenizerAdapter from src.deepsearch.tokenizer_adapter import TokenizerAdapter
def extract_search_query(text: str) -> str | None: def extract_search_query(text: str) -> str | None:

@ -5,10 +5,10 @@ Evaluation utilities for RL training.
import inspect import inspect
from datetime import datetime from datetime import datetime
from src.agent import Agent from src.deepsearch.agent import Agent
from src.config import logger from config import logger
from src.search_module import get_qa_dataset from src.deepsearch.search_module import get_qa_dataset
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
async def verify(student_answer: str, question: str, answer: str) -> bool: async def verify(student_answer: str, question: str, answer: str) -> bool:

@ -9,8 +9,8 @@ from difflib import SequenceMatcher
import numpy as np import numpy as np
from src.config import LOG_FOLDER, logger from config import LOG_FOLDER, logger
from src.evaluation import check_student_answers from src.deepsearch.evaluation import check_student_answers
def build_reward_correctness_fn( def build_reward_correctness_fn(

@ -9,8 +9,8 @@ import random
from datasets import Dataset from datasets import Dataset
from langchain_community.vectorstores import FAISS from langchain_community.vectorstores import FAISS
from src.config import DATA_DIR, logger from config import DATA_DIR, logger
from src.embeddings import CustomHuggingFaceEmbeddings from src.deepsearch.embeddings import CustomHuggingFaceEmbeddings
# Load pre-saved vectorstore # Load pre-saved vectorstore

@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
import torch import torch
from src.config import logger from config import logger
class TokenizerAdapter(ABC): class TokenizerAdapter(ABC):

@ -2,8 +2,8 @@
from transformers import LlamaTokenizerFast from transformers import LlamaTokenizerFast
from src.agent import Agent from src.deepsearch.agent import Agent
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
def mock_generate_fn(prompts): def mock_generate_fn(prompts):

@ -4,7 +4,7 @@ Test cases for reward functions in rewards.py
import pytest import pytest
from src.rewards import ( from src.deepsearch.rewards import (
build_reward_correctness_fn, build_reward_correctness_fn,
reward_em_chunk, reward_em_chunk,
reward_format, reward_format,

@ -5,8 +5,8 @@ Test module for tokenizer adapters.
import torch import torch
from transformers import AutoTokenizer, LlamaTokenizerFast from transformers import AutoTokenizer, LlamaTokenizerFast
from src.config import logger from config import logger
from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
# Test conversation used across all tests # Test conversation used across all tests
TEST_CHAT = [ TEST_CHAT = [

@ -6,12 +6,12 @@ import os
from unsloth import FastLanguageModel, is_bfloat16_supported from unsloth import FastLanguageModel, is_bfloat16_supported
import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp import src.deepsearch.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
# Import reward functions # Import reward functions
from src import build_reward_correctness_fn, get_qa_dataset, reward_em_chunk, reward_format, reward_retry from src import build_reward_correctness_fn, get_qa_dataset, reward_em_chunk, reward_format, reward_retry
from src.agent import Agent from src.deepsearch.agent import Agent
from src.config import ( from config import (
MODEL_CONFIG, MODEL_CONFIG,
MODEL_NAME, MODEL_NAME,
OUTPUT_DIR, OUTPUT_DIR,
@ -21,7 +21,7 @@ from src.config import (
logger, logger,
update_log_path, update_log_path,
) )
from src.rewards import ( from src.deepsearch.rewards import (
build_reward_correctness_fn, build_reward_correctness_fn,
reward_em_chunk, reward_em_chunk,
reward_format, reward_format,
@ -29,8 +29,8 @@ from src.rewards import (
reward_search_diversity, reward_search_diversity,
reward_search_strategy, reward_search_strategy,
) )
from src.search_module import get_qa_dataset from src.deepsearch.search_module import get_qa_dataset
from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
# Initialize training directories # Initialize training directories
paths = init_training_dirs() paths = init_training_dirs()

Loading…
Cancel
Save