Created
June 30, 2025 11:36
-
-
Save a-r-r-o-w/87926a348703e55f008f259a1778e4f3 to your computer and use it in GitHub Desktop.
Revisions
-
a-r-r-o-w created this gist
Jun 30, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,89 @@ import torch torch.manual_seed(42) def torch_sdpa(query, key, value): out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( torch.ops.aten._scaled_dot_product_cudnn_attention( query=query, key=key, value=value, attn_bias=None, compute_log_sumexp=True, ) ) return out, lse def ring_sdpa_sequential(partial_queries, partial_keys, partial_values, *, world_size: int = 1, convert_to_fp32: bool = True): outputs, lses = [], [] for rank in range(world_size): query, key, value = partial_queries[rank], partial_keys[rank], partial_values[rank] next_rank = (rank + 1) % world_size prev_out = prev_lse = None for i in range(world_size): if i > 0: key, value = partial_keys[next_rank], partial_values[next_rank] next_rank = (next_rank + 1) % world_size out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( torch.ops.aten._scaled_dot_product_cudnn_attention( query=query, key=key, value=value, attn_bias=None, compute_log_sumexp=True, ) ) if convert_to_fp32: out = out.to(torch.float32) lse = lse.to(torch.float32) # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 lse = lse.unsqueeze(-1) if prev_out is not None: out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) prev_out = out prev_lse = lse out = out.to(query.dtype) lse = lse.squeeze(-1) outputs.append(out) lses.append(lse) return outputs, lses device = "cuda" dtype = torch.bfloat16 world_size = 4 batch_size = 1 image_sequence_length = 4096 text_sequence_length = 512 sequence_length = image_sequence_length + text_sequence_length num_attention_heads = 24 attention_head_dim = 128 query = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) key = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) value = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) partial_queries = query.chunk(world_size, dim=2) partial_keys = key.chunk(world_size, dim=2) partial_values = value.chunk(world_size, dim=2) torch_sdpa_out, torch_sdpa_lse = torch_sdpa(query, key, value) ring_sdpa_out, ring_sdpa_lse = ring_sdpa_sequential(partial_queries, partial_keys, partial_values, world_size=world_size) all_ring_sdpa_out = torch.cat(ring_sdpa_out, dim=2) all_ring_sdpa_lse = torch.cat(ring_sdpa_lse, dim=2) assert torch_sdpa_out.shape == all_ring_sdpa_out.shape, "Output shapes do not match!" assert torch_sdpa_lse.shape == all_ring_sdpa_lse.shape, "LSE shapes do not match!" assert torch.allclose(all_ring_sdpa_out, torch_sdpa_out, atol=1e-3, rtol=1e-3), "Outputs do not match!" assert torch.allclose(all_ring_sdpa_lse, torch_sdpa_lse, atol=1e-3, rtol=1e-3), "LSE values do not match!" This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,103 @@ import torch torch.manual_seed(42) def torch_sdpa(query, key, value): out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( torch.ops.aten._scaled_dot_product_cudnn_attention( query=query, key=key, value=value, attn_bias=None, compute_log_sumexp=True, ) ) return out, lse def ulysses_sdpa_sequential(partial_queries, partial_keys, partial_values, *, world_size: int = 1): B, H, S_LOCAL, D = partial_queries[0].shape H_LOCAL = H // world_size outputs, lses = [], [] for partials in [partial_queries, partial_keys, partial_values]: for rank in range(world_size): x_local = partials[rank] # (B, H, S // world_size, D) -> (world_size, S // world_size, B, H // world_size, D) partials[rank] = x_local.reshape(B, world_size, H_LOCAL, S_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() x = all_to_all_single_sequential(partials, world_size) for rank in range(world_size): x_local = x[rank] # (S, B, H // world_size, D) -> (B, H // world_size, S, D) partials[rank] = x_local.permute(1, 2, 0, 3).contiguous() for rank in range(world_size): query_local, key_local, value_local = partial_queries[rank], partial_keys[rank], partial_values[rank] out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( torch.ops.aten._scaled_dot_product_cudnn_attention( query=query_local, key=key_local, value=value_local, attn_bias=None, compute_log_sumexp=True, ) ) outputs.append(out) lses.append(lse) for rank in range(world_size): out_local = outputs[rank] lse_local = lses[rank] # (B, H // world_size, S, D) -> (B, H // world_size, world_size, S // world_size, D) -> (world_size, H // world_size, B, S // world_size, D) outputs[rank] = out_local.reshape(B, H_LOCAL, world_size, S_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() lses[rank] = lse_local.reshape(B, H_LOCAL, world_size, S_LOCAL).permute(2, 1, 0, 3).contiguous() outputs = all_to_all_single_sequential(outputs, world_size) lses = all_to_all_single_sequential(lses, world_size) for rank in range(world_size): out_local = outputs[rank] lse_local = lses[rank] # (H, B, S // world_size, D) -> (B, H, S // world_size, D) outputs[rank] = out_local.permute(1, 0, 2, 3).contiguous() lses[rank] = lse_local.permute(1, 0, 2).contiguous() return outputs, lses def all_to_all_single_sequential(partials, world_size): output_partials = [] for i in range(world_size): received_chunks = [p[i] for p in partials] output_partials.append(torch.cat(received_chunks, dim=0)) return output_partials device = "cuda" dtype = torch.bfloat16 world_size = 4 batch_size = 1 image_sequence_length = 4096 text_sequence_length = 512 sequence_length = image_sequence_length + text_sequence_length num_attention_heads = 24 attention_head_dim = 128 query = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) key = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) value = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) partial_queries = list(query.chunk(world_size, dim=2)) partial_keys = list(key.chunk(world_size, dim=2)) partial_values = list(value.chunk(world_size, dim=2)) torch_sdpa_out, torch_sdpa_lse = torch_sdpa(query, key, value) ulysses_sdpa_out, ulysses_sdpa_lse = ulysses_sdpa_sequential(partial_queries, partial_keys, partial_values, world_size=world_size) all_ulysses_sdpa_out = torch.cat(ulysses_sdpa_out, dim=2) all_ulysses_sdpa_lse = torch.cat(ulysses_sdpa_lse, dim=2) assert torch_sdpa_out.shape == all_ulysses_sdpa_out.shape, "Output shapes do not match!" assert torch_sdpa_lse.shape == all_ulysses_sdpa_lse.shape, "LSE shapes do not match!" assert torch.allclose(all_ulysses_sdpa_out, torch_sdpa_out, atol=1e-3, rtol=1e-3), "Outputs do not match!" assert torch.allclose(all_ulysses_sdpa_lse, torch_sdpa_lse, atol=1e-3, rtol=1e-3), "LSEs do not match!" This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,149 @@ import torch torch.manual_seed(42) def torch_sdpa(query, key, value): out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( torch.ops.aten._scaled_dot_product_cudnn_attention( query=query, key=key, value=value, attn_bias=None, compute_log_sumexp=True, ) ) return out, lse def ring_sdpa_sequential(partial_queries, partial_keys, partial_values, *, ring_size: int = 1, convert_to_fp32: bool = True): outputs, lses = [], [] for rank in range(ring_size): query, key, value = partial_queries[rank], partial_keys[rank], partial_values[rank] next_rank = (rank + 1) % ring_size prev_out = prev_lse = None for i in range(ring_size): if i > 0: key, value = partial_keys[next_rank], partial_values[next_rank] next_rank = (next_rank + 1) % ring_size out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( torch.ops.aten._scaled_dot_product_cudnn_attention( query=query, key=key, value=value, attn_bias=None, compute_log_sumexp=True, ) ) if convert_to_fp32: out = out.to(torch.float32) lse = lse.to(torch.float32) # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 lse = lse.unsqueeze(-1) if prev_out is not None: out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) prev_out = out prev_lse = lse out = out.to(query.dtype) lse = lse.squeeze(-1) outputs.append(out) lses.append(lse) return outputs, lses def unified_ulysses_ring_sdpa_sequential(partial_queries, partial_keys, partial_values, *, ulysses_size: int = 1, ring_size: int = 1): B, H, S_LOCAL, D = partial_queries[0][0].shape H_LOCAL = H // ulysses_size outputs, lses = [], [] for partials in [partial_queries, partial_keys, partial_values]: for ring_rank in range(ring_size): for rank in range(ulysses_size): x_local = partials[ring_rank][rank] partials[ring_rank][rank] = x_local.reshape(B, ulysses_size, H_LOCAL, S_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() x = all_to_all_single_sequential(partials[ring_rank], ulysses_size) for rank in range(ulysses_size): x_local = x[rank] partials[ring_rank][rank] = x_local.permute(1, 2, 0, 3).contiguous() partial_queries = [list(x) for x in zip(*partial_queries)] partial_keys = [list(x) for x in zip(*partial_keys)] partial_values = [list(x) for x in zip(*partial_values)] for rank in range(ulysses_size): ring_outputs, ring_lses = ring_sdpa_sequential(partial_queries[rank], partial_keys[rank], partial_values[rank], ring_size=ring_size) outputs.append(ring_outputs) lses.append(ring_lses) outputs = [list(x) for x in zip(*outputs)] lses = [list(x) for x in zip(*lses)] for ring_rank in range(ring_size): for rank in range(ulysses_size): outputs[ring_rank][rank] = outputs[ring_rank][rank].reshape(B, H_LOCAL, ulysses_size, S_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() lses[ring_rank][rank] = lses[ring_rank][rank].reshape(B, H_LOCAL, ulysses_size, S_LOCAL).permute(2, 1, 0, 3).contiguous() outputs[ring_rank] = all_to_all_single_sequential(outputs[ring_rank], ulysses_size) lses[ring_rank] = all_to_all_single_sequential(lses[ring_rank], ulysses_size) for rank in range(ulysses_size): outputs[ring_rank][rank] = outputs[ring_rank][rank].permute(1, 0, 2, 3).contiguous() lses[ring_rank][rank] = lses[ring_rank][rank].permute(1, 0, 2).contiguous() return outputs, lses def all_to_all_single_sequential(partials, world_size): output_partials = [] for i in range(world_size): received_chunks = [p[i] for p in partials] output_partials.append(torch.cat(received_chunks, dim=0)) return output_partials device = "cuda" dtype = torch.bfloat16 WORLD_SIZE = 8 ulysses_size = 4 ring_size = 2 assert ulysses_size * ring_size == WORLD_SIZE, "ulysses_size * ring_size must equal WORLD_SIZE" batch_size = 1 image_sequence_length = 4096 text_sequence_length = 512 sequence_length = image_sequence_length + text_sequence_length num_attention_heads = 24 attention_head_dim = 128 query = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) key = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) value = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) partial_queries = list(query.chunk(WORLD_SIZE, dim=2)) partial_keys = list(key.chunk(WORLD_SIZE, dim=2)) partial_values = list(value.chunk(WORLD_SIZE, dim=2)) # R=1, U=4 => [[tensor1, tensor2, tensor3, tensor4]] # R=2, U=2 => [[tensor1, tensor2], [tensor3, tensor4]] # R=4, U=1 => [[tensor1], [tensor2], [tensor3], [tensor4]] partial_queries = [partial_queries[i:i + ulysses_size] for i in range(0, WORLD_SIZE, ulysses_size)] partial_keys = [partial_keys[i:i + ulysses_size] for i in range(0, WORLD_SIZE, ulysses_size)] partial_values = [partial_values[i:i + ulysses_size] for i in range(0, WORLD_SIZE, ulysses_size)] torch_sdpa_out, torch_sdpa_lse = torch_sdpa(query, key, value) unified_sdpa_out, unified_sdpa_lse = unified_ulysses_ring_sdpa_sequential(partial_queries, partial_keys, partial_values, ulysses_size=ulysses_size, ring_size=ring_size) all_unified_sdpa_out = torch.cat([torch.cat(out, dim=2) for out in unified_sdpa_out], dim=2) all_unified_sdpa_lse = torch.cat([torch.cat(lse, dim=2) for lse in unified_sdpa_lse], dim=2) assert torch_sdpa_out.shape == all_unified_sdpa_out.shape, "Output shapes do not match!" assert torch_sdpa_lse.shape == all_unified_sdpa_lse.shape, "LSE shapes do not match!" assert torch.allclose(all_unified_sdpa_out, torch_sdpa_out, atol=1e-3, rtol=1e-3), "Outputs do not match!" assert torch.allclose(all_unified_sdpa_lse, torch_sdpa_lse, atol=1e-3, rtol=1e-3), "LSEs do not match!" This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,305 @@ import argparse from dataclasses import dataclass from typing import Callable, Literal, List import torch import torch.distributed as dist import torch.distributed._functional_collectives as funcol from torch.distributed import DeviceMesh @dataclass class ContextParallelOptions: mode: Literal["ring", "ulysses", "unified"] = "ring" ring_mesh: DeviceMesh | None = None ulysses_mesh: DeviceMesh | None = None convert_to_fp32: bool = True op: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]] | None = None cp_options = ContextParallelOptions() def _templated_ring_attention(query, key, value): rank = cp_options.ring_mesh.get_rank() world_size = cp_options.ring_mesh.size() if world_size == 1: return cp_options.op(query, key, value) next_rank = (rank + 1) % world_size prev_out = prev_lse = None kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous() kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=cp_options.ring_mesh.get_group()) kv_buffer = kv_buffer.chunk(world_size) for i in range(world_size): if i > 0: kv = kv_buffer[next_rank] key = kv[:key.numel()].reshape_as(key) value = kv[key.numel():].reshape_as(value) next_rank = (next_rank + 1) % world_size out, lse = cp_options.op(query, key, value) if cp_options.convert_to_fp32: out = out.to(torch.float32) lse = lse.to(torch.float32) lse = lse.unsqueeze(-1) if prev_out is not None: out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) prev_out = out prev_lse = lse out = out.to(query.dtype) lse = lse.squeeze(-1) return out, lse def _templated_ulysses_attention(query, key, value): world_size = cp_options.ulysses_mesh.size() group = cp_options.ulysses_mesh.get_group() if world_size == 1: return cp_options.op(query, key, value) B, H, S_LOCAL, D = query.shape H_LOCAL = H // world_size query, key, value = ( x.reshape(B, world_size, H_LOCAL, S_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() for x in (query, key, value) ) query, key, value = ( funcol.all_to_all_single(x, None, None, group=group) for x in (query, key, value) ) query, key, value = ( x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (query, key, value) ) out, lse = cp_options.op(query, key, value) out = out.reshape(B, H_LOCAL, world_size, S_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() lse = lse.reshape(B, H_LOCAL, world_size, S_LOCAL).permute(2, 1, 0, 3).contiguous() out = funcol.all_to_all_single(out, None, None, group=group).wait() lse = funcol.all_to_all_single(lse, None, None, group=group).wait() out = out.flatten(0, 1).permute(1, 0, 2, 3).contiguous() lse = lse.flatten(0, 1).permute(1, 0, 2).contiguous() return out, lse def _templated_unified_attention(query, key, value): ring_size = cp_options.ring_mesh.size() ulysses_size = cp_options.ulysses_mesh.size() ulysses_group = cp_options.ulysses_mesh.get_group() world_size = ring_size * ulysses_size if world_size == 1: return cp_options.op(query, key, value) B, H, S_LOCAL, D = query.shape H_LOCAL = H // ulysses_size query, key, value = ( x.reshape(B, ulysses_size, H_LOCAL, S_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() for x in (query, key, value) ) query, key, value = ( funcol.all_to_all_single(x, None, None, group=ulysses_group) for x in (query, key, value) ) query, key, value = ( x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (query, key, value) ) out, lse = _templated_ring_attention(query, key, value) out = out.reshape(B, H_LOCAL, ulysses_size, S_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() lse = lse.reshape(B, H_LOCAL, ulysses_size, S_LOCAL).permute(2, 1, 0, 3).contiguous() out = funcol.all_to_all_single(out, None, None, group=ulysses_group).wait() lse = funcol.all_to_all_single(lse, None, None, group=ulysses_group).wait() out = out.flatten(0, 1).permute(1, 0, 2, 3).contiguous() lse = lse.flatten(0, 1).permute(1, 0, 2).contiguous() return out, lse def torch_cudnn_attention(query, key, value): out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( torch.ops.aten._scaled_dot_product_cudnn_attention( query=query, key=key, value=value, attn_bias=None, compute_log_sumexp=True, ) ) return out, lse def torch_flash_attention(query, key, value): out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( torch.ops.aten._scaled_dot_product_flash_attention( query=query, key=key, value=value, ) ) return out, lse OPS = { "cudnn": torch_cudnn_attention, "flash": torch_flash_attention, } WORLD_SIZE = -1 RANK = -1 def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--ring_degree", type=int, default=1) parser.add_argument("--ulysses_degree", type=int, default=1) parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--num_heads", type=int, default=24) parser.add_argument("--head_dim", type=int, default=128) parser.add_argument("--seq_lens", type=int, nargs="+", default=[512, 1024, 2048, 4096, 4224, 4352, 4480, 4608, 8192]) parser.add_argument( "--ops", type=str, nargs="+", choices=list(OPS.keys()), default=list(OPS.keys()), ) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() return args def main( ring_degree: int, ulysses_degree: int, batch_size: int, num_heads: int, head_dim: int, seq_lens: List[int], ops: List[str], seed: int, ): global cp_options, WORLD_SIZE, RANK mesh_names = ["ring", "ulysses"] mesh_dims = [ring_degree, ulysses_degree] mesh = dist.device_mesh.init_device_mesh("cuda", mesh_dims, mesh_dim_names=mesh_names) cp_options.ring_mesh = mesh["ring"] cp_options.ulysses_mesh = mesh["ulysses"] cp_options.convert_to_fp32 = True cp_attention = None num_warmups = 5 num_repeats = 10 device = torch.device("cuda") dtype = torch.bfloat16 if ring_degree > 1 and ulysses_degree > 1: cp_options.mode = "unified" cp_attention = _templated_unified_attention elif ulysses_degree > 1: cp_options.mode = "ulysses" cp_attention = _templated_ulysses_attention else: cp_options.mode = "ring" cp_attention = _templated_ring_attention results = {} for op_name in ops: op = OPS[op_name] cp_options.op = op results[op_name] = {} for seq_len in seq_lens: shape = (batch_size, num_heads, seq_len, head_dim) query = torch.randn(shape, device=device, dtype=dtype) key = torch.randn(shape, device=device, dtype=dtype) value = torch.randn(shape, device=device, dtype=dtype) dist.broadcast(query, src=0) dist.broadcast(key, src=0) dist.broadcast(value, src=0) dist.barrier() torch.cuda.synchronize() reference_out, reference_lse = torch_cudnn_attention(query, key, value) query, key, value = (x.chunk(WORLD_SIZE, dim=2)[RANK].contiguous() for x in (query, key, value)) for _ in range(num_warmups): if WORLD_SIZE == 1: out, lse = op(query, key, value) else: out, lse = cp_attention(query, key, value) out = funcol.all_gather_tensor(out, gather_dim=2, group=mesh._flatten().get_group()) lse = funcol.all_gather_tensor(lse, gather_dim=2, group=mesh._flatten().get_group()) torch.cuda.synchronize() diff = out - reference_out absdiff = torch.abs(diff) absmax = torch.max(absdiff) mae = torch.mean(absdiff) mse = torch.mean(diff * diff) if RANK == 0: print(f"op: {op_name}, seq_len: {seq_len}, absmax: {absmax:.5f}, mae: {mae:.5f}, mse: {mse:.5f}") # if not torch.allclose(out, reference_out, atol=1e-2, rtol=1e-2): # raise ValueError(f"Output mismatch for op: {op_name}, seq_len: {seq_len}") # if not torch.allclose(lse, reference_lse, atol=1e-2, rtol=1e-2): # raise ValueError(f"LSE mismatch for op: {op_name}, seq_len: {seq_len}") dist.barrier() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_repeats): if WORLD_SIZE == 1: out, lse = op(query, key, value) else: out, lse = cp_attention(query, key, value) end_event.record() torch.cuda.synchronize() dist.barrier() elapsed_time = start_event.elapsed_time(end_event) / num_repeats results[op_name][seq_len] = elapsed_time if RANK == 0: print("Benchmark results:") for op_name, seq_times in results.items(): print(f"\n\n===== op: {op_name} =====") for seq_len, time in seq_times.items(): print(f" {seq_len=}, {time:.5f} ms") if __name__ == "__main__": args = get_args() torch.manual_seed(args.seed) try: dist.init_process_group(backend="nccl") WORLD_SIZE = dist.get_world_size() RANK = dist.get_rank() torch.cuda.set_device(RANK) if args.ring_degree * args.ulysses_degree != WORLD_SIZE: raise ValueError( f"ring_degree * ulysses_degree must equal world size, got {args.ring_degree} * {args.ulysses_degree} != {WORLD_SIZE}" ) main( ring_degree=args.ring_degree, ulysses_degree=args.ulysses_degree, batch_size=args.batch_size, num_heads=args.num_heads, head_dim=args.head_dim, seq_lens=args.seq_lens, ops=args.ops, seed=args.seed, ) finally: dist.destroy_process_group()