Skip to content

Instantly share code, notes, and snippets.

@galatolofederico
Created November 29, 2019 11:46
Show Gist options
  • Select an option

  • Save galatolofederico/32bd91ec784a5089a3628b46f7ee5395 to your computer and use it in GitHub Desktop.

Select an option

Save galatolofederico/32bd91ec784a5089a3628b46f7ee5395 to your computer and use it in GitHub Desktop.
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from argparse import ArgumentParser
from pytorch_lightning import Trainer
import pytorch_lightning as pl
def mean_lod(key, lod):
ret = 0
for d in lod:
ret += d[key]
return ret/len(lod)
class System(pl.LightningModule):
def __init__(self, hparams):
super(System, self).__init__()
self.net = torch.nn.Sequential(
torch.nn.BatchNorm1d(784),
torch.nn.Linear(784, 100),
torch.nn.LeakyReLU(),
torch.nn.Linear(100, 100),
torch.nn.LeakyReLU(),
torch.nn.Linear(100, 10),
torch.nn.Softmax(dim=1)
)
self.loss_fn = torch.nn.CrossEntropyLoss()
self.transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.view(x.shape[0], -1)[0,:])
])
self.hparams = hparams
def forward(self, x):
return self.net(x)
def training_step(self, batch, batch_nb):
x, y = batch
out = self.forward(x)
acc = (out.max(dim=1).indices == y).float().mean()
loss = self.loss_fn(out, y)
return {
"loss": loss,
"progress_bar": {
"loss": loss,
"acc": acc,
},
"log": {
"loss": loss.item(),
"acc": acc.item()
}
}
def validation_step(self, batch, batch_nb):
x, y = batch
out = self.forward(x)
acc = (out.max(dim=1).indices == y).float().mean()
loss = self.loss_fn(out, y)
return {
"val_loss": loss.item(),
"val_acc": acc.item()
}
def validation_end(self, outputs):
return{
"log":{
"val_accuracy": mean_lod("val_acc", outputs),
"val_loss": mean_lod("val_loss", outputs),
}
}
def test_step(self, batch, batch_nb):
x, y = batch
out = self.forward(x)
rights = (out.max(dim=1).indices == y).float().sum()
return {
"size": out.shape[0],
"rights": rights,
"accuracy": rights/out.shape[0]
}
def test_end(self, outputs):
acc = 0
tots = 0
for output in outputs:
acc += output["rights"].item()
tots += output["size"]
print("Accuracy :", acc/tots)
return {
"log":{
"test_accuracy": acc/tots
},
"progress_bar": {
"test_accuracy": acc/tots
}
}
def configure_optimizers(self):
return torch.optim.Adam(self.net.parameters(), lr=self.hparams.lr)
@pl.data_loader
def train_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=self.transforms), batch_size=32)
@pl.data_loader
def val_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=self.transforms), batch_size=32)
@pl.data_loader
def test_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=self.transforms), batch_size=32)
parser = ArgumentParser(add_help=False)
parser.add_argument("--lr", type=float, default=.001)
model = System(parser.parse_args())
trainer = Trainer()
trainer.fit(model)
trainer.test(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment