最优化算法 - 带权最小二乘


1 原理


  • $w_i$表示第i个观察样本的权重;
  • $a_i$表示第i个观察样本的特征向量;
  • $b_i$表示第i个观察样本的标签。


minimize{x}\frac{1}{2} \sum{i=1}^n \frac{wi(a_i^T x -b_i)^2}{\sum{k=1}^n wk} + \frac{1}{2}\frac{\lambda}{\delta}\sum{j=1}^m(\sigma{j} x{j})^2



  spark ml中使用WeightedLeastSquares求解带权最小二乘问题。WeightedLeastSquares仅仅支持L2正则化,并且提供了正则化和标准化
的开关。为了使正太方程(normal equation)方法有效,特征数不能超过4096。如果超过4096,用L-BFGS代替。下面从代码层面介绍带权最小二乘优化算法

2 代码解析


  1. private[ml] class WeightedLeastSquares(
  2. val fitIntercept: Boolean, //是否使用截距
  3. val regParam: Double, //L2正则化参数,指上面公式中的lambda
  4. val elasticNetParam: Double, // alpha,控制L1和L2正则化
  5. val standardizeFeatures: Boolean, // 是否标准化特征
  6. val standardizeLabel: Boolean, // 是否标准化标签
  7. val solverType: WeightedLeastSquares.Solver = WeightedLeastSquares.Auto,
  8. val maxIter: Int = 100, // 迭代次数
  9. val tol: Double = 1e-6) extends Logging with Serializable
  10. sealed trait Solver
  11. case object Auto extends Solver
  12. case object Cholesky extends Solver
  13. case object QuasiNewton extends Solver


2.1 求解过程


  1. def fit(instances: RDD[Instance]): WeightedLeastSquaresModel


  • 1 统计样本信息
  1. val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_))


  1. private class Aggregator extends Serializable {
  2. var initialized: Boolean = false
  3. var k: Int = _ // 特征数
  4. var count: Long = _ // 样本数
  5. var triK: Int = _ // 对角矩阵保存的元素个数
  6. var wSum: Double = _ // 权重和
  7. private var wwSum: Double = _ // 权重的平方和
  8. private var bSum: Double = _ // 带权标签和
  9. private var bbSum: Double = _ // 带权标签的平方和
  10. private var aSum: DenseVector = _ // 带权特征和
  11. private var abSum: DenseVector = _ // 带权特征标签相乘和
  12. private var aaSum: DenseVector = _ // 带权特征平方和
  13. }


  1. /**
  2. * Adds an instance.
  3. */
  4. def add(instance: Instance): this.type = {
  5. val Instance(l, w, f) = instance
  6. val ak = f.size
  7. if (!initialized) {
  8. init(ak)
  9. }
  10. assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.")
  11. count += 1L
  12. wSum += w
  13. wwSum += w * w
  14. bSum += w * l
  15. bbSum += w * l * l
  16. BLAS.axpy(w, f, aSum)
  17. BLAS.axpy(w * l, f, abSum)
  18. BLAS.spr(w, f, aaSum) // wff^T
  19. this
  20. }
  21. /**
  22. * Merges another [[Aggregator]].
  23. */
  24. def merge(other: Aggregator): this.type = {
  25. if (!other.initialized) {
  26. this
  27. } else {
  28. if (!initialized) {
  29. init(other.k)
  30. }
  31. assert(k == other.k, s"dimension mismatch: this.k = $k but other.k = ${other.k}")
  32. count += other.count
  33. wSum += other.wSum
  34. wwSum += other.wwSum
  35. bSum += other.bSum
  36. bbSum += other.bbSum
  37. BLAS.axpy(1.0, other.aSum, aSum)
  38. BLAS.axpy(1.0, other.abSum, abSum)
  39. BLAS.axpy(1.0, other.aaSum, aaSum)
  40. this
  41. }


  1. aBar: 特征加权平均数
  2. bBar: 标签加权平均数
  3. aaBar: 特征平方加权平均数
  4. bbBar: 标签平方加权平均数
  5. aStd: 特征的加权总体标准差
  6. bStd: 标签的加权总体标准差
  7. aVar: 带权的特征总体方差


  1. // 缩放bBar和 bbBar
  2. val bBar = summary.bBar / bStd
  3. val bbBar = summary.bbBar / (bStd * bStd)
  4. val aStd = summary.aStd
  5. val aStdValues = aStd.values
  6. // 缩放aBar
  7. val aBar = {
  8. val _aBar = summary.aBar
  9. val _aBarValues = _aBar.values
  10. var i = 0
  11. // scale aBar to standardized space in-place
  12. while (i < numFeatures) {
  13. if (aStdValues(i) == 0.0) {
  14. _aBarValues(i) = 0.0
  15. } else {
  16. _aBarValues(i) /= aStdValues(i)
  17. }
  18. i += 1
  19. }
  20. _aBar
  21. }
  22. val aBarValues = aBar.values
  23. // 缩放 abBar
  24. val abBar = {
  25. val _abBar = summary.abBar
  26. val _abBarValues = _abBar.values
  27. var i = 0
  28. // scale abBar to standardized space in-place
  29. while (i < numFeatures) {
  30. if (aStdValues(i) == 0.0) {
  31. _abBarValues(i) = 0.0
  32. } else {
  33. _abBarValues(i) /= (aStdValues(i) * bStd)
  34. }
  35. i += 1
  36. }
  37. _abBar
  38. }
  39. val abBarValues = abBar.values
  40. // 缩放aaBar
  41. val aaBar = {
  42. val _aaBar = summary.aaBar
  43. val _aaBarValues = _aaBar.values
  44. var j = 0
  45. var p = 0
  46. // scale aaBar to standardized space in-place
  47. while (j < numFeatures) {
  48. val aStdJ = aStdValues(j)
  49. var i = 0
  50. while (i <= j) {
  51. val aStdI = aStdValues(i)
  52. if (aStdJ == 0.0 || aStdI == 0.0) {
  53. _aaBarValues(p) = 0.0
  54. } else {
  55. _aaBarValues(p) /= (aStdI * aStdJ)
  56. }
  57. p += 1
  58. i += 1
  59. }
  60. j += 1
  61. }
  62. _aaBar
  63. }
  64. val aaBarValues = aaBar.values
  • 2 处理L2正则项
  1. val effectiveRegParam = regParam / bStd
  2. val effectiveL1RegParam = elasticNetParam * effectiveRegParam
  3. val effectiveL2RegParam = (1.0 - elasticNetParam) * effectiveRegParam
  4. // 添加L2正则项到对角矩阵中
  5. var i = 0
  6. var j = 2
  7. while (i < triK) {
  8. var lambda = effectiveL2RegParam
  9. if (!standardizeFeatures) {
  10. val std = aStdValues(j - 2)
  11. if (std != 0.0) {
  12. lambda /= (std * std) //正则项标准化
  13. } else {
  14. lambda = 0.0
  15. }
  16. }
  17. if (!standardizeLabel) {
  18. lambda *= bStd
  19. }
  20. aaBarValues(i) += lambda
  21. i += j
  22. j += 1
  23. }
  • 3 选择solver


  1. val solver = if ((solverType == WeightedLeastSquares.Auto && elasticNetParam != 0.0 &&
  2. regParam != 0.0) || (solverType == WeightedLeastSquares.QuasiNewton)) {
  3. val effectiveL1RegFun: Option[(Int) => Double] = if (effectiveL1RegParam != 0.0) {
  4. Some((index: Int) => {
  5. if (fitIntercept && index == numFeatures) {
  6. 0.0
  7. } else {
  8. if (standardizeFeatures) {
  9. effectiveL1RegParam
  10. } else {
  11. if (aStdValues(index) != 0.0) effectiveL1RegParam / aStdValues(index) else 0.0
  12. }
  13. }
  14. })
  15. } else {
  16. None
  17. }
  18. new QuasiNewtonSolver(fitIntercept, maxIter, tol, effectiveL1RegFun)
  19. } else {
  20. new CholeskySolver
  21. }


  • 4 处理结果
  1. val solution = solver match {
  2. case cholesky: CholeskySolver =>
  3. try {
  4. cholesky.solve(bBar, bbBar, ab, aa, aBar)
  5. } catch {
  6. // if Auto solver is used and Cholesky fails due to singular AtA, then fall back to
  7. // Quasi-Newton solver.
  8. case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto =>
  9. logWarning("Cholesky solver failed due to singular covariance matrix. " +
  10. "Retrying with Quasi-Newton solver.")
  11. // ab and aa were modified in place, so reconstruct them
  12. val _aa = getAtA(aaBarValues, aBarValues)
  13. val _ab = getAtB(abBarValues, bBar)
  14. val newSolver = new QuasiNewtonSolver(fitIntercept, maxIter, tol, None)
  15. newSolver.solve(bBar, bbBar, _ab, _aa, aBar)
  16. }
  17. case qn: QuasiNewtonSolver =>
  18. qn.solve(bBar, bbBar, ab, aa, aBar)
  19. }
  20. val (coefficientArray, intercept) = if (fitIntercept) {
  21. (solution.coefficients.slice(0, solution.coefficients.length - 1),
  22. solution.coefficients.last * bStd)
  23. } else {
  24. (solution.coefficients, 0.0)
  25. }



  1. // convert the coefficients from the scaled space to the original space
  2. var q = 0
  3. val len = coefficientArray.length
  4. while (q < len) {
  5. coefficientArray(q) *= { if (aStdValues(q) != 0.0) bStd / aStdValues(q) else 0.0 }
  6. q += 1
  7. }