Skip to content

Instantly share code, notes, and snippets.

@YingGwan
Forked from sbarratt/torch_jacobian.py
Created November 26, 2022 11:08
Show Gist options
  • Select an option

  • Save YingGwan/aaf10d5a3b07078f3d8afd6ebcdf1c8f to your computer and use it in GitHub Desktop.

Select an option

Save YingGwan/aaf10d5a3b07078f3d8afd6ebcdf1c8f to your computer and use it in GitHub Desktop.
Get the jacobian of a vector-valued function that takes batch inputs, in pytorch.
def get_jacobian(net, x, noutputs):
x = x.squeeze()
n = x.size()[0]
x = x.repeat(noutputs, 1)
x.requires_grad_(True)
y = net(x)
y.backward(torch.eye(noutputs))
return x.grad.data
@YingGwan
Copy link
Author

Remember to flatten the rest axes to be 1d before you send into net();
This is what squeeze do.
In PyTorch, you can also use *.view(-1,n) to do the stuff
In PyTorch c++, you can do similar thing and I have verified both platforms.

image

@YingGwan
Copy link
Author

import torch

def get_batch_jacobian(net, x, noutputs):
x = x.unsqueeze(1) # b, 1 ,in_dim
n = x.size()[0]
x = x.repeat(1, noutputs, 1) # b, out_dim, in_dim
x.requires_grad_(True)
x.retain_grad()
y = net(x)
input_val = torch.eye(noutputs).reshape(1,noutputs, noutputs).repeat(n, 1, 1)
y.backward(input_val)
return x.grad.data

@YingGwan
Copy link
Author

def get_batch_jacobian(net, x, noutputs):
print(f"Exact in: x shape is {x.shape}") #9000, 3
x = x.unsqueeze(1) # b, 1 ,in_dim #9000, 1, 3
print(f"in: x shape is {x.shape}")
n = x.size()[0]
print(f"in: n is {n}")
x = x.repeat(1, noutputs, 1) # b, out_dim, in_dim #9000, 2, 3
print(f"in: x shape is {x.shape}")
x.requires_grad_(True)
y = net(x)
input_val = torch.eye(noutputs).reshape(1,noutputs, noutputs).repeat(n, 1, 1)
print(f"input_val shape is {input_val.shape}")
y.backward(input_val)
return x.grad.data

class Net(torch.nn.Module):
def init(self, in_dim, out_dim):
super(Net, self).init()
self.fc1 = torch.nn.Linear(in_dim, out_dim)

def forward(self, x):
    return torch.nn.functional.relu(self.fc1(x))

batch = 9000
num_features = 3
num_outputs = 5
x = torch.randn(batch, num_features)
print(f"NEW: x shape is {x.shape}")
net = Net(num_features, num_outputs)

print(f"output dimension\n")
print(x.shape)
print(net(x).shape)
result = get_batch_jacobian(net, x, num_outputs)
print(result.shape)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment