Skip to content

Instantly share code, notes, and snippets.

@akesling
Last active August 15, 2024 03:08
Show Gist options
  • Select an option

  • Save akesling/5358964 to your computer and use it in GitHub Desktop.

Select an option

Save akesling/5358964 to your computer and use it in GitHub Desktop.

Revisions

  1. akesling revised this gist Aug 11, 2022. 1 changed file with 4 additions and 0 deletions.
    4 changes: 4 additions & 0 deletions mnist.py
    Original file line number Diff line number Diff line change
    @@ -3,6 +3,10 @@
    import numpy as np

    """
    MNist loading helper for Python 2.7.
    For Python 3.x, see https://gist.github.com/akesling/42393ccb868125071fdea77d98a0d2f0
    Loosely inspired by http://abel.ee.ucla.edu/cvxopt/_downloads/mnist.py
    which is GPL licensed.
    """
  2. Alex Kesling revised this gist Apr 11, 2013. 1 changed file with 3 additions and 2 deletions.
    5 changes: 3 additions & 2 deletions mnist.py
    Original file line number Diff line number Diff line change
    @@ -42,11 +42,12 @@ def show(image):
    """
    Render a given numpy.uint8 2D array of pixel data.
    """
    from matplotlib import pyplot
    import matplotlib as mpl
    fig = figure()
    fig = pyplot.figure()
    ax = fig.add_subplot(1,1,1)
    imgplot = ax.imshow(image, cmap=mpl.cm.Greys)
    imgplot.set_interpolation('nearest')
    ax.xaxis.set_ticks_position('top')
    ax.yaxis.set_ticks_position('left')
    mpl.pyplot.show()
    pyplot.show()
  3. Alex Kesling revised this gist Apr 11, 2013. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion mnist.py
    Original file line number Diff line number Diff line change
    @@ -43,7 +43,7 @@ def show(image):
    Render a given numpy.uint8 2D array of pixel data.
    """
    import matplotlib as mpl
    fig = mpl.pyplot.figure()
    fig = figure()
    ax = fig.add_subplot(1,1,1)
    imgplot = ax.imshow(image, cmap=mpl.cm.Greys)
    imgplot.set_interpolation('nearest')
  4. Alex Kesling created this gist Apr 10, 2013.
    52 changes: 52 additions & 0 deletions mnist.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,52 @@
    import os
    import struct
    import numpy as np

    """
    Loosely inspired by http://abel.ee.ucla.edu/cvxopt/_downloads/mnist.py
    which is GPL licensed.
    """

    def read(dataset = "training", path = "."):
    """
    Python function for importing the MNIST data set. It returns an iterator
    of 2-tuples with the first element being the label and the second element
    being a numpy.uint8 2D array of pixel data for the given image.
    """

    if dataset is "training":
    fname_img = os.path.join(path, 'train-images-idx3-ubyte')
    fname_lbl = os.path.join(path, 'train-labels-idx1-ubyte')
    elif dataset is "testing":
    fname_img = os.path.join(path, 't10k-images-idx3-ubyte')
    fname_lbl = os.path.join(path, 't10k-labels-idx1-ubyte')
    else:
    raise ValueError, "dataset must be 'testing' or 'training'"

    # Load everything in some numpy arrays
    with open(fname_lbl, 'rb') as flbl:
    magic, num = struct.unpack(">II", flbl.read(8))
    lbl = np.fromfile(flbl, dtype=np.int8)

    with open(fname_img, 'rb') as fimg:
    magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
    img = np.fromfile(fimg, dtype=np.uint8).reshape(len(lbl), rows, cols)

    get_img = lambda idx: (lbl[idx], img[idx])

    # Create an iterator which returns each image in turn
    for i in xrange(len(lbl)):
    yield get_img(i)

    def show(image):
    """
    Render a given numpy.uint8 2D array of pixel data.
    """
    import matplotlib as mpl
    fig = mpl.pyplot.figure()
    ax = fig.add_subplot(1,1,1)
    imgplot = ax.imshow(image, cmap=mpl.cm.Greys)
    imgplot.set_interpolation('nearest')
    ax.xaxis.set_ticks_position('top')
    ax.yaxis.set_ticks_position('left')
    mpl.pyplot.show()