"""Test agent functionality."""

from transformers import LlamaTokenizerFast

from src.agent import Agent
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter


def mock_generate_fn(prompts):
    """Mock generation function that returns simple responses."""

    class MockResponse:
        def __init__(self, text):
            self.outputs = [type("obj", (object,), {"text": text})()]

    return [MockResponse(f"Assistant: Test response for {i}") for i, _ in enumerate(prompts)]


def test_llama_agent_response_mask_lengths():
    """Test that response tokens and masks have the same length for Llama."""
    # Test data
    questions = ["What is Python?", "How to write tests?"]

    # Setup Llama agent
    tokenizer = LlamaTokenizerFast.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
    agent = Agent(LlamaTokenizerAdapter())

    # Run agent
    outputs = agent.run_agent(
        generate_fn=mock_generate_fn, tokenizer=tokenizer, questions=questions, max_generations=1, max_new_tokens=100
    )

    # Check lengths match for each example
    for i, (tokens, mask) in enumerate(zip(outputs.response_tokens, outputs.response_masks)):
        print(f"\nExample {i}:")
        print(f"Question: {questions[i]}")
        print(f"Response tokens length: {len(tokens)}")
        print(f"Response mask length: {len(mask)}")

        assert len(tokens) == len(mask), f"Mismatch in example {i}: tokens={len(tokens)}, mask={len(mask)}"
        assert mask.sum().item() > 0, "Mask should have some 1s indicating response tokens"
        assert all(x in [0, 1] for x in mask.tolist()), "Mask should only contain 0s and 1s"


def test_r1_distil_agent_response_mask_lengths():
    """Test that response tokens and masks have the same length for R1-Distil."""
    # Test data
    questions = ["What is Python?", "How to write tests?"]

    # Setup R1-Distil agent
    tokenizer = LlamaTokenizerFast.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
    agent = Agent(R1DistilTokenizerAdapter())

    # Run agent
    outputs = agent.run_agent(
        generate_fn=mock_generate_fn, tokenizer=tokenizer, questions=questions, max_generations=1, max_new_tokens=100
    )

    # Check lengths match for each example
    for i, (tokens, mask) in enumerate(zip(outputs.response_tokens, outputs.response_masks)):
        print(f"\nExample {i}:")
        print(f"Question: {questions[i]}")
        print(f"Response tokens length: {len(tokens)}")
        print(f"Response mask length: {len(mask)}")

        assert len(tokens) == len(mask), f"Mismatch in example {i}: tokens={len(tokens)}, mask={len(mask)}"
        assert mask.sum().item() > 0, "Mask should have some 1s indicating response tokens"
        assert all(x in [0, 1] for x in mask.tolist()), "Mask should only contain 0s and 1s"