Skip to content

Instantly share code, notes, and snippets.

@nelhage
Last active January 22, 2022 01:25
Show Gist options
  • Select an option

  • Save nelhage/c5f9b2831014963a4b35d2e0311f93c7 to your computer and use it in GitHub Desktop.

Select an option

Save nelhage/c5f9b2831014963a4b35d2e0311f93c7 to your computer and use it in GitHub Desktop.

Revisions

  1. nelhage revised this gist Jan 22, 2022. 1 changed file with 123 additions and 0 deletions.
    123 changes: 123 additions & 0 deletions torch_leak_irecv.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,123 @@
    #!/usr/bin/env python
    import os
    import time
    import torch
    import torch.distributed as dist
    import torch.multiprocessing as mp

    INTERVAL = 1
    COMM_SIZE = (10,)


    def run(rank, size):
    torch.cuda.set_device(rank)

    if rank == 0:
    s1 = torch.cuda.Stream()
    s2 = torch.cuda.Stream()

    dist.barrier()
    torch.cuda.synchronize()

    buf = torch.empty(COMM_SIZE, dtype=torch.float, device="cuda")
    # Do a warmup comms; The first comm seems to block until
    # completion whether or not we do it async.
    dist.irecv(buf, src=1).wait()

    with torch.cuda.stream(s1):
    # Allocate a tensor whose backing block comes from stream
    # `s1`.
    buf = torch.empty(COMM_SIZE, dtype=torch.float, device="cuda")

    # Now use it in a NCCL
    # comm. ProcessGroupNCCL::pointToPoint will call
    # `recordStream` to record `buf` as having been used on
    # the NCCL comms stream.
    #
    # This comm will be fast since rank 0 sends promptly.
    dist.irecv(buf, src=1).wait()

    # Now we start a long-running comm. Rank 0 will sleep
    # before sending this tensor, so this results in a
    # long-running op on the NCCL CUDA stream.
    handle = dist.irecv(
    torch.empty(COMM_SIZE, dtype=torch.float, device="cuda"), src=1
    )

    # Now we free `buf`. This eventually ends up in
    # DeviceCachingAllocator::free; it notices that the block
    # has non-empty stream uses, and so queues an event on the
    # NCCL comms stream to make sure the tensor is actually
    # done being used before it is actually released to CUDA.
    buf = None

    print("[0] Queued the receive.")

    t = time.time()

    # Now we do some concurrent work while the comms happen in the
    # background.
    while not handle.is_completed():
    # We allocate a tensor, and then we `record_stream` to
    # make the allocator record it as having stream_uses. This
    # is the simplest demo for a reproducer; in real code this
    # can happen in autograd, by other comms, or a handful of
    # other ways.
    data = torch.randn((1024,), device="cuda")
    data.record_stream(s2)

    # Now we free `data`. Since it has `stream_uses`, the
    # allocator enqueues an event and marks the underlying
    # buffer for later free.
    #
    # However, `process_events` will walk the event list in
    # order, and stop at the first event which isn't
    # ready. Since we queued an event on the NCCL comms up
    # above, it will always stop there, and no memory will be
    # released until the comms complete.
    data = None

    now = time.time()
    if (now - t) > INTERVAL:
    # Dump memory stats every second
    t = now
    print(torch.cuda.memory_summary(abbreviated=True))

    handle.wait()
    else:
    dist.barrier()

    buf = torch.randn(COMM_SIZE, dtype=torch.float, device="cuda")
    # One warm-up send
    dist.isend(buf, dst=0).wait()

    # One send for the first (fast) `irecv` in the rank 0
    dist.isend(buf, dst=0).wait()

    # Now we sleep 10 and then do a final isend, to cause the
    # final `irecv` in rank 0 to be long-running.
    print("[1] Sleeping...")
    time.sleep(10)
    dist.isend(buf, dst=0).wait()
    print("[1] Sent a tensor")


    def init_process(rank, size, fn, backend="nccl"):
    """ Initialize the distributed environment. """
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)


    if __name__ == "__main__":
    size = 2
    processes = []
    mp.set_start_method("spawn")
    for rank in range(size):
    p = mp.Process(target=init_process, args=(rank, size, run))
    p.start()
    processes.append(p)

    for p in processes:
    p.join()
  2. nelhage created this gist Jan 21, 2022.
    121 changes: 121 additions & 0 deletions torch_leak.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,121 @@
    #!/usr/bin/env python
    import os
    import time
    import torch
    import torch.distributed as dist
    import torch.multiprocessing as mp

    INTERVAL = 1
    COMM_SIZE = (10,)


    def run(rank, size):
    torch.cuda.set_device(rank)

    pg = torch.distributed.new_group(list(range(size)), backend="nccl")

    if rank == 0:
    s1 = torch.cuda.Stream()
    s2 = torch.cuda.Stream()

    dist.barrier()
    torch.cuda.synchronize()

    outputs = torch.zeros((size, *COMM_SIZE), dtype=torch.float, device="cuda")
    mine = torch.randn(COMM_SIZE, dtype=torch.float, device="cuda")

    # Do an allgather to warm up comms. Somehow the first
    # all-gather we do isn't actually async and waits for the comm
    # to complete.
    pg._allgather_base(outputs, mine).wait()

    with torch.cuda.stream(s1):
    # Allocate a tensor whose backing block comes from stream
    # `s1`.
    mine = torch.randn(COMM_SIZE, dtype=torch.float, device="cuda")

    # When we execute the _allgather_base,
    # ProcessGroupNCCL::collective calls `recordStream` to
    # record `mine` as having been used on the NCCL comms
    # stream.
    handle = pg._allgather_base(outputs, mine)

    # Now we free `mine`. This ends up in
    # DeviceCachingAllocator::free, which notices that the
    # block has non-empty stream uses, and queues an event on
    # the NCCL comms stream.
    #
    # Note that the bug wouldn't show up with point-to-point
    # comms, because they hold on to their input or output
    # tensors in WorkNCCL::outputs_, and so the tensor would
    # not actually be freed her.
    mine = None

    print("[0] Queued the receive.")

    t = time.time()

    # Now we do some concurrent work while the comms happen in the
    # background.
    while not handle.is_completed():
    # We allocate a tensor, and then we `record_stream` to
    # make the allocator record it as having stream_uses. This
    # is the simplest demo for a reproducer; in real code this
    # can happen in autograd, by other comms, or a handful of
    # other ways.
    data = torch.randn((1024,), device="cuda")
    data.record_stream(s2)

    # Now we free `data`. Since it has `stream_uses`, the
    # allocator enqueues an event and marks the underlying
    # buffer for later free.
    #
    # However, `process_events` will walk the event list in
    # order, and stop at the first event which isn't
    # ready. Since we queued and event on the NCCL comms up
    # above, it will always stop there, and no memory will be
    # released until the comms complete.
    data = None

    now = time.time()
    if (now - t) > INTERVAL:
    # Dump memory stats every second
    t = now
    print(torch.cuda.memory_summary(abbreviated=True))

    handle.wait()
    else:
    dist.barrier()
    outputs = torch.zeros((size, *COMM_SIZE), dtype=torch.float, device="cuda")
    mine = torch.randn(COMM_SIZE, dtype=torch.float, device="cuda")

    pg._allgather_base(outputs, mine)

    # On rank 1, we just sleep 10s and then do an all-gather, to
    # achieve the effect of a long-running op on the NCCL stream
    # in rank 0.
    print("[1] Sleeping...")
    time.sleep(10)
    pg._allgather_base(outputs, mine)
    print("[1] Sent a tensor")


    def init_process(rank, size, fn, backend="nccl"):
    """ Initialize the distributed environment. """
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)


    if __name__ == "__main__":
    size = 2
    processes = []
    mp.set_start_method("spawn")
    for rank in range(size):
    p = mp.Process(target=init_process, args=(rank, size, run))
    p.start()
    processes.append(p)

    for p in processes:
    p.join()