refactor: moved modules from src/deepsearch to src/

main
thinhlpg 3 months ago
parent 0f662d4330
commit eebf914a81

@ -20,7 +20,7 @@ from src import (
format_search_results,
get_system_prompt,
)
from src.deepsearch.search_module import load_vectorstore, search
from src.search_module import load_vectorstore, search
def setup_model_and_tokenizer(model_path: str):

@ -29,7 +29,7 @@ from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.vectorstores import FAISS
from config import DATA_DIR, logger
from src.deepsearch.embeddings import CustomHuggingFaceEmbeddings
from src.embeddings import CustomHuggingFaceEmbeddings
# Load your markdown file (adjust the path as needed)
loader = UnstructuredMarkdownLoader("./data/mission_report.md")

@ -16,7 +16,7 @@ sys.path.append(str(project_root))
# Import our search module and config
from config import DATA_DIR, logger
from src.deepsearch.search_module import get_question_answer, get_question_count, search
from src.search_module import get_question_answer, get_question_count, search
# TODO: Import verify function and router from appropriate module
# TODO: Consider moving verify function to search_module.py for better organization

@ -5,17 +5,17 @@ Main package exports for RL helpers.
from trl.trainer.grpo_trainer import apply_chat_template
from config import logger
from src.deepsearch.agent import Agent, extract_search_query
from src.deepsearch.evaluation import check_student_answers, run_eval, verify
from src.deepsearch.prompts import build_user_prompt, format_search_results, get_system_prompt
from src.deepsearch.rewards import (
from src.agent import Agent, extract_search_query
from src.evaluation import check_student_answers, run_eval, verify
from src.prompts import build_user_prompt, format_search_results, get_system_prompt
from src.rewards import (
build_reward_correctness_fn,
reward_em_chunk,
reward_format,
reward_retry,
)
from src.deepsearch.search_module import get_qa_dataset, search
from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
from src.search_module import get_qa_dataset, search
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
__all__ = [
# Prompts

@ -10,9 +10,9 @@ import torch
from trl.trainer.grpo_trainer import apply_chat_template
from config import logger
from src.deepsearch.prompts import build_user_prompt, get_system_prompt
from src.deepsearch.search_module import search
from src.deepsearch.tokenizer_adapter import TokenizerAdapter
from src.prompts import build_user_prompt, get_system_prompt
from src.search_module import search
from src.tokenizer_adapter import TokenizerAdapter
def extract_search_query(text: str) -> str | None:
@ -36,9 +36,15 @@ class AgenticOutputs:
class Agent:
"""Base agent class for handling tool-based conversations."""
def __init__(self, tokenizer_adapter: TokenizerAdapter):
"""Initialize the agent with a tokenizer adapter."""
def __init__(self, tokenizer_adapter: TokenizerAdapter, search_fn=None):
"""Initialize the agent with a tokenizer adapter and optional search function.
Args:
tokenizer_adapter: Tokenizer adapter for handling text
search_fn: Optional custom search function. If None, uses default search.
"""
self.tokenizer_adapter = tokenizer_adapter
self.search_fn = search_fn or search # Use provided search function or default
def get_initial_chat(self, question: str) -> dict:
"""Initialize a chat state with the question."""
@ -113,11 +119,10 @@ class Agent:
search_query = extract_search_query(assistant_response)
if search_query:
logger.info(f"🔍 Search Query: {search_query}")
results = search(search_query, return_type=str, results=2)
results = self.search_fn(search_query, return_type=str, results=2)
formatted_results = f"<information>{results}</information>"
logger.info(f" Information: {formatted_results}")
# chat_state["messages"].append({"role": "ipython", "content": formatted_results})
chat_state["messages"].append({"role": "user", "content": formatted_results})
logger.debug("Added search results to chat state")
except Exception as e:

@ -5,10 +5,10 @@ Evaluation utilities for RL training.
import inspect
from datetime import datetime
from src.deepsearch.agent import Agent
from config import logger
from src.deepsearch.search_module import get_qa_dataset
from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
from src.agent import Agent
from src.search_module import get_qa_dataset
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
async def verify(student_answer: str, question: str, answer: str) -> bool:

@ -10,7 +10,7 @@ from difflib import SequenceMatcher
import numpy as np
from config import LOG_FOLDER, logger
from src.deepsearch.evaluation import check_student_answers
from src.evaluation import check_student_answers
def build_reward_correctness_fn(

@ -10,7 +10,7 @@ from datasets import Dataset
from langchain_community.vectorstores import FAISS
from config import DATA_DIR, logger
from src.deepsearch.embeddings import CustomHuggingFaceEmbeddings
from src.embeddings import CustomHuggingFaceEmbeddings
# Load pre-saved vectorstore

@ -2,8 +2,8 @@
from transformers import LlamaTokenizerFast
from src.deepsearch.agent import Agent
from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
from src.agent import Agent
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
def mock_generate_fn(prompts):

@ -4,7 +4,7 @@ Test cases for reward functions in rewards.py
import pytest
from src.deepsearch.rewards import (
from src.rewards import (
build_reward_correctness_fn,
reward_em_chunk,
reward_format,

@ -6,7 +6,7 @@ import torch
from transformers import AutoTokenizer, LlamaTokenizerFast
from config import logger
from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
# Test conversation used across all tests
TEST_CHAT = [

@ -6,11 +6,11 @@ import os
from unsloth import FastLanguageModel, is_bfloat16_supported
import src.deepsearch.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
# Import reward functions
from src import build_reward_correctness_fn, get_qa_dataset, reward_em_chunk, reward_format, reward_retry
from src.deepsearch.agent import Agent
from src.agent import Agent
from config import (
MODEL_CONFIG,
MODEL_NAME,
@ -21,7 +21,7 @@ from config import (
logger,
update_log_path,
)
from src.deepsearch.rewards import (
from src.rewards import (
build_reward_correctness_fn,
reward_em_chunk,
reward_format,
@ -29,8 +29,8 @@ from src.deepsearch.rewards import (
reward_search_diversity,
reward_search_strategy,
)
from src.deepsearch.search_module import get_qa_dataset
from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
from src.search_module import get_qa_dataset
from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
# Initialize training directories
paths = init_training_dirs()

Loading…
Cancel
Save