parent
63818982f8
commit
0e626a686e
@ -1,898 +0,0 @@
|
|||||||
from enum import Enum
|
|
||||||
from typing import Union, Optional
|
|
||||||
import io
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import struct
|
|
||||||
|
|
||||||
|
|
||||||
from enum import auto
|
|
||||||
from typing import List, Dict, Tuple
|
|
||||||
import wave
|
|
||||||
from dataclasses import dataclass
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from loguru import logger
|
|
||||||
from einops import rearrange
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelConfig:
|
|
||||||
"""Configuration for the enhanced BytePredictor model."""
|
|
||||||
|
|
||||||
vocab_size: int = 256 # Standard byte range
|
|
||||||
hidden_size: int = 1024
|
|
||||||
num_layers: int = 12
|
|
||||||
num_key_value_heads: int = 8 # For multi-query attention
|
|
||||||
num_query_heads: int = 32 # More query heads than kv heads
|
|
||||||
dropout: float = 0.1
|
|
||||||
max_sequence_length: int = 8192
|
|
||||||
rope_theta: float = 10000.0
|
|
||||||
layer_norm_eps: float = 1e-5
|
|
||||||
vocab_parallel: bool = False
|
|
||||||
qk_norm: bool = True
|
|
||||||
qk_norm_scale: float = None
|
|
||||||
attention_bias: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class MultiQueryAttention(nn.Module):
|
|
||||||
"""Fixed Multi-Query Attention implementation."""
|
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.num_query_heads = config.num_query_heads
|
|
||||||
self.num_key_value_heads = config.num_key_value_heads
|
|
||||||
self.head_dim = config.hidden_size // config.num_query_heads
|
|
||||||
self.qk_scale = config.qk_norm_scale or (self.head_dim**-0.5)
|
|
||||||
|
|
||||||
self.q_proj = nn.Linear(
|
|
||||||
config.hidden_size, config.num_query_heads * self.head_dim
|
|
||||||
)
|
|
||||||
self.k_proj = nn.Linear(
|
|
||||||
config.hidden_size,
|
|
||||||
config.num_key_value_heads * self.head_dim,
|
|
||||||
)
|
|
||||||
self.v_proj = nn.Linear(
|
|
||||||
config.hidden_size,
|
|
||||||
config.num_key_value_heads * self.head_dim,
|
|
||||||
)
|
|
||||||
self.o_proj = nn.Linear(
|
|
||||||
config.num_query_heads * self.head_dim, config.hidden_size
|
|
||||||
)
|
|
||||||
|
|
||||||
self.qk_norm = config.qk_norm
|
|
||||||
if self.qk_norm:
|
|
||||||
self.q_norm = nn.LayerNorm(self.head_dim)
|
|
||||||
self.k_norm = nn.LayerNorm(self.head_dim)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
batch_size, seq_length, _ = hidden_states.shape
|
|
||||||
|
|
||||||
# Project and reshape
|
|
||||||
q = self.q_proj(hidden_states)
|
|
||||||
k = self.k_proj(hidden_states)
|
|
||||||
v = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
# Reshape to [seq_len, batch, heads, head_dim]
|
|
||||||
q = q.view(
|
|
||||||
batch_size,
|
|
||||||
seq_length,
|
|
||||||
self.num_query_heads,
|
|
||||||
self.head_dim,
|
|
||||||
).permute(1, 0, 2, 3)
|
|
||||||
k = k.view(
|
|
||||||
batch_size,
|
|
||||||
seq_length,
|
|
||||||
self.num_key_value_heads,
|
|
||||||
self.head_dim,
|
|
||||||
).permute(1, 0, 2, 3)
|
|
||||||
v = v.view(
|
|
||||||
batch_size,
|
|
||||||
seq_length,
|
|
||||||
self.num_key_value_heads,
|
|
||||||
self.head_dim,
|
|
||||||
).permute(1, 0, 2, 3)
|
|
||||||
|
|
||||||
# Apply rotary embeddings
|
|
||||||
# q, k = self.rotary(q, k, seq_length)
|
|
||||||
|
|
||||||
# Apply QK normalization if enabled
|
|
||||||
if self.qk_norm:
|
|
||||||
q = self.q_norm(q)
|
|
||||||
k = self.k_norm(k)
|
|
||||||
|
|
||||||
# Handle MQA head expansion
|
|
||||||
if self.num_key_value_heads != self.num_query_heads:
|
|
||||||
k = k.repeat_interleave(
|
|
||||||
self.num_query_heads // self.num_key_value_heads,
|
|
||||||
dim=2,
|
|
||||||
)
|
|
||||||
v = v.repeat_interleave(
|
|
||||||
self.num_query_heads // self.num_key_value_heads,
|
|
||||||
dim=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute attention
|
|
||||||
# Reshape for matmul: [batch, heads, seq_length, head_dim]
|
|
||||||
q = q.permute(1, 2, 0, 3)
|
|
||||||
k = k.permute(1, 2, 0, 3)
|
|
||||||
v = v.permute(1, 2, 0, 3)
|
|
||||||
|
|
||||||
attn_weights = (
|
|
||||||
torch.matmul(q, k.transpose(-2, -1)) * self.qk_scale
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
|
||||||
|
|
||||||
output = torch.matmul(attn_weights, v)
|
|
||||||
|
|
||||||
# Reshape back to [batch, seq_length, hidden_size]
|
|
||||||
output = (
|
|
||||||
output.transpose(1, 2)
|
|
||||||
.contiguous()
|
|
||||||
.view(batch_size, seq_length, -1)
|
|
||||||
)
|
|
||||||
output = self.o_proj(output)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class EnhancedBytePredictor(nn.Module):
|
|
||||||
"""Enhanced byte prediction model with state-of-the-art features."""
|
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
# Token embeddings
|
|
||||||
self.tok_embeddings = nn.Embedding(
|
|
||||||
config.vocab_size, config.hidden_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Transformer layers
|
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
[
|
|
||||||
nn.ModuleDict(
|
|
||||||
{
|
|
||||||
"attention": MultiQueryAttention(config),
|
|
||||||
"attention_norm": nn.LayerNorm(
|
|
||||||
config.hidden_size,
|
|
||||||
eps=config.layer_norm_eps,
|
|
||||||
),
|
|
||||||
"feed_forward": nn.Sequential(
|
|
||||||
nn.Linear(
|
|
||||||
config.hidden_size,
|
|
||||||
4 * config.hidden_size,
|
|
||||||
),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(
|
|
||||||
4 * config.hidden_size,
|
|
||||||
config.hidden_size,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
"feed_forward_norm": nn.LayerNorm(
|
|
||||||
config.hidden_size,
|
|
||||||
eps=config.layer_norm_eps,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for _ in range(config.num_layers)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.norm = nn.LayerNorm(
|
|
||||||
config.hidden_size, eps=config.layer_norm_eps
|
|
||||||
)
|
|
||||||
self.output = nn.Linear(
|
|
||||||
config.hidden_size, config.vocab_size, bias=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize weights
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
|
|
||||||
def _init_weights(self, module: nn.Module) -> None:
|
|
||||||
"""Initialize weights with scaled normal distribution."""
|
|
||||||
if isinstance(module, nn.Linear):
|
|
||||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
||||||
if module.bias is not None:
|
|
||||||
torch.nn.init.zeros_(module.bias)
|
|
||||||
elif isinstance(module, nn.Embedding):
|
|
||||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Forward pass of the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: Tensor of shape (batch_size, sequence_length)
|
|
||||||
attention_mask: Optional attention mask
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor of logits with shape (batch_size, sequence_length, vocab_size)
|
|
||||||
"""
|
|
||||||
hidden_states = self.tok_embeddings(input_ids)
|
|
||||||
|
|
||||||
# Create causal mask if needed
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = torch.triu(
|
|
||||||
torch.ones(
|
|
||||||
(input_ids.size(1), input_ids.size(1)),
|
|
||||||
device=input_ids.device,
|
|
||||||
dtype=torch.bool,
|
|
||||||
),
|
|
||||||
diagonal=1,
|
|
||||||
)
|
|
||||||
attention_mask = attention_mask.masked_fill(
|
|
||||||
attention_mask == 1, float("-inf")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply transformer layers
|
|
||||||
for layer in self.layers:
|
|
||||||
# Attention block
|
|
||||||
hidden_states = hidden_states + layer["attention"](
|
|
||||||
layer["attention_norm"](hidden_states), attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
# Feed-forward block
|
|
||||||
hidden_states = hidden_states + layer["feed_forward"](
|
|
||||||
layer["feed_forward_norm"](hidden_states)
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
logits = self.output(hidden_states)
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def compute_loss(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
target_ids: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute cross entropy loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: Input token ids
|
|
||||||
target_ids: Target token ids
|
|
||||||
attention_mask: Optional attention mask
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Loss value
|
|
||||||
"""
|
|
||||||
logits = self(input_ids, attention_mask)
|
|
||||||
loss = F.cross_entropy(
|
|
||||||
rearrange(logits, "b s v -> (b s) v"),
|
|
||||||
rearrange(target_ids, "b s -> (b s)"),
|
|
||||||
)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _generate(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
max_new_tokens: int = 100,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
repetition_penalty: float = 1.0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Generate new tokens autoregressively.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: Starting sequence
|
|
||||||
max_new_tokens: Number of tokens to generate
|
|
||||||
temperature: Sampling temperature
|
|
||||||
top_k: K for top-k sampling
|
|
||||||
top_p: P for nucleus sampling
|
|
||||||
repetition_penalty: Penalty for repeating tokens
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generated sequence
|
|
||||||
"""
|
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
generated = input_ids.clone()
|
|
||||||
|
|
||||||
for _ in range(max_new_tokens):
|
|
||||||
if generated.size(1) >= self.config.max_sequence_length:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Forward pass
|
|
||||||
logits = self(generated)[:, -1, :]
|
|
||||||
|
|
||||||
# Apply temperature
|
|
||||||
logits = logits / temperature
|
|
||||||
|
|
||||||
# Apply repetition penalty
|
|
||||||
if repetition_penalty != 1.0:
|
|
||||||
for i in range(batch_size):
|
|
||||||
for token_id in set(generated[i].tolist()):
|
|
||||||
logits[i, token_id] /= repetition_penalty
|
|
||||||
|
|
||||||
# Apply top-k sampling
|
|
||||||
if top_k is not None:
|
|
||||||
indices_to_remove = (
|
|
||||||
logits
|
|
||||||
< torch.topk(logits, top_k)[0][..., -1, None]
|
|
||||||
)
|
|
||||||
logits[indices_to_remove] = float("-inf")
|
|
||||||
|
|
||||||
# Apply nucleus (top-p) sampling
|
|
||||||
if top_p is not None:
|
|
||||||
sorted_logits, sorted_indices = torch.sort(
|
|
||||||
logits, descending=True
|
|
||||||
)
|
|
||||||
cumulative_probs = torch.cumsum(
|
|
||||||
F.softmax(sorted_logits, dim=-1), dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove tokens with cumulative probability above the threshold
|
|
||||||
sorted_indices_to_remove = cumulative_probs > top_p
|
|
||||||
sorted_indices_to_remove[..., 1:] = (
|
|
||||||
sorted_indices_to_remove[..., :-1].clone()
|
|
||||||
)
|
|
||||||
sorted_indices_to_remove[..., 0] = 0
|
|
||||||
|
|
||||||
indices_to_remove = torch.zeros_like(
|
|
||||||
logits, dtype=torch.bool
|
|
||||||
)
|
|
||||||
indices_to_remove.scatter_(
|
|
||||||
1, sorted_indices, sorted_indices_to_remove
|
|
||||||
)
|
|
||||||
logits[indices_to_remove] = float("-inf")
|
|
||||||
|
|
||||||
# Sample next token
|
|
||||||
probs = F.softmax(logits, dim=-1)
|
|
||||||
next_token = torch.multinomial(probs, num_samples=1)
|
|
||||||
|
|
||||||
# Append to sequence
|
|
||||||
generated = torch.cat([generated, next_token], dim=1)
|
|
||||||
|
|
||||||
return generated
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
max_new_tokens: int = 100,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
repetition_penalty: float = 1.0,
|
|
||||||
):
|
|
||||||
tensor_data = self._generate(
|
|
||||||
input_ids=input_ids,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
|
||||||
repetition_penalty=repetition_penalty,
|
|
||||||
)
|
|
||||||
|
|
||||||
return tensor_to_data(tensor_data)
|
|
||||||
|
|
||||||
|
|
||||||
# import torch
|
|
||||||
# from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class DataType(Enum):
|
|
||||||
TEXT = "text"
|
|
||||||
IMAGE = "image"
|
|
||||||
AUDIO = "audio"
|
|
||||||
VIDEO = "video"
|
|
||||||
BINARY = "binary"
|
|
||||||
|
|
||||||
|
|
||||||
class ByteDetokenizer:
|
|
||||||
"""Utility class for converting model output bytes back to original data formats."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def tensor_to_bytes(tensor: torch.Tensor) -> bytes:
|
|
||||||
"""Convert model output tensor to bytes."""
|
|
||||||
# Convert logits/probabilities to byte values
|
|
||||||
if tensor.dim() > 1:
|
|
||||||
# If we have logits, convert to byte indices
|
|
||||||
byte_indices = tensor.argmax(dim=-1)
|
|
||||||
else:
|
|
||||||
byte_indices = tensor
|
|
||||||
|
|
||||||
# Convert to Python bytes
|
|
||||||
return bytes(
|
|
||||||
byte_indices.cpu().numpy().astype(np.uint8).tolist()
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def decode_text(byte_sequence: bytes) -> str:
|
|
||||||
"""Convert bytes to text."""
|
|
||||||
try:
|
|
||||||
return byte_sequence.decode("utf-8")
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
# Try with error handling
|
|
||||||
return byte_sequence.decode("utf-8", errors="replace")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def decode_image(
|
|
||||||
byte_sequence: bytes,
|
|
||||||
mode: str = "RGB",
|
|
||||||
size: Optional[tuple] = None,
|
|
||||||
) -> Image.Image:
|
|
||||||
"""Convert bytes to image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
byte_sequence: Raw image bytes
|
|
||||||
mode: Image mode (RGB, RGBA, L, etc.)
|
|
||||||
size: Optional tuple of (width, height)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Try to load as-is first (for standard image formats)
|
|
||||||
img = Image.open(io.BytesIO(byte_sequence))
|
|
||||||
if size:
|
|
||||||
img = img.resize(size)
|
|
||||||
return img
|
|
||||||
except:
|
|
||||||
# If failed, assume raw pixel data
|
|
||||||
if not size:
|
|
||||||
# Try to determine size from byte count
|
|
||||||
pixel_count = len(byte_sequence) // len(mode)
|
|
||||||
size = (
|
|
||||||
int(np.sqrt(pixel_count)),
|
|
||||||
int(np.sqrt(pixel_count)),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert raw bytes to pixel array
|
|
||||||
pixels = np.frombuffer(byte_sequence, dtype=np.uint8)
|
|
||||||
pixels = pixels.reshape((*size, len(mode)))
|
|
||||||
|
|
||||||
return Image.fromarray(pixels, mode=mode)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def decode_audio(
|
|
||||||
byte_sequence: bytes,
|
|
||||||
sample_rate: int = 44100,
|
|
||||||
channels: int = 2,
|
|
||||||
sample_width: int = 2,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Convert bytes to audio samples.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
byte_sequence: Raw audio bytes
|
|
||||||
sample_rate: Audio sample rate in Hz
|
|
||||||
channels: Number of audio channels
|
|
||||||
sample_width: Bytes per sample (1, 2, or 4)
|
|
||||||
"""
|
|
||||||
# Determine format string based on sample width
|
|
||||||
format_str = {
|
|
||||||
1: "b", # signed char
|
|
||||||
2: "h", # short
|
|
||||||
4: "i", # int
|
|
||||||
}[sample_width]
|
|
||||||
|
|
||||||
# Unpack bytes to samples
|
|
||||||
sample_count = len(byte_sequence) // (channels * sample_width)
|
|
||||||
samples = struct.unpack(
|
|
||||||
f"<{sample_count * channels}{format_str}", byte_sequence
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reshape to [samples, channels]
|
|
||||||
return np.array(samples).reshape(-1, channels)
|
|
||||||
|
|
||||||
def decode_data(
|
|
||||||
self,
|
|
||||||
model_output: Union[torch.Tensor, bytes],
|
|
||||||
data_type: DataType,
|
|
||||||
**kwargs,
|
|
||||||
) -> Union[str, Image.Image, np.ndarray, bytes]:
|
|
||||||
"""Main method to decode model output to desired format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_output: Either tensor from model or raw bytes
|
|
||||||
data_type: Type of data to decode to
|
|
||||||
**kwargs: Additional parameters for specific decoders
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Decoded data in specified format
|
|
||||||
"""
|
|
||||||
# Convert tensor to bytes if needed
|
|
||||||
if isinstance(model_output, torch.Tensor):
|
|
||||||
byte_sequence = self.tensor_to_bytes(model_output)
|
|
||||||
else:
|
|
||||||
byte_sequence = model_output
|
|
||||||
|
|
||||||
# Decode based on type
|
|
||||||
if data_type == DataType.TEXT:
|
|
||||||
return self.decode_text(byte_sequence)
|
|
||||||
elif data_type == DataType.IMAGE:
|
|
||||||
return self.decode_image(byte_sequence, **kwargs)
|
|
||||||
elif data_type == DataType.AUDIO:
|
|
||||||
return self.decode_audio(byte_sequence, **kwargs)
|
|
||||||
elif data_type == DataType.VIDEO:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Video decoding not yet implemented"
|
|
||||||
)
|
|
||||||
else: # BINARY
|
|
||||||
return byte_sequence
|
|
||||||
|
|
||||||
|
|
||||||
# Usage example
|
|
||||||
|
|
||||||
|
|
||||||
class Modality(Enum):
|
|
||||||
TEXT = auto()
|
|
||||||
IMAGE = auto()
|
|
||||||
AUDIO = auto()
|
|
||||||
VIDEO = auto()
|
|
||||||
BINARY = auto()
|
|
||||||
MULTIMODAL = auto()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModalityInfo:
|
|
||||||
"""Information about detected modality."""
|
|
||||||
|
|
||||||
modality: Modality
|
|
||||||
confidence: float
|
|
||||||
metadata: Dict[str, any]
|
|
||||||
sub_modalities: Optional[List["ModalityInfo"]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ModalityDetector:
|
|
||||||
"""Detects data modalities from byte sequences."""
|
|
||||||
|
|
||||||
# Common file signatures (magic numbers)
|
|
||||||
SIGNATURES = {
|
|
||||||
# Images
|
|
||||||
b"\xFF\xD8\xFF": "JPEG",
|
|
||||||
b"\x89PNG\r\n\x1a\n": "PNG",
|
|
||||||
b"GIF87a": "GIF",
|
|
||||||
b"GIF89a": "GIF",
|
|
||||||
b"RIFF": "WEBP",
|
|
||||||
# Audio
|
|
||||||
b"RIFF....WAVE": "WAV",
|
|
||||||
b"ID3": "MP3",
|
|
||||||
b"\xFF\xFB": "MP3",
|
|
||||||
b"OggS": "OGG",
|
|
||||||
# Video
|
|
||||||
b"\x00\x00\x00\x18ftypmp42": "MP4",
|
|
||||||
b"\x00\x00\x00\x1Cftypav01": "MP4",
|
|
||||||
b"\x1A\x45\xDF\xA3": "WEBM",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.magic = magic.Magic(mime=True)
|
|
||||||
|
|
||||||
def _check_text_probability(self, data: bytes) -> float:
|
|
||||||
"""Estimate probability that data is text."""
|
|
||||||
# Check if data is valid UTF-8
|
|
||||||
try:
|
|
||||||
data.decode("utf-8")
|
|
||||||
# Count printable ASCII characters
|
|
||||||
printable = sum(1 for b in data if 32 <= b <= 126)
|
|
||||||
return printable / len(data)
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
def _check_image_validity(self, data: bytes) -> Tuple[bool, Dict]:
|
|
||||||
"""Check if data is a valid image and extract metadata."""
|
|
||||||
try:
|
|
||||||
with io.BytesIO(data) as bio:
|
|
||||||
img = Image.open(bio)
|
|
||||||
return True, {
|
|
||||||
"format": img.format,
|
|
||||||
"size": img.size,
|
|
||||||
"mode": img.mode,
|
|
||||||
}
|
|
||||||
except:
|
|
||||||
return False, {}
|
|
||||||
|
|
||||||
def _check_audio_validity(self, data: bytes) -> Tuple[bool, Dict]:
|
|
||||||
"""Check if data is valid audio and extract metadata."""
|
|
||||||
try:
|
|
||||||
with io.BytesIO(data) as bio:
|
|
||||||
# Try to parse as WAV
|
|
||||||
with wave.open(bio) as wav:
|
|
||||||
return True, {
|
|
||||||
"channels": wav.getnchannels(),
|
|
||||||
"sample_width": wav.getsampwidth(),
|
|
||||||
"framerate": wav.getframerate(),
|
|
||||||
"frames": wav.getnframes(),
|
|
||||||
}
|
|
||||||
except:
|
|
||||||
# Check for other audio signatures
|
|
||||||
for sig in [b"ID3", b"\xFF\xFB", b"OggS"]:
|
|
||||||
if data.startswith(sig):
|
|
||||||
return True, {"format": "compressed_audio"}
|
|
||||||
return False, {}
|
|
||||||
|
|
||||||
def _detect_boundaries(
|
|
||||||
self, data: bytes
|
|
||||||
) -> List[Tuple[int, int, Modality]]:
|
|
||||||
"""Detect boundaries between different modalities in the data."""
|
|
||||||
boundaries = []
|
|
||||||
current_pos = 0
|
|
||||||
|
|
||||||
while current_pos < len(data):
|
|
||||||
# Look for known signatures
|
|
||||||
for sig, format_type in self.SIGNATURES.items():
|
|
||||||
if data[current_pos:].startswith(sig):
|
|
||||||
# Found a signature, determine its length
|
|
||||||
if format_type in ["JPEG", "PNG", "GIF"]:
|
|
||||||
# Find image end
|
|
||||||
try:
|
|
||||||
with io.BytesIO(
|
|
||||||
data[current_pos:]
|
|
||||||
) as bio:
|
|
||||||
img = Image.open(bio)
|
|
||||||
img.verify()
|
|
||||||
size = bio.tell()
|
|
||||||
boundaries.append(
|
|
||||||
(
|
|
||||||
current_pos,
|
|
||||||
current_pos + size,
|
|
||||||
Modality.IMAGE,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
current_pos += size
|
|
||||||
continue
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Check for text sections
|
|
||||||
text_prob = self._check_text_probability(
|
|
||||||
data[current_pos : current_pos + 1024]
|
|
||||||
)
|
|
||||||
if text_prob > 0.8:
|
|
||||||
# Look for end of text section
|
|
||||||
end_pos = current_pos + 1
|
|
||||||
while end_pos < len(data):
|
|
||||||
if (
|
|
||||||
self._check_text_probability(
|
|
||||||
data[end_pos : end_pos + 32]
|
|
||||||
)
|
|
||||||
< 0.5
|
|
||||||
):
|
|
||||||
break
|
|
||||||
end_pos += 1
|
|
||||||
boundaries.append(
|
|
||||||
(current_pos, end_pos, Modality.TEXT)
|
|
||||||
)
|
|
||||||
current_pos = end_pos
|
|
||||||
continue
|
|
||||||
|
|
||||||
current_pos += 1
|
|
||||||
|
|
||||||
return boundaries
|
|
||||||
|
|
||||||
def detect_modality(self, data: bytes) -> ModalityInfo:
|
|
||||||
"""Detect modality of byte sequence."""
|
|
||||||
# First check for single modality
|
|
||||||
mime_type = self.magic.from_buffer(data)
|
|
||||||
|
|
||||||
# Check text
|
|
||||||
text_prob = self._check_text_probability(data)
|
|
||||||
if text_prob > 0.9:
|
|
||||||
return ModalityInfo(
|
|
||||||
modality=Modality.TEXT,
|
|
||||||
confidence=text_prob,
|
|
||||||
metadata={"mime_type": mime_type},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check image
|
|
||||||
is_image, image_meta = self._check_image_validity(data)
|
|
||||||
if is_image:
|
|
||||||
return ModalityInfo(
|
|
||||||
modality=Modality.IMAGE,
|
|
||||||
confidence=1.0,
|
|
||||||
metadata={**image_meta, "mime_type": mime_type},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check audio
|
|
||||||
is_audio, audio_meta = self._check_audio_validity(data)
|
|
||||||
if is_audio:
|
|
||||||
return ModalityInfo(
|
|
||||||
modality=Modality.AUDIO,
|
|
||||||
confidence=1.0,
|
|
||||||
metadata={**audio_meta, "mime_type": mime_type},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check for multimodal content
|
|
||||||
boundaries = self._detect_boundaries(data)
|
|
||||||
if len(boundaries) > 1:
|
|
||||||
sub_modalities = []
|
|
||||||
for start, end, modality in boundaries:
|
|
||||||
chunk_data = data[start:end]
|
|
||||||
sub_info = self.detect_modality(chunk_data)
|
|
||||||
if sub_info.modality != Modality.BINARY:
|
|
||||||
sub_modalities.append(sub_info)
|
|
||||||
|
|
||||||
if sub_modalities:
|
|
||||||
return ModalityInfo(
|
|
||||||
modality=Modality.MULTIMODAL,
|
|
||||||
confidence=0.8,
|
|
||||||
metadata={"mime_type": "multipart/mixed"},
|
|
||||||
sub_modalities=sub_modalities,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Default to binary
|
|
||||||
return ModalityInfo(
|
|
||||||
modality=Modality.BINARY,
|
|
||||||
confidence=0.5,
|
|
||||||
metadata={"mime_type": mime_type},
|
|
||||||
)
|
|
||||||
|
|
||||||
def split_modalities(
|
|
||||||
self, data: bytes
|
|
||||||
) -> List[Tuple[Modality, bytes, Dict]]:
|
|
||||||
"""Split multimodal data into separate modalities."""
|
|
||||||
boundaries = self._detect_boundaries(data)
|
|
||||||
result = []
|
|
||||||
|
|
||||||
for start, end, modality in boundaries:
|
|
||||||
chunk = data[start:end]
|
|
||||||
info = self.detect_modality(chunk)
|
|
||||||
result.append((modality, chunk, info.metadata))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class AutoDetectBytesDecoder:
|
|
||||||
"""Decoder that automatically detects and decodes different modalities."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.detector = ModalityDetector()
|
|
||||||
self.text_decoder = ByteDetokenizer() # From previous example
|
|
||||||
|
|
||||||
def decode(
|
|
||||||
self, data: bytes
|
|
||||||
) -> Union[str, Image.Image, np.ndarray, List[any]]:
|
|
||||||
"""Automatically detect and decode byte sequence."""
|
|
||||||
info = self.detector.detect_modality(data)
|
|
||||||
|
|
||||||
if info.modality == Modality.MULTIMODAL:
|
|
||||||
# Handle multimodal content
|
|
||||||
parts = self.detector.split_modalities(data)
|
|
||||||
return [
|
|
||||||
self.decode(chunk) for modality, chunk, _ in parts
|
|
||||||
]
|
|
||||||
|
|
||||||
if info.modality == Modality.TEXT:
|
|
||||||
return self.text_decoder.decode_text(data)
|
|
||||||
elif info.modality == Modality.IMAGE:
|
|
||||||
return self.text_decoder.decode_image(data)
|
|
||||||
elif info.modality == Modality.AUDIO:
|
|
||||||
return self.text_decoder.decode_audio(data)
|
|
||||||
else:
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
# # Example usage
|
|
||||||
# def demo_auto_detection():
|
|
||||||
# """Demonstrate auto modality detection."""
|
|
||||||
# # Create mixed content
|
|
||||||
# text = "Hello, World!".encode('utf-8')
|
|
||||||
|
|
||||||
# # Create a small test image
|
|
||||||
# img = Image.new('RGB', (100, 100), color='red')
|
|
||||||
# img_bytes = io.BytesIO()
|
|
||||||
# img.save(img_bytes, format='PNG')
|
|
||||||
|
|
||||||
# # Combine into multimodal content
|
|
||||||
# mixed_content = text + img_bytes.getvalue()
|
|
||||||
|
|
||||||
# # Initialize decoder
|
|
||||||
# decoder = AutoDetectBytesDecoder()
|
|
||||||
|
|
||||||
# # Decode
|
|
||||||
# result = decoder.decode(mixed_content)
|
|
||||||
|
|
||||||
# if isinstance(result, list):
|
|
||||||
# print("Detected multimodal content:")
|
|
||||||
# for i, part in enumerate(result):
|
|
||||||
# print(f"Part {i+1}: {type(part)}")
|
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# demo_auto_detection()
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_to_data(tensor: Tensor):
|
|
||||||
byte_sequence = ByteDetokenizer.tensor_to_bytes(tensor)
|
|
||||||
|
|
||||||
# Initialize auto-detector
|
|
||||||
decoder = AutoDetectBytesDecoder()
|
|
||||||
|
|
||||||
# Decode with automatic detection
|
|
||||||
result = decoder.decode(byte_sequence)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def demo_byte_predictor():
|
|
||||||
"""Demo with smaller dimensions to test."""
|
|
||||||
# Initialize model configuration with adjusted dimensions
|
|
||||||
config = ModelConfig(
|
|
||||||
vocab_size=256,
|
|
||||||
hidden_size=128, # Smaller for testing
|
|
||||||
num_layers=2, # Fewer layers for testing
|
|
||||||
num_key_value_heads=2,
|
|
||||||
num_query_heads=4,
|
|
||||||
dropout=0.1,
|
|
||||||
max_sequence_length=1024,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize model
|
|
||||||
model = EnhancedBytePredictor(config)
|
|
||||||
logger.info("Model initialized")
|
|
||||||
|
|
||||||
# Move to GPU if available
|
|
||||||
device = torch.device(
|
|
||||||
"cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
)
|
|
||||||
model = model.to(device)
|
|
||||||
logger.info(f"Using device: {device}")
|
|
||||||
|
|
||||||
# Create sample input data
|
|
||||||
batch_size = 2
|
|
||||||
seq_length = 16 # Shorter sequence for testing
|
|
||||||
input_ids = torch.randint(
|
|
||||||
0, config.vocab_size, (batch_size, seq_length), device=device
|
|
||||||
)
|
|
||||||
logger.info(f"Created input tensor of shape: {input_ids.shape}")
|
|
||||||
|
|
||||||
# Test forward pass
|
|
||||||
try:
|
|
||||||
logits = model(input_ids)
|
|
||||||
logger.info(
|
|
||||||
f"Forward pass successful! Output shape: {logits.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test loss computation
|
|
||||||
target_ids = torch.randint(
|
|
||||||
0,
|
|
||||||
config.vocab_size,
|
|
||||||
(batch_size, seq_length),
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
loss = model.compute_loss(input_ids, target_ids)
|
|
||||||
logger.info(
|
|
||||||
f"Loss computation successful! Loss value: {loss.item():.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test generation
|
|
||||||
prompt = torch.randint(
|
|
||||||
0,
|
|
||||||
config.vocab_size,
|
|
||||||
(1, 4), # Very short prompt for testing
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
generated = model.generate(
|
|
||||||
prompt, max_new_tokens=8, temperature=0.8, top_k=50
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Generation successful! Generated shape: {generated.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during execution: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Set up logging
|
|
||||||
# logger.remove() # Remove default handler
|
|
||||||
# logger.add(sys.stderr, format="<green>{time:HH:mm:ss}</green> | {level} | {message}")
|
|
||||||
|
|
||||||
demo_byte_predictor()
|
|
Loading…
Reference in new issue