Skip to content

Instantly share code, notes, and snippets.

@chrjxj
Forked from sgraaf/ddp_example.py
Created July 24, 2021 02:38
Show Gist options
  • Select an option

  • Save chrjxj/7d1a00eaa2754a5713760d463e99524d to your computer and use it in GitHub Desktop.

Select an option

Save chrjxj/7d1a00eaa2754a5713760d463e99524d to your computer and use it in GitHub Desktop.

Revisions

  1. @sgraaf sgraaf revised this gist Aug 5, 2020. 1 changed file with 1 addition and 0 deletions.
    1 change: 1 addition & 0 deletions ddp_example.py
    Original file line number Diff line number Diff line change
    @@ -32,6 +32,7 @@ def main():

    # initialize PyTorch distributed using environment variables (you could also do this more explicitly by specifying `rank` and `world_size`, but I find using environment variables makes it so that you can easily use the same script on different machines)
    dist.init_process_group(backend='nccl', init_method='env://')
    torch.cuda.set_device(args.local_rank)

    # set the seed for all GPUs (also make sure to set the seed for random, numpy, etc.)
    torch.cuda.manual_seed_all(SEED)
  2. @sgraaf sgraaf revised this gist Aug 5, 2020. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion ddp_example.py
    Original file line number Diff line number Diff line change
    @@ -24,7 +24,7 @@ def main():
    parser.add_argument('--local_rank', type=int, default=-1, metavar='N', help='Local process rank.') # you need this argument in your scripts for DDP to work
    args = parser.parse_args()

    # keep track of whether the current process is the `master` process (totally optional, but I find it useful)
    # keep track of whether the current process is the `master` process (totally optional, but I find it useful for data laoding, logging, etc.)
    args.is_master = args.local_rank == 0

    # set the device
  3. @sgraaf sgraaf created this gist Aug 5, 2020.
    87 changes: 87 additions & 0 deletions ddp_example.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,87 @@
    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    from argparse import ArgumentParser

    import torch
    import torch.distributed as dist
    from torch.nn.parallel import DistributedDataParallel as DDP
    from torch.utils.data import DataLoader, Dataset
    from torch.utils.data.distributed import DistributedSampler
    from transformers import BertForMaskedLM

    SEED = 42
    BATCH_SIZE = 8
    NUM_EPOCHS = 3

    class YourDataset(Dataset):

    def __init__(self):
    pass


    def main():
    parser = ArgumentParser('DDP usage example')
    parser.add_argument('--local_rank', type=int, default=-1, metavar='N', help='Local process rank.') # you need this argument in your scripts for DDP to work
    args = parser.parse_args()

    # keep track of whether the current process is the `master` process (totally optional, but I find it useful)
    args.is_master = args.local_rank == 0

    # set the device
    args.device = torch.cuda.device(args.local_rank)

    # initialize PyTorch distributed using environment variables (you could also do this more explicitly by specifying `rank` and `world_size`, but I find using environment variables makes it so that you can easily use the same script on different machines)
    dist.init_process_group(backend='nccl', init_method='env://')

    # set the seed for all GPUs (also make sure to set the seed for random, numpy, etc.)
    torch.cuda.manual_seed_all(SEED)

    # initialize your model (BERT in this example)
    model = BertForMaskedLM.from_pretrained('bert-base-uncased')

    # send your model to GPU
    model = model.to(device)

    # initialize distributed data parallel (DDP)
    model = DDP(
    model,
    device_ids=[args.local_rank],
    output_device=args.local_rank
    )

    # initialize your dataset
    dataset = YourDataset()

    # initialize the DistributedSampler
    sampler = DistributedSampler(dataset)

    # initialize the dataloader
    dataloader = DataLoader(
    dataset=dataset,
    sampler=sampler,
    batch_size=BATCH_SIZE
    )

    # start your training!
    for epoch in range(NUM_EPOCHS):
    # put model in train mode
    model.train()

    # let all processes sync up before starting with a new epoch of training
    dist.barrier()

    for step, batch in enumerate(dataloader):
    # send batch to device
    batch = tuple(t.to(args.device) for t in batch)

    # forward pass
    outputs = model(*batch)

    # compute loss
    loss = outputs[0]

    # etc.


    if __name__ == '__main__':
    main()
    17 changes: 17 additions & 0 deletions ddp_example.sh
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,17 @@
    #!/bin/bash

    # this example uses a single node (`NUM_NODES=1`) w/ 4 GPUs (`NUM_GPUS_PER_NODE=4`)
    export NUM_NODES=1
    export NUM_GPUS_PER_NODE=4
    export NODE_RANK=0
    export WORLD_SIZE=$(($NUM_NODES * $NUM_GPUS_PER_NODE))

    # launch your script w/ `torch.distributed.launch`
    python -m torch.distributed.launch \
    --nproc_per_node=$NUM_GPUS_PER_NODE \
    --nnodes=$NUM_NODES \
    --node_rank $NODE_RANK \
    ddp_example.py \
    # include any arguments to your script, e.g:
    # --seed 42
    # etc.