当前位置: 首页 > 知识库问答 >
问题:

线性回归的梯度下降实现问题

谷梁智
2023-03-14

我正在学习机器学习/线性回归的Coursera课程。下面是他们如何描述用于求解估计OLS系数的梯度下降算法:

因此,他们对系数使用w,对设计矩阵(或他们称之为特征)使用H,对因变量使用y。它们的收敛准则通常是RSS梯度的范数小于容差ε;也就是说,他们对“不收敛”的定义是:

我很难让这个算法收敛,我想知道在我的实现中是否忽略了一些东西。下面是代码。请注意,我还通过statsmodels回归库运行了我在其中使用的样本数据集(df),只是为了查看回归可以收敛,并获得与之相关的系数值。确实如此,他们是:

Intercept    4.344435
x1           4.387702
x2           0.450958

这是我的实现。在每次迭代中,它将打印RSS梯度的范数:

import numpy as np
import numpy.linalg as LA
import pandas as pd
from pandas import DataFrame

# First define the grad function: grad(RSS) = -2H'(y-Hw)
def grad_rss(df, var_name_y, var_names_h, w):
    # Set up feature matrix H
    H = DataFrame({"Intercept" : [1 for i in range(0,len(df))]})
    for var_name_h in var_names_h:
        H[var_name_h] = df[var_name_h]

    # Set up y vector
    y = df[var_name_y]

    # Calculate the gradient of the RSS:  -2H'(y - Hw)
    result = -2 * np.transpose(H.values) @ (y.values - H.values @ w)

    return result

def ols_gradient_descent(df, var_name_y, var_names_h, epsilon = 0.0001, eta = 0.05):
    # Set all initial w values to 0.0001 (not related to our choice of epsilon)
    w = np.array([0.0001 for i in range(0, len(var_names_h) + 1)])

    # Iteration counter
    t = 0

    # Basic algorithm: keep subtracting eta * grad(RSS) from w until
    # ||grad(RSS)|| < epsilon.
    while True:
        t = t + 1

        grad = grad_rss(df, var_name_y, var_names_h, w)
        norm_grad = LA.norm(grad)

        if norm_grad < epsilon:
            break
        else:
            print("{} : {}".format(t, norm_grad))
            w = w - eta * grad

            if t > 10:
                raise Exception ("Failed to converge")

    return w

# ########################################## 

df = DataFrame({
        "y" : [20,40,60,80,100] ,
        "x1" : [1,5,7,9,11] ,
        "x2" : [23,29,60,85,99]         
    })

# Run
ols_gradient_descent(df, "y", ["x1", "x2"])

不幸的是,这并没有收敛,事实上打印了一个随着每次迭代而爆炸的规范:

1 : 44114.31506051333
2 : 98203544.03067812
3 : 218612547944.95386
4 : 486657040646682.9
5 : 1.083355358314664e+18
6 : 2.411675439503567e+21
7 : 5.368670935963926e+24
8 : 1.1951287949674022e+28
9 : 2.660496151835357e+31
10 : 5.922574875391406e+34
11 : 1.3184342751414824e+38
---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
......
Exception: Failed to converge

如果我增加最大迭代次数足够多,它就不会收敛,而是会爆炸到无穷大。

这里是否有实现错误,或者我误解了课堂笔记中的解释?

正如@Kant所建议的,eta需要在每次迭代时更新。课程本身有一些关于这方面的示例公式,但没有一个有助于收敛。维基百科页面关于梯度下降的这一部分提到Barzilai Borwein方法是更新eta的好方法。我实现了它,并在每次迭代时修改代码以更新eta,回归成功收敛。下面是我将维基百科版本的公式翻译成回归中使用的变量,以及实现它的代码。同样,在我原来的ols\u gradient\u descent循环中调用此代码来更新eta

def eta_t (w_t, w_t_minus_1, grad_t, grad_t_minus_1):
    delta_w = w_t - w_t_minus_1
    delta_grad = grad_t - grad_t_minus_1

    eta_t = (delta_w.T @ delta_grad) / (LA.norm(delta_grad))**2

    return eta_t

共有1个答案

唐煜
2023-03-14

尝试降低eta的值。如果eta太高,梯度下降可能会发散。

 类似资料:
  • 我试图实现梯度下降的线性回归,如本文(https://towardsdatascience.com/linear-regression-using-gradient-descent-97a6c8700931)所述。我已经严格遵循了实现,但是经过几次迭代后,我的结果会溢出。我试图得到这个结果大约: y=-0.02x 8499.6。 代码: 在这里,它可以在围棋场上工作:https://play.go

  • 我试图在java中实现线性回归。我的假设是θ0θ1*x[i]。我试图计算θ0和θ1的值,使成本函数最小。我正在用梯度下降来找出值- 在 在收敛之前,这种重复是什么?我知道这是局部最小值,但我应该在while循环中输入的确切代码是什么? 我对机器学习非常陌生,刚开始编写基本的算法以获得更好的理解。任何帮助都将不胜感激。

  • 我试图在MatLab中实现一个函数,该函数使用牛顿法计算最佳线性回归。然而,我陷入了一个问题。我不知道如何求二阶导数。所以我不能实施它。这是我的密码。 谢谢你的帮助。 编辑:: 我用一些纸和笔解决了这个问题。你所需要的只是一些微积分和矩阵运算。我找到了二阶导数,它现在正在工作。我正在为感兴趣的人分享我的工作代码。

  • 我用JavaScript实现了一个非常简单的线性回归和梯度下降算法,但是在查阅了多个源代码并尝试了几件事情之后,我无法使它收敛。 数据是绝对线性的,只是数字0到30作为输入,x*3作为正确的输出来学习。 这就是梯度下降背后的逻辑: 我从不同的地方取了公式,包括: 乌达城深度学习基金会纳米学位的练习 吴恩达的线性回归梯度下降课程(也在这里) 斯坦福CS229讲义 我从卡内基梅隆大学找到的其他PDF幻

  • 好的,那么这个算法到底意味着什么呢? 据我所知: i) 阿尔法:梯度下降的步骤有多大。 ii)现在,∑{hTheta[x(i)]-y(i)}:指给定θ值的总误差。 误差是指预测值{hTheta[x(i)]}与实际值之间的差值。[y(i)] σ{hTheta[x(i)]-y(i)}给出了所有训练示例中所有误差的总和。 结尾的Xj^(i)代表什么? 在为多元线性回归实现梯度下降时,我们是否在执行以下操

  • 在机器学习课程https://share.coursera.org/wiki/index.php/ML:Linear_Regression_with_Multiple_Variables#Gradient_Descent_for_Multiple_Variables中,它说梯度下降应该收敛。 我正在使用scikit学习的线性回归。它不提供梯度下降信息。我已经看到了许多关于stackoverflow