Skip to content

Instantly share code, notes, and snippets.

@seba-1511
Last active August 20, 2020 23:24
Show Gist options
  • Select an option

  • Save seba-1511/cf8929b133a050d47fc3ca428dce940a to your computer and use it in GitHub Desktop.

Select an option

Save seba-1511/cf8929b133a050d47fc3ca428dce940a to your computer and use it in GitHub Desktop.
inner_model = nnModuleModel(input_size=input_size, output_size=num_classes)
metamodel = nnModuleModel(input_size=input_size, output_size=1)
step = 0
metaopt = Optim(metamodel.parameters(), with_no_grad=True) # conventional optimizer with "with no_grad"
for epoch in num_epochs:
# Normal training loop
for (images, labels) in my_dataset:
opt.zero_grad()
step += 1
# training step
preds = model(images)
loss_weight = metamodel(images)
loss = my_loss_function(preds, labels) * loss_weight
# update model params
loss.backward()
opt.step()
print("Training loss:", loss.item(), "at step", step)
# Every meta_update_freq steps (e.g. once per epoch)...
if step % meta_update_freq == 0:
state_dict = inner_model.state_dict()
# ... update the parameters of the model num_meta_updates times
for meta_step in num_meta_updates:
metaopt.zero_grad()
inner_model.load_state_dict(state_dict) # not currently differentiable
# Unroll the training loop num_inner_steps training steps
for inner_step in num_inner_steps:
inner_opt.zero_grad() # doesn't delete information required to do "backprop-through-backprop"
images, labels = my_dataset.next()
preds = inner_model(images)
loss_weight = metamodel(images)
loss = my_loss_function(preds, labels) * loss_weight
grads = torch.autograd.grad(loss, inputs=inner_model.parameters())
updates = [-lr * g for g in grads]
inner_model.update(updates)
# Compute validation loss on the latest version of the unrolled model
val_images, val_labels = my_validation_set.next()
val_preds = inner_model(val_images)
meta_loss = my_loss_function(val_preds, val_labels) # no loss_weight for meta-loss
# backprop to parameters of metamodel and update them
meta_loss.backward() # this should propagate to every use of the parameters within the unrolled loop
metaopt.step()
print("Meta loss:", meta_loss.item(), "at meta_step", meta_step, "of step", step)
# Run a normal validation loop
for (val_images, val_labels) in my_validation_set:
val_preds = model(val_images)
val_loss = my_loss_function(val_preds, val_labels) # report unweighted validation loss
print("Validation loss:", val_loss.item(), "at step", step)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment