最优化算法 - 迭代再加权最小二乘
优质
小牛编辑
129浏览
2023-12-01
1 原理
迭代再加权最小二乘(IRLS
)用于解决特定的最优化问题,这个最优化问题的目标函数如下所示:
arg min{\beta} \sum{i=1}^{n}|y{i} - f{i}(\beta)|^{p}
这个目标函数可以通过迭代的方法求解。在每次迭代中,解决一个带权最小二乘问题,形式如下:
\beta ^{t+1} = argmin{\beta} \sum{i=1}^{n} w{i}(\beta^{(t)}))|y{i} - f_{i}(\beta)|^{2} = (X^{T}W^{(t)}X)^{-1}X^{T}W^{(t)}y
在这个公式中,$W^{(t)}$是权重对角矩阵,它的所有元素都初始化为1。每次迭代中,通过下面的公式更新。
W{i}^{(t)} = |y{i} - X_{i}\beta^{(t)}|^{p-2}
2 源码分析
在spark ml
中,迭代再加权最小二乘主要解决广义线性回归问题。下面看看实现代码。
2.1 更新权重
// Update offsets and weights using reweightFunc
val newInstances = instances.map { instance =>
val (newOffset, newWeight) = reweightFunc(instance, oldModel)
Instance(newOffset, newWeight, instance.features)
}
这里使用reweightFunc
方法更新权重。具体的实现在广义线性回归的实现中。
/**
* The reweight function used to update offsets and weights
* at each iteration of [[IterativelyReweightedLeastSquares]].
*/
val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double) = {
(instance: Instance, model: WeightedLeastSquaresModel) => {
val eta = model.predict(instance.features)
val mu = fitted(eta)
val offset = eta + (instance.label - mu) * link.deriv(mu)
val weight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu))
(offset, weight)
}
}
def fitted(eta: Double): Double = family.project(link.unlink(eta))
这里的model.predict
利用带权最小二乘模型预测样本的取值,然后调用fitted
方法计算均值函数$\mu$。offset
表示
更新后的标签值,weight
表示更新后的权重。关于链接函数的相关计算可以参考广义线性回归的分析。
有一点需要说明的是,这段代码中标签和权重的更新并没有参照上面的原理或者说我理解有误。
2.2 训练新的模型
// 使用更新过的样本训练新的模型
model = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0,
standardizeFeatures = false, standardizeLabel = false).fit(newInstances)
// 检查是否收敛
val oldCoefficients = oldModel.coefficients
val coefficients = model.coefficients
BLAS.axpy(-1.0, coefficients, oldCoefficients)
val maxTolOfCoefficients = oldCoefficients.toArray.reduce { (x, y) =>
math.max(math.abs(x), math.abs(y))
}
val maxTol = math.max(maxTolOfCoefficients, math.abs(oldModel.intercept - model.intercept))
if (maxTol < tol) {
converged = true
}
训练完新的模型后,重复2.1步,直到参数收敛或者到达迭代的最大次数。