Skip to content

Instantly share code, notes, and snippets.

@enhuiz
Last active October 6, 2022 13:54
Show Gist options
  • Select an option

  • Save enhuiz/58c688d3a30678625719f7fe7743c1c2 to your computer and use it in GitHub Desktop.

Select an option

Save enhuiz/58c688d3a30678625719f7fe7743c1c2 to your computer and use it in GitHub Desktop.
Low Resource MNIST
import argparse
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm, trange
from torch.nn.utils.weight_norm import weight_norm
from collections import defaultdict
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
def _subset(ds: MNIST, n=[10] * 10):
"""
Args:
ds: MNIST dataset
n: the size of new dataset
p: the distribution of targets in the new dataset
"""
assert len(n) == 10, "Must specify the number for each label."
target_to_indices = defaultdict(list)
for i, t in enumerate(ds.targets):
target_to_indices[t.item()].append(i)
data = ds.data
targets = ds.targets
ds.data = []
ds.targets = []
for t, nt in enumerate(n):
ds.data.extend(data[target_to_indices[t][:nt]])
ds.targets.extend(targets[target_to_indices[t][:nt]])
ds.data = torch.stack(ds.data)
ds.targets = torch.stack(ds.targets)
return ds
class Model(nn.Sequential):
def __init__(self):
super().__init__(
nn.Conv2d(1, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.GELU(),
nn.Conv2d(32, 32, 3, dilation=2, padding=2),
nn.BatchNorm2d(32),
nn.GELU(),
nn.Conv2d(32, 32, 3, dilation=4, padding=4),
nn.BatchNorm2d(32),
nn.GELU(),
nn.Conv2d(32, 32, 3, dilation=8, padding=8),
nn.BatchNorm2d(32),
nn.GELU(),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(1),
nn.Linear(32, 10),
)
self.apply(self._weight_norm)
@staticmethod
def _weight_norm(m):
if isinstance(m, nn.Conv2d):
weight_norm(m)
def _loop_forever(dl):
while True:
yield from dl
def run(args, n):
train_ds = MNIST(
"data",
train=True,
download=True,
transform=transforms.ToTensor(),
)
test_ds = MNIST(
"data",
train=False,
download=True,
transform=transforms.ToTensor(),
)
train_ds = _subset(train_ds, n=[n] * 10)
train_dl = DataLoader(
train_ds,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
)
test_dl = DataLoader(
test_ds,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
)
del train_ds, test_ds
model = Model()
model.to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
history_accs = []
for i, (x, y) in enumerate(_loop_forever(train_dl)):
x, y = map(lambda t: t.to(args.device), [x, y])
h = model(x)
loss = F.cross_entropy(h, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i * 10) % args.val_every == 0 and args.verbose:
print(f"Iteration {i} Loss: {loss.item():.5g}")
if i % args.val_every == 0:
with torch.inference_mode():
model.eval()
digit_accs = defaultdict(list)
accs = []
for x, y in (tqdm if args.verbose else lambda x: x)(test_dl):
x, y = map(lambda t: t.to(args.device), [x, y])
batch_accs = (model(x).argmax(dim=-1) == y).float().tolist()
accs.extend(batch_accs)
for yi, ai in zip(y, batch_accs):
digit_accs[yi.item()].append(ai)
acc = np.mean(accs)
if args.verbose:
print(f"Accuracy: {acc:.4g}.")
for d, daccs in sorted(digit_accs.items()):
daccs = np.mean(daccs)
if args.verbose:
print(f"Accuracy of {d}: {daccs:.4g}.")
model.train()
history_accs.append(acc)
if len(history_accs) > 2 and history_accs[-1] < history_accs[-2]:
return history_accs[-2]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=int, default=10)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--val-every", type=int, default=1000)
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()
n_to_accs = {}
for n in trange(1, 11):
n_to_accs[n] = run(args, n)
print(n_to_accs)
if __name__ == "__main__":
main()
@enhuiz
Copy link
Copy Markdown
Author

enhuiz commented Oct 6, 2022

{1: 0.4242, 2: 0.6128, 3: 0.6846, 4: 0.8414, 5: 0.8363, 6: 0.8857, 7: 0.8516, 8: 0.8808, 9: 0.8765, 10: 0.8995}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment