0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013 function out = script_rwtBPDN_adaptive(in)
0014
0015 A = in.A; y = in.y; x = in.x;
0016 gamma_orig = find(x);
0017 [M N] = size(A);
0018 tau = in.tau; max_rwt = in.max_rwt;
0019 x_init = in.x_init;
0020
0021
0022 rwt_mode = in.rwt_mode;
0023
0024 delx_mode = in.delx_mode;
0025 homotopy_update = 'v1';
0026 debias = in.debias;
0027 verbose = in.verbose;
0028
0029 maxiter = 2*N;
0030 iter_ALL = [];
0031 err_ALL = [];
0032 time_ALL = [];
0033 supp_diff = [];
0034
0035 if norm(x_init) == 0
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050 ewt = 1; shrinkage_mode = 'Trwt'; shrinkage_flag = 0;
0051
0052
0053
0054 if verbose
0055 fprintf(' ewt = %1.2g, shrinkage_mode: %s, ',ewt, shrinkage_mode);
0056 end
0057
0058 in = [];
0059 in.tau = tau;
0060 in.ewt = ewt;
0061 in.shrinkage_flag = shrinkage_flag;
0062 in.shrinkage_mode = shrinkage_mode;
0063 in.maxiter = maxiter;
0064 in.debias = debias;
0065 in.early_terminate = 0;
0066 in.x_orig = x;
0067 in.record = 1;
0068 in.omp = 0;
0069 in.plots = 0;
0070 in.plot_wts = 0;
0071 in.delx_mode = delx_mode;
0072
0073 tic;
0074 out = wtBPDN_adaptive_function(A, y, in);
0075 time_ALL = [time_ALL toc];
0076
0077 x_init = out.x_out;
0078 gamma_old = out.gamma;
0079 iter_old = out.iter;
0080 tau_vec = out.tau_vec;
0081 iter_ALL = [iter_ALL iter_old];
0082 err_ALL = [err_ALL norm(x-x_init)/norm(x)];
0083 supp_diff = [supp_diff length(setxor(gamma_old,gamma_orig))];
0084 if verbose
0085 fprintf('err = %3.4g, iter = %3.4g, time = %3.4g \n',err_ALL(end), iter_ALL(end), time_ALL(end));
0086 end
0087 end
0088 W_new = tau_vec;
0089
0090
0091
0092
0093
0094 xh_mod = x_init;
0095 gamma_xh = gamma_old;
0096 for rwt_itr = 1:max_rwt
0097 gamma_old = gamma_xh;
0098 x_old = xh_mod;
0099 W_old = W_new;
0100
0101 [alpha beta epsilon] = weight_param(rwt_mode,rwt_itr,x_old,M);
0102 W_new = tau/alpha./(beta*abs(x_old)+epsilon);
0103
0104
0105
0106
0107 switch homotopy_update
0108 case 'v1'
0109
0110
0111
0112 pk_old = A'*(A*x_old-y);
0113 pk_old(gamma_old) = sign(pk_old(gamma_old)).*W_old(gamma_old);
0114 in = [];
0115 in.x_old = x_old;
0116 in.gamma = gamma_old;
0117 in.pk_old = pk_old;
0118 in.W_old = W_old;
0119 in.W_new = W_new;
0120 in.maxiter = maxiter;
0121 dW = W_new-W_old;
0122 in.maxiter = maxiter;
0123
0124
0125
0126 in.delx_mode = delx_mode;
0127 switch in.delx_mode
0128 case 'mil';
0129
0130 AtAgx = A(:,gamma_old)'*A(:,gamma_old);
0131 iAtAgx = pinv(AtAgx);
0132 in.AtA = AtAgx;
0133 in.iAtA = iAtAgx;
0134 case 'qr';
0135
0136 [Q R] = qr(A(:,gamma_old),0);
0137 in.Q = Q; in.R = R;
0138 case 'qrM'
0139
0140 [Q0 R0] = qr(A(:,gamma_old));
0141 in.Q0 = Q0; in.R0 = R0;
0142 end
0143 tic;
0144 out = wtBPDN_Update_function_v1(A, y, in);
0145
0146 time_update = toc;
0147 xh_mod = out.x_out;
0148 gamma_xh = out.gamma;
0149 iter_update = out.iter;
0150 case 'v2'
0151
0152
0153 AW_old = A*diag(tau./W_old);
0154 u0_hat = x_old.*(W_old/tau);
0155 ds = AW_old'*(AW_old*u0_hat-y);
0156 AW = A*diag(tau./W_new);
0157 yhat = AW*u0_hat;
0158 pk_old = ds; pk_old(gamma_old) = sign(pk_old(gamma_old)).*tau;
0159 in = [];
0160 in.x_old = u0_hat;
0161 in.gamma = gamma_old;
0162 in.pk_old = pk_old;
0163 in.tau = tau;
0164 in.maxiter = maxiter;
0165 in.yhat = yhat;
0166
0167
0168
0169
0170 in.delx_mode = delx_mode;
0171 switch in.delx_mode
0172 case 'mil';
0173
0174 AtAgx = AW(:,gamma_old)'*AW(:,gamma_old);
0175 iAtAgx = pinv(AtAgx);
0176 in.AtA = AtAgx;
0177 in.iAtA = iAtAgx;
0178 case 'qr';
0179
0180 [Q R] = qr(AW(:,gamma_old),0);
0181 in.Q = Q; in.R = R;
0182 case 'qrM'
0183
0184 [Q0 R0] = qr(AW(:,gamma_old));
0185 in.Q0 = Q0; in.R0 = R0;
0186 case 'cg'
0187
0188 in.W_new = tau./W_new;
0189 end
0190 tic;
0191 out = wtBPDN_Update_function_v2(AW, y, in);
0192
0193 time_update = toc;
0194
0195 xh_mod = out.x_out.*(tau./W_new);
0196 gamma_xh = out.gamma;
0197 iter_update = out.iter;
0198 end
0199 iter_ALL = [iter_ALL iter_update];
0200 err_ALL = [err_ALL norm(x-xh_mod)/norm(x)];
0201 time_ALL = [time_ALL time_update];
0202 supp_diff = [supp_diff length(setxor(gamma_xh,gamma_orig))];
0203 end
0204
0205 out.x_out = xh_mod;
0206 out.x_init = x_init;
0207 out.iter = iter_ALL;
0208 out.gamma = gamma_xh;
0209 out.W_new = W_new;
0210 out.err = err_ALL;
0211 out.time = time_ALL;
0212 out.supp_diff = supp_diff;