Created
June 5, 2019 12:47
-
-
Save laol777/696643bfa4845594938730051d640343 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import simulation | |
| import helper | |
| import copy | |
| import time | |
| from collections import defaultdict | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms, datasets, models | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import torch | |
| import torch.optim as optim | |
| from torch.optim import lr_scheduler | |
| import numpy as np | |
| from convcrf import convcrf | |
| import ipdb | |
| import matplotlib.pyplot as plt | |
| input_images, target_masks = simulation.generate_random_data(320, 320, count=3) | |
| class SimDataset(Dataset): | |
| def __init__(self, count, transform=None): | |
| self.input_images, self.target_masks = simulation.generate_random_data(320, 320, count=count) | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.input_images) | |
| def __getitem__(self, idx): | |
| image = self.input_images[idx] | |
| mask = self.target_masks[idx] | |
| if self.transform: | |
| image = self.transform(image) | |
| return [image, mask] | |
| trans = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| train_set = SimDataset(2000, transform = trans) | |
| val_set = SimDataset(200, transform = trans) | |
| # train_set = SimDataset(4, transform=trans) | |
| # val_set = SimDataset(2, transform=trans) | |
| image_datasets = { | |
| 'train': train_set, 'val': val_set | |
| } | |
| batch_size = 25 | |
| batch_size = 1 | |
| dataloaders = { | |
| 'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0), | |
| 'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0) | |
| } | |
| dataset_sizes = { | |
| x: len(image_datasets[x]) for x in image_datasets.keys() | |
| } | |
| # Generate some random images | |
| input_images, target_masks = simulation.generate_random_data(320, 320, count=3) | |
| # target_masks = target_masks[:, :2, :, :] | |
| for x in [input_images, target_masks]: | |
| print(x.shape) | |
| print(x.min(), x.max()) | |
| # Change channel-order and make 3 channels for matplot | |
| input_images_rgb = [x.astype(np.uint8) for x in input_images] | |
| # Map each channel (i.e. class) to each color | |
| target_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks] | |
| # Left: Input image, Right: Target mask (Ground-truth) | |
| helper.plot_side_by_side([input_images_rgb, target_masks_rgb]) | |
| def double_conv(in_channels, out_channels): | |
| return nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, 3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, 3, padding=1), | |
| nn.ReLU(inplace=True) | |
| ) | |
| class UNet(nn.Module): | |
| def __init__(self, n_class): | |
| super().__init__() | |
| self.dconv_down1 = double_conv(3, 64) | |
| self.dconv_down2 = double_conv(64, 128) | |
| self.dconv_down3 = double_conv(128, 256) | |
| self.dconv_down4 = double_conv(256, 512) | |
| self.maxpool = nn.MaxPool2d(2) | |
| self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
| self.dconv_up3 = double_conv(256 + 512, 256) | |
| self.dconv_up2 = double_conv(128 + 256, 128) | |
| self.dconv_up1 = double_conv(128 + 64, 64) | |
| self.conv_last = nn.Conv2d(64, n_class, 1) | |
| shape = (320, 320) | |
| config = convcrf.default_conf | |
| config['pyinn'] = False | |
| config['trainable'] = True | |
| config['trainable_bias'] = True | |
| self.convcrf = convcrf.GaussCRF(conf=config, shape=shape, nclasses=n_class) | |
| self.postprocessing = False | |
| def forward(self, x): | |
| x_origin = x | |
| conv1 = self.dconv_down1(x) | |
| x = self.maxpool(conv1) | |
| conv2 = self.dconv_down2(x) | |
| x = self.maxpool(conv2) | |
| conv3 = self.dconv_down3(x) | |
| x = self.maxpool(conv3) | |
| x = self.dconv_down4(x) | |
| x = self.upsample(x) | |
| x = torch.cat([x, conv3], dim=1) | |
| x = self.dconv_up3(x) | |
| x = self.upsample(x) | |
| x = torch.cat([x, conv2], dim=1) | |
| x = self.dconv_up2(x) | |
| x = self.upsample(x) | |
| x = torch.cat([x, conv1], dim=1) | |
| x = self.dconv_up1(x) | |
| out_x = self.conv_last(x) | |
| if self.postprocessing: | |
| # out_x = torch.clamp(out_x, 0, 1) | |
| # ipdb.set_trace() | |
| out_x = self.convcrf(out_x, x_origin) | |
| return out_x | |
| def postprocessing_state(self, is_enamble=False): | |
| self.postprocessing = is_enamble | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = UNet(2) | |
| model = model.to(device) | |
| model.eval() | |
| for inputs, labels in dataloaders['train']: | |
| inputs = inputs.to(device) | |
| labels = labels.to(device) | |
| with torch.no_grad(): | |
| res = model(inputs) | |
| break | |
| def dice_loss(pred, target, smooth=1.): | |
| pred = pred.contiguous() | |
| target = target.contiguous() | |
| intersection = (pred * target).sum(dim=2).sum(dim=2) | |
| loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth))) | |
| return loss.mean() | |
| def calc_loss(pred, target, metrics, bce_weight=0.5): | |
| bce = F.binary_cross_entropy_with_logits(pred, target) | |
| pred = torch.sigmoid(pred) | |
| dice = dice_loss(pred, target) | |
| loss = bce * bce_weight + dice * (1 - bce_weight) | |
| metrics['bce'] += bce.data.cpu().numpy() * target.size(0) | |
| metrics['dice'] += dice.data.cpu().numpy() * target.size(0) | |
| metrics['loss'] += loss.data.cpu().numpy() * target.size(0) | |
| return loss | |
| def print_metrics(metrics, epoch_samples, phase): | |
| outputs = [] | |
| for k in metrics.keys(): | |
| outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples)) | |
| print("{}: {}".format(phase, ", ".join(outputs))) | |
| def train_model(model, optimizer, scheduler, num_epochs=25): | |
| best_model_wts = copy.deepcopy(model.state_dict()) | |
| best_loss = 1e10 | |
| for epoch in range(num_epochs): | |
| print('Epoch {}/{}'.format(epoch, num_epochs - 1)) | |
| print('-' * 10) | |
| since = time.time() | |
| # Each epoch has a training and validation phase | |
| for phase in ['train', 'val']: | |
| if phase == 'train': | |
| scheduler.step() | |
| for param_group in optimizer.param_groups: | |
| print("LR", param_group['lr']) | |
| model.train() # Set model to training mode | |
| else: | |
| model.eval() # Set model to evaluate mode | |
| metrics = defaultdict(float) | |
| epoch_samples = 0 | |
| for inputs, labels in dataloaders[phase]: | |
| inputs = inputs.to(device) | |
| labels = labels.to(device) | |
| # zero the parameter gradients | |
| optimizer.zero_grad() | |
| # forward | |
| # track history if only in train | |
| with torch.set_grad_enabled(phase == 'train'): | |
| outputs = model(inputs) | |
| loss = calc_loss(outputs, labels, metrics) | |
| # backward + optimize only if in training phase | |
| if phase == 'train': | |
| loss.backward() | |
| optimizer.step() | |
| # statistics | |
| epoch_samples += inputs.size(0) | |
| print_metrics(metrics, epoch_samples, phase) | |
| epoch_loss = metrics['loss'] / epoch_samples | |
| # deep copy the model | |
| if phase == 'val' and epoch_loss < best_loss: | |
| print("saving best model") | |
| best_loss = epoch_loss | |
| best_model_wts = copy.deepcopy(model.state_dict()) | |
| time_elapsed = time.time() - since | |
| print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) | |
| print('Best val loss: {:4f}'.format(best_loss)) | |
| # load best model weights | |
| model.load_state_dict(best_model_wts) | |
| return model | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| print(device) | |
| num_class = 2 | |
| model = UNet(num_class).to(device) | |
| # Observe that all parameters are being optimized | |
| optimizer_ft = optim.Adam(model.parameters(), lr=1e-4) | |
| exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=25, gamma=0.1) | |
| model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=2) | |
| model.eval() | |
| for inputs, labels in dataloaders['train']: | |
| inputs = inputs.to(device) | |
| labels = labels.to(device) | |
| with torch.no_grad(): | |
| res = model(inputs) | |
| break | |
| print('TRAININ WITH CONVCRF') | |
| model.postprocessing_state(True) | |
| model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=2) | |
| torch.save(model.state_dict(), 'model_artif.torch') | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| print(device) | |
| num_class = 2 | |
| model = UNet(num_class).to(device) | |
| model.load_state_dict(torch.load('model_artif.torch')) | |
| model.eval() | |
| def reverse_transform(inp): | |
| inp = inp.numpy().transpose((1, 2, 0)) | |
| inp = np.clip(inp, 0, 1) | |
| inp = (inp * 255).astype(np.uint8) | |
| return inp | |
| model.eval() # Set model to evaluate mode | |
| test_dataset = SimDataset(3, transform=trans) | |
| test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0) | |
| inputs, labels = next(iter(test_loader)) | |
| inputs = inputs.to(device) | |
| labels = labels.to(device) | |
| pred = model(inputs) | |
| pred = pred.data.cpu().numpy() | |
| print(pred.shape) | |
| input_images_rgb = [reverse_transform(x) for x in inputs.cpu()] | |
| target_masks_rgb = [helper.masks_to_colorimg(x) for x in labels.cpu().numpy()] | |
| pred_rgb = [helper.masks_to_colorimg(x) for x in pred] | |
| helper.plot_side_by_side([input_images_rgb, target_masks_rgb, pred_rgb]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment