Created
May 22, 2016 18:38
-
-
Save shkr/32b82d3ea082f42e12c901531f53ccdb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import org.apache.spark.sql.Column | |
| /** | |
| * Optimized Median calculation for a distributed dataframe built on three findings : | |
| * 1. Real world datasets are made from low cardinality domains | |
| * 2. Median calculation requires sort which is O(N * log N), it implies that computation time increase by a factor significantly greater than 2 for a list twice as long. | |
| * 3. Hive UDF support only primitive data types hence in the algo the datastructure transferred to a single node is a Seq[String] instead of Seq[Struct{}] | |
| **/ | |
| def insertMedianOnKey(inputDF: DataFrame, | |
| key: Seq[Column], | |
| column: String, | |
| medianColumnName: String, | |
| errorMargin: Double = 10.0): DataFrame = { | |
| // To further reduce cardinality of the set whose values are found in column | |
| val roundOff = udf((item: Double) => item - (item % errorMargin)) | |
| // Make a String rep of a tuple | |
| val tuple2UDF = udf((item1: Int, item2: Int) => item1.toString + "," + item2.toString) | |
| // Search Median in a quantity grouped set | |
| def getElement(a:Array[(Double, Long)], element: Long): Double = { | |
| if (a.length == 0) | |
| -1.0 // this case should not happen... ;-) | |
| else if (element < a(0)._2) | |
| a(0)._1 | |
| else | |
| getElement(a.tail, element - a(0)._2) | |
| } | |
| // UDF to find Median in a Seq[String] | |
| val findMedian = udf((items: Seq[String]) => { | |
| val numberArray = items.map(item => (item.split(',')(0).toDouble, item.split(',')(1).toLong)).sortBy(_._1).toArray | |
| val rowCount = numberArray.map(item => item._2).sum | |
| val medianOne = Math.floor((rowCount-1)/2.0).toInt | |
| val medianTwo = Math.ceil((rowCount-1)/2.0).toInt | |
| val medianValueOne = getElement(numberArray, medianOne) | |
| val medianValueTwo = getElement(numberArray, medianTwo) | |
| (medianValueOne + medianValueTwo)/2.0 | |
| }) | |
| // UDF to find Sample Size | |
| val sampleSize = udf((items: Seq[String]) => items.map(item => item.split(',')(1).toLong).sum) | |
| // Apply algorithm | |
| val df = inputDF.withColumn("rounded_value", roundOff($"${column}")) | |
| .groupBy((key ++ Array($"rounded_value")):_*) | |
| .agg(count($"rounded_value").as("quantity(rounded_value)")) | |
| .withColumn("value_with_quantity", tuple2UDF($"rounded_value", $"quantity(rounded_value)")) | |
| .groupBy(key:_*) | |
| .agg(collect_list($"value_with_quantity").as("value_with_quantity")) | |
| // Return result | |
| return df.withColumn(medianColumnName, findMedian($"value_with_quantity")).withColumn("n", sampleSize($"value_with_quantity")).drop("value_with_quantity") | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment