Created
April 5, 2024 17:17
-
-
Save jdriordan/6addb849ca5ac560ba68ad15acd4953d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| (import torch) | |
| (import torch [nn]) | |
| (import torch.optim [SGD]) | |
| (import torch.utils.data [DataLoader]) | |
| (import torchvision [datasets transforms]) | |
| (import pytorch_lightning :as pl) | |
| ;; Define CNN module | |
| (defclass CNNModule [pl.LightningModule] | |
| (defn __init__ [self] | |
| ((. (super) __init__)) | |
| (setv self.model (nn.Sequential | |
| (nn.Conv2d 1 32 3) | |
| (nn.ReLU) | |
| (nn.MaxPool2d 2 2) | |
| (nn.Conv2d 32 64 3) | |
| (nn.ReLU) | |
| (nn.MaxPool2d 2 2) | |
| (nn.Flatten) | |
| (nn.Linear (* 64 5 5) 128) | |
| (nn.ReLU) | |
| (nn.Linear 128 10))) | |
| (setv self.criterion (nn.CrossEntropyLoss)) | |
| (setv self.optimizer (SGD (.parameters self.model) 0.001 0.9))) | |
| (defn forward [self x] | |
| (.forward self.model x)) | |
| (defn training_step [self batch batch_idx] | |
| (setv [inputs labels] batch) | |
| (setv outputs (self.forward inputs)) | |
| (setv loss (self.criterion outputs labels)) | |
| {"loss" loss}) | |
| (defn configure_optimizers [self] | |
| self.optimizer)) | |
| ;; Define main function | |
| (defn main [] | |
| (setv transform (transforms.Compose [(transforms.ToTensor) (transforms.Normalize 0.5 0.5)])) | |
| (setv train-set (datasets.MNIST :root "./data" :train True :transform transform :download True)) | |
| (setv train-loader (DataLoader train-set :batch-size 4 :shuffle True :num-workers 2)) | |
| (setv model (CNNModule)) | |
| (setv trainer (pl.Trainer :max_epochs 5)) | |
| (.fit trainer model train_loader)) | |
| (main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment