Skip to content

Instantly share code, notes, and snippets.

@laol777
Created June 5, 2019 12:47
Show Gist options
  • Select an option

  • Save laol777/696643bfa4845594938730051d640343 to your computer and use it in GitHub Desktop.

Select an option

Save laol777/696643bfa4845594938730051d640343 to your computer and use it in GitHub Desktop.
import simulation
import helper
import copy
import time
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
import torch.nn.functional as F
import torch.nn as nn
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
from convcrf import convcrf
import ipdb
import matplotlib.pyplot as plt
input_images, target_masks = simulation.generate_random_data(320, 320, count=3)
class SimDataset(Dataset):
def __init__(self, count, transform=None):
self.input_images, self.target_masks = simulation.generate_random_data(320, 320, count=count)
self.transform = transform
def __len__(self):
return len(self.input_images)
def __getitem__(self, idx):
image = self.input_images[idx]
mask = self.target_masks[idx]
if self.transform:
image = self.transform(image)
return [image, mask]
trans = transforms.Compose([
transforms.ToTensor(),
])
train_set = SimDataset(2000, transform = trans)
val_set = SimDataset(200, transform = trans)
# train_set = SimDataset(4, transform=trans)
# val_set = SimDataset(2, transform=trans)
image_datasets = {
'train': train_set, 'val': val_set
}
batch_size = 25
batch_size = 1
dataloaders = {
'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
}
dataset_sizes = {
x: len(image_datasets[x]) for x in image_datasets.keys()
}
# Generate some random images
input_images, target_masks = simulation.generate_random_data(320, 320, count=3)
# target_masks = target_masks[:, :2, :, :]
for x in [input_images, target_masks]:
print(x.shape)
print(x.min(), x.max())
# Change channel-order and make 3 channels for matplot
input_images_rgb = [x.astype(np.uint8) for x in input_images]
# Map each channel (i.e. class) to each color
target_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks]
# Left: Input image, Right: Target mask (Ground-truth)
helper.plot_side_by_side([input_images_rgb, target_masks_rgb])
def double_conv(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True)
)
class UNet(nn.Module):
def __init__(self, n_class):
super().__init__()
self.dconv_down1 = double_conv(3, 64)
self.dconv_down2 = double_conv(64, 128)
self.dconv_down3 = double_conv(128, 256)
self.dconv_down4 = double_conv(256, 512)
self.maxpool = nn.MaxPool2d(2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.dconv_up3 = double_conv(256 + 512, 256)
self.dconv_up2 = double_conv(128 + 256, 128)
self.dconv_up1 = double_conv(128 + 64, 64)
self.conv_last = nn.Conv2d(64, n_class, 1)
shape = (320, 320)
config = convcrf.default_conf
config['pyinn'] = False
config['trainable'] = True
config['trainable_bias'] = True
self.convcrf = convcrf.GaussCRF(conf=config, shape=shape, nclasses=n_class)
self.postprocessing = False
def forward(self, x):
x_origin = x
conv1 = self.dconv_down1(x)
x = self.maxpool(conv1)
conv2 = self.dconv_down2(x)
x = self.maxpool(conv2)
conv3 = self.dconv_down3(x)
x = self.maxpool(conv3)
x = self.dconv_down4(x)
x = self.upsample(x)
x = torch.cat([x, conv3], dim=1)
x = self.dconv_up3(x)
x = self.upsample(x)
x = torch.cat([x, conv2], dim=1)
x = self.dconv_up2(x)
x = self.upsample(x)
x = torch.cat([x, conv1], dim=1)
x = self.dconv_up1(x)
out_x = self.conv_last(x)
if self.postprocessing:
# out_x = torch.clamp(out_x, 0, 1)
# ipdb.set_trace()
out_x = self.convcrf(out_x, x_origin)
return out_x
def postprocessing_state(self, is_enamble=False):
self.postprocessing = is_enamble
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(2)
model = model.to(device)
model.eval()
for inputs, labels in dataloaders['train']:
inputs = inputs.to(device)
labels = labels.to(device)
with torch.no_grad():
res = model(inputs)
break
def dice_loss(pred, target, smooth=1.):
pred = pred.contiguous()
target = target.contiguous()
intersection = (pred * target).sum(dim=2).sum(dim=2)
loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
return loss.mean()
def calc_loss(pred, target, metrics, bce_weight=0.5):
bce = F.binary_cross_entropy_with_logits(pred, target)
pred = torch.sigmoid(pred)
dice = dice_loss(pred, target)
loss = bce * bce_weight + dice * (1 - bce_weight)
metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
metrics['loss'] += loss.data.cpu().numpy() * target.size(0)
return loss
def print_metrics(metrics, epoch_samples, phase):
outputs = []
for k in metrics.keys():
outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
print("{}: {}".format(phase, ", ".join(outputs)))
def train_model(model, optimizer, scheduler, num_epochs=25):
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = 1e10
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
since = time.time()
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
for param_group in optimizer.param_groups:
print("LR", param_group['lr'])
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
metrics = defaultdict(float)
epoch_samples = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
loss = calc_loss(outputs, labels, metrics)
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
epoch_samples += inputs.size(0)
print_metrics(metrics, epoch_samples, phase)
epoch_loss = metrics['loss'] / epoch_samples
# deep copy the model
if phase == 'val' and epoch_loss < best_loss:
print("saving best model")
best_loss = epoch_loss
best_model_wts = copy.deepcopy(model.state_dict())
time_elapsed = time.time() - since
print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val loss: {:4f}'.format(best_loss))
# load best model weights
model.load_state_dict(best_model_wts)
return model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
num_class = 2
model = UNet(num_class).to(device)
# Observe that all parameters are being optimized
optimizer_ft = optim.Adam(model.parameters(), lr=1e-4)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=25, gamma=0.1)
model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=2)
model.eval()
for inputs, labels in dataloaders['train']:
inputs = inputs.to(device)
labels = labels.to(device)
with torch.no_grad():
res = model(inputs)
break
print('TRAININ WITH CONVCRF')
model.postprocessing_state(True)
model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=2)
torch.save(model.state_dict(), 'model_artif.torch')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
num_class = 2
model = UNet(num_class).to(device)
model.load_state_dict(torch.load('model_artif.torch'))
model.eval()
def reverse_transform(inp):
inp = inp.numpy().transpose((1, 2, 0))
inp = np.clip(inp, 0, 1)
inp = (inp * 255).astype(np.uint8)
return inp
model.eval() # Set model to evaluate mode
test_dataset = SimDataset(3, transform=trans)
test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0)
inputs, labels = next(iter(test_loader))
inputs = inputs.to(device)
labels = labels.to(device)
pred = model(inputs)
pred = pred.data.cpu().numpy()
print(pred.shape)
input_images_rgb = [reverse_transform(x) for x in inputs.cpu()]
target_masks_rgb = [helper.masks_to_colorimg(x) for x in labels.cpu().numpy()]
pred_rgb = [helper.masks_to_colorimg(x) for x in pred]
helper.plot_side_by_side([input_images_rgb, target_masks_rgb, pred_rgb])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment