from enum import Enum
from typing import Union, Optional
import io
from PIL import Image
import numpy as np
import torch
import struct


from enum import auto
from typing import List, Dict, Tuple
import wave
from dataclasses import dataclass
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
from einops import rearrange
from torch import Tensor


@dataclass
class ModelConfig:
    """Configuration for the enhanced BytePredictor model."""

    vocab_size: int = 256  # Standard byte range
    hidden_size: int = 1024
    num_layers: int = 12
    num_key_value_heads: int = 8  # For multi-query attention
    num_query_heads: int = 32  # More query heads than kv heads
    dropout: float = 0.1
    max_sequence_length: int = 8192
    rope_theta: float = 10000.0
    layer_norm_eps: float = 1e-5
    vocab_parallel: bool = False
    qk_norm: bool = True
    qk_norm_scale: float = None
    attention_bias: bool = False


class MultiQueryAttention(nn.Module):
    """Fixed Multi-Query Attention implementation."""

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_query_heads = config.num_query_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.head_dim = config.hidden_size // config.num_query_heads
        self.qk_scale = config.qk_norm_scale or (self.head_dim**-0.5)

        self.q_proj = nn.Linear(
            config.hidden_size, config.num_query_heads * self.head_dim
        )
        self.k_proj = nn.Linear(
            config.hidden_size,
            config.num_key_value_heads * self.head_dim,
        )
        self.v_proj = nn.Linear(
            config.hidden_size,
            config.num_key_value_heads * self.head_dim,
        )
        self.o_proj = nn.Linear(
            config.num_query_heads * self.head_dim, config.hidden_size
        )

        self.qk_norm = config.qk_norm
        if self.qk_norm:
            self.q_norm = nn.LayerNorm(self.head_dim)
            self.k_norm = nn.LayerNorm(self.head_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size, seq_length, _ = hidden_states.shape

        # Project and reshape
        q = self.q_proj(hidden_states)
        k = self.k_proj(hidden_states)
        v = self.v_proj(hidden_states)

        # Reshape to [seq_len, batch, heads, head_dim]
        q = q.view(
            batch_size,
            seq_length,
            self.num_query_heads,
            self.head_dim,
        ).permute(1, 0, 2, 3)
        k = k.view(
            batch_size,
            seq_length,
            self.num_key_value_heads,
            self.head_dim,
        ).permute(1, 0, 2, 3)
        v = v.view(
            batch_size,
            seq_length,
            self.num_key_value_heads,
            self.head_dim,
        ).permute(1, 0, 2, 3)

        # Apply rotary embeddings
        # q, k = self.rotary(q, k, seq_length)

        # Apply QK normalization if enabled
        if self.qk_norm:
            q = self.q_norm(q)
            k = self.k_norm(k)

        # Handle MQA head expansion
        if self.num_key_value_heads != self.num_query_heads:
            k = k.repeat_interleave(
                self.num_query_heads // self.num_key_value_heads,
                dim=2,
            )
            v = v.repeat_interleave(
                self.num_query_heads // self.num_key_value_heads,
                dim=2,
            )

        # Compute attention
        # Reshape for matmul: [batch, heads, seq_length, head_dim]
        q = q.permute(1, 2, 0, 3)
        k = k.permute(1, 2, 0, 3)
        v = v.permute(1, 2, 0, 3)

        attn_weights = (
            torch.matmul(q, k.transpose(-2, -1)) * self.qk_scale
        )

        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask

        attn_weights = F.softmax(attn_weights, dim=-1)

        output = torch.matmul(attn_weights, v)

        # Reshape back to [batch, seq_length, hidden_size]
        output = (
            output.transpose(1, 2)
            .contiguous()
            .view(batch_size, seq_length, -1)
        )
        output = self.o_proj(output)

        return output


class EnhancedBytePredictor(nn.Module):
    """Enhanced byte prediction model with state-of-the-art features."""

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        # Token embeddings
        self.tok_embeddings = nn.Embedding(
            config.vocab_size, config.hidden_size
        )

        # Transformer layers
        self.layers = nn.ModuleList(
            [
                nn.ModuleDict(
                    {
                        "attention": MultiQueryAttention(config),
                        "attention_norm": nn.LayerNorm(
                            config.hidden_size,
                            eps=config.layer_norm_eps,
                        ),
                        "feed_forward": nn.Sequential(
                            nn.Linear(
                                config.hidden_size,
                                4 * config.hidden_size,
                            ),
                            nn.GELU(),
                            nn.Linear(
                                4 * config.hidden_size,
                                config.hidden_size,
                            ),
                        ),
                        "feed_forward_norm": nn.LayerNorm(
                            config.hidden_size,
                            eps=config.layer_norm_eps,
                        ),
                    }
                )
                for _ in range(config.num_layers)
            ]
        )

        self.norm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
        self.output = nn.Linear(
            config.hidden_size, config.vocab_size, bias=False
        )

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module) -> None:
        """Initialize weights with scaled normal distribution."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Forward pass of the model.

        Args:
            input_ids: Tensor of shape (batch_size, sequence_length)
            attention_mask: Optional attention mask

        Returns:
            Tensor of logits with shape (batch_size, sequence_length, vocab_size)
        """
        hidden_states = self.tok_embeddings(input_ids)

        # Create causal mask if needed
        if attention_mask is None:
            attention_mask = torch.triu(
                torch.ones(
                    (input_ids.size(1), input_ids.size(1)),
                    device=input_ids.device,
                    dtype=torch.bool,
                ),
                diagonal=1,
            )
            attention_mask = attention_mask.masked_fill(
                attention_mask == 1, float("-inf")
            )

        # Apply transformer layers
        for layer in self.layers:
            # Attention block
            hidden_states = hidden_states + layer["attention"](
                layer["attention_norm"](hidden_states), attention_mask
            )

            # Feed-forward block
            hidden_states = hidden_states + layer["feed_forward"](
                layer["feed_forward_norm"](hidden_states)
            )

        hidden_states = self.norm(hidden_states)
        logits = self.output(hidden_states)

        return logits

    def compute_loss(
        self,
        input_ids: torch.Tensor,
        target_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Compute cross entropy loss.

        Args:
            input_ids: Input token ids
            target_ids: Target token ids
            attention_mask: Optional attention mask

        Returns:
            Loss value
        """
        logits = self(input_ids, attention_mask)
        loss = F.cross_entropy(
            rearrange(logits, "b s v -> (b s) v"),
            rearrange(target_ids, "b s -> (b s)"),
        )
        return loss

    @torch.no_grad()
    def _generate(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int = 100,
        temperature: float = 1.0,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        repetition_penalty: float = 1.0,
    ) -> torch.Tensor:
        """
        Generate new tokens autoregressively.

        Args:
            input_ids: Starting sequence
            max_new_tokens: Number of tokens to generate
            temperature: Sampling temperature
            top_k: K for top-k sampling
            top_p: P for nucleus sampling
            repetition_penalty: Penalty for repeating tokens

        Returns:
            Generated sequence
        """
        batch_size, seq_length = input_ids.shape
        generated = input_ids.clone()

        for _ in range(max_new_tokens):
            if generated.size(1) >= self.config.max_sequence_length:
                break

            # Forward pass
            logits = self(generated)[:, -1, :]

            # Apply temperature
            logits = logits / temperature

            # Apply repetition penalty
            if repetition_penalty != 1.0:
                for i in range(batch_size):
                    for token_id in set(generated[i].tolist()):
                        logits[i, token_id] /= repetition_penalty

            # Apply top-k sampling
            if top_k is not None:
                indices_to_remove = (
                    logits
                    < torch.topk(logits, top_k)[0][..., -1, None]
                )
                logits[indices_to_remove] = float("-inf")

            # Apply nucleus (top-p) sampling
            if top_p is not None:
                sorted_logits, sorted_indices = torch.sort(
                    logits, descending=True
                )
                cumulative_probs = torch.cumsum(
                    F.softmax(sorted_logits, dim=-1), dim=-1
                )

                # Remove tokens with cumulative probability above the threshold
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = (
                    sorted_indices_to_remove[..., :-1].clone()
                )
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = torch.zeros_like(
                    logits, dtype=torch.bool
                )
                indices_to_remove.scatter_(
                    1, sorted_indices, sorted_indices_to_remove
                )
                logits[indices_to_remove] = float("-inf")

            # Sample next token
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            # Append to sequence
            generated = torch.cat([generated, next_token], dim=1)

        return generated

    def generate(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int = 100,
        temperature: float = 1.0,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        repetition_penalty: float = 1.0,
    ):
        tensor_data = self._generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
        )

        return tensor_to_data(tensor_data)


# import torch
# from typing import Optional


class DataType(Enum):
    TEXT = "text"
    IMAGE = "image"
    AUDIO = "audio"
    VIDEO = "video"
    BINARY = "binary"


class ByteDetokenizer:
    """Utility class for converting model output bytes back to original data formats."""

    @staticmethod
    def tensor_to_bytes(tensor: torch.Tensor) -> bytes:
        """Convert model output tensor to bytes."""
        # Convert logits/probabilities to byte values
        if tensor.dim() > 1:
            # If we have logits, convert to byte indices
            byte_indices = tensor.argmax(dim=-1)
        else:
            byte_indices = tensor

        # Convert to Python bytes
        return bytes(
            byte_indices.cpu().numpy().astype(np.uint8).tolist()
        )

    @staticmethod
    def decode_text(byte_sequence: bytes) -> str:
        """Convert bytes to text."""
        try:
            return byte_sequence.decode("utf-8")
        except UnicodeDecodeError:
            # Try with error handling
            return byte_sequence.decode("utf-8", errors="replace")

    @staticmethod
    def decode_image(
        byte_sequence: bytes,
        mode: str = "RGB",
        size: Optional[tuple] = None,
    ) -> Image.Image:
        """Convert bytes to image.

        Args:
            byte_sequence: Raw image bytes
            mode: Image mode (RGB, RGBA, L, etc.)
            size: Optional tuple of (width, height)
        """
        try:
            # Try to load as-is first (for standard image formats)
            img = Image.open(io.BytesIO(byte_sequence))
            if size:
                img = img.resize(size)
            return img
        except:
            # If failed, assume raw pixel data
            if not size:
                # Try to determine size from byte count
                pixel_count = len(byte_sequence) // len(mode)
                size = (
                    int(np.sqrt(pixel_count)),
                    int(np.sqrt(pixel_count)),
                )

            # Convert raw bytes to pixel array
            pixels = np.frombuffer(byte_sequence, dtype=np.uint8)
            pixels = pixels.reshape((*size, len(mode)))

            return Image.fromarray(pixels, mode=mode)

    @staticmethod
    def decode_audio(
        byte_sequence: bytes,
        sample_rate: int = 44100,
        channels: int = 2,
        sample_width: int = 2,
    ) -> np.ndarray:
        """Convert bytes to audio samples.

        Args:
            byte_sequence: Raw audio bytes
            sample_rate: Audio sample rate in Hz
            channels: Number of audio channels
            sample_width: Bytes per sample (1, 2, or 4)
        """
        # Determine format string based on sample width
        format_str = {
            1: "b",  # signed char
            2: "h",  # short
            4: "i",  # int
        }[sample_width]

        # Unpack bytes to samples
        sample_count = len(byte_sequence) // (channels * sample_width)
        samples = struct.unpack(
            f"<{sample_count * channels}{format_str}", byte_sequence
        )

        # Reshape to [samples, channels]
        return np.array(samples).reshape(-1, channels)

    def decode_data(
        self,
        model_output: Union[torch.Tensor, bytes],
        data_type: DataType,
        **kwargs,
    ) -> Union[str, Image.Image, np.ndarray, bytes]:
        """Main method to decode model output to desired format.

        Args:
            model_output: Either tensor from model or raw bytes
            data_type: Type of data to decode to
            **kwargs: Additional parameters for specific decoders

        Returns:
            Decoded data in specified format
        """
        # Convert tensor to bytes if needed
        if isinstance(model_output, torch.Tensor):
            byte_sequence = self.tensor_to_bytes(model_output)
        else:
            byte_sequence = model_output

        # Decode based on type
        if data_type == DataType.TEXT:
            return self.decode_text(byte_sequence)
        elif data_type == DataType.IMAGE:
            return self.decode_image(byte_sequence, **kwargs)
        elif data_type == DataType.AUDIO:
            return self.decode_audio(byte_sequence, **kwargs)
        elif data_type == DataType.VIDEO:
            raise NotImplementedError(
                "Video decoding not yet implemented"
            )
        else:  # BINARY
            return byte_sequence


# Usage example


class Modality(Enum):
    TEXT = auto()
    IMAGE = auto()
    AUDIO = auto()
    VIDEO = auto()
    BINARY = auto()
    MULTIMODAL = auto()


@dataclass
class ModalityInfo:
    """Information about detected modality."""

    modality: Modality
    confidence: float
    metadata: Dict[str, any]
    sub_modalities: Optional[List["ModalityInfo"]] = None


class ModalityDetector:
    """Detects data modalities from byte sequences."""

    # Common file signatures (magic numbers)
    SIGNATURES = {
        # Images
        b"\xFF\xD8\xFF": "JPEG",
        b"\x89PNG\r\n\x1a\n": "PNG",
        b"GIF87a": "GIF",
        b"GIF89a": "GIF",
        b"RIFF": "WEBP",
        # Audio
        b"RIFF....WAVE": "WAV",
        b"ID3": "MP3",
        b"\xFF\xFB": "MP3",
        b"OggS": "OGG",
        # Video
        b"\x00\x00\x00\x18ftypmp42": "MP4",
        b"\x00\x00\x00\x1Cftypav01": "MP4",
        b"\x1A\x45\xDF\xA3": "WEBM",
    }

    def __init__(self):
        self.magic = magic.Magic(mime=True)

    def _check_text_probability(self, data: bytes) -> float:
        """Estimate probability that data is text."""
        # Check if data is valid UTF-8
        try:
            data.decode("utf-8")
            # Count printable ASCII characters
            printable = sum(1 for b in data if 32 <= b <= 126)
            return printable / len(data)
        except UnicodeDecodeError:
            return 0.0

    def _check_image_validity(self, data: bytes) -> Tuple[bool, Dict]:
        """Check if data is a valid image and extract metadata."""
        try:
            with io.BytesIO(data) as bio:
                img = Image.open(bio)
                return True, {
                    "format": img.format,
                    "size": img.size,
                    "mode": img.mode,
                }
        except:
            return False, {}

    def _check_audio_validity(self, data: bytes) -> Tuple[bool, Dict]:
        """Check if data is valid audio and extract metadata."""
        try:
            with io.BytesIO(data) as bio:
                # Try to parse as WAV
                with wave.open(bio) as wav:
                    return True, {
                        "channels": wav.getnchannels(),
                        "sample_width": wav.getsampwidth(),
                        "framerate": wav.getframerate(),
                        "frames": wav.getnframes(),
                    }
        except:
            # Check for other audio signatures
            for sig in [b"ID3", b"\xFF\xFB", b"OggS"]:
                if data.startswith(sig):
                    return True, {"format": "compressed_audio"}
            return False, {}

    def _detect_boundaries(
        self, data: bytes
    ) -> List[Tuple[int, int, Modality]]:
        """Detect boundaries between different modalities in the data."""
        boundaries = []
        current_pos = 0

        while current_pos < len(data):
            # Look for known signatures
            for sig, format_type in self.SIGNATURES.items():
                if data[current_pos:].startswith(sig):
                    # Found a signature, determine its length
                    if format_type in ["JPEG", "PNG", "GIF"]:
                        # Find image end
                        try:
                            with io.BytesIO(
                                data[current_pos:]
                            ) as bio:
                                img = Image.open(bio)
                                img.verify()
                                size = bio.tell()
                                boundaries.append(
                                    (
                                        current_pos,
                                        current_pos + size,
                                        Modality.IMAGE,
                                    )
                                )
                                current_pos += size
                                continue
                        except:
                            pass

            # Check for text sections
            text_prob = self._check_text_probability(
                data[current_pos : current_pos + 1024]
            )
            if text_prob > 0.8:
                # Look for end of text section
                end_pos = current_pos + 1
                while end_pos < len(data):
                    if (
                        self._check_text_probability(
                            data[end_pos : end_pos + 32]
                        )
                        < 0.5
                    ):
                        break
                    end_pos += 1
                boundaries.append(
                    (current_pos, end_pos, Modality.TEXT)
                )
                current_pos = end_pos
                continue

            current_pos += 1

        return boundaries

    def detect_modality(self, data: bytes) -> ModalityInfo:
        """Detect modality of byte sequence."""
        # First check for single modality
        mime_type = self.magic.from_buffer(data)

        # Check text
        text_prob = self._check_text_probability(data)
        if text_prob > 0.9:
            return ModalityInfo(
                modality=Modality.TEXT,
                confidence=text_prob,
                metadata={"mime_type": mime_type},
            )

        # Check image
        is_image, image_meta = self._check_image_validity(data)
        if is_image:
            return ModalityInfo(
                modality=Modality.IMAGE,
                confidence=1.0,
                metadata={**image_meta, "mime_type": mime_type},
            )

        # Check audio
        is_audio, audio_meta = self._check_audio_validity(data)
        if is_audio:
            return ModalityInfo(
                modality=Modality.AUDIO,
                confidence=1.0,
                metadata={**audio_meta, "mime_type": mime_type},
            )

        # Check for multimodal content
        boundaries = self._detect_boundaries(data)
        if len(boundaries) > 1:
            sub_modalities = []
            for start, end, modality in boundaries:
                chunk_data = data[start:end]
                sub_info = self.detect_modality(chunk_data)
                if sub_info.modality != Modality.BINARY:
                    sub_modalities.append(sub_info)

            if sub_modalities:
                return ModalityInfo(
                    modality=Modality.MULTIMODAL,
                    confidence=0.8,
                    metadata={"mime_type": "multipart/mixed"},
                    sub_modalities=sub_modalities,
                )

        # Default to binary
        return ModalityInfo(
            modality=Modality.BINARY,
            confidence=0.5,
            metadata={"mime_type": mime_type},
        )

    def split_modalities(
        self, data: bytes
    ) -> List[Tuple[Modality, bytes, Dict]]:
        """Split multimodal data into separate modalities."""
        boundaries = self._detect_boundaries(data)
        result = []

        for start, end, modality in boundaries:
            chunk = data[start:end]
            info = self.detect_modality(chunk)
            result.append((modality, chunk, info.metadata))

        return result


class AutoDetectBytesDecoder:
    """Decoder that automatically detects and decodes different modalities."""

    def __init__(self):
        self.detector = ModalityDetector()
        self.text_decoder = ByteDetokenizer()  # From previous example

    def decode(
        self, data: bytes
    ) -> Union[str, Image.Image, np.ndarray, List[any]]:
        """Automatically detect and decode byte sequence."""
        info = self.detector.detect_modality(data)

        if info.modality == Modality.MULTIMODAL:
            # Handle multimodal content
            parts = self.detector.split_modalities(data)
            return [
                self.decode(chunk) for modality, chunk, _ in parts
            ]

        if info.modality == Modality.TEXT:
            return self.text_decoder.decode_text(data)
        elif info.modality == Modality.IMAGE:
            return self.text_decoder.decode_image(data)
        elif info.modality == Modality.AUDIO:
            return self.text_decoder.decode_audio(data)
        else:
            return data


# # Example usage
# def demo_auto_detection():
#     """Demonstrate auto modality detection."""
#     # Create mixed content
#     text = "Hello, World!".encode('utf-8')

#     # Create a small test image
#     img = Image.new('RGB', (100, 100), color='red')
#     img_bytes = io.BytesIO()
#     img.save(img_bytes, format='PNG')

#     # Combine into multimodal content
#     mixed_content = text + img_bytes.getvalue()

#     # Initialize decoder
#     decoder = AutoDetectBytesDecoder()

#     # Decode
#     result = decoder.decode(mixed_content)

#     if isinstance(result, list):
#         print("Detected multimodal content:")
#         for i, part in enumerate(result):
#             print(f"Part {i+1}: {type(part)}")

# if __name__ == "__main__":
#     demo_auto_detection()


def tensor_to_data(tensor: Tensor):
    byte_sequence = ByteDetokenizer.tensor_to_bytes(tensor)

    # Initialize auto-detector
    decoder = AutoDetectBytesDecoder()

    # Decode with automatic detection
    result = decoder.decode(byte_sequence)

    return result


def demo_byte_predictor():
    """Demo with smaller dimensions to test."""
    # Initialize model configuration with adjusted dimensions
    config = ModelConfig(
        vocab_size=256,
        hidden_size=128,  # Smaller for testing
        num_layers=2,  # Fewer layers for testing
        num_key_value_heads=2,
        num_query_heads=4,
        dropout=0.1,
        max_sequence_length=1024,
    )

    # Initialize model
    model = EnhancedBytePredictor(config)
    logger.info("Model initialized")

    # Move to GPU if available
    device = torch.device(
        "cuda" if torch.cuda.is_available() else "cpu"
    )
    model = model.to(device)
    logger.info(f"Using device: {device}")

    # Create sample input data
    batch_size = 2
    seq_length = 16  # Shorter sequence for testing
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, seq_length), device=device
    )
    logger.info(f"Created input tensor of shape: {input_ids.shape}")

    # Test forward pass
    try:
        logits = model(input_ids)
        logger.info(
            f"Forward pass successful! Output shape: {logits.shape}"
        )

        # Test loss computation
        target_ids = torch.randint(
            0,
            config.vocab_size,
            (batch_size, seq_length),
            device=device,
        )
        loss = model.compute_loss(input_ids, target_ids)
        logger.info(
            f"Loss computation successful! Loss value: {loss.item():.4f}"
        )

        # Test generation
        prompt = torch.randint(
            0,
            config.vocab_size,
            (1, 4),  # Very short prompt for testing
            device=device,
        )
        generated = model.generate(
            prompt, max_new_tokens=8, temperature=0.8, top_k=50
        )
        logger.info(
            f"Generation successful! Generated shape: {generated.shape}"
        )

    except Exception as e:
        logger.error(f"Error during execution: {str(e)}")
        raise


if __name__ == "__main__":
    # Set up logging
    # logger.remove()  # Remove default handler
    # logger.add(sys.stderr, format="<green>{time:HH:mm:ss}</green> | {level} | {message}")

    demo_byte_predictor()