wtBPDN_adaptive_function

PURPOSE ^

wtBPDN_adaptive_function

SYNOPSIS ^

function out = wtBPDN_adaptive_function(A, y, in)

DESCRIPTION ^

 wtBPDN_adaptive_function

 Solves the following basis pursuit denoising (BPDN) problem
 min_x  \Sum \tau \e_i |x_i| + 1/2*||y-Ax||_2^2
 where at every homotopy step the e_i corresponding to the incoming element
 is reduced to a very small value (e.g., 1e-6). This way active and
 inactive indices have different values of the regularization parameter.
 The active elements have smaller weight, so they can stay active as long
 as there is no change in sign. The inactive elements have higher weight,
 so they are pushed towards zero. The hope is that, in this adaptive
 procedure, only the true elements will become active and stay active
 along the homotopy path.

 In homotopy, such an adaptive weight selection strategy can be included
 at every step without any additional cost.

 Inputs:
 A - m x n measurement matrix
 y - measurement vector
 in - input structure
   tau - final value of regularization parameter

   shrinkage_mode - (fixed or adaptive weight selection)
       (for more details, see shrinkage_update.m)
           'Tsteps':   active weights are reduced to tau/ewt
           'frac':     active weights are divided by ewt
           'Trwt':     active weights updated as w_i = tau/(beta*x_i)
           'rwt':      ...
           'OLS': set according to LS solution on the active support
               such as w_gamma = 1./abs((A_gamma'*A_gamma)^-1*A_gamma'*y);
               
           "Tsteps" signifies the observation that using Tsteps along with
           a large value of ewt often yields solution in T steps.

   ewt - Selects weight factor for the active elements (Use a large value)
           ewt controls tradeoff b/w speed and accuracy
           higher ewt --> quicker but potentially unstable
           (ewt=1, shrinkage_mode = Tsteps && shrinkage_flag=2) solves 
           standard LASSO homotopy path

   shrinkage_flag - 0, 1, or 2 (default is 0)
           0 - Instantaneously set active constraints to the "desired"
               value as long as it does not interfere with the active set.
               And if any constraint is violated, take care of that by
               resetting the running value of epsilon
               FAST BUT POTENTIALLY UNSTABLE
           1 - Step size that causes a constraint violation by a member
               of the inactive set
           2 - Gradually change the active constraints, step size takes
               into consideration both sign of elements in the active set
               and constraint violations by the members of inactive set
               (as is the case in standard homotopy)
           3 - FLASH variation???

   maxiter - maximum number of homotopy iterations
   Te - maximum support size allowed
   omp - comparison with OMP
   record - record iteration history
   x_orig - origianl signal for error history
   debias - debias the solution at the end
   early_terminate - terminate early if the support is identified
                   (useful only in high SNR settings)

 Outputs:
 out - output structure
   x_out - output for BPDN
   gamma - support of the solution
   iter - number of homotopy iterations taken by the solver
   time - time taken by the solver
   error_table - error table with iteration record

 Written by: Salman Asif, Georgia Tech
 Email: sasif@gatech.edu

-------------------------------------------+
 Copyright (c) 2012.  M. Salman Asif
-------------------------------------------+

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 % wtBPDN_adaptive_function
0002 %
0003 % Solves the following basis pursuit denoising (BPDN) problem
0004 % min_x  \Sum \tau \e_i |x_i| + 1/2*||y-Ax||_2^2
0005 % where at every homotopy step the e_i corresponding to the incoming element
0006 % is reduced to a very small value (e.g., 1e-6). This way active and
0007 % inactive indices have different values of the regularization parameter.
0008 % The active elements have smaller weight, so they can stay active as long
0009 % as there is no change in sign. The inactive elements have higher weight,
0010 % so they are pushed towards zero. The hope is that, in this adaptive
0011 % procedure, only the true elements will become active and stay active
0012 % along the homotopy path.
0013 %
0014 % In homotopy, such an adaptive weight selection strategy can be included
0015 % at every step without any additional cost.
0016 %
0017 % Inputs:
0018 % A - m x n measurement matrix
0019 % y - measurement vector
0020 % in - input structure
0021 %   tau - final value of regularization parameter
0022 %
0023 %   shrinkage_mode - (fixed or adaptive weight selection)
0024 %       (for more details, see shrinkage_update.m)
0025 %           'Tsteps':   active weights are reduced to tau/ewt
0026 %           'frac':     active weights are divided by ewt
0027 %           'Trwt':     active weights updated as w_i = tau/(beta*x_i)
0028 %           'rwt':      ...
0029 %           'OLS': set according to LS solution on the active support
0030 %               such as w_gamma = 1./abs((A_gamma'*A_gamma)^-1*A_gamma'*y);
0031 %
0032 %           "Tsteps" signifies the observation that using Tsteps along with
0033 %           a large value of ewt often yields solution in T steps.
0034 %
0035 %   ewt - Selects weight factor for the active elements (Use a large value)
0036 %           ewt controls tradeoff b/w speed and accuracy
0037 %           higher ewt --> quicker but potentially unstable
0038 %           (ewt=1, shrinkage_mode = Tsteps && shrinkage_flag=2) solves
0039 %           standard LASSO homotopy path
0040 %
0041 %   shrinkage_flag - 0, 1, or 2 (default is 0)
0042 %           0 - Instantaneously set active constraints to the "desired"
0043 %               value as long as it does not interfere with the active set.
0044 %               And if any constraint is violated, take care of that by
0045 %               resetting the running value of epsilon
0046 %               FAST BUT POTENTIALLY UNSTABLE
0047 %           1 - Step size that causes a constraint violation by a member
0048 %               of the inactive set
0049 %           2 - Gradually change the active constraints, step size takes
0050 %               into consideration both sign of elements in the active set
0051 %               and constraint violations by the members of inactive set
0052 %               (as is the case in standard homotopy)
0053 %           3 - FLASH variation???
0054 %
0055 %   maxiter - maximum number of homotopy iterations
0056 %   Te - maximum support size allowed
0057 %   omp - comparison with OMP
0058 %   record - record iteration history
0059 %   x_orig - origianl signal for error history
0060 %   debias - debias the solution at the end
0061 %   early_terminate - terminate early if the support is identified
0062 %                   (useful only in high SNR settings)
0063 %
0064 % Outputs:
0065 % out - output structure
0066 %   x_out - output for BPDN
0067 %   gamma - support of the solution
0068 %   iter - number of homotopy iterations taken by the solver
0069 %   time - time taken by the solver
0070 %   error_table - error table with iteration record
0071 %
0072 % Written by: Salman Asif, Georgia Tech
0073 % Email: sasif@gatech.edu
0074 %
0075 %-------------------------------------------+
0076 % Copyright (c) 2012.  M. Salman Asif
0077 %-------------------------------------------+
0078 
0079 function out = wtBPDN_adaptive_function(A, y, in)
0080 
0081 N = size(A,2);
0082 M = size(A,1);
0083 
0084 % Regularization parameters
0085 tau = in.tau;
0086 ewt = in.ewt;
0087 shrinkage_mode = in.shrinkage_mode;
0088 shrinkage_flag = 0;
0089 if isfield(in,'shrinkage_flag')
0090     shrinkage_flag = in.shrinkage_flag;
0091 end
0092 
0093 maxiter = in.maxiter;
0094 Te = inf;
0095 if isfield(in,'Te')
0096     Te = in.Te;
0097 end
0098 err_record = 0;
0099 if isfield(in,'record');
0100     err_record = in.record;
0101     if err_record
0102         x_orig = in.x_orig;
0103     end
0104 end
0105 omp = 0; % compare results with OMP
0106 if isfield(in,'omp');
0107     omp = in.omp;
0108 end
0109 plots = 0; % debug plots
0110 if isfield(in,'plots');
0111     plots = in.plots;
0112 end
0113 plot_wts = 0; % plot evolution of weights
0114 if isfield(in,'plot_wts');
0115     plot_wts = in.plot_wts;
0116 end
0117 debias = 0;
0118 if isfield(in,'debias')
0119     debias = in.debias;
0120 end
0121 early_terminate = 0;
0122 if isfield(in,'early_terminate')
0123     early_terminate = in.early_terminate;
0124 end
0125 
0126 t0 = cputime;
0127 
0128 %% Phase I (support selection)
0129 % Initial step
0130 z_x = zeros(N,1);
0131 pk_old = -A'*y;
0132 [c idelta] = max(abs(pk_old));
0133 
0134 gamma_xh = idelta;
0135 temp_gamma = zeros(N,1);
0136 temp_gamma(gamma_xh) = gamma_xh;
0137 gamma_xc = find([1:N]' ~= temp_gamma);
0138 
0139 z_x(gamma_xh) = -sign(pk_old(gamma_xh));
0140 epsilon = c;
0141 pk_old(gamma_xh) = sign(pk_old(gamma_xh))*epsilon;
0142 xk_1 = zeros(N,1);
0143 
0144 %% loop parameters
0145 done = 0;
0146 iter = 0;
0147 rwt_step2 = 0;
0148 
0149 gamma_omp = gamma_xh;
0150 
0151 error_table = [];
0152 if err_record
0153     error_table = [epsilon norm(xk_1-x_orig) 1];
0154 end
0155 
0156 
0157 %% (selective support shrinkage)
0158 epsilon_old = epsilon;
0159 Supp_ledger = zeros(N,1);
0160 Supp_ledger(idelta) = 1;
0161 
0162 tau_vec = ones(N,1)*tau; % Final value of the regularization parameters
0163 epsilon_vec = ones(N,1)*epsilon; % Running values of the regularization parameters
0164 
0165 % Update target weights
0166 shrinkage_update
0167 
0168 % initialize delx
0169 in_delx = [];
0170 delx_mode = in.delx_mode;
0171 rhs = (epsilon_vec-tau_vec).*z_x;
0172 update_mode = 'init0';
0173 update_delx;
0174 
0175 while iter < maxiter
0176     iter = iter+1;
0177     % warning('off','MATLAB:divideByZero')
0178     %% OMP comparison
0179     if omp && iter < M
0180         x_omp = zeros(N,1);
0181         x_omp(gamma_omp) = A(:,gamma_omp)\y;
0182         p_omp = A'*(y-A*x_omp);
0183         gamma_ompC = setdiff([1:N],gamma_omp);
0184         [val_omp, ind_omp] = max(abs(p_omp(gamma_ompC)));
0185         gamma_omp = [gamma_omp; gamma_ompC(ind_omp)];
0186     end
0187     
0188     %% Homotopy
0189     x_k = xk_1;
0190     
0191     % Update direction
0192     delx_vec = zeros(N,1);
0193     delx_vec(gamma_xh) = delx;
0194     
0195     if ~isempty(idelta) && (sign(delx_vec(idelta)) == sign(pk_old(idelta)) && abs(x_k(idelta)) == 0)
0196         delta = 0; flag = 0;
0197     else
0198         pk = pk_old;
0199         % dk = AtA*delx_vec;
0200         dk_temp = A*delx_vec;
0201         dk = A'*dk_temp;
0202         
0203         %%%--- compute step size
0204         in = [];
0205         
0206         % Setting shrinkage_flag to zero shrinks new active constraint towards the
0207         % final value instantly if doing so doesn't disturb the active set
0208         in.shrinkage_flag = shrinkage_flag;
0209         in.pk = pk; in.dk = dk;
0210         in.ak = epsilon_vec; in.bk = tau_vec-epsilon_vec;
0211         in.gamma = gamma_xh; in.gamma_c = gamma_xc;
0212         in.delx_vec = delx_vec; in.x = xk_1;
0213         out = compute_delta(in);
0214         delta = out.delta; idelta = out.idelta;
0215         flag = out.flag;
0216         
0217         %% FLASH stepsize selection???
0218         if shrinkage_flag == 3 && flag == 1
0219             delta_l = 0.5;
0220             delta_avg = out.delta_in*(1-delta_l)+delta_l; % Select stpe size as a convex combination of forward selection (FS) and LASSO...
0221             if delta_avg < out.delta_out
0222                 delta = delta_avg;
0223                 idelta = out.idelta_in;
0224             else
0225                 delta = out.delta_out;
0226                 idelta = out.idelta_out;
0227                 flag = 0;
0228             end
0229         end
0230         
0231         if delta > 1
0232             delta = 1;
0233             flag = 1;
0234         end
0235         
0236         xk_1(gamma_xh) = x_k(gamma_xh)+delta*delx_vec(gamma_xh);
0237         pk_old = pk+delta*dk;
0238         
0239         epsilon_vec_old = epsilon_vec;
0240         % epsilon_vec(gamma_xh) = (1-delta)*epsilon_vec(gamma_xh)+delta*tau_vec(gamma_xh);
0241         epsilon_vec = (1-delta)*epsilon_vec+delta*tau_vec;
0242                 
0243         pk_old(gamma_xh) = sign(pk_old(gamma_xh)).*epsilon_vec(gamma_xh);
0244         
0245         
0246 %         fig(333); plot(abs([A'*(A(:,gamma_xh)*x_orig(gamma_xh)-y) A'*(A(:,gamma_xh)*xk_1(gamma_xh)-y) tau_vec epsilon_vec]));
0247 %         pause;
0248         
0249         % Check convergence criterion (this can be useful)...
0250         if early_terminate
0251                 if length(gamma_xh) < M/2
0252                     xhat = zeros(N,1);
0253                     % xhat(gamma_xh) = AtAgx\(A(:,gamma_xh)'*y);
0254                     switch delx_mode
0255                         case 'mil'
0256                             xhat(gamma_xh) = iAtA*(A(:,gamma_xh)'*y);
0257                         case 'qr'
0258                             xhat(gamma_xh) = R\(R'\(A(:,gamma_xh)'*y));
0259                     end
0260                     if norm(y-A*xhat) < tau
0261                         xk_1 = xhat;
0262                         break;
0263                     end
0264                 end
0265         end
0266         
0267         if max(abs(pk_old)) <= tau
0268             % if you want to solve exactly according to tau, uncomment the
0269             % following lines:
0270             %
0271             % delta_end = epsilon_old-tau;
0272             % xk_1(gamma_xh) = x_k(gamma_xh)+delta_end*delx_vec(gamma_xh);
0273             % pk_old = pk+delta_end*dk;
0274             % epsilon_vec = epsilon_vec_old;
0275             % epsilon_vec(gamma_xh) = (1-delta_end)*epsilon_vec(gamma_xh)+delta_end*tau_vec(gamma_xh);
0276             
0277             % disp('epsilon reduce below threshold');
0278             % fig(303); plot([xk_1(gamma_xh)-(A(:,gamma_xh)'*A(:,gamma_xh))\(A(:,gamma_xh)'*y-epsilon_vec(gamma_xh).*z_x(gamma_xh))])
0279             if flag == 0
0280                 outx_index = find(gamma_xh==idelta);
0281                 gamma_xh = [gamma_xh(1:outx_index-1); gamma_xh(outx_index+1:end)];
0282                 
0283                 xk_1(idelta) = 0;
0284                 epsilon_vec(idelta) = epsilon;
0285             end
0286             break;
0287         end
0288         if length(gamma_xh) >= Te
0289             total_time = cputime-t0;
0290             % disp('support size exceeds limit');
0291             % fig(303); plot([xk_1(gamma_xh) (A(:,gamma_xh)'*A(:,gamma_xh))\(A(:,gamma_xh)'*y-tau*z_x(gamma_xh)) x_orig(gamma_xh)])
0292             % setxor(gamma_xh,find(abs(x_orig)>0))
0293             break;
0294         end
0295         
0296         %% Search for new element
0297         % The one that violates the constraint
0298         epsilon_old = epsilon;
0299         [epsilon index] = max(abs(pk_old));
0300         if flag == 1 && delta == 1
0301             if nnz(index == gamma_xh)
0302                 % iter = iter-1;
0303                 shrinkage_update;
0304                 z_x = -sign(pk_old);
0305                 rhs = (epsilon_vec-tau_vec).*z_x;
0306                 
0307                 switch delx_mode
0308                     case 'mil'
0309                         delx = iAtA*rhs(gamma_xh);
0310                     case 'qr'
0311                         delx = R\(R'\rhs(gamma_xh));
0312                 end
0313                 continue;
0314                 % because the index already exists in the active set
0315             end
0316             idelta = index;
0317             epsilon_vec(idelta) = epsilon;
0318         elseif flag == 1
0319             epsilon_vec(idelta) = abs(pk_old(idelta)); % (1-delta)*epsilon_vec(idelta)+delta*tau_vec(idelta);
0320         end
0321     end
0322 
0323     if err_record
0324         error_table = [error_table; epsilon norm(xk_1-x_orig) length(gamma_xh)];
0325     end
0326     
0327     
0328     % update support
0329     update_supp;
0330     
0331     temp_gamma = zeros(N,1);
0332     temp_gamma(gamma_xh) = gamma_xh;
0333     gamma_xc = find([1:N]' ~= temp_gamma);
0334     epsilon_vec(gamma_xc) = epsilon;
0335     % epsilon_vec(gamma_xc) = max(abs(A'*y));
0336     
0337     if flag == 0
0338         Supp_ledge(idelta) = 0;
0339     else
0340         Supp_ledger(gamma_xh) = Supp_ledger(gamma_xh)+1;
0341     end
0342     
0343     %% Shrinkage parameters selection
0344     shrinkage_update;
0345     
0346     % update delx
0347     z_x = -sign(pk_old);
0348     rhs = (epsilon_vec-tau_vec).*z_x;
0349     in_delx.max_rec = 1;
0350     update_mode = 'update';
0351     update_delx;
0352     %     AtAgx = A(:,gamma_xh)'*A(:,gamma_xh);
0353     %     delx2 =  AtAgx\rhs(gamma_xh);% AtAgx\((epsilon_vec(gamma_xh)-tau_vec(gamma_xh)).*z_x(gamma_xh));
0354     %     fig(111); plot([delx delx2]);
0355     %     if norm(delx-delx2) > 1e-5
0356     %         stp = 1;
0357     %     end
0358     
0359     %% debug...
0360     if plots
0361         fig(101); plot([A'*(A*xk_1-y) pk_old epsilon_vec -epsilon_vec]);
0362         fig(102); clf; hold on;
0363         subplot(311); plot([abs(pk_old) epsilon_vec epsilon_vec_old]);
0364         title(sprintf('iter %d',iter));
0365         subplot(312); plot(delx_vec);
0366         subplot(313); hold on; stem(x_orig,'Marker','.'); plot([x_k xk_1]);
0367         pause(1/60);
0368         
0369         % fig(303); plot([xk_1(gamma_xh)-(A(:,gamma_xh)'*A(:,gamma_xh))\(A(:,gamma_xh)'*y-epsilon_vec(gamma_xh).*z_x(gamma_xh))])
0370         % [pk(gamma_xh) x_k(gamma_xh) pk_old(gamma_xh) xk_1(gamma_xh) x_orig(gamma_xh) gamma_xh]
0371         if (max(abs(A'*(A*xk_1-y))-abs(pk_old)) > 1e-8)
0372             disp('constraints mismatch...')
0373         end
0374     end
0375     constr_violation = nnz((abs(pk_old(gamma_xc))-epsilon_vec(gamma_xc))>1e-10);
0376     sign_violation = nnz(abs(sign(pk_old(gamma_xh))+sign(xk_1(gamma_xh)))>1);
0377     if constr_violation
0378         chk = gamma_xc((abs(pk_old(gamma_xc))-epsilon_vec(gamma_xc))>1e-10);
0379         stp = 1;
0380         fprintf('problem... with constraint violation -- %s\n', mfilename);
0381         fprintf('Refactorize the matrix... recompute delx \n');
0382         % some times it comes here due to bad conditioning of AtAgx.
0383         update_mode = 'init0';
0384         update_delx;
0385     end
0386     if sign_violation>1
0387         chk = gamma_xh(abs(sign(pk_old(gamma_xh))+sign(xk_1(gamma_xh)))>1);
0388         stp = 1;
0389         fprintf('problem... sign mismatch -- %s\n',mfilename);
0390         fprintf('Refactorize the matrix... recompute delx \n');
0391         update_mode = 'init0';
0392         update_delx;
0393     end
0394     
0395     %% Figure to view evolution of weights
0396     if plot_wts
0397         if mod(iter,10)==1
0398             if ~exist('marker_iter','var')
0399                 weight_marker = {'b','r','k','m','g'};
0400                 fig(121); clf;
0401                 set(gca,'FontSize',16);
0402                 semilogy(1:N,epsilon_vec([gamma_xh; gamma_xc]),'Color',weight_marker{1},'LineWidth',2);
0403                 marker_iter = 1;
0404                 hold on;
0405                 EPS_VEC = [];
0406             else
0407                 marker_iter = mod(marker_iter,length(weight_marker))+1;
0408                 fig(121); semilogy(1:N,epsilon_vec([gamma_xh; gamma_xc]),'Color',weight_marker{marker_iter},'LineWidth',2);
0409             end
0410             axis tight;
0411             YLim = get(gca,'YLim');
0412             YLim1 = YLim;
0413             YLim1(1) = YLim(1)-(YLim(2)-YLim(1))*0.1;
0414             YLim1(2) = YLim(2)+(YLim(2)-YLim(1))*0.3;
0415             set(gca,'YLim',YLim1);
0416             str1 = sprintf('step %d',iter);
0417             % text(iter+5,epsilon*1.2, str1,'FontSize',16);
0418             EPS_VEC = [EPS_VEC epsilon_vec([gamma_xh; gamma_xc])];
0419         end
0420     end
0421 end
0422 
0423 if debias
0424     x_out = zeros(N,1);
0425     switch delx_mode
0426         case 'mil'            
0427             x_out(gamma_xh) = iAtA*(A(:,gamma_xh)'*y);
0428         case 'qr'
0429             x_out(gamma_xh) = R\(R'\(A(:,gamma_xh)'*y));
0430     end
0431 else
0432     x_out = xk_1;
0433 end
0434 
0435 if err_record
0436     error_table = [error_table; epsilon norm(x_out-x_orig) length(gamma_xh)];
0437 end
0438 total_iter = iter;
0439 total_time = cputime-t0;
0440 
0441 out = [];
0442 out.x_out = x_out;
0443 out.gamma = gamma_xh; % find(abs(xk_1)>0);
0444 out.iter = total_iter;
0445 out.time = total_time;
0446 out.error_table = error_table;
0447 out.tau_vec = epsilon_vec;
0448 
0449 % if ~isempty(setxor(gamma_xh,find(abs(x_orig)>0))) || iter > nnz(x_orig)
0450 %     stp = 1;
0451 % end
0452 % fig(303); plot([xk_1(gamma_xh)-(A(:,gamma_xh)'*A(:,gamma_xh))\(A(:,gamma_xh)'*y-tau_vec(gamma_xh).*z_x(gamma_xh))])
0453 % fig(303); plot([xk_1(gamma_xh) (A(:,gamma_xh)'*A(:,gamma_xh))\(A(:,gamma_xh)'*y-tau_vec(gamma_xh).*z_x(gamma_xh)) x_orig(gamma_xh)])

Generated on Mon 10-Jun-2013 23:03:23 by m2html © 2005