script_rwtBPDN_adaptive

PURPOSE ^

BPDN rwt update (initialize with adaptive_wtBPDN)

SYNOPSIS ^

function out = script_rwtBPDN_adaptive(in)

DESCRIPTION ^

 BPDN rwt update (initialize with adaptive_wtBPDN)

 Solve the following basis pursuit denoising (BPDN) problem
 min_x  \Sum \w_i |x_i| + 1/2*||y-Ax||_2^2

 Initialize with adaptive weighted BPDN
 and dynamically update the weights w_i

 Written by: Salman Asif, Georgia Tech
 Email: sasif@ece.gatech.edu
 Created: April 16, 2011

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 % BPDN rwt update (initialize with adaptive_wtBPDN)
0002 %
0003 % Solve the following basis pursuit denoising (BPDN) problem
0004 % min_x  \Sum \w_i |x_i| + 1/2*||y-Ax||_2^2
0005 %
0006 % Initialize with adaptive weighted BPDN
0007 % and dynamically update the weights w_i
0008 %
0009 % Written by: Salman Asif, Georgia Tech
0010 % Email: sasif@ece.gatech.edu
0011 % Created: April 16, 2011
0012 
0013 function out = script_rwtBPDN_adaptive(in)
0014 
0015 A = in.A; y = in.y; x = in.x; 
0016 gamma_orig = find(x);
0017 [M N] = size(A);
0018 tau = in.tau; max_rwt = in.max_rwt;
0019 x_init = in.x_init; 
0020 
0021 %% parameter selection
0022 rwt_mode = in.rwt_mode;
0023 
0024 delx_mode = in.delx_mode;
0025 homotopy_update = 'v1';
0026 debias = in.debias;
0027 verbose = in.verbose;
0028 
0029 maxiter = 2*N;
0030 iter_ALL = [];
0031 err_ALL = [];
0032 time_ALL = [];
0033 supp_diff = [];
0034 
0035 if norm(x_init) == 0
0036     %% Adaptive support + weight selection
0037     % Fast shrinkage on the active set...???
0038     %
0039     % use a large value for Tsteps (~10-100)
0040     %     ewt = 10; shrinkage_mode = 'Tsteps'; shrinkage_flag = 0;
0041     %
0042     % use a large value for Trwt
0043     %     ewt = 100; shrinkage_mode = 'Trwt'; shrinkage_flag = 0; % Gaussian
0044     % use shrinkage_flag = 0 for Gaussian signal... (amplitude variations)
0045     % use shrinkage_flag = 2 for sign/ones signal... (flat)
0046     %
0047     
0048     % 'rwt' works better when nonzero components have diverse amplitudes
0049     % ewt = 2; shrinkage_mode = 'rwt'; shrinkage_flag = 0;
0050     ewt = 1; shrinkage_mode = 'Trwt'; shrinkage_flag = 0;
0051     % ewt = 2; shrinkage_mode = 'Tsteps'; shrinkage_flag = 3;
0052     % ewt = 1; shrinkage_mode = 'OLS'; shrinkage_flag = 0;
0053     
0054     if verbose
0055         fprintf(' ewt = %1.2g, shrinkage_mode: %s, ',ewt, shrinkage_mode);
0056     end
0057     
0058     in = [];
0059     in.tau = tau;
0060     in.ewt = ewt; % Setting ewt = 1 with Tsteps solves unweighted LASSO
0061     in.shrinkage_flag = shrinkage_flag;
0062     in.shrinkage_mode = shrinkage_mode; % (either fixed or adaptive weighting)
0063     in.maxiter = maxiter;
0064     in.debias = debias;
0065     in.early_terminate = 0;
0066     in.x_orig = x;
0067     in.record = 1;
0068     in.omp = 0;
0069     in.plots = 0;
0070     in.plot_wts = 0;
0071     in.delx_mode = delx_mode;
0072     % in.Te = T;
0073     tic;
0074     out = wtBPDN_adaptive_function(A, y, in);
0075     time_ALL = [time_ALL toc];
0076     % time_ALL = [time_ALL out.time];
0077     x_init = out.x_out;
0078     gamma_old = out.gamma;
0079     iter_old = out.iter;
0080     tau_vec = out.tau_vec;
0081     iter_ALL = [iter_ALL iter_old];
0082     err_ALL = [err_ALL norm(x-x_init)/norm(x)];
0083     supp_diff = [supp_diff length(setxor(gamma_old,gamma_orig))];
0084     if verbose
0085         fprintf('err = %3.4g, iter = %3.4g, time = %3.4g \n',err_ALL(end), iter_ALL(end), time_ALL(end));
0086     end
0087 end
0088 W_new = tau_vec; % These weights are chosen adaptively by the homotopy solver
0089 
0090 % THERE SHOULDN'T BE ANY NEED FOR REWEIGHTING HERE
0091 % But just in case...
0092 
0093 %% Iterative reweighting
0094 xh_mod = x_init;
0095 gamma_xh = gamma_old;
0096 for rwt_itr = 1:max_rwt
0097     gamma_old = gamma_xh;
0098     x_old = xh_mod;
0099     W_old = W_new;
0100     
0101     [alpha beta epsilon] = weight_param(rwt_mode,rwt_itr,x_old,M);    
0102     W_new = tau/alpha./(beta*abs(x_old)+epsilon);
0103     
0104     % W_new = tau/epsilon*ones(N,1);
0105     % W_new(gamma_old) = min([tau*ones(length(gamma_old),1) tau./(beta*abs(x_old(gamma_old)))],[],2);
0106     
0107     switch homotopy_update
0108         case 'v1'
0109             % The following Gram matrix and its inverse can be used from the
0110             % previous homotopy. Too lazy to include that right now...
0111             % wt BPDN homotopy update
0112             pk_old = A'*(A*x_old-y);
0113             pk_old(gamma_old) = sign(pk_old(gamma_old)).*W_old(gamma_old);
0114             in = [];
0115             in.x_old = x_old;
0116             in.gamma = gamma_old;
0117             in.pk_old = pk_old;
0118             in.W_old = W_old;
0119             in.W_new = W_new;
0120             in.maxiter = maxiter;
0121             dW = W_new-W_old;
0122             in.maxiter = maxiter;
0123             
0124             % delx = -AtAgx\(-dW(gamma_old).*sign(pk_old(gamma_old)));
0125             % in.delx = delx;
0126             in.delx_mode = delx_mode;
0127             switch in.delx_mode
0128                 case 'mil';
0129                     % in.delx_mode = 'mil';
0130                     AtAgx = A(:,gamma_old)'*A(:,gamma_old);
0131                     iAtAgx = pinv(AtAgx);
0132                     in.AtA = AtAgx;
0133                     in.iAtA = iAtAgx;                    
0134                 case 'qr';
0135                     % in.delx_mode = 'qr';
0136                     [Q R] = qr(A(:,gamma_old),0);
0137                     in.Q = Q; in.R = R;
0138                 case 'qrM'
0139                     % in.delx_mode = 'qrM';
0140                     [Q0 R0] = qr(A(:,gamma_old));
0141                     in.Q0 = Q0; in.R0 = R0;
0142             end
0143             tic;
0144             out = wtBPDN_Update_function_v1(A, y, in);
0145             % time_update = out.time;
0146             time_update = toc;
0147             xh_mod = out.x_out;
0148             gamma_xh = out.gamma;
0149             iter_update = out.iter;
0150         case 'v2'
0151             
0152             % Homotopy update v2 (weighting embedded in the matrix A)
0153             AW_old = A*diag(tau./W_old);
0154             u0_hat = x_old.*(W_old/tau);
0155             ds = AW_old'*(AW_old*u0_hat-y);
0156             AW = A*diag(tau./W_new);
0157             yhat = AW*u0_hat;
0158             pk_old = ds; pk_old(gamma_old) = sign(pk_old(gamma_old)).*tau;
0159             in = [];
0160             in.x_old = u0_hat;
0161             in.gamma = gamma_old;
0162             in.pk_old = pk_old;
0163             in.tau = tau;
0164             in.maxiter = maxiter;
0165             in.yhat = yhat;
0166             
0167             % ds = pk_old; d = AW'*(yhat-y);
0168             % delx = AtAgx\(ds(gamma_old)-d(gamma_old));
0169             % in.delx = delx;
0170             in.delx_mode = delx_mode;
0171             switch in.delx_mode
0172                 case 'mil';
0173                     % in.delx_mode = 'mil';
0174                     AtAgx = AW(:,gamma_old)'*AW(:,gamma_old);
0175                     iAtAgx = pinv(AtAgx);                    
0176                     in.AtA = AtAgx;
0177                     in.iAtA = iAtAgx;                    
0178                 case 'qr';
0179                     % in.delx_mode = 'qr';
0180                     [Q R] = qr(AW(:,gamma_old),0);
0181                     in.Q = Q; in.R = R;
0182                 case 'qrM'
0183                     % in.delx_mode = 'qrM';
0184                     [Q0 R0] = qr(AW(:,gamma_old));
0185                     in.Q0 = Q0; in.R0 = R0;
0186                 case 'cg'
0187                     % in.delx_mode = 'cg';
0188                     in.W_new = tau./W_new;
0189             end
0190             tic;
0191             out = wtBPDN_Update_function_v2(AW, y, in);
0192             % time_update = out.time;
0193             time_update = toc;
0194 
0195             xh_mod = out.x_out.*(tau./W_new);
0196             gamma_xh = out.gamma;
0197             iter_update = out.iter;
0198     end
0199     iter_ALL = [iter_ALL iter_update];
0200     err_ALL = [err_ALL norm(x-xh_mod)/norm(x)];
0201     time_ALL = [time_ALL time_update];
0202     supp_diff = [supp_diff length(setxor(gamma_xh,gamma_orig))];
0203 end
0204 
0205 out.x_out = xh_mod;
0206 out.x_init = x_init;
0207 out.iter = iter_ALL;
0208 out.gamma = gamma_xh;
0209 out.W_new = W_new;
0210 out.err = err_ALL;
0211 out.time = time_ALL;
0212 out.supp_diff = supp_diff;

Generated on Mon 10-Jun-2013 23:03:23 by m2html © 2005