0001 function [x, Out] = yall1(A, b, opts)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046 [A,At,b,opts] = linear_operators(A,b,opts);
0047
0048 m = length(b);
0049 L1L1 = isfield(opts,'nu') && opts.nu > 0;
0050 if L1L1 && isfield(opts,'weights')
0051 opts.weights = [opts.weights(:); ones(m,1)];
0052 end
0053
0054
0055 posrho = isfield(opts,'rho') && opts.rho > 0;
0056 posdel = isfield(opts,'delta') && opts.delta > 0;
0057 posnu = isfield(opts,'nu') && opts.nu > 0;
0058 nonneg = isfield(opts,'nonneg') && opts.nonneg == 1;
0059 if isfield(opts,'x0'); x0 = opts.x0; else x0 = []; end
0060 if isfield(opts,'z0'); z0 = opts.z0; else z0 = []; end
0061
0062
0063 if posdel && posrho || posdel && posnu || posrho && posnu
0064 fprintf('Model parameter conflict! YALL1: set delta = 0 && nu = 0;\n');
0065 opts.delta = 0; posdel = false;
0066 opts.nu = 0; posnu = false;
0067 end
0068 prob = 'the basis pursuit problem';
0069 if posrho, prob = 'the unconstrained L1L2 problem'; end
0070 if posdel, prob = 'the constrained L1L2 problem'; end
0071 if posnu, prob = 'the unconstrained L1L1 problem'; end
0072
0073
0074
0075 Atb = At(b);
0076 bmax = norm(b,inf);
0077 L2Unc_zsol = posrho && norm(Atb,inf) <= opts.rho;
0078 L2Con_zsol = posdel && norm(b) <= opts.delta;
0079 L1L1_zsol = posnu && bmax < opts.tol;
0080 BP_zsol = ~posrho && ~posdel && ~posnu && bmax < opts.tol;
0081 zsol = L2Unc_zsol || L2Con_zsol || BP_zsol || L1L1_zsol;
0082 if zsol
0083 n = length(Atb);
0084 x = zeros(n,1);
0085 Out.iter = 0;
0086 Out.cntAt = 1;
0087 Out.cntA = 0;
0088 Out.exit = 'Data b = 0';
0089 return;
0090 end
0091
0092
0093
0094 b1 = b / bmax;
0095 if posrho, opts.rho = opts.rho / bmax; end
0096 if posdel, opts.delta = opts.delta / bmax; end
0097 if isfield(opts,'xs'), opts.xs = opts.xs/bmax; end
0098
0099
0100 t0 = cputime;
0101 [x1,Out] = yall1_solve(A, At, b1, x0, z0, opts);
0102 Out.cputime = cputime - t0;
0103
0104
0105 x = x1 * bmax;
0106 if L1L1; x = x(1:end-m); end
0107 if isfield(opts,'basis')
0108 x = opts.basis.trans(x);
0109 end
0110 if nonneg; x = max(0,x); end
0111
0112 end
0113
0114
0115 function [A,At,b,opts] = linear_operators(A0, b0, opts)
0116
0117
0118
0119
0120 b = b0;
0121 if isnumeric(A0);
0122 if size(A0,1) > size(A0,2);
0123 error('A must have m < n');
0124 end
0125 A = @(x) A0*x;
0126 At = @(y) (y'*A0)';
0127 elseif isstruct(A0) && isfield(A0,'times') && isfield(A0,'trans');
0128 A = A0.times;
0129 At = A0.trans;
0130 elseif isa(A0,'function_handle')
0131 A = @(x) A0(x,1);
0132 At = @(x) A0(x,2);
0133 else
0134 error('A must be a matrix, a struct or a function handle');
0135 end
0136
0137
0138 if isfield(opts,'basis')
0139 C = A; Ct = At; clear A At;
0140 B = opts.basis.times;
0141 Bt = opts.basis.trans;
0142 A = @(x) C(Bt(x));
0143 At = @(y) B(Ct(y));
0144 end
0145
0146
0147 if isfield(opts,'nu') && opts.nu > 0
0148 C = A; Ct = At; clear A At;
0149 m = length(b0);
0150 nu = opts.nu;
0151 t = 1/sqrt(1 + nu^2);
0152 A = @(x) ( C(x(1:end-m)) + nu*x(end-m+1:end) )*t;
0153 At = @(y) [ Ct(y); nu*y ]*t;
0154 b = b0*t;
0155 end
0156
0157 if ~isfield(opts,'nonorth');
0158 opts.nonorth = check_orth(A,At,b);
0159 end
0160
0161 end
0162
0163
0164 function nonorth = check_orth(A, At, b)
0165
0166
0167
0168 nonorth = 0;
0169 s1 = randn(size(b));
0170 s2 = A(At(s1));
0171 err = norm(s1-s2)/norm(s1);
0172 if err > 1.e-12; nonorth = 1; end
0173 end
0174
0175
0176
0177
0178
0179
0180 function [x, Out] = yall1_solve(A,At,b,x0,z0,opts)
0181
0182
0183
0184
0185
0186 m = length(b); bnrm = norm(b);
0187 [tol,mu,maxit,print,nu,rho,delta, ...
0188 w,nonneg,nonorth,gamma] = get_opts;
0189 x = x0; z = z0;
0190 if isempty(x0); x = At(b); end
0191 n = length(x);
0192 if isempty(z0), z = zeros(n,1); end
0193 if isfield(opts,'nonorth') && opts.nonorth > 0
0194 y = zeros(m,1); Aty = zeros(n,1);
0195 end
0196 if print; fprintf('--- YALL1 v1.4 ---\n'); end
0197 if print; iprint1(0); end;
0198
0199 mu_orig = mu;
0200 rdmu = rho / mu;
0201 rdmu1 = rdmu + 1;
0202 bdmu = b / mu;
0203 ddmu = delta / mu;
0204
0205 Out.cntA = 0; Out.cntAt = 0;
0206 rel_gap = 0; rel_rd = 0;
0207 rel_rp = 0; stop = 0;
0208
0209
0210 for iter = 1:maxit
0211
0212
0213 xdmu = x / mu;
0214 if ~nonorth;
0215 y = A(z - xdmu) + bdmu;
0216 if rho > 0;
0217 y = y / rdmu1;
0218 elseif delta > 0
0219 y = max(0, 1 - ddmu/norm(y))*y;
0220 end
0221 Aty = At(y);
0222 else
0223 ry = A(Aty - z + xdmu) - bdmu;
0224 if rho > 0; ry = ry + rdmu*y; end
0225 Atry = At(ry);
0226 denom = Atry'*Atry;
0227 if rho > 0, denom = denom + rdmu * (ry'*ry); end
0228 stp = real(ry'*ry)/(real(denom) + eps);
0229 Out.cntAt = Out.cntAt + 1;
0230 y = y - stp*ry;
0231 Aty = Aty - stp*Atry;
0232 end
0233
0234 z = Aty + xdmu;
0235 z = proj2box(z,w,nonneg,nu,m);
0236
0237 Out.cntA = Out.cntA + 1;
0238 Out.cntAt = Out.cntAt + 1;
0239
0240 rd = Aty - z; xp = x;
0241 x = x + (gamma*mu) * rd;
0242
0243
0244 if rem(iter,2) == 0,
0245 check_stopping; update_mu;
0246 end
0247 if print > 1; iprint2; end
0248 if stop; break; end
0249
0250 end
0251
0252
0253 Out.iter = iter;
0254 Out.mu = [mu_orig mu];
0255 Out.obj = [objp objd];
0256 Out.y = y; Out.z = z;
0257
0258 if iter == maxit; Out.exit = 'Exit: maxiter'; end
0259 if print; iprint1(1); end
0260
0261
0262 function [tol,mu,maxit,print,nu,rho,delta, ...
0263 w,nonneg,nonorth,gamma] = get_opts
0264
0265 tol = opts.tol;
0266 mu = mean(abs(b));
0267
0268 maxit = 9999;
0269 print = 0;
0270 nu = 0;
0271 rho = eps;
0272 delta = 0;
0273 w = 1;
0274 nonneg = 0;
0275 nonorth = 0;
0276 gamma = 1.;
0277 if isfield(opts,'mu'); mu = opts.mu; end
0278 if isfield(opts,'maxit'); maxit = opts.maxit; end
0279 if isfield(opts,'print'); print = opts.print; end
0280 if isfield(opts,'nu'); nu = opts.nu; end
0281 if isfield(opts,'rho'); rho = opts.rho; end
0282 if isfield(opts,'delta'); delta = opts.delta; end
0283 if isfield(opts,'weights'); w = opts.weights; end
0284 if isfield(opts,'nonneg'); nonneg = opts.nonneg; end
0285 if isfield(opts,'nonorth'); nonorth = opts.nonorth; end
0286 if isfield(opts,'gamma'); gamma = opts.gamma; end
0287 end
0288
0289 function z = proj2box(z,w,nonneg,nu,m)
0290 if nonneg
0291 z = min(w,real(z));
0292 if nu > 0
0293 z(end-m:end) = max(-1,z(end-m:end));
0294 end
0295 else
0296 z = z .* w ./ max(w,abs(z));
0297 end
0298 end
0299
0300 function check_stopping
0301 q = 0.1;
0302 if delta > 0; q = 0; end
0303
0304 rdnrm = norm(rd);
0305 rel_rd = rdnrm / norm(z);
0306
0307 objp = sum(abs(w.*x));
0308 objd = b'*y;
0309 if delta > 0, objd = objd - delta*norm(y); end
0310 if rho > 0
0311 rp = A(x) - b;
0312 rpnrm = norm(rp);
0313 Out.cntA = Out.cntA + 1;
0314 objp = objp + (0.5/rho)*rpnrm^2;
0315 objd = objd - (0.5*rho)*norm(y)^2;
0316 end
0317 rel_gap = abs(objd - objp)/abs(objp);
0318
0319
0320 xrel_chg = norm(x-xp)/norm(x);
0321 if xrel_chg < tol*(1 - q)
0322 Out.exit = 'Exit: Stablized';
0323 stop = 1; return;
0324 end
0325
0326
0327 if xrel_chg >= tol*(1 + q); return; end
0328 gap_small = rel_gap < tol;
0329 if ~gap_small; return; end
0330 d_feasible = rel_rd < tol;
0331 if ~d_feasible; return; end
0332
0333
0334 if rho == 0,
0335 rp = A(x) - b;
0336 rpnrm = norm(rp);
0337 Out.cntA = Out.cntA + 1;
0338 end;
0339 if rho > 0;
0340 p_feasible = true;
0341 elseif delta > 0
0342 p_feasible = rpnrm <= delta*(1 + tol);
0343 else
0344 p_feasible = rpnrm < tol*bnrm;
0345 end
0346 if p_feasible, stop = 1; Out.exit = 'Exit: Converged'; end
0347 end
0348
0349 function iprint1(mode)
0350 switch mode;
0351 case 0;
0352 rp = A(x) - b;
0353 rpnrm = norm(rp);
0354 fprintf(' norm( A*x0 - b ) = %6.2e\n',rpnrm);
0355 case 1;
0356 rp = A(x) - b;
0357 objp = sum(abs(w.*x));
0358 objd = b'*y;
0359 if rho > 0
0360 objp = objp + (0.5/rho)*(rp'*rp);
0361 objd = objd - (0.5*rho)*( y'*y );
0362 end
0363 if delta > 0; objd = objd - delta*norm(y); end
0364 dgap = abs(objd - objp);
0365 rel_gap = dgap / abs(objp);
0366 rdnrm = norm(rd);
0367 rel_rd = rdnrm / norm(z);
0368 rpnrm = norm(rp);
0369 rel_rp = rpnrm / bnrm;
0370 fprintf(' Rel_Gap Rel_ResD Rel_ResP\n');
0371 fprintf(' %8.2e %8.2e %8.2e\n',rel_gap,rel_rd,rel_rp);
0372 end
0373 end
0374
0375 function iprint2
0376 rdnrm = norm(rd);
0377 rp = A(x) - b;
0378 rpnrm = norm(rp);
0379 objp = sum(abs(w.*x));
0380 objd = b'*y;
0381 if rho > 0
0382 objp = objp + (0.5/rho)*rpnrm^2;
0383 objd = objd - (0.5*rho)*(y'*y);
0384 end
0385 if delta > 0;
0386 objd = objd - delta*norm(y);
0387 end
0388 gap = abs(objd - objp);
0389 if ~rem(iter,50)
0390 fprintf(' Iter %4i:' ,iter);
0391 fprintf(' Rel_Gap %6.2e',gap/objp);
0392 fprintf(' Rel_ResD %6.2e',rdnrm/norm(z));
0393 fprintf(' Rel_ResP %6.2e',norm(rp)/norm(b));
0394 fprintf('\n');
0395 end
0396
0397 if ~isfield(opts,'xs'), return; end
0398 if isfield(opts,'nu') && opts.nu, return; end
0399
0400 if iter <= 1,
0401 Out.objd = []; Out.rd = [];
0402 Out.objp = []; Out.rp = [];
0403 opts.xsnrm = norm(opts.xs);
0404 Out.relerr = [];
0405 end
0406 Out.objd = [Out.objd objd];
0407 Out.objp = [Out.objp objp];
0408 Out.rd = [Out.rd rdnrm];
0409 Out.rp = [Out.rp rpnrm];
0410 err = norm(x-opts.xs)/opts.xsnrm;
0411 Out.relerr = [Out.relerr err];
0412 end
0413
0414 function update_mu
0415 mfrac = 0.1; big = 50; nup = 8;
0416 mu_min = mfrac^nup * mu_orig;
0417 do_update = rel_gap > big*rel_rd;
0418 do_update = do_update && mu > 1.1*mu_min;
0419 do_update = do_update && iter > 10;
0420 if ~do_update, return; end
0421
0422 mu = max(mfrac*mu,mu_min);
0423 rdmu = rho / mu; rdmu1 = rdmu + 1;
0424 bdmu = b / mu; ddmu = delta / mu;
0425 if print>1; fprintf(' -- mu updated\n'); end
0426 end
0427
0428 end