You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
293 lines
8.9 KiB
293 lines
8.9 KiB
1 month ago
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.distributed as dist
|
||
|
from dataclasses import dataclass
|
||
|
from typing import Optional, Tuple, Union
|
||
|
from loguru import logger
|
||
|
import math
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class StarAttentionConfig:
|
||
|
"""Configuration for StarAttention module.
|
||
|
|
||
|
Attributes:
|
||
|
hidden_size: Dimension of the model's hidden states
|
||
|
num_attention_heads: Number of attention heads
|
||
|
num_hosts: Number of hosts in the distributed system
|
||
|
block_size: Size of each context block
|
||
|
anchor_size: Size of the anchor block
|
||
|
dropout_prob: Dropout probability (default: 0.1)
|
||
|
layer_norm_eps: Layer normalization epsilon (default: 1e-12)
|
||
|
"""
|
||
|
|
||
|
hidden_size: int
|
||
|
num_attention_heads: int
|
||
|
num_hosts: int
|
||
|
block_size: int
|
||
|
anchor_size: int
|
||
|
dropout_prob: float = 0.1
|
||
|
layer_norm_eps: float = 1e-12
|
||
|
|
||
|
|
||
|
class StarAttention(nn.Module):
|
||
|
"""
|
||
|
Implementation of Star Attention mechanism for distributed inference.
|
||
|
|
||
|
The module implements a two-phase attention mechanism:
|
||
|
1. Local Context Encoding with Anchor Blocks
|
||
|
2. Query Encoding and Output Generation with Global Attention
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config: StarAttentionConfig):
|
||
|
super().__init__()
|
||
|
|
||
|
if config.hidden_size % config.num_attention_heads != 0:
|
||
|
raise ValueError(
|
||
|
f"Hidden size {config.hidden_size} not divisible by number of attention "
|
||
|
f"heads {config.num_attention_heads}"
|
||
|
)
|
||
|
|
||
|
self.config = config
|
||
|
self.head_dim = (
|
||
|
config.hidden_size // config.num_attention_heads
|
||
|
)
|
||
|
|
||
|
# Initialize components
|
||
|
self.query = nn.Linear(config.hidden_size, config.hidden_size)
|
||
|
self.key = nn.Linear(config.hidden_size, config.hidden_size)
|
||
|
self.value = nn.Linear(config.hidden_size, config.hidden_size)
|
||
|
|
||
|
self.dropout = nn.Dropout(config.dropout_prob)
|
||
|
self.layer_norm = nn.LayerNorm(
|
||
|
config.hidden_size, eps=config.layer_norm_eps
|
||
|
)
|
||
|
|
||
|
# KV cache for storing computed key/value pairs
|
||
|
self.kv_cache = {}
|
||
|
|
||
|
logger.info(
|
||
|
f"Initialized StarAttention with config: {config}"
|
||
|
)
|
||
|
|
||
|
def _split_heads(
|
||
|
self, tensor: torch.Tensor, num_heads: int
|
||
|
) -> torch.Tensor:
|
||
|
"""Split the last dimension into (num_heads, head_dim)."""
|
||
|
batch_size, seq_len, _ = tensor.size()
|
||
|
tensor = tensor.view(
|
||
|
batch_size, seq_len, num_heads, self.head_dim
|
||
|
)
|
||
|
# Transpose to (batch_size, num_heads, seq_len, head_dim)
|
||
|
return tensor.transpose(1, 2)
|
||
|
|
||
|
def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
|
||
|
"""Merge the head dimension back into hidden_size."""
|
||
|
batch_size, _, seq_len, _ = tensor.size()
|
||
|
tensor = tensor.transpose(1, 2)
|
||
|
return tensor.reshape(
|
||
|
batch_size, seq_len, self.config.hidden_size
|
||
|
)
|
||
|
|
||
|
def _compute_attention_scores(
|
||
|
self,
|
||
|
query: torch.Tensor,
|
||
|
key: torch.Tensor,
|
||
|
value: torch.Tensor,
|
||
|
mask: Optional[torch.Tensor] = None,
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
"""Compute attention scores and weighted values."""
|
||
|
# Scale dot-product attention
|
||
|
scores = torch.matmul(
|
||
|
query, key.transpose(-2, -1)
|
||
|
) / math.sqrt(self.head_dim)
|
||
|
|
||
|
if mask is not None:
|
||
|
scores = scores.masked_fill(mask == 0, float("-inf"))
|
||
|
|
||
|
# Online softmax computation
|
||
|
attention_probs = torch.nn.functional.softmax(scores, dim=-1)
|
||
|
attention_probs = self.dropout(attention_probs)
|
||
|
|
||
|
context = torch.matmul(attention_probs, value)
|
||
|
|
||
|
return context, attention_probs
|
||
|
|
||
|
def phase1_local_context_encoding(
|
||
|
self,
|
||
|
input_ids: torch.Tensor,
|
||
|
host_id: int,
|
||
|
device: Union[str, torch.device] = "cuda",
|
||
|
) -> None:
|
||
|
"""
|
||
|
Phase 1: Local Context Encoding with Anchor Blocks
|
||
|
|
||
|
Args:
|
||
|
input_ids: Input tensor of shape (batch_size, seq_len)
|
||
|
host_id: ID of the current host
|
||
|
device: Device to run computations on
|
||
|
"""
|
||
|
logger.debug(f"Starting Phase 1 on host {host_id}")
|
||
|
|
||
|
# Calculate block assignments
|
||
|
block_start = host_id * self.config.block_size
|
||
|
block_end = block_start + self.config.block_size
|
||
|
|
||
|
# Get local block
|
||
|
local_block = input_ids[:, block_start:block_end].to(device)
|
||
|
|
||
|
# Get anchor block (first block)
|
||
|
anchor_block = input_ids[:, : self.config.anchor_size].to(
|
||
|
device
|
||
|
)
|
||
|
|
||
|
# Compute KV pairs for local block
|
||
|
local_hidden = self.layer_norm(local_block)
|
||
|
local_key = self._split_heads(
|
||
|
self.key(local_hidden), self.config.num_attention_heads
|
||
|
)
|
||
|
local_value = self._split_heads(
|
||
|
self.value(local_hidden), self.config.num_attention_heads
|
||
|
)
|
||
|
|
||
|
# Store in KV cache
|
||
|
self.kv_cache[host_id] = {
|
||
|
"key": local_key,
|
||
|
"value": local_value,
|
||
|
"anchor_key": (
|
||
|
None
|
||
|
if host_id == 0
|
||
|
else self._split_heads(
|
||
|
self.key(self.layer_norm(anchor_block)),
|
||
|
self.config.num_attention_heads,
|
||
|
)
|
||
|
),
|
||
|
}
|
||
|
|
||
|
logger.debug(
|
||
|
f"Phase 1 complete on host {host_id}. KV cache shapes - "
|
||
|
f"key: {local_key.shape}, value: {local_value.shape}"
|
||
|
)
|
||
|
|
||
|
def phase2_query_encoding(
|
||
|
self,
|
||
|
query_input: torch.Tensor,
|
||
|
host_id: int,
|
||
|
is_query_host: bool,
|
||
|
device: Union[str, torch.device] = "cuda",
|
||
|
) -> Optional[torch.Tensor]:
|
||
|
"""
|
||
|
Phase 2: Query Encoding and Output Generation
|
||
|
|
||
|
Args:
|
||
|
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
|
||
|
host_id: ID of the current host
|
||
|
is_query_host: Whether this host is the query host
|
||
|
device: Device to run computations on
|
||
|
|
||
|
Returns:
|
||
|
Output tensor if this is the query host, None otherwise
|
||
|
"""
|
||
|
logger.debug(f"Starting Phase 2 on host {host_id}")
|
||
|
|
||
|
# Transform query
|
||
|
query_hidden = self.layer_norm(query_input)
|
||
|
query = self._split_heads(
|
||
|
self.query(query_hidden), self.config.num_attention_heads
|
||
|
)
|
||
|
|
||
|
# Compute local attention scores
|
||
|
local_context, local_probs = self._compute_attention_scores(
|
||
|
query,
|
||
|
self.kv_cache[host_id]["key"],
|
||
|
self.kv_cache[host_id]["value"],
|
||
|
)
|
||
|
|
||
|
if not is_query_host:
|
||
|
# Non-query hosts send their local attention statistics
|
||
|
dist.send(local_probs, dst=self.config.num_hosts - 1)
|
||
|
return None
|
||
|
|
||
|
# Query host aggregates attention from all hosts
|
||
|
all_attention_probs = [local_probs]
|
||
|
for src_rank in range(self.config.num_hosts - 1):
|
||
|
probs = torch.empty_like(local_probs)
|
||
|
dist.recv(probs, src=src_rank)
|
||
|
all_attention_probs.append(probs)
|
||
|
|
||
|
# Compute global attention
|
||
|
torch.mean(torch.stack(all_attention_probs), dim=0)
|
||
|
|
||
|
# Final output computation
|
||
|
output = self._merge_heads(local_context)
|
||
|
output = self.dropout(output)
|
||
|
|
||
|
logger.debug(
|
||
|
f"Phase 2 complete on host {host_id}. Output shape: {output.shape}"
|
||
|
)
|
||
|
|
||
|
return output
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.Tensor,
|
||
|
query_input: torch.Tensor,
|
||
|
host_id: int,
|
||
|
is_query_host: bool,
|
||
|
device: Union[str, torch.device] = "cuda",
|
||
|
) -> Optional[torch.Tensor]:
|
||
|
"""
|
||
|
Forward pass of the StarAttention module.
|
||
|
|
||
|
Args:
|
||
|
input_ids: Input tensor of shape (batch_size, seq_len)
|
||
|
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
|
||
|
host_id: ID of the current host
|
||
|
is_query_host: Whether this host is the query host
|
||
|
device: Device to run computations on
|
||
|
|
||
|
Returns:
|
||
|
Output tensor if this is the query host, None otherwise
|
||
|
"""
|
||
|
# Phase 1: Local Context Encoding
|
||
|
self.phase1_local_context_encoding(input_ids, host_id, device)
|
||
|
|
||
|
# Phase 2: Query Encoding and Output Generation
|
||
|
return self.phase2_query_encoding(
|
||
|
query_input, host_id, is_query_host, device
|
||
|
)
|
||
|
|
||
|
|
||
|
# Example forward pass
|
||
|
config = StarAttentionConfig(
|
||
|
hidden_size=768,
|
||
|
num_attention_heads=12,
|
||
|
num_hosts=3,
|
||
|
block_size=512,
|
||
|
anchor_size=128,
|
||
|
)
|
||
|
|
||
|
# Initialize model
|
||
|
model = StarAttention(config)
|
||
|
|
||
|
# Example input tensors
|
||
|
batch_size = 4
|
||
|
seq_len = 512
|
||
|
input_ids = torch.randint(
|
||
|
0, 1000, (batch_size, seq_len)
|
||
|
) # Random input IDs
|
||
|
query_input = torch.randn(
|
||
|
batch_size, seq_len, config.hidden_size
|
||
|
) # Random query input
|
||
|
|
||
|
# Example forward pass for query host (host_id = 2)
|
||
|
output = model(
|
||
|
input_ids=input_ids,
|
||
|
query_input=query_input,
|
||
|
host_id=2,
|
||
|
is_query_host=True,
|
||
|
device="cpu",
|
||
|
)
|
||
|
|
||
|
print(output)
|