Skip to content

Instantly share code, notes, and snippets.

@seba-1511
Created May 15, 2020 19:28
Show Gist options
  • Select an option

  • Save seba-1511/30e08ff2e680769dc06bc6da81cd33e7 to your computer and use it in GitHub Desktop.

Select an option

Save seba-1511/30e08ff2e680769dc06bc6da81cd33e7 to your computer and use it in GitHub Desktop.
Working implementation of Reptile
#!/usr/bin/env python3
"""
TODO:
* Remove dependency on ppt and randopt.
* Replace get_problems with mini-imagenet dataset.
"""
import os
import random
import copy
import ppt
import randopt as ro
import numpy as np
import torch as th
from torch import nn
from torch import optim
from torchvision import transforms
from statistics import mean
import learn2learn as l2l
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels
from get_problems import PROBLEMS
def accuracy(predictions, targets):
predictions = predictions.argmax(dim=1).view(targets.shape)
return (predictions == targets).sum().float() / targets.size(0)
def fast_adapt(batch, learner, adapt_opt, loss, adaptation_steps, shots, ways, batch_size, device):
data, labels = batch
data, labels = data.to(device), labels.to(device)
# Separate data into adaptation/evalutation sets
adaptation_indices = th.zeros(data.size(0)).byte()
adaptation_indices[th.arange(shots*ways) * 2] = 1
adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
evaluation_data, evaluation_labels = data[1 - adaptation_indices], labels[1 - adaptation_indices]
# Adapt the model
for step in range(adaptation_steps):
idx = th.randint(adaptation_data.size(0),
size=(batch_size, ))
adapt_X = adaptation_data[idx]
adapt_y = adaptation_labels[idx]
adapt_opt.zero_grad()
error = loss(learner(adapt_X), adapt_y)
error.backward()
adapt_opt.step()
# Evaluate the adapted model
predictions = learner(evaluation_data)
valid_error = loss(predictions, evaluation_labels)
valid_error /= len(evaluation_data)
valid_accuracy = accuracy(predictions, evaluation_labels)
return valid_error, valid_accuracy
@ro.cli
def new_reptile(
experiment='dev',
problem='mini-imagenet',
ways=5,
train_shots=15,
test_shots=5,
meta_lr=1.0,
meta_bsz=5,
fast_lr=0.001,
train_bsz=10,
test_bsz=15,
train_steps=8,
test_steps=50,
iterations=100000,
test_interval=100,
save='',
cuda=1,
seed=42,
):
exp_name = 'new-reptile-' + experiment + '-' + ro.dict_to_string(locals())
exp_params = ro.dict_to_constants(locals())
exp = ro.Experiment(name=exp_name, params=exp_params, directory='results')
plotter = ppt.Plotter(env=exp_name)
cuda = bool(cuda)
random.seed(seed)
np.random.seed(seed)
th.manual_seed(seed)
device = th.device('cpu')
if cuda and th.cuda.device_count():
th.cuda.manual_seed(seed)
device = th.device('cuda')
train_tasks, valid_tasks, test_tasks, model = PROBLEMS[problem](ways=ways,
train_shots=train_shots,
test_shots=test_shots)
# Create model
model.to(device)
opt = optim.SGD(model.parameters(), meta_lr)
adapt_opt = optim.Adam(model.parameters(), lr=fast_lr, betas=(0, 0.999))
adapt_opt_state = adapt_opt.state_dict()
loss = nn.CrossEntropyLoss(reduction='mean')
train_inner_errors = []
train_inner_accuracies = []
valid_inner_errors = []
valid_inner_accuracies = []
test_inner_errors = []
test_inner_accuracies = []
for iteration in range(iterations):
opt.zero_grad()
meta_train_error = 0.0
meta_train_accuracy = 0.0
meta_valid_error = 0.0
meta_valid_accuracy = 0.0
meta_test_error = 0.0
meta_test_accuracy = 0.0
# anneal meta-lr
frac_done = float(iteration) / iterations
new_lr = frac_done * meta_lr + (1 - frac_done) * meta_lr
for pg in opt.param_groups:
pg['lr'] = new_lr
# zero-grad the parameters
for p in model.parameters():
p.grad = th.zeros_like(p.data)
for task in range(meta_bsz):
# Compute meta-training loss
learner = copy.deepcopy(model)
adapt_opt = optim.Adam(learner.parameters(),
lr=fast_lr,
betas=(0, 0.999))
adapt_opt.load_state_dict(adapt_opt_state)
batch = train_tasks.sample()
evaluation_error, evaluation_accuracy = fast_adapt(batch,
learner,
adapt_opt,
loss,
train_steps,
train_shots,
ways,
train_bsz,
device)
adapt_opt_state = adapt_opt.state_dict()
for p, l in zip(model.parameters(), learner.parameters()):
p.grad.data.add_(-1.0, l.data)
meta_train_error += evaluation_error.item()
meta_train_accuracy += evaluation_accuracy.item()
if iteration % test_interval == 0:
# Compute meta-validation loss
learner = copy.deepcopy(model)
adapt_opt = optim.Adam(learner.parameters(),
lr=fast_lr,
betas=(0, 0.999))
adapt_opt.load_state_dict(adapt_opt_state)
batch = valid_tasks.sample()
evaluation_error, evaluation_accuracy = fast_adapt(batch,
learner,
adapt_opt,
loss,
test_steps,
test_shots,
ways,
test_bsz,
device)
meta_valid_error += evaluation_error.item()
meta_valid_accuracy += evaluation_accuracy.item()
# Compute meta-testing loss
learner = copy.deepcopy(model)
adapt_opt = optim.Adam(learner.parameters(),
lr=fast_lr,
betas=(0, 0.999))
adapt_opt.load_state_dict(adapt_opt_state)
batch = test_tasks.sample()
evaluation_error, evaluation_accuracy = fast_adapt(batch,
learner,
adapt_opt,
loss,
test_steps,
test_shots,
ways,
test_bsz,
device)
meta_test_error += evaluation_error.item()
meta_test_accuracy += evaluation_accuracy.item()
# Print some metrics
print('\n')
print('Iteration', iteration)
print('Meta Train Error', meta_train_error / meta_bsz)
print('Meta Train Accuracy', meta_train_accuracy / meta_bsz)
print('Meta Valid Error', meta_valid_error / meta_bsz)
print('Meta Valid Accuracy', meta_valid_accuracy / meta_bsz)
print('Meta Test Error', meta_test_error / meta_bsz)
print('Meta Test Accuracy', meta_test_accuracy / meta_bsz)
# Track quantities
train_inner_errors.append(meta_train_error / meta_bsz)
train_inner_accuracies.append(meta_train_accuracy / meta_bsz)
if iteration % test_interval == 0:
valid_inner_errors.append(meta_valid_error / meta_bsz)
valid_inner_accuracies.append(meta_valid_accuracy / meta_bsz)
test_inner_errors.append(meta_test_error / meta_bsz)
test_inner_accuracies.append(meta_test_accuracy / meta_bsz)
# Average the accumulated gradients and optimize
for p in model.parameters():
p.grad.data.mul_(1.0 / meta_bsz).add_(p.data)
opt.step()
result = mean(valid_inner_accuracies[-100:])
data = {
'train_inner_errors': train_inner_errors,
'train_inner_accuracies': train_inner_accuracies,
'valid_inner_errors': valid_inner_errors,
'valid_inner_accuracies': valid_inner_accuracies,
'test_inner_errors': test_inner_errors,
'test_inner_accuracies': test_inner_accuracies,
}
exp.add_result(result, data)
if save != '':
data['model'] = model.cpu().state_dict()
data['adapt_opt_state'] = {k: v.cpu() if hasattr(v, 'cpu') else v
for k, v in adapt_opt_state.items()}
th.save(data, 'archives/' + save + '_' + exp_name + '.pth')
if __name__ == '__main__':
ro.parse()
@seba-1511
Copy link
Copy Markdown
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment