Created
July 29, 2020 05:14
-
-
Save otakbeku/c3aed0f2f797177859aa2be7854fc6cb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # XAI | |
| from captum.attr import Saliency | |
| from captum.attr import visualization as viz | |
| import torchvision | |
| def attribute_image_features(model, algorithm, input_data, target_label, **kwargs): | |
| model.zero_grad() | |
| tensor_attributions = algorithm.attribute( | |
| input_data, | |
| target=target_label, | |
| **kwargs) | |
| return tensor_attributions | |
| # Model sudah diload | |
| # torch.cuda.empty_cache() # In case butuh kalo pake cuda | |
| dataiter = iter(valid_loader_inception) | |
| images, labels = dataiter.next() | |
| print(len(labels) | |
| print('Label: ', ' '.join(classes[labels[j]] for j in range(len(labels)))) | |
| output = inception_v3_funed(images.cuda()) | |
| _, predicted = torch.max(output, 1) | |
| print('Predicted: ', ' '.join(classes[predicted[j]] for j in range(len(labels)))) | |
| rd_idx = 6 # Ini yang diganti dengan index dari fotonya | |
| input_image = images[rd_idx].unsqueeze(0) | |
| input_image.requires_grad = True | |
| # Saliency Check | |
| saliency = Saliency(model) | |
| grads = saliency.attribute(input_image.cuda(), target=labels[rd_idx].item()) | |
| grads = np.transpose(grads.squeeze().cpu().detach().numpy(), (1, 2, 0)) | |
| print(f'Original Image. Label: {classes[labels[rd_idx]]}') | |
| print('Predicted:', classes[predicted[rd_idx]], 'probability: ', torch.max(F.softmax(output, 1)).item()) | |
| original_image = np.transpose((images[rd_idx].cpu().detach().numpy()/2) + 0.5, (1, 2, 0)) | |
| _ = viz.visualize_image_attr(None, original_image, method='original_image', title='Original Image') | |
| _ = viz.visualize_image_attr(grads, original_image, method='blended_heat_map', sign='absolute_value', show_colorbar=True, title='Overlayed Gradient Magnitude') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment