yall1

PURPOSE ^

SYNOPSIS ^

function [x, Out] = yall1(A, b, opts)

DESCRIPTION ^

 A solver for L1-minimization models:

 min ||Wx||_{w,1}, st Ax = b
 min ||Wx||_{w,1} + (1/nu)||Ax - b||_1
 min ||Wx||_{w,1} + (1/2*rho)||Ax - b||_2^2
 min ||x||_{w,1}, st Ax = b                and x > = 0
 min ||x||_{w,1} + (1/nu)||Ax - b||_1,      st x > = 0
 min ||x||_{w,1} + (1/2*rho)||Ax - b||_2^2, st x > = 0

 where (A,b,x) can be complex or real 
 (but x must be real in the last 3 models)

 Copyright(c) 2009-2011 Yin Zhang, Junfeng Yang, Wotao Yin

 --- Input:
     A --- either an m x n matrix or
           a structure with 2 fields:
           1) A.times: a function handle for A*x
           2) A.trans: a function handle for A'*y
     b --- an m-vector, real or complex
  opts --- a structure with fields:
           opts.tol   -- tolerance *** required ***
           opts.nu    -- values > 0 for L1/L1 model
           opts.rho   -- values > 0 for L1/L2 model
           opts.basis -- sparsifying unitary basis W (W*W = I)
                        a struct with 2 fields:
                        1) times: a function handle for W*x
                        2) trans: a function handle for W'*y
           opts.nonneg  -- 1 for nonnegativity constraints
           opts.nonorth -- 1 for A with non-orthonormal rows
           see the User's Guide for all other options

 --- Output: 
     x --- last iterate (hopefully an approximate solution)
   Out --- a structure with fields:
           Out.exit    --- exit information
           Out.iter    --- #iterations taken
           Out.cputime --- solver CPU time
           Out.y       --- dual variable
           Out.z       --- dual slack
           .....       --- and some more
 --------------------------------------------------------------
 define linear operators

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [x, Out] = yall1(A, b, opts)
0002 %
0003 % A solver for L1-minimization models:
0004 %
0005 % min ||Wx||_{w,1}, st Ax = b
0006 % min ||Wx||_{w,1} + (1/nu)||Ax - b||_1
0007 % min ||Wx||_{w,1} + (1/2*rho)||Ax - b||_2^2
0008 % min ||x||_{w,1}, st Ax = b                and x > = 0
0009 % min ||x||_{w,1} + (1/nu)||Ax - b||_1,      st x > = 0
0010 % min ||x||_{w,1} + (1/2*rho)||Ax - b||_2^2, st x > = 0
0011 %
0012 % where (A,b,x) can be complex or real
0013 % (but x must be real in the last 3 models)
0014 %
0015 % Copyright(c) 2009-2011 Yin Zhang, Junfeng Yang, Wotao Yin
0016 %
0017 % --- Input:
0018 %     A --- either an m x n matrix or
0019 %           a structure with 2 fields:
0020 %           1) A.times: a function handle for A*x
0021 %           2) A.trans: a function handle for A'*y
0022 %     b --- an m-vector, real or complex
0023 %  opts --- a structure with fields:
0024 %           opts.tol   -- tolerance *** required ***
0025 %           opts.nu    -- values > 0 for L1/L1 model
0026 %           opts.rho   -- values > 0 for L1/L2 model
0027 %           opts.basis -- sparsifying unitary basis W (W*W = I)
0028 %                        a struct with 2 fields:
0029 %                        1) times: a function handle for W*x
0030 %                        2) trans: a function handle for W'*y
0031 %           opts.nonneg  -- 1 for nonnegativity constraints
0032 %           opts.nonorth -- 1 for A with non-orthonormal rows
0033 %           see the User's Guide for all other options
0034 %
0035 % --- Output:
0036 %     x --- last iterate (hopefully an approximate solution)
0037 %   Out --- a structure with fields:
0038 %           Out.exit    --- exit information
0039 %           Out.iter    --- #iterations taken
0040 %           Out.cputime --- solver CPU time
0041 %           Out.y       --- dual variable
0042 %           Out.z       --- dual slack
0043 %           .....       --- and some more
0044 % --------------------------------------------------------------
0045 % define linear operators
0046 [A,At,b,opts] = linear_operators(A,b,opts);
0047 
0048 m = length(b);
0049 L1L1 = isfield(opts,'nu') && opts.nu > 0;
0050 if L1L1 && isfield(opts,'weights')
0051     opts.weights = [opts.weights(:); ones(m,1)];
0052 end
0053 
0054 % parse options
0055 posrho = isfield(opts,'rho')    && opts.rho > 0;
0056 posdel = isfield(opts,'delta')  && opts.delta > 0;
0057 posnu  = isfield(opts,'nu')     && opts.nu > 0;
0058 nonneg = isfield(opts,'nonneg') && opts.nonneg == 1;
0059 if isfield(opts,'x0'); x0 = opts.x0; else x0 = []; end 
0060 if isfield(opts,'z0'); z0 = opts.z0; else z0 = []; end 
0061 
0062 % check conflicts % modified by Junfeng
0063 if posdel && posrho || posdel && posnu || posrho && posnu
0064     fprintf('Model parameter conflict! YALL1: set delta = 0 && nu = 0;\n');
0065     opts.delta = 0; posdel = false;
0066     opts.nu    = 0; posnu  = false;
0067 end
0068 prob = 'the basis pursuit problem';
0069 if posrho, prob = 'the unconstrained L1L2 problem'; end
0070 if posdel, prob = 'the constrained L1L2 problem';   end
0071 if posnu,  prob = 'the unconstrained L1L1 problem'; end
0072 % disp(['YALL1 is solving ', prob, '.']);
0073 
0074 % check zero solution % modified by Junfeng
0075 Atb = At(b);
0076 bmax = norm(b,inf);
0077 L2Unc_zsol = posrho && norm(Atb,inf) <= opts.rho;
0078 L2Con_zsol = posdel && norm(b) <= opts.delta;
0079 L1L1_zsol  = posnu  && bmax < opts.tol;
0080 BP_zsol    = ~posrho && ~posdel && ~posnu && bmax < opts.tol;
0081 zsol = L2Unc_zsol || L2Con_zsol || BP_zsol || L1L1_zsol;
0082 if zsol  
0083     n = length(Atb);
0084     x = zeros(n,1); 
0085     Out.iter = 0;
0086     Out.cntAt = 1;
0087     Out.cntA = 0;
0088     Out.exit = 'Data b = 0';
0089     return; 
0090 end
0091 % ========================================================================
0092 
0093 % scaling data and model parameters
0094 b1 = b / bmax;
0095 if posrho, opts.rho   = opts.rho / bmax; end
0096 if posdel, opts.delta = opts.delta / bmax; end
0097 if isfield(opts,'xs'), opts.xs = opts.xs/bmax; end
0098     
0099 % solve the problem
0100 t0 = cputime; 
0101 [x1,Out] = yall1_solve(A, At, b1, x0, z0, opts);
0102 Out.cputime = cputime - t0;
0103 
0104 % restore solution x
0105 x = x1 * bmax;
0106 if L1L1; x = x(1:end-m); end
0107 if isfield(opts,'basis')
0108     x = opts.basis.trans(x);
0109 end
0110 if nonneg; x = max(0,x); end
0111 
0112 end
0113 
0114 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0115 function [A,At,b,opts] = linear_operators(A0, b0, opts)
0116 %
0117 % define linear operators A and At
0118 % (possibly modify RHS b if nu > 0)
0119 %
0120 b = b0;
0121 if isnumeric(A0); 
0122     if size(A0,1) > size(A0,2); 
0123         error('A must have m < n');
0124     end
0125     A  = @(x) A0*x;
0126     At = @(y) (y'*A0)';
0127 elseif isstruct(A0) && isfield(A0,'times') && isfield(A0,'trans');
0128     A  = A0.times;
0129     At = A0.trans;
0130 elseif isa(A0,'function_handle')
0131     A  = @(x) A0(x,1);
0132     At = @(x) A0(x,2);
0133 else
0134     error('A must be a matrix, a struct or a function handle');
0135 end
0136 
0137 % use sparsfying basis W
0138 if isfield(opts,'basis')
0139     C = A; Ct = At; clear A At; 
0140     B  = opts.basis.times;
0141     Bt = opts.basis.trans;
0142     A  = @(x) C(Bt(x));
0143     At = @(y) B(Ct(y));
0144 end
0145 
0146 % solving L1-L1 model if nu > 0
0147 if isfield(opts,'nu') && opts.nu > 0
0148     C = A; Ct = At; clear A At; 
0149     m = length(b0);
0150     nu = opts.nu; 
0151     t = 1/sqrt(1 + nu^2);
0152     A  = @(x) ( C(x(1:end-m)) + nu*x(end-m+1:end) )*t;
0153     At = @(y) [ Ct(y);  nu*y ]*t;
0154     b = b0*t;
0155 end
0156 
0157 if ~isfield(opts,'nonorth'); 
0158     opts.nonorth = check_orth(A,At,b); 
0159 end
0160 
0161 end
0162 
0163 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0164 function nonorth = check_orth(A, At, b)
0165 %
0166 % check whether the rows of A are orthonormal
0167 %
0168 nonorth = 0;
0169 s1 = randn(size(b));
0170 s2 = A(At(s1));
0171 err = norm(s1-s2)/norm(s1);
0172 if err > 1.e-12; nonorth = 1; end
0173 end
0174 
0175 
0176 
0177 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0178 %%%%                   solver                %%%%%
0179 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0180 function [x, Out] = yall1_solve(A,At,b,x0,z0,opts)
0181 
0182 % yall1_solve version 1.4 (July, 2011)
0183 % Copyright(c) 2009-2011 Yin Zhang
0184 
0185 %% initialization
0186 m = length(b); bnrm = norm(b);
0187 [tol,mu,maxit,print,nu,rho,delta, ... 
0188     w,nonneg,nonorth,gamma] = get_opts;
0189 x = x0; z = z0;
0190 if isempty(x0); x = At(b); end
0191 n = length(x); 
0192 if isempty(z0), z = zeros(n,1); end
0193 if isfield(opts,'nonorth') && opts.nonorth > 0
0194     y = zeros(m,1); Aty = zeros(n,1); 
0195 end
0196 if print; fprintf('--- YALL1 v1.4 ---\n'); end
0197 if print; iprint1(0); end;
0198 
0199 mu_orig = mu;
0200 rdmu = rho / mu;
0201 rdmu1 = rdmu + 1;
0202 bdmu = b / mu;
0203 ddmu = delta / mu;
0204 
0205 Out.cntA = 0; Out.cntAt = 0;
0206 rel_gap = 0;  rel_rd  = 0;
0207 rel_rp  = 0;  stop = 0;
0208 
0209 %% main iterations
0210 for iter = 1:maxit
0211     
0212     %% calculations
0213     xdmu = x / mu;
0214     if ~nonorth; % orthonormal A
0215         y = A(z - xdmu) + bdmu;
0216         if rho > 0;
0217             y = y / rdmu1;
0218         elseif delta > 0
0219             y = max(0, 1 - ddmu/norm(y))*y;
0220         end
0221         Aty = At(y);
0222     else     % non-orthonormal A
0223         ry = A(Aty - z + xdmu) - bdmu;
0224         if rho > 0; ry = ry + rdmu*y; end
0225         Atry = At(ry);
0226         denom = Atry'*Atry;
0227         if rho > 0, denom = denom + rdmu * (ry'*ry); end
0228         stp = real(ry'*ry)/(real(denom) + eps);
0229         Out.cntAt = Out.cntAt + 1;
0230         y = y - stp*ry;
0231         Aty = Aty - stp*Atry;
0232     end
0233     
0234     z = Aty + xdmu;
0235     z = proj2box(z,w,nonneg,nu,m);
0236     
0237     Out.cntA  = Out.cntA  + 1;
0238     Out.cntAt = Out.cntAt + 1;
0239 
0240     rd = Aty - z; xp = x;
0241     x = x + (gamma*mu) * rd;
0242         
0243     %% other chores
0244     if rem(iter,2) == 0, 
0245         check_stopping; update_mu; 
0246     end
0247     if print > 1; iprint2; end
0248     if stop; break; end 
0249     
0250 end % main iterations
0251 
0252 % output
0253 Out.iter = iter;
0254 Out.mu = [mu_orig mu];
0255 Out.obj = [objp objd];
0256 Out.y = y; Out.z = z;
0257 
0258 if iter == maxit; Out.exit = 'Exit: maxiter'; end
0259 if print; iprint1(1); end
0260 
0261 %% nested functions
0262     function [tol,mu,maxit,print,nu,rho,delta, ...
0263              w,nonneg,nonorth,gamma] = get_opts
0264         % get or set options
0265         tol = opts.tol;
0266         mu = mean(abs(b)); 
0267         %mu = norm(b)/numel(b);
0268         maxit = 9999;
0269         print = 0;
0270         nu = 0;
0271         rho = eps;
0272         delta = 0;
0273         w = 1;
0274         nonneg = 0;
0275         nonorth = 0;
0276         gamma = 1.; % ADM parameter
0277         if isfield(opts,'mu');       mu = opts.mu;    end
0278         if isfield(opts,'maxit'); maxit = opts.maxit; end
0279         if isfield(opts,'print'); print = opts.print; end        
0280         if isfield(opts,'nu');       nu = opts.nu;    end
0281         if isfield(opts,'rho');     rho = opts.rho;   end
0282         if isfield(opts,'delta'); delta = opts.delta; end
0283         if isfield(opts,'weights'); w = opts.weights; end
0284         if isfield(opts,'nonneg');   nonneg = opts.nonneg;  end
0285         if isfield(opts,'nonorth'); nonorth = opts.nonorth; end
0286         if isfield(opts,'gamma');   gamma   = opts.gamma;   end
0287     end
0288 
0289     function z = proj2box(z,w,nonneg,nu,m)
0290         if nonneg
0291             z = min(w,real(z));
0292             if nu > 0 %L1L1 model
0293                 z(end-m:end) = max(-1,z(end-m:end));
0294             end
0295         else
0296             z = z .* w ./ max(w,abs(z));
0297         end
0298     end
0299 
0300     function check_stopping
0301         q = 0.1; % q in [0,1)
0302         if delta > 0; q = 0; end
0303         % dual residual
0304         rdnrm = norm(rd); 
0305         rel_rd = rdnrm / norm(z);
0306         % duality gap
0307         objp = sum(abs(w.*x));
0308         objd = b'*y;
0309         if delta > 0, objd = objd - delta*norm(y); end
0310         if rho > 0
0311             rp = A(x) - b; 
0312             rpnrm = norm(rp); 
0313             Out.cntA = Out.cntA + 1;
0314             objp = objp + (0.5/rho)*rpnrm^2;
0315             objd = objd - (0.5*rho)*norm(y)^2;
0316         end
0317         rel_gap = abs(objd - objp)/abs(objp);
0318         
0319         % check relative change
0320         xrel_chg = norm(x-xp)/norm(x);
0321         if xrel_chg < tol*(1 - q)
0322             Out.exit = 'Exit: Stablized'; 
0323             stop = 1; return; 
0324         end
0325         
0326         % decide whether to go further
0327         if xrel_chg >= tol*(1 + q); return; end
0328         gap_small = rel_gap < tol;
0329         if ~gap_small; return; end
0330         d_feasible = rel_rd < tol;
0331         if ~d_feasible; return; end
0332 
0333         % check primal residual
0334         if rho == 0, 
0335             rp = A(x) - b; 
0336             rpnrm = norm(rp);
0337             Out.cntA = Out.cntA + 1; 
0338         end;    
0339         if rho > 0;
0340             p_feasible = true;
0341         elseif delta > 0
0342             p_feasible = rpnrm <= delta*(1 + tol);
0343         else
0344             p_feasible = rpnrm < tol*bnrm;
0345         end
0346         if p_feasible, stop = 1; Out.exit = 'Exit: Converged'; end        
0347     end
0348 
0349     function iprint1(mode)
0350         switch mode;
0351             case 0; % at the beginning
0352                 rp = A(x) - b;
0353                 rpnrm = norm(rp);
0354                 fprintf(' norm( A*x0 - b ) = %6.2e\n',rpnrm);
0355             case 1; % at the end
0356                 rp = A(x) - b;
0357                 objp = sum(abs(w.*x));
0358                 objd = b'*y;
0359                 if rho > 0
0360                     objp = objp + (0.5/rho)*(rp'*rp);
0361                     objd = objd - (0.5*rho)*( y'*y );
0362                 end
0363                 if delta > 0; objd = objd - delta*norm(y); end
0364                 dgap = abs(objd - objp);
0365                 rel_gap = dgap / abs(objp);
0366                 rdnrm = norm(rd);
0367                 rel_rd = rdnrm / norm(z);
0368                 rpnrm = norm(rp);
0369                 rel_rp = rpnrm / bnrm;
0370                 fprintf(' Rel_Gap   Rel_ResD  Rel_ResP\n');
0371                 fprintf(' %8.2e  %8.2e  %8.2e\n',rel_gap,rel_rd,rel_rp);
0372         end
0373     end
0374 
0375     function iprint2
0376         rdnrm = norm(rd);
0377         rp = A(x) - b;
0378         rpnrm = norm(rp);
0379         objp = sum(abs(w.*x));
0380         objd = b'*y;
0381         if rho > 0
0382             objp = objp + (0.5/rho)*rpnrm^2;
0383             objd = objd - (0.5*rho)*(y'*y);
0384         end
0385         if delta > 0; 
0386             objd = objd - delta*norm(y); 
0387         end
0388         gap = abs(objd - objp);
0389         if ~rem(iter,50)
0390             fprintf('  Iter %4i:' ,iter);
0391             fprintf('  Rel_Gap  %6.2e',gap/objp);
0392             fprintf('  Rel_ResD %6.2e',rdnrm/norm(z));
0393             fprintf('  Rel_ResP %6.2e',norm(rp)/norm(b));
0394             fprintf('\n');
0395         end
0396         
0397         if ~isfield(opts,'xs'), return; end
0398         if isfield(opts,'nu') && opts.nu, return; end
0399         
0400         if iter <= 1, 
0401             Out.objd = []; Out.rd = []; 
0402             Out.objp = []; Out.rp = [];
0403             opts.xsnrm = norm(opts.xs);
0404             Out.relerr = [];
0405         end
0406         Out.objd = [Out.objd objd];
0407         Out.objp = [Out.objp objp];
0408         Out.rd = [Out.rd rdnrm];
0409         Out.rp = [Out.rp rpnrm];
0410         err = norm(x-opts.xs)/opts.xsnrm;
0411         Out.relerr = [Out.relerr err];
0412     end
0413 
0414     function update_mu    % added to v14
0415         mfrac = 0.1; big = 50; nup = 8;
0416         mu_min = mfrac^nup * mu_orig;
0417         do_update = rel_gap > big*rel_rd;
0418         do_update = do_update && mu > 1.1*mu_min;
0419         do_update = do_update && iter > 10;
0420         if ~do_update, return; end
0421         % do update
0422         mu = max(mfrac*mu,mu_min);
0423         rdmu = rho / mu; rdmu1 = rdmu + 1;
0424         bdmu = b / mu; ddmu = delta / mu;
0425         if print>1; fprintf('  -- mu updated\n'); end
0426     end
0427 
0428 end

Generated on Mon 10-Jun-2013 23:03:23 by m2html © 2005