Last active
October 6, 2022 13:54
-
-
Save enhuiz/58c688d3a30678625719f7fe7743c1c2 to your computer and use it in GitHub Desktop.
Low Resource MNIST
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import torch | |
| import numpy as np | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from tqdm import tqdm, trange | |
| from torch.nn.utils.weight_norm import weight_norm | |
| from collections import defaultdict | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms | |
| from torchvision.datasets import MNIST | |
| def _subset(ds: MNIST, n=[10] * 10): | |
| """ | |
| Args: | |
| ds: MNIST dataset | |
| n: the size of new dataset | |
| p: the distribution of targets in the new dataset | |
| """ | |
| assert len(n) == 10, "Must specify the number for each label." | |
| target_to_indices = defaultdict(list) | |
| for i, t in enumerate(ds.targets): | |
| target_to_indices[t.item()].append(i) | |
| data = ds.data | |
| targets = ds.targets | |
| ds.data = [] | |
| ds.targets = [] | |
| for t, nt in enumerate(n): | |
| ds.data.extend(data[target_to_indices[t][:nt]]) | |
| ds.targets.extend(targets[target_to_indices[t][:nt]]) | |
| ds.data = torch.stack(ds.data) | |
| ds.targets = torch.stack(ds.targets) | |
| return ds | |
| class Model(nn.Sequential): | |
| def __init__(self): | |
| super().__init__( | |
| nn.Conv2d(1, 32, 3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.GELU(), | |
| nn.Conv2d(32, 32, 3, dilation=2, padding=2), | |
| nn.BatchNorm2d(32), | |
| nn.GELU(), | |
| nn.Conv2d(32, 32, 3, dilation=4, padding=4), | |
| nn.BatchNorm2d(32), | |
| nn.GELU(), | |
| nn.Conv2d(32, 32, 3, dilation=8, padding=8), | |
| nn.BatchNorm2d(32), | |
| nn.GELU(), | |
| nn.AdaptiveAvgPool2d((1, 1)), | |
| nn.Flatten(1), | |
| nn.Linear(32, 10), | |
| ) | |
| self.apply(self._weight_norm) | |
| @staticmethod | |
| def _weight_norm(m): | |
| if isinstance(m, nn.Conv2d): | |
| weight_norm(m) | |
| def _loop_forever(dl): | |
| while True: | |
| yield from dl | |
| def run(args, n): | |
| train_ds = MNIST( | |
| "data", | |
| train=True, | |
| download=True, | |
| transform=transforms.ToTensor(), | |
| ) | |
| test_ds = MNIST( | |
| "data", | |
| train=False, | |
| download=True, | |
| transform=transforms.ToTensor(), | |
| ) | |
| train_ds = _subset(train_ds, n=[n] * 10) | |
| train_dl = DataLoader( | |
| train_ds, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| drop_last=True, | |
| ) | |
| test_dl = DataLoader( | |
| test_ds, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| drop_last=False, | |
| ) | |
| del train_ds, test_ds | |
| model = Model() | |
| model.to(args.device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) | |
| history_accs = [] | |
| for i, (x, y) in enumerate(_loop_forever(train_dl)): | |
| x, y = map(lambda t: t.to(args.device), [x, y]) | |
| h = model(x) | |
| loss = F.cross_entropy(h, y) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| if (i * 10) % args.val_every == 0 and args.verbose: | |
| print(f"Iteration {i} Loss: {loss.item():.5g}") | |
| if i % args.val_every == 0: | |
| with torch.inference_mode(): | |
| model.eval() | |
| digit_accs = defaultdict(list) | |
| accs = [] | |
| for x, y in (tqdm if args.verbose else lambda x: x)(test_dl): | |
| x, y = map(lambda t: t.to(args.device), [x, y]) | |
| batch_accs = (model(x).argmax(dim=-1) == y).float().tolist() | |
| accs.extend(batch_accs) | |
| for yi, ai in zip(y, batch_accs): | |
| digit_accs[yi.item()].append(ai) | |
| acc = np.mean(accs) | |
| if args.verbose: | |
| print(f"Accuracy: {acc:.4g}.") | |
| for d, daccs in sorted(digit_accs.items()): | |
| daccs = np.mean(daccs) | |
| if args.verbose: | |
| print(f"Accuracy of {d}: {daccs:.4g}.") | |
| model.train() | |
| history_accs.append(acc) | |
| if len(history_accs) > 2 and history_accs[-1] < history_accs[-2]: | |
| return history_accs[-2] | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--batch-size", type=int, default=10) | |
| parser.add_argument("--device", type=str, default="cuda") | |
| parser.add_argument("--lr", type=float, default=1e-4) | |
| parser.add_argument("--val-every", type=int, default=1000) | |
| parser.add_argument("--verbose", action="store_true") | |
| args = parser.parse_args() | |
| n_to_accs = {} | |
| for n in trange(1, 11): | |
| n_to_accs[n] = run(args, n) | |
| print(n_to_accs) | |
| if __name__ == "__main__": | |
| main() |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
{1: 0.4242, 2: 0.6128, 3: 0.6846, 4: 0.8414, 5: 0.8363, 6: 0.8857, 7: 0.8516, 8: 0.8808, 9: 0.8765, 10: 0.8995}