Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Created June 30, 2025 11:36
Show Gist options
  • Select an option

  • Save a-r-r-o-w/87926a348703e55f008f259a1778e4f3 to your computer and use it in GitHub Desktop.

Select an option

Save a-r-r-o-w/87926a348703e55f008f259a1778e4f3 to your computer and use it in GitHub Desktop.

Revisions

  1. a-r-r-o-w created this gist Jun 30, 2025.
    89 changes: 89 additions & 0 deletions sequential_ring.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,89 @@
    import torch

    torch.manual_seed(42)


    def torch_sdpa(query, key, value):
    out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
    torch.ops.aten._scaled_dot_product_cudnn_attention(
    query=query,
    key=key,
    value=value,
    attn_bias=None,
    compute_log_sumexp=True,
    )
    )
    return out, lse


    def ring_sdpa_sequential(partial_queries, partial_keys, partial_values, *, world_size: int = 1, convert_to_fp32: bool = True):
    outputs, lses = [], []

    for rank in range(world_size):
    query, key, value = partial_queries[rank], partial_keys[rank], partial_values[rank]
    next_rank = (rank + 1) % world_size
    prev_out = prev_lse = None

    for i in range(world_size):
    if i > 0:
    key, value = partial_keys[next_rank], partial_values[next_rank]
    next_rank = (next_rank + 1) % world_size

    out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
    torch.ops.aten._scaled_dot_product_cudnn_attention(
    query=query,
    key=key,
    value=value,
    attn_bias=None,
    compute_log_sumexp=True,
    )
    )

    if convert_to_fp32:
    out = out.to(torch.float32)
    lse = lse.to(torch.float32)

    # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
    lse = lse.unsqueeze(-1)
    if prev_out is not None:
    out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
    lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
    prev_out = out
    prev_lse = lse

    out = out.to(query.dtype)
    lse = lse.squeeze(-1)
    outputs.append(out)
    lses.append(lse)

    return outputs, lses


    device = "cuda"
    dtype = torch.bfloat16
    world_size = 4

    batch_size = 1
    image_sequence_length = 4096
    text_sequence_length = 512
    sequence_length = image_sequence_length + text_sequence_length
    num_attention_heads = 24
    attention_head_dim = 128

    query = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)
    key = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)
    value = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)
    partial_queries = query.chunk(world_size, dim=2)
    partial_keys = key.chunk(world_size, dim=2)
    partial_values = value.chunk(world_size, dim=2)

    torch_sdpa_out, torch_sdpa_lse = torch_sdpa(query, key, value)
    ring_sdpa_out, ring_sdpa_lse = ring_sdpa_sequential(partial_queries, partial_keys, partial_values, world_size=world_size)

    all_ring_sdpa_out = torch.cat(ring_sdpa_out, dim=2)
    all_ring_sdpa_lse = torch.cat(ring_sdpa_lse, dim=2)

    assert torch_sdpa_out.shape == all_ring_sdpa_out.shape, "Output shapes do not match!"
    assert torch_sdpa_lse.shape == all_ring_sdpa_lse.shape, "LSE shapes do not match!"
    assert torch.allclose(all_ring_sdpa_out, torch_sdpa_out, atol=1e-3, rtol=1e-3), "Outputs do not match!"
    assert torch.allclose(all_ring_sdpa_lse, torch_sdpa_lse, atol=1e-3, rtol=1e-3), "LSE values do not match!"
    103 changes: 103 additions & 0 deletions sequential_ulysses.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,103 @@
    import torch

    torch.manual_seed(42)


    def torch_sdpa(query, key, value):
    out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
    torch.ops.aten._scaled_dot_product_cudnn_attention(
    query=query,
    key=key,
    value=value,
    attn_bias=None,
    compute_log_sumexp=True,
    )
    )
    return out, lse


    def ulysses_sdpa_sequential(partial_queries, partial_keys, partial_values, *, world_size: int = 1):
    B, H, S_LOCAL, D = partial_queries[0].shape
    H_LOCAL = H // world_size

    outputs, lses = [], []

    for partials in [partial_queries, partial_keys, partial_values]:
    for rank in range(world_size):
    x_local = partials[rank]
    # (B, H, S // world_size, D) -> (world_size, S // world_size, B, H // world_size, D)
    partials[rank] = x_local.reshape(B, world_size, H_LOCAL, S_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
    x = all_to_all_single_sequential(partials, world_size)
    for rank in range(world_size):
    x_local = x[rank]
    # (S, B, H // world_size, D) -> (B, H // world_size, S, D)
    partials[rank] = x_local.permute(1, 2, 0, 3).contiguous()

    for rank in range(world_size):
    query_local, key_local, value_local = partial_queries[rank], partial_keys[rank], partial_values[rank]
    out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
    torch.ops.aten._scaled_dot_product_cudnn_attention(
    query=query_local,
    key=key_local,
    value=value_local,
    attn_bias=None,
    compute_log_sumexp=True,
    )
    )
    outputs.append(out)
    lses.append(lse)

    for rank in range(world_size):
    out_local = outputs[rank]
    lse_local = lses[rank]
    # (B, H // world_size, S, D) -> (B, H // world_size, world_size, S // world_size, D) -> (world_size, H // world_size, B, S // world_size, D)
    outputs[rank] = out_local.reshape(B, H_LOCAL, world_size, S_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
    lses[rank] = lse_local.reshape(B, H_LOCAL, world_size, S_LOCAL).permute(2, 1, 0, 3).contiguous()
    outputs = all_to_all_single_sequential(outputs, world_size)
    lses = all_to_all_single_sequential(lses, world_size)
    for rank in range(world_size):
    out_local = outputs[rank]
    lse_local = lses[rank]
    # (H, B, S // world_size, D) -> (B, H, S // world_size, D)
    outputs[rank] = out_local.permute(1, 0, 2, 3).contiguous()
    lses[rank] = lse_local.permute(1, 0, 2).contiguous()

    return outputs, lses


    def all_to_all_single_sequential(partials, world_size):
    output_partials = []
    for i in range(world_size):
    received_chunks = [p[i] for p in partials]
    output_partials.append(torch.cat(received_chunks, dim=0))
    return output_partials


    device = "cuda"
    dtype = torch.bfloat16
    world_size = 4

    batch_size = 1
    image_sequence_length = 4096
    text_sequence_length = 512
    sequence_length = image_sequence_length + text_sequence_length
    num_attention_heads = 24
    attention_head_dim = 128

    query = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)
    key = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)
    value = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)
    partial_queries = list(query.chunk(world_size, dim=2))
    partial_keys = list(key.chunk(world_size, dim=2))
    partial_values = list(value.chunk(world_size, dim=2))

    torch_sdpa_out, torch_sdpa_lse = torch_sdpa(query, key, value)
    ulysses_sdpa_out, ulysses_sdpa_lse = ulysses_sdpa_sequential(partial_queries, partial_keys, partial_values, world_size=world_size)

    all_ulysses_sdpa_out = torch.cat(ulysses_sdpa_out, dim=2)
    all_ulysses_sdpa_lse = torch.cat(ulysses_sdpa_lse, dim=2)

    assert torch_sdpa_out.shape == all_ulysses_sdpa_out.shape, "Output shapes do not match!"
    assert torch_sdpa_lse.shape == all_ulysses_sdpa_lse.shape, "LSE shapes do not match!"
    assert torch.allclose(all_ulysses_sdpa_out, torch_sdpa_out, atol=1e-3, rtol=1e-3), "Outputs do not match!"
    assert torch.allclose(all_ulysses_sdpa_lse, torch_sdpa_lse, atol=1e-3, rtol=1e-3), "LSEs do not match!"
    149 changes: 149 additions & 0 deletions sequential_unified.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,149 @@
    import torch

    torch.manual_seed(42)


    def torch_sdpa(query, key, value):
    out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
    torch.ops.aten._scaled_dot_product_cudnn_attention(
    query=query,
    key=key,
    value=value,
    attn_bias=None,
    compute_log_sumexp=True,
    )
    )
    return out, lse


    def ring_sdpa_sequential(partial_queries, partial_keys, partial_values, *, ring_size: int = 1, convert_to_fp32: bool = True):
    outputs, lses = [], []

    for rank in range(ring_size):
    query, key, value = partial_queries[rank], partial_keys[rank], partial_values[rank]
    next_rank = (rank + 1) % ring_size
    prev_out = prev_lse = None

    for i in range(ring_size):
    if i > 0:
    key, value = partial_keys[next_rank], partial_values[next_rank]
    next_rank = (next_rank + 1) % ring_size

    out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
    torch.ops.aten._scaled_dot_product_cudnn_attention(
    query=query,
    key=key,
    value=value,
    attn_bias=None,
    compute_log_sumexp=True,
    )
    )

    if convert_to_fp32:
    out = out.to(torch.float32)
    lse = lse.to(torch.float32)

    # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
    lse = lse.unsqueeze(-1)
    if prev_out is not None:
    out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
    lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
    prev_out = out
    prev_lse = lse

    out = out.to(query.dtype)
    lse = lse.squeeze(-1)
    outputs.append(out)
    lses.append(lse)

    return outputs, lses


    def unified_ulysses_ring_sdpa_sequential(partial_queries, partial_keys, partial_values, *, ulysses_size: int = 1, ring_size: int = 1):
    B, H, S_LOCAL, D = partial_queries[0][0].shape
    H_LOCAL = H // ulysses_size

    outputs, lses = [], []

    for partials in [partial_queries, partial_keys, partial_values]:
    for ring_rank in range(ring_size):
    for rank in range(ulysses_size):
    x_local = partials[ring_rank][rank]
    partials[ring_rank][rank] = x_local.reshape(B, ulysses_size, H_LOCAL, S_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
    x = all_to_all_single_sequential(partials[ring_rank], ulysses_size)
    for rank in range(ulysses_size):
    x_local = x[rank]
    partials[ring_rank][rank] = x_local.permute(1, 2, 0, 3).contiguous()

    partial_queries = [list(x) for x in zip(*partial_queries)]
    partial_keys = [list(x) for x in zip(*partial_keys)]
    partial_values = [list(x) for x in zip(*partial_values)]

    for rank in range(ulysses_size):
    ring_outputs, ring_lses = ring_sdpa_sequential(partial_queries[rank], partial_keys[rank], partial_values[rank], ring_size=ring_size)
    outputs.append(ring_outputs)
    lses.append(ring_lses)

    outputs = [list(x) for x in zip(*outputs)]
    lses = [list(x) for x in zip(*lses)]

    for ring_rank in range(ring_size):
    for rank in range(ulysses_size):
    outputs[ring_rank][rank] = outputs[ring_rank][rank].reshape(B, H_LOCAL, ulysses_size, S_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
    lses[ring_rank][rank] = lses[ring_rank][rank].reshape(B, H_LOCAL, ulysses_size, S_LOCAL).permute(2, 1, 0, 3).contiguous()
    outputs[ring_rank] = all_to_all_single_sequential(outputs[ring_rank], ulysses_size)
    lses[ring_rank] = all_to_all_single_sequential(lses[ring_rank], ulysses_size)
    for rank in range(ulysses_size):
    outputs[ring_rank][rank] = outputs[ring_rank][rank].permute(1, 0, 2, 3).contiguous()
    lses[ring_rank][rank] = lses[ring_rank][rank].permute(1, 0, 2).contiguous()

    return outputs, lses


    def all_to_all_single_sequential(partials, world_size):
    output_partials = []
    for i in range(world_size):
    received_chunks = [p[i] for p in partials]
    output_partials.append(torch.cat(received_chunks, dim=0))
    return output_partials


    device = "cuda"
    dtype = torch.bfloat16
    WORLD_SIZE = 8
    ulysses_size = 4
    ring_size = 2
    assert ulysses_size * ring_size == WORLD_SIZE, "ulysses_size * ring_size must equal WORLD_SIZE"

    batch_size = 1
    image_sequence_length = 4096
    text_sequence_length = 512
    sequence_length = image_sequence_length + text_sequence_length
    num_attention_heads = 24
    attention_head_dim = 128

    query = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)
    key = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)
    value = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)

    partial_queries = list(query.chunk(WORLD_SIZE, dim=2))
    partial_keys = list(key.chunk(WORLD_SIZE, dim=2))
    partial_values = list(value.chunk(WORLD_SIZE, dim=2))

    # R=1, U=4 => [[tensor1, tensor2, tensor3, tensor4]]
    # R=2, U=2 => [[tensor1, tensor2], [tensor3, tensor4]]
    # R=4, U=1 => [[tensor1], [tensor2], [tensor3], [tensor4]]
    partial_queries = [partial_queries[i:i + ulysses_size] for i in range(0, WORLD_SIZE, ulysses_size)]
    partial_keys = [partial_keys[i:i + ulysses_size] for i in range(0, WORLD_SIZE, ulysses_size)]
    partial_values = [partial_values[i:i + ulysses_size] for i in range(0, WORLD_SIZE, ulysses_size)]

    torch_sdpa_out, torch_sdpa_lse = torch_sdpa(query, key, value)
    unified_sdpa_out, unified_sdpa_lse = unified_ulysses_ring_sdpa_sequential(partial_queries, partial_keys, partial_values, ulysses_size=ulysses_size, ring_size=ring_size)

    all_unified_sdpa_out = torch.cat([torch.cat(out, dim=2) for out in unified_sdpa_out], dim=2)
    all_unified_sdpa_lse = torch.cat([torch.cat(lse, dim=2) for lse in unified_sdpa_lse], dim=2)

    assert torch_sdpa_out.shape == all_unified_sdpa_out.shape, "Output shapes do not match!"
    assert torch_sdpa_lse.shape == all_unified_sdpa_lse.shape, "LSE shapes do not match!"
    assert torch.allclose(all_unified_sdpa_out, torch_sdpa_out, atol=1e-3, rtol=1e-3), "Outputs do not match!"
    assert torch.allclose(all_unified_sdpa_lse, torch_sdpa_lse, atol=1e-3, rtol=1e-3), "LSEs do not match!"
    305 changes: 305 additions & 0 deletions templated_benchmark.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,305 @@
    import argparse
    from dataclasses import dataclass
    from typing import Callable, Literal, List

    import torch
    import torch.distributed as dist
    import torch.distributed._functional_collectives as funcol
    from torch.distributed import DeviceMesh


    @dataclass
    class ContextParallelOptions:
    mode: Literal["ring", "ulysses", "unified"] = "ring"
    ring_mesh: DeviceMesh | None = None
    ulysses_mesh: DeviceMesh | None = None
    convert_to_fp32: bool = True
    op: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]] | None = None


    cp_options = ContextParallelOptions()


    def _templated_ring_attention(query, key, value):
    rank = cp_options.ring_mesh.get_rank()
    world_size = cp_options.ring_mesh.size()

    if world_size == 1:
    return cp_options.op(query, key, value)

    next_rank = (rank + 1) % world_size
    prev_out = prev_lse = None

    kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
    kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=cp_options.ring_mesh.get_group())
    kv_buffer = kv_buffer.chunk(world_size)

    for i in range(world_size):
    if i > 0:
    kv = kv_buffer[next_rank]
    key = kv[:key.numel()].reshape_as(key)
    value = kv[key.numel():].reshape_as(value)
    next_rank = (next_rank + 1) % world_size

    out, lse = cp_options.op(query, key, value)

    if cp_options.convert_to_fp32:
    out = out.to(torch.float32)
    lse = lse.to(torch.float32)

    lse = lse.unsqueeze(-1)
    if prev_out is not None:
    out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
    lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
    prev_out = out
    prev_lse = lse

    out = out.to(query.dtype)
    lse = lse.squeeze(-1)
    return out, lse


    def _templated_ulysses_attention(query, key, value):
    world_size = cp_options.ulysses_mesh.size()
    group = cp_options.ulysses_mesh.get_group()

    if world_size == 1:
    return cp_options.op(query, key, value)

    B, H, S_LOCAL, D = query.shape
    H_LOCAL = H // world_size
    query, key, value = (
    x.reshape(B, world_size, H_LOCAL, S_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
    for x in (query, key, value)
    )
    query, key, value = (
    funcol.all_to_all_single(x, None, None, group=group)
    for x in (query, key, value)
    )
    query, key, value = (
    x.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
    for x in (query, key, value)
    )
    out, lse = cp_options.op(query, key, value)
    out = out.reshape(B, H_LOCAL, world_size, S_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
    lse = lse.reshape(B, H_LOCAL, world_size, S_LOCAL).permute(2, 1, 0, 3).contiguous()
    out = funcol.all_to_all_single(out, None, None, group=group).wait()
    lse = funcol.all_to_all_single(lse, None, None, group=group).wait()
    out = out.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
    lse = lse.flatten(0, 1).permute(1, 0, 2).contiguous()
    return out, lse


    def _templated_unified_attention(query, key, value):
    ring_size = cp_options.ring_mesh.size()
    ulysses_size = cp_options.ulysses_mesh.size()
    ulysses_group = cp_options.ulysses_mesh.get_group()
    world_size = ring_size * ulysses_size

    if world_size == 1:
    return cp_options.op(query, key, value)

    B, H, S_LOCAL, D = query.shape
    H_LOCAL = H // ulysses_size
    query, key, value = (
    x.reshape(B, ulysses_size, H_LOCAL, S_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
    for x in (query, key, value)
    )
    query, key, value = (
    funcol.all_to_all_single(x, None, None, group=ulysses_group)
    for x in (query, key, value)
    )
    query, key, value = (
    x.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
    for x in (query, key, value)
    )
    out, lse = _templated_ring_attention(query, key, value)
    out = out.reshape(B, H_LOCAL, ulysses_size, S_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
    lse = lse.reshape(B, H_LOCAL, ulysses_size, S_LOCAL).permute(2, 1, 0, 3).contiguous()
    out = funcol.all_to_all_single(out, None, None, group=ulysses_group).wait()
    lse = funcol.all_to_all_single(lse, None, None, group=ulysses_group).wait()
    out = out.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
    lse = lse.flatten(0, 1).permute(1, 0, 2).contiguous()
    return out, lse


    def torch_cudnn_attention(query, key, value):
    out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
    torch.ops.aten._scaled_dot_product_cudnn_attention(
    query=query,
    key=key,
    value=value,
    attn_bias=None,
    compute_log_sumexp=True,
    )
    )
    return out, lse


    def torch_flash_attention(query, key, value):
    out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
    torch.ops.aten._scaled_dot_product_flash_attention(
    query=query,
    key=key,
    value=value,
    )
    )
    return out, lse


    OPS = {
    "cudnn": torch_cudnn_attention,
    "flash": torch_flash_attention,
    }
    WORLD_SIZE = -1
    RANK = -1


    def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--ring_degree", type=int, default=1)
    parser.add_argument("--ulysses_degree", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--num_heads", type=int, default=24)
    parser.add_argument("--head_dim", type=int, default=128)
    parser.add_argument("--seq_lens", type=int, nargs="+", default=[512, 1024, 2048, 4096, 4224, 4352, 4480, 4608, 8192])
    parser.add_argument(
    "--ops",
    type=str,
    nargs="+",
    choices=list(OPS.keys()),
    default=list(OPS.keys()),
    )
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()
    return args


    def main(
    ring_degree: int,
    ulysses_degree: int,
    batch_size: int,
    num_heads: int,
    head_dim: int,
    seq_lens: List[int],
    ops: List[str],
    seed: int,
    ):
    global cp_options, WORLD_SIZE, RANK

    mesh_names = ["ring", "ulysses"]
    mesh_dims = [ring_degree, ulysses_degree]
    mesh = dist.device_mesh.init_device_mesh("cuda", mesh_dims, mesh_dim_names=mesh_names)
    cp_options.ring_mesh = mesh["ring"]
    cp_options.ulysses_mesh = mesh["ulysses"]
    cp_options.convert_to_fp32 = True
    cp_attention = None
    num_warmups = 5
    num_repeats = 10
    device = torch.device("cuda")
    dtype = torch.bfloat16

    if ring_degree > 1 and ulysses_degree > 1:
    cp_options.mode = "unified"
    cp_attention = _templated_unified_attention
    elif ulysses_degree > 1:
    cp_options.mode = "ulysses"
    cp_attention = _templated_ulysses_attention
    else:
    cp_options.mode = "ring"
    cp_attention = _templated_ring_attention

    results = {}
    for op_name in ops:
    op = OPS[op_name]
    cp_options.op = op
    results[op_name] = {}

    for seq_len in seq_lens:
    shape = (batch_size, num_heads, seq_len, head_dim)
    query = torch.randn(shape, device=device, dtype=dtype)
    key = torch.randn(shape, device=device, dtype=dtype)
    value = torch.randn(shape, device=device, dtype=dtype)

    dist.broadcast(query, src=0)
    dist.broadcast(key, src=0)
    dist.broadcast(value, src=0)
    dist.barrier()
    torch.cuda.synchronize()

    reference_out, reference_lse = torch_cudnn_attention(query, key, value)
    query, key, value = (x.chunk(WORLD_SIZE, dim=2)[RANK].contiguous() for x in (query, key, value))

    for _ in range(num_warmups):
    if WORLD_SIZE == 1:
    out, lse = op(query, key, value)
    else:
    out, lse = cp_attention(query, key, value)
    out = funcol.all_gather_tensor(out, gather_dim=2, group=mesh._flatten().get_group())
    lse = funcol.all_gather_tensor(lse, gather_dim=2, group=mesh._flatten().get_group())
    torch.cuda.synchronize()

    diff = out - reference_out
    absdiff = torch.abs(diff)
    absmax = torch.max(absdiff)
    mae = torch.mean(absdiff)
    mse = torch.mean(diff * diff)
    if RANK == 0:
    print(f"op: {op_name}, seq_len: {seq_len}, absmax: {absmax:.5f}, mae: {mae:.5f}, mse: {mse:.5f}")

    # if not torch.allclose(out, reference_out, atol=1e-2, rtol=1e-2):
    # raise ValueError(f"Output mismatch for op: {op_name}, seq_len: {seq_len}")
    # if not torch.allclose(lse, reference_lse, atol=1e-2, rtol=1e-2):
    # raise ValueError(f"LSE mismatch for op: {op_name}, seq_len: {seq_len}")
    dist.barrier()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    for _ in range(num_repeats):
    if WORLD_SIZE == 1:
    out, lse = op(query, key, value)
    else:
    out, lse = cp_attention(query, key, value)
    end_event.record()
    torch.cuda.synchronize()
    dist.barrier()
    elapsed_time = start_event.elapsed_time(end_event) / num_repeats
    results[op_name][seq_len] = elapsed_time

    if RANK == 0:
    print("Benchmark results:")
    for op_name, seq_times in results.items():
    print(f"\n\n===== op: {op_name} =====")
    for seq_len, time in seq_times.items():
    print(f" {seq_len=}, {time:.5f} ms")


    if __name__ == "__main__":
    args = get_args()

    torch.manual_seed(args.seed)

    try:
    dist.init_process_group(backend="nccl")
    WORLD_SIZE = dist.get_world_size()
    RANK = dist.get_rank()
    torch.cuda.set_device(RANK)

    if args.ring_degree * args.ulysses_degree != WORLD_SIZE:
    raise ValueError(
    f"ring_degree * ulysses_degree must equal world size, got {args.ring_degree} * {args.ulysses_degree} != {WORLD_SIZE}"
    )

    main(
    ring_degree=args.ring_degree,
    ulysses_degree=args.ulysses_degree,
    batch_size=args.batch_size,
    num_heads=args.num_heads,
    head_dim=args.head_dim,
    seq_lens=args.seq_lens,
    ops=args.ops,
    seed=args.seed,
    )
    finally:
    dist.destroy_process_group()