refactor: moved modules from src/deepsearch to src/

main
thinhlpg 3 months ago
parent 0f662d4330
commit eebf914a81

@ -20,7 +20,7 @@ from src import (
format_search_results, format_search_results,
get_system_prompt, get_system_prompt,
) )
from src.deepsearch.search_module import load_vectorstore, search from src.search_module import load_vectorstore, search
def setup_model_and_tokenizer(model_path: str): def setup_model_and_tokenizer(model_path: str):

@ -29,7 +29,7 @@ from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.vectorstores import FAISS from langchain_community.vectorstores import FAISS
from config import DATA_DIR, logger from config import DATA_DIR, logger
from src.deepsearch.embeddings import CustomHuggingFaceEmbeddings from src.embeddings import CustomHuggingFaceEmbeddings
# Load your markdown file (adjust the path as needed) # Load your markdown file (adjust the path as needed)
loader = UnstructuredMarkdownLoader("./data/mission_report.md") loader = UnstructuredMarkdownLoader("./data/mission_report.md")

@ -16,7 +16,7 @@ sys.path.append(str(project_root))
# Import our search module and config # Import our search module and config
from config import DATA_DIR, logger from config import DATA_DIR, logger
from src.deepsearch.search_module import get_question_answer, get_question_count, search from src.search_module import get_question_answer, get_question_count, search
# TODO: Import verify function and router from appropriate module # TODO: Import verify function and router from appropriate module
# TODO: Consider moving verify function to search_module.py for better organization # TODO: Consider moving verify function to search_module.py for better organization

@ -5,17 +5,17 @@ Main package exports for RL helpers.
from trl.trainer.grpo_trainer import apply_chat_template from trl.trainer.grpo_trainer import apply_chat_template
from config import logger from config import logger
from src.deepsearch.agent import Agent, extract_search_query from src.agent import Agent, extract_search_query
from src.deepsearch.evaluation import check_student_answers, run_eval, verify from src.evaluation import check_student_answers, run_eval, verify
from src.deepsearch.prompts import build_user_prompt, format_search_results, get_system_prompt from src.prompts import build_user_prompt, format_search_results, get_system_prompt
from src.deepsearch.rewards import ( from src.rewards import (
build_reward_correctness_fn, build_reward_correctness_fn,
reward_em_chunk, reward_em_chunk,
reward_format, reward_format,
reward_retry, reward_retry,
) )
from src.deepsearch.search_module import get_qa_dataset, search from src.search_module import get_qa_dataset, search
from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
__all__ = [ __all__ = [
# Prompts # Prompts

@ -10,9 +10,9 @@ import torch
from trl.trainer.grpo_trainer import apply_chat_template from trl.trainer.grpo_trainer import apply_chat_template
from config import logger from config import logger
from src.deepsearch.prompts import build_user_prompt, get_system_prompt from src.prompts import build_user_prompt, get_system_prompt
from src.deepsearch.search_module import search from src.search_module import search
from src.deepsearch.tokenizer_adapter import TokenizerAdapter from src.tokenizer_adapter import TokenizerAdapter
def extract_search_query(text: str) -> str | None: def extract_search_query(text: str) -> str | None:
@ -36,9 +36,15 @@ class AgenticOutputs:
class Agent: class Agent:
"""Base agent class for handling tool-based conversations.""" """Base agent class for handling tool-based conversations."""
def __init__(self, tokenizer_adapter: TokenizerAdapter): def __init__(self, tokenizer_adapter: TokenizerAdapter, search_fn=None):
"""Initialize the agent with a tokenizer adapter.""" """Initialize the agent with a tokenizer adapter and optional search function.
Args:
tokenizer_adapter: Tokenizer adapter for handling text
search_fn: Optional custom search function. If None, uses default search.
"""
self.tokenizer_adapter = tokenizer_adapter self.tokenizer_adapter = tokenizer_adapter
self.search_fn = search_fn or search # Use provided search function or default
def get_initial_chat(self, question: str) -> dict: def get_initial_chat(self, question: str) -> dict:
"""Initialize a chat state with the question.""" """Initialize a chat state with the question."""
@ -113,11 +119,10 @@ class Agent:
search_query = extract_search_query(assistant_response) search_query = extract_search_query(assistant_response)
if search_query: if search_query:
logger.info(f"🔍 Search Query: {search_query}") logger.info(f"🔍 Search Query: {search_query}")
results = search(search_query, return_type=str, results=2) results = self.search_fn(search_query, return_type=str, results=2)
formatted_results = f"<information>{results}</information>" formatted_results = f"<information>{results}</information>"
logger.info(f" Information: {formatted_results}") logger.info(f" Information: {formatted_results}")
# chat_state["messages"].append({"role": "ipython", "content": formatted_results})
chat_state["messages"].append({"role": "user", "content": formatted_results}) chat_state["messages"].append({"role": "user", "content": formatted_results})
logger.debug("Added search results to chat state") logger.debug("Added search results to chat state")
except Exception as e: except Exception as e:

@ -5,10 +5,10 @@ Evaluation utilities for RL training.
import inspect import inspect
from datetime import datetime from datetime import datetime
from src.deepsearch.agent import Agent
from config import logger from config import logger
from src.deepsearch.search_module import get_qa_dataset from src.agent import Agent
from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter from src.search_module import get_qa_dataset
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
async def verify(student_answer: str, question: str, answer: str) -> bool: async def verify(student_answer: str, question: str, answer: str) -> bool:

@ -10,7 +10,7 @@ from difflib import SequenceMatcher
import numpy as np import numpy as np
from config import LOG_FOLDER, logger from config import LOG_FOLDER, logger
from src.deepsearch.evaluation import check_student_answers from src.evaluation import check_student_answers
def build_reward_correctness_fn( def build_reward_correctness_fn(

@ -10,7 +10,7 @@ from datasets import Dataset
from langchain_community.vectorstores import FAISS from langchain_community.vectorstores import FAISS
from config import DATA_DIR, logger from config import DATA_DIR, logger
from src.deepsearch.embeddings import CustomHuggingFaceEmbeddings from src.embeddings import CustomHuggingFaceEmbeddings
# Load pre-saved vectorstore # Load pre-saved vectorstore

@ -2,8 +2,8 @@
from transformers import LlamaTokenizerFast from transformers import LlamaTokenizerFast
from src.deepsearch.agent import Agent from src.agent import Agent
from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
def mock_generate_fn(prompts): def mock_generate_fn(prompts):

@ -4,7 +4,7 @@ Test cases for reward functions in rewards.py
import pytest import pytest
from src.deepsearch.rewards import ( from src.rewards import (
build_reward_correctness_fn, build_reward_correctness_fn,
reward_em_chunk, reward_em_chunk,
reward_format, reward_format,

@ -6,7 +6,7 @@ import torch
from transformers import AutoTokenizer, LlamaTokenizerFast from transformers import AutoTokenizer, LlamaTokenizerFast
from config import logger from config import logger
from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
# Test conversation used across all tests # Test conversation used across all tests
TEST_CHAT = [ TEST_CHAT = [

@ -6,11 +6,11 @@ import os
from unsloth import FastLanguageModel, is_bfloat16_supported from unsloth import FastLanguageModel, is_bfloat16_supported
import src.deepsearch.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
# Import reward functions # Import reward functions
from src import build_reward_correctness_fn, get_qa_dataset, reward_em_chunk, reward_format, reward_retry from src import build_reward_correctness_fn, get_qa_dataset, reward_em_chunk, reward_format, reward_retry
from src.deepsearch.agent import Agent from src.agent import Agent
from config import ( from config import (
MODEL_CONFIG, MODEL_CONFIG,
MODEL_NAME, MODEL_NAME,
@ -21,7 +21,7 @@ from config import (
logger, logger,
update_log_path, update_log_path,
) )
from src.deepsearch.rewards import ( from src.rewards import (
build_reward_correctness_fn, build_reward_correctness_fn,
reward_em_chunk, reward_em_chunk,
reward_format, reward_format,
@ -29,8 +29,8 @@ from src.deepsearch.rewards import (
reward_search_diversity, reward_search_diversity,
reward_search_strategy, reward_search_strategy,
) )
from src.deepsearch.search_module import get_qa_dataset from src.search_module import get_qa_dataset
from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
# Initialize training directories # Initialize training directories
paths = init_training_dirs() paths = init_training_dirs()

Loading…
Cancel
Save