Skip to content

Instantly share code, notes, and snippets.

@jgao1025
Created May 4, 2016 08:39
Show Gist options
  • Select an option

  • Save jgao1025/1d0b84cac5c11c5c251d29f54581594c to your computer and use it in GitHub Desktop.

Select an option

Save jgao1025/1d0b84cac5c11c5c251d29f54581594c to your computer and use it in GitHub Desktop.
deeplearning4j-cnn test
import org.canova.api.records.reader.RecordReader;
import org.canova.api.split.FileSplit;
import org.canova.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.ui.weights.HistogramIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class imageclassification {
private static final Logger log = LoggerFactory.getLogger(imageclassification.class);
static MultiLayerConfiguration.Builder buildModel(String modelType) throws ClassNotFoundException, IllegalAccessException, InstantiationException {
Class builder = Class.forName(modelType);
Object o = builder.newInstance();
return (MultiLayerConfiguration.Builder) o;
}
public static void main(String[] args) throws Exception {
log.info("hello world!");
String labeledPath;
String labeledPathTest;
int nChannels = 3;
int outputNum = 5;
int batchSize = 50;
int nEpochs = 20;
int iterations = 1;
int seed = 123;
int numRows = 32;
int numColumns = 32;
RecordReader recordReader;
RecordReader recordReaderTest;
DataSetIterator colTrain;
DataSetIterator colTest;
labeledPath = "C:\\Users\\Pictures\\worldCup2\\train";
labeledPathTest = "C:\\Users\\Pictures\\worldCup2\\validata";
//create array of strings called labels
List<String> labels = new ArrayList<>();
List<String> labelsTest = new ArrayList<>();
//traverse dataset to get each label
for(File f : new File(labeledPath).listFiles()) {
labels.add(f.getName());
}
for(File f : new File(labeledPathTest).listFiles()) {
labelsTest.add(f.getName());
}
// Instantiating RecordReader. Specify height and width of images.
recordReader = new ImageRecordReader(numRows, numColumns, nChannels, true, labels);
recordReaderTest = new ImageRecordReader(numRows, numColumns, nChannels, true, labelsTest);
// Point to data path.
try {
recordReader.initialize(new FileSplit(new File(labeledPath)));
recordReaderTest.initialize(new FileSplit(new File(labeledPathTest)));
} catch (IOException e) {
e.printStackTrace();
} catch (InterruptedException e) {
e.printStackTrace();
}
colTrain = new RecordReaderDataSetIterator(recordReader, batchSize, numRows * numColumns * nChannels, labels.size());
colTest = new RecordReaderDataSetIterator(recordReaderTest, batchSize, numRows * numColumns * nChannels, labelsTest.size());
// MultiLayerConfiguration.Builder builder = buildModel("convol");
/* test
while (colTrain.hasNext()) {
DataSet ds = colTrain.next();
System.out.println(ds.getFeatureMatrix());
System.out.println(ds.getLabels());
System.out.println(ds.getLabelNames());
}
*/
log.info("Build model....");
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.regularization(true).l2(0.0005)
.learningRate(0.01)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Updater.NESTEROVS).momentum(0.9)
.list(8)
.layer(0, new ConvolutionLayer.Builder(5, 5).nIn(nChannels).stride(1,1).nOut(16).activation("relu").build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2).stride(2,2).build())
.layer(2, new ConvolutionLayer.Builder(5, 5).nIn(16).stride(1, 1).nOut(20).activation("relu").build())
.layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2).stride(2,2).build())
.layer(4, new ConvolutionLayer.Builder(5,5).nIn(20).stride(1, 1).padding(1,1).nOut(20).activation("relu").build())
.layer(5, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2).stride(2,2).build())
.layer(6, new DenseLayer.Builder().activation("relu").nOut(320).build())
.layer(7, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).activation("softmax").build())
.backprop(true).pretrain(false);
new ConvolutionLayerSetup(builder,numRows, numColumns,nChannels);
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
log.info("Train model....");
model.setListeners(new HistogramIterationListener(1));
//model.setListeners(new ScoreIterationListener(2));
for( int i=0; i<nEpochs; i++ ) {
colTrain.reset();
colTest.reset();
while (colTrain.hasNext()) {
model.fit(colTrain.next());
}
log.info("Evaluate model....");
Evaluation eval = new Evaluation(outputNum);
while (colTest.hasNext()) {
DataSet ds = colTest.next();
INDArray output = model.output(ds.getFeatureMatrix());
eval.eval(ds.getLabels(), output);
}
log.info(eval.stats());
}
log.info("****************Example finished********************");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment