Skip to content

Instantly share code, notes, and snippets.

@jdriordan
Created April 5, 2024 17:17
Show Gist options
  • Select an option

  • Save jdriordan/6addb849ca5ac560ba68ad15acd4953d to your computer and use it in GitHub Desktop.

Select an option

Save jdriordan/6addb849ca5ac560ba68ad15acd4953d to your computer and use it in GitHub Desktop.
(import torch)
(import torch [nn])
(import torch.optim [SGD])
(import torch.utils.data [DataLoader])
(import torchvision [datasets transforms])
(import pytorch_lightning :as pl)
;; Define CNN module
(defclass CNNModule [pl.LightningModule]
(defn __init__ [self]
((. (super) __init__))
(setv self.model (nn.Sequential
(nn.Conv2d 1 32 3)
(nn.ReLU)
(nn.MaxPool2d 2 2)
(nn.Conv2d 32 64 3)
(nn.ReLU)
(nn.MaxPool2d 2 2)
(nn.Flatten)
(nn.Linear (* 64 5 5) 128)
(nn.ReLU)
(nn.Linear 128 10)))
(setv self.criterion (nn.CrossEntropyLoss))
(setv self.optimizer (SGD (.parameters self.model) 0.001 0.9)))
(defn forward [self x]
(.forward self.model x))
(defn training_step [self batch batch_idx]
(setv [inputs labels] batch)
(setv outputs (self.forward inputs))
(setv loss (self.criterion outputs labels))
{"loss" loss})
(defn configure_optimizers [self]
self.optimizer))
;; Define main function
(defn main []
(setv transform (transforms.Compose [(transforms.ToTensor) (transforms.Normalize 0.5 0.5)]))
(setv train-set (datasets.MNIST :root "./data" :train True :transform transform :download True))
(setv train-loader (DataLoader train-set :batch-size 4 :shuffle True :num-workers 2))
(setv model (CNNModule))
(setv trainer (pl.Trainer :max_epochs 5))
(.fit trainer model train_loader))
(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment