Skip to content

Instantly share code, notes, and snippets.

@victoroliv2
Created March 21, 2024 01:58
Show Gist options
  • Select an option

  • Save victoroliv2/3668f07e11a0757febb6e55a8d78592a to your computer and use it in GitHub Desktop.

Select an option

Save victoroliv2/3668f07e11a0757febb6e55a8d78592a to your computer and use it in GitHub Desktop.

Revisions

  1. victoroliv2 created this gist Mar 21, 2024.
    27 changes: 27 additions & 0 deletions pytorch_fmha_nested_tensor.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,27 @@
    import torch

    BATCH = 4
    EMB_DIM = 256
    HEADS = 8
    Q_TOKENS = 512
    KV_TOKENS = 16384

    q_proj = torch.nested.nested_tensor([torch.zeros(HEADS, Q_TOKENS // (i+1), EMB_DIM) for i in range(BATCH)], dtype=torch.half, device="cuda")
    k_proj = torch.nested.nested_tensor([torch.zeros(HEADS, KV_TOKENS // (i+1), EMB_DIM) for i in range(BATCH)], dtype=torch.half, device="cuda")
    v_proj = torch.nested.nested_tensor([torch.zeros(HEADS, KV_TOKENS // (i+1), EMB_DIM) for i in range(BATCH)], dtype=torch.half, device="cuda")

    def trace_ready(p):
    print('trace ready!')
    import os
    os.system('rm chrome_trace.gz')
    p.export_chrome_trace('chrome_trace')
    os.system('gzip chrome_trace')

    with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],on_trace_ready=trace_ready,with_stack=True):
    for i in range(10):
    with torch.backends.cuda.sdp_kernel(
    enable_flash=True, enable_math=False, enable_mem_efficient=False
    ):
    out = torch.nn.functional.scaled_dot_product_attention(
    q_proj, k_proj, v_proj, attn_mask=None, dropout_p=0.0
    )