diff --git a/rl_helpers.py b/rl_helpers.py index 95d79ce..9df28f5 100644 --- a/rl_helpers.py +++ b/rl_helpers.py @@ -68,7 +68,7 @@ def build_user_prompt(q): Given a question, answer it using by doing searches using the search_corpus tool. To use the search_corpus tool, respond with a JSON for a function call with its proper arguments. -You may also reason in any message, thinking step by step about how to answer the question. Wrap your reasoning in and tags. +You may also reason in any message, think step by step about how to answer the question. Wrap your reasoning in and tags. {json.dumps(SEARCH_TOOL_DEFINITION, indent=2)} @@ -140,17 +140,17 @@ def extract_json_objects(text): def remove_reasoning(text: str) -> str: """ - Removes all content between and tags, + Removes all content between and tags, including the tags themselves. Parameters: - text (str): The input text that may contain ... tags. + text (str): The input text that may contain ... tags. Returns: str: The text with the tags and their content removed. """ - # The regex pattern matches from to non-greedily. - pattern = r".*?" + # The regex pattern matches from to non-greedily. + pattern = r".*?" cleaned_text = re.sub(pattern, "", text, flags=re.DOTALL) return cleaned_text @@ -495,6 +495,7 @@ def check_student_answers( return results +# Reward Functions def build_reward_correctness_fn(generate_fn, tokenizer): def reward_correctness(prompts, completions, **reward_kwargs): teacher_answers = reward_kwargs["answer"] @@ -525,6 +526,14 @@ def reward_formatting(prompts, completions, **reward_kwargs): return [0.7 if not e else 0 for e in has_error] +# def reward_retry_behavior(prompts, completions, **reward_kwargs): +# pass + + +# def reward_exact_match_chunk_query(prompts, completions, **reward_kwargs): +# pass + + def run_eval(generate_fn, verify_fn, tokenizer): train_dataset, test_dataset = get_qa_dataset() questions = test_dataset["prompt"]