当前位置: 首页 > 知识库问答 >
问题:

在Spark中创建用户定义的函数来处理嵌套结构列

麻和雅
2023-03-14

在我的数据框架中,我有一个复杂的数据结构,我需要处理它来更新另一列。我尝试的方法是使用UDF。但是,如果有更简单的方法可以做到这一点,请随意回答。

所讨论的数据帧结构是:

root
 |-- user: string (nullable = true)
 |-- cat: string (nullable = true)
 |-- data_to_update: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- _1: array (nullable = true)
 |    |    |    |-- element: double (containsNull = false)
 |    |    |-- _2: double (nullable = false)

我试图解决的问题是在有闪烁时更新id列。闪烁发生在id列从id到另一个多次更改时;例如,这将是闪烁4,5,4,5,4

data标识的Data连接起来,如如何构造数据部分所示

+----+-----+----+--------+----+---------------------------------------------------------+
|user|cat  |id  |time_sec|rn  |data_to_update                                           |
+----+-----+----+--------+----+---------------------------------------------------------+
|a   |silom|3.0 |1       |1.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|2.0 |2       |2.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|1.0 |3       |3.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|0.0 |4       |4.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|1.0 |5       |5.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|0.0 |6       |6.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|1.0 |7       |7.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|2.0 |8       |8.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|3.0 |9       |9.0 |[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|4.0 |10      |10.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|3.0 |11      |11.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|4.0 |12      |12.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|3.0 |13      |13.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|4.0 |14      |14.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|
|a   |silom|5.0 |15      |15.0|[[[4.0, 5.0, 6.0], 0.0], [[10.0, 11.0, 12.0, 13.0], 4.0]]|

data_to_update告诉我们的是,rn[4.0,5.0,6.0],需要通过将id更改为0.0来更新,rn[10.0,11.0,12.0,12.0]需要通过将id更改为4.0来更新。

我的企图

我在想我可以使用UDF来处理data_to_update列,方法是使用with Colzo更新id列。但是,复杂的数据结构是我遇到问题的原因。到目前为止,我所拥有的是

// I will call this in my UDF
def checkArray(clusterID: Double, colRN: Double, arr: Array[(Array[Double], Double)]): Double = {

    // @tailrec
    def getReturnID(clusterID: Double, colRN: Double, arr: Array[(Array[Double], Double)]): Double = arr match {

        case arr if arr.nonEmpty && arr(0)._1.contains(colRN) =>
            arr(0)._2
        case arr if arr.nonEmpty && !arr(0)._1.contains(colRN) =>
            getReturnID(clusterID, colRN, arr.drop(1))
        case _ => clusterID
    }

    getReturnID(clusterID, colRN, arr)
}

val columnUpdate: UserDefinedFunction = udf {
    (colID: Double, colRN: Double, colArrayData: Array[(Array[Double], Double)]) =>

    if(colArrayData.length > 0) {
        checkArray(colID, colRN, colArrayData)
    }
    else {
        colID
    }
}

data
    .join(broadcast(identifiedData), Seq("user", "cat"), "inner")
    .withColumn("id", columnUpdate($"id", $"rn", $"data_to_update"))
    .show(100, false)

我无法访问已转换为结构的元组。

org.apache.spark.SparkException: Failed to execute user defined function($anonfun$1: (double, double, array<struct<_1:array<double>,_2:double>>) => double)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage7.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
    at org.apache.spark.scheduler.Task.run(Task.scala:109)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
    at java.lang.Thread.run(Thread.java:745)
Caused by: java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to [Lscala.Tuple2;
    at $line329.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:37)

通过阅读StructType/Row的Spark UDF,我模仿了他们的设置,但仍然没有成功。我所做的改变是

def checkArray(clusterID: Double, 
               colRN: Double, 
               dataStruct: Row): Double = {

    // map back array to fill id value
    val arrData = dataStruct
        .getAs[Seq[Double]](0)
        .zipWithIndex.map{
            case (arr, idx) => (Array(arr), dataStruct.getAs[Seq[Double]](1)(idx))
        }

    // @tailrec
    def getReturnID(clusterID: Double, 
                    colRN: Double, 
                    arr: Seq[(Array[Double], Double)]): Double = arr match {

        case arr if arr.nonEmpty && arr(0)._1.contains(colRN) =>
            arr(0)._2
        case arr if arr.nonEmpty && !arr(0)._1.contains(colRN) =>
            getReturnID(clusterID, colRN, arr.drop(1))
        case _ => clusterID
    }

    getReturnID(clusterID, colRN, arrData)
}

val columnUpdate: UserDefinedFunction = udf {
    (colID: Double, colRN: Double, colStructData: Row) =>

    if(colStructData.getAs[Seq[Double]](0).nonEmpty) {
        checkArray(colID, colRN, colStructData)
    }
    else {
        colID
    }
}

我正在传递一个。数据的形式如下:

data.join(
    broadcast(identifiedData
        .withColumn("data_to_update", $"data_to_update")
        .withColumn("array_data", $"data_to_update._1")
        .withColumn("value", $"data_to_update._2")
        .withColumn("data_struct", struct("array_data", "value"))
        .drop("data_to_update", "array_data", "value")
             ),
              Seq("user", "cat"), 
              "inner"
    )

使用以下模式

root
 |-- user: string (nullable = true)
 |-- cat: string (nullable = true)
 |-- id: double (nullable = false)
 |-- time_sec: integer (nullable = false)
 |-- rn: double (nullable = true)
 |-- data_struct: struct (nullable = false)
 |    |-- array_data: array (nullable = true)
 |    |    |-- element: array (containsNull = true)
 |    |    |    |-- element: double (containsNull = false)
 |    |-- value: array (nullable = true)
 |    |    |-- element: double (containsNull = true)

data.join(
    broadcast(identifiedData
        .withColumn("data_to_update", $"data_to_update")
        .withColumn("array", $"data_to_update._1")
        .withColumn("value", $"data_to_update._2")
        .withColumn("data_struct", struct("array", "value"))
        .drop("data_to_update", "array", "value")
             ),
              Seq("user", "cat"), 
              "inner"
    )
    .withColumn("id", columnUpdate($"id", $"rn", $"data_struct"))
    .show

使用这种方法,我得到以下错误:

Caused by: java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to java.lang.Double

但是,我正在传入,然后执行 getAs 以转换为所需的数据结构。

如何构造数据

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.types.DoubleType
import scala.collection.mutable.WrappedArray
import org.apache.spark.sql.expressions.UserDefinedFunction
import scala.util.Sorting.stableSort
import org.apache.spark.sql.Row

val dataDF = Seq(
    ("a", "silom", 3, 1),
    ("a", "silom", 2, 2),
    ("a", "silom", 1, 3),
    ("a", "silom", 0, 4),
    ("a", "silom", 1, 5),
    ("a", "silom", 0, 6),
    ("a", "silom", 1, 7),
    ("a", "silom", 2, 8),
    ("a", "silom", 3, 9),
    ("a", "silom", 4, 10),
    ("a", "silom", 3, 11),
    ("a", "silom", 4, 12),
    ("a", "silom", 3, 13),
    ("a", "silom", 4, 14),
    ("a", "silom", 5, 15),
    ("a", "suk", 18, 1),
    ("a", "suk", 19, 2),
    ("a", "suk", 20, 3),
    ("a", "suk", 21, 4),
    ("a", "suk", 20, 5),
    ("a", "suk", 21, 6),
    ("a", "suk", 0, 7),
    ("a", "suk", 1, 8),
    ("a", "suk", 2, 9),
    ("a", "suk", 3, 10),
    ("a", "suk", 4, 11),
    ("a", "suk", 3, 12),
    ("a", "suk", 4, 13),
    ("a", "suk", 3, 14),
    ("a", "suk", 5, 15),
    ("b", "silom", 4, 1),
    ("b", "silom", 3, 2),
    ("b", "silom", 2, 3),
    ("b", "silom", 1, 4),
    ("b", "silom", 0, 5),
    ("b", "silom", 1, 6),
    ("b", "silom", 2, 7),
    ("b", "silom", 3, 8),
    ("b", "silom", 4, 9),
    ("b", "silom", 5, 10),
    ("b", "silom", 6, 11),
    ("b", "silom", 7, 12),
    ("b", "silom", 8, 13),
    ("b", "silom", 9, 14),
    ("b", "silom", 10, 15),
    ("b", "suk", 11, 1),
    ("b", "suk", 12, 2),
    ("b", "suk", 13, 3),
    ("b", "suk", 14, 4),
    ("b", "suk", 13, 5),
    ("b", "suk", 14, 6),
    ("b", "suk", 13, 7),
    ("b", "suk", 12, 8),
    ("b", "suk", 11, 9),
    ("b", "suk", 10, 10),
    ("b", "suk", 9, 11),
    ("b", "suk", 8, 12),
    ("b", "suk", 7, 13),
    ("b", "suk", 6, 14),
    ("b", "suk", 5, 15)
).toDF("user", "cat", "id", "time_sec")
val recastDataDF = dataDF.withColumn("id", $"id".cast(DoubleType))

val category = recastDataDF.select("cat").distinct.collect.map(x => x(0).toString)

val data = recastDataDF
    .withColumn("rn", row_number.over(Window.partitionBy("user", "cat").orderBy("time_sec")).cast(DoubleType))

val data2 = recastDataDF
    .select($"*" +: category.map(
        name => 
        lag("id", 1).over(
            Window.partitionBy("user", "cat").orderBy("time_sec")
        )
        .alias(s"lag_${name}_id")): _*)
    .withColumn("sequencing_diff", when($"cat" === "silom", ($"lag_silom_id" - $"id").cast(DoubleType))
                .otherwise(($"lag_suk_id" - $"id")))
    .drop("lag_silom_id", "lag_suk_id")
    .withColumn("rn", row_number.over(Window.partitionBy("user", "cat").orderBy("time_sec")).cast(DoubleType))
    .withColumn("id_rn", array($"id", $"rn", $"sequencing_diff"))
    .groupBy($"user", $"cat").agg(collect_list($"id_rn").alias("array_data"))

def getGrps2(arr: Array[(Double, Double)]): Array[(Array[Double], Double)] = {

    // @tailrec
    def returnAlternatingIDs(arr: Array[(Double, Double)], 
                             altIDs: Array[(Array[Double], Double)]): Array[(Array[Double], Double)] = arr match {

        case arr if arr.nonEmpty =>
            val rowNum = arr.take(1)(0)._1
            val keepID = arr.take(1)(0)._2
            val newArr = arr.drop(1)

            val rowNums = (Array(rowNum)) ++ newArr.zipWithIndex.map{
                case (tups, idx) => 
                if(rowNum + idx + 1 == tups._1) {
                    rowNum + 1 + idx
                }
                else {
                    Double.NaN
                }
            }
                .filter(v => v == v)

            val updateArray = altIDs ++ Array((rowNums, keepID))
            returnAlternatingIDs(arr.drop(rowNums.length), updateArray)
        case _ => altIDs
    }

    returnAlternatingIDs(arr, Array((Array(Double.NaN), Double.NaN))).drop(1)
}

val identifyFlickeringIDs: UserDefinedFunction = udf {
    (colArrayData: WrappedArray[WrappedArray[Double]]) =>
    val newArray: Array[(Double, Double, Double)] = colArrayData.toArray
        .map(x => (x(0).toDouble, x(1).toDouble, x(2).toDouble))

    // sort array by rn via less than relation
    stableSort(newArray, (e1: (Double, Double, Double), e2: (Double, Double, Double)) => e1._2 < e2._2)

    val shifted: Array[(Double, Double, Double)] = newArray.toArray.drop(1)
    val combined = newArray
        .zipAll(shifted, (Double.NaN, Double.NaN, Double.NaN), (Double.NaN, Double.NaN, Double.NaN))

    val parsedArray = combined.map{
        case (data0, data1) =>
        if(data0._3 != data1._3 && data0._2 > 1 && data0._3 + data1._3 == 0) {
            (data0._2, data0._1)
        }
        else (Double.NaN, Double.NaN)
        }
        .filter(t => t._1 == t._1 && t._1 == t._1)

    getGrps2(parsedArray).filter(data => data._1.length > 1)
}

val identifiedData = data2.withColumn("data_to_update", identifyFlickeringIDs($"array_data"))
    .drop("array_data")

共有1个答案

景令秋
2023-03-14

我通过解构列找到了解决方案,因为它位于数组

def checkArray(clusterID: Double, 
               colRN: Double, 
               dataStruct: Row): Double = {
    // Array[(Array[Double], Double)]
    val arrData: Seq[(Seq[Double], Double)] = dataStruct
        .getAs[Seq[Seq[Double]]](0)
        .zipWithIndex.map{
            case (arr, idx) => (arr, dataStruct.getAs[Seq[Double]](1)(idx))
        }

    // @tailrec
    def getReturnID(clusterID: Double, 
                    colRN: Double, 
                    arr: Seq[(Seq[Double], Double)]): Double = arr match {

        case arr if arr.nonEmpty && arr(0)._1.contains(colRN) =>
            arr(0)._2
        case arr if arr.nonEmpty && !arr(0)._1.contains(colRN) =>
            getReturnID(clusterID, colRN, arr.drop(1))
        case _ => clusterID
    }

    getReturnID(clusterID, colRN, arrData)
}

val columnUpdate: UserDefinedFunction = udf {
    (colID: Double, colRN: Double, colStructData: Row) =>

    if(colStructData.getAs[Seq[Double]](0).nonEmpty) {
        checkArray(colID, colRN, colStructData)
    }
    else {
        colID
    }
}

// I believe all this withColumns are unnecessary but this was the only way
// I could get a working solution
data.join(
    broadcast(identifiedData
        .withColumn("data_to_update", $"data_to_update")
        .withColumn("array", $"data_to_update._1")
        .withColumn("value", $"data_to_update._2")
        .withColumn("data_struct", struct("array", "value"))
        .drop("data_to_update", "array", "value")
             ),
              Seq("user", "cat"), 
              "inner"
    )
    .withColumn("id", columnUpdate($"id", $"rn", $"data_struct"))
    .show(100, false)

 类似资料:
  • 我已经在Spark中读取了一个JSON文件。该文件具有以下结构: 我创建了一个递归函数来使用嵌套结构类型的列来展平架构 如何展平包含嵌套结构类型的ArrayType,例如engagementItems:数组(nullable=true) 感谢您的帮助。

  • 我正在使用MongoDB-Hadoop连接器读取具有嵌入文档的集合。 例外 scala.matcherror:io.abc.spark.schema.personametadata@31FF5060(类为io.abc.spark.sql.catalyst.catalyst.catalysttypeConverters$structconverters.scala:255)在org.apache.s

  • 也许这很愚蠢,我是一名Microsoft SQL/C开发人员,以前从未真正使用过任何其他IDE/编写的JAVA/SCALA。我正在将一些Azure SQL查询迁移到Azure Databricks解决方案。 似乎没有等效的TSQLDATEDIFF_BIG函数(https://docs.microsoft.com/en-us/sql/t-sql/functions/datediff-transact

  • 我需要展平一个数据帧,以便将其与Spark(Scala)中的另一个数据帧连接起来。 基本上,我的2个数据帧有以下模式: 数据流1 DF2 老实说,我不知道如何使DF2变平。最后,我需要连接DF.field4 = DF2.field9上的2个数据帧 我用的是2.1.0 我的第一个想法是使用爆炸,但在Spark 2.1.0中已经被否决了,有人能给我一点提示吗?

  • 我正在使用 Scala,并希望构建自己的数据帧函数。例如,我想将一列视为数组,循环访问每个元素并进行计算。 首先,我尝试实现自己的getMax方法。因此列x的值为[3,8,2,5,9],该方法的预期输出为9。 下面是它在Scala中的样子 这是我目前所知道的,并得到这个错误 我不知道如何迭代该列。 } 一旦我能够实现自己的方法,我将创建一个列函数 然后我希望能够在SQL的陈述中使用它 给定输入列[

  • 问题内容: 您如何初始化以下结构? 我试过了: 没用: 我试过了: 没用: 问题答案: 您是一个具有匿名结构类型的字段。因此,您必须重复类型定义: 但在大多数情况下,最好将其定义为rob74建议的单独类型。