Skip to content

Instantly share code, notes, and snippets.

@shkr
Created May 22, 2016 18:38
Show Gist options
  • Select an option

  • Save shkr/32b82d3ea082f42e12c901531f53ccdb to your computer and use it in GitHub Desktop.

Select an option

Save shkr/32b82d3ea082f42e12c901531f53ccdb to your computer and use it in GitHub Desktop.
import org.apache.spark.sql.Column
/**
* Optimized Median calculation for a distributed dataframe built on three findings :
* 1. Real world datasets are made from low cardinality domains
* 2. Median calculation requires sort which is O(N * log N), it implies that computation time increase by a factor significantly greater than 2 for a list twice as long.
* 3. Hive UDF support only primitive data types hence in the algo the datastructure transferred to a single node is a Seq[String] instead of Seq[Struct{}]
**/
def insertMedianOnKey(inputDF: DataFrame,
key: Seq[Column],
column: String,
medianColumnName: String,
errorMargin: Double = 10.0): DataFrame = {
// To further reduce cardinality of the set whose values are found in column
val roundOff = udf((item: Double) => item - (item % errorMargin))
// Make a String rep of a tuple
val tuple2UDF = udf((item1: Int, item2: Int) => item1.toString + "," + item2.toString)
// Search Median in a quantity grouped set
def getElement(a:Array[(Double, Long)], element: Long): Double = {
if (a.length == 0)
-1.0 // this case should not happen... ;-)
else if (element < a(0)._2)
a(0)._1
else
getElement(a.tail, element - a(0)._2)
}
// UDF to find Median in a Seq[String]
val findMedian = udf((items: Seq[String]) => {
val numberArray = items.map(item => (item.split(',')(0).toDouble, item.split(',')(1).toLong)).sortBy(_._1).toArray
val rowCount = numberArray.map(item => item._2).sum
val medianOne = Math.floor((rowCount-1)/2.0).toInt
val medianTwo = Math.ceil((rowCount-1)/2.0).toInt
val medianValueOne = getElement(numberArray, medianOne)
val medianValueTwo = getElement(numberArray, medianTwo)
(medianValueOne + medianValueTwo)/2.0
})
// UDF to find Sample Size
val sampleSize = udf((items: Seq[String]) => items.map(item => item.split(',')(1).toLong).sum)
// Apply algorithm
val df = inputDF.withColumn("rounded_value", roundOff($"${column}"))
.groupBy((key ++ Array($"rounded_value")):_*)
.agg(count($"rounded_value").as("quantity(rounded_value)"))
.withColumn("value_with_quantity", tuple2UDF($"rounded_value", $"quantity(rounded_value)"))
.groupBy(key:_*)
.agg(collect_list($"value_with_quantity").as("value_with_quantity"))
// Return result
return df.withColumn(medianColumnName, findMedian($"value_with_quantity")).withColumn("n", sampleSize($"value_with_quantity")).drop("value_with_quantity")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment