refactor: change repo stucture (move code from src/ to src/deepsearch)

main
thinhlpg 1 month ago
parent e3163081a0
commit 2fec4f2f42

@ -12,7 +12,7 @@ from vllm import SamplingParams
load_dotenv(override=True)
# Project paths
PROJ_ROOT = Path(__file__).resolve().parent.parent
PROJ_ROOT = Path(__file__).resolve().parent
DATA_DIR = PROJ_ROOT / "data"
LOG_FOLDER = PROJ_ROOT / "logs"
@ -55,7 +55,7 @@ TRAINING_CONFIG = {
"logging_steps": 1,
"per_device_train_batch_size": 8,
"gradient_accumulation_steps": 1, # Increase to 4 for smoother training
"num_generations": 8, # Decrease if out of memory
"num_generations": 6, # Decrease if out of memory
"max_prompt_length": 1024,
"max_completion_length": 1024,
"max_steps": 101,
@ -244,3 +244,7 @@ _init_logging(env=env)
# Log project root on import
logger.info(f"Project root path: {PROJ_ROOT}")
logger.debug(f"Running in {env} environment")
if __name__ == "__main__":
print(PROJ_ROOT)

@ -21,7 +21,7 @@ from src import (
get_system_prompt,
run_eval,
)
from src.config import MODEL_NAME, logger
from config import MODEL_NAME, logger
def get_model_config():

@ -20,7 +20,7 @@ from src import (
format_search_results,
get_system_prompt,
)
from src.search_module import load_vectorstore, search
from src.deepsearch.search_module import load_vectorstore, search
def setup_model_and_tokenizer(model_path: str):

@ -20,7 +20,7 @@ from src import (
get_system_prompt,
run_eval,
)
from src.config import logger
from config import logger
def main():

@ -21,7 +21,7 @@ from src import (
get_system_prompt,
run_eval,
)
from src.config import logger
from config import logger
def main():

@ -28,8 +28,8 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.vectorstores import FAISS
from src.config import DATA_DIR, logger
from src.embeddings import CustomHuggingFaceEmbeddings
from config import DATA_DIR, logger
from src.deepsearch.embeddings import CustomHuggingFaceEmbeddings
# Load your markdown file (adjust the path as needed)
loader = UnstructuredMarkdownLoader("./data/mission_report.md")

@ -15,8 +15,8 @@ project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))
# Import our search module and config
from src.config import DATA_DIR, logger
from src.search_module import get_question_answer, get_question_count, search
from config import DATA_DIR, logger
from src.deepsearch.search_module import get_question_answer, get_question_count, search
# TODO: Import verify function and router from appropriate module
# TODO: Consider moving verify function to search_module.py for better organization

@ -4,18 +4,18 @@ Main package exports for RL helpers.
from trl.trainer.grpo_trainer import apply_chat_template
from src.agent import Agent, extract_search_query
from src.config import logger
from src.evaluation import check_student_answers, run_eval, verify
from src.prompts import build_user_prompt, format_search_results, get_system_prompt
from src.rewards import (
from config import logger
from src.deepsearch.agent import Agent, extract_search_query
from src.deepsearch.evaluation import check_student_answers, run_eval, verify
from src.deepsearch.prompts import build_user_prompt, format_search_results, get_system_prompt
from src.deepsearch.rewards import (
build_reward_correctness_fn,
reward_em_chunk,
reward_format,
reward_retry,
)
from src.search_module import get_qa_dataset, search
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
from src.deepsearch.search_module import get_qa_dataset, search
from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
__all__ = [
# Prompts

@ -54,7 +54,7 @@ from trl.trainer.grpo_trainer import (
wandb,
)
from src.config import logger
from config import logger
torch_compile_options = {
"epilogue_fusion": True,

@ -9,10 +9,10 @@ from dataclasses import dataclass
import torch
from trl.trainer.grpo_trainer import apply_chat_template
from src.config import logger
from src.prompts import build_user_prompt, get_system_prompt
from src.search_module import search
from src.tokenizer_adapter import TokenizerAdapter
from config import logger
from src.deepsearch.prompts import build_user_prompt, get_system_prompt
from src.deepsearch.search_module import search
from src.deepsearch.tokenizer_adapter import TokenizerAdapter
def extract_search_query(text: str) -> str | None:

@ -5,10 +5,10 @@ Evaluation utilities for RL training.
import inspect
from datetime import datetime
from src.agent import Agent
from src.config import logger
from src.search_module import get_qa_dataset
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
from src.deepsearch.agent import Agent
from config import logger
from src.deepsearch.search_module import get_qa_dataset
from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
async def verify(student_answer: str, question: str, answer: str) -> bool:

@ -9,8 +9,8 @@ from difflib import SequenceMatcher
import numpy as np
from src.config import LOG_FOLDER, logger
from src.evaluation import check_student_answers
from config import LOG_FOLDER, logger
from src.deepsearch.evaluation import check_student_answers
def build_reward_correctness_fn(

@ -9,8 +9,8 @@ import random
from datasets import Dataset
from langchain_community.vectorstores import FAISS
from src.config import DATA_DIR, logger
from src.embeddings import CustomHuggingFaceEmbeddings
from config import DATA_DIR, logger
from src.deepsearch.embeddings import CustomHuggingFaceEmbeddings
# Load pre-saved vectorstore

@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
import torch
from src.config import logger
from config import logger
class TokenizerAdapter(ABC):

@ -2,8 +2,8 @@
from transformers import LlamaTokenizerFast
from src.agent import Agent
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
from src.deepsearch.agent import Agent
from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
def mock_generate_fn(prompts):

@ -4,7 +4,7 @@ Test cases for reward functions in rewards.py
import pytest
from src.rewards import (
from src.deepsearch.rewards import (
build_reward_correctness_fn,
reward_em_chunk,
reward_format,

@ -5,8 +5,8 @@ Test module for tokenizer adapters.
import torch
from transformers import AutoTokenizer, LlamaTokenizerFast
from src.config import logger
from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
from config import logger
from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
# Test conversation used across all tests
TEST_CHAT = [

@ -6,12 +6,12 @@ import os
from unsloth import FastLanguageModel, is_bfloat16_supported
import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
import src.deepsearch.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
# Import reward functions
from src import build_reward_correctness_fn, get_qa_dataset, reward_em_chunk, reward_format, reward_retry
from src.agent import Agent
from src.config import (
from src.deepsearch.agent import Agent
from config import (
MODEL_CONFIG,
MODEL_NAME,
OUTPUT_DIR,
@ -21,7 +21,7 @@ from src.config import (
logger,
update_log_path,
)
from src.rewards import (
from src.deepsearch.rewards import (
build_reward_correctness_fn,
reward_em_chunk,
reward_format,
@ -29,8 +29,8 @@ from src.rewards import (
reward_search_diversity,
reward_search_strategy,
)
from src.search_module import get_qa_dataset
from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
from src.deepsearch.search_module import get_qa_dataset
from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
# Initialize training directories
paths = init_training_dirs()

Loading…
Cancel
Save