-
-
Save kennyballou/e26ddeb469509f059b70 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import division | |
| import collections | |
| import math | |
| class Model: | |
| def __init__(self, arffFile): | |
| print "hello" | |
| self.trainingFile = arffFile | |
| self.features = {} #feature names and vals (including class labels) | |
| self.featureNameList = [] #maintain the order of features as in the arff | |
| self.featureCounts = collections.defaultdict(lambda: 1)#(label, feature_name, feature_value) | |
| self.featureVectors = [] #contains all the values and the label as the last entry | |
| self.labelCounts = collections.defaultdict(lambda: 0) #these will be smoothed later | |
| def TrainClassifier(self): | |
| for fv in self.featureVectors: | |
| self.labelCounts[fv[len(fv)-1]] += 1 #udpate count of the label | |
| for counter in range(0, len(fv)-1): | |
| self.featureCounts[(fv[len(fv)-1], self.featureNameList[counter], fv[counter])] += 1 | |
| for label in self.labelCounts: #last feature is actually the label | |
| for feature in self.featureNameList[:len(self.featureNameList)-1]: | |
| self.labelCounts[label] += len(self.features[feature]) | |
| def Classify(self, featureVector): | |
| probabilityPerLabel = {} | |
| for label in self.labelCounts: | |
| logProb = 0 | |
| for featureValue in featureVector: | |
| logProb += math.log(self.featureCounts[(label, self.featureNameList[featureVector.index(featureValue)], featureValue)]/self.labelCounts[label]) | |
| probabilityPerLabel[label] = math.exp(logProb) | |
| print probabilityPerLabel | |
| return max(probabilityPerLabel, key = lambda classLabel: probabilityPerLabel[classLabel]) | |
| def GetValues(self): | |
| file = open(self.trainingFile, 'r') | |
| for line in file: | |
| if line[0] != '@': #start of actual data | |
| self.featureVectors.append(line.strip().lower().split(',')) | |
| else: #feature definitions | |
| if line.strip().lower().find('@data') == -1 and | |
| (not line.lower().startswith('@relation')): | |
| self.featureNameList.append(line.strip().split()[1]) | |
| self.features[self.featureNameList[len(self.featureNameList) - 1]] = line[line.find('{')+1: line.find('}')].strip().split(',') | |
| file.close() | |
| def TestClassifier(self, arffFile): | |
| file = open(arffFile, 'r') | |
| for line in file: | |
| if line[0] != '@': | |
| vector = line.strip().lower().split(',') | |
| print "classifier: " + self.Classify(vector) + " given " + vector[len(vector) - 1] | |
| if __name__ == "__main__": | |
| model = Model("/home/wirelive/Dropbox/code/NaiveBayesPython/tennis.arff") | |
| model.GetValues() | |
| model.TrainClassifier() | |
| model.TestClassifier("/home/wirelive/Dropbox/code/NaiveBayesPython/tennis.arff") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment