组合树 - 随机森林


1 Bagging

  Bagging采用自助采样法(bootstrap sampling)采样数据。给定包含m个样本的数据集,我们先随机取出一个样本放入采样集中,再把该样本放回初始数据集,使得下次采样时,样本仍可能被选中,




2 随机森林



3 随机森林在分布式环境下的优化策略


  • 切分点抽样统计,如下图所示。在单机环境下的决策树对连续变量进行切分点选择时,一般是通过对特征点进行排序,然后取相邻两个数之间的点作为切分点,这在单机环境下是可行的,但如果在分布式环境下如此操作的话,
  • 特征装箱(Binning),如下图所示。决策树的构建过程就是对特征的取值不断进行划分的过程,对于离散的特征,如果有M个值,最多有2^(M-1) - 1个划分。如果值是有序的,那么就最多M-1个划分。
  • 逐层训练(level-wise training),如下图所示。单机版本的决策树生成过程是通过递归调用(本质上是深度优先)的方式构造树,在构造树的同时,需要移动数据,将同一个子节点的数据移动到一起。

4 使用实例


  1. import org.apache.spark.mllib.tree.RandomForest
  2. import org.apache.spark.mllib.tree.model.RandomForestModel
  3. import org.apache.spark.mllib.util.MLUtils
  4. // Load and parse the data file.
  5. val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
  6. // Split the data into training and test sets (30% held out for testing)
  7. val splits = data.randomSplit(Array(0.7, 0.3))
  8. val (trainingData, testData) = (splits(0), splits(1))
  9. // Train a RandomForest model.
  10. // 空的类别特征信息表示所有的特征都是连续的.
  11. val numClasses = 2
  12. val categoricalFeaturesInfo = Map[Int, Int]()
  13. val numTrees = 3 // Use more in practice.
  14. val featureSubsetStrategy = "auto" // Let the algorithm choose.
  15. val impurity = "gini"
  16. val maxDepth = 4
  17. val maxBins = 32
  18. val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
  19. numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
  20. // Evaluate model on test instances and compute test error
  21. val labelAndPreds = testData.map { point =>
  22. val prediction = model.predict(point.features)
  23. (point.label, prediction)
  24. }
  25. val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
  26. println("Test Error = " + testErr)
  27. println("Learned classification forest model:\n" + model.toDebugString)


  1. import org.apache.spark.mllib.tree.RandomForest
  2. import org.apache.spark.mllib.tree.model.RandomForestModel
  3. import org.apache.spark.mllib.util.MLUtils
  4. // Load and parse the data file.
  5. val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
  6. // Split the data into training and test sets (30% held out for testing)
  7. val splits = data.randomSplit(Array(0.7, 0.3))
  8. val (trainingData, testData) = (splits(0), splits(1))
  9. // Train a RandomForest model.
  10. // 空的类别特征信息表示所有的特征都是连续的
  11. val numClasses = 2
  12. val categoricalFeaturesInfo = Map[Int, Int]()
  13. val numTrees = 3 // Use more in practice.
  14. val featureSubsetStrategy = "auto" // Let the algorithm choose.
  15. val impurity = "variance"
  16. val maxDepth = 4
  17. val maxBins = 32
  18. val model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo,
  19. numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
  20. // Evaluate model on test instances and compute test error
  21. val labelsAndPredictions = testData.map { point =>
  22. val prediction = model.predict(point.features)
  23. (point.label, prediction)
  24. }
  25. val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
  26. println("Test Mean Squared Error = " + testMSE)
  27. println("Learned regression forest model:\n" + model.toDebugString)

5 源码分析

5.1 训练分析


5.1.1 初始化

  1. val retaggedInput = input.retag(classOf[LabeledPoint])
  2. //建立决策树的元数据信息(分裂点位置、箱子数及各箱子包含特征属性的值等)
  3. val metadata =
  4. DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
  5. //找到切分点(splits)及箱子信息(Bins)
  6. //对于连续型特征,利用切分点抽样统计简化计算
  7. //对于离散型特征,如果是无序的,则最多有个 splits=2^(numBins-1)-1 划分
  8. //如果是有序的,则最多有 splits=numBins-1 个划分
  9. val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
  10. //转换成树形的 RDD 类型,转换后,所有样本点已经按分裂点条件分到了各自的箱子中
  11. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
  12. val withReplacement = if (numTrees > 1) true else false
  13. // convertToBaggedRDD 方法使得每棵树就是样本的一个子集
  14. val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput,
  15. strategy.subsamplingRate, numTrees,
  16. withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)
  17. //决策树的深度,最大为30
  18. val maxDepth = strategy.maxDepth
  19. //聚合的最大内存
  20. val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
  21. val maxMemoryPerNode = {
  22. val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
  23. // Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
  24. Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
  25. .take(metadata.numFeaturesPerNode).map(_._2))
  26. } else {
  27. None
  28. }
  29. //计算聚合操作时节点的内存
  30. RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
  31. }


  1. def buildMetadata(
  2. input: RDD[LabeledPoint],
  3. strategy: Strategy,
  4. numTrees: Int,
  5. featureSubsetStrategy: String): DecisionTreeMetadata = {
  6. //特征数
  7. val numFeatures = input.map(_.features.size).take(1).headOption.getOrElse {
  8. throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " +
  9. s"but was given by empty one.")
  10. }
  11. val numExamples = input.count()
  12. val numClasses = strategy.algo match {
  13. case Classification => strategy.numClasses
  14. case Regression => 0
  15. }
  16. //最大可能的装箱数
  17. val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
  18. if (maxPossibleBins < strategy.maxBins) {
  19. logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" +
  20. s" (= number of training instances)")
  21. }
  22. // We check the number of bins here against maxPossibleBins.
  23. // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified
  24. // based on the number of training examples.
  25. //最大分类数要小于最大可能装箱数
  26. //这里categoricalFeaturesInfo是传入的信息,这个map保存特征的类别信息。
  27. //例如,(n->k)表示特征k包含的类别有(0,1,...,k-1)
  28. if (strategy.categoricalFeaturesInfo.nonEmpty) {
  29. val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
  30. val maxCategory =
  31. strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1
  32. require(maxCategoriesPerFeature <= maxPossibleBins,
  33. s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " +
  34. s"number of values in each categorical feature, but categorical feature $maxCategory " +
  35. s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " +
  36. "features with a large number of values, or add more training examples.")
  37. }
  38. val unorderedFeatures = new mutable.HashSet[Int]()
  39. val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
  40. if (numClasses > 2) {
  41. // 多分类
  42. val maxCategoriesForUnorderedFeature =
  43. ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
  44. strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
  45. //如果类别特征只有1个类,我们把它看成连续的特征
  46. if (numCategories > 1) {
  47. // Decide if some categorical features should be treated as unordered features,
  48. // which require 2 * ((1 << numCategories - 1) - 1) bins.
  49. // We do this check with log values to prevent overflows in case numCategories is large.
  50. // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
  51. if (numCategories <= maxCategoriesForUnorderedFeature) {
  52. unorderedFeatures.add(featureIndex)
  53. numBins(featureIndex) = numUnorderedBins(numCategories)
  54. } else {
  55. numBins(featureIndex) = numCategories
  56. }
  57. }
  58. }
  59. } else {
  60. // 二分类或者回归
  61. strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
  62. //如果类别特征只有1个类,我们把它看成连续的特征
  63. if (numCategories > 1) {
  64. numBins(featureIndex) = numCategories
  65. }
  66. }
  67. }
  68. // 设置每个节点的特征数 (对随机森林而言).
  69. val _featureSubsetStrategy = featureSubsetStrategy match {
  70. case "auto" =>
  71. if (numTrees == 1) {//决策树时,使用所有特征
  72. "all"
  73. } else {
  74. if (strategy.algo == Classification) {//分类时,使用开平方
  75. "sqrt"
  76. } else { //回归时,使用1/3的特征
  77. "onethird"
  78. }
  79. }
  80. case _ => featureSubsetStrategy
  81. }
  82. val numFeaturesPerNode: Int = _featureSubsetStrategy match {
  83. case "all" => numFeatures
  84. case "sqrt" => math.sqrt(numFeatures).ceil.toInt
  85. case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
  86. case "onethird" => (numFeatures / 3.0).ceil.toInt
  87. }
  88. new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
  89. strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
  90. strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
  91. strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
  92. }


  1. /**
  2. * Returns splits and bins for decision tree calculation.
  3. * Continuous and categorical features are handled differently.
  4. *
  5. * Continuous features:
  6. * For each feature, there are numBins - 1 possible splits representing the possible binary
  7. * decisions at each node in the tree.
  8. * This finds locations (feature values) for splits using a subsample of the data.
  9. *
  10. * Categorical features:
  11. * For each feature, there is 1 bin per split.
  12. * Splits and bins are handled in 2 ways:
  13. * (a) "unordered features"
  14. * For multiclass classification with a low-arity feature
  15. * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
  16. * the feature is split based on subsets of categories.
  17. * (b) "ordered features"
  18. * For regression and binary classification,
  19. * and for multiclass classification with a high-arity feature,
  20. * there is one bin per category.
  21. *
  22. * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
  23. * @param metadata Learning and dataset metadata
  24. * @return A tuple of (splits, bins).
  25. * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
  26. * of size (numFeatures, numSplits).
  27. * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
  28. * of size (numFeatures, numBins).
  29. */
  30. protected[tree] def findSplitsBins(
  31. input: RDD[LabeledPoint],
  32. metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {
  33. //特征数
  34. val numFeatures = metadata.numFeatures
  35. // Sample the input only if there are continuous features.
  36. // 判断特征中是否存在连续特征
  37. val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
  38. val sampledInput = if (continuousFeatures.nonEmpty) {
  39. // Calculate the number of samples for approximate quantile calculation.
  40. //采样样本数量,最少有 10000 个
  41. val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
  42. //计算采样比例
  43. val fraction = if (requiredSamples < metadata.numExamples) {
  44. requiredSamples.toDouble / metadata.numExamples
  45. } else {
  46. 1.0
  47. }
  48. //采样数据,有放回采样
  49. input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt())
  50. } else {
  51. input.sparkContext.emptyRDD[LabeledPoint]
  52. }
  53. //分裂点策略,目前 Spark 中只实现了一种策略:排序 Sort
  54. metadata.quantileStrategy match {
  55. case Sort =>
  56. findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures)
  57. case MinMax =>
  58. throw new UnsupportedOperationException("minmax not supported yet.")
  59. case ApproxHist =>
  60. throw new UnsupportedOperationException("approximate histogram not supported yet.")
  61. }
  62. }


  1. private def findSplitsBinsBySorting(
  2. input: RDD[LabeledPoint],
  3. metadata: DecisionTreeMetadata,
  4. continuousFeatures: IndexedSeq[Int]): (Array[Array[Split]], Array[Array[Bin]]) = {
  5. def findSplits(
  6. featureIndex: Int,
  7. featureSamples: Iterable[Double]): (Int, (Array[Split], Array[Bin])) = {
  8. //每个特征分别对应一组切分点位置,这里splits是有序的
  9. val splits = {
  10. // findSplitsForContinuousFeature 返回连续特征的所有切分位置
  11. val featureSplits = findSplitsForContinuousFeature(
  12. featureSamples.toArray,
  13. metadata,
  14. featureIndex)
  15. featureSplits.map(threshold => new Split(featureIndex, threshold, Continuous, Nil))
  16. }
  17. //存放切分点位置对应的箱子信息
  18. val bins = {
  19. //采用最小阈值 Double.MinValue 作为最左边的分裂位置并进行装箱
  20. val lowSplit = new DummyLowSplit(featureIndex, Continuous)
  21. //最后一个箱子的计算采用最大阈值 Double.MaxValue 作为最右边的切分位置
  22. val highSplit = new DummyHighSplit(featureIndex, Continuous)
  23. // tack the dummy splits on either side of the computed splits
  24. val allSplits = lowSplit +: splits.toSeq :+ highSplit
  25. //将切分点两两结合成一个箱子
  26. allSplits.sliding(2).map {
  27. case Seq(left, right) => new Bin(left, right, Continuous, Double.MinValue)
  28. }.toArray
  29. }
  30. (featureIndex, (splits, bins))
  31. }
  32. val continuousSplits = {
  33. // reduce the parallelism for split computations when there are less
  34. // continuous features than input partitions. this prevents tasks from
  35. // being spun up that will definitely do no work.
  36. val numPartitions = math.min(continuousFeatures.length, input.partitions.length)
  37. input
  38. .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx))))
  39. .groupByKey(numPartitions)
  40. .map { case (k, v) => findSplits(k, v) }
  41. .collectAsMap()
  42. }
  43. val numFeatures = metadata.numFeatures
  44. //遍历所有特征
  45. val (splits, bins) = Range(0, numFeatures).unzip {
  46. //处理连续特征的情况
  47. case i if metadata.isContinuous(i) =>
  48. val (split, bin) = continuousSplits(i)
  49. metadata.setNumSplits(i, split.length)
  50. (split, bin)
  51. //处理离散特征且无序的情况
  52. case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
  53. // Unordered features
  54. // 2^(maxFeatureValue - 1) - 1 combinations
  55. val featureArity = metadata.featureArity(i)
  56. val split = Range(0, metadata.numSplits(i)).map { splitIndex =>
  57. val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
  58. new Split(i, Double.MinValue, Categorical, categories)
  59. }
  60. // For unordered categorical features, there is no need to construct the bins.
  61. // since there is a one-to-one correspondence between the splits and the bins.
  62. (split.toArray, Array.empty[Bin])
  63. //处理离散特征且有序的情况
  64. case i if metadata.isCategorical(i) =>
  65. //有序特征无需处理,箱子与特征值对应
  66. // Ordered features
  67. // Bins correspond to feature values, so we do not need to compute splits or bins
  68. // beforehand. Splits are constructed as needed during training.
  69. (Array.empty[Split], Array.empty[Bin])
  70. }
  71. (splits.toArray, bins.toArray)
  72. }


  1. private[tree] def findSplitsForContinuousFeature(
  2. featureSamples: Array[Double],
  3. metadata: DecisionTreeMetadata,
  4. featureIndex: Int): Array[Double] = {
  5. val splits = {
  6. //切分数是bin的数量减1,即m-1
  7. val numSplits = metadata.numSplits(featureIndex)
  8. // (特征,特征出现的次数)
  9. val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
  10. m + ((x, m.getOrElse(x, 0) + 1))
  11. }
  12. // 根据特征进行排序
  13. val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
  14. // if possible splits is not enough or just enough, just return all possible splits
  15. val possibleSplits = valueCounts.length
  16. //如果特征数小于切分数,所有特征均作为切分点
  17. if (possibleSplits <= numSplits) {
  18. valueCounts.map(_._1)
  19. } else {
  20. // 等频切分
  21. // 切分点之间的步长
  22. val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
  23. val splitsBuilder = Array.newBuilder[Double]
  24. var index = 1
  25. // currentCount: sum of counts of values that have been visited
  26. //第一个特征的出现次数
  27. var currentCount = valueCounts(0)._2
  28. // targetCount: target value for `currentCount`.
  29. // If `currentCount` is closest value to `targetCount`,
  30. // then current value is a split threshold.
  31. // After finding a split threshold, `targetCount` is added by stride.
  32. // 如果currentCount离targetCount最近,那么当前值是切分点
  33. var targetCount = stride
  34. while (index < valueCounts.length) {
  35. val previousCount = currentCount
  36. currentCount += valueCounts(index)._2
  37. val previousGap = math.abs(previousCount - targetCount)
  38. val currentGap = math.abs(currentCount - targetCount)
  39. // If adding count of current value to currentCount
  40. // makes the gap between currentCount and targetCount smaller,
  41. // previous value is a split threshold.
  42. if (previousGap < currentGap) {
  43. splitsBuilder += valueCounts(index - 1)._1
  44. targetCount += stride
  45. }
  46. index += 1
  47. }
  48. splitsBuilder.result()
  49. }
  50. }
  51. splits
  52. }


5.1.2 迭代构建随机森林

  1. //节点是否使用缓存,节点 ID 从 1 开始,1 即为这颗树的根节点,左节点为 2,右节点为 3,依次递增下去
  2. val nodeIdCache = if (strategy.useNodeIdCache) {
  3. Some(NodeIdCache.init(
  4. data = baggedInput,
  5. numTrees = numTrees,
  6. checkpointInterval = strategy.checkpointInterval,
  7. initVal = 1))
  8. } else {
  9. None
  10. }
  11. // FIFO queue of nodes to train: (treeIndex, node)
  12. val nodeQueue = new mutable.Queue[(Int, Node)]()
  13. val rng = new scala.util.Random()
  14. rng.setSeed(seed)
  15. // Allocate and queue root nodes.
  16. //创建树的根节点
  17. val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
  18. //将(树的索引,树的根节点)入队,树索引从 0 开始,根节点从 1 开始
  19. Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
  20. while (nodeQueue.nonEmpty) {
  21. // Collect some nodes to split, and choose features for each node (if subsampling).
  22. // Each group of nodes may come from one or multiple trees, and at multiple levels.
  23. // 取得每个树所有需要切分的节点,nodesForGroup表示需要切分的节点
  24. val (nodesForGroup, treeToNodeToIndexInfo) =
  25. RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
  26. //找出最优切点
  27. DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
  28. treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
  29. }


  • 取得每个树所有需要切分的节点
  1. private[tree] def selectNodesToSplit(
  2. nodeQueue: mutable.Queue[(Int, Node)],
  3. maxMemoryUsage: Long,
  4. metadata: DecisionTreeMetadata,
  5. rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, NodeIndexInfo]]) = {
  6. // nodesForGroup保存需要切分的节点,treeIndex --> nodes
  7. val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[Node]]()
  8. // mutableTreeToNodeToIndexInfo保存每个节点中选中特征的索引
  9. // treeIndex --> (global) node index --> (node index in group, feature indices)
  10. //(global) node index是树中的索引,组中节点索引的范围是[0, numNodesInGroup)
  11. val mutableTreeToNodeToIndexInfo =
  12. new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
  13. var memUsage: Long = 0L
  14. var numNodesInGroup = 0
  15. while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {
  16. val (treeIndex, node) = nodeQueue.head
  17. // Choose subset of features for node (if subsampling).
  18. // 选中特征子集
  19. val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
  20. Some(SamplingUtils.reservoirSampleAndCount(Range(0,
  21. metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1)
  22. } else {
  23. None
  24. }
  25. // Check if enough memory remains to add this node to the group.
  26. // 检查是否有足够的内存
  27. val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
  28. if (memUsage + nodeMemUsage <= maxMemoryUsage) {
  29. nodeQueue.dequeue()
  30. mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[Node]()) += node
  31. mutableTreeToNodeToIndexInfo
  32. .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
  33. = new NodeIndexInfo(numNodesInGroup, featureSubset)
  34. }
  35. numNodesInGroup += 1
  36. memUsage += nodeMemUsage
  37. }
  38. // 将可变map转换为不可变map
  39. val nodesForGroup: Map[Int, Array[Node]] = mutableNodesForGroup.mapValues(_.toArray).toMap
  40. val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
  41. (nodesForGroup, treeToNodeToIndexInfo)
  42. }
  • 选中最优切分
  1. //所有可切分的节点
  2. val nodes = new Array[Node](numNodes)
  3. nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
  4. nodesForTree.foreach { node =>
  5. nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
  6. }
  7. }
  8. // In each partition, iterate all instances and compute aggregate stats for each node,
  9. // yield an (nodeIndex, nodeAggregateStats) pair for each node.
  10. // After a `reduceByKey` operation,
  11. // stats of a node will be shuffled to a particular partition and be combined together,
  12. // then best splits for nodes are found there.
  13. // Finally, only best Splits for nodes are collected to driver to construct decision tree.
  14. //获取节点对应的特征
  15. val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
  16. val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
  17. val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
  18. input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
  19. // Construct a nodeStatsAggregators array to hold node aggregate stats,
  20. // each node will have a nodeStatsAggregator
  21. val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
  22. //节点对应的特征集
  23. val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
  24. Some(nodeToFeatures(nodeIndex))
  25. }
  26. // DTStatsAggregator,其中引用了 ImpurityAggregator,给出计算不纯度 impurity 的逻辑
  27. new DTStatsAggregator(metadata, featuresForNode)
  28. }
  29. // 迭代当前分区的所有对象,更新聚合统计信息,统计信息即采样数据的权重值
  30. points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
  31. // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
  32. // which can be combined with other partition using `reduceByKey`
  33. nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
  34. }
  35. } else {
  36. input.mapPartitions { points =>
  37. // Construct a nodeStatsAggregators array to hold node aggregate stats,
  38. // each node will have a nodeStatsAggregator
  39. val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
  40. //节点对应的特征集
  41. val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
  42. Some(nodeToFeatures(nodeIndex))
  43. }
  44. // DTStatsAggregator,其中引用了 ImpurityAggregator,给出计算不纯度 impurity 的逻辑
  45. new DTStatsAggregator(metadata, featuresForNode)
  46. }
  47. // 迭代当前分区的所有对象,更新聚合统计信息
  48. points.foreach(binSeqOp(nodeStatsAggregators, _))
  49. // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
  50. // which can be combined with other partition using `reduceByKey`
  51. nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
  52. }
  53. }
  54. val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))
  55. .map { case (nodeIndex, aggStats) =>
  56. val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
  57. nodeToFeatures(nodeIndex)
  58. }
  59. // find best split for each node
  60. val (split: Split, stats: InformationGainStats, predict: Predict) =
  61. binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
  62. (nodeIndex, (split, stats, predict))
  63. }.collectAsMap()


  1. private def binsToBestSplit(
  2. binAggregates: DTStatsAggregator,
  3. splits: Array[Array[Split]],
  4. featuresForNode: Option[Array[Int]],
  5. node: Node): (Split, InformationGainStats, Predict) = {
  6. // 如果当前节点是根节点,计算预测和不纯度
  7. val level = Node.indexToLevel(node.id)
  8. var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) {
  9. None
  10. } else {
  11. Some((node.predict, node.impurity))
  12. }
  13. // 对各特征及切分点,计算其信息增益并从中选择最优 (feature, split)
  14. val (bestSplit, bestSplitStats) =
  15. Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
  16. val featureIndex = if (featuresForNode.nonEmpty) {
  17. featuresForNode.get.apply(featureIndexIdx)
  18. } else {
  19. featureIndexIdx
  20. }
  21. val numSplits = binAggregates.metadata.numSplits(featureIndex)
  22. //特征为连续值的情况
  23. if (binAggregates.metadata.isContinuous(featureIndex)) {
  24. // Cumulative sum (scanLeft) of bin statistics.
  25. // Afterwards, binAggregates for a bin is the sum of aggregates for
  26. // that bin + all preceding bins.
  27. val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
  28. var splitIndex = 0
  29. while (splitIndex < numSplits) {
  30. binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
  31. splitIndex += 1
  32. }
  33. // Find best split.
  34. val (bestFeatureSplitIndex, bestFeatureGainStats) =
  35. Range(0, numSplits).map { case splitIdx =>
  36. //计算 leftChild 及 rightChild 子节点的 impurity
  37. val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
  38. val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
  39. rightChildStats.subtract(leftChildStats)
  40. //求 impurity 的预测值,采用的是平均值计算
  41. predictWithImpurity = Some(predictWithImpurity.getOrElse(
  42. calculatePredictImpurity(leftChildStats, rightChildStats)))
  43. //求信息增益 information gain 值,用于评估切分点是否最优,请参考决策树中1.4.4章节的介绍
  44. val gainStats = calculateGainForSplit(leftChildStats,
  45. rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
  46. (splitIdx, gainStats)
  47. }.maxBy(_._2.gain)
  48. (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
  49. }
  50. //无序离散特征时的情况
  51. else if (binAggregates.metadata.isUnordered(featureIndex)) {
  52. // Unordered categorical feature
  53. val (leftChildOffset, rightChildOffset) =
  54. binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
  55. val (bestFeatureSplitIndex, bestFeatureGainStats) =
  56. Range(0, numSplits).map { splitIndex =>
  57. val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
  58. val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
  59. predictWithImpurity = Some(predictWithImpurity.getOrElse(
  60. calculatePredictImpurity(leftChildStats, rightChildStats)))
  61. val gainStats = calculateGainForSplit(leftChildStats,
  62. rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
  63. (splitIndex, gainStats)
  64. }.maxBy(_._2.gain)
  65. (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
  66. } else {//有序离散特征时的情况
  67. // Ordered categorical feature
  68. val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
  69. val numBins = binAggregates.metadata.numBins(featureIndex)
  70. /* Each bin is one category (feature value).
  71. * The bins are ordered based on centroidForCategories, and this ordering determines which
  72. * splits are considered. (With K categories, we consider K - 1 possible splits.)
  73. *
  74. * centroidForCategories is a list: (category, centroid)
  75. */
  76. //多元分类时的情况
  77. val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
  78. // For categorical variables in multiclass classification,
  79. // the bins are ordered by the impurity of their corresponding labels.
  80. Range(0, numBins).map { case featureValue =>
  81. val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
  82. val centroid = if (categoryStats.count != 0) {
  83. // impurity 求的就是均方差
  84. categoryStats.calculate()
  85. } else {
  86. Double.MaxValue
  87. }
  88. (featureValue, centroid)
  89. }
  90. } else { // 回归或二元分类时的情况
  91. // For categorical variables in regression and binary classification,
  92. // the bins are ordered by the centroid of their corresponding labels.
  93. Range(0, numBins).map { case featureValue =>
  94. val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
  95. val centroid = if (categoryStats.count != 0) {
  96. //求的就是平均值作为 impurity
  97. categoryStats.predict
  98. } else {
  99. Double.MaxValue
  100. }
  101. (featureValue, centroid)
  102. }
  103. }
  104. // bins sorted by centroids
  105. val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
  106. // Cumulative sum (scanLeft) of bin statistics.
  107. // Afterwards, binAggregates for a bin is the sum of aggregates for
  108. // that bin + all preceding bins.
  109. var splitIndex = 0
  110. while (splitIndex < numSplits) {
  111. val currentCategory = categoriesSortedByCentroid(splitIndex)._1
  112. val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
  113. //将两个箱子的状态信息进行合并
  114. binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
  115. splitIndex += 1
  116. }
  117. // lastCategory = index of bin with total aggregates for this (node, feature)
  118. val lastCategory = categoriesSortedByCentroid.last._1
  119. // Find best split.
  120. //通过信息增益值选择最优切分点
  121. val (bestFeatureSplitIndex, bestFeatureGainStats) =
  122. Range(0, numSplits).map { splitIndex =>
  123. val featureValue = categoriesSortedByCentroid(splitIndex)._1
  124. val leftChildStats =
  125. binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
  126. val rightChildStats =
  127. binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
  128. rightChildStats.subtract(leftChildStats)
  129. predictWithImpurity = Some(predictWithImpurity.getOrElse(
  130. calculatePredictImpurity(leftChildStats, rightChildStats)))
  131. val gainStats = calculateGainForSplit(leftChildStats,
  132. rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
  133. (splitIndex, gainStats)
  134. }.maxBy(_._2.gain)
  135. val categoriesForSplit =
  136. categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
  137. val bestFeatureSplit =
  138. new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
  139. (bestFeatureSplit, bestFeatureGainStats)
  140. }
  141. }.maxBy(_._2.gain)
  142. (bestSplit, bestSplitStats, predictWithImpurity.get._1)
  143. }

5.2 预测分析


  1. //不同的策略采用不同的预测方法
  2. def predict(features: Vector): Double = {
  3. (algo, combiningStrategy) match {
  4. case (Regression, Sum) =>
  5. predictBySumming(features)
  6. case (Regression, Average) =>
  7. predictBySumming(features) / sumWeights
  8. case (Classification, Sum) => // binary classification
  9. val prediction = predictBySumming(features)
  10. // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info.
  11. if (prediction > 0.0) 1.0 else 0.0
  12. case (Classification, Vote) =>
  13. predictByVoting(features)
  14. case _ =>
  15. throw new IllegalArgumentException()
  16. }
  17. }
  18. private def predictBySumming(features: Vector): Double = {
  19. val treePredictions = trees.map(_.predict(features))
  20. //两个向量的内集
  21. blas.ddot(numTrees, treePredictions, 1, treeWeights, 1)
  22. }
  23. //通过投票选举
  24. private def predictByVoting(features: Vector): Double = {
  25. val votes = mutable.Map.empty[Int, Double]
  26. trees.view.zip(treeWeights).foreach { case (tree, weight) =>
  27. val prediction = tree.predict(features).toInt
  28. votes(prediction) = votes.getOrElse(prediction, 0.0) + weight
  29. }
  30. votes.maxBy(_._2)._1
  31. }



