import torch import torchvision tt = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))]) mnist_train = torchvision.datasets.MNIST('./files/', train=True, download=True, transform=tt) mnist_test = torchvision.datasets.MNIST('./files/', train=False, download=True, transform=tt) train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=20, shuffle=True) test_loader = torch.utils.data.DataLoader(mnist_test, batch_size = 1000, shuffle=True) examples = enumerate(train_loader) batch_idx, (example_data, example_targets) = next(examples)