
function [newClusters,newGroups,newMap,newMeans] = clusterMerge(im, initGroups, clusterMap, initClusters, means, LAB_DIFF_THRESHOLD, PERCENTAGE_THRESHOLD)

% CLUSTERMERGE - merge neighboring clusters
%
%  [NEWCLUSTERS,NEWGROUPS,NEWMAP,NEWMUS] = CLUSTERMERGE(IM,INITGROUPS,CLUSTERMAP,INITCLUSTERS,MEANS) 
%
%  Input:
%       IM - MxNx3 image, where each pixel is a point in a perceptually unif color space (e.g. LAB)
%       INITGROUPS - MxN matrix, where each element is the group number this pixel belongs to
%       CLUSTERMAP - Cx1 matrix, used to map a cluster to a group number
%       INITCLUSTERS - MxNxC matrix, each MxN is a mask for a specific cluster specifying member pixels
%       MEANS - Cx3 matrix, containing the mean LAB values for each cluster
%  ** Please note that the ordering of clustermap, initClusters, and means MUST match
%
%       LAB_DIFF_THRESHOLD - (P2 from paper) regions that differ by less than this amount will be merged
%       PERCENTAGE_THRESHOLD - (P3 from paper) regions that occupy less that this percentage of image will be removed
%
%  Output: 
%       NEWCLUSTERs,NEWGROUPS,NEWMAP,NEWMEANS are the updated data structures after merging 
%       neigboring regions given the threshold defined. Also, any cluster which contains
%       fewer pixels than a certain percentage will be removed and those pixels are left
%       unclassified.
%
% Jeff Walters & Angi Chau
% Feb 2003

%LAB_DIFF_THRESHOLD = 8;       % P2 from paper
%PERCENTAGE_THRESHOLD = 0.05;  % P3 from paper

% initClusters MxNxC have the masks for each cluster
% initGroups MxN has groupings in one image

[h,w,c] = size(im);
currClusters = initClusters;
currGroups = initGroups;
currMap = clusterMap;
currMeans = means;

keepGoing = 1;
iteration = 1;

% we'll keep trying to merge unless the last round found nothing to merge or we have 1 cluster
while (keepGoing & size(currClusters,3) > 1)

    disp(sprintf('iteration %d',iteration));
    % assume this is the last round and we won't try to merge anymore
    keepGoing = 0;
    
    numClusters = size(currClusters,3);
    
    for cind = 1:numClusters,
        disp(sprintf('checking cluster %d',cind));
        
        % we could have already merged this cluster on the last few rounds, 
        % so make sure this is a new one for this round first
        if (currMap(cind) ~= -1)
        
            % mask the image so we only have values for pixels in this cluster
            maskedim = zeros(h,w,3);
            thisMask = currClusters(:,:,cind);
            maskedim(:,:,1) = thisMask.*im(:,:,1);	
            maskedim(:,:,2) = thisMask.*im(:,:,2);
            maskedim(:,:,3) = thisMask.*im(:,:,3);
            
            % find the mean in LAB space of this cluster (we have this already from kmeans)
            %[thisStd, thisMean]=statsCluster(maskedim, thisMask);
            thisMean = currMeans(cind,:);
            
            % get indices of all the pixels in this cluster
            indices = find(thisMask~=0)';
            
            % which group number are we in? just pick a pixel in this group
            % and look up its group number
            thisGroup = currGroups(indices);
            % error check : everyone should be in same group
            if (sum(thisGroup-thisGroup(1)) ~= 0)
                disp('ERROR: cluster masks and groups do not match!!!');
            end
            thisGroup = thisGroup(1);
            
            % for every pixel in this cluster, find its neighbors. the way
            % we can do this is first to assume every pixel has 8 neighbors
            % (we'll get rid of the invalid ones later). since we know matlab
            % can treat matrices as vectors (scanning down columns), we can
            % just manipulate indices to find the indices of the neighbors
            adjPixels = [   indices-h-1;    % NW corner
                indices-h;      % left
                indices-h+1;    % SW corner
                indices-1;      % top
                indices+1;      % bottom
                indices+h-1;    % NE corner
                indices+h;      % right
                indices+h+1];   % SE corner
            
            % pick out only the valid pixel indices
            validOnes = find(adjPixels<h*w & adjPixels>0);
            
            % now, we want to find all the groups these pixels belong to
            adjGroups = currGroups(adjPixels(validOnes));
            
            % remove repetitions + take out current group 
            adjGroups = unique(adjGroups);
            adjGroups = adjGroups(find(adjGroups ~= thisGroup));
            
            % now we have all the neighboring groups. we want to compare
            % the mean of each of the neighbors with the current group's mean
            % and if they're within the threshold, we can merge the clusters
            numNeighbors = length(adjGroups);
            for ind=1:numNeighbors,
                
                % use the cluster map to find the cluster number for the neighbor
                % (we know their group num)
                theirGroup = adjGroups(ind);
                theirCind = find(currMap==theirGroup);
                
                disp(sprintf('testing neighbor cluster -> %d',theirCind));
                
                % look up the their mean
                theirMean = currMeans(theirCind,:);
                
                disp(sprintf('distance between cluster %d and %d = %f', cind,theirCind,sqrt(distances(theirMean,thisMean))));
                if (sqrt(distances(theirMean,thisMean)) < LAB_DIFF_THRESHOLD)
                    
                    disp('merging!');
                    thisCind = cind;
                    
                    % merge clusters (zero out one of them - we'll prune later)
                    currClusters(:,:,thisCind) = currClusters(:,:,thisCind)+currClusters(:,:,theirCind);
                    currClusters(:,:,theirCind) = zeros(h,w);
                    
                    % now we have to update all the other data structures
                    %                 maskedim = zeros(h,w,3);
                    %                 bothMask = currClusters(:,:,thisCind);
                    %                 maskedim(:,:,1) = bothMask.*im(:,:,1);	
                    %                 maskedim(:,:,2) = bothMask.*im(:,:,2);
                    %                 maskedim(:,:,3) = bothMask.*im(:,:,3);
                    %                 [bothStd bothMean]=statsCluster(maskedim, bothMask);
                    
                    currMeans(thisCind, :) = (thisMean+theirMean)./2;
                    currMeans(theirCind, :) = zeros(1,3);
                    
                    currMap(theirCind) = -1;
                    
                    mergeUs = find(currGroups==theirGroup);
                    currGroups(mergeUs) = thisGroup;
                    
                    % if we found a pair to merge, we should check again
                    keepGoing = 1;
                end
            end
        end % end if this is a new cluster
    end % end for each cluster
end  % end while

% once we're done, we should prune off the empty clusters and also, any cluster which occupy
% a small enough number of pixels will be removed
newGroups = currGroups;
numClusters = size(currClusters,3);
newClusters = zeros(h,w,0);
newMeans = zeros(0,size(currMeans,2));
newMap = zeros(1,0);

totalPixels = h*w;

for j=1:numClusters,
    mask = currClusters(:,:,j);
    if (sum(mask(:)) ~= 0)     % if this is a non-empty cluster
        if (sum(mask(:))/totalPixels > PERCENTAGE_THRESHOLD)   % if this is a significant cluster
            disp(sprintf('cluster %d is significant (%f percent) --> adding',j,sum(mask(:))/totalPixels*100));
            newClusters(:,:,end+1) = mask;
            newMeans(end+1,:) = currMeans(j,:);
            newMap(end+1) = currMap(j);
        else    % if cluster too small, we have to remove the pixel associations to this group
            disp(sprintf('cluster %d is too small (%f percent) --> removing',j,sum(mask(:))/totalPixels*100));
            groupless = find(newGroups==currMap(j));
            newGroups(groupless) = 0;
        end
    end
end

   

    