Created
May 4, 2026 22:51
-
-
Save ashvinnihalani/7fc40fcabad9629096b6bdf4fdd0e602 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import copy | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec | |
| from megatron.core.transformer.spec_utils import ModuleSpec | |
| from megatron.core.transformer.transformer_block import get_num_layers_to_build | |
| from megatron.core.transformer.transformer_layer import get_transformer_layer_offset | |
| from transformers.activations import ACT2FN | |
| try: | |
| from fla.modules import FusedRMSNormGated, ShortConvolution | |
| from fla.ops.gated_delta_rule import chunk_gated_delta_rule | |
| except ImportError: | |
| pass | |
| from .hf_attention import HuggingfaceAttention, _load_hf_config | |
| def _get_text_config(hf_config): | |
| """Extract text config from a VLM config if needed.""" | |
| if hasattr(hf_config, "text_config"): | |
| return hf_config.text_config | |
| return hf_config | |
| # Adapted from Qwen3NextGatedDeltaNet but with separate in_proj_qkv and in_proj_z | |
| class Qwen3_5GatedDeltaNet(nn.Module): | |
| """ | |
| Qwen3.5 GatedDeltaNet with varlen support. | |
| Unlike Qwen3Next which uses a combined in_proj_qkvz, Qwen3.5 uses | |
| separate in_proj_qkv (for Q,K,V) and in_proj_z (for Z). | |
| """ | |
| def __init__(self, config, layer_idx: int): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.num_v_heads = config.linear_num_value_heads | |
| self.num_k_heads = config.linear_num_key_heads | |
| self.head_k_dim = config.linear_key_head_dim | |
| self.head_v_dim = config.linear_value_head_dim | |
| self.key_dim = self.head_k_dim * self.num_k_heads | |
| self.value_dim = self.head_v_dim * self.num_v_heads | |
| self.conv_kernel_size = config.linear_conv_kernel_dim | |
| self.layer_idx = layer_idx | |
| self.activation = config.hidden_act | |
| self.act = ACT2FN[config.hidden_act] | |
| self.layer_norm_epsilon = config.rms_norm_eps | |
| # QKV | |
| self.conv_dim = self.key_dim * 2 + self.value_dim | |
| self.conv1d = ShortConvolution( | |
| hidden_size=self.conv_dim, | |
| bias=False, | |
| kernel_size=self.conv_kernel_size, | |
| ) | |
| # Separate projections for QKV and Z (unlike Qwen3Next which combines QKVZ) | |
| projection_size_qkv = self.key_dim * 2 + self.value_dim | |
| projection_size_z = self.value_dim | |
| self.in_proj_qkv = nn.Linear(self.hidden_size, projection_size_qkv, bias=False) | |
| self.in_proj_z = nn.Linear(self.hidden_size, projection_size_z, bias=False) | |
| self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) | |
| self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False) | |
| # time step projection | |
| self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads)) | |
| A = torch.empty(self.num_v_heads).uniform_(0, 16) | |
| self.A_log = nn.Parameter(torch.log(A)) | |
| self.norm = FusedRMSNormGated( | |
| self.head_v_dim, | |
| eps=self.layer_norm_epsilon, | |
| activation=self.activation, | |
| device=torch.cuda.current_device(), | |
| dtype=config.dtype if config.dtype is not None else torch.get_current_dtype(), | |
| ) | |
| self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cu_seqlens: torch.Tensor = None, | |
| ): | |
| batch_size, seq_len, _ = hidden_states.shape | |
| # Projections (flat layout: [Q_all, K_all, V_all]) | |
| mixed_qkv = self.in_proj_qkv(hidden_states) | |
| z = self.in_proj_z(hidden_states) | |
| z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) | |
| b = self.in_proj_b(hidden_states) | |
| a = self.in_proj_a(hidden_states) | |
| # Convolution on the flat QKV | |
| mixed_qkv, _ = self.conv1d( | |
| x=mixed_qkv, | |
| cu_seqlens=cu_seqlens, | |
| ) | |
| # Split into Q, K, V (flat split, matching HF layout) | |
| query, key, value = torch.split( | |
| mixed_qkv, | |
| [self.key_dim, self.key_dim, self.value_dim], | |
| dim=-1, | |
| ) | |
| query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) | |
| key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) | |
| value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) | |
| beta = b.sigmoid() | |
| # If the model is loaded in fp16, without the .float() here, A might be -inf | |
| g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) | |
| if self.num_v_heads // self.num_k_heads > 1: | |
| query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) | |
| key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) | |
| core_attn_out, last_recurrent_state = chunk_gated_delta_rule( | |
| query, | |
| key, | |
| value, | |
| g=g, | |
| beta=beta, | |
| initial_state=None, | |
| output_final_state=False, | |
| use_qk_l2norm_in_kernel=True, | |
| cu_seqlens=cu_seqlens, | |
| ) | |
| z_shape_og = z.shape | |
| # reshape input data into 2D tensor | |
| core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) | |
| z = z.reshape(-1, z.shape[-1]) | |
| core_attn_out = self.norm(core_attn_out, z) | |
| core_attn_out = core_attn_out.reshape(z_shape_og) | |
| core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) | |
| output = self.out_proj(core_attn_out) | |
| return output | |
| class Attention(HuggingfaceAttention): | |
| def __init__( | |
| self, | |
| args, | |
| config, | |
| layer_number: int, | |
| cp_comm_type: str = "p2p", | |
| pg_collection=None, | |
| ): | |
| super().__init__( | |
| args, | |
| config, | |
| layer_number, | |
| cp_comm_type, | |
| pg_collection, | |
| ) | |
| # Qwen3.5 is a VLM model with nested text_config | |
| self.hf_config = _get_text_config(self.hf_config) | |
| self.hf_config._attn_implementation = "flash_attention_2" | |
| self.linear_attn = Qwen3_5GatedDeltaNet(self.hf_config, self.hf_layer_idx) | |
| # Use a simple RMSNorm | |
| try: | |
| from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextRMSNorm | |
| self.input_layernorm = Qwen3NextRMSNorm(self.hf_config.hidden_size, eps=self.hf_config.rms_norm_eps) | |
| except ImportError: | |
| from torch.nn import RMSNorm | |
| self.input_layernorm = RMSNorm(self.hf_config.hidden_size, eps=self.hf_config.rms_norm_eps) | |
| def hf_forward(self, hidden_states, packed_seq_params): | |
| hidden_states = self.input_layernorm(hidden_states) | |
| hidden_states = self.linear_attn( | |
| hidden_states=hidden_states, | |
| cu_seqlens=packed_seq_params.cu_seqlens_q, | |
| ) | |
| return hidden_states | |
| def get_qwen3_5_spec(args, config, vp_stage): | |
| # always use the moe path for MoE models | |
| if not args.num_experts: | |
| config.moe_layer_freq = [0] * config.num_layers | |
| # Define the decoder block spec | |
| kwargs = { | |
| "use_transformer_engine": True, | |
| } | |
| if vp_stage is not None: | |
| kwargs["vp_stage"] = vp_stage | |
| transformer_layer_spec = get_gpt_decoder_block_spec(config, **kwargs) | |
| assert config.pipeline_model_parallel_layout is None, "not support this at the moment" | |
| # Slice the layer specs to only include the layers that are built in this pipeline stage. | |
| num_layers_to_build = get_num_layers_to_build(config, vp_stage=vp_stage) | |
| offset = get_transformer_layer_offset(config, vp_stage=vp_stage) | |
| hf_config = _load_hf_config(args.hf_checkpoint) | |
| text_config = _get_text_config(hf_config) | |
| # Compute layer_types if the config class doesn't expose it | |
| if not hasattr(text_config, "layer_types"): | |
| interval = getattr(text_config, "full_attention_interval", 4) | |
| n = text_config.num_hidden_layers | |
| text_config.layer_types = [ | |
| "full_attention" if (i + 1) % interval == 0 else "linear_attention" for i in range(n) | |
| ] | |
| for layer_id in range(num_layers_to_build): | |
| if text_config.layer_types[layer_id + offset] == "linear_attention": | |
| layer_specs = copy.deepcopy(transformer_layer_spec.layer_specs[layer_id]) | |
| layer_specs.submodules.self_attention = ModuleSpec( | |
| module=Attention, | |
| params={"args": args}, | |
| ) | |
| transformer_layer_spec.layer_specs[layer_id] = layer_specs | |
| return transformer_layer_spec | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment