目录

基本统计

优质
小牛编辑
144浏览
2023-12-01

  MLlib支持RDD[Vector]列的概括统计,它通过调用StatisticscolStats方法实现。colStats返回一个MultivariateStatisticalSummary对象,这个对象包含列式的最大值、最小值、均值、方差等等。
下面是一个应用例子:

  1. import org.apache.spark.mllib.linalg.Vector
  2. import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
  3. val observations: RDD[Vector] = ... // an RDD of Vectors
  4. // Compute column summary statistics.
  5. val summary: MultivariateStatisticalSummary = Statistics.colStats(observations)
  6. println(summary.mean) // a dense vector containing the mean value for each column
  7. println(summary.variance) // column-wise variance
  8. println(summary.numNonzeros) // number of nonzeros in each column

  下面我们具体看看colStats方法的实现。

  1. def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = {
  2. new RowMatrix(X).computeColumnSummaryStatistics()
  3. }

  上面的代码非常明显,利用传人的RDD创建RowMatrix对象,利用方法computeColumnSummaryStatistics统计指标。

  1. def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
  2. val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)(
  3. (aggregator, data) => aggregator.add(data),
  4. (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
  5. updateNumRows(summary.count)
  6. summary
  7. }

  上面的代码调用了RDDtreeAggregate方法,treeAggregate是聚合方法,它迭代处理RDD中的数据,其中,(aggregator, data) => aggregator.add(data)处理每条数据,将其添加到MultivariateOnlineSummarizer
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)将不同分区的MultivariateOnlineSummarizer对象汇总。所以上述代码实现的重点是add方法和merge方法。它们都定义在MultivariateOnlineSummarizer中。
我们先来看add代码。

  1. @Since("1.1.0")
  2. def add(sample: Vector): this.type = add(sample, 1.0)
  3. private[spark] def add(instance: Vector, weight: Double): this.type = {
  4. if (weight == 0.0) return this
  5. if (n == 0) {
  6. n = instance.size
  7. currMean = Array.ofDim[Double](n)
  8. currM2n = Array.ofDim[Double](n)
  9. currM2 = Array.ofDim[Double](n)
  10. currL1 = Array.ofDim[Double](n)
  11. nnz = Array.ofDim[Double](n)
  12. currMax = Array.fill[Double](n)(Double.MinValue)
  13. currMin = Array.fill[Double](n)(Double.MaxValue)
  14. }
  15. val localCurrMean = currMean
  16. val localCurrM2n = currM2n
  17. val localCurrM2 = currM2
  18. val localCurrL1 = currL1
  19. val localNnz = nnz
  20. val localCurrMax = currMax
  21. val localCurrMin = currMin
  22. instance.foreachActive { (index, value) =>
  23. if (value != 0.0) {
  24. if (localCurrMax(index) < value) {
  25. localCurrMax(index) = value
  26. }
  27. if (localCurrMin(index) > value) {
  28. localCurrMin(index) = value
  29. }
  30. val prevMean = localCurrMean(index)
  31. val diff = value - prevMean
  32. localCurrMean(index) = prevMean + weight * diff / (localNnz(index) + weight)
  33. localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff
  34. localCurrM2(index) += weight * value * value
  35. localCurrL1(index) += weight * math.abs(value)
  36. localNnz(index) += weight
  37. }
  38. }
  39. weightSum += weight
  40. weightSquareSum += weight * weight
  41. totalCnt += 1
  42. this
  43. }

  这段代码使用了在线算法来计算均值和方差。根据文献【1】的介绍,计算均值和方差遵循如下的迭代公式:

1.1 1.2

  在上面的公式中,x表示样本均值,s表示样本方差,delta表示总体方差。MLlib实现的是带有权重的计算,所以使用的迭代公式略有不同,参考文献【2】。

1.1

  merge方法相对比较简单,它只是对两个MultivariateOnlineSummarizer对象的指标作合并操作。

  1. def merge(other: MultivariateOnlineSummarizer): this.type = {
  2. if (this.weightSum != 0.0 && other.weightSum != 0.0) {
  3. totalCnt += other.totalCnt
  4. weightSum += other.weightSum
  5. weightSquareSum += other.weightSquareSum
  6. var i = 0
  7. while (i < n) {
  8. val thisNnz = nnz(i)
  9. val otherNnz = other.nnz(i)
  10. val totalNnz = thisNnz + otherNnz
  11. if (totalNnz != 0.0) {
  12. val deltaMean = other.currMean(i) - currMean(i)
  13. // merge mean together
  14. currMean(i) += deltaMean * otherNnz / totalNnz
  15. // merge m2n together,不单纯是累加
  16. currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz
  17. // merge m2 together
  18. currM2(i) += other.currM2(i)
  19. // merge l1 together
  20. currL1(i) += other.currL1(i)
  21. // merge max and min
  22. currMax(i) = math.max(currMax(i), other.currMax(i))
  23. currMin(i) = math.min(currMin(i), other.currMin(i))
  24. }
  25. nnz(i) = totalNnz
  26. i += 1
  27. }
  28. } else if (weightSum == 0.0 && other.weightSum != 0.0) {
  29. this.n = other.n
  30. this.currMean = other.currMean.clone()
  31. this.currM2n = other.currM2n.clone()
  32. this.currM2 = other.currM2.clone()
  33. this.currL1 = other.currL1.clone()
  34. this.totalCnt = other.totalCnt
  35. this.weightSum = other.weightSum
  36. this.weightSquareSum = other.weightSquareSum
  37. this.nnz = other.nnz.clone()
  38. this.currMax = other.currMax.clone()
  39. this.currMin = other.currMin.clone()
  40. }
  41. this
  42. }

  这里需要注意的是,在线算法的并行化实现是一种特殊情况。例如样本集X分到两个不同的分区,分别为X_AX_B,那么它们的合并需要满足下面的公式:

1.6

  依靠文献【3】我们可以知道,样本方差的无偏估计由下面的公式给出:

1.4 1.5

  所以,真实的样本均值和样本方差通过下面的代码实现。

  1. override def mean: Vector = {
  2. val realMean = Array.ofDim[Double](n)
  3. var i = 0
  4. while (i < n) {
  5. realMean(i) = currMean(i) * (nnz(i) / weightSum)
  6. i += 1
  7. }
  8. Vectors.dense(realMean)
  9. }
  10. override def variance: Vector = {
  11. val realVariance = Array.ofDim[Double](n)
  12. val denominator = weightSum - (weightSquareSum / weightSum)
  13. // Sample variance is computed, if the denominator is less than 0, the variance is just 0.
  14. if (denominator > 0.0) {
  15. val deltaMean = currMean
  16. var i = 0
  17. val len = currM2n.length
  18. while (i < len) {
  19. realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) *
  20. (weightSum - nnz(i)) / weightSum) / denominator
  21. i += 1
  22. }
  23. }
  24. Vectors.dense(realVariance)
  25. }

参考文献

【1】Algorithms for calculating variance

【2】Updating mean and variance estimates: an improved method

【3】Weighted arithmetic mean