Skip to content

Instantly share code, notes, and snippets.

@jamesdellinger
Created June 16, 2019 00:58
Show Gist options
  • Select an option

  • Save jamesdellinger/c4a7aa588f2971a89c01484bb680c4e5 to your computer and use it in GitHub Desktop.

Select an option

Save jamesdellinger/c4a7aa588f2971a89c01484bb680c4e5 to your computer and use it in GitHub Desktop.
DALI train and validation pipelines for Imagenette
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from torch import Tensor
from fastai import datasets
im_size = 128
bs = 64
path = datasets.untar_data('https://s3.amazonaws.com/fast-ai-imageclas/imagenette')
imagenette_means = Tensor([0.4879, 0.4740, 0.4307])
imagenette_std_devs = Tensor([0.2814, 0.2831, 0.3079])
class ImagenetteTrainPipeline(Pipeline):
def __init__(self, batch_size=bs, num_threads=8, device_id=0):
super(ImagenetteTrainPipeline, self).__init__(batch_size, num_threads, device_id, seed=42)
self.input = ops.FileReader(file_root = path/'train', random_shuffle=True)
# Randomly crop and resize
self.decode = ops.nvJPEGDecoderRandomCrop(device='mixed', output_type=types.RGB,
random_area=[0.08,1.0],
random_aspect_ratio=[0.75,1.333333],
num_attempts=100)
self.resize = ops.Resize(device='gpu', resize_x=im_size, resize_y=im_size,
interp_type=types.INTERP_NN)
# Will flip vertically with prob of 0.1
self.vert_flip = ops.Flip(device='gpu', horizontal=0, interp_type=types.INTERP_NN)
self.vert_coin = ops.CoinFlip(probability=0.1)
# My workaround for Dali not supporting random affine transforms:
# a "synthetic random" warp affine transform.
self.num_warp_tfms = 7
self.affine_tfms = [get_affine_tfm(im_size, 0.1) for i in range(self.num_warp_tfms)]
self.warp_tfms = [ops.WarpAffine(device='gpu', matrix=i, interp_type=types.INTERP_NN) for i in self.affine_tfms]
self.warp_prob = ops.CoinFlip(probability=0.025)
# TODO: self.warp_prob() should actually be called below in define_graph().
# Calling it up here in the __init__ method means that the same warp affines
# will be applied to images at the *same* indices in every mini-batch.
self.warp_probs = [self.warp_prob() for i in range(self.num_warp_tfms)]
# Rotate within a narrow range with probability of 0.075
self.rotate = ops.Rotate(device='gpu', interp_type=types.INTERP_NN)
self.rotate_range = ops.Uniform(range = (-7, 7))
self.rotate_coin = ops.CoinFlip(probability=0.075)
# Flip horizontally with prob of 0.5, then convert tensor format from
# NHWC to NCHW and normalize.
self.cmnp = ops.CropMirrorNormalize(device='gpu',
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(im_size, im_size),
image_type=types.RGB,
mean=listify(imagenette_means*255),
std=listify(imagenette_std_devs*255))
self.mirror_coin = ops.CoinFlip(probability=0.5)
def define_graph(self):
# Generate new random numbers for each mini-batch
prob_vert_flip = self.vert_coin()
prob_rotate = self.rotate_coin()
prob_mirror = self.mirror_coin()
angle_range = self.rotate_range()
# Perform pipeline operations in the order they appear below:
self.jpegs, self.labels = self.input(name='r')
images = self.decode(self.jpegs)
images = self.resize(images)
images = self.vert_flip(images, vertical=prob_vert_flip) # Specify prob_vert_flip here
for i, tfm in enumerate(self.warp_tfms):
images = tfm(images, mask=self.warp_probs[i])
images = self.rotate(images, angle=angle_range, mask=prob_rotate)
images = self.cmnp(images, mirror=prob_mirror)
return (images, self.labels)
class ImagenetteValPipeline(Pipeline):
def __init__(self, batch_size=bs, num_threads=8, device_id=0):
super(ImagenetteValPipeline, self).__init__(batch_size, num_threads, device_id, seed=42)
self.input = ops.FileReader(file_root = path/'val')
self.decode = ops.nvJPEGDecoder(device='mixed', output_type=types.RGB)
# Not possible to center crop with DALI, so I use the entire image.
self.resize = ops.Resize(device = 'gpu', resize_x=im_size, resize_y=im_size,
interp_type=types.INTERP_NN)
# Convert tensor format from NHWC to NCHW and normalize
self.normperm = ops.NormalizePermute(device="gpu",
height=im_size,
width=im_size,
output_dtype=types.FLOAT,
image_type=types.RGB,
mean=listify(imagenette_means*255),
std=listify(imagenette_std_devs*255))
def define_graph(self):
self.jpegs, self.labels = self.input(name='r')
images = self.decode(self.jpegs)
images = self.resize(images)
images = self.normperm(images)
return (images, self.labels)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment