Skip to content

Instantly share code, notes, and snippets.

@AlexBaitov
Last active May 24, 2021 07:40
Show Gist options
  • Select an option

  • Save AlexBaitov/fa506e736198c31dd04d2a7dec1b47e6 to your computer and use it in GitHub Desktop.

Select an option

Save AlexBaitov/fa506e736198c31dd04d2a7dec1b47e6 to your computer and use it in GitHub Desktop.
Flatten nested StructType dataframe schema

Original realisations

https://stackoverflow.com/questions/61863489/flatten-nested-json-in-scala-spark-dataframe

(comment https://stackoverflow.com/a/61863579/3251389)

https://stackoverflow.com/questions/37471346/automatically-and-elegantly-flatten-dataframe-in-spark-sql - flatten only schema examples

Goal

Get flattened schema

scala> df.printSchema
root 
 |-- author: string (nullable = true)
 |-- frameworks: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- id: long (nullable = true)
 |    |    |-- name: string (nullable = true)
 |-- id: long (nullable = true)
 |-- name: string (nullable = true)


scala> df.explodeColumns.printSchema
root
 |-- author: string (nullable = true)
 |-- frameworks_id: long (nullable = true)
 |-- frameworks_name: string (nullable = true)
 |-- id: long (nullable = true)
 |-- name: string (nullable = true)

scala>

Refactoring

Intermediate result Pattern matching on struct types. Final result Adding recursive Schema snakify without calling df.select recursively and syntax ingestion

Only schema flattening from stackoverflow with some remarks

// https://stackoverflow.com/a/61863579/3251389
scala> :paste
// Entering paste mode (ctrl-D to finish)
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import scala.annotation.tailrec
import scala.util.Try
implicit class DFHelpers(df: DataFrame) {
def columns = {
val dfColumns = df.columns.map(_.toLowerCase)
df.schema.fields.flatMap { data =>
data match {
case column if column.dataType.isInstanceOf[StructType] => {
column.dataType.asInstanceOf[StructType].fields.map { field =>
val columnName = column.name
val fieldName = field.name
col(s"${columnName}.${fieldName}").as(s"${columnName}_${fieldName}")
}.toList
}
case column => List(col(s"${column.name}"))
}
}
}
def flatten: DataFrame = {
val empty = df.schema.filter(_.dataType.isInstanceOf[StructType]).isEmpty
empty match {
case false =>
df.select(columns: _*).flatten
case _ => df
}
}
def explodeColumns = {
@tailrec
def columns(cdf: DataFrame):DataFrame = cdf.schema.fields.filter(_.dataType.typeName == "array") match {
case c if !c.isEmpty => columns(c.foldLeft(cdf)((dfa,field) => {
dfa.withColumn(field.name,explode_outer(col(s"${field.name}"))).flatten
}))
case _ => cdf
}
columns(df.flatten)
}
}
// Exiting paste mode, now interpreting.
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import scala.annotation.tailrec
import scala.util.Try
defined class DFHelpers
scala> :paste
// Entering paste mode (ctrl-D to finish)
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import scala.annotation.tailrec
import scala.util.Try
implicit class DFHelpers(df: DataFrame) {
def columns = {
val dfColumns = df.columns.map(_.toLowerCase) // <-- is not used further
df.schema.fields.flatMap { data =>
data match {
case column if column.dataType.isInstanceOf[StructType] => {
column.dataType.asInstanceOf[StructType].fields.map { field =>
val columnName = column.name
val fieldName = field.name
col(s"${columnName}.${fieldName}").as(s"${columnName}_${fieldName}")
}.toList
}
case column => List(col(s"${column.name}"))
}
}
}
def flatten: DataFrame = {
val empty = df.schema.filter(_.dataType.isInstanceOf[StructType]).isEmpty
empty match {
case false =>
df.select(columns: _*).flatten
case _ => df
}
}
def explodeColumns = {
@tailrec
def columns(cdf: DataFrame):DataFrame = cdf.schema.fields.filter(_.dataType.typeName == "array") match {
case c if !c.isEmpty => columns(c.foldLeft(cdf)((dfa,field) => {
dfa.withColumn(field.name,explode_outer(col(s"${field.name}"))).flatten // <-- no need to flatten array without nested struct type
}))
case _ => cdf
}
columns(df.flatten)
}
}
// Exiting paste mode, now interpreting.
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import scala.annotation.tailrec
import scala.util.Try
defined class DFHelpers
package com.example.utils.syntax
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.functions.{col, explode_outer}
import org.apache.spark.sql.types.{ArrayType, StructField, StructType}
import scala.annotation.tailrec
implicit class DFHelpers(df: DataFrame) {
def snakifyColumns: Array[Column] = {
df.schema.fields.flatMap { structField: StructField =>
structField.dataType match {
case dataType: StructType => dataType.fields.map { field =>
val columnName = structField.name
val fieldName = field.name
col(s"${columnName}.${fieldName}").as(s"${columnName}_${fieldName}")
}.toList
case _ => col(s"${structField.name}") :: Nil
}
}
}
@tailrec
final def flattenStruct: DataFrame = {
if (df.schema.fields.exists(_.dataType.isInstanceOf[StructType])) {
df.select(df.snakifyColumns: _*).flattenStruct
} else df
}
def explodeColumns: DataFrame = {
@tailrec
def explodeRecursively(cdf: DataFrame): DataFrame = {
cdf.schema.fields.collect {
case field@StructField(_, ArrayType(_, _), _, _) => field
}.toList match {
case Nil => cdf
case l: List[StructField] =>
val explodedDf: DataFrame = l.foldLeft(cdf)((dfa, field) =>
dfa.withColumn(field.name, explode_outer(col(s"${field.name}")))
)
val flattenedDf: DataFrame = l.collectFirst {
case _@StructField(_, ArrayType(_: StructType, _), _, _) => explodedDf.flattenStruct
}.getOrElse(explodedDf)
explodeRecursively(flattenedDf)
}
}
explodeRecursively(df.flattenStruct)
}
}
package ru.cft.ml.spark.utils
package object syntax {
object all extends DatasetSyntax
object dataset extends DatasetSyntax
}
package com.example.utils.syntax
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.functions.{col, explode_outer}
import org.apache.spark.sql.types.{ArrayType, StructField, StructType}
import scala.annotation.tailrec
// recursive Schema snakify without calling df.select recursively
trait DatasetSyntax {
implicit class DatasetFlattenNestedColumns[A](df: Dataset[A]) {
def snakifyColumns(delimiter: String = "_", prefix: Option[String] = None): Array[Column] = {
def flattenSchema(schema: StructType, delimiter: String, prefix: Option[String], aliasPrefix: Option[String]): Array[Column] = {
schema.fields.flatMap(field => {
val colName = prefix.map(_ + "." + field.name).getOrElse(field.name)
val aliasName = aliasPrefix.map(_ + delimiter + field.name).getOrElse(field.name)
field match {
case StructField(_, struct: StructType, _, _) => flattenSchema(struct, delimiter, Some(colName), Some(aliasName))
case _ => Array(col(colName).as(aliasName))
}
})
}
flattenSchema(df.schema, delimiter, prefix, prefix)
}
final def flattenStruct: DataFrame = {
df.select(df.snakifyColumns(): _*)
}
def explodeColumns: DataFrame = {
@tailrec
def explodeRecursively(cdf: DataFrame): DataFrame = {
cdf.schema.fields.collect {
case field@StructField(_, ArrayType(_, _), _, _) => field
}.toList match {
case Nil => cdf
case l: List[StructField] =>
val explodedDf: DataFrame = l.foldLeft(cdf)((dfa, field) =>
dfa.withColumn(field.name, explode_outer(col(s"${field.name}")))
)
val flattenedDf: DataFrame = l.collectFirst {
case _@StructField(_, ArrayType(_: StructType, _), _, _) => explodedDf.flattenStruct
}.getOrElse(explodedDf)
explodeRecursively(flattenedDf)
}
}
explodeRecursively(df.flattenStruct)
}
}
}
implicit class DataframeOnlySchemaFlattening(df: DataFrame) {
def explodeSchema = {
def flattenSchema(schema: StructType, prefix: String = null) : Array[String] = {
schema.fields.flatMap(f => {
val colName = if (prefix == null) f.name else (prefix + "." + f.name)
f match {
case StructField(_, struct:StructType, _, _) => flattenSchema(struct, colName)
case StructField(_, ArrayType(x :StructType, _), _, _) => flattenSchema(x, colName)
case StructField(_, ArrayType(_, _), _, _) => Array(colName)
case _ => Array(colName)
}
})
}
val cols = flattenSchema(df.schema)
val snakifiedCols = cols.map(c => col(c).as(c.replaceAll("\\.","_")))
df.select(snakifiedCols: _*)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment