今天我们踩到一个collect_list的坑,collect_list的结果不包含null值
name | city |
---|---|
张三 | 广州 |
null | 广州 |
李四 | 深圳 |
对city作group by后collect_list(name)得到的结果中city='广州’为List(‘张三’),没有null值。跟踪源码:
def collect_list(e: Column): Column = withAggregateFunction { CollectList(e.expr) }
collect_list使用CollectList计算
@ExpressionDescription(
usage = "_FUNC_(expr) - Collects and returns a list of non-unique elements.")
case class CollectList(
child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {
def this(child: Expression) = this(child, 0, 0)
override lazy val bufferElementType = child.dataType
override def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value)
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
override def prettyName: String = "collect_list"
override def eval(buffer: mutable.ArrayBuffer[Any]): Any = {
new GenericArrayData(buffer.toArray)
}
}
CollectList继承于Collect[mutable.ArrayBuffer[Any]]
/**
* A base class for collect_list and collect_set aggregate functions.
*
* We have to store all the collected elements in memory, and so notice that too many elements
* can cause GC paused and eventually OutOfMemory Errors.
*/
abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T] {
val child: Expression
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
override def dataType: DataType = ArrayType(child.dataType)
// Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the
// actual order of input rows.
override lazy val deterministic: Boolean = false
override def defaultResult: Option[Literal] = Option(Literal.create(Array(), dataType))
protected def convertToBufferElement(value: Any): Any
override def update(buffer: T, input: InternalRow): T = {
val value = child.eval(input)
// Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here.
// See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
if (value != null) {
buffer += convertToBufferElement(value)
}
buffer
}
override def merge(buffer: T, other: T): T = {
buffer ++= other
}
protected val bufferElementType: DataType
private lazy val projection = UnsafeProjection.create(
Array[DataType](ArrayType(elementType = bufferElementType, containsNull = false)))
private lazy val row = new UnsafeRow(1)
override def serialize(obj: T): Array[Byte] = {
val array = new GenericArrayData(obj.toArray)
projection.apply(InternalRow.apply(array)).getBytes()
}
override def deserialize(bytes: Array[Byte]): T = {
val buffer = createAggregationBuffer()
row.pointTo(bytes, bytes.length)
row.getArray(0).foreach(bufferElementType, (_, x: Any) => buffer += x)
buffer
}
}
在update方法中可以看到判空的逻辑,注释内容说要与Hive的collect_list/collect_set方法保持一致。
如果collect_list再支持一个可选参数,用于控制是否过滤null就好了,于是我们自定义了collect_list函数
object CollectListWithNullUDAF extends UserDefinedAggregateFunction {
// Data types of input arguments of this aggregate function
override def inputSchema: StructType = StructType(StructField("element", StringType) :: Nil)
// Data types of values in the aggregation buffer
override def bufferSchema: StructType = StructType(StructField("buffer", ArrayType(StringType, containsNull = true)) :: Nil)
// The data type of the returned value
override def dataType: DataType = ArrayType(StringType, containsNull = true)
// Whether this function always returns the same output on the identical input
override def deterministic: Boolean = true
// Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
// standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
// the opportunity to update its values. Note that arrays and maps inside the buffer are still
// immutable. override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = List.empty[String]
}
// Updates the given aggregation buffer `buffer` with new input data from `input`
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
implicit val defaultFormats: DefaultFormats.type = org.json4s.DefaultFormats
var list = buffer.get(0).asInstanceOf[mutable.WrappedArray[String]].toList
val value = input.get(0)
if (value == null) {
list = list.+:(null)
} else {
list = list.+:(value.toString)
}
buffer(0) = list
}
// Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val listResult = buffer1.get(0).asInstanceOf[mutable.WrappedArray[String]].toList
val listTemp = buffer2.get(0).asInstanceOf[mutable.WrappedArray[String]].toList
buffer1(0) = listResult ++ listTemp
}
// Calculates the final result
override def evaluate(buffer: Row): List[String] = {
val list = buffer.get(0).asInstanceOf[mutable.WrappedArray[String]].reverse.toList
list
}
}
// 方法注册
spark.udf.register("collect_list_with_null", CollectListWithNullUDAF)