Skip to content

Instantly share code, notes, and snippets.

@crearo
Created June 14, 2019 16:09
Show Gist options
  • Select an option

  • Save crearo/5c702bf40d8df1f888108c22e0291ed1 to your computer and use it in GitHub Desktop.

Select an option

Save crearo/5c702bf40d8df1f888108c22e0291ed1 to your computer and use it in GitHub Desktop.

Revisions

  1. crearo created this gist Jun 14, 2019.
    53 changes: 53 additions & 0 deletions keras-preprocessing-bug.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,53 @@
    '''
    Keras's ImageDataGenerator is buggy. It first resizes based on `target_size`,
    and then applies the preprocessing function you specified. This sucks.
    Who wants that to happen anyway.
    '''

    def crop_image(img):
    print(img.shape)
    return img


    def load_data():
    train_datagen = ImageDataGenerator(rotation_range=90,
    rescale=1. / 255,
    preprocessing_function=crop_image)

    test_datagen = ImageDataGenerator(rescale=1. / 255,
    preprocessing_function=crop_image)

    train_gen = train_datagen.flow_from_directory(
    '%s/train' % base,
    target_size=(256, 256),
    color_mode='grayscale',
    batch_size=batch_size,
    class_mode='categorical')

    test_gen = test_datagen.flow_from_directory(
    '%s/test' % base,
    target_size=(256, 256),
    color_mode='grayscale',
    batch_size=batch_size,
    class_mode='categorical')

    return train_gen, test_gen


    def plot_data_gen(train_gen, test_gen):
    for X, y in train_gen:
    plt.figure(figsize=(16, 16))
    for i in range(25):
    plt.subplot(5, 5, i + 1)
    plt.axis('off')
    plt.title('Label: %d' % np.argmax(y[i]))
    img = np.uint8(255 * X[i, :, :, 0])
    plt.imshow(img, cmap='gray')
    break
    plt.show()


    if __name__ == '__main__':
    model = load_model()
    train_gen, test_gen = load_data()
    plot_data_gen(train_gen, test_gen)