From 0e626a686e600152e4d81b194143f20585ae182e Mon Sep 17 00:00:00 2001 From: Kye Gomez <98760976+kyegomez@users.noreply.github.com> Date: Mon, 9 Dec 2024 10:13:41 -0800 Subject: [PATCH] Delete byte.py --- byte.py | 898 -------------------------------------------------------- 1 file changed, 898 deletions(-) delete mode 100644 byte.py diff --git a/byte.py b/byte.py deleted file mode 100644 index d0a5a92f..00000000 --- a/byte.py +++ /dev/null @@ -1,898 +0,0 @@ -from enum import Enum -from typing import Union, Optional -import io -from PIL import Image -import numpy as np -import torch -import struct - - -from enum import auto -from typing import List, Dict, Tuple -import wave -from dataclasses import dataclass -import torch.nn as nn -import torch.nn.functional as F -from loguru import logger -from einops import rearrange -from torch import Tensor - - -@dataclass -class ModelConfig: - """Configuration for the enhanced BytePredictor model.""" - - vocab_size: int = 256 # Standard byte range - hidden_size: int = 1024 - num_layers: int = 12 - num_key_value_heads: int = 8 # For multi-query attention - num_query_heads: int = 32 # More query heads than kv heads - dropout: float = 0.1 - max_sequence_length: int = 8192 - rope_theta: float = 10000.0 - layer_norm_eps: float = 1e-5 - vocab_parallel: bool = False - qk_norm: bool = True - qk_norm_scale: float = None - attention_bias: bool = False - - -class MultiQueryAttention(nn.Module): - """Fixed Multi-Query Attention implementation.""" - - def __init__(self, config: ModelConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.num_query_heads = config.num_query_heads - self.num_key_value_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_query_heads - self.qk_scale = config.qk_norm_scale or (self.head_dim**-0.5) - - self.q_proj = nn.Linear( - config.hidden_size, config.num_query_heads * self.head_dim - ) - self.k_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - ) - self.v_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - ) - self.o_proj = nn.Linear( - config.num_query_heads * self.head_dim, config.hidden_size - ) - - self.qk_norm = config.qk_norm - if self.qk_norm: - self.q_norm = nn.LayerNorm(self.head_dim) - self.k_norm = nn.LayerNorm(self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - batch_size, seq_length, _ = hidden_states.shape - - # Project and reshape - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - - # Reshape to [seq_len, batch, heads, head_dim] - q = q.view( - batch_size, - seq_length, - self.num_query_heads, - self.head_dim, - ).permute(1, 0, 2, 3) - k = k.view( - batch_size, - seq_length, - self.num_key_value_heads, - self.head_dim, - ).permute(1, 0, 2, 3) - v = v.view( - batch_size, - seq_length, - self.num_key_value_heads, - self.head_dim, - ).permute(1, 0, 2, 3) - - # Apply rotary embeddings - # q, k = self.rotary(q, k, seq_length) - - # Apply QK normalization if enabled - if self.qk_norm: - q = self.q_norm(q) - k = self.k_norm(k) - - # Handle MQA head expansion - if self.num_key_value_heads != self.num_query_heads: - k = k.repeat_interleave( - self.num_query_heads // self.num_key_value_heads, - dim=2, - ) - v = v.repeat_interleave( - self.num_query_heads // self.num_key_value_heads, - dim=2, - ) - - # Compute attention - # Reshape for matmul: [batch, heads, seq_length, head_dim] - q = q.permute(1, 2, 0, 3) - k = k.permute(1, 2, 0, 3) - v = v.permute(1, 2, 0, 3) - - attn_weights = ( - torch.matmul(q, k.transpose(-2, -1)) * self.qk_scale - ) - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = F.softmax(attn_weights, dim=-1) - - output = torch.matmul(attn_weights, v) - - # Reshape back to [batch, seq_length, hidden_size] - output = ( - output.transpose(1, 2) - .contiguous() - .view(batch_size, seq_length, -1) - ) - output = self.o_proj(output) - - return output - - -class EnhancedBytePredictor(nn.Module): - """Enhanced byte prediction model with state-of-the-art features.""" - - def __init__(self, config: ModelConfig): - super().__init__() - self.config = config - - # Token embeddings - self.tok_embeddings = nn.Embedding( - config.vocab_size, config.hidden_size - ) - - # Transformer layers - self.layers = nn.ModuleList( - [ - nn.ModuleDict( - { - "attention": MultiQueryAttention(config), - "attention_norm": nn.LayerNorm( - config.hidden_size, - eps=config.layer_norm_eps, - ), - "feed_forward": nn.Sequential( - nn.Linear( - config.hidden_size, - 4 * config.hidden_size, - ), - nn.GELU(), - nn.Linear( - 4 * config.hidden_size, - config.hidden_size, - ), - ), - "feed_forward_norm": nn.LayerNorm( - config.hidden_size, - eps=config.layer_norm_eps, - ), - } - ) - for _ in range(config.num_layers) - ] - ) - - self.norm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) - self.output = nn.Linear( - config.hidden_size, config.vocab_size, bias=False - ) - - # Initialize weights - self.apply(self._init_weights) - - def _init_weights(self, module: nn.Module) -> None: - """Initialize weights with scaled normal distribution.""" - if isinstance(module, nn.Linear): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - - def forward( - self, - input_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Forward pass of the model. - - Args: - input_ids: Tensor of shape (batch_size, sequence_length) - attention_mask: Optional attention mask - - Returns: - Tensor of logits with shape (batch_size, sequence_length, vocab_size) - """ - hidden_states = self.tok_embeddings(input_ids) - - # Create causal mask if needed - if attention_mask is None: - attention_mask = torch.triu( - torch.ones( - (input_ids.size(1), input_ids.size(1)), - device=input_ids.device, - dtype=torch.bool, - ), - diagonal=1, - ) - attention_mask = attention_mask.masked_fill( - attention_mask == 1, float("-inf") - ) - - # Apply transformer layers - for layer in self.layers: - # Attention block - hidden_states = hidden_states + layer["attention"]( - layer["attention_norm"](hidden_states), attention_mask - ) - - # Feed-forward block - hidden_states = hidden_states + layer["feed_forward"]( - layer["feed_forward_norm"](hidden_states) - ) - - hidden_states = self.norm(hidden_states) - logits = self.output(hidden_states) - - return logits - - def compute_loss( - self, - input_ids: torch.Tensor, - target_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Compute cross entropy loss. - - Args: - input_ids: Input token ids - target_ids: Target token ids - attention_mask: Optional attention mask - - Returns: - Loss value - """ - logits = self(input_ids, attention_mask) - loss = F.cross_entropy( - rearrange(logits, "b s v -> (b s) v"), - rearrange(target_ids, "b s -> (b s)"), - ) - return loss - - @torch.no_grad() - def _generate( - self, - input_ids: torch.Tensor, - max_new_tokens: int = 100, - temperature: float = 1.0, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - repetition_penalty: float = 1.0, - ) -> torch.Tensor: - """ - Generate new tokens autoregressively. - - Args: - input_ids: Starting sequence - max_new_tokens: Number of tokens to generate - temperature: Sampling temperature - top_k: K for top-k sampling - top_p: P for nucleus sampling - repetition_penalty: Penalty for repeating tokens - - Returns: - Generated sequence - """ - batch_size, seq_length = input_ids.shape - generated = input_ids.clone() - - for _ in range(max_new_tokens): - if generated.size(1) >= self.config.max_sequence_length: - break - - # Forward pass - logits = self(generated)[:, -1, :] - - # Apply temperature - logits = logits / temperature - - # Apply repetition penalty - if repetition_penalty != 1.0: - for i in range(batch_size): - for token_id in set(generated[i].tolist()): - logits[i, token_id] /= repetition_penalty - - # Apply top-k sampling - if top_k is not None: - indices_to_remove = ( - logits - < torch.topk(logits, top_k)[0][..., -1, None] - ) - logits[indices_to_remove] = float("-inf") - - # Apply nucleus (top-p) sampling - if top_p is not None: - sorted_logits, sorted_indices = torch.sort( - logits, descending=True - ) - cumulative_probs = torch.cumsum( - F.softmax(sorted_logits, dim=-1), dim=-1 - ) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = ( - sorted_indices_to_remove[..., :-1].clone() - ) - sorted_indices_to_remove[..., 0] = 0 - - indices_to_remove = torch.zeros_like( - logits, dtype=torch.bool - ) - indices_to_remove.scatter_( - 1, sorted_indices, sorted_indices_to_remove - ) - logits[indices_to_remove] = float("-inf") - - # Sample next token - probs = F.softmax(logits, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - - # Append to sequence - generated = torch.cat([generated, next_token], dim=1) - - return generated - - def generate( - self, - input_ids: torch.Tensor, - max_new_tokens: int = 100, - temperature: float = 1.0, - top_k: Optional[int] = None, - top_p: Optional[float] = None, - repetition_penalty: float = 1.0, - ): - tensor_data = self._generate( - input_ids=input_ids, - max_new_tokens=max_new_tokens, - temperature=temperature, - top_k=top_k, - top_p=top_p, - repetition_penalty=repetition_penalty, - ) - - return tensor_to_data(tensor_data) - - -# import torch -# from typing import Optional - - -class DataType(Enum): - TEXT = "text" - IMAGE = "image" - AUDIO = "audio" - VIDEO = "video" - BINARY = "binary" - - -class ByteDetokenizer: - """Utility class for converting model output bytes back to original data formats.""" - - @staticmethod - def tensor_to_bytes(tensor: torch.Tensor) -> bytes: - """Convert model output tensor to bytes.""" - # Convert logits/probabilities to byte values - if tensor.dim() > 1: - # If we have logits, convert to byte indices - byte_indices = tensor.argmax(dim=-1) - else: - byte_indices = tensor - - # Convert to Python bytes - return bytes( - byte_indices.cpu().numpy().astype(np.uint8).tolist() - ) - - @staticmethod - def decode_text(byte_sequence: bytes) -> str: - """Convert bytes to text.""" - try: - return byte_sequence.decode("utf-8") - except UnicodeDecodeError: - # Try with error handling - return byte_sequence.decode("utf-8", errors="replace") - - @staticmethod - def decode_image( - byte_sequence: bytes, - mode: str = "RGB", - size: Optional[tuple] = None, - ) -> Image.Image: - """Convert bytes to image. - - Args: - byte_sequence: Raw image bytes - mode: Image mode (RGB, RGBA, L, etc.) - size: Optional tuple of (width, height) - """ - try: - # Try to load as-is first (for standard image formats) - img = Image.open(io.BytesIO(byte_sequence)) - if size: - img = img.resize(size) - return img - except: - # If failed, assume raw pixel data - if not size: - # Try to determine size from byte count - pixel_count = len(byte_sequence) // len(mode) - size = ( - int(np.sqrt(pixel_count)), - int(np.sqrt(pixel_count)), - ) - - # Convert raw bytes to pixel array - pixels = np.frombuffer(byte_sequence, dtype=np.uint8) - pixels = pixels.reshape((*size, len(mode))) - - return Image.fromarray(pixels, mode=mode) - - @staticmethod - def decode_audio( - byte_sequence: bytes, - sample_rate: int = 44100, - channels: int = 2, - sample_width: int = 2, - ) -> np.ndarray: - """Convert bytes to audio samples. - - Args: - byte_sequence: Raw audio bytes - sample_rate: Audio sample rate in Hz - channels: Number of audio channels - sample_width: Bytes per sample (1, 2, or 4) - """ - # Determine format string based on sample width - format_str = { - 1: "b", # signed char - 2: "h", # short - 4: "i", # int - }[sample_width] - - # Unpack bytes to samples - sample_count = len(byte_sequence) // (channels * sample_width) - samples = struct.unpack( - f"<{sample_count * channels}{format_str}", byte_sequence - ) - - # Reshape to [samples, channels] - return np.array(samples).reshape(-1, channels) - - def decode_data( - self, - model_output: Union[torch.Tensor, bytes], - data_type: DataType, - **kwargs, - ) -> Union[str, Image.Image, np.ndarray, bytes]: - """Main method to decode model output to desired format. - - Args: - model_output: Either tensor from model or raw bytes - data_type: Type of data to decode to - **kwargs: Additional parameters for specific decoders - - Returns: - Decoded data in specified format - """ - # Convert tensor to bytes if needed - if isinstance(model_output, torch.Tensor): - byte_sequence = self.tensor_to_bytes(model_output) - else: - byte_sequence = model_output - - # Decode based on type - if data_type == DataType.TEXT: - return self.decode_text(byte_sequence) - elif data_type == DataType.IMAGE: - return self.decode_image(byte_sequence, **kwargs) - elif data_type == DataType.AUDIO: - return self.decode_audio(byte_sequence, **kwargs) - elif data_type == DataType.VIDEO: - raise NotImplementedError( - "Video decoding not yet implemented" - ) - else: # BINARY - return byte_sequence - - -# Usage example - - -class Modality(Enum): - TEXT = auto() - IMAGE = auto() - AUDIO = auto() - VIDEO = auto() - BINARY = auto() - MULTIMODAL = auto() - - -@dataclass -class ModalityInfo: - """Information about detected modality.""" - - modality: Modality - confidence: float - metadata: Dict[str, any] - sub_modalities: Optional[List["ModalityInfo"]] = None - - -class ModalityDetector: - """Detects data modalities from byte sequences.""" - - # Common file signatures (magic numbers) - SIGNATURES = { - # Images - b"\xFF\xD8\xFF": "JPEG", - b"\x89PNG\r\n\x1a\n": "PNG", - b"GIF87a": "GIF", - b"GIF89a": "GIF", - b"RIFF": "WEBP", - # Audio - b"RIFF....WAVE": "WAV", - b"ID3": "MP3", - b"\xFF\xFB": "MP3", - b"OggS": "OGG", - # Video - b"\x00\x00\x00\x18ftypmp42": "MP4", - b"\x00\x00\x00\x1Cftypav01": "MP4", - b"\x1A\x45\xDF\xA3": "WEBM", - } - - def __init__(self): - self.magic = magic.Magic(mime=True) - - def _check_text_probability(self, data: bytes) -> float: - """Estimate probability that data is text.""" - # Check if data is valid UTF-8 - try: - data.decode("utf-8") - # Count printable ASCII characters - printable = sum(1 for b in data if 32 <= b <= 126) - return printable / len(data) - except UnicodeDecodeError: - return 0.0 - - def _check_image_validity(self, data: bytes) -> Tuple[bool, Dict]: - """Check if data is a valid image and extract metadata.""" - try: - with io.BytesIO(data) as bio: - img = Image.open(bio) - return True, { - "format": img.format, - "size": img.size, - "mode": img.mode, - } - except: - return False, {} - - def _check_audio_validity(self, data: bytes) -> Tuple[bool, Dict]: - """Check if data is valid audio and extract metadata.""" - try: - with io.BytesIO(data) as bio: - # Try to parse as WAV - with wave.open(bio) as wav: - return True, { - "channels": wav.getnchannels(), - "sample_width": wav.getsampwidth(), - "framerate": wav.getframerate(), - "frames": wav.getnframes(), - } - except: - # Check for other audio signatures - for sig in [b"ID3", b"\xFF\xFB", b"OggS"]: - if data.startswith(sig): - return True, {"format": "compressed_audio"} - return False, {} - - def _detect_boundaries( - self, data: bytes - ) -> List[Tuple[int, int, Modality]]: - """Detect boundaries between different modalities in the data.""" - boundaries = [] - current_pos = 0 - - while current_pos < len(data): - # Look for known signatures - for sig, format_type in self.SIGNATURES.items(): - if data[current_pos:].startswith(sig): - # Found a signature, determine its length - if format_type in ["JPEG", "PNG", "GIF"]: - # Find image end - try: - with io.BytesIO( - data[current_pos:] - ) as bio: - img = Image.open(bio) - img.verify() - size = bio.tell() - boundaries.append( - ( - current_pos, - current_pos + size, - Modality.IMAGE, - ) - ) - current_pos += size - continue - except: - pass - - # Check for text sections - text_prob = self._check_text_probability( - data[current_pos : current_pos + 1024] - ) - if text_prob > 0.8: - # Look for end of text section - end_pos = current_pos + 1 - while end_pos < len(data): - if ( - self._check_text_probability( - data[end_pos : end_pos + 32] - ) - < 0.5 - ): - break - end_pos += 1 - boundaries.append( - (current_pos, end_pos, Modality.TEXT) - ) - current_pos = end_pos - continue - - current_pos += 1 - - return boundaries - - def detect_modality(self, data: bytes) -> ModalityInfo: - """Detect modality of byte sequence.""" - # First check for single modality - mime_type = self.magic.from_buffer(data) - - # Check text - text_prob = self._check_text_probability(data) - if text_prob > 0.9: - return ModalityInfo( - modality=Modality.TEXT, - confidence=text_prob, - metadata={"mime_type": mime_type}, - ) - - # Check image - is_image, image_meta = self._check_image_validity(data) - if is_image: - return ModalityInfo( - modality=Modality.IMAGE, - confidence=1.0, - metadata={**image_meta, "mime_type": mime_type}, - ) - - # Check audio - is_audio, audio_meta = self._check_audio_validity(data) - if is_audio: - return ModalityInfo( - modality=Modality.AUDIO, - confidence=1.0, - metadata={**audio_meta, "mime_type": mime_type}, - ) - - # Check for multimodal content - boundaries = self._detect_boundaries(data) - if len(boundaries) > 1: - sub_modalities = [] - for start, end, modality in boundaries: - chunk_data = data[start:end] - sub_info = self.detect_modality(chunk_data) - if sub_info.modality != Modality.BINARY: - sub_modalities.append(sub_info) - - if sub_modalities: - return ModalityInfo( - modality=Modality.MULTIMODAL, - confidence=0.8, - metadata={"mime_type": "multipart/mixed"}, - sub_modalities=sub_modalities, - ) - - # Default to binary - return ModalityInfo( - modality=Modality.BINARY, - confidence=0.5, - metadata={"mime_type": mime_type}, - ) - - def split_modalities( - self, data: bytes - ) -> List[Tuple[Modality, bytes, Dict]]: - """Split multimodal data into separate modalities.""" - boundaries = self._detect_boundaries(data) - result = [] - - for start, end, modality in boundaries: - chunk = data[start:end] - info = self.detect_modality(chunk) - result.append((modality, chunk, info.metadata)) - - return result - - -class AutoDetectBytesDecoder: - """Decoder that automatically detects and decodes different modalities.""" - - def __init__(self): - self.detector = ModalityDetector() - self.text_decoder = ByteDetokenizer() # From previous example - - def decode( - self, data: bytes - ) -> Union[str, Image.Image, np.ndarray, List[any]]: - """Automatically detect and decode byte sequence.""" - info = self.detector.detect_modality(data) - - if info.modality == Modality.MULTIMODAL: - # Handle multimodal content - parts = self.detector.split_modalities(data) - return [ - self.decode(chunk) for modality, chunk, _ in parts - ] - - if info.modality == Modality.TEXT: - return self.text_decoder.decode_text(data) - elif info.modality == Modality.IMAGE: - return self.text_decoder.decode_image(data) - elif info.modality == Modality.AUDIO: - return self.text_decoder.decode_audio(data) - else: - return data - - -# # Example usage -# def demo_auto_detection(): -# """Demonstrate auto modality detection.""" -# # Create mixed content -# text = "Hello, World!".encode('utf-8') - -# # Create a small test image -# img = Image.new('RGB', (100, 100), color='red') -# img_bytes = io.BytesIO() -# img.save(img_bytes, format='PNG') - -# # Combine into multimodal content -# mixed_content = text + img_bytes.getvalue() - -# # Initialize decoder -# decoder = AutoDetectBytesDecoder() - -# # Decode -# result = decoder.decode(mixed_content) - -# if isinstance(result, list): -# print("Detected multimodal content:") -# for i, part in enumerate(result): -# print(f"Part {i+1}: {type(part)}") - -# if __name__ == "__main__": -# demo_auto_detection() - - -def tensor_to_data(tensor: Tensor): - byte_sequence = ByteDetokenizer.tensor_to_bytes(tensor) - - # Initialize auto-detector - decoder = AutoDetectBytesDecoder() - - # Decode with automatic detection - result = decoder.decode(byte_sequence) - - return result - - -def demo_byte_predictor(): - """Demo with smaller dimensions to test.""" - # Initialize model configuration with adjusted dimensions - config = ModelConfig( - vocab_size=256, - hidden_size=128, # Smaller for testing - num_layers=2, # Fewer layers for testing - num_key_value_heads=2, - num_query_heads=4, - dropout=0.1, - max_sequence_length=1024, - ) - - # Initialize model - model = EnhancedBytePredictor(config) - logger.info("Model initialized") - - # Move to GPU if available - device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu" - ) - model = model.to(device) - logger.info(f"Using device: {device}") - - # Create sample input data - batch_size = 2 - seq_length = 16 # Shorter sequence for testing - input_ids = torch.randint( - 0, config.vocab_size, (batch_size, seq_length), device=device - ) - logger.info(f"Created input tensor of shape: {input_ids.shape}") - - # Test forward pass - try: - logits = model(input_ids) - logger.info( - f"Forward pass successful! Output shape: {logits.shape}" - ) - - # Test loss computation - target_ids = torch.randint( - 0, - config.vocab_size, - (batch_size, seq_length), - device=device, - ) - loss = model.compute_loss(input_ids, target_ids) - logger.info( - f"Loss computation successful! Loss value: {loss.item():.4f}" - ) - - # Test generation - prompt = torch.randint( - 0, - config.vocab_size, - (1, 4), # Very short prompt for testing - device=device, - ) - generated = model.generate( - prompt, max_new_tokens=8, temperature=0.8, top_k=50 - ) - logger.info( - f"Generation successful! Generated shape: {generated.shape}" - ) - - except Exception as e: - logger.error(f"Error during execution: {str(e)}") - raise - - -if __name__ == "__main__": - # Set up logging - # logger.remove() # Remove default handler - # logger.add(sys.stderr, format="{time:HH:mm:ss} | {level} | {message}") - - demo_byte_predictor()