from enum import Enum from typing import Union, Optional import io from PIL import Image import numpy as np import torch import struct from enum import auto from typing import List, Dict, Tuple import wave from dataclasses import dataclass import torch.nn as nn import torch.nn.functional as F from loguru import logger from einops import rearrange from torch import Tensor @dataclass class ModelConfig: """Configuration for the enhanced BytePredictor model.""" vocab_size: int = 256 # Standard byte range hidden_size: int = 1024 num_layers: int = 12 num_key_value_heads: int = 8 # For multi-query attention num_query_heads: int = 32 # More query heads than kv heads dropout: float = 0.1 max_sequence_length: int = 8192 rope_theta: float = 10000.0 layer_norm_eps: float = 1e-5 vocab_parallel: bool = False qk_norm: bool = True qk_norm_scale: float = None attention_bias: bool = False class MultiQueryAttention(nn.Module): """Fixed Multi-Query Attention implementation.""" def __init__(self, config: ModelConfig): super().__init__() self.hidden_size = config.hidden_size self.num_query_heads = config.num_query_heads self.num_key_value_heads = config.num_key_value_heads self.head_dim = config.hidden_size // config.num_query_heads self.qk_scale = config.qk_norm_scale or (self.head_dim**-0.5) self.q_proj = nn.Linear( config.hidden_size, config.num_query_heads * self.head_dim ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, ) self.o_proj = nn.Linear( config.num_query_heads * self.head_dim, config.hidden_size ) self.qk_norm = config.qk_norm if self.qk_norm: self.q_norm = nn.LayerNorm(self.head_dim) self.k_norm = nn.LayerNorm(self.head_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, seq_length, _ = hidden_states.shape # Project and reshape q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) # Reshape to [seq_len, batch, heads, head_dim] q = q.view( batch_size, seq_length, self.num_query_heads, self.head_dim, ).permute(1, 0, 2, 3) k = k.view( batch_size, seq_length, self.num_key_value_heads, self.head_dim, ).permute(1, 0, 2, 3) v = v.view( batch_size, seq_length, self.num_key_value_heads, self.head_dim, ).permute(1, 0, 2, 3) # Apply rotary embeddings # q, k = self.rotary(q, k, seq_length) # Apply QK normalization if enabled if self.qk_norm: q = self.q_norm(q) k = self.k_norm(k) # Handle MQA head expansion if self.num_key_value_heads != self.num_query_heads: k = k.repeat_interleave( self.num_query_heads // self.num_key_value_heads, dim=2, ) v = v.repeat_interleave( self.num_query_heads // self.num_key_value_heads, dim=2, ) # Compute attention # Reshape for matmul: [batch, heads, seq_length, head_dim] q = q.permute(1, 2, 0, 3) k = k.permute(1, 2, 0, 3) v = v.permute(1, 2, 0, 3) attn_weights = ( torch.matmul(q, k.transpose(-2, -1)) * self.qk_scale ) if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = F.softmax(attn_weights, dim=-1) output = torch.matmul(attn_weights, v) # Reshape back to [batch, seq_length, hidden_size] output = ( output.transpose(1, 2) .contiguous() .view(batch_size, seq_length, -1) ) output = self.o_proj(output) return output class EnhancedBytePredictor(nn.Module): """Enhanced byte prediction model with state-of-the-art features.""" def __init__(self, config: ModelConfig): super().__init__() self.config = config # Token embeddings self.tok_embeddings = nn.Embedding( config.vocab_size, config.hidden_size ) # Transformer layers self.layers = nn.ModuleList( [ nn.ModuleDict( { "attention": MultiQueryAttention(config), "attention_norm": nn.LayerNorm( config.hidden_size, eps=config.layer_norm_eps, ), "feed_forward": nn.Sequential( nn.Linear( config.hidden_size, 4 * config.hidden_size, ), nn.GELU(), nn.Linear( 4 * config.hidden_size, config.hidden_size, ), ), "feed_forward_norm": nn.LayerNorm( config.hidden_size, eps=config.layer_norm_eps, ), } ) for _ in range(config.num_layers) ] ) self.norm = nn.LayerNorm( config.hidden_size, eps=config.layer_norm_eps ) self.output = nn.Linear( config.hidden_size, config.vocab_size, bias=False ) # Initialize weights self.apply(self._init_weights) def _init_weights(self, module: nn.Module) -> None: """Initialize weights with scaled normal distribution.""" if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass of the model. Args: input_ids: Tensor of shape (batch_size, sequence_length) attention_mask: Optional attention mask Returns: Tensor of logits with shape (batch_size, sequence_length, vocab_size) """ hidden_states = self.tok_embeddings(input_ids) # Create causal mask if needed if attention_mask is None: attention_mask = torch.triu( torch.ones( (input_ids.size(1), input_ids.size(1)), device=input_ids.device, dtype=torch.bool, ), diagonal=1, ) attention_mask = attention_mask.masked_fill( attention_mask == 1, float("-inf") ) # Apply transformer layers for layer in self.layers: # Attention block hidden_states = hidden_states + layer["attention"]( layer["attention_norm"](hidden_states), attention_mask ) # Feed-forward block hidden_states = hidden_states + layer["feed_forward"]( layer["feed_forward_norm"](hidden_states) ) hidden_states = self.norm(hidden_states) logits = self.output(hidden_states) return logits def compute_loss( self, input_ids: torch.Tensor, target_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Compute cross entropy loss. Args: input_ids: Input token ids target_ids: Target token ids attention_mask: Optional attention mask Returns: Loss value """ logits = self(input_ids, attention_mask) loss = F.cross_entropy( rearrange(logits, "b s v -> (b s) v"), rearrange(target_ids, "b s -> (b s)"), ) return loss @torch.no_grad() def _generate( self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, repetition_penalty: float = 1.0, ) -> torch.Tensor: """ Generate new tokens autoregressively. Args: input_ids: Starting sequence max_new_tokens: Number of tokens to generate temperature: Sampling temperature top_k: K for top-k sampling top_p: P for nucleus sampling repetition_penalty: Penalty for repeating tokens Returns: Generated sequence """ batch_size, seq_length = input_ids.shape generated = input_ids.clone() for _ in range(max_new_tokens): if generated.size(1) >= self.config.max_sequence_length: break # Forward pass logits = self(generated)[:, -1, :] # Apply temperature logits = logits / temperature # Apply repetition penalty if repetition_penalty != 1.0: for i in range(batch_size): for token_id in set(generated[i].tolist()): logits[i, token_id] /= repetition_penalty # Apply top-k sampling if top_k is not None: indices_to_remove = ( logits < torch.topk(logits, top_k)[0][..., -1, None] ) logits[indices_to_remove] = float("-inf") # Apply nucleus (top-p) sampling if top_p is not None: sorted_logits, sorted_indices = torch.sort( logits, descending=True ) cumulative_probs = torch.cumsum( F.softmax(sorted_logits, dim=-1), dim=-1 ) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = ( sorted_indices_to_remove[..., :-1].clone() ) sorted_indices_to_remove[..., 0] = 0 indices_to_remove = torch.zeros_like( logits, dtype=torch.bool ) indices_to_remove.scatter_( 1, sorted_indices, sorted_indices_to_remove ) logits[indices_to_remove] = float("-inf") # Sample next token probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Append to sequence generated = torch.cat([generated, next_token], dim=1) return generated def generate( self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, repetition_penalty: float = 1.0, ): tensor_data = self._generate( input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, ) return tensor_to_data(tensor_data) # import torch # from typing import Optional class DataType(Enum): TEXT = "text" IMAGE = "image" AUDIO = "audio" VIDEO = "video" BINARY = "binary" class ByteDetokenizer: """Utility class for converting model output bytes back to original data formats.""" @staticmethod def tensor_to_bytes(tensor: torch.Tensor) -> bytes: """Convert model output tensor to bytes.""" # Convert logits/probabilities to byte values if tensor.dim() > 1: # If we have logits, convert to byte indices byte_indices = tensor.argmax(dim=-1) else: byte_indices = tensor # Convert to Python bytes return bytes( byte_indices.cpu().numpy().astype(np.uint8).tolist() ) @staticmethod def decode_text(byte_sequence: bytes) -> str: """Convert bytes to text.""" try: return byte_sequence.decode("utf-8") except UnicodeDecodeError: # Try with error handling return byte_sequence.decode("utf-8", errors="replace") @staticmethod def decode_image( byte_sequence: bytes, mode: str = "RGB", size: Optional[tuple] = None, ) -> Image.Image: """Convert bytes to image. Args: byte_sequence: Raw image bytes mode: Image mode (RGB, RGBA, L, etc.) size: Optional tuple of (width, height) """ try: # Try to load as-is first (for standard image formats) img = Image.open(io.BytesIO(byte_sequence)) if size: img = img.resize(size) return img except: # If failed, assume raw pixel data if not size: # Try to determine size from byte count pixel_count = len(byte_sequence) // len(mode) size = ( int(np.sqrt(pixel_count)), int(np.sqrt(pixel_count)), ) # Convert raw bytes to pixel array pixels = np.frombuffer(byte_sequence, dtype=np.uint8) pixels = pixels.reshape((*size, len(mode))) return Image.fromarray(pixels, mode=mode) @staticmethod def decode_audio( byte_sequence: bytes, sample_rate: int = 44100, channels: int = 2, sample_width: int = 2, ) -> np.ndarray: """Convert bytes to audio samples. Args: byte_sequence: Raw audio bytes sample_rate: Audio sample rate in Hz channels: Number of audio channels sample_width: Bytes per sample (1, 2, or 4) """ # Determine format string based on sample width format_str = { 1: "b", # signed char 2: "h", # short 4: "i", # int }[sample_width] # Unpack bytes to samples sample_count = len(byte_sequence) // (channels * sample_width) samples = struct.unpack( f"<{sample_count * channels}{format_str}", byte_sequence ) # Reshape to [samples, channels] return np.array(samples).reshape(-1, channels) def decode_data( self, model_output: Union[torch.Tensor, bytes], data_type: DataType, **kwargs, ) -> Union[str, Image.Image, np.ndarray, bytes]: """Main method to decode model output to desired format. Args: model_output: Either tensor from model or raw bytes data_type: Type of data to decode to **kwargs: Additional parameters for specific decoders Returns: Decoded data in specified format """ # Convert tensor to bytes if needed if isinstance(model_output, torch.Tensor): byte_sequence = self.tensor_to_bytes(model_output) else: byte_sequence = model_output # Decode based on type if data_type == DataType.TEXT: return self.decode_text(byte_sequence) elif data_type == DataType.IMAGE: return self.decode_image(byte_sequence, **kwargs) elif data_type == DataType.AUDIO: return self.decode_audio(byte_sequence, **kwargs) elif data_type == DataType.VIDEO: raise NotImplementedError( "Video decoding not yet implemented" ) else: # BINARY return byte_sequence # Usage example class Modality(Enum): TEXT = auto() IMAGE = auto() AUDIO = auto() VIDEO = auto() BINARY = auto() MULTIMODAL = auto() @dataclass class ModalityInfo: """Information about detected modality.""" modality: Modality confidence: float metadata: Dict[str, any] sub_modalities: Optional[List["ModalityInfo"]] = None class ModalityDetector: """Detects data modalities from byte sequences.""" # Common file signatures (magic numbers) SIGNATURES = { # Images b"\xFF\xD8\xFF": "JPEG", b"\x89PNG\r\n\x1a\n": "PNG", b"GIF87a": "GIF", b"GIF89a": "GIF", b"RIFF": "WEBP", # Audio b"RIFF....WAVE": "WAV", b"ID3": "MP3", b"\xFF\xFB": "MP3", b"OggS": "OGG", # Video b"\x00\x00\x00\x18ftypmp42": "MP4", b"\x00\x00\x00\x1Cftypav01": "MP4", b"\x1A\x45\xDF\xA3": "WEBM", } def __init__(self): self.magic = magic.Magic(mime=True) def _check_text_probability(self, data: bytes) -> float: """Estimate probability that data is text.""" # Check if data is valid UTF-8 try: data.decode("utf-8") # Count printable ASCII characters printable = sum(1 for b in data if 32 <= b <= 126) return printable / len(data) except UnicodeDecodeError: return 0.0 def _check_image_validity(self, data: bytes) -> Tuple[bool, Dict]: """Check if data is a valid image and extract metadata.""" try: with io.BytesIO(data) as bio: img = Image.open(bio) return True, { "format": img.format, "size": img.size, "mode": img.mode, } except: return False, {} def _check_audio_validity(self, data: bytes) -> Tuple[bool, Dict]: """Check if data is valid audio and extract metadata.""" try: with io.BytesIO(data) as bio: # Try to parse as WAV with wave.open(bio) as wav: return True, { "channels": wav.getnchannels(), "sample_width": wav.getsampwidth(), "framerate": wav.getframerate(), "frames": wav.getnframes(), } except: # Check for other audio signatures for sig in [b"ID3", b"\xFF\xFB", b"OggS"]: if data.startswith(sig): return True, {"format": "compressed_audio"} return False, {} def _detect_boundaries( self, data: bytes ) -> List[Tuple[int, int, Modality]]: """Detect boundaries between different modalities in the data.""" boundaries = [] current_pos = 0 while current_pos < len(data): # Look for known signatures for sig, format_type in self.SIGNATURES.items(): if data[current_pos:].startswith(sig): # Found a signature, determine its length if format_type in ["JPEG", "PNG", "GIF"]: # Find image end try: with io.BytesIO( data[current_pos:] ) as bio: img = Image.open(bio) img.verify() size = bio.tell() boundaries.append( ( current_pos, current_pos + size, Modality.IMAGE, ) ) current_pos += size continue except: pass # Check for text sections text_prob = self._check_text_probability( data[current_pos : current_pos + 1024] ) if text_prob > 0.8: # Look for end of text section end_pos = current_pos + 1 while end_pos < len(data): if ( self._check_text_probability( data[end_pos : end_pos + 32] ) < 0.5 ): break end_pos += 1 boundaries.append( (current_pos, end_pos, Modality.TEXT) ) current_pos = end_pos continue current_pos += 1 return boundaries def detect_modality(self, data: bytes) -> ModalityInfo: """Detect modality of byte sequence.""" # First check for single modality mime_type = self.magic.from_buffer(data) # Check text text_prob = self._check_text_probability(data) if text_prob > 0.9: return ModalityInfo( modality=Modality.TEXT, confidence=text_prob, metadata={"mime_type": mime_type}, ) # Check image is_image, image_meta = self._check_image_validity(data) if is_image: return ModalityInfo( modality=Modality.IMAGE, confidence=1.0, metadata={**image_meta, "mime_type": mime_type}, ) # Check audio is_audio, audio_meta = self._check_audio_validity(data) if is_audio: return ModalityInfo( modality=Modality.AUDIO, confidence=1.0, metadata={**audio_meta, "mime_type": mime_type}, ) # Check for multimodal content boundaries = self._detect_boundaries(data) if len(boundaries) > 1: sub_modalities = [] for start, end, modality in boundaries: chunk_data = data[start:end] sub_info = self.detect_modality(chunk_data) if sub_info.modality != Modality.BINARY: sub_modalities.append(sub_info) if sub_modalities: return ModalityInfo( modality=Modality.MULTIMODAL, confidence=0.8, metadata={"mime_type": "multipart/mixed"}, sub_modalities=sub_modalities, ) # Default to binary return ModalityInfo( modality=Modality.BINARY, confidence=0.5, metadata={"mime_type": mime_type}, ) def split_modalities( self, data: bytes ) -> List[Tuple[Modality, bytes, Dict]]: """Split multimodal data into separate modalities.""" boundaries = self._detect_boundaries(data) result = [] for start, end, modality in boundaries: chunk = data[start:end] info = self.detect_modality(chunk) result.append((modality, chunk, info.metadata)) return result class AutoDetectBytesDecoder: """Decoder that automatically detects and decodes different modalities.""" def __init__(self): self.detector = ModalityDetector() self.text_decoder = ByteDetokenizer() # From previous example def decode( self, data: bytes ) -> Union[str, Image.Image, np.ndarray, List[any]]: """Automatically detect and decode byte sequence.""" info = self.detector.detect_modality(data) if info.modality == Modality.MULTIMODAL: # Handle multimodal content parts = self.detector.split_modalities(data) return [ self.decode(chunk) for modality, chunk, _ in parts ] if info.modality == Modality.TEXT: return self.text_decoder.decode_text(data) elif info.modality == Modality.IMAGE: return self.text_decoder.decode_image(data) elif info.modality == Modality.AUDIO: return self.text_decoder.decode_audio(data) else: return data # # Example usage # def demo_auto_detection(): # """Demonstrate auto modality detection.""" # # Create mixed content # text = "Hello, World!".encode('utf-8') # # Create a small test image # img = Image.new('RGB', (100, 100), color='red') # img_bytes = io.BytesIO() # img.save(img_bytes, format='PNG') # # Combine into multimodal content # mixed_content = text + img_bytes.getvalue() # # Initialize decoder # decoder = AutoDetectBytesDecoder() # # Decode # result = decoder.decode(mixed_content) # if isinstance(result, list): # print("Detected multimodal content:") # for i, part in enumerate(result): # print(f"Part {i+1}: {type(part)}") # if __name__ == "__main__": # demo_auto_detection() def tensor_to_data(tensor: Tensor): byte_sequence = ByteDetokenizer.tensor_to_bytes(tensor) # Initialize auto-detector decoder = AutoDetectBytesDecoder() # Decode with automatic detection result = decoder.decode(byte_sequence) return result def demo_byte_predictor(): """Demo with smaller dimensions to test.""" # Initialize model configuration with adjusted dimensions config = ModelConfig( vocab_size=256, hidden_size=128, # Smaller for testing num_layers=2, # Fewer layers for testing num_key_value_heads=2, num_query_heads=4, dropout=0.1, max_sequence_length=1024, ) # Initialize model model = EnhancedBytePredictor(config) logger.info("Model initialized") # Move to GPU if available device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) model = model.to(device) logger.info(f"Using device: {device}") # Create sample input data batch_size = 2 seq_length = 16 # Shorter sequence for testing input_ids = torch.randint( 0, config.vocab_size, (batch_size, seq_length), device=device ) logger.info(f"Created input tensor of shape: {input_ids.shape}") # Test forward pass try: logits = model(input_ids) logger.info( f"Forward pass successful! Output shape: {logits.shape}" ) # Test loss computation target_ids = torch.randint( 0, config.vocab_size, (batch_size, seq_length), device=device, ) loss = model.compute_loss(input_ids, target_ids) logger.info( f"Loss computation successful! Loss value: {loss.item():.4f}" ) # Test generation prompt = torch.randint( 0, config.vocab_size, (1, 4), # Very short prompt for testing device=device, ) generated = model.generate( prompt, max_new_tokens=8, temperature=0.8, top_k=50 ) logger.info( f"Generation successful! Generated shape: {generated.shape}" ) except Exception as e: logger.error(f"Error during execution: {str(e)}") raise if __name__ == "__main__": # Set up logging # logger.remove() # Remove default handler # logger.add(sys.stderr, format="{time:HH:mm:ss} | {level} | {message}") demo_byte_predictor()