Skip to content

Instantly share code, notes, and snippets.

@xinyazhang
Created November 5, 2024 21:43
Show Gist options
  • Select an option

  • Save xinyazhang/1e0594ea2c08fc9bf8232b6f06f0b06e to your computer and use it in GitHub Desktop.

Select an option

Save xinyazhang/1e0594ea2c08fc9bf8232b6f06f0b06e to your computer and use it in GitHub Desktop.
A modifed version of PyTorch's SDPA.py benchmark
import itertools
from collections import defaultdict
from contextlib import nullcontext
from dataclasses import asdict, dataclass
from typing import Callable, List, Tuple
from tabulate import tabulate
from tqdm import tqdm
import torch
import torch.utils.benchmark as benchmark
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.functional import scaled_dot_product_attention
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
# warmup
for _ in range(5):
func(*args, **kwargs)
t0 = benchmark.Timer(
stmt="func(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "func": func},
)
return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
@dataclass(frozen=True)
class ExperimentConfig:
batch_size: int
num_heads: int
q_seq_len: int
kv_seq_len: int
embed_dim: int
is_causal: bool
dtype: torch.dtype
backend: SDPBackend
device: torch.device = torch.device("cuda")
@property
def head_dim(self) -> int:
return self.embed_dim // self.num_heads
def asdict(self):
dict_obj = asdict(self)
dict_obj["head_dim"] = self.head_dim
# del dict_obj["embed_dim"]
return dict_obj
@dataclass(frozen=True)
class ExperimentRatio:
forward_ratio: float
backward_ratio: float
def asdict(self):
return asdict(self)
@dataclass(frozen=True)
class ExperimentResults:
forward_time: float
backward_time: float
def asdict(self):
return asdict(self)
def __truediv__(self, other) -> ExperimentRatio:
return ExperimentRatio(forward_ratio = self.forward_time / other.forward_time,
backward_ratio = self.backward_time / other.backward_time)
@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
results: ExperimentResults
baseline: ExperimentResults
def asdict(self):
dict1 = self.config.asdict()
result_dict = asdict(self.results)
ratio_dict = asdict(self.baseline / self.results)
return {**dict1, **result_dict, **ratio_dict}
def get_input(
config: ExperimentConfig,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = torch.randn(
(config.batch_size, config.num_heads, config.q_seq_len, config.head_dim),
dtype=config.dtype,
device=config.device,
requires_grad=True,
)
k = torch.randn(
(config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
dtype=config.dtype,
device=config.device,
requires_grad=True,
)
v = torch.randn(
(config.batch_size, config.num_heads, config.kv_seq_len, config.head_dim),
dtype=config.dtype,
device=config.device,
requires_grad=True,
)
return q, k, v
def run_single_experiment(config: ExperimentConfig, baseline=False) -> ExperimentResults:
q, k, v = get_input(config)
is_causal = config.is_causal
context = (
sdpa_kernel(config.backend) if config.backend is not None else nullcontext()
) if not baseline else (
sdpa_kernel(SDPBackend.MATH)
)
with context:
forward_time = benchmark_torch_function_in_microseconds(
scaled_dot_product_attention,
q,
k,
v,
is_causal=is_causal,
attn_mask=None,
)
out_torch = scaled_dot_product_attention(
q, k, v, is_causal=is_causal, attn_mask=None
)
dOut = torch.randn_like(out_torch)
backward_time = benchmark_torch_function_in_microseconds(
out_torch.backward, dOut, retain_graph=True
)
return ExperimentResults(
forward_time=forward_time,
backward_time=backward_time,
)
def generate_experiment_configs() -> List[ExperimentConfig]:
batch_sizes = [
1,
8,
]
num_heads = [16]
q_kv_seq_lens = [(128, 128), (256, 256), (512, 512), (1024, 1024)]
embed_dims = [2048]
# backends = [SDPBackend.FLASH_ATTENTION, SDPBackend.MATH] # If set to None, all backends are enabled
backends = [SDPBackend.EFFICIENT_ATTENTION] # If set to None, all backends are enabled
dtypes = [
# torch.bfloat16,
torch.float32,
]
is_causal = [True, False]
all_configs = []
for (
bsz,
heads,
(q_seq_len, kv_seq_len),
embed_dim,
causal,
dtype,
backend,
) in itertools.product(
batch_sizes, num_heads, q_kv_seq_lens, embed_dims, is_causal, dtypes, backends
):
all_configs.append(
ExperimentConfig(
batch_size=bsz,
num_heads=heads,
q_seq_len=q_seq_len,
kv_seq_len=kv_seq_len,
embed_dim=embed_dim,
is_causal=causal,
dtype=dtype,
backend=backend,
)
)
# return all_configs # debug
for batch_size in [8,32]:
for backend in backends:
all_configs.append(
ExperimentConfig(
batch_size=batch_size,
num_heads=12,
q_seq_len=224,
kv_seq_len=224,
embed_dim=768,
is_causal=True,
dtype=torch.float32,
backend=backend,
)
)
return all_configs
def print_results(experiments: List[Experiment]):
table_data = defaultdict(list)
for experiment in experiments:
for key, value in experiment.asdict().items():
table_data[key].append(value)
del table_data["device"]
# If it's default, show default
#
# if table_data["backend"][0] is None:
# del table_data["backend"]
print(tabulate(table_data, headers="keys", tablefmt="pretty", floatfmt=".3f"))
def main():
seed = 123
torch.manual_seed(seed)
results = []
for config in tqdm(generate_experiment_configs()):
results.append(Experiment(config,
run_single_experiment(config),
run_single_experiment(config, baseline=True))
)
print_results(results)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment