当前位置: 首页 > 工具软件 > Lars > 使用案例 >

Matlab实现Lasso-Lars

穆理
2023-12-01

我的第一篇博客哟(沧海一声笑)
基础知识传送门:https://zhuanlan.zhihu.com/p/46999826
参考博文传送门:http://blog.sina.com.cn/s/blog_4a03c0100101b3c3.html。
我的工作:对代码进行了一些些修改,修正了错误部分,以及去掉了无用的部分。
本来Matlab就不是很熟,算法更是一头雾水(嘤~)花掉了姐妹儿两天的时间,肩周炎都犯了(大哭)
注释都写上啦,接下来更新匹配追踪(OMP)。
废话不多说,上代码==》
------------更新说明------------
该算法似乎不太准确,直接使用matlab的lasso会更好。有时间更新。

%主文件
clc;
clear;
y=[7.6030 -4.1781 3.1123 1.0586 7.8053]';
A=[0.5377 -1.3077 -1.3499 -0.2050 0.6715 1.0347 0.8884;
    1.8339 -0.4336 3.3049 -0.1241 -1.2075 0.7269 -1.1471;
    -2.2588 0.3426 0.7254 1.4897 0.7172 -0.3034 -1.0689;
    0.8622 3.5784 -0.0631 1.4090 1.6302 0.2939 0.8095;
    0.3188 2.7694 0.7147 1.4172 0.4889 -0.7873 -2.9443];
beta1 = lars1(A, y, 'lasso', 0, 1, [], 1); |
%lars1函数
function beta = lars1(X, y, ~, ~, ~, ~, ~)
    [n,p] = size(X);%传感矩阵的[行数,列数]
    nvars = p; %y最多由p个向量决定
    maxk = 8*nvars; % 最多迭代次数
    beta = zeros(2*nvars, p);%beta为每次迭代得稀疏系数的预测
    mu = zeros(n, 1); % LARS向lsq解决方案移动时的当前“位置”,n*1的0向量(列)
    I = 1:p; % 非积极集,开始为所有向量下标(观测矩阵X的总列数),存储下标1、2、3、4、5、6、7
    A = []; % 积极集,存储下标(已选定向量)
    R = []; 
    lassocond = 0; % 出现负相关变量时,lassocond=1
    k = 0; % 当前迭代次数
    vars = 0; % 目前积极变量的个数。
    fprintf('Step\tAdded\tDropped\t\tActive set size\n');%命令窗口显示字,matlab2019版
    
     %并非所有变量都为积极集时,k未达到最大迭代次数进行以下循环
    while vars < nvars  && k < maxk 
      k = k + 1; 
      c = X'*(y - mu);% 各自变量与因变量的相关度。
      fprintf("相关度:");%列向量
      disp(c');
      [C,j] = max(abs(c(I)));%最相关向量,仅产生于非积极集,C为最大相关度,j为下标。此处并未去除负值
      %注意!j为(最相关变量在X中的列数)在I中的下标
      if ~lassocond 
        R = cholinsert(R,X(:,I(j)),X(:,A));
        A = [A I(j)]; %把最相关向量所对应下标(元素I(j))加入到积极集
        I(j) = []; %把进入积极集的向量下标从非积极集中踢出来。注意!元素直接左移,不留空位
        vars = vars + 1;  %目前的积极集中的变量个数+1
        b=size(A);
        %显示:迭代次数,积极集加入变量,总积极向量个数
        fprintf('%d\t\t%d\t\t\t\t\t%d\n', k, A(b(2)), vars);
      end
      
      %计算角平分线
      s = sign(c(A)); % 返回与 c(A)大小相同的数组 s;c(A)>0,s元素是1;c(A)=0,s元素是0;c(A)<0,s元素是-1; 
      GA1 = R\(R'\s);
      fprintf("\nGA1:");%列向量
      disp(GA1');
      AA = 1/sqrt(sum(GA1.*s));
      fprintf("AA:");
      disp(AA);
      w = AA*GA1;
      fprintf("w:");%列向量
      disp(w');%列向量
      u = X(:,A)*w; % 角平分线(单位矢量),n*1的列向量
      fprintf('角平分线 u = ');%列向量
      disp(u');
      
      %计算移动系数gama(伽马),计算在这个方向上走多远,步长
      if vars == nvars % 如果所有的变量都是积极集,则变成一个最小二乘问题。
        gamma = C/AA;  % C/A:相关度/AA
      else
        a = X'*u; % 各变量与角平分线的相关性
        temp = [(C - c(I))./(AA - a(I)); (C + c(I))./(AA + a(I))];%步长。必须选其中的正数。拼起来的列向量
        gamma = min([temp(temp > 0); C/AA]); %[temp中大于0的数字组成列向量,数字]
      end
      
      %找出其中的负相关变量
      lassocond = 0;
      temp = -beta(k,A)./w';
      [gamma_tilde] = min([temp(temp > 0) gamma]);%gamma_tilde:伽马+波浪线
      % fprintf('gamma_tilde = %d',gamma_tilde));
      %判断A中是否有负相关变量
      j = find(temp == gamma_tilde);%j(此处j指A中变量下标)有时有(积极集删除操作,转61行),有时没有
      if gamma_tilde < gamma %存在负相关变量
       gamma = gamma_tilde;%gama取最小
       lassocond = 1;   %lasso条件满足
      end
      fprintf('gamma = %d\n',gamma);
      mu = mu + gamma*u;  % mu:逼近Y 的向量
      fprintf('逼近Y 的向量 mu = ');%列向量
      disp(mu');
      beta(k+1,A) = beta(k,A) + gamma*w';%本次=上次beta中相关位置元素相应增加 gamma*w';beta(k,A)为beta的k行A列,A为列表
      fprintf('beta = ');%列向量
      disp(beta(k,:));
      % 删除负相关变量
      if lassocond == 1
        R = choldelete(R,j);%删除与拟合为反方向的列
        I = [I A(j)];%将该元素加入非积极集
        vars = vars - 1;%积极集减少一个元素
        A(j) = [];
        b=size(I);
        %展示:迭代次数,积极集删除的向量,积极变量个数
        fprintf('%d\t\t\t\t%d\t\t\t%d\n', k, I(b(2)), vars);
      end
      
      %计算残差
      error_pri=y-X*(beta(k,:))';             
      error=sum(error_pri.^2);
      fprintf('误差:');
      disp(error)
      if error < 1e-6
          break;  
      end
      
    end  %while结束
function R = cholinsert(R, x, X)
    diag_k = x'*x; % diagonal(对角线) element k in X'X matrix
    if isempty(R)
      R = sqrt(diag_k);
    else
      col_k = x'*X; % elements of column k in X'X matrix
      R_k = R'\col_k'; % R'R_k = (X'X)_k, solve for R_k
      R_kk = sqrt(diag_k - R_k'*R_k); % norm(x'x) = norm(R'*R), 排除查找最后一个元素
      R = [R R_k; [zeros(1,size(R,2)) R_kk]]; % update R
    end

function R = choldelete(R,j)
    R(:,j) = []; % remove column j
    n = size(R,2);
    for k = j:n
      p = k:k+1;
      [G,R(p,k)] = planerot(R(p,k)); %删除该列额外元素,吉文斯平面旋转
      if k < n
        R(p,k+1:n) = G*R(p,k+1:n); % adjust rest of row
      end
    end
    R(end,:) = []; % remove zero'ed out row


 类似资料: