"""OpenAI OSS sdpa and moe implementations that are suitable for both training and inference.""" from typing import Final import torch import torch.nn.functional as F from einops import einsum, rearrange, repeat from torch import Tensor, nn __all__ = ["sdpa", "MOEBlock"] def sdpa( query: Tensor, key: Tensor, value: Tensor, sink_logits: Tensor | None, *, sliding_window: int = 0, attn_dropout_p: float = 0.0, training: bool | None = None, ) -> Tensor: """Scaled dot-product attention with grouped queries, causal masking, optional sliding window, and an optional per-head *sink* logit. Parameters ---------- query : torch.Tensor Query tensor of shape ``[batch, query_seq_len, heads_kv, group_size, hidden_dim]`` or ``[query_seq_len, heads_kv, group_size, hidden_dim]``. ``group_size`` is the group size used in GQA, i.e. ``heads_queries = heads_kv * group_size``. key : torch.Tensor Key tensor of shape ``[batch, query_seq_len, heads_kv, hidden_dim]`` or ``[query_seq_len, heads_kv, hidden_dim]``. value : torch.Tensor Value tensor of shape ``[batch, query_seq_len, heads_kv, hidden_dim]`` or ``[query_seq_len, heads_kv, hidden_dim]``. sink_logits : torch.Tensor or None Optional per-attention-head sink attention_scores of shape ``[heads_queries] == [heads_kv * group_size]``. When provided, a sink column is appended to the attention attention_scores; it draws probability mass but is discarded before applying values. sliding_window : int, default=0 If ``> 0``, token ``t`` may only attend to keys in ``[t - sliding_window, ..., t]`` (inclusive). attn_dropout_p : float, default=0.0 Dropout probability applied to attention probabilities *after* softmax. training : bool or None, default=None If training (or prefilling), set to True to apply mask Returns ------- torch.Tensor Attention output of shape ``[batch, query_seq_len, heads_queries * hidden_dim]`` (or ``[query_seq_len, heads_queries * hidden_dim]`` if the input did not have a batch dimension). Notes ----- - Einsum index legend used below: ``b``=batch, ``t``=query sequence, ``s``=key value sequence, ``h``=KV head, ``g``=group size, ``d``=head dim. """ # --- normalize shapes to 5D with an explicit batch dimension --- added_batch_dim: Final[bool] = query.ndim == 4 if added_batch_dim: query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) batch, query_seq_len, heads_kv, group_size, hidden_dim = query.shape device = query.device dtype = query.dtype heads_queries: Final[int] = heads_kv * group_size scale = torch.rsqrt(torch.sqrt(torch.as_tensor(hidden_dim, dtype=dtype, device=device))) # 1 / sqrt(hidden_dim) attention_scores = einsum(query * scale, key * scale, "b t h g d, b s h d -> b h g t s") if training: # or prefilling # causal mask -- upper triangular matrix with -inf mask_shape = attention_scores.shape[-2:] causal_mask = torch.triu( torch.full(mask_shape, fill_value=-float("inf"), dtype=dtype, device=device), diagonal=1, ) # sliding window mask -- lower triangular matrix with -inf if sliding_window and sliding_window > 0: sliding_mask = torch.tril( torch.full(mask_shape, fill_value=-float("inf"), dtype=dtype, device=device), diagonal=-sliding_window, ) mask = torch.minimum(causal_mask, sliding_mask) else: mask = causal_mask attention_scores = attention_scores + rearrange( mask, "t s -> 1 1 1 t s" ) # [batch,heads_kv,group_size,query_seq_len,S] # --- optional sink column (per head) --- if sink_logits is not None: if sink_logits.numel() != heads_queries: msg = f"sink_logits must have shape [heads_queries]={heads_queries}, got {tuple(sink_logits.shape)}" raise ValueError(msg) sink = rearrange( sink_logits.to(dtype), "(h g) -> 1 h g 1 1", h=heads_kv, g=group_size, ) sink = repeat( sink, "1 h g 1 1 -> b h g t 1", b=batch, t=query_seq_len ) # [batch, heads_kv, group_size, query_seq_len, 1] attention_scores = torch.cat([attention_scores, sink], dim=-1) # append sink column # softmax over keys+sink, then drop sink column attn_prob = torch.softmax(attention_scores, dim=-1)[..., :-1] else: attn_prob = torch.softmax(attention_scores, dim=-1) if attn_dropout_p and training: attn_prob = F.dropout(attn_prob, p=attn_dropout_p, training=True) attn_out = einsum(attn_prob, value, "b h g t s, b s h d -> b t h g d") out = rearrange(attn_out, "b t h g d -> b t (h g d)") if added_batch_dim: out = out.squeeze(0) return out def test_sdpa(): torch.manual_seed(0) b, t, h, g, d = 1, 8, 4, 2, 16 query, key, value = torch.randn((b, t, h, g, d)), torch.randn((b, t, h, d)), torch.randn((b, t, h, d)) sink = torch.ones((h * g)) sdpa(query, key, value, sink, sliding_window=3) def swiglu(x: Tensor, alpha: float = 1.702, clamp_limit: float = 7.0) -> Tensor: """OpenAI's unconventional SwiGLU activation function. Parameters ---------- x : torch.Tensor Input tensor of shape ``[B, T, D]`` or ``[T, D]`` Returns ------- torch.Tensor Output tensor of shape [B, T, D//2] or [T, D//2], where the last dimension is halved. """ x_glu, x_linear = x[..., ::2], x[..., 1::2] if clamp_limit: x_glu = x_glu.clamp(min=None, max=clamp_limit) x_linear = x_linear.clamp(min=-clamp_limit, max=clamp_limit) return (x_linear + 1) * x_glu * (alpha * x_glu).sigmoid() class RMSNorm(nn.Module): def __init__(self, hidden_dim: int, eps: float = 1e-05, device: torch.device | None = None): super().__init__() self.hidden_dim = hidden_dim self.eps = eps self.scale = nn.Parameter(torch.ones(hidden_dim, device=device, dtype=torch.float32)) def forward(self, x: torch.Tensor) -> torch.Tensor: assert x.shape[-1] == self.hidden_dim x, dtype = x.float(), x.dtype x = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps) # x / root_mean_squared(x) return (x * self.scale).to(dtype) class MOEBlock(nn.Module): """Mixture-of-Experts two-layer MLP with Top-experts_per_token routing (SwiGLU). The block applies a pre-layer-normalization, routes each token to its top-experts_per_token experts, executes expert-specific MLPs, sums expert outputs with softmax weights, and adds a residual connection. Parameters ---------- hidden_dim : int hidden dimension of the input states intermediate_size : int intermediate dimension of the two-layer MLPs. Can be set different from hidden_dim. If set to > hidden_dim, the MLP will up-project the input to a larger intermediate space If set to < hidden_dim, the MLP will down-project the input to a smaller intermediate space. num_experts : int Number of experts ``num_experts``. experts_per_token : int, default=4 Number of experts selected per token (Top-experts_per_token routing). Shapes ------ Input: ``x`` has shape ``[B, T, hidden_dim]`` or ``[T, hidden_dim]``. Parameters: ``w1``: ``[num_experts, 2*intermediate_size, hidden_dim]``, ``b1``: ``[num_experts, 2*intermediate_size]`` (up-projection) ``w2``: ``[num_experts, hidden_dim, intermediate_size]``, ``b2``: ``[num_experts, hidden_dim]`` (down-projection) Output: Same shape as input, with residual connection applied. Notes ----- - Einsum index legend: ``b``=batch, ``t``=query sequence, ``k``=top-k experts, ``i`` = intermediate size, ``e``=expert id, ``c``=model dim, ``h``=hidden dim """ def __init__( self, hidden_dim: int = 2880, intermediate_size: int = 2880, num_experts: int = 32, experts_per_token: int = 4, *, device: torch.device | None = None, ) -> None: super().__init__() assert experts_per_token >= 1, "experts_per_token must be >= 1" assert num_experts >= experts_per_token, "num_experts must be >= experts_per_token" self.hidden_dim: Final[int] = hidden_dim self.intermediate_size: Final[int] = intermediate_size self.num_experts: Final[int] = num_experts self.experts_per_token: Final[int] = experts_per_token self.norm = RMSNorm(hidden_dim, device=device) self.router = nn.Linear(self.hidden_dim, self.num_experts, device=device) # Experts' weights and biases # (num_experts stacks of Linear(hidden_dim -> 2*intermediate_size) and Linear(intermediate_size -> hidden_dim)) w1 = torch.empty(self.num_experts, 2 * self.intermediate_size, self.hidden_dim, device=device) b1 = torch.empty(self.num_experts, 2 * self.intermediate_size, device=device) w2 = torch.empty(self.num_experts, self.hidden_dim, self.intermediate_size, device=device) b2 = torch.empty(self.num_experts, self.hidden_dim, device=device) # Xavier init for weights, zeros for biases nn.init.xavier_uniform_(w1) nn.init.xavier_uniform_(w2) nn.init.zeros_(b1) nn.init.zeros_(b2) self.w1 = nn.Parameter(w1) self.b1 = nn.Parameter(b1) self.w2 = nn.Parameter(w2) self.b2 = nn.Parameter(b2) def forward(self, x: Tensor) -> Tensor: """Apply Top-experts_per_token routed expert MLP with residual. Parameters ---------- x : torch.Tensor Input of shape ``[B, T, hidden_dim]`` or ``[T, hidden_dim]``. Returns ------- torch.Tensor Output tensor with the same shape as ``x``. """ # --- normalize input shape to [B, T, hidden_dim] --- added_batch_dim: Final[bool] = x.ndim == 2 if added_batch_dim: x = x.unsqueeze(0) B, T, hidden_dim = x.shape assert hidden_dim == self.hidden_dim, f"Expected last dim {self.hidden_dim}, got {hidden_dim}" # Pre-layer-RMSNorm x_norm = self.norm(x) # [B, T, hidden_dim] router_logits = self.router(x_norm) # project to [B, T, num_experts] expert_scores, expert_indices = torch.topk( router_logits, k=self.experts_per_token, dim=-1, sorted=True ) # [B, T, experts_per_token] expert_weights = torch.softmax(expert_scores, dim=-1) # [B, T, experts_per_token] # Gather per-token expert parameters -> [B, T, experts_per_token, ...] experts_w1 = self.w1[expert_indices, ...] # [B, T, experts_per_token, 2*intermediate_size, hidden_dim] experts_b1 = self.b1[expert_indices, ...] # [B, T, experts_per_token, 2*intermediate_size] experts_w2 = self.w2[expert_indices, ...] # [B, T, experts_per_token, hidden_dim, intermediate_size] experts_b2 = self.b2[expert_indices, ...] # [B, T, experts_per_token, hidden_dim] # Expert FFN-1 (up-proj) + SwiGLU up_projected = einsum(x_norm, experts_w1, "b t h, b t k double_i h -> b t k double_i") + experts_b1 intermediate = swiglu(up_projected) # [B, T, experts_per_token, intermediate_size] # Expert FFN-2 (down-proj) down_projected = einsum(intermediate, experts_w2, "b t k i, b t k h i -> b t k h") + experts_b2 # Weighted mixture over the selected experts_per_token experts y = einsum(down_projected, expert_weights, "b t k h, b t k -> b t h") # Residual connection and shape restore out = x + y if added_batch_dim: out = out.squeeze(0) return out def test_moe_block(): moe = MOEBlock() torch.manual_seed(0) x = torch.randn((2, 10, moe.hidden_dim)) with torch.inference_mode(): output = moe(x) assert output.shape == ( 2, 10, moe.hidden_dim, ), f"Expected output shape {(2, 10, moe.hidden_dim)}, got {output.shape}" if __name__ == "__main__": test_sdpa() test_moe_block()