当前位置: 首页 > 工具软件 > collect > 使用案例 >

Spark开发注意: collect_list、collect_set会去除Null值

凌琦
2023-12-01

今天我们踩到一个collect_list的坑,collect_list的结果不包含null值

namecity
张三广州
null广州
李四深圳

对city作group by后collect_list(name)得到的结果中city='广州’为List(‘张三’),没有null值。跟踪源码:

  def collect_list(e: Column): Column = withAggregateFunction { CollectList(e.expr) }

collect_list使用CollectList计算

@ExpressionDescription(
  usage = "_FUNC_(expr) - Collects and returns a list of non-unique elements.")
case class CollectList(
    child: Expression,
    mutableAggBufferOffset: Int = 0,
    inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {

  def this(child: Expression) = this(child, 0, 0)

  override lazy val bufferElementType = child.dataType

  override def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value)

  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
    copy(mutableAggBufferOffset = newMutableAggBufferOffset)

  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
    copy(inputAggBufferOffset = newInputAggBufferOffset)

  override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty

  override def prettyName: String = "collect_list"

  override def eval(buffer: mutable.ArrayBuffer[Any]): Any = {
    new GenericArrayData(buffer.toArray)
  }
}

CollectList继承于Collect[mutable.ArrayBuffer[Any]]


/**
 * A base class for collect_list and collect_set aggregate functions.
 *
 * We have to store all the collected elements in memory, and so notice that too many elements
 * can cause GC paused and eventually OutOfMemory Errors.
 */
abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T] {

  val child: Expression

  override def children: Seq[Expression] = child :: Nil

  override def nullable: Boolean = true

  override def dataType: DataType = ArrayType(child.dataType)

  // Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the
  // actual order of input rows.
  override lazy val deterministic: Boolean = false

  override def defaultResult: Option[Literal] = Option(Literal.create(Array(), dataType))

  protected def convertToBufferElement(value: Any): Any

  override def update(buffer: T, input: InternalRow): T = {
    val value = child.eval(input)

    // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here.
    // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
    if (value != null) {
      buffer += convertToBufferElement(value)
    }
    buffer
  }

  override def merge(buffer: T, other: T): T = {
    buffer ++= other
  }

  protected val bufferElementType: DataType

  private lazy val projection = UnsafeProjection.create(
    Array[DataType](ArrayType(elementType = bufferElementType, containsNull = false)))
  private lazy val row = new UnsafeRow(1)

  override def serialize(obj: T): Array[Byte] = {
    val array = new GenericArrayData(obj.toArray)
    projection.apply(InternalRow.apply(array)).getBytes()
  }

  override def deserialize(bytes: Array[Byte]): T = {
    val buffer = createAggregationBuffer()
    row.pointTo(bytes, bytes.length)
    row.getArray(0).foreach(bufferElementType, (_, x: Any) => buffer += x)
    buffer
  }
}

在update方法中可以看到判空的逻辑,注释内容说要与Hive的collect_list/collect_set方法保持一致。
如果collect_list再支持一个可选参数,用于控制是否过滤null就好了,于是我们自定义了collect_list函数

object CollectListWithNullUDAF extends UserDefinedAggregateFunction {

  // Data types of input arguments of this aggregate function
  override def inputSchema: StructType = StructType(StructField("element", StringType) :: Nil)

  // Data types of values in the aggregation buffer
  override def bufferSchema: StructType = StructType(StructField("buffer", ArrayType(StringType, containsNull = true)) :: Nil)

  // The data type of the returned value
  override def dataType: DataType = ArrayType(StringType, containsNull = true)

  // Whether this function always returns the same output on the identical input
  override def deterministic: Boolean = true

  // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
  // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
  // the opportunity to update its values. Note that arrays and maps inside the buffer are still
  // immutable.  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = List.empty[String]
  }

  // Updates the given aggregation buffer `buffer` with new input data from `input`
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    implicit val defaultFormats: DefaultFormats.type = org.json4s.DefaultFormats
    var list = buffer.get(0).asInstanceOf[mutable.WrappedArray[String]].toList

    val value = input.get(0)
    if (value == null) {
      list = list.+:(null)
    } else {
      list = list.+:(value.toString)
    }

    buffer(0) = list
  }

  // Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val listResult = buffer1.get(0).asInstanceOf[mutable.WrappedArray[String]].toList
    val listTemp = buffer2.get(0).asInstanceOf[mutable.WrappedArray[String]].toList

    buffer1(0) = listResult ++ listTemp
  }

  // Calculates the final result
  override def evaluate(buffer: Row): List[String] = {
    val list = buffer.get(0).asInstanceOf[mutable.WrappedArray[String]].reverse.toList
    list
  }
}

// 方法注册
spark.udf.register("collect_list_with_null", CollectListWithNullUDAF)

 类似资料: