Skip to content

Instantly share code, notes, and snippets.

@jgao1025
Created May 4, 2016 08:39
Show Gist options
  • Select an option

  • Save jgao1025/1d0b84cac5c11c5c251d29f54581594c to your computer and use it in GitHub Desktop.

Select an option

Save jgao1025/1d0b84cac5c11c5c251d29f54581594c to your computer and use it in GitHub Desktop.

Revisions

  1. bottles created this gist May 4, 2016.
    154 changes: 154 additions & 0 deletions test
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,154 @@
    import org.canova.api.records.reader.RecordReader;
    import org.canova.api.split.FileSplit;
    import org.canova.image.recordreader.ImageRecordReader;
    import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator;
    import org.deeplearning4j.datasets.iterator.DataSetIterator;
    import org.deeplearning4j.eval.Evaluation;
    import org.deeplearning4j.nn.api.OptimizationAlgorithm;
    import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
    import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
    import org.deeplearning4j.nn.conf.Updater;
    import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
    import org.deeplearning4j.nn.conf.layers.DenseLayer;
    import org.deeplearning4j.nn.conf.layers.OutputLayer;
    import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
    import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup;
    import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
    import org.deeplearning4j.nn.weights.WeightInit;
    import org.deeplearning4j.ui.weights.HistogramIterationListener;
    import org.nd4j.linalg.api.ndarray.INDArray;
    import org.nd4j.linalg.dataset.DataSet;
    import org.nd4j.linalg.lossfunctions.LossFunctions;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;

    import java.io.File;
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.List;

    public class imageclassification {
    private static final Logger log = LoggerFactory.getLogger(imageclassification.class);


    static MultiLayerConfiguration.Builder buildModel(String modelType) throws ClassNotFoundException, IllegalAccessException, InstantiationException {
    Class builder = Class.forName(modelType);
    Object o = builder.newInstance();

    return (MultiLayerConfiguration.Builder) o;
    }
    public static void main(String[] args) throws Exception {

    log.info("hello world!");

    String labeledPath;
    String labeledPathTest;

    int nChannels = 3;
    int outputNum = 5;
    int batchSize = 50;
    int nEpochs = 20;
    int iterations = 1;
    int seed = 123;
    int numRows = 32;
    int numColumns = 32;
    RecordReader recordReader;
    RecordReader recordReaderTest;
    DataSetIterator colTrain;
    DataSetIterator colTest;


    labeledPath = "C:\\Users\\Pictures\\worldCup2\\train";
    labeledPathTest = "C:\\Users\\Pictures\\worldCup2\\validata";

    //create array of strings called labels
    List<String> labels = new ArrayList<>();
    List<String> labelsTest = new ArrayList<>();

    //traverse dataset to get each label
    for(File f : new File(labeledPath).listFiles()) {
    labels.add(f.getName());
    }
    for(File f : new File(labeledPathTest).listFiles()) {
    labelsTest.add(f.getName());
    }

    // Instantiating RecordReader. Specify height and width of images.
    recordReader = new ImageRecordReader(numRows, numColumns, nChannels, true, labels);
    recordReaderTest = new ImageRecordReader(numRows, numColumns, nChannels, true, labelsTest);

    // Point to data path.
    try {
    recordReader.initialize(new FileSplit(new File(labeledPath)));
    recordReaderTest.initialize(new FileSplit(new File(labeledPathTest)));
    } catch (IOException e) {
    e.printStackTrace();
    } catch (InterruptedException e) {
    e.printStackTrace();
    }



    colTrain = new RecordReaderDataSetIterator(recordReader, batchSize, numRows * numColumns * nChannels, labels.size());
    colTest = new RecordReaderDataSetIterator(recordReaderTest, batchSize, numRows * numColumns * nChannels, labelsTest.size());
    // MultiLayerConfiguration.Builder builder = buildModel("convol");


    /* test
    while (colTrain.hasNext()) {

    DataSet ds = colTrain.next();
    System.out.println(ds.getFeatureMatrix());
    System.out.println(ds.getLabels());
    System.out.println(ds.getLabelNames());
    }
    */

    log.info("Build model....");
    MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
    .seed(seed)
    .iterations(iterations)
    .regularization(true).l2(0.0005)
    .learningRate(0.01)
    .weightInit(WeightInit.XAVIER)
    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
    .updater(Updater.NESTEROVS).momentum(0.9)
    .list(8)
    .layer(0, new ConvolutionLayer.Builder(5, 5).nIn(nChannels).stride(1,1).nOut(16).activation("relu").build())
    .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2).stride(2,2).build())
    .layer(2, new ConvolutionLayer.Builder(5, 5).nIn(16).stride(1, 1).nOut(20).activation("relu").build())
    .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2).stride(2,2).build())
    .layer(4, new ConvolutionLayer.Builder(5,5).nIn(20).stride(1, 1).padding(1,1).nOut(20).activation("relu").build())
    .layer(5, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2).stride(2,2).build())
    .layer(6, new DenseLayer.Builder().activation("relu").nOut(320).build())
    .layer(7, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).activation("softmax").build())
    .backprop(true).pretrain(false);
    new ConvolutionLayerSetup(builder,numRows, numColumns,nChannels);

    MultiLayerConfiguration conf = builder.build();
    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();

    log.info("Train model....");
    model.setListeners(new HistogramIterationListener(1));
    //model.setListeners(new ScoreIterationListener(2));
    for( int i=0; i<nEpochs; i++ ) {
    colTrain.reset();
    colTest.reset();
    while (colTrain.hasNext()) {
    model.fit(colTrain.next());
    }
    log.info("Evaluate model....");
    Evaluation eval = new Evaluation(outputNum);
    while (colTest.hasNext()) {
    DataSet ds = colTest.next();
    INDArray output = model.output(ds.getFeatureMatrix());
    eval.eval(ds.getLabels(), output);
    }
    log.info(eval.stats());

    }

    log.info("****************Example finished********************");
    }
    }