Last active
May 17, 2016 17:44
-
-
Save m-philipp/799b8de8765e766924c8417f232e70e6 to your computer and use it in GitHub Desktop.
Learn with WEKA in JAVA
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| public class classifyDatum { | |
| public static void main(String[] args) throws Exception { | |
| // load the classifier | |
| Classifier cls = (Classifier) weka.core.SerializationHelper.read("your/path/to/training/data.csv.model"); | |
| // now classify data from a csv file | |
| classifyCsv(cls, "your/path/to/test/data.csv"); | |
| // or classify some value captured from somewhere | |
| classifyProgrammaticDatum(cls); | |
| } | |
| private static void classifyProgrammaticDatum(Classifier cls) throws Exception { | |
| // generate the instance programmatic | |
| Instances testData = getInstance(); | |
| // load and classify the first (and only instance) | |
| double pred = cls.classifyInstance(testData.instance(0)); | |
| System.out.print("Prediction was: " + pred); | |
| } | |
| private static void classifyCsv(Classifier cls, String testSetPath) throws Exception { | |
| // load testData | |
| DataSource source = new DataSource(testSetPath); | |
| Instances testData = source.getDataSet(); | |
| testData.setClassIndex(0); | |
| // iterate through the data | |
| for (int i = 0; i < testData.numInstances(); i++) { | |
| // classify instance wise | |
| double pred = cls.classifyInstance(testData.instance(i)); | |
| System.out.print("Prediction was: " + pred); | |
| } | |
| } | |
| public static Instances getInstance() { | |
| // some madeup values put your's in here | |
| Double val1 = 0.34; | |
| Double val2 = 0.82; | |
| Double val3 = 0.32; | |
| // Instances have Attributes so create a list for them | |
| ArrayList<Attribute> atts = new ArrayList<Attribute>(4); | |
| ArrayList<String> classVal = new ArrayList<String>(); | |
| classVal.add("something"); // here put in your first class label | |
| classVal.add("something else"); // here put in another class label etc. | |
| atts.add(new Attribute("@@class@@", classVal)); | |
| // add the attributes eg. describing some mean values | |
| atts.add(new Attribute("mean_X")); | |
| atts.add(new Attribute("mean_Y")); | |
| atts.add(new Attribute("mean_Z")); | |
| // create a new Instances Object and a double array containing the values | |
| Instances dataRaw = new Instances("TestInstances", atts, 0); | |
| double[] instanceValue = new double[dataRaw.numAttributes()]; | |
| // set the class | |
| instanceValue[0] = 0; | |
| // set the values | |
| instanceValue[1] = val1; | |
| instanceValue[2] = val2; | |
| instanceValue[3] = val3; | |
| // add the values as an instance to our Instances object | |
| dataRaw.add(new DenseInstance(1.0, instanceValue)); | |
| // set the class index | |
| dataRaw.setClassIndex(0); | |
| // return tha Instance packed in an instances object | |
| return dataRaw; | |
| } | |
| } | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| public class evaluateModel { | |
| public static void main(String[] args) throws Exception { | |
| Classifier cls = (Classifier) weka.core.SerializationHelper.read("your/path/to/training/data.csv.model"); | |
| testClassifier("your/path/to/training/data.csv", "your/path/to/test/data.csv", cls); | |
| } | |
| private static void testClassifier(String trainSetPath, String testSetPath, Classifier cls) throws Exception { | |
| // load training Data Set (Needed for the Evaluator) | |
| DataSource source = new DataSource(trainSetPath); | |
| Instances trainData = source.getDataSet(); | |
| trainData.setClassIndex(0); | |
| // load your test Data Set | |
| source = new DataSource(testSetPath); | |
| Instances testData = source.getDataSet(); | |
| testData.setClassIndex(0); | |
| // evaluate classifier | |
| Evaluation(trainData); | |
| eval.evaluateModel(cls, testData); | |
| // print the evaluation | |
| System.out.println(eval.toSummaryString("\nResults\n======\n", false)); | |
| } | |
| } | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| public class generateModel { | |
| public static void main(String[] args) throws Exception { | |
| String trainingDataPath = "your/path/to/training/data.csv"; | |
| trainClassifier(trainingDataPath, trainPath + ".model"); | |
| } | |
| private static void trainClassifier(String trainSetPath, String classifierSavePath) throws Exception { | |
| // load the csv file | |
| CSVLoader loader = new CSVLoader(); | |
| loader.setSource(new File(trainSetPath)); | |
| // get Instances | |
| Instances trainData = loader.getDataSet(); | |
| // set the class index. In this case the first colum indicates the class | |
| trainData.setClassIndex(0); | |
| // create a new Random Forest Classifier | |
| Classifier cls = new RandomForest(); | |
| // train the classifier on your data. | |
| cls.buildClassifier(trainData); | |
| // write your trained model to a file | |
| weka.core.SerializationHelper.write(classifierSavePath, cls); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment