wtBPDN_homotopy_function

PURPOSE ^

wtBPDN_homotopy_function

SYNOPSIS ^

function out = wtBPDN_homotopy_function(A, y, in)

DESCRIPTION ^

 wtBPDN_homotopy_function

 Solves the following basis pursuit denoising (BPDN) problem
 min_x  \Sum w_i |x_i| + 1/2*||y-Ax||_2^2

 using homotopy from scratch. 

 (A simpler way to solve weighted BPDN is to modulate columns of A using
 the weights and solve simple BPDN, i.e., min \|x_i\|_1 + 1/2\|AWx-y\|_2^2
 Inputs:
 A - m x n measurement matrix
 y - measurement vector
 in - input structure
   W - final values of regularization parameter
   maxiter - maximum number of homotopy iterations
   Te -
   record - record iteration history
   x_orig - origianl signal for error history

 Outputs:
 out - output structure
   x_out - output for BPDN
   gamma - support of the solution
   iter - number of homotopy iterations taken by the solver
   time - time taken by the solver
   error_table - error table with iteration record

 Written by: Salman Asif, Georgia Tech
 Email: sasif@ece.gatech.edu

-------------------------------------------+
 Copyright (c) 2011.  Muhammad Salman Asif
-------------------------------------------+

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 % wtBPDN_homotopy_function
0002 %
0003 % Solves the following basis pursuit denoising (BPDN) problem
0004 % min_x  \Sum w_i |x_i| + 1/2*||y-Ax||_2^2
0005 %
0006 % using homotopy from scratch.
0007 %
0008 % (A simpler way to solve weighted BPDN is to modulate columns of A using
0009 % the weights and solve simple BPDN, i.e., min \|x_i\|_1 + 1/2\|AWx-y\|_2^2
0010 % Inputs:
0011 % A - m x n measurement matrix
0012 % y - measurement vector
0013 % in - input structure
0014 %   W - final values of regularization parameter
0015 %   maxiter - maximum number of homotopy iterations
0016 %   Te -
0017 %   record - record iteration history
0018 %   x_orig - origianl signal for error history
0019 %
0020 % Outputs:
0021 % out - output structure
0022 %   x_out - output for BPDN
0023 %   gamma - support of the solution
0024 %   iter - number of homotopy iterations taken by the solver
0025 %   time - time taken by the solver
0026 %   error_table - error table with iteration record
0027 %
0028 % Written by: Salman Asif, Georgia Tech
0029 % Email: sasif@ece.gatech.edu
0030 %
0031 %-------------------------------------------+
0032 % Copyright (c) 2011.  Muhammad Salman Asif
0033 %-------------------------------------------+
0034 
0035 function out = wtBPDN_homotopy_function(A, y, in)
0036 
0037 N = size(A,2);
0038 M = size(A,1);
0039 
0040 W_vec = in.W;
0041 maxiter = in.maxiter;
0042 Te = inf;
0043 if isfield(in,'Te')
0044     Te = in.Te;
0045 end
0046 err_record = 0;
0047 if isfield(in,'record');
0048     err_record = in.record;
0049     if err_record
0050         x_orig = in.x_orig;
0051     end
0052 end
0053 t0 = cputime;
0054 
0055 % Regularization parameters
0056 unique_eps = sort(unique(W_vec),'descend');
0057 
0058 % Initialization of primal sign and support
0059 z_x = zeros(N,1);
0060 gamma_x = [];       % Primal support
0061 
0062 % Initial step
0063 pk_old = -A'*y;
0064 constr_mask = abs(pk_old)>W_vec;
0065 [c idelta] = max(abs(pk_old.*constr_mask));
0066 eps_iter = sum(unique_eps>c)+1;
0067 
0068 gamma_xh = idelta;
0069 temp_gamma = zeros(N,1);
0070 temp_gamma(gamma_xh) = gamma_xh;
0071 gamma_xc = find([1:N]' ~= temp_gamma);
0072 
0073 z_x(gamma_xh) = -sign(pk_old(gamma_xh));
0074 epsilon = c;
0075 pk_old(gamma_xh) = sign(pk_old(gamma_xh))*epsilon;
0076 xk_1 = zeros(N,1);
0077 
0078 % loop parameters
0079 done = 0;
0080 iter = 0;
0081 itr_history = [];
0082 
0083 error_table = [];
0084 if err_record
0085     error_table = [epsilon norm(xk_1-x_orig) 1];
0086 end
0087 
0088 % initialize delx
0089 in_delx = [];
0090 delx_mode = in.delx_mode;
0091 indicator_temp = epsilon>W_vec;
0092 rhs = indicator_temp.*z_x;
0093 update_mode = 'init0';
0094 update_delx;
0095 flag = 1;
0096 
0097 while iter < maxiter
0098     iter = iter+1;
0099     % warning('off','MATLAB:divideByZero')
0100     
0101     x_k = xk_1;
0102     
0103     %%%%%%%%%%%%%%%%%%%%%
0104     %%%% update on x %%%%
0105     %%%%%%%%%%%%%%%%%%%%%
0106     
0107     % Update direction
0108     % update direction
0109     delx_vec = zeros(N,1);
0110     delx_vec(gamma_xh) = delx;
0111     
0112     if sign(delx_vec(idelta)) == sign(pk_old(idelta)) && iter > 1 && flag == 0
0113         delta = 0; flag = 0;
0114     else
0115         pk = pk_old;
0116         % dk = AtA*del_x_vec;
0117         dk_temp = A*delx_vec;
0118         dk = A'*dk_temp;
0119         
0120         %%%--- compute step size
0121         in = [];
0122         
0123         % Setting shrinkage_flag to zero shrinks new active constraint towards the
0124         % final value instantly if doing so doesn't disturb the active set
0125         
0126         epsilon_temp = epsilon.*(epsilon>W_vec)+W_vec.*(epsilon<=W_vec);
0127         one_temp = epsilon>W_vec;
0128         
0129         in.delta_flag = 2;
0130         in.pk = pk; in.dk = dk;
0131         in.ak = epsilon_temp; in.bk = -one_temp;
0132         in.gamma = gamma_xh; in.gamma_c = gamma_xc;
0133         in.delx_vec = delx_vec; in.x = xk_1;
0134         out = compute_delta(in);
0135         delta = out.delta; idelta = out.idelta;
0136         flag = out.flag;
0137         
0138         xk_1 = x_k+delta*delx_vec;
0139         pk_old = pk+delta*dk;
0140         epsilon_old = epsilon;
0141         epsilon = epsilon-delta;
0142         
0143         if epsilon <= unique_eps(eps_iter)
0144             epsilon = unique_eps(eps_iter);
0145             delta_end = epsilon_old-epsilon;
0146             pk_old = pk+delta_end*dk;
0147             epsilon_temp = epsilon.*(epsilon>W_vec)+W_vec.*(epsilon<=W_vec);
0148             pk_old([gamma_xh]) = sign(pk_old([gamma_xh])).*epsilon_temp([gamma_xh]);
0149             
0150             xk_1 = x_k + delta_end*delx_vec;
0151             eps_iter = eps_iter+1;
0152             if eps_iter > length(unique_eps)
0153                 %disp('done!');
0154                 break;
0155             else
0156                 %disp('switch epsilon!');
0157                 flag = 1;
0158                 z_x = -sign(pk_old);
0159                 indicator_temp = epsilon>W_vec;
0160                 rhs = indicator_temp.*z_x;
0161                 update_mode = 'recompute';
0162                 update_delx;
0163                 continue;
0164             end
0165         end
0166         if epsilon <= min(W_vec); %sqrt(2*log(N))*sigma; %1e-7 %|| iter > 5*T || (length(gamma_lambda) == K)
0167             delta_end = epsilon_old-thresh;
0168             pk_old = pk+delta_end*dk;
0169             xk_1 = x_k + delta_end*delx_vec;
0170             % disp('done!');
0171             break;
0172         end
0173         if length(gamma_x)-flag >= Te
0174             total_time = cputime-t0;
0175             break;
0176         end
0177     end
0178     % disp(sprintf(['iter = %d, delta = %3.4g, idelta = %d, flag = %d'], iter, delta, idelta, flag));
0179     itr_history = [itr_history; idelta delta flag];
0180     
0181     % update support
0182     update_supp;
0183     
0184     epsilon_temp = epsilon.*(epsilon>W_vec)+W_vec.*(epsilon<=W_vec);
0185     pk_old([gamma_xh; idelta]) = sign(pk_old([gamma_xh; idelta])).*epsilon_temp([gamma_xh; idelta]);
0186     
0187     % update delx
0188     z_x = -sign(pk_old);
0189     indicator_temp = epsilon>W_vec;
0190     rhs = indicator_temp.*z_x;
0191     update_mode = 'update';
0192     update_delx;
0193         
0194     constr_violation = nnz((abs(pk_old(gamma_xc))-epsilon_temp(gamma_xc))>1e-10);
0195     sign_violation = nnz(abs(sign(pk_old(gamma_xh))+sign(xk_1(gamma_xh)))>1);
0196     if constr_violation
0197         chk = gamma_xc((abs(pk_old(gamma_xc))-epsilon_temp(gamma_xc))>1e-10);
0198         stp = 1;
0199         fprintf('problem... with constraint violation -- %s\n', mfilename);
0200         fprintf('Refactorize the matrix... recompute delx \n');
0201         % some times it comes here due to bad conditioning of AtAgx.
0202         update_mode = 'init0';
0203         update_delx;
0204     end
0205     if sign_violation>1
0206         chk = gamma_xh(abs(sign(pk_old(gamma_xh))+sign(xk_1(gamma_xh)))>1);
0207         stp = 1;
0208         fprintf('problem... sign mismatch -- %s\n',mfilename);
0209         fprintf('Refactorize the matrix... recompute delx \n');
0210         update_mode = 'init0';
0211         update_delx;
0212     end
0213     if err_record
0214         error_table = [error_table; epsilon norm(xk_1-x_orig) length(gamma_x)];
0215     end
0216 end
0217 
0218 if err_record
0219     error_table = [error_table; epsilon norm(xk_1-x_orig) length(gamma_x)];
0220 end
0221 total_iter = iter;
0222 total_time = cputime-t0;
0223 
0224 out = [];
0225 out.x_out = xk_1;
0226 out.gamma = gamma_xh;
0227 out.iter = total_iter;
0228 out.time = total_time;
0229 out.error_table = error_table;

Generated on Mon 10-Jun-2013 23:03:23 by m2html © 2005