Created
November 4, 2024 17:34
-
-
Save xinyazhang/1b469d353c798b3b1758001dfb3a1ee3 to your computer and use it in GitHub Desktop.
PyTorch SDPA kernel selection
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import contextlib | |
| from torch.nn.functional import scaled_dot_product_attention | |
| from torch.nn.attention import sdpa_kernel, SDPBackend | |
| ctxmgr = contextlib.nullcontext() | |
| # ctxmgr = sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]) | |
| # ctxmgr = sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]) | |
| # ctxmgr = sdpa_kernel(backends=[SDPBackend.MATH]) | |
| # ctxmgr = sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]) | |
| with ctxmgr: | |
| pass # call scaled_dot_product_attention here |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment