demo_KalmanRWT

PURPOSE ^

demo_KalmanRWT

SYNOPSIS ^

This is a script file.

DESCRIPTION ^

 demo_KalmanRWT

 Solves the following dynamic BPDN problem over a window t = t_1,...,t_L
 min_x \sum_t \|W_t x_t\|_1 + 1/2*||A_t x_t - y_t||_2^2 + 1/2||F_t x_t - x_t+1\|_2^2

 which updates the solution as the signal changes according to a linear
 dynamical system.

 for instance, y_t = A_t x_t + e_t
               x_t+1 = F_t x_t + f_t 
       where F_t is a partially known function that models prediction
       between the consecutive x_t and f_t denotes the prediction error 
       (e.g., a random drift)

 Applications:
       streaming signal recovery using a dynamic model
 
       track a signal as y, A, and/or x change... 
       predict an estimate of the solution and
       update weights according to the predicted solution

 Written by: Salman Asif, Georgia Tech
 Email: sasif@gatech.edu
 Created: November 2012

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 % demo_KalmanRWT
0002 %
0003 % Solves the following dynamic BPDN problem over a window t = t_1,...,t_L
0004 % min_x \sum_t \|W_t x_t\|_1 + 1/2*||A_t x_t - y_t||_2^2 + 1/2||F_t x_t - x_t+1\|_2^2
0005 %
0006 % which updates the solution as the signal changes according to a linear
0007 % dynamical system.
0008 %
0009 % for instance, y_t = A_t x_t + e_t
0010 %               x_t+1 = F_t x_t + f_t
0011 %       where F_t is a partially known function that models prediction
0012 %       between the consecutive x_t and f_t denotes the prediction error
0013 %       (e.g., a random drift)
0014 %
0015 % Applications:
0016 %       streaming signal recovery using a dynamic model
0017 %
0018 %       track a signal as y, A, and/or x change...
0019 %       predict an estimate of the solution and
0020 %       update weights according to the predicted solution
0021 %
0022 % Written by: Salman Asif, Georgia Tech
0023 % Email: sasif@gatech.edu
0024 % Created: November 2012
0025 
0026 clear
0027 close all force
0028 
0029 % Limit the number of computational threads (for profiling)
0030 maxNumCompThreads(1);
0031 
0032 %% Setup path
0033 mname = mfilename;
0034 mpath = mfilename('fullpath');
0035 mdir = mpath(1:end-length(mname));
0036 cd(mdir);
0037 
0038 addpath utils/
0039 addpath utils/utils_Wavelet
0040 addpath utils/utils_LOT
0041 addpath solvers/
0042 
0043 fprintf(['----------',datestr(now),'-------%s------------------\n'],mname)
0044 
0045 % load RandomStates
0046 %
0047 rseed = 2013;
0048 rand('state',rseed);
0049 randn('state',rseed);
0050 
0051 % simulation parameters
0052 mType = 'sign'; % {'randn','orth','rdct'};
0053 mFixed = 0; % measurement system time-(in)variant
0054 sType = 'pcwreg'; % {'heavisine', 'pcwreg', 'blocks','pcwPoly'}
0055 SNR = 35;       % additive Gaussian noise
0056 
0057 wt_pred = sqrt(0.5);
0058 
0059 N = 256;   % signal length
0060 R = 4; % compression rate
0061 M = round(N/R);    % no. of measurements
0062 
0063 LM = 1*N; % LM: length of measurement window
0064 LS_Kalman = 'smooth'; % {'filter','smooth','inst'};
0065 
0066 % streaming window
0067 P = 3; % size of the working window is P*N
0068 
0069 % signal length
0070 sig_length = 2^15; % 128*128;
0071 
0072 % signal dynamics
0073 dType = 'crshift'; % type of dynamics 'crshift', or 'static'
0074 cshift = -1;
0075 rshift_max = 0.5;
0076 rshift_h = @(z) (rand-0.5)*rshift_max*2;
0077 
0078 % DWT parameters
0079 % type of scaling function
0080 % depth of scaling functions (number of wavelet levels)
0081 % type of extension for DWT (0 - periodic extension, 3 - streaming)
0082 wType = 'daub79'; sym = 1;
0083 wType = 'daub8'; sym = 0;
0084 J = log2(N)-3;
0085 
0086 % rank-1 update mode
0087 delx_mode = 'mil'; % mil or qr
0088 
0089 % add snapshots of the signal in streaming window and average them before comitting to the output.
0090 avg_output = 0; 
0091 
0092 verbose = 0;
0093 
0094 
0095 %% SM: Sampling modes
0096 % % LM: length of measurement window
0097 % LM = 2*N; % 'universal' sampling scheme (align before the overlapping regions of DWT windows that are measured)
0098 if LM > N
0099     LeftEdge_trunc = 1;
0100 else
0101     LeftEdge_trunc = 0;
0102 end
0103 LeftProj_cancel = 1;
0104 
0105 
0106 %%
0107 fprintf('CS-Kalman tracking a dynamical signal and reweighting..\n');
0108 str0 = sprintf('mType-%s, sType-%s, SNR = %d, (N,M,R) = %d, %d, %d, P = %d, LM = %d, LS_Kalman-%s \n wType-%s, J = %d, sym = %d, specified signal-length = %d, \n dType-%s, cshift = %d, rshift_max = %0.3g, wt_pred = %0.3g. ', mType, sType, SNR, N, round(N/R), R, P, LM, LS_Kalman, wType, J, sym, sig_length, dType, cshift, rshift_max, wt_pred);
0109 disp(str0);
0110 
0111 %% DWT setup
0112 % DWT parameters
0113 % Length of each window is L. (extend to adaptive/dyadic windows later?)
0114 % wType = 'daub4'; % type of scaling function
0115 % J = 3; % depth of scaling functions (number of wavelet levels)
0116 % sym = 3; % type of extension for DWT (0 - periodic extension, 3 - streaming)
0117 in_Psi = []; in_Psi.N = N; in_Psi.J = J; in_Psi.wType = wType; in_Psi.sym = sym;
0118 Psi = create_DWT(in_Psi); % DWT synthesis matrix over a window
0119 L = size(Psi,1);
0120 
0121 %% Signal generation
0122 
0123 % Setup dynamical model
0124 % At every time instance, add to the original/previous signal
0125 % an integer circular shift that is known
0126 % a random drift that is unknown
0127 if strcmpi(dType, 'crshift')
0128     % Generate a signal by circular shift and a random drift in a seed
0129     % signal
0130     in = []; in.type = sType; in.randgen = 0; in.take_fwt = 0;
0131     [x_init sig wave_struct] = genSignal(N,in);
0132     
0133     F_h = @(x,cshift,rshift) interp1(1:N,circshift(x,cshift),[1:N]+rshift,'linear','extrap')';
0134     
0135     F0 = zeros(N);
0136     for ii = 1:N;
0137         F0(:,ii) = F_h(circshift([1; zeros(N-1,1)],ii-1),cshift,0);
0138     end
0139     sigt = sig; sig = [];
0140     for ii = 1:round(sig_length/N);
0141         rshift = rshift_h(1);
0142         sigt = F_h(sigt, cshift, rshift);
0143         sig = [sig; sigt];
0144     end
0145 else
0146     % Generate a predefined streaming signal
0147     in = []; in.type = sType; in.randgen = 0; in.take_fwt = 0;
0148     [x_init sig wave_struct] = genSignal(N,in);
0149     
0150     cshift = 0; rshift = 0;
0151     F_h = @(x,cshift,rshift) x;
0152     F0 = eye(N);
0153 end
0154 % sig = [zeros(L-N,1);sig];
0155 sig_length = length(sig);
0156 
0157 % view DWT coefficients...
0158 alpha_vec = apply_DWT(sig,N,wType,J,sym);
0159 figure(123);
0160 subplot(211); imagesc(reshape(alpha_vec,N,length(alpha_vec)/N));
0161 axis xy;
0162 subplot(212); plot(alpha_vec);
0163 
0164 % view innovations in the signal..
0165 % dsig = []; for n = 0:N:length(sig)-N; dsig = [dsig; sig(n+1:n+N)-circshift(sig(n+N+1:n+2*N),1)]; figure(1); plot([sig(n+1:n+N) sig(n+1:n+N)-circshift(sig(n+N+1:n+2*N),1)]); pause; end
0166 
0167 % Simulation parameters
0168 streaming_iter = ceil(length(sig)/N);
0169 SIM_stack = cell(streaming_iter,1);
0170 SIM_memory = cell(streaming_iter,1);
0171 
0172 x_vec = zeros(N*streaming_iter,1);
0173 xh_vec = zeros(N*streaming_iter,3);
0174 sig_vec = zeros(length(sig),1);
0175 sigh_vec = zeros(length(sig),3);
0176 
0177 %% Setup sensing matrices
0178 in = []; in.type = mType;
0179 if mFixed
0180     At = genAmat(M,LM,in);
0181     genAmat_h = @(m,n) At;
0182 else
0183     genAmat_h = @(M,N) genAmat(M,N,in);
0184 end
0185 in.P = P-(LM-N)/N;
0186 in.LM = LM; in.M = M; in.N = N;
0187 PHI = create_PHI(in);
0188 
0189 %% Dynamics matrix
0190 F = zeros(P*N,(P+1)*N);
0191 for p = 1:P
0192     F((p-1)*N+1:p*N,(p-1)*N+1:(p+1)*N) = [F0 -eye(N)];
0193 end
0194 F = wt_pred*F(:,N+1:end);
0195 
0196 %% Create analysis/synthesis matrix explicitly and compute sparse coeffs.
0197 in = [];
0198 in.P = P; in.Psi = Psi;
0199 % in.P = P; in.Jp = Jp; in.wType = wType; in.N = N; in.sym = sym;
0200 PSI = create_PSI_DWT(in);
0201 
0202 % Sparse coefficients...
0203 T_length = size(PSI,1);
0204 t_ind = 1:T_length;
0205 
0206 sigt = sig(t_ind); % Signal under the LOT window at time t
0207 if sym == 1 || sym == 2
0208     x = pinv(PSI'*PSI)*(PSI'*sigt); % Sparse LOT coefficients
0209 else
0210     x = PSI'*sigt;
0211 end
0212 
0213 %% initialize with a predicted value of first x
0214 % xt = x(1:N);
0215 %
0216 % At = genAmat_h(M,N);
0217 % sigma = sqrt(norm(At*xt)^2/10^(SNR/10)/M);
0218 % e = randn(M,1)*sigma;
0219 % yt = At*xt+e;
0220 %
0221 % tau = max(1e-2*max(abs(At'*yt)),sigma*sqrt(log(N)));
0222 %
0223 % % rwt L1 with the first set of measurement...
0224 % in = [];
0225 % in.tau = tau; W = tau;
0226 % in.delx_mode = delx_mode;
0227 % for wt_itr = 1:5
0228 %     W_old = W;
0229 %
0230 %     out = l1homotopy(At,yt,in);
0231 %     xh = out.x_out;
0232 %
0233 %     % Update weights
0234 %     xh_old = xh;
0235 %     alpha = 1; epsilon = 1;
0236 %     beta = M*(norm(xh_old,2)/norm(xh_old,1))^2;
0237 %     W = tau/alpha./(beta*abs(xh_old)+epsilon);
0238 %
0239 %     yh = At*xh_old;
0240 %     Atr = At'*(At*xh-yt);
0241 %     u =  -W.*sign(xh)-Atr;
0242 %     pk_old = Atr+u;
0243 %
0244 %     in = out;
0245 %     in.xh_old = xh;
0246 %     in.pk_old = pk_old;
0247 %     in.u = u;
0248 %     in.W_old = W_old;
0249 %     in.W = W;
0250 % end
0251 % xh(abs(xh)<tau/sqrt(log(N))) = 0;
0252 
0253 
0254 % Another way to initialize...
0255 % Best M/2-sparse signal...
0256 % [val_sort ind_sort] = sort(abs(x),'descend');
0257 % xh = x;
0258 % xh(ind_sort(P*N/2+1:end)) = 0;
0259 
0260 % Oracle value for the initialization
0261 xh = x; disp('oracle initialization');
0262 
0263 % model for the outgoing window...
0264 sim = 1;
0265 st_ind = N;
0266 t_ind = st_ind+t_ind;
0267 s_ind = t_ind(1:L);
0268 
0269 sig_out = PSI(st_ind+1:st_ind+N,:)*xh;
0270 xh = xh(st_ind+1:end);
0271 
0272 xh_out = xh(1:N);
0273 x_vec((sim-1)*N+1:sim*N,1) = x(st_ind+1:st_ind+N);
0274 xh_vec((sim-1)*N+1:sim*N,1:3) = [xh_out xh_out xh_out];
0275 
0276 sig_temp = Psi*xh_out;
0277 sig_temp = [sig_out; sig_temp(N+1:end)];
0278 sig_vec(s_ind) = sigt(s_ind);
0279 sigh_vec(s_ind,1:3) = sigh_vec(s_ind,1:3)+[sig_temp sig_temp sig_temp];
0280 
0281 
0282 %% Generate complete measurement system
0283 % Sparse coefficients...
0284 t_ind = t_ind + N;
0285 sigt = sig(t_ind); % Signal under the LOT window at time t
0286 if sym == 1 || sym == 2
0287     x = pinv(PSI'*PSI)*(PSI'*sigt); % Sparse LOT coefficients
0288 else
0289     x = PSI'*sigt;
0290 end
0291 
0292 y = PHI*sigt(1:end-(L-N));
0293 
0294 leny = length(y);
0295 sigma = sqrt(norm(y)^2/10^(SNR/10)/leny);
0296 e = randn(leny,1)*sigma;
0297 y = y+e;
0298 
0299 
0300 PSI_M = PSI(1:end-(L-N),:);
0301 A = [PHI; F]*PSI_M;
0302 
0303 
0304 sig_out = sigh_vec(t_ind(1:N)-N,1);
0305 y = [y; -wt_pred*F0*sig_out; zeros((P-1)*N,1)];
0306 
0307 % REMOVE the part of outgoing DWT projection in the overlapping region
0308 % on left side of streaming window...
0309 if LeftProj_cancel
0310     y = y-[PHI(:,1:(L-N));F(:,1:(L-N))]*(Psi(end-(L-N)+1:end,:)*xh_out(1:N));
0311 end
0312 
0313 %% parameter selection
0314 % tau = sigma*sqrt(log(N));
0315 tau = max(1e-2*max(abs(A'*y)),sigma*sqrt(log(P*N)));
0316 
0317 maxiter = 2*P*N;
0318 err_fun = @(z) (norm(x-z)/norm(x))^2;
0319 
0320 
0321 %% Initialize by solving a rwt L1 problem
0322 in = [];
0323 in.tau = tau; W = tau;
0324 in.W = W;
0325 in.delx_mode = delx_mode;
0326 in.debias = 0;
0327 in.verbose = 0;
0328 in.plots = 0;
0329 in.record = 1;
0330 in.err_fun = err_fun;
0331 tic
0332 for wt_itr = 1:5
0333     
0334     out = l1homotopy(A,y,in);
0335     xh = out.x_out;
0336     iter_bpdn = out.iter;
0337     time_bpdn = toc;
0338     gamma_bpdn = out.gamma;
0339     
0340     % Update weights
0341     xh_old = xh;
0342     
0343     alpha = 1; epsilon = 1;
0344     beta = M*(norm(xh,2)/norm(xh,1))^2;
0345     W = tau/alpha./(beta*abs(xh)+epsilon);
0346     
0347     W_old = W;
0348     yh = A*xh;
0349     Atr = A'*(A*xh-y);
0350     u =  -W.*sign(xh)-Atr;
0351     pk_old = Atr+u;
0352     
0353     in = out;
0354     in.xh_old = xh;
0355     in.pk_old = pk_old;
0356     in.u = u;
0357     in.W_old = W_old;
0358     in.W = W;
0359 end
0360 W = W_old;
0361 
0362 sim = sim+1;
0363 x_vec((sim-1)*N+1:sim*N,1) = x(1:N);
0364 xh_vec((sim-1)*N+1:sim*N,1:3) = [xh(1:N) xh(1:N) xh(1:N)];
0365 
0366 s_ind = t_ind(1:L);
0367 sig_temp = Psi*xh(1:N);
0368 sig_vec(s_ind) = sigt(1:L);
0369 sigh_vec(s_ind,1:3) = sigh_vec(s_ind,1:3)+[sig_temp sig_temp sig_temp];
0370 
0371 % average instantaneous estimates before committing to output...
0372 estimate_buffer = repmat(xh(1:(P-1)*N,1),1,P-1)/(P-1);
0373 
0374 xh_streamingRWT = xh;
0375 x_sparsa = xh;
0376 x_yall1 = xh;
0377 
0378 
0379 %% Kalman initialization
0380 if LM == N
0381     Pk_1 = eye(N)/(wt_pred)^2;
0382     sig_kalman = sig_vec(t_ind(1:N)-N,1);
0383     
0384     Ak = PHI(1:M,1:N);
0385     yk = y(1:M);
0386     x_k = F0*sig_kalman;
0387     P_k = F0*Pk_1*F0'+1/(wt_pred^2)*eye(N);
0388     PAt = P_k*Ak';
0389     Kk = PAt*(pinv(Ak*PAt+eye(M)));
0390     Pk_1 = P_k-Kk*PAt';
0391     sig_kalman = x_k + Kk*(yk-Ak*x_k);
0392     
0393     sig_temp = sigh_vec(t_ind(1:N)-N,3);
0394     y_kalman = y;
0395     y_kalman(P*M+1:P*M+N) = -wt_pred*(F0*sig_temp);
0396     
0397     switch LS_Kalman
0398         case 'inst'
0399             sig_P = [PHI;F]\y_kalman;
0400             sig_kalman = sig_P(1:N);
0401         case 'smooth'
0402             % solves for x_1 using the prediction covariance matrix
0403             % from all previous measurements and smoothing with P-1 future measurements
0404             % minimize 1/2 (x_1-x_1|0)'*P_1|0(x_1-x_1|0)
0405             % + \sum_{k = 1,...,P} 1/2||y_k-A_k x_k||_2^2 + lambda/2||F_k
0406             % x_k-x_k+1||_2^2
0407             
0408             iP_k = pinv(P_k);
0409             Pmat = PHI'*PHI + F'*F;
0410             Pmat(1:N,1:N) = Pmat(1:N,1:N)-wt_pred^2*eye(N)+iP_k;
0411             Pty = PHI'*y_kalman(1:M*P)+[iP_k*(F0*sig_temp); zeros((P-1)*N,1)];
0412             sig_P2 = pinv(Pmat)*Pty;
0413             sig_kalman = sig_P2(1:N);
0414         case 'filter'
0415             % no change...
0416     end
0417     
0418     sigh_vec(t_ind(1:N),3) = sig_kalman;
0419 end
0420 
0421 
0422 %% GO...
0423 
0424 done = 0;
0425 while ~done
0426     
0427     % Update the solution after updating the measurement matrix and/or the
0428     % sparse signal
0429     x_old = x;
0430     y_old = y; A_old = A;
0431     
0432     sigt_old = sigt; t_ind_old = t_ind;
0433     PHI_old = PHI;
0434     
0435     % Shift the sampling window
0436     t_ind = t_ind+N;
0437     if t_ind(end) > length(sig)
0438         break;
0439     end
0440     sigt = sig(t_ind); % Signal under the LOT window at time t
0441     if sym == 1 || sym == 2
0442         x = pinv(PSI'*PSI)*(PSI'*sigt); % Sparse LOT coefficients
0443     else
0444         x = PSI'*sigt;
0445     end
0446     
0447     % System matrix setup...
0448     % Shift up and left
0449     PHI(1:end-M,1:end-N) = PHI(M+1:end,N+1:end);
0450     % new measurement matrix
0451     Phi = genAmat_h(M,LM);
0452     PHI(end-M+1:end,end-LM+1:end) = Phi;
0453     
0454     % shift old measurements and add one new set of measurementts
0455     y = PHI*sigt(1:end-(L-N));
0456     e(1:end-M) = e(M+1:end);
0457     e(end-M+1:end) = randn(M,1)*sigma;
0458     y= y+e;
0459     
0460     A = [PHI; F]*PSI_M;
0461     
0462     A0 = A; x0 = x; y0 = y;
0463     for solver = {'l1homotopy','sparsa'}
0464         solver = char(solver);
0465         switch solver
0466             case 'l1homotopy'
0467                 xh = xh_streamingRWT;
0468                 sig_out = sigh_vec(t_ind(1:N)-N,1);
0469             case 'sparsa'
0470                 xh = x_sparsa;
0471                 sig_out = sigh_vec(t_ind(1:N)-N,2);
0472             case 'yall1'
0473                 xh = x_yall1;
0474                 sig_out = sigh_vec(t_ind(1:N)-N,3);
0475         end
0476         y = y0; A = A0; x = x0;
0477         xh_old = xh;
0478         y = [y; -wt_pred*F0*sig_out; zeros((P-1)*N,1)];
0479         
0480         % REMOVE the part of outgoing DWT projection in the overlapping region
0481         % on left side of streaming window...
0482         if LeftProj_cancel
0483             y = y-[PHI(:,1:(L-N));F(:,1:(L-N))]*(Psi(end-(L-N)+1:end,:)*xh_old(1:N));
0484         end
0485         
0486         % Update the signal estimate (for warm start)
0487         xh(1:end-N) = xh(N+1:end);
0488         sigh_old = PSI(1:end-L,:)*xh;
0489         % sigh_pred = [sigh_old; F_h(sigh_old(end-N+1:end),cshift,0); zeros(L-N,1)];
0490         sigh_pred = [sigh_old; F_h(sigh_old(end-N+1:end),cshift,0)];
0491         
0492         if sym == 3
0493             sigh_temp = F_h(sigh_pred(end-N+1:end),cshift,0);
0494             sigh_pred = [sigh_pred; sigh_temp(1:L-N)];
0495             % sigh_pred = [sigh_pred; linspace(sigh_pred(end),0,L-N)'];
0496         end
0497         if sym == 1 || sym == 2
0498             xh = pinv(PSI'*PSI)*(PSI'*sigh_pred); % Sparse LOT coefficients
0499         else
0500             xh = PSI'*sigh_pred;
0501         end
0502         xh(abs(xh)<tau/sqrt(log(P*N))) = 0;
0503         %         xh_temp = xh(end-N+1:end);
0504         %         xh_temp(abs(xh_temp)<tau/sqrt(log(P*N))) = 0;
0505         %         xh(end-N+1:end) = xh_temp;
0506         if sym == 3 % truncate coefficients for overlapping wavelets...
0507             for p = 2.^(0:J)
0508                 xh((P-1)*N+N/p) = 0;                
0509             end
0510         end
0511         
0512         fig(111);
0513         plot([x xh]);
0514         
0515         % Remove the top-left edge of the system matrix
0516         if LeftEdge_trunc
0517             % fprintf('Consider oldest set of LOT coefficients to be fully known, and remove their contribution from the measurements... \n');
0518             alpha0h = xh(1:N);
0519             xh = xh(N+1:end);
0520             y = y-A(:,1:N)*alpha0h;
0521             A = A(:,N+1:end);
0522             
0523             A_old = A; y_old = y;
0524             A(size(PHI,1)+1:size(PHI,1)+N,:) = [];
0525             y(size(PHI,1)+1:size(PHI,1)+N) = [];
0526             
0527             alpha0 = x(1:N);
0528             x = x(N+1:end);
0529         end
0530         
0531         % Update weights
0532         alpha = 1; epsilon = 1;
0533         beta = M*(norm(xh,2)/norm(xh,1))^1;
0534         W = tau/alpha./(beta*abs(xh)+epsilon);
0535         W_old = W;
0536         
0537         if strcmpi(solver,'l1homotopy');
0538             
0539             homotopy_mode = 'dummy';
0540             switch homotopy_mode
0541                 case 'dummy'
0542                     % create a dummy variable...
0543                     % use homotopy on the measurements...
0544                     % in principle, we can start with any xh_old with this formulation
0545                     % and any starting value of tau or W...
0546                     gamma = find(xh);
0547                     M_trunc = size(A,1); % P*(M-1);
0548                     if length(gamma) >= M_trunc
0549                         disp('length of gamma exceeded number of rows');
0550                         [xh_sort ind_sort] = sort(abs(xh),'descend');
0551                         xh(ind_sort(M_trunc+1:end)) = 0;
0552                         gamma = ind_sort(1:M_trunc);
0553                     end
0554                     Atr = A'*(A*xh-y);
0555                     u =  -W.*sign(xh)-Atr;
0556                     pk_old = Atr+u;
0557                 otherwise
0558                     didp('Go back ... no escape');
0559             end
0560             
0561             
0562             in = out;
0563             gamma_old = gamma;
0564             in.gamma = gamma_old;
0565             switch delx_mode
0566                 case 'mil';
0567                     % in.delx_mode = 'mil';
0568                     % The following gram matrix and its inverse can be used from the
0569                     % previous homotopy. Too lazy to include that right now...
0570                     % wt BPDN homotopy update
0571                     AtAgx = A(:,gamma_old)'*A(:,gamma_old);
0572                     iAtAgx = pinv(AtAgx);
0573                     in.iAtA = iAtAgx;
0574                 case {'qr','chol'};
0575                     % in.delx_mode = 'qr';
0576                     [Q R] = qr(A(:,gamma_old),0);
0577                     in.Q = Q; in.R = R;
0578                 case 'qrM'
0579                     % in.delx_mode = 'qrM';
0580                     [Q0 R0] = qr(A(:,gamma_old));
0581                     in.Q0 = Q0; in.R0 = R0;
0582             end
0583             
0584             in.xh_old = xh;
0585             in.pk_old = pk_old;
0586             in.u = u;
0587             in.W = W;
0588             in.delx_mode = delx_mode;
0589             in.debias = 0;
0590             in.verbose = 0;
0591             in.plots = 0;
0592             in.record = 1;
0593             in.err_fun = @(z) (norm(x-z)/norm(x))^2;
0594             tic
0595             out = l1homotopy(A,y,in);
0596             time_streamingRWT = toc;
0597             xh_streamingRWT = out.x_out;
0598             gamma_streamingRWT = out.gamma;
0599             iter_streamingRWT = out.iter;
0600             % Reconstructed signal
0601             if LeftEdge_trunc
0602                 x = [alpha0; x];
0603                 xh_streamingRWT = [alpha0h; xh_streamingRWT];
0604             end
0605         elseif  strcmpi(solver,'sparsa')
0606             %% SpaRSA
0607             x_sparsa = xh; W_sparsa = W/tau; iter_sparsa = 0; time_sparsa = 0;
0608             if norm(y) > 1e-3
0609                 psi_function = @(x,tau) soft(x,tau*W_sparsa);
0610                 phi_function = @(x) sum(abs(W_sparsa.*x));
0611                 tic;
0612                 [x_sparsa,x_debias_SpaRSA_m,obj_SpaRSA_m_cont,...
0613                     times_SpaRSA_m_cont,debias_start_SpaRSA_m,mse_SpaRSA_m,taus_SpaRSA_m, numA, numAt]= ...
0614                     SpaRSA_adpW(y,A,tau,...
0615                     'Monotone',0,...
0616                     'adp_wt',0,...
0617                     'W_new',W_sparsa,...
0618                     'Debias',0,...
0619                     'Initialization',x_sparsa,...
0620                     'StopCriterion',2,...
0621                     'ToleranceA',1e-4,...
0622                     'psi',psi_function,...
0623                     'phi',phi_function,...
0624                     'Safeguard',1,...
0625                     'MaxiterA',maxiter,...
0626                     'Verbose',0,...
0627                     'True_x',x,...
0628                     'Continuation',1,...
0629                     'Continuationsteps',-1);
0630                 
0631                 time_sparsa = toc;
0632                 iter_sparsa = (numA+numAt)/2;
0633                 error_sparsa = norm(x-x_sparsa)/norm(x);
0634             end
0635             % Reconstructed signal
0636             if LeftEdge_trunc
0637                 x = [alpha0; x];
0638                 x_sparsa = [alpha0h; x_sparsa];
0639             end
0640         elseif strcmpi(solver,'yall1')
0641             
0642             %% YALL1
0643             disp('yall1 only works when A is underdetermined');
0644             % set options
0645             digit = 4; if sigma > 0; digit = 4; end
0646             opts = [];
0647             opts.tol = 10^(-digit);
0648             opts.weights = W/tau;
0649             opts.print = 0;
0650             opts.maxit = maxiter;
0651             opts.nonorth = 1;
0652             % opts.x0 = xh;
0653             opts.nu = 0; opts.rho = tau;
0654             tic;
0655             [x_yall1,Out_yall1] = yall1(A,y,opts);
0656             % time_yall1 = [time_yall1 Out.cputime];
0657             time_yall1 = toc;
0658             iter_yall1 = (Out_yall1.cntA+Out_yall1.cntAt)/2;
0659             err_yall1 = norm(x-x_yall1)/norm(x);
0660             % Reconstructed signal
0661             if LeftEdge_trunc
0662                 x = [alpha0; x];
0663                 x_yall1 = [alpha0h; x_yall1];
0664             end
0665             
0666             if max(abs(x_yall1)) > 500
0667                 stp = 1;
0668             end
0669         end
0670     end
0671     
0672     %% Plot DWT coeffs. on the window
0673     fig(1); subplot(211);
0674     plot([x xh_streamingRWT x_sparsa]);
0675     title('Comparison betweeen the original and reconstructed signal')
0676     
0677     %% Reconstructed signal
0678     sim = sim+1;
0679     x_vec((sim-1)*N+1:sim*N,1) = x(1:N);
0680     xh = xh_streamingRWT;
0681     
0682     % remove the oldest estimate, shift the remaining up and left, and add the new estimate
0683     estimate_buffer = [[estimate_buffer(N+1:end,2:end); zeros(N,P-2)] xh(1:end-N)/(P-1)];
0684     if avg_output
0685         xh_est = xh(1:N);
0686         xh(1:N) = sum(estimate_buffer(1:N,:),2);
0687         % fig(123); plot([xh_est xh(1:N) x(1:N)])
0688         if sim == 2
0689             disp('output is averaged');
0690         end
0691     end
0692     xh_vec((sim-1)*N+1:sim*N,1) = xh(1:N);
0693     xh_vec((sim-1)*N+1:sim*N,2) = x_sparsa(1:N);
0694     
0695     s_ind = t_ind(1:L);
0696     sig_vec(s_ind) = sigt(1:L);
0697     sigh_vec(s_ind,1) = sigh_vec(s_ind,1)+Psi*xh(1:N);
0698     sigh_vec(s_ind,2) = sigh_vec(s_ind,2)+Psi*x_sparsa(1:N);
0699     
0700     
0701     
0702     %% plot recovered signals
0703     fig(1); subplot(212)
0704     plot([sig_vec(1:s_ind(end)) sigh_vec(1:s_ind(end),1)]);
0705     
0706     drawnow;
0707     
0708     %% Kalman recursion
0709     if LM == N
0710         
0711         Ak = PHI(1:M,1:N);
0712         yk = y(1:M);
0713         x_k = F0*sig_kalman;
0714         P_k = F0*Pk_1*F0'+1/(wt_pred^2)*eye(N);
0715         PAt = P_k*Ak';
0716         Kk = PAt*(pinv(Ak*PAt+eye(M)));
0717         sig_kalman = x_k + Kk*(yk-Ak*x_k);
0718         Pk_1 = P_k-Kk*PAt';
0719         %         if mod(sim,50) == 0
0720         %             Pk_1 = eye(N)/wt_pred^2;
0721         %         end
0722         
0723         sig_temp = sigh_vec(t_ind(1:N)-N,3);
0724         y_kalman = y;
0725         y_kalman(P*M+1:P*M+N) = -wt_pred*(F0*sig_temp);
0726         
0727         switch LS_Kalman
0728             case 'inst'
0729                 sig_P = [PHI;F]\y_kalman;
0730                 sig_kalman = sig_P(1:N);
0731             case 'smooth'
0732                 % solves for x_1 using the prediction covariance matrix
0733                 % from all previous measurements and smoothing with P-1 future measurements
0734                 % minimize 1/2 (x_1-x_1|0)'*P_1|0(x_1-x_1|0)
0735                 % + \sum_{k = 1,...,P} 1/2||y_k-A_k x_k||_2^2 + lambda/2||F_k
0736                 % x_k-x_k+1||_2^2
0737                 
0738                 iP_k = pinv(P_k);
0739                 Pmat = PHI'*PHI + F'*F;
0740                 Pmat(1:N,1:N) = Pmat(1:N,1:N)-wt_pred^2*eye(N)+iP_k;
0741                 Pty = PHI'*y_kalman(1:M*P)+[iP_k*(F0*sig_temp); zeros((P-1)*N,1)];
0742                 sig_P2 = pinv(Pmat)*Pty;
0743                 sig_kalman = sig_P2(1:N);
0744             case 'filter'
0745                 % do nothing
0746         end
0747         
0748         sigh_vec(t_ind(1:N),3) = sig_kalman;
0749         
0750         fig(33)
0751         plot([sig_vec(s_ind(1:N),1) sigh_vec(s_ind(1:N),1) sigh_vec(s_ind(1:N),3)]);
0752         title('comparison with LS-Kalman');
0753         err_l1 = norm(sig_vec(s_ind(1:N),1)-sigh_vec(s_ind(1:N),1))^2/norm(sig_vec(s_ind(1:N),1))^2;
0754         err_ls = norm(sig_vec(s_ind(1:N),1)-sigh_vec(s_ind(1:N),3))^2/norm(sig_vec(s_ind(1:N),1))^2;
0755         if mod(sim-1,verbose) == 0 && verbose
0756             fprintf('L1 vs LS Kalman: sim %d  --%3.4g : %3.4g--\n',sim, err_l1, err_ls);
0757         end
0758     end
0759     
0760     %% Record results
0761     SIM_stack{sim} = [sim, tau, ...
0762         norm(x-xh_streamingRWT)^2/norm(x)^2, sum(iter_streamingRWT,2), sum(time_streamingRWT,2), ...
0763         norm(x-x_sparsa)^2/norm(x)^2, sum(iter_sparsa,2), sum(time_sparsa,2), ...
0764         err_ls];
0765     
0766     % print and plot
0767     if mod(sim-1,verbose) == 0 && verbose
0768         fprintf('streaming iter. %d. tau = %3.4g, (err,iter,time): streamingRWT homotopy-%3.4g,%3.4g,%3.4g, SpaRSA-%3.4g,%3.4g,%3.4g, LS-Kalman-%3.4g\n', ...
0769             SIM_stack{sim});
0770     end
0771 end
0772 fprintf('\n');
0773 mS =  sum(cell2mat(SIM_stack),1);
0774 fprintf('Summed results: streaming_iter %d. tau = %3.4g, \n solver-(err,iter,time): \n streamingRWT homotopy-%3.4g,%3.4g,%3.4g; \n SpaRSA-%3.4g,%3.4g,%3.4g; \n LS-Kalman-%3.4g. \n', streaming_iter, mS(2:end));
0775 % mS =  mean(cell2mat(SIM_stack),1);
0776 % fprintf('Average results: streaming_iter %d. tau = %3.4g, \n solver-(err,iter,time): \n streamingRWT homotopy-%3.4g,%3.4g,%3.4g; \n SpaRSA-%3.4g,%3.4g,%3.4g; \n LS-Kalman-%3.4g. \n', streaming_iter, mS(2:end));
0777 
0778 % l1homotopy-%3.4g,%3.4g,%3.4g;
0779 st_ind = 2*N;
0780 err_l1homotopy = norm(sig_vec(st_ind+1:N*sim)-sigh_vec(st_ind+1:N*sim,1))^2/norm(sig_vec(st_ind+1:N*sim))^2;
0781 err_sparsa = norm(sig_vec(st_ind+1:N*sim)-sigh_vec(st_ind+1:N*sim,2))^2/norm(sig_vec(st_ind+1:N*sim))^2;
0782 if LM == N
0783     err_kalman = norm(sig_vec(st_ind+1:N*sim)-sigh_vec(st_ind+1:N*sim,3))^2/norm(sig_vec(st_ind+1:N*sim))^2;
0784     fprintf('Signal MSE: l1homotopy-%3.4g, sparsa-%3.4g, LS-kalman-%3.4g.\n',([err_l1homotopy,err_sparsa,err_kalman]));
0785     fprintf('Signal SER (in dB): l1homotopy-%3.4g, sparsa-%3.4g, LS-kalman-%3.4g.\n',-10*log10([err_l1homotopy,err_sparsa,err_kalman]));
0786 
0787 else
0788     fprintf('Signal MSE: l1homotopy-%3.4g, sparsa-%3.4g.\n',err_l1homotopy,err_sparsa);
0789     fprintf('Signal SER (in dB): l1homotopy-%3.4g, sparsa-%3.4g.\n',-10*log10([err_l1homotopy,err_sparsa]));
0790 end
0791 
0792 %% plot signal and reconstruction error
0793 x_len = min(length(x_vec),length(xh_vec))-(P-1)*N;
0794 sig_len = min(length(sig_vec),length(sigh_vec))-(P-1)*N-L-N;
0795 x_vec1 = x_vec(1:x_len);
0796 xh_vec1 = xh_vec(1:x_len,1);
0797 sig_vec1 = sig_vec(1:sig_len);
0798 sigh_vec1 = sigh_vec(1:sig_len,1);
0799 
0800 fig(123);
0801 subplot(221);
0802 imagesc(reshape(sig_vec1,N,sig_len/N));
0803 axis tight;
0804 title('original signal')
0805 subplot(2,2,2)
0806 imagesc(reshape(sigh_vec1,N,sig_len/N));
0807 % plot((1:sig_len)/N,sigh_vec1-sig_vec1, 'LineWidth',1);
0808 axis tight
0809 title('reconstruction error')
0810 subplot(2,2,3);
0811 imagesc(reshape(x_vec1,N,x_len/N)); axis xy;
0812 axis tight;
0813 title('DWT coefficients');
0814 colorbar
0815 subplot(2,2,4);
0816 imagesc(reshape(x_vec1-xh_vec1,N,x_len/N),[0 max(abs(x_vec1))/20]); axis xy
0817 axis tight
0818 title('reconstruction error');
0819 colorbar
0820 
0821 
0822 fig(3);
0823 % view DWT coefficients...
0824 alpha_vec1 = apply_DWT(sig_vec1,N,wType,J,sym);
0825 alphah_vec1 = apply_DWT(sigh_vec1,N,wType,J,sym);
0826 subplot(221); imagesc(reshape(alpha_vec1,N,length(alpha_vec1)/N));
0827 axis xy;
0828 title('original')
0829 subplot(222); imagesc(reshape(alphah_vec1,N,length(alpha_vec1)/N));
0830 axis xy;
0831 title('reconstructed');
0832 subplot(212); plot([alpha_vec1 alphah_vec1]);
0833 title('comparison');
0834 
0835 %%
0836 if LM == N
0837     st_ind = 2*N;
0838     fig(4); clf; hold on;
0839     s_len = length(sig_vec(st_ind+1:N*sim));
0840     rshp1 = @(x) reshape(x(st_ind+1:N*sim),N,s_len/N);
0841     err_fun = @(x) -10*log10(sum(rshp1(sig_vec-x).^2,1)./sum(rshp1(sig_vec).^2,1));
0842     plot(err_fun(sigh_vec(:,1)),'b');
0843     % plot(err_fun(sigh_vec(:,2)),'k');
0844     plot(err_fun(sigh_vec(:,3)),'r');
0845     title('error evolution');
0846     xlabel('iteration');
0847     ylabel('ser in db');
0848     legend('L1','LS');
0849     shg
0850 end

Generated on Mon 10-Jun-2013 23:03:23 by m2html © 2005