Skip to content

Instantly share code, notes, and snippets.

@vkkhare
Created October 11, 2018 14:41
Show Gist options
  • Select an option

  • Save vkkhare/03fd90e2fc566a9f8da5c07311c04924 to your computer and use it in GitHub Desktop.

Select an option

Save vkkhare/03fd90e2fc566a9f8da5c07311c04924 to your computer and use it in GitHub Desktop.
def train(model, resnet,device, train_loader, optimizer_res, optimizer_att, epoch,losslist,loss,lmb):
for param in model.parameters():
param.requires_grad = False
for param in resnet.parameters():
param.requires_grad = True
model.eval()
resnet.train()
loss_l = torch.zeros(1,dtype=torch.float32).to(device)
b_idx = 0
for x in train_loader:
b_idx+=1
images, label = x['feature'].to(device), x['class_label'].type(torch.LongTensor).to(device)
attribute = x['attribute'].type(torch.FloatTensor).to(device)
optimizer_res.zero_grad()
means,covs = model(attribute)
x_feat,max_margin_pred = resnet(images)
l2 = torch.sum((x_feat-means)*covs*(x_feat-means)) - torch.sum(torch.log(covs))/2
loss_eval = l2 + lmb*loss(max_margin_pred,label-1)
loss_eval.backward()
optimizer_res.step()
loss_l+=loss_eval
if b_idx%20 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, b_idx * x_feat.shape[0], len(train_loader.dataset),
100. * b_idx / len(train_loader), loss_eval.item()))
print("l2",l2.item())
losslist.append(loss_l.item())
for param in model.parameters():
param.requires_grad = True
for param in resnet.parameters():
param.requires_grad = False
model.train()
resnet.eval()
for i in range(20):
b_idx = 0
loss_l = torch.zeros(1,dtype=torch.float32).to(device)
for x in train_loader:
b_idx+=1
images, label = x['feature'].to(device), x['class_label'].type(torch.LongTensor).to(device)
attribute = x['attribute'].type(torch.FloatTensor).to(device)
optimizer_att.zero_grad()
means,covs = model(attribute)
x_feat,max_margin_pred = resnet(images)
l2 = torch.sum((x_feat-means)*covs*(x_feat-means)) - torch.sum(torch.log(covs))/2
loss_eval = l2 + lmb*loss(max_margin_pred,label-1)
loss_eval.backward()
optimizer_att.step()
loss_l+=loss_eval
if b_idx%20 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, b_idx * x_feat.shape[0], len(train_loader.dataset),
100. * b_idx / len(train_loader), loss_eval.item()))
print("l2",l2.item())
losslist.append(loss_l.item())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment