You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
swarms/test.py

293 lines
8.9 KiB

1 month ago
import torch
import torch.nn as nn
import torch.distributed as dist
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from loguru import logger
import math
@dataclass
class StarAttentionConfig:
"""Configuration for StarAttention module.
Attributes:
hidden_size: Dimension of the model's hidden states
num_attention_heads: Number of attention heads
num_hosts: Number of hosts in the distributed system
block_size: Size of each context block
anchor_size: Size of the anchor block
dropout_prob: Dropout probability (default: 0.1)
layer_norm_eps: Layer normalization epsilon (default: 1e-12)
"""
hidden_size: int
num_attention_heads: int
num_hosts: int
block_size: int
anchor_size: int
dropout_prob: float = 0.1
layer_norm_eps: float = 1e-12
class StarAttention(nn.Module):
"""
Implementation of Star Attention mechanism for distributed inference.
The module implements a two-phase attention mechanism:
1. Local Context Encoding with Anchor Blocks
2. Query Encoding and Output Generation with Global Attention
"""
def __init__(self, config: StarAttentionConfig):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"Hidden size {config.hidden_size} not divisible by number of attention "
f"heads {config.num_attention_heads}"
)
self.config = config
self.head_dim = (
config.hidden_size // config.num_attention_heads
)
# Initialize components
self.query = nn.Linear(config.hidden_size, config.hidden_size)
self.key = nn.Linear(config.hidden_size, config.hidden_size)
self.value = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.dropout_prob)
self.layer_norm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
# KV cache for storing computed key/value pairs
self.kv_cache = {}
logger.info(
f"Initialized StarAttention with config: {config}"
)
def _split_heads(
self, tensor: torch.Tensor, num_heads: int
) -> torch.Tensor:
"""Split the last dimension into (num_heads, head_dim)."""
batch_size, seq_len, _ = tensor.size()
tensor = tensor.view(
batch_size, seq_len, num_heads, self.head_dim
)
# Transpose to (batch_size, num_heads, seq_len, head_dim)
return tensor.transpose(1, 2)
def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
"""Merge the head dimension back into hidden_size."""
batch_size, _, seq_len, _ = tensor.size()
tensor = tensor.transpose(1, 2)
return tensor.reshape(
batch_size, seq_len, self.config.hidden_size
)
def _compute_attention_scores(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute attention scores and weighted values."""
# Scale dot-product attention
scores = torch.matmul(
query, key.transpose(-2, -1)
) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
# Online softmax computation
attention_probs = torch.nn.functional.softmax(scores, dim=-1)
attention_probs = self.dropout(attention_probs)
context = torch.matmul(attention_probs, value)
return context, attention_probs
def phase1_local_context_encoding(
self,
input_ids: torch.Tensor,
host_id: int,
device: Union[str, torch.device] = "cuda",
) -> None:
"""
Phase 1: Local Context Encoding with Anchor Blocks
Args:
input_ids: Input tensor of shape (batch_size, seq_len)
host_id: ID of the current host
device: Device to run computations on
"""
logger.debug(f"Starting Phase 1 on host {host_id}")
# Calculate block assignments
block_start = host_id * self.config.block_size
block_end = block_start + self.config.block_size
# Get local block
local_block = input_ids[:, block_start:block_end].to(device)
# Get anchor block (first block)
anchor_block = input_ids[:, : self.config.anchor_size].to(
device
)
# Compute KV pairs for local block
local_hidden = self.layer_norm(local_block)
local_key = self._split_heads(
self.key(local_hidden), self.config.num_attention_heads
)
local_value = self._split_heads(
self.value(local_hidden), self.config.num_attention_heads
)
# Store in KV cache
self.kv_cache[host_id] = {
"key": local_key,
"value": local_value,
"anchor_key": (
None
if host_id == 0
else self._split_heads(
self.key(self.layer_norm(anchor_block)),
self.config.num_attention_heads,
)
),
}
logger.debug(
f"Phase 1 complete on host {host_id}. KV cache shapes - "
f"key: {local_key.shape}, value: {local_value.shape}"
)
def phase2_query_encoding(
self,
query_input: torch.Tensor,
host_id: int,
is_query_host: bool,
device: Union[str, torch.device] = "cuda",
) -> Optional[torch.Tensor]:
"""
Phase 2: Query Encoding and Output Generation
Args:
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
host_id: ID of the current host
is_query_host: Whether this host is the query host
device: Device to run computations on
Returns:
Output tensor if this is the query host, None otherwise
"""
logger.debug(f"Starting Phase 2 on host {host_id}")
# Transform query
query_hidden = self.layer_norm(query_input)
query = self._split_heads(
self.query(query_hidden), self.config.num_attention_heads
)
# Compute local attention scores
local_context, local_probs = self._compute_attention_scores(
query,
self.kv_cache[host_id]["key"],
self.kv_cache[host_id]["value"],
)
if not is_query_host:
# Non-query hosts send their local attention statistics
dist.send(local_probs, dst=self.config.num_hosts - 1)
return None
# Query host aggregates attention from all hosts
all_attention_probs = [local_probs]
for src_rank in range(self.config.num_hosts - 1):
probs = torch.empty_like(local_probs)
dist.recv(probs, src=src_rank)
all_attention_probs.append(probs)
# Compute global attention
torch.mean(torch.stack(all_attention_probs), dim=0)
# Final output computation
output = self._merge_heads(local_context)
output = self.dropout(output)
logger.debug(
f"Phase 2 complete on host {host_id}. Output shape: {output.shape}"
)
return output
def forward(
self,
input_ids: torch.Tensor,
query_input: torch.Tensor,
host_id: int,
is_query_host: bool,
device: Union[str, torch.device] = "cuda",
) -> Optional[torch.Tensor]:
"""
Forward pass of the StarAttention module.
Args:
input_ids: Input tensor of shape (batch_size, seq_len)
query_input: Query tensor of shape (batch_size, seq_len, hidden_size)
host_id: ID of the current host
is_query_host: Whether this host is the query host
device: Device to run computations on
Returns:
Output tensor if this is the query host, None otherwise
"""
# Phase 1: Local Context Encoding
self.phase1_local_context_encoding(input_ids, host_id, device)
# Phase 2: Query Encoding and Output Generation
return self.phase2_query_encoding(
query_input, host_id, is_query_host, device
)
# Example forward pass
config = StarAttentionConfig(
hidden_size=768,
num_attention_heads=12,
num_hosts=3,
block_size=512,
anchor_size=128,
)
# Initialize model
model = StarAttention(config)
# Example input tensors
batch_size = 4
seq_len = 512
input_ids = torch.randint(
0, 1000, (batch_size, seq_len)
) # Random input IDs
query_input = torch.randn(
batch_size, seq_len, config.hidden_size
) # Random query input
# Example forward pass for query host (host_id = 2)
output = model(
input_ids=input_ids,
query_input=query_input,
host_id=2,
is_query_host=True,
device="cpu",
)
print(output)