"""
Tokenizer adapter implementations for different models.
This module provides adapter classes for handling different tokenizer formats.
"""

from abc import ABC, abstractmethod

import torch

from config import logger


class TokenizerAdapter(ABC):
    """Base class for tokenizer adapters."""

    @abstractmethod
    def get_assistant_marker(self) -> str:
        """Get the assistant marker for the model."""
        pass

    @abstractmethod
    def get_end_marker(self) -> str:
        """Get the end marker for the model."""
        pass

    @abstractmethod
    def get_mask(self, text: str, tokenizer) -> torch.Tensor:
        """Get the mask for the model's response."""
        pass

    @abstractmethod
    def split_prompt_assistant(self, text: str) -> tuple[str, str]:
        """Split conversation text into prompt and assistant response."""
        pass


class LlamaTokenizerAdapter(TokenizerAdapter):
    """Adapter for Llama model tokenizer."""

    def get_assistant_marker(self) -> str:
        """Get the assistant marker."""
        return "<|start_header_id|>assistant<|end_header_id|>"

    def get_end_marker(self) -> str:
        """Get the end marker."""
        return "<|eot_id|>"

    def split_prompt_assistant(self, convo_text: str) -> tuple[str, str]:
        """Split the text into prompt and assistant parts.

        Args:
            convo_text: The text to split

        Returns:
            A tuple of (prompt, assistant)
        """
        # EXACT replication from rl_helpers.py but using existing method
        marker = self.get_assistant_marker()  # Use existing method but same value
        idx = convo_text.find(marker)
        if idx == -1:
            raise ValueError("Could not find assistant marker in conversation text.")
            return convo_text, ""

        # Include the marker in the prompt by slicing up to the end of the marker.
        prompt = convo_text[: idx + len(marker)]
        # The assistant response is everything after the marker.
        assistant_response = convo_text[idx + len(marker) :]
        return prompt, assistant_response

    def get_mask(self, text: str, tokenizer) -> torch.Tensor:
        """Get the mask for the text.

        Args:
            text: The text to get the mask for
            tokenizer: The tokenizer to use

        Returns:
            A tensor of 0s and 1s where 1s indicate assistant tokens
        """
        # Log input
        logger.debug(f"πŸ” Llama: Full text length: {len(text)}")

        # EXACT replication from rl_helpers.py but using existing methods
        encoding = tokenizer(text, add_special_tokens=False)
        # Use existing methods but same values
        start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>")
        assistant_token = tokenizer.convert_tokens_to_ids("assistant")
        end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>")
        eot_id = tokenizer.convert_tokens_to_ids(self.get_end_marker())  # Use existing method but same value

        # Log token IDs
        logger.debug(f"πŸ” Llama: Tokenized length: {len(encoding.input_ids)}")
        logger.debug(f"πŸ” Llama: Input IDs: {encoding.input_ids}")
        logger.debug(
            f"πŸ” Llama: Special token IDs: start={start_header_id}, assistant={assistant_token}, end={end_header_id}, eot={eot_id}"
        )

        assistant_ranges = []
        i = 0
        while i < len(encoding.input_ids) - 1:
            if encoding.input_ids[i] == start_header_id and encoding.input_ids[i + 1] == assistant_token:
                logger.debug(f"πŸ” Llama: Found assistant marker at position {i}")
                logger.debug(f"πŸ” Llama: Assistant marker tokens: {encoding.input_ids[i : i + 2]}")
                i += 2
                while i < len(encoding.input_ids) and encoding.input_ids[i] != end_header_id:
                    i += 1
                i += 2
                start_idx = i
                logger.debug(f"πŸ” Llama: Found start of response at {start_idx}")
                logger.debug(f"πŸ” Llama: Start token ID: {encoding.input_ids[start_idx]}")
                while i < len(encoding.input_ids) and encoding.input_ids[i] != eot_id:
                    i += 1
                end_idx = i
                logger.debug(f"πŸ” Llama: Found end of response at {end_idx}")
                logger.debug(f"πŸ” Llama: End token ID: {encoding.input_ids[end_idx]}")
                logger.debug(f"πŸ” Llama: Response token IDs: {encoding.input_ids[start_idx:end_idx]}")
                assistant_ranges.append((start_idx, end_idx))
            else:
                i += 1

        mask = [0] * len(encoding.input_ids)
        for start_idx, end_idx in assistant_ranges:
            for idx in range(start_idx, end_idx):
                mask[idx] = 1

        mask = torch.tensor(mask, dtype=torch.int)

        # Log final mask
        logger.debug(f"πŸ” Llama: Final mask shape: {mask.shape}")
        logger.debug(f"πŸ” Llama: Mask sum: {mask.sum().item()}")
        logger.debug(f"πŸ” Llama: Mask: {mask}")

        # Additional debug info
        try:
            prompt, response = self.split_prompt_assistant(text)
            prompt_tokens = tokenizer(prompt, add_special_tokens=False).input_ids
            response_tokens = tokenizer(response, add_special_tokens=False).input_ids

            logger.debug(f"πŸ” Llama: Prompt length: {len(prompt)}")
            logger.debug(f"πŸ” Llama: Response length: {len(response)}")
            logger.debug(f"πŸ” Llama: Prompt token IDs: {prompt_tokens}")
            logger.debug(f"πŸ” Llama: Response token IDs: {response_tokens}")
            logger.debug(f"πŸ” Llama: Prompt: {prompt[:100]}...")
            logger.debug(f"πŸ” Llama: Response: {response[:100]}...")
            logger.debug(f"πŸ” Llama: Full input IDs length: {len(encoding.input_ids)}")
            logger.debug(f"πŸ” Llama: Prompt + Response token IDs length: {len(prompt_tokens) + len(response_tokens)}")
            logger.debug(
                f"πŸ” Llama: Difference in lengths: {len(encoding.input_ids) - (len(prompt_tokens) + len(response_tokens))}"
            )
        except Exception as e:
            logger.error(f"πŸ” Llama: Error splitting prompt/response: {e}")

        return mask


class R1DistilTokenizerAdapter(TokenizerAdapter):
    """Adapter for R1-Distil model tokenizer."""

    def get_assistant_marker(self) -> str:
        marker = "<|Assistant|>"
        return marker

    def get_end_marker(self) -> str:
        marker = "<|end▁of▁sentence|>"
        return marker

    def get_begin_marker(self) -> str:
        return "<|begin▁of▁sentence|>"

    def get_user_marker(self) -> str:
        return "<|User|>"

    def get_mask(self, text: str, tokenizer) -> torch.Tensor:
        """Get the mask for the text.

        Args:
            text: The text to get the mask for
            tokenizer: The tokenizer to use

        Returns:
            A tensor of 0s and 1s where 1s indicate assistant tokens
        """
        logger.debug(f"πŸ” R1Distil: Getting mask for text length: {len(text)}")

        # Get all markers
        assistant_marker = self.get_assistant_marker()
        end_marker = self.get_end_marker()

        # Get the full tokenization
        encoding = tokenizer(text, add_special_tokens=False)
        tokens = encoding.input_ids
        logger.debug(f"πŸ” R1Distil: Full text token IDs: {tokens}")

        # Create mask initialized to zeros - ENSURE SAME LENGTH AS INPUT_IDS
        mask = torch.zeros(len(tokens), dtype=torch.int)

        # Get token IDs for markers
        assistant_tokens = tokenizer(assistant_marker, add_special_tokens=False).input_ids
        end_tokens = tokenizer(end_marker, add_special_tokens=False).input_ids
        logger.debug(f"πŸ” R1Distil: Assistant marker token IDs: {assistant_tokens}")
        logger.debug(f"πŸ” R1Distil: End marker token IDs: {end_tokens}")

        # Find all assistant responses
        assistant_ranges = []
        i = 0
        while i < len(tokens):
            # Look for assistant marker
            if i + len(assistant_tokens) <= len(tokens) and tokens[i : i + len(assistant_tokens)] == assistant_tokens:
                logger.debug(f"πŸ” R1Distil: Found assistant marker at position {i}")

                # Start masking AFTER the assistant marker
                start_idx = i + len(assistant_tokens)

                # Find end marker
                end_idx = None
                j = start_idx
                while j < len(tokens):
                    if j + len(end_tokens) <= len(tokens) and tokens[j : j + len(end_tokens)] == end_tokens:
                        end_idx = j  # Don't include the end marker in the mask
                        break
                    j += 1

                if end_idx is None:
                    # If no end marker found, mask until the end
                    end_idx = len(tokens)

                logger.debug(f"πŸ” R1Distil: Response range: {start_idx} to {end_idx}")
                assistant_ranges.append((start_idx, end_idx))
                i = end_idx + len(end_tokens)  # Start next search after the end marker
            else:
                i += 1

        # Apply mask for all found ranges
        for start_idx, end_idx in assistant_ranges:
            mask[start_idx:end_idx] = 1

        logger.debug(f"πŸ” R1Distil: Found {len(assistant_ranges)} assistant responses")
        logger.debug(f"πŸ” R1Distil: Final mask sum: {mask.sum().item()}")
        logger.debug(f"πŸ” R1Distil: Final mask length: {len(mask)}")
        logger.debug(f"πŸ” R1Distil: Mask: {mask}")

        return mask

    def split_prompt_assistant(self, text: str) -> tuple[str, str]:
        """Split the text into prompt and assistant parts.

        Args:
            text: The text to split

        Returns:
            A tuple of (prompt, assistant)
        """
        logger.debug(f"πŸ” R1Distil: Splitting text of length: {len(text)}")

        # Find the assistant marker
        marker = self.get_assistant_marker()
        end_marker = self.get_end_marker()

        # Find ALL assistant markers in the text
        assistant_markers = []
        pos = 0
        while True:
            pos = text.find(marker, pos)
            if pos == -1:
                break
            assistant_markers.append(pos)
            pos += len(marker)

        if not assistant_markers:
            raise ValueError("Could not find assistant marker in text")

        # Get the positions of all markers for later use
        marker_positions = []
        for start_pos in assistant_markers:
            response_start = start_pos + len(marker)

            # Find the end marker after this response
            end_pos = text.find(end_marker, response_start)
            if end_pos == -1:
                end_pos = len(text)
            else:
                end_pos = end_pos + len(end_marker)

            marker_positions.append((start_pos, response_start, end_pos))

        # Get the full response (all assistant outputs concatenated)
        full_response = ""
        for _, resp_start, resp_end in marker_positions:
            full_response += text[resp_start:resp_end]

        # Include ALL assistant markers and responses in the response
        # This matches how the mask is generated in get_mask
        first_assistant_pos = marker_positions[0][0]
        last_response_end = marker_positions[-1][2]

        # Split into prompt and response
        prompt = text[:first_assistant_pos]  # Everything before the first assistant marker
        response = text[first_assistant_pos:last_response_end]  # All markers and responses

        logger.debug(f"πŸ” R1Distil: Prompt length: {len(prompt)}")
        logger.debug(f"πŸ” R1Distil: Response length: {len(response)}")
        logger.debug(f"πŸ” R1Distil: Response token count estimate: {len(response) / 4}")  # Rough estimate
        logger.debug(f"πŸ” R1Distil: Final prompt: {prompt[:100]}...")
        logger.debug(f"πŸ” R1Distil: Final response: {response[:100]}...")

        return prompt, response


class QwenTokenizerAdapter(TokenizerAdapter):
    """Adapter for Qwen2.5 model tokenizer."""

    def get_assistant_marker(self) -> str:
        """Get the assistant marker."""
        return "<|im_start|>assistant"

    def get_end_marker(self) -> str:
        """Get the end marker."""
        return "<|im_end|>"

    def split_prompt_assistant(self, convo_text: str) -> tuple[str, str]:
        """Split the text into prompt and assistant parts.

        Args:
            convo_text: The text to split

        Returns:
            A tuple of (prompt, assistant)
        """
        marker = self.get_assistant_marker()
        idx = convo_text.find(marker)
        if idx == -1:
            raise ValueError("Could not find assistant marker in conversation text.")
            return convo_text, ""

        # Include the marker in the prompt by slicing up to the end of the marker
        prompt = convo_text[: idx + len(marker)]
        # The assistant response is everything after the marker
        assistant_response = convo_text[idx + len(marker) :]
        return prompt, assistant_response

    def get_mask(self, text: str, tokenizer) -> torch.Tensor:
        """Get the mask for the text.

        Args:
            text: The text to get the mask for
            tokenizer: The tokenizer to use

        Returns:
            A tensor of 0s and 1s where 1s indicate assistant tokens
        """
        # Log input
        logger.debug(f"πŸ” Qwen: Full text length: {len(text)}")

        encoding = tokenizer(text, add_special_tokens=False)
        # Get token IDs for markers
        im_start = tokenizer.convert_tokens_to_ids("<|im_start|>")
        assistant_token = tokenizer.convert_tokens_to_ids("assistant")
        im_end = tokenizer.convert_tokens_to_ids(self.get_end_marker())

        # Log token IDs
        logger.debug(f"πŸ” Qwen: Tokenized length: {len(encoding.input_ids)}")
        logger.debug(f"πŸ” Qwen: Input IDs: {encoding.input_ids}")
        logger.debug(f"πŸ” Qwen: Special token IDs: im_start={im_start}, assistant={assistant_token}, im_end={im_end}")

        assistant_ranges = []
        i = 0
        while i < len(encoding.input_ids) - 1:
            if encoding.input_ids[i] == im_start and encoding.input_ids[i + 1] == assistant_token:
                logger.debug(f"πŸ” Qwen: Found assistant marker at position {i}")
                logger.debug(f"πŸ” Qwen: Assistant marker tokens: {encoding.input_ids[i : i + 2]}")
                i += 2  # Skip past <|im_start|>assistant
                start_idx = i
                logger.debug(f"πŸ” Qwen: Found start of response at {start_idx}")
                logger.debug(f"πŸ” Qwen: Start token ID: {encoding.input_ids[start_idx]}")

                while i < len(encoding.input_ids) and encoding.input_ids[i] != im_end:
                    i += 1
                end_idx = i
                logger.debug(f"πŸ” Qwen: Found end of response at {end_idx}")
                logger.debug(f"πŸ” Qwen: End token ID: {encoding.input_ids[end_idx]}")
                logger.debug(f"πŸ” Qwen: Response token IDs: {encoding.input_ids[start_idx:end_idx]}")
                assistant_ranges.append((start_idx, end_idx))
            else:
                i += 1

        mask = [0] * len(encoding.input_ids)
        for start_idx, end_idx in assistant_ranges:
            for idx in range(start_idx, end_idx):
                mask[idx] = 1

        mask = torch.tensor(mask, dtype=torch.int)

        # Log final mask
        logger.debug(f"πŸ” Qwen: Final mask shape: {mask.shape}")
        logger.debug(f"πŸ” Qwen: Mask sum: {mask.sum().item()}")
        logger.debug(f"πŸ” Qwen: Mask: {mask}")

        # Additional debug info
        try:
            prompt, response = self.split_prompt_assistant(text)
            prompt_tokens = tokenizer(prompt, add_special_tokens=False).input_ids
            response_tokens = tokenizer(response, add_special_tokens=False).input_ids

            logger.debug(f"πŸ” Qwen: Prompt length: {len(prompt)}")
            logger.debug(f"πŸ” Qwen: Response length: {len(response)}")
            logger.debug(f"πŸ” Qwen: Prompt token IDs: {prompt_tokens}")
            logger.debug(f"πŸ” Qwen: Response token IDs: {response_tokens}")
            logger.debug(f"πŸ” Qwen: Prompt: {prompt[:100]}...")
            logger.debug(f"πŸ” Qwen: Response: {response[:100]}...")
            logger.debug(f"πŸ” Qwen: Full input IDs length: {len(encoding.input_ids)}")
            logger.debug(f"πŸ” Qwen: Prompt + Response token IDs length: {len(prompt_tokens) + len(response_tokens)}")
            logger.debug(
                f"πŸ” Qwen: Difference in lengths: {len(encoding.input_ids) - (len(prompt_tokens) + len(response_tokens))}"
            )
        except Exception as e:
            logger.error(f"πŸ” Qwen: Error splitting prompt/response: {e}")

        return mask