import java.nio.ByteBuffer import org.apache.spark.{SparkContext, SparkConf} import org.apache.commons.math3.distribution.ExponentialDistribution import scala.collection._ import scala.collection.generic.CanBuildFrom import scala.util.Random case class Centroid(var mean: Double, var count: Long) extends Ordered[Centroid] with Serializable { def update(x: Double, weight: Long): Unit = { this.count += weight this.mean += weight * (x - this.mean) / this.count } def compare(that: Centroid): Int = this.mean compare that.mean override def equals(o: Any): Boolean = o match { case that: Centroid => that.mean == this.mean case _ => false } override def hashCode: Int = this.mean.hashCode() } case class TDigest(delta: Double = 0.01, k: Int = 25, var n: Long = 0, // speedups with better DS val centroids: mutable.TreeSet[Centroid] = mutable.TreeSet[Centroid]()) extends Serializable { def size: Int = centroids.size def ++(other: TDigest): TDigest = { val bothCentroids: Seq[Centroid] = Random.shuffle(other.centroids.toSeq ++ this.centroids.toSeq) val newDigest = TDigest(this.delta, this.k) bothCentroids.foreach{ c => newDigest.addCentroid(c) } //newDigest.compress newDigest } def addCentroid(c: Centroid, increment: Boolean = false): Unit = { if (increment) { this.n += c.count } this.centroids.contains(c) match { case true => this.updateCentroid(c, c.mean, c.count) case false => this.centroids.add(c) } } def computeCentroidQuantile(centroid: Centroid): Double = { this.centroids.filter{ c => c.mean <= centroid.mean }.map{ _.count }.sum / this.n.toDouble } def updateCentroid(c: Centroid, x: Double, weight: Long): Unit = { centroids.find{ _ == c } match { case None => case Some(c) => c.update(x, weight) } } // maybe protected def findClosestCentroids(x: Double): mutable.TreeSet[Centroid] = { val (below, above) = this.centroids.partition{c => c.mean < x} val out: mutable.TreeSet[Centroid] = mutable.TreeSet[Centroid]() (below.lastOption ++ above.headOption).foreach{ c => out.add(c) } out } def threshold(q: Double): Long = Math.round(4 * this.n * this.delta * q * (1 - q)) // check this def compress: Unit = { val oldCentroids: Seq[Centroid] = Random.shuffle(this.centroids.toSeq) this.centroids.clear oldCentroids.foreach{ c => this.update(c.mean, c.count) } } // insert a new element def update(x: Double, weight: Long = 1): Unit = { this.n += weight this.size > 0 match { case false => this.addCentroid(Centroid(x, weight)) case true => val S: mutable.TreeSet[Centroid] = this.findClosestCentroids(x) var w: Long = weight Random.shuffle(S.toSeq).foreach{ c => val q: Double = this.computeCentroidQuantile(c) val delta_w: Long = Seq(this.threshold(q) - c.count, w).min if ((w > 0) && ((c.count + w) <= this.threshold(q))) { this.updateCentroid(c, x, delta_w) w -= delta_w } } if (w > 0) { this.addCentroid(Centroid(x, weight)) } /* if (this.size > (this.k / this.delta)) { this.compress } */ } } def batchUpdate(X: Seq[Double], weight: Long = 1): Unit = { X.foreach( x => this.update( x, weight )) //this.compress } def invCDF(p: Double): Double = { val cumProb: Seq[(Double, Centroid)] = this.centroids.map{ _.count / this.n.toDouble } .scanLeft(0.0)( _ + _ ).zip(this.centroids).toSeq val above: Option[(Double, Centroid)] = cumProb.find{ _._1 > p } val below: Option[(Double, Centroid)] = cumProb.reverse.find(_._1 < p) (below, above) match { case (None, None) => -1.0 // raise error here, we don't have any centroids case (None, Some(aC)) => aC._2.mean case (Some(bC), None) => bC._2.mean case (Some(bC), Some(aC)) => // linear interpolation between means val deltaX: Double = aC._2.mean - bC._2.mean val deltaP = (p - bC._1) / (aC._1 - bC._1) bC._2.mean + ( deltaP * deltaX ) } } def cdf(x: Double): Double = { val cumCount: Seq[(Double, Centroid)] = this.centroids.toSeq.map{ _.count / this.n.toDouble }.scanLeft(0.0)( _ + _ ).zip(this.centroids) val above: Option[(Double, Centroid)] = cumCount.find{ _._2.mean > x } val below: Option[(Double, Centroid)] = cumCount.reverse.find{ _._2.mean < x } (below, above) match { case (None, None) => -1.0 // raise error here, we don't have any centroids case (None, Some(aC)) => aC._1 case (Some(bC), None) => bC._1 case (Some(bC), Some(aC)) => // piece-wise uniform, distribution val deltaX: Double = (x - bC._2.mean) / (aC._2.mean - bC._2.mean) val deltaP = aC._1 - bC._1 bC._1 + ( deltaP * deltaX ) } } def trimmedMean(x0: Double, x1: Double): Double = { // the mean value, from a window of the distribution val within = this.centroids.filter{ c => c.mean > x0 && c.mean < x1 } val s = within.map{ _.count.toDouble }.sum val trimmedMean = within.map{ c => c.mean * (c.count / s) }.sum trimmedMean } } object TDigestAppCustom { def main(arg: Array[String]): Unit = { val appName: String = "TDigest-Test" val conf: SparkConf = new SparkConf().setAppName(appName).setMaster("local[16]") val sc: SparkContext = new SparkContext(conf) val trueDist0: ExponentialDistribution = new ExponentialDistribution(15) val trueDist1: ExponentialDistribution = new ExponentialDistribution(30) val data0: immutable.Seq[Double] = (0 until 10000).map{ i => trueDist0.sample()}.toSeq val data1: immutable.Seq[Double] = (0 until 10000).map{ i => trueDist1.sample()}.toSeq val TD0: TDigest = new TDigest() val TD1: TDigest = new TDigest() TD0.batchUpdate(data0, 1) TD1.batchUpdate(data1, 1) println(TD0.cdf(15.0)) println(TD1.cdf(30.0)) println(TD0.invCDF(0.50)) println(TD1.invCDF(0.50)) val bothTD: TDigest = TD0 ++ TD1 println(bothTD.cdf(15.0)) println(bothTD.cdf(30.0)) println(bothTD.invCDF(0.50)) } }