import numpy as np from sklearn.metrics import confusion_matrix from collections import defaultdict def itr(confusion_matrix, timings, eps=1e-8): """Take a confusion matrix of the form (actual) a b c d (intended) a 10 0 1 9 b 2 11 7 2 c 0 0 15 0 d 5 3 12 9 and a list of the average time to specify each class a,b,c,d,... Maximum entropy is log2(n_classes) Residual entropy in each class is sum(p * log2(p)) We compute (max entropy - residual_entropy) and divide by the time per intended class to get approximate maximum possible ITR return the average information transfer rate and the standard deviation """ norm_matrix = (confusion_matrix.T / np.sum(confusion_matrix, axis=1)).T entropy = np.sum(norm_matrix * np.log2(norm_matrix + eps), axis=1) max_entropy = np.log2(len(confusion_matrix)) bps_per_class = (max_entropy + entropy) / timings return np.mean(bps_per_class), np.std(bps_per_class) ## Example of use # data in the form (intended_class, detected_class, duration_seconds) example_data = [ (0, 0, 0.5), (1, 1, 0.54), (3, 3, 0.53), (3, 2, 0.44), (1, 1, 0.42), (0, 0, 0.33), (2, 2, 0.92), ] # compute confusion matrix confusion = confusion_matrix( y_true=[intended for (intended, actual, duration) in example_data], y_pred=[actual for (intended, actual, duration) in example_data], ) ## compute average duration of each class times = defaultdict(float) n = defaultdict(int) for intended, actual, duration in example_data: times[intended] += duration n[intended] += duration mean_duration = [times.get(i, 0) / n.get(i, 1) for i in range(len(confusion))] ## show mean and std. itr print(itr(confusion, mean_duration)) ############## ## tests confusion = np.eye(8) mean_duration = [1]*8 # should get ~ 3 bits / second mean, sd = itr(confusion, mean_duration) assert abs(mean-3.0)<1e-5 # should get ~ 6 bits/second mean, sd = itr(confusion, [0.5]*8) print(mean) assert abs(mean-6.0)<1e-5 # should get ~ 1.5 bits/second mean, sd = itr(confusion, [2.0]*8) assert abs(mean-1.5)<1e-5 ## should be pretty much close to 0 for random performance, less than 0.5 bits/second for i in range(20): confusion = np.random.randint(0,200,(8,8)) mean, sd = itr(confusion, [1]*8) assert mean<0.5