-
-
Save jgao1025/1d0b84cac5c11c5c251d29f54581594c to your computer and use it in GitHub Desktop.
deeplearning4j-cnn test
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import org.canova.api.records.reader.RecordReader; | |
| import org.canova.api.split.FileSplit; | |
| import org.canova.image.recordreader.ImageRecordReader; | |
| import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator; | |
| import org.deeplearning4j.datasets.iterator.DataSetIterator; | |
| import org.deeplearning4j.eval.Evaluation; | |
| import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
| import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
| import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
| import org.deeplearning4j.nn.conf.Updater; | |
| import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; | |
| import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
| import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
| import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; | |
| import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup; | |
| import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
| import org.deeplearning4j.nn.weights.WeightInit; | |
| import org.deeplearning4j.ui.weights.HistogramIterationListener; | |
| import org.nd4j.linalg.api.ndarray.INDArray; | |
| import org.nd4j.linalg.dataset.DataSet; | |
| import org.nd4j.linalg.lossfunctions.LossFunctions; | |
| import org.slf4j.Logger; | |
| import org.slf4j.LoggerFactory; | |
| import java.io.File; | |
| import java.io.IOException; | |
| import java.util.ArrayList; | |
| import java.util.List; | |
| public class imageclassification { | |
| private static final Logger log = LoggerFactory.getLogger(imageclassification.class); | |
| static MultiLayerConfiguration.Builder buildModel(String modelType) throws ClassNotFoundException, IllegalAccessException, InstantiationException { | |
| Class builder = Class.forName(modelType); | |
| Object o = builder.newInstance(); | |
| return (MultiLayerConfiguration.Builder) o; | |
| } | |
| public static void main(String[] args) throws Exception { | |
| log.info("hello world!"); | |
| String labeledPath; | |
| String labeledPathTest; | |
| int nChannels = 3; | |
| int outputNum = 5; | |
| int batchSize = 50; | |
| int nEpochs = 20; | |
| int iterations = 1; | |
| int seed = 123; | |
| int numRows = 32; | |
| int numColumns = 32; | |
| RecordReader recordReader; | |
| RecordReader recordReaderTest; | |
| DataSetIterator colTrain; | |
| DataSetIterator colTest; | |
| labeledPath = "C:\\Users\\Pictures\\worldCup2\\train"; | |
| labeledPathTest = "C:\\Users\\Pictures\\worldCup2\\validata"; | |
| //create array of strings called labels | |
| List<String> labels = new ArrayList<>(); | |
| List<String> labelsTest = new ArrayList<>(); | |
| //traverse dataset to get each label | |
| for(File f : new File(labeledPath).listFiles()) { | |
| labels.add(f.getName()); | |
| } | |
| for(File f : new File(labeledPathTest).listFiles()) { | |
| labelsTest.add(f.getName()); | |
| } | |
| // Instantiating RecordReader. Specify height and width of images. | |
| recordReader = new ImageRecordReader(numRows, numColumns, nChannels, true, labels); | |
| recordReaderTest = new ImageRecordReader(numRows, numColumns, nChannels, true, labelsTest); | |
| // Point to data path. | |
| try { | |
| recordReader.initialize(new FileSplit(new File(labeledPath))); | |
| recordReaderTest.initialize(new FileSplit(new File(labeledPathTest))); | |
| } catch (IOException e) { | |
| e.printStackTrace(); | |
| } catch (InterruptedException e) { | |
| e.printStackTrace(); | |
| } | |
| colTrain = new RecordReaderDataSetIterator(recordReader, batchSize, numRows * numColumns * nChannels, labels.size()); | |
| colTest = new RecordReaderDataSetIterator(recordReaderTest, batchSize, numRows * numColumns * nChannels, labelsTest.size()); | |
| // MultiLayerConfiguration.Builder builder = buildModel("convol"); | |
| /* test | |
| while (colTrain.hasNext()) { | |
| DataSet ds = colTrain.next(); | |
| System.out.println(ds.getFeatureMatrix()); | |
| System.out.println(ds.getLabels()); | |
| System.out.println(ds.getLabelNames()); | |
| } | |
| */ | |
| log.info("Build model...."); | |
| MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() | |
| .seed(seed) | |
| .iterations(iterations) | |
| .regularization(true).l2(0.0005) | |
| .learningRate(0.01) | |
| .weightInit(WeightInit.XAVIER) | |
| .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
| .updater(Updater.NESTEROVS).momentum(0.9) | |
| .list(8) | |
| .layer(0, new ConvolutionLayer.Builder(5, 5).nIn(nChannels).stride(1,1).nOut(16).activation("relu").build()) | |
| .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2).stride(2,2).build()) | |
| .layer(2, new ConvolutionLayer.Builder(5, 5).nIn(16).stride(1, 1).nOut(20).activation("relu").build()) | |
| .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2).stride(2,2).build()) | |
| .layer(4, new ConvolutionLayer.Builder(5,5).nIn(20).stride(1, 1).padding(1,1).nOut(20).activation("relu").build()) | |
| .layer(5, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2,2).stride(2,2).build()) | |
| .layer(6, new DenseLayer.Builder().activation("relu").nOut(320).build()) | |
| .layer(7, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).activation("softmax").build()) | |
| .backprop(true).pretrain(false); | |
| new ConvolutionLayerSetup(builder,numRows, numColumns,nChannels); | |
| MultiLayerConfiguration conf = builder.build(); | |
| MultiLayerNetwork model = new MultiLayerNetwork(conf); | |
| model.init(); | |
| log.info("Train model...."); | |
| model.setListeners(new HistogramIterationListener(1)); | |
| //model.setListeners(new ScoreIterationListener(2)); | |
| for( int i=0; i<nEpochs; i++ ) { | |
| colTrain.reset(); | |
| colTest.reset(); | |
| while (colTrain.hasNext()) { | |
| model.fit(colTrain.next()); | |
| } | |
| log.info("Evaluate model...."); | |
| Evaluation eval = new Evaluation(outputNum); | |
| while (colTest.hasNext()) { | |
| DataSet ds = colTest.next(); | |
| INDArray output = model.output(ds.getFeatureMatrix()); | |
| eval.eval(ds.getLabels(), output); | |
| } | |
| log.info(eval.stats()); | |
| } | |
| log.info("****************Example finished********************"); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment