From 2fec4f2f420025fc34da251973947c7c174de544 Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Sun, 6 Apr 2025 22:22:32 +0700 Subject: [PATCH] refactor: change repo stucture (move code from src/ to src/deepsearch) --- src/config.py => config.py | 8 ++++++-- eval.py | 2 +- inference.py | 2 +- scripts/eval_base.py | 2 +- scripts/eval_lora.py | 2 +- scripts/generate_data.py | 4 ++-- scripts/simple_qa.py | 4 ++-- src/__init__.py | 14 +++++++------- src/{ => deepsearch}/UnslothGRPOTrainerTemp.py | 2 +- src/deepsearch/__init__.py | 0 src/{ => deepsearch}/agent.py | 8 ++++---- src/{ => deepsearch}/embeddings.py | 0 src/{ => deepsearch}/evaluation.py | 8 ++++---- src/{ => deepsearch}/prompts.py | 0 src/{ => deepsearch}/rewards.py | 4 ++-- src/{ => deepsearch}/search_module.py | 4 ++-- src/{ => deepsearch}/tokenizer_adapter.py | 2 +- tests/test_agent.py | 4 ++-- tests/test_rewards.py | 2 +- tests/test_tokenizer_adapters.py | 4 ++-- train_grpo.py | 12 ++++++------ 21 files changed, 46 insertions(+), 42 deletions(-) rename src/config.py => config.py (98%) rename src/{ => deepsearch}/UnslothGRPOTrainerTemp.py (99%) create mode 100644 src/deepsearch/__init__.py rename src/{ => deepsearch}/agent.py (98%) rename src/{ => deepsearch}/embeddings.py (100%) rename src/{ => deepsearch}/evaluation.py (97%) rename src/{ => deepsearch}/prompts.py (100%) rename src/{ => deepsearch}/rewards.py (99%) rename src/{ => deepsearch}/search_module.py (98%) rename src/{ => deepsearch}/tokenizer_adapter.py (99%) diff --git a/src/config.py b/config.py similarity index 98% rename from src/config.py rename to config.py index cba9d89..e00ec1d 100644 --- a/src/config.py +++ b/config.py @@ -12,7 +12,7 @@ from vllm import SamplingParams load_dotenv(override=True) # Project paths -PROJ_ROOT = Path(__file__).resolve().parent.parent +PROJ_ROOT = Path(__file__).resolve().parent DATA_DIR = PROJ_ROOT / "data" LOG_FOLDER = PROJ_ROOT / "logs" @@ -55,7 +55,7 @@ TRAINING_CONFIG = { "logging_steps": 1, "per_device_train_batch_size": 8, "gradient_accumulation_steps": 1, # Increase to 4 for smoother training - "num_generations": 8, # Decrease if out of memory + "num_generations": 6, # Decrease if out of memory "max_prompt_length": 1024, "max_completion_length": 1024, "max_steps": 101, @@ -244,3 +244,7 @@ _init_logging(env=env) # Log project root on import logger.info(f"Project root path: {PROJ_ROOT}") logger.debug(f"Running in {env} environment") + + +if __name__ == "__main__": + print(PROJ_ROOT) diff --git a/eval.py b/eval.py index d623f43..25d02f5 100644 --- a/eval.py +++ b/eval.py @@ -21,7 +21,7 @@ from src import ( get_system_prompt, run_eval, ) -from src.config import MODEL_NAME, logger +from config import MODEL_NAME, logger def get_model_config(): diff --git a/inference.py b/inference.py index d9b7ec1..929e190 100644 --- a/inference.py +++ b/inference.py @@ -20,7 +20,7 @@ from src import ( format_search_results, get_system_prompt, ) -from src.search_module import load_vectorstore, search +from src.deepsearch.search_module import load_vectorstore, search def setup_model_and_tokenizer(model_path: str): diff --git a/scripts/eval_base.py b/scripts/eval_base.py index 53167e9..2857f83 100644 --- a/scripts/eval_base.py +++ b/scripts/eval_base.py @@ -20,7 +20,7 @@ from src import ( get_system_prompt, run_eval, ) -from src.config import logger +from config import logger def main(): diff --git a/scripts/eval_lora.py b/scripts/eval_lora.py index 0a70354..bdec694 100644 --- a/scripts/eval_lora.py +++ b/scripts/eval_lora.py @@ -21,7 +21,7 @@ from src import ( get_system_prompt, run_eval, ) -from src.config import logger +from config import logger def main(): diff --git a/scripts/generate_data.py b/scripts/generate_data.py index ff67313..8830686 100644 --- a/scripts/generate_data.py +++ b/scripts/generate_data.py @@ -28,8 +28,8 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import UnstructuredMarkdownLoader from langchain_community.vectorstores import FAISS -from src.config import DATA_DIR, logger -from src.embeddings import CustomHuggingFaceEmbeddings +from config import DATA_DIR, logger +from src.deepsearch.embeddings import CustomHuggingFaceEmbeddings # Load your markdown file (adjust the path as needed) loader = UnstructuredMarkdownLoader("./data/mission_report.md") diff --git a/scripts/simple_qa.py b/scripts/simple_qa.py index 7467526..892d6a2 100644 --- a/scripts/simple_qa.py +++ b/scripts/simple_qa.py @@ -15,8 +15,8 @@ project_root = Path(__file__).resolve().parent.parent sys.path.append(str(project_root)) # Import our search module and config -from src.config import DATA_DIR, logger -from src.search_module import get_question_answer, get_question_count, search +from config import DATA_DIR, logger +from src.deepsearch.search_module import get_question_answer, get_question_count, search # TODO: Import verify function and router from appropriate module # TODO: Consider moving verify function to search_module.py for better organization diff --git a/src/__init__.py b/src/__init__.py index 5e9f975..0f267bd 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -4,18 +4,18 @@ Main package exports for RL helpers. from trl.trainer.grpo_trainer import apply_chat_template -from src.agent import Agent, extract_search_query -from src.config import logger -from src.evaluation import check_student_answers, run_eval, verify -from src.prompts import build_user_prompt, format_search_results, get_system_prompt -from src.rewards import ( +from config import logger +from src.deepsearch.agent import Agent, extract_search_query +from src.deepsearch.evaluation import check_student_answers, run_eval, verify +from src.deepsearch.prompts import build_user_prompt, format_search_results, get_system_prompt +from src.deepsearch.rewards import ( build_reward_correctness_fn, reward_em_chunk, reward_format, reward_retry, ) -from src.search_module import get_qa_dataset, search -from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter +from src.deepsearch.search_module import get_qa_dataset, search +from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter __all__ = [ # Prompts diff --git a/src/UnslothGRPOTrainerTemp.py b/src/deepsearch/UnslothGRPOTrainerTemp.py similarity index 99% rename from src/UnslothGRPOTrainerTemp.py rename to src/deepsearch/UnslothGRPOTrainerTemp.py index 18ca2bf..731e174 100644 --- a/src/UnslothGRPOTrainerTemp.py +++ b/src/deepsearch/UnslothGRPOTrainerTemp.py @@ -54,7 +54,7 @@ from trl.trainer.grpo_trainer import ( wandb, ) -from src.config import logger +from config import logger torch_compile_options = { "epilogue_fusion": True, diff --git a/src/deepsearch/__init__.py b/src/deepsearch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/agent.py b/src/deepsearch/agent.py similarity index 98% rename from src/agent.py rename to src/deepsearch/agent.py index 0f87e4d..049baf6 100644 --- a/src/agent.py +++ b/src/deepsearch/agent.py @@ -9,10 +9,10 @@ from dataclasses import dataclass import torch from trl.trainer.grpo_trainer import apply_chat_template -from src.config import logger -from src.prompts import build_user_prompt, get_system_prompt -from src.search_module import search -from src.tokenizer_adapter import TokenizerAdapter +from config import logger +from src.deepsearch.prompts import build_user_prompt, get_system_prompt +from src.deepsearch.search_module import search +from src.deepsearch.tokenizer_adapter import TokenizerAdapter def extract_search_query(text: str) -> str | None: diff --git a/src/embeddings.py b/src/deepsearch/embeddings.py similarity index 100% rename from src/embeddings.py rename to src/deepsearch/embeddings.py diff --git a/src/evaluation.py b/src/deepsearch/evaluation.py similarity index 97% rename from src/evaluation.py rename to src/deepsearch/evaluation.py index 750b7aa..4096f2b 100644 --- a/src/evaluation.py +++ b/src/deepsearch/evaluation.py @@ -5,10 +5,10 @@ Evaluation utilities for RL training. import inspect from datetime import datetime -from src.agent import Agent -from src.config import logger -from src.search_module import get_qa_dataset -from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter +from src.deepsearch.agent import Agent +from config import logger +from src.deepsearch.search_module import get_qa_dataset +from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter async def verify(student_answer: str, question: str, answer: str) -> bool: diff --git a/src/prompts.py b/src/deepsearch/prompts.py similarity index 100% rename from src/prompts.py rename to src/deepsearch/prompts.py diff --git a/src/rewards.py b/src/deepsearch/rewards.py similarity index 99% rename from src/rewards.py rename to src/deepsearch/rewards.py index 3587f3c..44000a0 100644 --- a/src/rewards.py +++ b/src/deepsearch/rewards.py @@ -9,8 +9,8 @@ from difflib import SequenceMatcher import numpy as np -from src.config import LOG_FOLDER, logger -from src.evaluation import check_student_answers +from config import LOG_FOLDER, logger +from src.deepsearch.evaluation import check_student_answers def build_reward_correctness_fn( diff --git a/src/search_module.py b/src/deepsearch/search_module.py similarity index 98% rename from src/search_module.py rename to src/deepsearch/search_module.py index f0fad85..03149c4 100644 --- a/src/search_module.py +++ b/src/deepsearch/search_module.py @@ -9,8 +9,8 @@ import random from datasets import Dataset from langchain_community.vectorstores import FAISS -from src.config import DATA_DIR, logger -from src.embeddings import CustomHuggingFaceEmbeddings +from config import DATA_DIR, logger +from src.deepsearch.embeddings import CustomHuggingFaceEmbeddings # Load pre-saved vectorstore diff --git a/src/tokenizer_adapter.py b/src/deepsearch/tokenizer_adapter.py similarity index 99% rename from src/tokenizer_adapter.py rename to src/deepsearch/tokenizer_adapter.py index f7aad02..e44f1e8 100644 --- a/src/tokenizer_adapter.py +++ b/src/deepsearch/tokenizer_adapter.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod import torch -from src.config import logger +from config import logger class TokenizerAdapter(ABC): diff --git a/tests/test_agent.py b/tests/test_agent.py index 0b19483..64a3a68 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -2,8 +2,8 @@ from transformers import LlamaTokenizerFast -from src.agent import Agent -from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter +from src.deepsearch.agent import Agent +from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter def mock_generate_fn(prompts): diff --git a/tests/test_rewards.py b/tests/test_rewards.py index ad2b9b2..8d7cf5a 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -4,7 +4,7 @@ Test cases for reward functions in rewards.py import pytest -from src.rewards import ( +from src.deepsearch.rewards import ( build_reward_correctness_fn, reward_em_chunk, reward_format, diff --git a/tests/test_tokenizer_adapters.py b/tests/test_tokenizer_adapters.py index 1d1db98..198f1d9 100644 --- a/tests/test_tokenizer_adapters.py +++ b/tests/test_tokenizer_adapters.py @@ -5,8 +5,8 @@ Test module for tokenizer adapters. import torch from transformers import AutoTokenizer, LlamaTokenizerFast -from src.config import logger -from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter +from config import logger +from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter # Test conversation used across all tests TEST_CHAT = [ diff --git a/train_grpo.py b/train_grpo.py index 182a8a3..92d9dc1 100644 --- a/train_grpo.py +++ b/train_grpo.py @@ -6,12 +6,12 @@ import os from unsloth import FastLanguageModel, is_bfloat16_supported -import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp +import src.deepsearch.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp # Import reward functions from src import build_reward_correctness_fn, get_qa_dataset, reward_em_chunk, reward_format, reward_retry -from src.agent import Agent -from src.config import ( +from src.deepsearch.agent import Agent +from config import ( MODEL_CONFIG, MODEL_NAME, OUTPUT_DIR, @@ -21,7 +21,7 @@ from src.config import ( logger, update_log_path, ) -from src.rewards import ( +from src.deepsearch.rewards import ( build_reward_correctness_fn, reward_em_chunk, reward_format, @@ -29,8 +29,8 @@ from src.rewards import ( reward_search_diversity, reward_search_strategy, ) -from src.search_module import get_qa_dataset -from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter +from src.deepsearch.search_module import get_qa_dataset +from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter # Initialize training directories paths = init_training_dirs()