Created
May 15, 2020 19:28
-
-
Save seba-1511/30e08ff2e680769dc06bc6da81cd33e7 to your computer and use it in GitHub Desktop.
Working implementation of Reptile
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| TODO: | |
| * Remove dependency on ppt and randopt. | |
| * Replace get_problems with mini-imagenet dataset. | |
| """ | |
| import os | |
| import random | |
| import copy | |
| import ppt | |
| import randopt as ro | |
| import numpy as np | |
| import torch as th | |
| from torch import nn | |
| from torch import optim | |
| from torchvision import transforms | |
| from statistics import mean | |
| import learn2learn as l2l | |
| from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels | |
| from get_problems import PROBLEMS | |
| def accuracy(predictions, targets): | |
| predictions = predictions.argmax(dim=1).view(targets.shape) | |
| return (predictions == targets).sum().float() / targets.size(0) | |
| def fast_adapt(batch, learner, adapt_opt, loss, adaptation_steps, shots, ways, batch_size, device): | |
| data, labels = batch | |
| data, labels = data.to(device), labels.to(device) | |
| # Separate data into adaptation/evalutation sets | |
| adaptation_indices = th.zeros(data.size(0)).byte() | |
| adaptation_indices[th.arange(shots*ways) * 2] = 1 | |
| adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices] | |
| evaluation_data, evaluation_labels = data[1 - adaptation_indices], labels[1 - adaptation_indices] | |
| # Adapt the model | |
| for step in range(adaptation_steps): | |
| idx = th.randint(adaptation_data.size(0), | |
| size=(batch_size, )) | |
| adapt_X = adaptation_data[idx] | |
| adapt_y = adaptation_labels[idx] | |
| adapt_opt.zero_grad() | |
| error = loss(learner(adapt_X), adapt_y) | |
| error.backward() | |
| adapt_opt.step() | |
| # Evaluate the adapted model | |
| predictions = learner(evaluation_data) | |
| valid_error = loss(predictions, evaluation_labels) | |
| valid_error /= len(evaluation_data) | |
| valid_accuracy = accuracy(predictions, evaluation_labels) | |
| return valid_error, valid_accuracy | |
| @ro.cli | |
| def new_reptile( | |
| experiment='dev', | |
| problem='mini-imagenet', | |
| ways=5, | |
| train_shots=15, | |
| test_shots=5, | |
| meta_lr=1.0, | |
| meta_bsz=5, | |
| fast_lr=0.001, | |
| train_bsz=10, | |
| test_bsz=15, | |
| train_steps=8, | |
| test_steps=50, | |
| iterations=100000, | |
| test_interval=100, | |
| save='', | |
| cuda=1, | |
| seed=42, | |
| ): | |
| exp_name = 'new-reptile-' + experiment + '-' + ro.dict_to_string(locals()) | |
| exp_params = ro.dict_to_constants(locals()) | |
| exp = ro.Experiment(name=exp_name, params=exp_params, directory='results') | |
| plotter = ppt.Plotter(env=exp_name) | |
| cuda = bool(cuda) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| th.manual_seed(seed) | |
| device = th.device('cpu') | |
| if cuda and th.cuda.device_count(): | |
| th.cuda.manual_seed(seed) | |
| device = th.device('cuda') | |
| train_tasks, valid_tasks, test_tasks, model = PROBLEMS[problem](ways=ways, | |
| train_shots=train_shots, | |
| test_shots=test_shots) | |
| # Create model | |
| model.to(device) | |
| opt = optim.SGD(model.parameters(), meta_lr) | |
| adapt_opt = optim.Adam(model.parameters(), lr=fast_lr, betas=(0, 0.999)) | |
| adapt_opt_state = adapt_opt.state_dict() | |
| loss = nn.CrossEntropyLoss(reduction='mean') | |
| train_inner_errors = [] | |
| train_inner_accuracies = [] | |
| valid_inner_errors = [] | |
| valid_inner_accuracies = [] | |
| test_inner_errors = [] | |
| test_inner_accuracies = [] | |
| for iteration in range(iterations): | |
| opt.zero_grad() | |
| meta_train_error = 0.0 | |
| meta_train_accuracy = 0.0 | |
| meta_valid_error = 0.0 | |
| meta_valid_accuracy = 0.0 | |
| meta_test_error = 0.0 | |
| meta_test_accuracy = 0.0 | |
| # anneal meta-lr | |
| frac_done = float(iteration) / iterations | |
| new_lr = frac_done * meta_lr + (1 - frac_done) * meta_lr | |
| for pg in opt.param_groups: | |
| pg['lr'] = new_lr | |
| # zero-grad the parameters | |
| for p in model.parameters(): | |
| p.grad = th.zeros_like(p.data) | |
| for task in range(meta_bsz): | |
| # Compute meta-training loss | |
| learner = copy.deepcopy(model) | |
| adapt_opt = optim.Adam(learner.parameters(), | |
| lr=fast_lr, | |
| betas=(0, 0.999)) | |
| adapt_opt.load_state_dict(adapt_opt_state) | |
| batch = train_tasks.sample() | |
| evaluation_error, evaluation_accuracy = fast_adapt(batch, | |
| learner, | |
| adapt_opt, | |
| loss, | |
| train_steps, | |
| train_shots, | |
| ways, | |
| train_bsz, | |
| device) | |
| adapt_opt_state = adapt_opt.state_dict() | |
| for p, l in zip(model.parameters(), learner.parameters()): | |
| p.grad.data.add_(-1.0, l.data) | |
| meta_train_error += evaluation_error.item() | |
| meta_train_accuracy += evaluation_accuracy.item() | |
| if iteration % test_interval == 0: | |
| # Compute meta-validation loss | |
| learner = copy.deepcopy(model) | |
| adapt_opt = optim.Adam(learner.parameters(), | |
| lr=fast_lr, | |
| betas=(0, 0.999)) | |
| adapt_opt.load_state_dict(adapt_opt_state) | |
| batch = valid_tasks.sample() | |
| evaluation_error, evaluation_accuracy = fast_adapt(batch, | |
| learner, | |
| adapt_opt, | |
| loss, | |
| test_steps, | |
| test_shots, | |
| ways, | |
| test_bsz, | |
| device) | |
| meta_valid_error += evaluation_error.item() | |
| meta_valid_accuracy += evaluation_accuracy.item() | |
| # Compute meta-testing loss | |
| learner = copy.deepcopy(model) | |
| adapt_opt = optim.Adam(learner.parameters(), | |
| lr=fast_lr, | |
| betas=(0, 0.999)) | |
| adapt_opt.load_state_dict(adapt_opt_state) | |
| batch = test_tasks.sample() | |
| evaluation_error, evaluation_accuracy = fast_adapt(batch, | |
| learner, | |
| adapt_opt, | |
| loss, | |
| test_steps, | |
| test_shots, | |
| ways, | |
| test_bsz, | |
| device) | |
| meta_test_error += evaluation_error.item() | |
| meta_test_accuracy += evaluation_accuracy.item() | |
| # Print some metrics | |
| print('\n') | |
| print('Iteration', iteration) | |
| print('Meta Train Error', meta_train_error / meta_bsz) | |
| print('Meta Train Accuracy', meta_train_accuracy / meta_bsz) | |
| print('Meta Valid Error', meta_valid_error / meta_bsz) | |
| print('Meta Valid Accuracy', meta_valid_accuracy / meta_bsz) | |
| print('Meta Test Error', meta_test_error / meta_bsz) | |
| print('Meta Test Accuracy', meta_test_accuracy / meta_bsz) | |
| # Track quantities | |
| train_inner_errors.append(meta_train_error / meta_bsz) | |
| train_inner_accuracies.append(meta_train_accuracy / meta_bsz) | |
| if iteration % test_interval == 0: | |
| valid_inner_errors.append(meta_valid_error / meta_bsz) | |
| valid_inner_accuracies.append(meta_valid_accuracy / meta_bsz) | |
| test_inner_errors.append(meta_test_error / meta_bsz) | |
| test_inner_accuracies.append(meta_test_accuracy / meta_bsz) | |
| # Average the accumulated gradients and optimize | |
| for p in model.parameters(): | |
| p.grad.data.mul_(1.0 / meta_bsz).add_(p.data) | |
| opt.step() | |
| result = mean(valid_inner_accuracies[-100:]) | |
| data = { | |
| 'train_inner_errors': train_inner_errors, | |
| 'train_inner_accuracies': train_inner_accuracies, | |
| 'valid_inner_errors': valid_inner_errors, | |
| 'valid_inner_accuracies': valid_inner_accuracies, | |
| 'test_inner_errors': test_inner_errors, | |
| 'test_inner_accuracies': test_inner_accuracies, | |
| } | |
| exp.add_result(result, data) | |
| if save != '': | |
| data['model'] = model.cpu().state_dict() | |
| data['adapt_opt_state'] = {k: v.cpu() if hasattr(v, 'cpu') else v | |
| for k, v in adapt_opt_state.items()} | |
| th.save(data, 'archives/' + save + '_' + exp_name + '.pth') | |
| if __name__ == '__main__': | |
| ro.parse() |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Merged in learnables/learn2learn#225.