# -*- coding: utf-8 -*- import torch import torch.nn as nn import triton import triton.language as tl from transformers.models.llama.modeling_llama import LlamaRMSNorm @triton.jit def rmsnorm_fwd_kernel_r( x, r, eps, stride_xb, stride_xt, stride_xd, stride_rb, T, D, BT: tl.constexpr, BD: tl.constexpr ): i_t, i_b = tl.program_id(0), tl.program_id(1) p_x = tl.make_block_ptr(x + i_b * stride_xb, (T, D), (stride_xt, stride_xd), (i_t * BT, 0), (BT, BD), (1, 0)) p_r = tl.make_block_ptr(r + i_b * stride_rb, (T,), (stride_xd,), (i_t * BT,), (BT,), (0,)) # [BT,] b_m = tl.zeros([BT,], dtype=tl.float32) for _ in range(0, D, BD): b_x = tl.load(p_x) b_m += tl.sum(tl.math.pow(b_x.to(tl.float32), 2), 1) p_x = tl.advance(p_x, (0, BD)) # [BT,] b_m = b_m / D b_r = tl.math.rsqrt(b_m + eps) tl.store(p_r, b_r) @triton.jit def rmsnorm_fwd_kernel_y( x, z, y, r, w, stride_xb, stride_xt, stride_xd, stride_rb, T, D, BT: tl.constexpr, BD: tl.constexpr ): i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) p_x = tl.make_block_ptr(x + i_b * stride_xb, (T, D), (stride_xt, stride_xd), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) p_z = tl.make_block_ptr(z + i_b * stride_xb, (T, D), (stride_xt, stride_xd), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) p_y = tl.make_block_ptr(y + i_b * stride_xb, (T, D), (stride_xt, stride_xd), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) p_r = tl.make_block_ptr(r + i_b * stride_rb, (T,), (stride_xd,), (i_t * BT,), (BT,), (0,)) p_w = tl.make_block_ptr(w, (D,), (stride_xd,), (i_d * BD,), (BD,), (0,)) # [BT,] b_r = tl.load(p_r) # [BT, BD] b_x = tl.load(p_x) # [BD,] b_w = tl.load(p_w) b_z = (b_x.to(tl.float32) * b_r[:, None]).to(b_x.dtype) tl.store(p_z, b_z) tl.store(p_y, b_z * b_w) @triton.jit def rmsnorm_bwd_kernel_s( z, s, w, dy, stride_xb, stride_xt, stride_xd, stride_rb, T, D, BT: tl.constexpr, BD: tl.constexpr ): i_t, i_b = tl.program_id(0), tl.program_id(1) p_s = tl.make_block_ptr(s + i_b * stride_rb, (T,), (stride_xd,), (i_t * BT,), (BT,), (0,)) b_s = tl.zeros([BT,], dtype=tl.float32) for i in range(0, D, BD): p_z = tl.make_block_ptr(z + i_b * stride_xb, (T, D), (stride_xt, stride_xd), (i_t * BT, i), (BT, BD), (1, 0)) p_w = tl.make_block_ptr(w, (D,), (stride_xd,), (i,), (BD,), (0,)) p_dy = tl.make_block_ptr(dy + i_b * stride_xb, (T, D), (stride_xt, stride_xd), (i_t * BT, i), (BT, BD), (1, 0)) b_s += tl.sum(tl.load(p_z) * tl.load(p_dy) * tl.load(p_w)[None, :], 1) tl.store(p_s, (b_s / D).to(p_s.dtype.element_ty)) @triton.jit def rmsnorm_bwd_kernel( z, r, s, w, dy, dx, dw, stride_xb, stride_xt, stride_xd, stride_rb, stride_dw, T, D, BT: tl.constexpr, BD: tl.constexpr ): i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) p_z = tl.make_block_ptr(z + i_b * stride_xb, (T, D), (stride_xt, stride_xd), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) p_r = tl.make_block_ptr(r + i_b * stride_rb, (T,), (stride_xd,), (i_t * BT,), (BT,), (0,)) p_s = tl.make_block_ptr(s + i_b * stride_rb, (T,), (stride_xd,), (i_t * BT,), (BT,), (0,)) p_w = tl.make_block_ptr(w, (D,), (stride_xd,), (i_d * BD,), (BD,), (0,)) p_dy = tl.make_block_ptr(dy + i_b * stride_xb, (T, D), (stride_xt, stride_xd), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) p_dx = tl.make_block_ptr(dx + i_b * stride_xb, (T, D), (stride_xt, stride_xd), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) p_dw = tl.make_block_ptr(dw + i_b * stride_dw, (tl.cdiv(T, BT) * D,), (stride_xd), (i_t * D + i_d * BD,), (BD,), (0,)) # [BT,] b_r = tl.load(p_r) b_s = tl.load(p_s) # [BT, BD] b_z = tl.load(p_z) # [BD,] b_w = tl.load(p_w) # [BT, BD] b_dy = tl.load(p_dy) # [BT, BD] b_dx = (b_dy * b_w[None, :] - b_s[:, None] * b_z) * b_r[:, None] # [BD,] b_dw = tl.sum(b_dy * b_z, 0) tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty)) tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty)) class FlashRMSNormFunction(torch.autograd.Function): @staticmethod def forward(ctx, x, w, eps): if not x.is_contiguous(): raise ValueError("data must be contiguous") batch_size, seq_len, hidden_size = x.shape BT, BD = 128, 128 grid_r = (triton.cdiv(seq_len, BT), batch_size) grid_y = (triton.cdiv(hidden_size, BD), triton.cdiv(seq_len, BT), batch_size) r = x.new_empty(batch_size, seq_len, dtype=torch.float) rmsnorm_fwd_kernel_r[grid_r]( x, r, eps, x.stride(0), x.stride(1), x.stride(2), r.stride(0), seq_len, hidden_size, BT=BT, BD=BD, num_stages=3, num_warps=4 ) z, y = torch.empty_like(x), torch.empty_like(x) rmsnorm_fwd_kernel_y[grid_y]( x, z, y, r, w, x.stride(0), x.stride(1), x.stride(2), r.stride(0), seq_len, hidden_size, BT=BT, BD=BD, num_stages=3, num_warps=4 ) ctx.save_for_backward(z, r, w) return y @staticmethod def backward(ctx, dy): z, r, w = ctx.saved_tensors batch_size, seq_len, hidden_size = z.shape BT, BD = 128, 128 grid_s = (triton.cdiv(seq_len, BT), batch_size) grid_d = (triton.cdiv(hidden_size, BD), triton.cdiv(seq_len, BT), batch_size) s, dx, dw = torch.empty_like(r), torch.empty_like(z), z.new_empty(batch_size, grid_d[1], hidden_size) rmsnorm_bwd_kernel_s[grid_s]( z, s, w, dy, z.stride(0), z.stride(1), z.stride(2), r.stride(0), seq_len, hidden_size, BT=BT, BD=BD, num_stages=3, num_warps=4 ) rmsnorm_bwd_kernel[grid_d]( z, r, s, w, dy, dx, dw, z.stride(0), z.stride(1), z.stride(2), r.stride(0), dw.stride(0), seq_len, hidden_size, BT=BT, BD=BD, num_stages=3, num_warps=4 ) dw = dw.sum((0, 1)) return dx, dw, None class LlamaRMSNorm(LlamaRMSNorm): """ RMS Normalization layer along the last dimension. This is similar to torch.nn.functional.normalize but with eps being added instead of max. Expects contiguous input of shape (..., dim), and returns normalized data of the same shape. For each dim-length vector x, the result has x / sqrt( x*x.sum() + eps) If weights are included, they are a parameter of length dim which multiplies the result. This functionality is experimental. Its API might be changed without warnings. Use it at your own risk. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, x: torch.Tensor): return FlashRMSNormFunction.apply(x, self.weight, self.variance_epsilon) class NaiveRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) if __name__ == "__main__": dtype = torch.float torch.random.manual_seed(0) naive_rmsnorm = NaiveRMSNorm(768).to('cuda').train() flash_rmsnorm = LlamaRMSNorm(768).to('cuda').train() w = torch.randn_like(naive_rmsnorm.weight) naive_rmsnorm.weight.data.copy_(w) flash_rmsnorm.weight.data.copy_(w) if dtype == torch.bfloat16: naive_rmsnorm = naive_rmsnorm.bfloat16() flash_rmsnorm = flash_rmsnorm.bfloat16() if dtype == torch.float: naive_rmsnorm = naive_rmsnorm.float() flash_rmsnorm = flash_rmsnorm.float() if dtype == torch.float16: naive_rmsnorm = naive_rmsnorm.half() flash_rmsnorm = flash_rmsnorm.half() x = torch.randn((8, 2048, 768), device='cuda', dtype=dtype, requires_grad=True) dy = torch.randn_like(x) ref = naive_rmsnorm(x) ref.backward(dy) ref_dw, naive_rmsnorm.weight.grad = naive_rmsnorm.weight.grad.clone(), None ref_dx, x.grad = x.grad.clone(), None tri = flash_rmsnorm(x) tri.backward(dy) tri_dw, flash_rmsnorm.weight.grad = flash_rmsnorm.weight.grad.clone(), None tri_dx, x.grad = x.grad.clone(), None assert ref.allclose(tri, 0, 1e-2), breakpoint() assert ref_dx.allclose(tri_dx, 0, 1e-2), breakpoint() assert ref_dw.allclose(tri_dw, 0, 1e-2), breakpoint() print('Done!') @triton.testing.perf_report( triton.testing.Benchmark( # argument names to use as an x-axis for the plot x_names=['seq_len'], # different possible values for `x_name` x_vals=[128 * 2 ** i for i in range(0, 10)], # argument name whose value corresponds to a different line in the plot line_arg='provider', # possible values for `line_arg`` line_vals=['naive', 'flash', 'naive_bwd', 'flash_bwd'], # label name for the lines line_names=['naive', 'flash', 'naive_bwd', 'flash_bwd'], # line styles styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':')], ylabel="Execution Time (ms)", # label name for the y-axis # name for the plot. Used also as a file name for saving the plot. plot_name="Performance", args={}, ) ) def benchmark(seq_len, provider): device = 'cuda' dtype = torch.bfloat16 requires_grad = True batch_size, hidden_size = 2, 3200 naive_rmsnorm = NaiveRMSNorm(hidden_size).to(device) flash_rmsnorm = LlamaRMSNorm(hidden_size).to(device) if dtype == torch.bfloat16: naive_rmsnorm = naive_rmsnorm.bfloat16() flash_rmsnorm = flash_rmsnorm.bfloat16() if dtype == torch.float: naive_rmsnorm = naive_rmsnorm.float() flash_rmsnorm = flash_rmsnorm.float() if dtype == torch.float16: naive_rmsnorm = naive_rmsnorm.half() flash_rmsnorm = flash_rmsnorm.half() x = torch.ones(batch_size, seq_len, hidden_size, requires_grad=requires_grad, dtype=dtype, device=device) dy = torch.randn_like(x) quantiles = [0.5, 0.2, 0.8] if provider == 'naive': results = triton.testing.do_bench(lambda: naive_rmsnorm(x), quantiles=quantiles) elif provider == 'flash': results = triton.testing.do_bench(lambda: flash_rmsnorm(x), quantiles=quantiles) elif provider == 'naive_bwd': results = triton.testing.do_bench(lambda: naive_rmsnorm(x).backward(dy), quantiles=quantiles) elif provider == 'flash_bwd': results = triton.testing.do_bench(lambda: flash_rmsnorm(x).backward(dy), quantiles=quantiles) return results benchmark.run(print_data=True)