Created
June 14, 2019 16:09
-
-
Save crearo/5c702bf40d8df1f888108c22e0291ed1 to your computer and use it in GitHub Desktop.
Revisions
-
crearo created this gist
Jun 14, 2019 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,53 @@ ''' Keras's ImageDataGenerator is buggy. It first resizes based on `target_size`, and then applies the preprocessing function you specified. This sucks. Who wants that to happen anyway. ''' def crop_image(img): print(img.shape) return img def load_data(): train_datagen = ImageDataGenerator(rotation_range=90, rescale=1. / 255, preprocessing_function=crop_image) test_datagen = ImageDataGenerator(rescale=1. / 255, preprocessing_function=crop_image) train_gen = train_datagen.flow_from_directory( '%s/train' % base, target_size=(256, 256), color_mode='grayscale', batch_size=batch_size, class_mode='categorical') test_gen = test_datagen.flow_from_directory( '%s/test' % base, target_size=(256, 256), color_mode='grayscale', batch_size=batch_size, class_mode='categorical') return train_gen, test_gen def plot_data_gen(train_gen, test_gen): for X, y in train_gen: plt.figure(figsize=(16, 16)) for i in range(25): plt.subplot(5, 5, i + 1) plt.axis('off') plt.title('Label: %d' % np.argmax(y[i])) img = np.uint8(255 * X[i, :, :, 0]) plt.imshow(img, cmap='gray') break plt.show() if __name__ == '__main__': model = load_model() train_gen, test_gen = load_data() plot_data_gen(train_gen, test_gen)