Created
November 5, 2024 21:43
-
-
Save xinyazhang/1e0594ea2c08fc9bf8232b6f06f0b06e to your computer and use it in GitHub Desktop.
A modifed version of PyTorch's SDPA.py benchmark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import itertools | |
| from collections import defaultdict | |
| from contextlib import nullcontext | |
| from dataclasses import asdict, dataclass | |
| from typing import Callable, List, Tuple | |
| from tabulate import tabulate | |
| from tqdm import tqdm | |
| import torch | |
| import torch.utils.benchmark as benchmark | |
| from torch.nn.attention import sdpa_kernel, SDPBackend | |
| from torch.nn.functional import scaled_dot_product_attention | |
| def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float: | |
| # warmup | |
| for _ in range(5): | |
| func(*args, **kwargs) | |
| t0 = benchmark.Timer( | |
| stmt="func(*args, **kwargs)", | |
| globals={"args": args, "kwargs": kwargs, "func": func}, | |
| ) | |
| return t0.adaptive_autorange(min_run_time=0.1).median * 1e6 | |
| @dataclass(frozen=True) | |
| class ExperimentConfig: | |
| batch_size: int | |
| num_heads: int | |
| q_seq_len: int | |
| kv_seq_len: int | |
| embed_dim: int | |
| is_causal: bool | |
| dtype: torch.dtype | |
| backend: SDPBackend | |
| device: torch.device = torch.device("cuda") | |
| @property | |
| def head_dim(self) -> int: | |
| return self.embed_dim // self.num_heads | |
| def asdict(self): | |
| dict_obj = asdict(self) | |
| dict_obj["head_dim"] = self.head_dim | |
| # del dict_obj["embed_dim"] | |
| return dict_obj | |
| @dataclass(frozen=True) | |
| class ExperimentRatio: | |
| forward_ratio: float | |
| backward_ratio: float | |
| def asdict(self): | |
| return asdict(self) | |
| @dataclass(frozen=True) | |
| class ExperimentResults: | |
| forward_time: float | |
| backward_time: float | |
| def asdict(self): | |
| return asdict(self) | |
| def __truediv__(self, other) -> ExperimentRatio: | |
| return ExperimentRatio(forward_ratio = self.forward_time / other.forward_time, | |
| backward_ratio = self.backward_time / other.backward_time) | |
| @dataclass(frozen=True) | |
| class Experiment: | |
| config: ExperimentConfig | |
| results: ExperimentResults | |
| baseline: ExperimentResults | |
| def asdict(self): | |
| dict1 = self.config.asdict() | |
| result_dict = asdict(self.results) | |
| ratio_dict = asdict(self.baseline / self.results) | |
| return {**dict1, **result_dict, **ratio_dict} | |
| def get_input( | |
| config: ExperimentConfig, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| q = torch.randn( | |
| (config.batch_size, config.num_heads, config.q_seq_len, config.head_dim), | |
| dtype=config.dtype, | |
| device=config.device, | |
| requires_grad=True, | |
| ) | |
| k = torch.randn( | |
| (config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim), | |
| dtype=config.dtype, | |
| device=config.device, | |
| requires_grad=True, | |
| ) | |
| v = torch.randn( | |
| (config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim), | |
| dtype=config.dtype, | |
| device=config.device, | |
| requires_grad=True, | |
| ) | |
| return q, k, v | |
| def run_single_experiment(config: ExperimentConfig, baseline=False) -> ExperimentResults: | |
| q, k, v = get_input(config) | |
| is_causal = config.is_causal | |
| context = ( | |
| sdpa_kernel(config.backend) if config.backend is not None else nullcontext() | |
| ) if not baseline else ( | |
| sdpa_kernel(SDPBackend.MATH) | |
| ) | |
| with context: | |
| forward_time = benchmark_torch_function_in_microseconds( | |
| scaled_dot_product_attention, | |
| q, | |
| k, | |
| v, | |
| is_causal=is_causal, | |
| attn_mask=None, | |
| ) | |
| out_torch = scaled_dot_product_attention( | |
| q, k, v, is_causal=is_causal, attn_mask=None | |
| ) | |
| dOut = torch.randn_like(out_torch) | |
| backward_time = benchmark_torch_function_in_microseconds( | |
| out_torch.backward, dOut, retain_graph=True | |
| ) | |
| return ExperimentResults( | |
| forward_time=forward_time, | |
| backward_time=backward_time, | |
| ) | |
| def generate_experiment_configs() -> List[ExperimentConfig]: | |
| batch_sizes = [ | |
| 1, | |
| 8, | |
| ] | |
| num_heads = [16] | |
| q_kv_seq_lens = [(128, 128), (256, 256), (512, 512), (1024, 1024)] | |
| embed_dims = [2048] | |
| # backends = [SDPBackend.FLASH_ATTENTION, SDPBackend.MATH] # If set to None, all backends are enabled | |
| backends = [SDPBackend.EFFICIENT_ATTENTION] # If set to None, all backends are enabled | |
| dtypes = [ | |
| # torch.bfloat16, | |
| torch.float32, | |
| ] | |
| is_causal = [True, False] | |
| all_configs = [] | |
| for ( | |
| bsz, | |
| heads, | |
| (q_seq_len, kv_seq_len), | |
| embed_dim, | |
| causal, | |
| dtype, | |
| backend, | |
| ) in itertools.product( | |
| batch_sizes, num_heads, q_kv_seq_lens, embed_dims, is_causal, dtypes, backends | |
| ): | |
| all_configs.append( | |
| ExperimentConfig( | |
| batch_size=bsz, | |
| num_heads=heads, | |
| q_seq_len=q_seq_len, | |
| kv_seq_len=kv_seq_len, | |
| embed_dim=embed_dim, | |
| is_causal=causal, | |
| dtype=dtype, | |
| backend=backend, | |
| ) | |
| ) | |
| # return all_configs # debug | |
| for batch_size in [8,32]: | |
| for backend in backends: | |
| all_configs.append( | |
| ExperimentConfig( | |
| batch_size=batch_size, | |
| num_heads=12, | |
| q_seq_len=224, | |
| kv_seq_len=224, | |
| embed_dim=768, | |
| is_causal=True, | |
| dtype=torch.float32, | |
| backend=backend, | |
| ) | |
| ) | |
| return all_configs | |
| def print_results(experiments: List[Experiment]): | |
| table_data = defaultdict(list) | |
| for experiment in experiments: | |
| for key, value in experiment.asdict().items(): | |
| table_data[key].append(value) | |
| del table_data["device"] | |
| # If it's default, show default | |
| # | |
| # if table_data["backend"][0] is None: | |
| # del table_data["backend"] | |
| print(tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f")) | |
| def main(): | |
| seed = 123 | |
| torch.manual_seed(seed) | |
| results = [] | |
| for config in tqdm(generate_experiment_configs()): | |
| results.append(Experiment(config, | |
| run_single_experiment(config), | |
| run_single_experiment(config, baseline=True)) | |
| ) | |
| print_results(results) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment