|
|
@ -6,12 +6,12 @@ import os
|
|
|
|
|
|
|
|
|
|
|
|
from unsloth import FastLanguageModel, is_bfloat16_supported
|
|
|
|
from unsloth import FastLanguageModel, is_bfloat16_supported
|
|
|
|
|
|
|
|
|
|
|
|
import src.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
|
|
|
|
import src.deepsearch.UnslothGRPOTrainerTemp as UnslothGRPOTrainerTemp
|
|
|
|
|
|
|
|
|
|
|
|
# Import reward functions
|
|
|
|
# Import reward functions
|
|
|
|
from src import build_reward_correctness_fn, get_qa_dataset, reward_em_chunk, reward_format, reward_retry
|
|
|
|
from src import build_reward_correctness_fn, get_qa_dataset, reward_em_chunk, reward_format, reward_retry
|
|
|
|
from src.agent import Agent
|
|
|
|
from src.deepsearch.agent import Agent
|
|
|
|
from src.config import (
|
|
|
|
from config import (
|
|
|
|
MODEL_CONFIG,
|
|
|
|
MODEL_CONFIG,
|
|
|
|
MODEL_NAME,
|
|
|
|
MODEL_NAME,
|
|
|
|
OUTPUT_DIR,
|
|
|
|
OUTPUT_DIR,
|
|
|
@ -21,7 +21,7 @@ from src.config import (
|
|
|
|
logger,
|
|
|
|
logger,
|
|
|
|
update_log_path,
|
|
|
|
update_log_path,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from src.rewards import (
|
|
|
|
from src.deepsearch.rewards import (
|
|
|
|
build_reward_correctness_fn,
|
|
|
|
build_reward_correctness_fn,
|
|
|
|
reward_em_chunk,
|
|
|
|
reward_em_chunk,
|
|
|
|
reward_format,
|
|
|
|
reward_format,
|
|
|
@ -29,8 +29,8 @@ from src.rewards import (
|
|
|
|
reward_search_diversity,
|
|
|
|
reward_search_diversity,
|
|
|
|
reward_search_strategy,
|
|
|
|
reward_search_strategy,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from src.search_module import get_qa_dataset
|
|
|
|
from src.deepsearch.search_module import get_qa_dataset
|
|
|
|
from src.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
|
|
|
|
from src.deepsearch.tokenizer_adapter import LlamaTokenizerAdapter, QwenTokenizerAdapter, R1DistilTokenizerAdapter
|
|
|
|
|
|
|
|
|
|
|
|
# Initialize training directories
|
|
|
|
# Initialize training directories
|
|
|
|
paths = init_training_dirs()
|
|
|
|
paths = init_training_dirs()
|
|
|
|