Created
November 29, 2019 11:46
-
-
Save galatolofederico/32bd91ec784a5089a3628b46f7ee5395 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import torch | |
| from torch.nn import functional as F | |
| from torch.utils.data import DataLoader | |
| from torchvision.datasets import MNIST | |
| from torchvision import transforms | |
| from argparse import ArgumentParser | |
| from pytorch_lightning import Trainer | |
| import pytorch_lightning as pl | |
| def mean_lod(key, lod): | |
| ret = 0 | |
| for d in lod: | |
| ret += d[key] | |
| return ret/len(lod) | |
| class System(pl.LightningModule): | |
| def __init__(self, hparams): | |
| super(System, self).__init__() | |
| self.net = torch.nn.Sequential( | |
| torch.nn.BatchNorm1d(784), | |
| torch.nn.Linear(784, 100), | |
| torch.nn.LeakyReLU(), | |
| torch.nn.Linear(100, 100), | |
| torch.nn.LeakyReLU(), | |
| torch.nn.Linear(100, 10), | |
| torch.nn.Softmax(dim=1) | |
| ) | |
| self.loss_fn = torch.nn.CrossEntropyLoss() | |
| self.transforms = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Lambda(lambda x: x.view(x.shape[0], -1)[0,:]) | |
| ]) | |
| self.hparams = hparams | |
| def forward(self, x): | |
| return self.net(x) | |
| def training_step(self, batch, batch_nb): | |
| x, y = batch | |
| out = self.forward(x) | |
| acc = (out.max(dim=1).indices == y).float().mean() | |
| loss = self.loss_fn(out, y) | |
| return { | |
| "loss": loss, | |
| "progress_bar": { | |
| "loss": loss, | |
| "acc": acc, | |
| }, | |
| "log": { | |
| "loss": loss.item(), | |
| "acc": acc.item() | |
| } | |
| } | |
| def validation_step(self, batch, batch_nb): | |
| x, y = batch | |
| out = self.forward(x) | |
| acc = (out.max(dim=1).indices == y).float().mean() | |
| loss = self.loss_fn(out, y) | |
| return { | |
| "val_loss": loss.item(), | |
| "val_acc": acc.item() | |
| } | |
| def validation_end(self, outputs): | |
| return{ | |
| "log":{ | |
| "val_accuracy": mean_lod("val_acc", outputs), | |
| "val_loss": mean_lod("val_loss", outputs), | |
| } | |
| } | |
| def test_step(self, batch, batch_nb): | |
| x, y = batch | |
| out = self.forward(x) | |
| rights = (out.max(dim=1).indices == y).float().sum() | |
| return { | |
| "size": out.shape[0], | |
| "rights": rights, | |
| "accuracy": rights/out.shape[0] | |
| } | |
| def test_end(self, outputs): | |
| acc = 0 | |
| tots = 0 | |
| for output in outputs: | |
| acc += output["rights"].item() | |
| tots += output["size"] | |
| print("Accuracy :", acc/tots) | |
| return { | |
| "log":{ | |
| "test_accuracy": acc/tots | |
| }, | |
| "progress_bar": { | |
| "test_accuracy": acc/tots | |
| } | |
| } | |
| def configure_optimizers(self): | |
| return torch.optim.Adam(self.net.parameters(), lr=self.hparams.lr) | |
| @pl.data_loader | |
| def train_dataloader(self): | |
| return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=self.transforms), batch_size=32) | |
| @pl.data_loader | |
| def val_dataloader(self): | |
| return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=self.transforms), batch_size=32) | |
| @pl.data_loader | |
| def test_dataloader(self): | |
| return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=self.transforms), batch_size=32) | |
| parser = ArgumentParser(add_help=False) | |
| parser.add_argument("--lr", type=float, default=.001) | |
| model = System(parser.parse_args()) | |
| trainer = Trainer() | |
| trainer.fit(model) | |
| trainer.test(model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment