From eebf914a81b1369c9bbf4daa9c7e872a6a28b9eb Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Tue, 8 Apr 2025 05:58:03 +0000 Subject: [PATCH] refactor: moved modules from src/deepsearch to src/ --- inference.py | 2 +- scripts/generate_data.py | 2 +- scripts/simple_qa.py | 2 +- .../UnslothGRPOTrainerTemp.py | 0 src/__init__.py | 12 ++++++------ src/{deepsearch => }/agent.py | 19 ++++++++++++------- src/deepsearch/__init__.py | 0 src/{deepsearch => }/embeddings.py | 0 src/{deepsearch => }/evaluation.py | 6 +++--- src/{deepsearch => }/prompts.py | 0 src/{deepsearch => }/rewards.py | 2 +- src/{deepsearch => }/search_module.py | 2 +- src/{deepsearch => }/tokenizer_adapter.py | 0 tests/test_agent.py | 4 ++-- tests/test_rewards.py | 2 +- tests/test_tokenizer_adapters.py | 2 +- train_grpo.py | 10 +++++----- 17 files changed, 35 insertions(+), 30 deletions(-) rename src/{deepsearch => }/UnslothGRPOTrainerTemp.py (100%) rename src/{deepsearch => }/agent.py (93%) delete mode 100644 src/deepsearch/__init__.py rename src/{deepsearch => }/embeddings.py (100%) rename src/{deepsearch => }/evaluation.py (98%) rename src/{deepsearch => }/prompts.py (100%) rename src/{deepsearch => }/rewards.py (99%) rename src/{deepsearch => }/search_module.py (98%) rename src/{deepsearch => }/tokenizer_adapter.py (100%) diff --git a/inference.py b/inference.py index 929e190..d9b7ec1 100644 --- a/inference.py +++ b/inference.py @@ -20,7 +20,7 @@ from src import ( format_search_results, get_system_prompt, ) -from src.deepsearch.search_module import load_vectorstore, search +from src.search_module import load_vectorstore, search def setup_model_and_tokenizer(model_path: str): diff --git a/scripts/generate_data.py b/scripts/generate_data.py index 8830686..2c8a4b7 100644 --- a/scripts/generate_data.py +++ b/scripts/generate_data.py @@ -29,7 +29,7 @@ from langchain_community.document_loaders import UnstructuredMarkdownLoader from langchain_community.vectorstores import FAISS from config import DATA_DIR, logger -from src.deepsearch.embeddings import CustomHuggingFaceEmbeddings +from src.embeddings import CustomHuggingFaceEmbeddings # Load your markdown file (adjust the path as needed) loader = UnstructuredMarkdownLoader("./data/mission_report.md") diff --git a/scripts/simple_qa.py b/scripts/simple_qa.py index 892d6a2..85fae90 100644 --- a/scripts/simple_qa.py +++ b/scripts/simple_qa.py @@ -16,7 +16,7 @@ sys.path.append(str(project_root)) # Import our search module and config from config import DATA_DIR, logger -from src.deepsearch.search_module import get_question_answer, get_question_count, search +from src.search_module import get_question_answer, get_question_count, search # TODO: Import verify function and router from appropriate module # TODO: Consider moving verify function to search_module.py for better organization diff --git a/src/deepsearch/UnslothGRPOTrainerTemp.py b/src/UnslothGRPOTrainerTemp.py similarity index 100% rename from src/deepsearch/UnslothGRPOTrainerTemp.py rename to src/UnslothGRPOTrainerTemp.py diff --git a/src/__init__.py b/src/__init__.py index 0f267bd..7c170db 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -5,17 +5,17 @@ Main package exports for RL helpers. from trl.trainer.grpo_trainer import apply_chat_template from config import logger -from src.deepsearch.agent import Agent, extract_search_query -from src.deepsearch.evaluation import check_student_answers, run_eval, verify -from src.deepsearch.prompts import build_user_prompt, format_search_results, get_system_prompt -from src.deepsearch.rewards import ( +from src.agent import Agent, extract_search_query +from src.evaluation import check_student_answers, run_eval, verify +from src.prompts import build_user_prompt, format_search_results, get_system_prompt +from src.rewards import ( build_reward_correctness_fn, reward_em_chunk, reward_format, reward_retry, ) -from src.deepsearch.search_module import get_qa_dataset, search -from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter +from src.search_module import get_qa_dataset, search +from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter __all__ = [ # Prompts diff --git a/src/deepsearch/agent.py b/src/agent.py similarity index 93% rename from src/deepsearch/agent.py rename to src/agent.py index 049baf6..d3eb87c 100644 --- a/src/deepsearch/agent.py +++ b/src/agent.py @@ -10,9 +10,9 @@ import torch from trl.trainer.grpo_trainer import apply_chat_template from config import logger -from src.deepsearch.prompts import build_user_prompt, get_system_prompt -from src.deepsearch.search_module import search -from src.deepsearch.tokenizer_adapter import TokenizerAdapter +from src.prompts import build_user_prompt, get_system_prompt +from src.search_module import search +from src.tokenizer_adapter import TokenizerAdapter def extract_search_query(text: str) -> str | None: @@ -36,9 +36,15 @@ class AgenticOutputs: class Agent: """Base agent class for handling tool-based conversations.""" - def __init__(self, tokenizer_adapter: TokenizerAdapter): - """Initialize the agent with a tokenizer adapter.""" + def __init__(self, tokenizer_adapter: TokenizerAdapter, search_fn=None): + """Initialize the agent with a tokenizer adapter and optional search function. + + Args: + tokenizer_adapter: Tokenizer adapter for handling text + search_fn: Optional custom search function. If None, uses default search. + """ self.tokenizer_adapter = tokenizer_adapter + self.search_fn = search_fn or search # Use provided search function or default def get_initial_chat(self, question: str) -> dict: """Initialize a chat state with the question.""" @@ -113,11 +119,10 @@ class Agent: search_query = extract_search_query(assistant_response) if search_query: logger.info(f"🔍 Search Query: {search_query}") - results = search(search_query, return_type=str, results=2) + results = self.search_fn(search_query, return_type=str, results=2) formatted_results = f"{results}" logger.info(f"â„šī¸ Information: {formatted_results}") - # chat_state["messages"].append({"role": "ipython", "content": formatted_results}) chat_state["messages"].append({"role": "user", "content": formatted_results}) logger.debug("Added search results to chat state") except Exception as e: diff --git a/src/deepsearch/__init__.py b/src/deepsearch/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/deepsearch/embeddings.py b/src/embeddings.py similarity index 100% rename from src/deepsearch/embeddings.py rename to src/embeddings.py diff --git a/src/deepsearch/evaluation.py b/src/evaluation.py similarity index 98% rename from src/deepsearch/evaluation.py rename to src/evaluation.py index 4096f2b..4089c24 100644 --- a/src/deepsearch/evaluation.py +++ b/src/evaluation.py @@ -5,10 +5,10 @@ Evaluation utilities for RL training. import inspect from datetime import datetime -from src.deepsearch.agent import Agent from config import logger -from src.deepsearch.search_module import get_qa_dataset -from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter +from src.agent import Agent +from src.search_module import get_qa_dataset +from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter async def verify(student_answer: str, question: str, answer: str) -> bool: diff --git a/src/deepsearch/prompts.py b/src/prompts.py similarity index 100% rename from src/deepsearch/prompts.py rename to src/prompts.py diff --git a/src/deepsearch/rewards.py b/src/rewards.py similarity index 99% rename from src/deepsearch/rewards.py rename to src/rewards.py index 44000a0..9fce68f 100644 --- a/src/deepsearch/rewards.py +++ b/src/rewards.py @@ -10,7 +10,7 @@ from difflib import SequenceMatcher import numpy as np from config import LOG_FOLDER, logger -from src.deepsearch.evaluation import check_student_answers +from src.evaluation import check_student_answers def build_reward_correctness_fn( diff --git a/src/deepsearch/search_module.py b/src/search_module.py similarity index 98% rename from src/deepsearch/search_module.py rename to src/search_module.py index 03149c4..d196543 100644 --- a/src/deepsearch/search_module.py +++ b/src/search_module.py @@ -10,7 +10,7 @@ from datasets import Dataset from langchain_community.vectorstores import FAISS from config import DATA_DIR, logger -from src.deepsearch.embeddings import CustomHuggingFaceEmbeddings +from src.embeddings import CustomHuggingFaceEmbeddings # Load pre-saved vectorstore diff --git a/src/deepsearch/tokenizer_adapter.py b/src/tokenizer_adapter.py similarity index 100% rename from src/deepsearch/tokenizer_adapter.py rename to src/tokenizer_adapter.py diff --git a/tests/test_agent.py b/tests/test_agent.py index 64a3a68..0b19483 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -2,8 +2,8 @@ from transformers import LlamaTokenizerFast -from src.deepsearch.agent import Agent -from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter +from src.agent import Agent +from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter def mock_generate_fn(prompts): diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 8d7cf5a..ad2b9b2 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -4,7 +4,7 @@ Test cases for reward functions in rewards.py import pytest -from src.deepsearch.rewards import ( +from src.rewards import ( build_reward_correctness_fn, reward_em_chunk, reward_format, diff --git a/tests/test_tokenizer_adapters.py b/tests/test_tokenizer_adapters.py index 198f1d9..9220580 100644 --- a/tests/test_tokenizer_adapters.py +++ b/tests/test_tokenizer_adapters.py @@ -6,7 +6,7 @@ import torch from transformers import AutoTokenizer, LlamaTokenizerFast from config import logger -from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter +from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter # Test conversation used across all tests TEST_CHAT = [ diff --git a/train_grpo.py b/train_grpo.py index 92d9dc1..9f3dd58 100644 --- a/train_grpo.py +++ b/train_grpo.py @@ -6,11 +6,11 @@ import os from unsloth import FastLanguageModel, is_bfloat16_supported -import src.deepsearch.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp +import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp # Import reward functions from src import build_reward_correctness_fn, get_qa_dataset, reward_em_chunk, reward_format, reward_retry -from src.deepsearch.agent import Agent +from src.agent import Agent from config import ( MODEL_CONFIG, MODEL_NAME, @@ -21,7 +21,7 @@ from config import ( logger, update_log_path, ) -from src.deepsearch.rewards import ( +from src.rewards import ( build_reward_correctness_fn, reward_em_chunk, reward_format, @@ -29,8 +29,8 @@ from src.deepsearch.rewards import ( reward_search_diversity, reward_search_strategy, ) -from src.deepsearch.search_module import get_qa_dataset -from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter +from src.search_module import get_qa_dataset +from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter # Initialize training directories paths = init_training_dirs()