function [Ysel,lam,Xsel]=blockgreedy(A,rhs,tol,Xsel,Ysel,rho)
% 
% Ysel = blockgreedy(A) selects columns of an MxN A by primal/dual residuals
% using all one rhs and the default tol = eps, Xsel=[], Ysel=[], and 
% rho = N if 9M/7<=N or otherwise, rho = max(2,log10(M)) 
% 
% [Ysel,lam]=blockgreedy(A,rhs,tol,Xsel,Ysel,rho) starts primal/dual
% residual iteration with rows Xsel and columans Ysel and returns the 
% least-squares solution lam=A(:,Ysel)\rhs
% 
% Example:
% rng(0),mu=1; m=1500; n=10000;
% A=rand(m,n);
% % select columns by block-greedy
% tic,[Ysel]=blockgreedy(A);toc,
% cond( A(:,Ysel)),
% subplot(2,1,1),spy(abs(A(:,Ysel)\A)>mu),
% % select columsn by QR
% tic,[q,r,e]=qr(A,0);toc, 
% subplot(2,1,2),spy(abs(A(:,e(1:m))\A)>mu), 
% % select columns randomly
% cond(A(:,e(1:m)))
% [~,i]=sort(rand(n,1)); 
% cond( A(:,i(1:m)))
%
% Ref: Leevan Ling, A fast block-greedy algorithm for quasi-optimal meshless
% trial subspace selection, SIAM Scientific Computing, to appear, 2016.
% 

%% Very basic inputs check
[mA, nA]=size(A); 
if nargin<2, rhs = []; end
if isempty(rhs), rhs=ones(mA,1); end
[mb, nb]=size(rhs);
if min(mb,nb)~=1, error('Second input argument is not a vector'); end
if nb~=1, rhs=rhs'; [mb, ~]=size(rhs); end
if mb~=mA, error('Input dimensions do not match'); end
if nargin<3, tol = []; end
if isempty(tol), tol = eps; end
if nargin<4, Xsel = []; end
if nargin<5, Ysel = []; end
if nargin<6, 
    if 9/7*mA>=nA;
        rho = nA; % Check all unselected columns
    else
        rho = max(2,log10(mA));
    end
end
%% Initialize
kc = @(k)(max( min(nA,20), rho*k)); % #Candidate in column selection
if ~isempty(Xsel),
    Xcan = setdiff( 1:mA , Xsel ); 
    if isempty(Ysel), 
        [~, inq] = max( sum((A(Xsel,:)).^2,1)  ); 
         Ysel = inq; Ycan = [1:inq-1, inq+1:nA];
    else
        if length(Xsel)<length(Ysel), 
            error('Too many preselected columns');
        end
        Ycan = setdiff( 1:nA , Ysel );
    end
else 
    [~, imr] = max(abs(rhs));       % get maximum primal residual
    [~, inq] = max(abs(A(imr,:)));  % get maximum dual residual
    Xsel = imr; Xcan = [1:imr-1, imr+1:mA]; % X collocation point/ Row
    Ysel = inq; Ycan = [1:inq-1, inq+1:nA]; % Y trial point / Column
end
Xsel = uint16(Xsel(:)'); Xcan = uint16(Xcan(:)');
Ysel = uint16(Ysel(:)'); Ycan = uint16(Ycan(:)');
% Matrix-free version: Generate and store A(Xsel,:) and A(:,Ysel)
[Q,R]= qr(A(Xsel,Ysel),0);
k    = length(Ysel); 
svec = [];
while k < min(mA,nA),
    %% Primal residual r
    x = R\(Q'*rhs(Xsel)); 
    if length(Xsel) < mA
        [mr, imr] = sort(abs( A(Xcan,Ysel)*x-rhs(Xcan) ),'descend');
        if mr < tol, disp('r small enough'); break, end
        Xcan = Xcan(imr);        
    end
    %% Dual residual q
    v = -Q*(R'\x);     
    [~, inq] = sort(abs( A(Xsel,Ycan)'*v ),'descend');
    Ycan = Ycan(inq);    
    %% Add rows to QR
    if ~isempty(Xcan)
        Mc   = min(length(Xcan), k+1/k);  % Mc <= #rows to be added    
        Mid  = length(Xcan): -floor(length(Xcan)/Mc): 1 ;                
        % Matrix-free version: Generate and store A(Xcan(Mid),:)
        Xsel = [ Xsel, Xcan(Mid) ];        
        Xcan = setdiff( Xcan, Xcan(Mid));         
        [Q,R]= qr(A(Xsel,Ysel),0);
    end
    %% Add columns to QR
    Nc   = min( length(Ycan), kc(k));  % Nc <= #candidates
    Ns   = min( min(mA,nA)-k,  k );    % Exact #seats    
    Nid  =  round(length(Ycan): -(length(Ycan)/Nc) :1 );   
    QA   = Q'*A(Xsel, Ycan(Nid));
    [Q2,R2,e] = qr( A(Xsel, Ycan(Nid)) - Q*QA , 0 ); 
    Q    = [ Q, Q2(:,1:Ns)];
    R    = [ R, QA(:,e(1:Ns)); zeros(Ns,k), R2(1:Ns,1:Ns)];
    % Matrix-free version: Generate and store A(:,Ycan(Nid(e(1:Ns))))
    Ysel = [ Ysel, Ycan(Nid(e(1:Ns))) ];
    Ycan = setdiff( Ycan, Ycan(Nid(e(1:Ns))));    
    %% Check condition number
    smax   = esmax(R);
    smin   = esmin(R);
    kappaR = smax(end) / smin;
    if kappaR > 1/tol
        % search in interval (k, k+Ns)
        n0  = k;  
        n1  = k+Ns;
        if Ns==1, cut = k; end
        while n1-n0> 1
            cut    = round((n0+n1)/2);
            kappaR = smax( cut )/esmin( R( 1:cut, 1:cut) );
            if kappaR > 1/tol
                n1 = cut;
            else
                n0 = cut;              
            end
        end
        Ysel = Ysel(1:cut);
        disp('Condition bad enough');    break;
    end
    % Repeat
    k = 2*k;  
end
if nargout>1
    lam       = zeros(nA,1);
    lam(Ysel) = A(:,Ysel) \ rhs;
    % Remove belows to speed up
    residual  = A(:,Ysel) * lam(Ysel) - rhs; 
    fprintf('\nBlock-greedy algorithm stopped after selecting %g columns out of %g\n',length(Ysel), nA)
    fprintf('(%g) inf-norm of residual = %g, 2-norm of residual = %g\n\n',kappaR,norm(residual,'inf'), norm(residual))
end


%% esmax (vector) and esmin (value) for estimating cond number
function [smax, y] = esmax(R)
% R n-by-n upper triangular. 
% Output: n-vector of estimated matrix 2-norm of all k by k (k<=n) upper-left submatrices of R
% Ref: Nicholas J. Higham, Estimating the matrix p-norm, Numer. Math 62 (1992), 511¡V538.
y(1)    = R(1,1);
smax(1) = abs(y);
for k = 2 : size(R,1)   
        [~,S,V] = svd( [ [y;0] , R(1:k,k)], 0 );
        smax(k) =  S(1);
        y       = [ [y;0] , R(1:k,k)] * V(:,1) ;
end
%%
function  smin = esmin(R)
% R n-by-n upper triangular. 
% Output: Estimated matrix 2-norm of the inv( R )
% Ref: Charles Van Loan, On estimating the condition of eigenvalues and eigenvectors, Linear Algebra
% and its Applications 88-89 (1987), no. 0, 715--732.
m = size(R,1);
p = zeros(m,1);  d = zeros(m,1);  y = zeros(m,1);
for k = m:-1:1
    if k == m
        c = 1; s = 0;
    else
        W = [ zeros(m-k,1) , y(k+1:m)                        ;...
              1            , -p(k)                           ;...
              R(1:k-1,k)   , R(k,k)*p(1:k-1)-p(k)*R(1:k-1,k) ];
        [~,~,V] = svd( W, 0 );
        c = V(1,1); 
        s = V(2,1);
    end
    d(k)     = c;
    y(k)     = ( c-s*p(k) ) / R(k,k);
    d(k+1:m) = s*d(k+1:m);
    y(k+1:m) = s*y(k+1:m);
    p(1:k-1) = s*p(1:k-1) + R(1:k-1,k)*y(k);
end 
smin = 1/norm(y);
