Skip to content

Instantly share code, notes, and snippets.

@devanshuDesai
Last active September 10, 2024 16:23
Show Gist options
  • Select an option

  • Save devanshuDesai/9f06681d8939afd04f8fab5ac5f5dbf8 to your computer and use it in GitHub Desktop.

Select an option

Save devanshuDesai/9f06681d8939afd04f8fab5ac5f5dbf8 to your computer and use it in GitHub Desktop.
A neural network written in PyTorch with > 99% accuracy on the MNIST dataset.
class CNN(nn.Module):
def __init__(self, input_size, num_classes):
"""
init convolution and activation layers
Args:
input_size: (1,28,28)
num_classes: 10
"""
super(CNN, self).__init__()
### YOUR CODE HERE
self.layer1 = nn.Sequential(
nn.Conv2d(input_size[0], 32, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2))
self.layer2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2))
self.fc1 = nn.Linear(4 * 4 * 64, num_classes)
### END OF CODE
def forward(self, x):
"""
forward function describes how input tensor is transformed to output tensor
Args:
x: (Nx1x28x28) tensor
"""
### YOUR CODE HERE
x = self.layer1(x)
x = self.layer2(x)
x = x.reshape(x.size(0), -1)
x = self.fc1(x)
### END OF CODE
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment