Skip to content

Instantly share code, notes, and snippets.

@drisspg
Created October 19, 2024 00:51
Show Gist options
  • Select an option

  • Save drisspg/bef40a41a3f2b6faedf4a5d625616bda to your computer and use it in GitHub Desktop.

Select an option

Save drisspg/bef40a41a3f2b6faedf4a5d625616bda to your computer and use it in GitHub Desktop.

Revisions

  1. drisspg created this gist Oct 19, 2024.
    210 changes: 210 additions & 0 deletions sdpa.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,210 @@
    import itertools
    from collections import defaultdict
    from contextlib import nullcontext
    from dataclasses import asdict, dataclass
    from typing import Callable, List, Tuple

    from tabulate import tabulate
    from tqdm import tqdm

    import torch
    import torch.utils.benchmark as benchmark
    from torch.nn.attention import sdpa_kernel, SDPBackend
    from torch.nn.functional import scaled_dot_product_attention


    def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
    # warmup
    for _ in range(5):
    func(*args, **kwargs)
    t0 = benchmark.Timer(
    stmt="func(*args, **kwargs)",
    globals={"args": args, "kwargs": kwargs, "func": func},
    )
    return t0.adaptive_autorange(min_run_time=0.1).median * 1e6


    @dataclass(frozen=True)
    class ExperimentConfig:
    batch_size: int
    num_heads: int
    q_seq_len: int
    kv_seq_len: int
    embed_dim: int
    is_causal: bool
    dtype: torch.dtype
    backend: SDPBackend
    transposed: bool # New field to control transposition
    device: torch.device = torch.device("cuda")

    @property
    def head_dim(self) -> int:
    return self.embed_dim // self.num_heads

    def asdict(self):
    dict_obj = asdict(self)
    dict_obj["head_dim"] = self.head_dim
    return dict_obj


    @dataclass(frozen=True)
    class ExperimentResults:
    forward_time: float
    backward_time: float

    def asdict(self):
    return asdict(self)


    @dataclass(frozen=True)
    class Experiment:
    config: ExperimentConfig
    results: ExperimentResults

    def asdict(self):
    dict1 = asdict(self.config)
    dict2 = asdict(self.results)
    return {**dict1, **dict2}


    def get_input(
    config: ExperimentConfig,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    if config.transposed:
    q = torch.randn(
    (config.batch_size, config.q_seq_len, config.num_heads, config.head_dim),
    dtype=config.dtype,
    device=config.device,
    requires_grad=True,
    ).transpose(1, 2)
    k = torch.randn(
    (config.batch_size, config.kv_seq_len, config.num_heads, config.head_dim),
    dtype=config.dtype,
    device=config.device,
    requires_grad=True,
    ).transpose(1, 2)
    v = torch.randn(
    (config.batch_size, config.kv_seq_len, config.num_heads, config.head_dim),
    dtype=config.dtype,
    device=config.device,
    requires_grad=True,
    ).transpose(1, 2)
    else:
    q = torch.randn(
    (config.batch_size, config.num_heads, config.q_seq_len, config.head_dim),
    dtype=config.dtype,
    device=config.device,
    requires_grad=True,
    )
    k = torch.randn(
    (config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
    dtype=config.dtype,
    device=config.device,
    requires_grad=True,
    )
    v = torch.randn(
    (config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
    dtype=config.dtype,
    device=config.device,
    requires_grad=True,
    )
    return q, k, v



    def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
    q, k, v = get_input(config)
    is_causal = config.is_causal
    context = (
    sdpa_kernel(config.backend) if config.backend is not None else nullcontext()
    )
    with context:
    forward_time = benchmark_torch_function_in_microseconds(
    scaled_dot_product_attention,
    q,
    k,
    v,
    is_causal=is_causal,
    attn_mask=None,
    )
    out_torch = scaled_dot_product_attention(
    q, k, v, is_causal=is_causal, attn_mask=None
    )
    dOut = torch.randn_like(out_torch)
    backward_time = benchmark_torch_function_in_microseconds(
    out_torch.backward, dOut, retain_graph=True
    )

    return ExperimentResults(
    forward_time=forward_time,
    backward_time=backward_time,
    )


    def generate_experiment_configs() -> List[ExperimentConfig]:
    batch_sizes = [
    1,
    8,
    ]
    num_heads = [32]
    q_kv_seq_lens = [(128, 128), (256, 256), (512, 512), (1024, 1024)]
    transposed_configs = [True, False]
    embed_dims = [2048]
    backends = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION] # If set to None, all backends are enabled
    dtypes = [
    torch.float16,
    ]
    is_causal = [True, False]
    all_configs = []
    for (
    bsz,
    heads,
    (q_seq_len, kv_seq_len),
    embed_dim,
    causal,
    dtype,
    backend,
    transpose
    ) in itertools.product(
    batch_sizes, num_heads, q_kv_seq_lens, embed_dims, is_causal, dtypes, backends, transposed_configs
    ):
    all_configs.append(
    ExperimentConfig(
    batch_size=bsz,
    num_heads=heads,
    q_seq_len=q_seq_len,
    kv_seq_len=kv_seq_len,
    embed_dim=embed_dim,
    is_causal=causal,
    dtype=dtype,
    backend=backend,
    transposed=transpose,
    )
    )

    return all_configs


    def print_results(experiments: List[Experiment]):
    table_data = defaultdict(list)
    for experiment in experiments:
    for key, value in experiment.asdict().items():
    table_data[key].append(value)
    del table_data["device"]
    if table_data["backend"][0] is None:
    del table_data["backend"]
    print(tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f"))


    def main():
    seed = 123
    torch.manual_seed(seed)
    results = []
    for config in tqdm(generate_experiment_configs()):
    results.append(Experiment(config, run_single_experiment(config)))

    print_results(results)


    if __name__ == "__main__":
    main()