Skip to content

Instantly share code, notes, and snippets.

@Timopheym
Created December 21, 2015 23:03
Show Gist options
  • Select an option

  • Save Timopheym/3f3bd380f61876f7e7a5 to your computer and use it in GitHub Desktop.

Select an option

Save Timopheym/3f3bd380f61876f7e7a5 to your computer and use it in GitHub Desktop.

Revisions

  1. Timopheym created this gist Dec 21, 2015.
    141 changes: 141 additions & 0 deletions train_mnist.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,141 @@

    from sklearn.cross_validation import StratifiedShuffleSplit
    from sklearn.decomposition import PCA
    from pandas import read_csv, DataFrame

    import find_mxnet
    import mxnet as mx
    import argparse
    import os, sys
    import train_model_poc

    train = read_csv('/Users/timopheym/Desktop/Projects/data_analysis/learn/kaggle/mnist/train.csv')
    test = read_csv('/Users/timopheym/Desktop/Projects/data_analysis/learn/kaggle/mnist/test.csv')

    train_y = train['label'].as_matrix()
    train_X = train.drop('label', 1).as_matrix()

    # t = train.drop('label', 1).as_matrix()
    train_X = train_X.reshape((len(train_X), 28, 28))
    # train_y = train_X.reshape((len(train_y), 28, 28))
    # pca = PCA()
    # pca.fit(train_X)
    # train_X_pca = pca.transform(train_X)[:, 0:60]
    # train_shuf = StratifiedShuffleSplit(train_y, n_iter = 10, test_size = .2, random_state = 123)

    parser = argparse.ArgumentParser(description='train an image classifer on mnist')
    parser.add_argument('--network', type=str, default='mlp',
    choices = ['mlp', 'lenet'],
    help = 'the cnn to use')
    parser.add_argument('--data-dir', type=str, default='mnist/',
    help='the input data directory')
    parser.add_argument('--gpus', type=str,
    help='the gpus will be used, e.g "0,1,2,3"')
    parser.add_argument('--num-examples', type=int, default=60000,
    help='the number of training examples')
    parser.add_argument('--batch-size', type=int, default=128,
    help='the batch size')
    parser.add_argument('--lr', type=float, default=.1,
    help='the initial learning rate')
    parser.add_argument('--model-prefix', type=str,
    help='the prefix of the model to load/save')
    parser.add_argument('--num-epochs', type=int, default=10,
    help='the number of training epochs')
    parser.add_argument('--load-epoch', type=int,
    help="load the model on an epoch using the model-prefix")
    parser.add_argument('--kv-store', type=str, default='local',
    help='the kvstore type')
    parser.add_argument('--lr-factor', type=float, default=1,
    help='times the lr with a factor for every lr-factor-epoch epoch')
    parser.add_argument('--lr-factor-epoch', type=float, default=1,
    help='the number of epoch to factor the lr, could be .5')
    args = parser.parse_args()

    def _download(data_dir):
    if not os.path.isdir(data_dir):
    os.system("mkdir " + data_dir)
    os.chdir(data_dir)
    if (not os.path.exists('train-images-idx3-ubyte')) or \
    (not os.path.exists('train-labels-idx1-ubyte')) or \
    (not os.path.exists('t10k-images-idx3-ubyte')) or \
    (not os.path.exists('t10k-labels-idx1-ubyte')):
    os.system("wget http://webdocs.cs.ualberta.ca/~bx3/data/mnist.zip")
    os.system("unzip -u mnist.zip; rm mnist.zip")
    os.chdir("..")

    def get_mlp():
    """
    multi-layer perceptron
    """
    data = mx.symbol.Variable('data')
    fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
    act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
    fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
    act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
    fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
    mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
    return mlp

    def get_lenet():
    """
    LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick
    Haffner. "Gradient-based learning applied to document recognition."
    Proceedings of the IEEE (1998)
    """
    data = mx.symbol.Variable('data')
    # first conv
    conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20)
    tanh1 = mx.symbol.Activation(data=conv1, act_type="tanh")
    pool1 = mx.symbol.Pooling(data=tanh1, pool_type="max",
    kernel=(2,2), stride=(2,2))
    # second conv
    conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50)
    tanh2 = mx.symbol.Activation(data=conv2, act_type="tanh")
    pool2 = mx.symbol.Pooling(data=tanh2, pool_type="max",
    kernel=(2,2), stride=(2,2))
    # first fullc
    flatten = mx.symbol.Flatten(data=pool2)
    fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
    tanh3 = mx.symbol.Activation(data=fc1, act_type="tanh")
    # second fullc
    fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=10)
    # loss
    lenet = mx.symbol.SoftmaxOutput(data=fc2, name='softmax')
    return lenet

    if args.network == 'mlp':
    data_shape = (784, )
    net = get_mlp()
    else:
    data_shape = (1, 28, 28)
    net = get_lenet()

    def get_iterator(args, kv):
    data_dir = args.data_dir
    if '://' not in args.data_dir:
    _download(args.data_dir)
    flat = False if len(data_shape) == 3 else True

    train = mx.io.MNISTIter(
    image = data_dir + "train-images-idx3-ubyte",
    label = data_dir + "train-labels-idx1-ubyte",
    input_shape = data_shape,
    batch_size = args.batch_size,
    shuffle = True,
    flat = flat,
    num_parts = kv.num_workers,
    part_index = kv.rank)

    val = mx.io.MNISTIter(
    image = data_dir + "t10k-images-idx3-ubyte",
    label = data_dir + "t10k-labels-idx1-ubyte",
    input_shape = data_shape,
    batch_size = args.batch_size,
    flat = flat,
    num_parts = kv.num_workers,
    part_index = kv.rank)

    return (train, val)

    # train
    train_model_poc.fit(args, net, train_X, train_y)
    72 changes: 72 additions & 0 deletions train_model.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,72 @@
    import find_mxnet
    import mxnet as mx
    import logging

    def fit(args, network, train_X, train_Y):
    # kvstore
    kv = mx.kvstore.create(args.kv_store)

    # logging
    head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
    logging.basicConfig(level=logging.DEBUG, format=head)
    logging.info('start with arguments %s', args)

    # load model?
    model_prefix = args.model_prefix
    if model_prefix is not None:
    model_prefix += "-%d" % (kv.rank)
    model_args = {}
    if args.load_epoch is not None:
    assert model_prefix is not None
    tmp = mx.model.FeedForward.load(model_prefix, args.load_epoch)
    model_args = {'arg_params' : tmp.arg_params,
    'aux_params' : tmp.aux_params,
    'begin_epoch' : args.load_epoch}
    # save model?
    checkpoint = None if model_prefix is None else mx.callback.do_checkpoint(model_prefix)

    # data
    # (train, val) = data_loader(args, kv)

    # train
    devs = mx.cpu() if args.gpus is None else [
    mx.gpu(int(i)) for i in args.gpus.split(',')]

    epoch_size = args.num_examples / args.batch_size

    if args.kv_store == 'dist_sync':
    epoch_size /= kv.num_workers
    model_args['epoch_size'] = epoch_size

    if 'lr_factor' in args and args.lr_factor < 1:
    model_args['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
    step = max(int(epoch_size * args.lr_factor_epoch), 1),
    factor = args.lr_factor)

    if 'clip_gradient' in args and args.clip_gradient is not None:
    model_args['clip_gradient'] = args.clip_gradient

    # disable kvstore for single device
    if 'local' in kv.type and (
    args.gpus is None or len(args.gpus.split(',')) is 1):
    kv = None

    model = mx.model.FeedForward(
    ctx = devs,
    symbol = network,
    num_epoch = args.num_epochs,
    learning_rate = args.lr,
    momentum = 0.9,
    wd = 0.00001,
    initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),
    **model_args)
    print(train_Y)
    # model.fit(X= train_X, y = train_Y)
    model.fit(
    X = train_X,
    y = train_Y,
    # eval_data = val,
    # kvstore = kv,
    # batch_end_callback = mx.callback.Speedometer(args.batch_size, 50),
    # epoch_end_callback = checkpoint
    )