# -*- coding: utf-8 -*- import torch import triton import triton.language as tl @triton.jit def cumsum_matmul_kernel( s, z, s_s_h, s_s_t, s_s_d, T: tl.constexpr, S: tl.constexpr, BT: tl.constexpr, BS: tl.constexpr, ): i_s, i_bh = tl.program_id(0), tl.program_id(1) o_i = tl.arange(0, BT) m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) b_z = tl.zeros([BS], dtype=tl.float32) for i_t in range(tl.cdiv(T, BT)): p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) # [BT, BS] b_s = tl.load(p_s, boundary_check=(0, 1)) # do cumsum by tensor cores b_sc = b_z[None, :] + tl.dot(m_s.to(b_s.dtype), b_s, allow_tf32=False) tl.store(p_z, b_sc.to(p_z.dtype.element_ty), boundary_check=(0, 1)) b_z = b_z + tl.sum(b_s, 0) @triton.jit def cumsum_triton_kernel( s, z, s_s_h, s_s_t, s_s_d, T: tl.constexpr, S: tl.constexpr, BT: tl.constexpr, BS: tl.constexpr, ): i_s, i_bh = tl.program_id(0), tl.program_id(1) b_z = tl.zeros([BS], dtype=tl.float32) for i_t in range(tl.cdiv(T, BT)): p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) # [BT, BS] b_s = tl.load(p_s, boundary_check=(0, 1)) b_sc = b_z[None, :] + tl.cumsum(b_s, 0) tl.store(p_z, b_sc.to(p_z.dtype.element_ty), boundary_check=(0, 1)) b_z = b_z + tl.sum(b_s, 0) def cumsum_torch(s): return s.float().cumsum(2).to(s.dtype) def cumsum_triton(s): B, H, T, S = s.shape BT, BS = 64, 64 NS = triton.cdiv(S, BS) grid = (NS, B * H) z = torch.empty_like(s) cumsum_triton_kernel[grid]( s, z, s.stride(1), s.stride(2), s.stride(3), T=T, S=S, BT=BT, BS=BS, num_warps=2, num_stages=1 ) return z def cumsum_matmul(s): B, H, T, S = s.shape BT, BS = 64, 64 NS = triton.cdiv(S, BS) grid = (NS, B * H) z = torch.empty_like(s) cumsum_matmul_kernel[grid]( s, z, s.stride(1), s.stride(2), s.stride(3), T=T, S=S, BT=BT, BS=BS, num_warps=1, num_stages=1 ) return z B, H, T, D = 8, 4, 2048, 256 dtype = torch.float device = 'cuda' s = torch.randn(B, H, T, D, device=device, dtype=dtype) print("DIFF\t") print('triton\t', f"{(cumsum_torch(s) - cumsum_triton(s)).abs().max():>10.6f}") print('matmul\t', f"{(cumsum_torch(s) - cumsum_matmul(s)).abs().max():>10.6f}") print('Done!') @triton.testing.perf_report( triton.testing.Benchmark( # argument names to use as an x-axis for the plot x_names=['seq_len'], # different possible values for `x_name` x_vals=[128 * 2 ** i for i in range(0, 8)], # argument name whose value corresponds to a different line in the plot line_arg='provider', # possible values for `line_arg`` line_vals=['torch', 'matmul', 'triton'], # label name for the lines line_names=['torch', 'matmul', 'triton'], # line styles styles=[('green', '-'), ('blue', '--'), ('red', '-.')], ylabel="Execution Time (ms)", # label name for the y-axis # name for the plot. Used also as a file name for saving the plot. plot_name="Performance", args={}, ) ) def benchmark(seq_len, provider): device = 'cuda' dtype = torch.bfloat16 s = torch.randn(B, H, seq_len, D, device=device, dtype=dtype) quantiles = [0.5, 0.2, 0.8] results = 0, 0, 0 if provider == 'torch': results = triton.testing.do_bench(lambda: cumsum_torch(s), quantiles=quantiles) elif provider == 'matmul': results = triton.testing.do_bench(lambda: cumsum_matmul(s), quantiles=quantiles) elif provider == 'triton': results = triton.testing.do_bench(lambda: cumsum_triton(s), quantiles=quantiles) return results benchmark.run(print_data=True)