Skip to content

Instantly share code, notes, and snippets.

@xinyazhang
Created November 4, 2024 17:34
Show Gist options
  • Select an option

  • Save xinyazhang/1b469d353c798b3b1758001dfb3a1ee3 to your computer and use it in GitHub Desktop.

Select an option

Save xinyazhang/1b469d353c798b3b1758001dfb3a1ee3 to your computer and use it in GitHub Desktop.
PyTorch SDPA kernel selection
import contextlib
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import sdpa_kernel, SDPBackend
ctxmgr = contextlib.nullcontext()
# ctxmgr = sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION])
# ctxmgr = sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION])
# ctxmgr = sdpa_kernel(backends=[SDPBackend.MATH])
# ctxmgr = sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION])
with ctxmgr:
pass # call scaled_dot_product_attention here
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment