Skip to content

Instantly share code, notes, and snippets.

@DN6
Last active October 18, 2021 17:10
Show Gist options
  • Select an option

  • Save DN6/6e3652080e7be6079d2d5b1737d6b747 to your computer and use it in GitHub Desktop.

Select an option

Save DN6/6e3652080e7be6079d2d5b1737d6b747 to your computer and use it in GitHub Desktop.
Catalyst + Comet
import comet_ml
import os
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl
from catalyst.data import ToTensor
from catalyst.contrib.datasets import MNIST
from torch.utils.data import DataLoader
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
logger = dl.CometLogger()
hparams = {"lr": 1.0e-3, "batch_size": 32}
optimizer = optim.Adam(model.parameters(), lr=hparams["lr"])
loaders = {
"train": DataLoader(
MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()),
batch_size=hparams["batch_size"],
),
"valid": DataLoader(
MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()),
batch_size=hparams["batch_size"],
),
}
runner = dl.SupervisedRunner(
input_key="features", output_key="logits", target_key="targets", loss_key="loss"
)
# model training
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
hparams=hparams,
num_epochs=1,
callbacks=[
dl.AccuracyCallback(
input_key="logits", target_key="targets", topk_args=(1, 3, 5)
),
dl.PrecisionRecallF1SupportCallback(
input_key="logits", target_key="targets", num_classes=10
),
],
logdir="./logs",
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
verbose=True,
load_best_on_end=True,
loggers={"comet": logger},
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment