You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
619 lines
18 KiB
619 lines
18 KiB
1 month ago
|
import torch
|
||
|
from torch.utils.data import DataLoader, TensorDataset
|
||
|
import numpy as np
|
||
|
from loguru import logger
|
||
|
|
||
|
from dataclasses import dataclass
|
||
|
from typing import Optional, Tuple, Dict
|
||
|
import math
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from torch import Tensor
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class TransformerConfig:
|
||
|
"""Configuration class for MoE Transformer model parameters."""
|
||
|
|
||
|
vocab_size: int = 50257
|
||
|
hidden_size: int = 768
|
||
|
num_attention_heads: int = 12
|
||
|
num_expert_layers: int = 4
|
||
|
num_experts: int = 8
|
||
|
expert_capacity: int = 32
|
||
|
max_position_embeddings: int = 1024
|
||
|
dropout_prob: float = 0.1
|
||
|
layer_norm_epsilon: float = 1e-5
|
||
|
initializer_range: float = 0.02
|
||
|
num_query_groups: int = 4 # For multi-query attention
|
||
|
|
||
|
|
||
|
class ExpertLayer(nn.Module):
|
||
|
"""Individual expert neural network."""
|
||
|
|
||
|
def __init__(self, config: TransformerConfig):
|
||
|
super().__init__()
|
||
|
self.fc1 = nn.Linear(
|
||
|
config.hidden_size, 4 * config.hidden_size
|
||
|
)
|
||
|
self.fc2 = nn.Linear(
|
||
|
4 * config.hidden_size, config.hidden_size
|
||
|
)
|
||
|
self.activation = nn.GELU()
|
||
|
self.dropout = nn.Dropout(config.dropout_prob)
|
||
|
|
||
|
def forward(self, x: Tensor) -> Tensor:
|
||
|
x = self.fc1(x)
|
||
|
x = self.activation(x)
|
||
|
x = self.dropout(x)
|
||
|
x = self.fc2(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class MixtureOfExperts(nn.Module):
|
||
|
"""Mixture of Experts layer with dynamic routing."""
|
||
|
|
||
|
def __init__(self, config: TransformerConfig):
|
||
|
super().__init__()
|
||
|
self.num_experts = config.num_experts
|
||
|
self.expert_capacity = config.expert_capacity
|
||
|
|
||
|
# Create expert networks
|
||
|
self.experts = nn.ModuleList(
|
||
|
[ExpertLayer(config) for _ in range(config.num_experts)]
|
||
|
)
|
||
|
|
||
|
# Router network
|
||
|
self.router = nn.Linear(
|
||
|
config.hidden_size, config.num_experts
|
||
|
)
|
||
|
|
||
|
def forward(self, x: Tensor) -> Tuple[Tensor, Dict]:
|
||
|
"""Route inputs to experts and combine outputs."""
|
||
|
batch_size, seq_len, hidden_size = x.shape
|
||
|
|
||
|
# Calculate routing probabilities
|
||
|
router_logits = self.router(x)
|
||
|
routing_weights = F.softmax(router_logits, dim=-1)
|
||
|
|
||
|
# Select top-k experts
|
||
|
top_k = 2
|
||
|
gates, indices = torch.topk(routing_weights, top_k, dim=-1)
|
||
|
gates = F.softmax(gates, dim=-1)
|
||
|
|
||
|
# Process inputs through selected experts
|
||
|
final_output = torch.zeros_like(x)
|
||
|
router_load = torch.zeros(self.num_experts, device=x.device)
|
||
|
|
||
|
for i in range(top_k):
|
||
|
expert_index = indices[..., i]
|
||
|
gate = gates[..., i : i + 1]
|
||
|
|
||
|
# Count expert assignments
|
||
|
for j in range(self.num_experts):
|
||
|
router_load[j] += (expert_index == j).float().sum()
|
||
|
|
||
|
# Process through selected experts
|
||
|
for j in range(self.num_experts):
|
||
|
mask = expert_index == j
|
||
|
if not mask.any():
|
||
|
continue
|
||
|
|
||
|
expert_input = x[mask]
|
||
|
expert_output = self.experts[j](expert_input)
|
||
|
final_output[mask] += gate[mask] * expert_output
|
||
|
|
||
|
aux_loss = router_load.float().var() / (
|
||
|
router_load.float().mean() ** 2
|
||
|
)
|
||
|
|
||
|
return final_output, {"load_balancing_loss": aux_loss}
|
||
|
|
||
|
|
||
|
class MultiQueryAttention(nn.Module):
|
||
|
"""Multi-Query Attention mechanism with proper multi-query group handling."""
|
||
|
|
||
|
def __init__(self, config: TransformerConfig):
|
||
|
super().__init__()
|
||
|
self.num_attention_heads = config.num_attention_heads
|
||
|
self.num_query_groups = config.num_query_groups
|
||
|
self.hidden_size = config.hidden_size
|
||
|
self.head_dim = (
|
||
|
config.hidden_size // config.num_attention_heads
|
||
|
)
|
||
|
|
||
|
# Query projection maintains full head dimension
|
||
|
self.q_proj = nn.Linear(
|
||
|
config.hidden_size, config.hidden_size
|
||
|
)
|
||
|
|
||
|
# Key and value projections use reduced number of heads (query groups)
|
||
|
self.k_proj = nn.Linear(
|
||
|
config.hidden_size,
|
||
|
self.head_dim * config.num_query_groups,
|
||
|
)
|
||
|
self.v_proj = nn.Linear(
|
||
|
config.hidden_size,
|
||
|
self.head_dim * config.num_query_groups,
|
||
|
)
|
||
|
|
||
|
self.dropout = nn.Dropout(config.dropout_prob)
|
||
|
|
||
|
# Calculate heads per group for proper reshaping
|
||
|
self.heads_per_group = (
|
||
|
self.num_attention_heads // self.num_query_groups
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: Tensor,
|
||
|
attention_mask: Optional[Tensor] = None,
|
||
|
cache: Optional[Dict[str, Tensor]] = None,
|
||
|
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
||
|
batch_size, seq_length, _ = hidden_states.shape
|
||
|
|
||
|
# Project queries, keys, and values
|
||
|
queries = self.q_proj(hidden_states)
|
||
|
keys = self.k_proj(hidden_states)
|
||
|
values = self.v_proj(hidden_states)
|
||
|
|
||
|
# Reshape queries to full number of heads
|
||
|
queries = queries.view(
|
||
|
batch_size,
|
||
|
seq_length,
|
||
|
self.num_attention_heads,
|
||
|
self.head_dim,
|
||
|
)
|
||
|
|
||
|
# Reshape keys and values to number of query groups
|
||
|
keys = keys.view(
|
||
|
batch_size,
|
||
|
seq_length,
|
||
|
self.num_query_groups,
|
||
|
self.head_dim,
|
||
|
)
|
||
|
values = values.view(
|
||
|
batch_size,
|
||
|
seq_length,
|
||
|
self.num_query_groups,
|
||
|
self.head_dim,
|
||
|
)
|
||
|
|
||
|
# Transpose for batch matrix multiplication
|
||
|
queries = queries.transpose(
|
||
|
1, 2
|
||
|
) # (batch, n_heads, seq_len, head_dim)
|
||
|
keys = keys.transpose(
|
||
|
1, 2
|
||
|
) # (batch, n_groups, seq_len, head_dim)
|
||
|
values = values.transpose(
|
||
|
1, 2
|
||
|
) # (batch, n_groups, seq_len, head_dim)
|
||
|
|
||
|
# Repeat keys and values for each head in the group
|
||
|
keys = keys.repeat_interleave(self.heads_per_group, dim=1)
|
||
|
values = values.repeat_interleave(self.heads_per_group, dim=1)
|
||
|
|
||
|
# Compute attention scores
|
||
|
scale = 1.0 / math.sqrt(self.head_dim)
|
||
|
scores = torch.matmul(queries, keys.transpose(-2, -1)) * scale
|
||
|
|
||
|
if attention_mask is not None:
|
||
|
# Expand attention mask to match scores dimensions
|
||
|
expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||
|
expanded_mask = expanded_mask.expand(
|
||
|
batch_size,
|
||
|
self.num_attention_heads,
|
||
|
seq_length,
|
||
|
seq_length,
|
||
|
)
|
||
|
mask_value = torch.finfo(scores.dtype).min
|
||
|
attention_mask = expanded_mask.eq(0).float() * mask_value
|
||
|
scores = scores + attention_mask
|
||
|
|
||
|
attention_weights = F.softmax(scores, dim=-1)
|
||
|
attention_weights = self.dropout(attention_weights)
|
||
|
|
||
|
# Compute attention output
|
||
|
attention_output = torch.matmul(attention_weights, values)
|
||
|
attention_output = attention_output.transpose(1, 2)
|
||
|
attention_output = attention_output.reshape(
|
||
|
batch_size, seq_length, -1
|
||
|
)
|
||
|
|
||
|
return attention_output, None
|
||
|
|
||
|
|
||
|
class MoETransformer(nn.Module):
|
||
|
"""
|
||
|
Production-grade Transformer model with Mixture of Experts and Multi-Query Attention.
|
||
|
|
||
|
Features:
|
||
|
- Multi-Query Attention mechanism for efficient inference
|
||
|
- Mixture of Experts for dynamic routing and specialization
|
||
|
- Real-time weight updates based on input similarity
|
||
|
- Built-in logging and monitoring
|
||
|
- Type annotations for better code maintainability
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config: TransformerConfig):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
|
||
|
# Initialize components
|
||
|
self.embedding = nn.Embedding(
|
||
|
config.vocab_size, config.hidden_size
|
||
|
)
|
||
|
self.position_embedding = nn.Embedding(
|
||
|
config.max_position_embeddings, config.hidden_size
|
||
|
)
|
||
|
|
||
|
# Multi-Query Attention layers
|
||
|
self.attention_layers = nn.ModuleList(
|
||
|
[
|
||
|
MultiQueryAttention(config)
|
||
|
for _ in range(config.num_expert_layers)
|
||
|
]
|
||
|
)
|
||
|
|
||
|
# Mixture of Experts layers
|
||
|
self.moe_layers = nn.ModuleList(
|
||
|
[
|
||
|
MixtureOfExperts(config)
|
||
|
for _ in range(config.num_expert_layers)
|
||
|
]
|
||
|
)
|
||
|
|
||
|
# Layer normalization and dropout
|
||
|
self.layer_norm = nn.LayerNorm(
|
||
|
config.hidden_size, eps=config.layer_norm_epsilon
|
||
|
)
|
||
|
self.dropout = nn.Dropout(config.dropout_prob)
|
||
|
|
||
|
# Output projection
|
||
|
self.output_projection = nn.Linear(
|
||
|
config.hidden_size, config.vocab_size
|
||
|
)
|
||
|
|
||
|
# Initialize weights
|
||
|
self.apply(self._init_weights)
|
||
|
logger.info("Initialized MoETransformer model")
|
||
|
|
||
|
def _init_weights(self, module: nn.Module):
|
||
|
"""Initialize model weights."""
|
||
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||
|
module.weight.data.normal_(
|
||
|
mean=0.0, std=self.config.initializer_range
|
||
|
)
|
||
|
if (
|
||
|
isinstance(module, nn.Linear)
|
||
|
and module.bias is not None
|
||
|
):
|
||
|
module.bias.data.zero_()
|
||
|
|
||
|
def get_position_embeddings(self, position_ids: Tensor) -> Tensor:
|
||
|
"""Generate position embeddings."""
|
||
|
return self.position_embedding(position_ids)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: Tensor,
|
||
|
attention_mask: Optional[Tensor] = None,
|
||
|
position_ids: Optional[Tensor] = None,
|
||
|
cache: Optional[Dict[str, Tensor]] = None,
|
||
|
) -> Tuple[Tensor, Dict]:
|
||
|
"""
|
||
|
Forward pass through the model.
|
||
|
|
||
|
Args:
|
||
|
input_ids: Input token IDs
|
||
|
attention_mask: Attention mask for padding
|
||
|
position_ids: Position IDs for positioning encoding
|
||
|
cache: Cache for key/value states in generation
|
||
|
|
||
|
Returns:
|
||
|
tuple: (logits, auxiliary_outputs)
|
||
|
"""
|
||
|
batch_size, seq_length = input_ids.shape
|
||
|
|
||
|
if position_ids is None:
|
||
|
position_ids = torch.arange(
|
||
|
seq_length, dtype=torch.long, device=input_ids.device
|
||
|
)
|
||
|
position_ids = position_ids.unsqueeze(0).expand_as(
|
||
|
input_ids
|
||
|
)
|
||
|
|
||
|
# Get embeddings
|
||
|
inputs_embeds = self.embedding(input_ids)
|
||
|
position_embeds = self.get_position_embeddings(position_ids)
|
||
|
hidden_states = inputs_embeds + position_embeds
|
||
|
|
||
|
# Initialize auxiliary outputs
|
||
|
aux_outputs = {"moe_losses": []}
|
||
|
|
||
|
# Process through transformer layers
|
||
|
for attention_layer, moe_layer in zip(
|
||
|
self.attention_layers, self.moe_layers
|
||
|
):
|
||
|
# Multi-Query Attention
|
||
|
attention_output, _ = attention_layer(
|
||
|
hidden_states, attention_mask, cache
|
||
|
)
|
||
|
hidden_states = self.layer_norm(
|
||
|
hidden_states + attention_output
|
||
|
)
|
||
|
|
||
|
# Mixture of Experts
|
||
|
moe_output, moe_aux = moe_layer(hidden_states)
|
||
|
hidden_states = self.layer_norm(
|
||
|
hidden_states + moe_output
|
||
|
)
|
||
|
aux_outputs["moe_losses"].append(
|
||
|
moe_aux["load_balancing_loss"]
|
||
|
)
|
||
|
|
||
|
# Final output projection
|
||
|
logits = self.output_projection(hidden_states)
|
||
|
|
||
|
return logits, aux_outputs
|
||
|
|
||
|
def fetch_loss(
|
||
|
self,
|
||
|
logits: Tensor,
|
||
|
labels: Tensor,
|
||
|
aux_outputs: Dict,
|
||
|
reduction: str = "mean",
|
||
|
) -> Tensor:
|
||
|
"""
|
||
|
Calculate the total loss including MoE balancing losses.
|
||
|
|
||
|
Args:
|
||
|
logits: Model output logits
|
||
|
labels: Ground truth labels
|
||
|
aux_outputs: Auxiliary outputs from forward pass
|
||
|
reduction: Loss reduction method
|
||
|
|
||
|
Returns:
|
||
|
Tensor: Total loss
|
||
|
"""
|
||
|
# Calculate cross entropy loss
|
||
|
ce_loss = F.cross_entropy(
|
||
|
logits.view(-1, self.config.vocab_size),
|
||
|
labels.view(-1),
|
||
|
reduction=reduction,
|
||
|
)
|
||
|
|
||
|
# Calculate MoE loss
|
||
|
moe_loss = torch.stack(aux_outputs["moe_losses"]).mean()
|
||
|
|
||
|
# Combine losses
|
||
|
total_loss = ce_loss + 0.01 * moe_loss
|
||
|
|
||
|
logger.debug(
|
||
|
f"CE Loss: {ce_loss.item():.4f}, "
|
||
|
f"MoE Loss: {moe_loss.item():.4f}"
|
||
|
)
|
||
|
|
||
|
return total_loss
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def generate(
|
||
|
self,
|
||
|
input_ids: Tensor,
|
||
|
max_length: int = 100,
|
||
|
temperature: float = 1.0,
|
||
|
top_k: int = 50,
|
||
|
top_p: float = 0.9,
|
||
|
) -> Tensor:
|
||
|
"""
|
||
|
Generate text using the model.
|
||
|
|
||
|
Args:
|
||
|
input_ids: Initial input tokens
|
||
|
max_length: Maximum sequence length to generate
|
||
|
temperature: Sampling temperature
|
||
|
top_k: Number of highest probability tokens to keep
|
||
|
top_p: Cumulative probability for nucleus sampling
|
||
|
|
||
|
Returns:
|
||
|
Tensor: Generated token IDs
|
||
|
"""
|
||
|
batch_size = input_ids.shape[0]
|
||
|
device = input_ids.device
|
||
|
|
||
|
# Initialize sequence with input_ids
|
||
|
generated = input_ids
|
||
|
|
||
|
# Cache for key-value pairs
|
||
|
cache = {}
|
||
|
|
||
|
for _ in range(max_length):
|
||
|
# Get position IDs for current sequence
|
||
|
position_ids = torch.arange(
|
||
|
generated.shape[1], dtype=torch.long, device=device
|
||
|
)
|
||
|
position_ids = position_ids.unsqueeze(0).expand(
|
||
|
batch_size, -1
|
||
|
)
|
||
|
|
||
|
# Forward pass
|
||
|
logits, _ = self.forward(
|
||
|
generated, position_ids=position_ids, cache=cache
|
||
|
)
|
||
|
|
||
|
# Get next token logits
|
||
|
next_token_logits = logits[:, -1, :] / temperature
|
||
|
|
||
|
# Apply top-k filtering
|
||
|
if top_k > 0:
|
||
|
indices_to_remove = (
|
||
|
next_token_logits
|
||
|
< torch.topk(next_token_logits, top_k)[0][
|
||
|
..., -1, None
|
||
|
]
|
||
|
)
|
||
|
next_token_logits[indices_to_remove] = float("-inf")
|
||
|
|
||
|
# Apply top-p (nucleus) filtering
|
||
|
if top_p < 1.0:
|
||
|
sorted_logits, sorted_indices = torch.sort(
|
||
|
next_token_logits, descending=True
|
||
|
)
|
||
|
cumulative_probs = torch.cumsum(
|
||
|
F.softmax(sorted_logits, dim=-1), dim=-1
|
||
|
)
|
||
|
|
||
|
# Remove tokens with cumulative probability above the threshold
|
||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||
|
sorted_indices_to_remove[..., 1:] = (
|
||
|
sorted_indices_to_remove[..., :-1].clone()
|
||
|
)
|
||
|
sorted_indices_to_remove[..., 0] = 0
|
||
|
|
||
|
indices_to_remove = sorted_indices[
|
||
|
sorted_indices_to_remove
|
||
|
]
|
||
|
next_token_logits[indices_to_remove] = float("-inf")
|
||
|
|
||
|
# Sample next token
|
||
|
probs = F.softmax(next_token_logits, dim=-1)
|
||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||
|
|
||
|
# Append next token to sequence
|
||
|
generated = torch.cat((generated, next_token), dim=1)
|
||
|
|
||
|
# Check for end of sequence token
|
||
|
if (next_token == self.config.vocab_size - 1).all():
|
||
|
break
|
||
|
|
||
|
return generated
|
||
|
|
||
|
|
||
|
# Initialize model configuration
|
||
|
config = TransformerConfig(
|
||
|
vocab_size=50257,
|
||
|
hidden_size=768,
|
||
|
num_attention_heads=12,
|
||
|
num_expert_layers=4,
|
||
|
num_experts=8,
|
||
|
expert_capacity=32,
|
||
|
max_position_embeddings=1024,
|
||
|
num_query_groups=4,
|
||
|
)
|
||
|
|
||
|
|
||
|
def prepare_sample_data(
|
||
|
batch_size: int = 8,
|
||
|
seq_length: int = 512,
|
||
|
vocab_size: int = 50257,
|
||
|
) -> DataLoader:
|
||
|
"""Create sample data for demonstration."""
|
||
|
# Create random input sequences
|
||
|
input_ids = torch.randint(
|
||
|
0, vocab_size, (100, seq_length) # 100 samples
|
||
|
)
|
||
|
|
||
|
# Create target sequences (shifted by 1)
|
||
|
labels = torch.randint(0, vocab_size, (100, seq_length))
|
||
|
|
||
|
# Create attention masks (1 for real tokens, 0 for padding)
|
||
|
attention_mask = torch.ones_like(input_ids)
|
||
|
|
||
|
# Create dataset and dataloader
|
||
|
dataset = TensorDataset(input_ids, attention_mask, labels)
|
||
|
dataloader = DataLoader(
|
||
|
dataset, batch_size=batch_size, shuffle=True
|
||
|
)
|
||
|
|
||
|
return dataloader
|
||
|
|
||
|
|
||
|
def train_step(
|
||
|
model: MoETransformer,
|
||
|
batch: tuple,
|
||
|
optimizer: torch.optim.Optimizer,
|
||
|
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
||
|
) -> float:
|
||
|
"""Execute single training step."""
|
||
|
model.train()
|
||
|
optimizer.zero_grad()
|
||
|
|
||
|
# Unpack batch
|
||
|
input_ids, attention_mask, labels = [b.to(device) for b in batch]
|
||
|
|
||
|
# Forward pass
|
||
|
logits, aux_outputs = model(
|
||
|
input_ids=input_ids, attention_mask=attention_mask
|
||
|
)
|
||
|
|
||
|
# Calculate loss
|
||
|
loss = model.fetch_loss(logits, labels, aux_outputs)
|
||
|
|
||
|
# Backward pass
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
|
||
|
return loss.item()
|
||
|
|
||
|
|
||
|
def main():
|
||
|
# Set device
|
||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
logger.info(f"Using device: {device}")
|
||
|
|
||
|
# Initialize model
|
||
|
model = MoETransformer(config).to(device)
|
||
|
logger.info("Model initialized")
|
||
|
|
||
|
# Setup optimizer
|
||
|
optimizer = torch.optim.AdamW(
|
||
|
model.parameters(), lr=1e-4, weight_decay=0.01
|
||
|
)
|
||
|
|
||
|
# Prepare data
|
||
|
dataloader = prepare_sample_data()
|
||
|
logger.info("Data prepared")
|
||
|
|
||
|
# Training loop
|
||
|
num_epochs = 3
|
||
|
for epoch in range(num_epochs):
|
||
|
epoch_losses = []
|
||
|
|
||
|
for batch_idx, batch in enumerate(dataloader):
|
||
|
loss = train_step(model, batch, optimizer, device)
|
||
|
epoch_losses.append(loss)
|
||
|
|
||
|
if batch_idx % 10 == 0:
|
||
|
logger.info(
|
||
|
f"Epoch {epoch+1}/{num_epochs} "
|
||
|
f"Batch {batch_idx}/{len(dataloader)} "
|
||
|
f"Loss: {loss:.4f}"
|
||
|
)
|
||
|
|
||
|
avg_loss = np.mean(epoch_losses)
|
||
|
logger.info(f"Epoch {epoch+1} average loss: {avg_loss:.4f}")
|
||
|
|
||
|
# Generation example
|
||
|
model.eval()
|
||
|
with torch.no_grad():
|
||
|
# Prepare input prompt
|
||
|
prompt = torch.randint(0, config.vocab_size, (1, 10)).to(
|
||
|
device
|
||
|
)
|
||
|
|
||
|
# Generate sequence
|
||
|
generated = model.generate(
|
||
|
input_ids=prompt,
|
||
|
max_length=50,
|
||
|
temperature=0.7,
|
||
|
top_k=50,
|
||
|
top_p=0.9,
|
||
|
)
|
||
|
|
||
|
logger.info(f"Generated sequence shape: {generated.shape}")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|