Last active
January 22, 2022 01:25
-
-
Save nelhage/c5f9b2831014963a4b35d2e0311f93c7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| import os | |
| import time | |
| import torch | |
| import torch.distributed as dist | |
| import torch.multiprocessing as mp | |
| INTERVAL = 1 | |
| COMM_SIZE = (10,) | |
| def run(rank, size): | |
| torch.cuda.set_device(rank) | |
| pg = torch.distributed.new_group(list(range(size)), backend="nccl") | |
| if rank == 0: | |
| s1 = torch.cuda.Stream() | |
| s2 = torch.cuda.Stream() | |
| dist.barrier() | |
| torch.cuda.synchronize() | |
| outputs = torch.zeros((size, *COMM_SIZE), dtype=torch.float, device="cuda") | |
| mine = torch.randn(COMM_SIZE, dtype=torch.float, device="cuda") | |
| # Do an allgather to warm up comms. Somehow the first | |
| # all-gather we do isn't actually async and waits for the comm | |
| # to complete. | |
| pg._allgather_base(outputs, mine).wait() | |
| with torch.cuda.stream(s1): | |
| # Allocate a tensor whose backing block comes from stream | |
| # `s1`. | |
| mine = torch.randn(COMM_SIZE, dtype=torch.float, device="cuda") | |
| # When we execute the _allgather_base, | |
| # ProcessGroupNCCL::collective calls `recordStream` to | |
| # record `mine` as having been used on the NCCL comms | |
| # stream. | |
| handle = pg._allgather_base(outputs, mine) | |
| # Now we free `mine`. This ends up in | |
| # DeviceCachingAllocator::free, which notices that the | |
| # block has non-empty stream uses, and queues an event on | |
| # the NCCL comms stream. | |
| # | |
| # Note that the bug wouldn't show up with point-to-point | |
| # comms, because they hold on to their input or output | |
| # tensors in WorkNCCL::outputs_, and so the tensor would | |
| # not actually be freed her. | |
| mine = None | |
| print("[0] Queued the receive.") | |
| t = time.time() | |
| # Now we do some concurrent work while the comms happen in the | |
| # background. | |
| while not handle.is_completed(): | |
| # We allocate a tensor, and then we `record_stream` to | |
| # make the allocator record it as having stream_uses. This | |
| # is the simplest demo for a reproducer; in real code this | |
| # can happen in autograd, by other comms, or a handful of | |
| # other ways. | |
| data = torch.randn((1024,), device="cuda") | |
| data.record_stream(s2) | |
| # Now we free `data`. Since it has `stream_uses`, the | |
| # allocator enqueues an event and marks the underlying | |
| # buffer for later free. | |
| # | |
| # However, `process_events` will walk the event list in | |
| # order, and stop at the first event which isn't | |
| # ready. Since we queued and event on the NCCL comms up | |
| # above, it will always stop there, and no memory will be | |
| # released until the comms complete. | |
| data = None | |
| now = time.time() | |
| if (now - t) > INTERVAL: | |
| # Dump memory stats every second | |
| t = now | |
| print(torch.cuda.memory_summary(abbreviated=True)) | |
| handle.wait() | |
| else: | |
| dist.barrier() | |
| outputs = torch.zeros((size, *COMM_SIZE), dtype=torch.float, device="cuda") | |
| mine = torch.randn(COMM_SIZE, dtype=torch.float, device="cuda") | |
| pg._allgather_base(outputs, mine) | |
| # On rank 1, we just sleep 10s and then do an all-gather, to | |
| # achieve the effect of a long-running op on the NCCL stream | |
| # in rank 0. | |
| print("[1] Sleeping...") | |
| time.sleep(10) | |
| pg._allgather_base(outputs, mine) | |
| print("[1] Sent a tensor") | |
| def init_process(rank, size, fn, backend="nccl"): | |
| """ Initialize the distributed environment. """ | |
| os.environ["MASTER_ADDR"] = "127.0.0.1" | |
| os.environ["MASTER_PORT"] = "29500" | |
| dist.init_process_group(backend, rank=rank, world_size=size) | |
| fn(rank, size) | |
| if __name__ == "__main__": | |
| size = 2 | |
| processes = [] | |
| mp.set_start_method("spawn") | |
| for rank in range(size): | |
| p = mp.Process(target=init_process, args=(rank, size, run)) | |
| p.start() | |
| processes.append(p) | |
| for p in processes: | |
| p.join() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment