Last active
August 20, 2020 23:24
-
-
Save seba-1511/cf8929b133a050d47fc3ca428dce940a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| inner_model = nnModuleModel(input_size=input_size, output_size=num_classes) | |
| metamodel = nnModuleModel(input_size=input_size, output_size=1) | |
| step = 0 | |
| metaopt = Optim(metamodel.parameters(), with_no_grad=True) # conventional optimizer with "with no_grad" | |
| for epoch in num_epochs: | |
| # Normal training loop | |
| for (images, labels) in my_dataset: | |
| opt.zero_grad() | |
| step += 1 | |
| # training step | |
| preds = model(images) | |
| loss_weight = metamodel(images) | |
| loss = my_loss_function(preds, labels) * loss_weight | |
| # update model params | |
| loss.backward() | |
| opt.step() | |
| print("Training loss:", loss.item(), "at step", step) | |
| # Every meta_update_freq steps (e.g. once per epoch)... | |
| if step % meta_update_freq == 0: | |
| state_dict = inner_model.state_dict() | |
| # ... update the parameters of the model num_meta_updates times | |
| for meta_step in num_meta_updates: | |
| metaopt.zero_grad() | |
| inner_model.load_state_dict(state_dict) # not currently differentiable | |
| # Unroll the training loop num_inner_steps training steps | |
| for inner_step in num_inner_steps: | |
| inner_opt.zero_grad() # doesn't delete information required to do "backprop-through-backprop" | |
| images, labels = my_dataset.next() | |
| preds = inner_model(images) | |
| loss_weight = metamodel(images) | |
| loss = my_loss_function(preds, labels) * loss_weight | |
| grads = torch.autograd.grad(loss, inputs=inner_model.parameters()) | |
| updates = [-lr * g for g in grads] | |
| inner_model.update(updates) | |
| # Compute validation loss on the latest version of the unrolled model | |
| val_images, val_labels = my_validation_set.next() | |
| val_preds = inner_model(val_images) | |
| meta_loss = my_loss_function(val_preds, val_labels) # no loss_weight for meta-loss | |
| # backprop to parameters of metamodel and update them | |
| meta_loss.backward() # this should propagate to every use of the parameters within the unrolled loop | |
| metaopt.step() | |
| print("Meta loss:", meta_loss.item(), "at meta_step", meta_step, "of step", step) | |
| # Run a normal validation loop | |
| for (val_images, val_labels) in my_validation_set: | |
| val_preds = model(val_images) | |
| val_loss = my_loss_function(val_preds, val_labels) # report unweighted validation loss | |
| print("Validation loss:", val_loss.item(), "at step", step) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment