9/17/2011

Pair-wise Weighted Euclidean distance between 2 sets of vectors

In many tasks, we wish compute the pair-wise distance between two sets of vectors (in high p-dimension space). For example, in clustering problem we want to compute the distance between each point in the given set of N points (or vectors), represent by a matrix X of size N-by-p, to every other points. The simplest implementation in Matlab may look similar to this:

DD = zeros(N,N);
for i =1:N
    for k=1:N
         DD(i,k) = sum((X(k,:)-X(i,:)).^2);
   end
end

(It actually compute the SQUARE of the distance, but in many case it is enough, e.g for just want to know which is the nearest point to a given point; otherwise we can easy say DD=sqrt(DD). Compute the square distance is significantly faster)
However, as you know, using FOR loop in Matlab is so slooow. People said "Avoid FOR loop as much as possible".
By the way, the square Euclidean distance between two vector a and b is: d = sum((a-b).^2); and the weighted version is simply: d = sum(wts.*(a-b).^2).
The inner for-loop in the above code can be easily eliminated, such as:


DD = zeros(N,N);
for i =1:N
   % inner for-loop of above code equivalent to:
    DD(i,:) = sum((X - repmat(X(i,:),N,1)).^2);
end


But we still need the outer for-loop. How about get rid all of it? YES, it is possible. We just need a small trick:
(a-b)^2= a^2+b^2-2ab
Yes it is easy for 2 vectors, but applying for two matrix A and B need quite an arrangement.
Check out the following code:

%***************************

function DD = pdist2(A,B,Wts)
% Find pair-wise SQUARE EUCLIDEAN distance
% or 'Weighted square euclidean' distance
% between each point in A and B
% For 2 vector a, b
% Euclidean distance= d = sum((a-b).^2)
% Weighted version  = d = sum(wts.*(a-b).^2)
% ------------------------------
% Input:
%   A= m_by_p, m points in p-dimension
%   B= n_by_p, n points in p-dimension
%   Wts = 1_by_p, defaut = [1 1 ...]
% Results:
%   DD= m_by_n
% ------------------------------
% Facts: (a-b)^2= a^2+b^2-2ab
% ------------------------------
% trungd@okstate.edu, Feb 2011
% ------------------------------
% Ideal from: Piotr Dollar.  [pdollar-at-caltech.edu]


[m,p1] = size(A); [n,p2] = size(B);
if(p1~=p2) % check size
    error('Must have: ncol(X)=ncol(Y)=length(Wts)')
end
if nargin < 3 % standard euclidean
    AA = sum(A.*A,2);  % column m_by_1
    BB = sum(B.*B,2)'; % row 1_by_n
    DD = AA(:,ones(1,n)) + BB(ones(1,m),:) - 2*A*B';
else
    if p1 ~=length(Wts)
        error('Must have: ncol(X)=ncol(Y)=length(Wts)')
    end
    sW = sqrt(Wts(:)'); % make sure row, square of Wts
    A = sW(ones(1,m),:).*A;
    B = sW(ones(1,n),:).*B; % modify A,B
    % Process the same as standard Euclidean
    AA = sum(A.*A,2);  
    BB = sum(B.*B,2)'; 
    DD = AA(:,ones(1,n)) + BB(ones(1,m),:) - 2*A*B';    
end

%***************************

It runs fast, just for the record my laptop is average, not supper:

>> A = rand(1000,100);
>> B = rand(2000,100);
>> W = rand(1,100);
>> tic, D = pdist2(A,B,W);toc
Elapsed time is 0.101665 seconds.


As said, the idea is from Piotr Dollar.  [pdollar-at-caltech.edu], I modified to handle the weighted euclidean case (which I meet pretty offen).
Other type of distance such as Cosine, L1, Mahalanobis, ... can be handled, but some of them are not that handy.
Last thing, pdist2() is available in Statistic Toolbox of Matlab that can deal with much more general distance. If you have that toolbox, forget about the code here :)


No comments: