%Psych 221 Final Project: EyeLab
%Mozziyar Etemadi
%2009
% Use this file with compact.avi to test algorithm performance
% For code comments, see EyeLab.m

clc;
clear all;
close all;

global DEBUG_ON; 
DEBUG_ON = 0;
global DEBUG_FIG;
DEBUG_FIG = 60;

YESNO_STRINGS = {'Yes','No'};
PREV_FIG = 100;
ISOL_FIG = 101;

fprintf('Welcome to EyeLab.  This system will help you process video\n');
fprintf('to record numbers from test equipment.  You are in training mode\n');
fprintf('which means we use compact.avi...\n');

[filename, pathname, filterindex] = uigetfile('*.avi', 'Please find compact.avi...');
if(filename==0)
    error('You did not choose a file!');
end


theReader = mmreader(sprintf('%s%s',pathname,filename));
theFrames = read(theReader);

theFrames = theFrames(35:350,160:420,:,:);

% DOWNSAMPLE FOR SIMULATION
% theFrames = theFrames(1:4:end,1:4:end,:,:);
% END DOWNSAMPLE

fprintf('File name: %s\n',filename);
fprintf('Number of frames: %g\n',size(theFrames,4));
fprintf('Frame rows: %g\n',size(theFrames,1));
fprintf('Frame cols: %g\n',size(theFrames,2));

NumDigits = 1;
ScaleFac = round([0:NumDigits]./NumDigits*size(theFrames,2));
ScaleFac(1) = 1;

ImgColor = 1;
theFrames = theFrames(:,:,ImgColor,:);
theFrames = theFrames./255;
theFrames(theFrames > .8) = 1;
theFrames(theFrames <=.8) = 0;
clf;
imagesc(theFrames(:,:,1,round(size(theFrames,4)/2)));
axis image; colormap gray;

fprintf('Okay.  Now that we have the digits, you need to choose an algorithm\n');
fprintf('to use to ID the digits.\n');
AlgChoice = MakeMenu('Algorithm Selection',{'Smart Strokes','Feature Extraction/Neural Net','Cross Correlator'});

tic;
switch AlgChoice
    case {1}
        fprintf('Welcome to Smart Strokes.  Here goes...\n');
        DoneData = zeros(size(theFrames,4),NumDigits);
        for(idx = 1:size(theFrames,4))
            for(jdx = 1:NumDigits)
               BUFFER_FACTOR = .2;
               
               timg = theFrames(:,ScaleFac(jdx):ScaleFac(jdx+1),:,idx);
               
               if(sum(timg(:))./prod(size(timg)) > 0.02)
                   temp = regionprops(timg,'BoundingBox');
                   temp.BoundingBox = round(temp.BoundingBox);
                   if(temp.BoundingBox(4) < temp.BoundingBox(3)*2.5)
                       bottom = max(temp.BoundingBox(2),1);
                       top = min(temp.BoundingBox(2)+temp.BoundingBox(4),size(timg,1));
                       left = max(temp.BoundingBox(1),1);
                       right = min(temp.BoundingBox(1)+temp.BoundingBox(3),size(timg,2));
                       timg = timg(bottom:top,left:right);

                       [RSz,CSz] = size(timg);

                       TOP_SEG = timg(1:round(RSz*BUFFER_FACTOR/2),:);
                       BOTTOM_SEG = flipud(timg(end:-1:round(RSz*(1-BUFFER_FACTOR/2)),:));
                       MIDDLE_SEG = timg(round(RSz/2-RSz*BUFFER_FACTOR/4):round(RSz/2+RSz*BUFFER_FACTOR/4),round(CSz*BUFFER_FACTOR):end-round(CSz*BUFFER_FACTOR));

                       LT_SEG = timg(1:round(RSz/2),1:round(CSz*BUFFER_FACTOR));
                       LB_SEG = timg(round(RSz/2)+1:end,1:round(CSz*BUFFER_FACTOR));
                       RT_SEG = fliplr(timg(1:round(RSz/2),end:-1:round(CSz*(1-BUFFER_FACTOR))));
                       RB_SEG = fliplr(timg(round(RSz/2)+1:end,end:-1:round(CSz*(1-BUFFER_FACTOR))));

                       if(DEBUG_ON)
                           figure(DEBUG_FIG);
                           subplot(331);
                           imagesc(LT_SEG); axis image;
                           subplot(332);
                           imagesc(TOP_SEG); axis image;
                           subplot(335);
                           imagesc(MIDDLE_SEG); axis image;
                           subplot(338);
                           imagesc(BOTTOM_SEG); axis image;
                           subplot(337);
                           imagesc(LB_SEG); axis image;
                           subplot(333);
                           imagesc(RT_SEG); axis image;
                           subplot(339);
                           imagesc(RB_SEG); axis image;
                       end

                       guess = GuessDigit(TOP_SEG,BOTTOM_SEG,MIDDLE_SEG,LT_SEG,LB_SEG,RT_SEG,RB_SEG);
                   else
                       guess = 1;
                   end
                   if(DEBUG_ON)
                       figure(DEBUG_FIG);
                       subplot(334);
                       title(sprintf('%g',guess));
                       pause;
                   end

                   if(guess==-1)
                       if(idx==1)
                           guess = 0;
                       else
                           guess = DoneData(idx-1,jdx);
                       end
                   end
               else
                   %the image became blank
                   guess = 0;
               end
               DoneData(idx,jdx) = guess;  
            end
            if(DEBUG_ON)
                figure(PREV_FIG);
                imagesc(theFrames(:,:,1,idx));
                axis image; colormap gray;
                title(mat2str(DoneData(idx,:)));
                pause;
            end
        end
    case {2}
        fprintf('Welcome to the Neural Network Approach. This could get hairy...\n');
        error('The Neural Net approach is implemented with PullParameters.m and the Neural Net Toolbox!');
    case {3}
        fprintf('Welcome to the Cross Correlation method.\n');
        DoneData = zeros(size(theFrames,4),NumDigits);
        ID_Dig = RunTrainAuto(theFrames,ScaleFac,20);
        for(idx = 1:10)
            for(jdx = 1:length(ID_Dig{idx}.TheImage))
                timg = ID_Dig{idx}.TheImage{jdx};
                temp = regionprops(timg,'BoundingBox');
                temp.BoundingBox = round(temp.BoundingBox);
                bottom = max(temp.BoundingBox(2),1);
                top = min(temp.BoundingBox(2)+temp.BoundingBox(4),size(timg,1));
                left = max(temp.BoundingBox(1),1);
                right = min(temp.BoundingBox(1)+temp.BoundingBox(3),size(timg,2));
                timg = timg(bottom:top,left:right);
                [X,Y] = meshgrid(linspace(1,100,size(timg,2)),linspace(1,100,size(timg,1)));
                [XI,YI] = meshgrid(linspace(1,100,40),linspace(1,100,70));
                timg = interp2(X,Y,double(timg),XI,YI);
                ID_Dig{idx}.TheImage{jdx} = (timg-mean(timg(:)))/std(timg(:));
                ID_Dig{idx}.TheDCT{jdx} = dct(ID_Dig{idx}.TheImage{jdx});
            end
        end
        for(idx = 1:size(theFrames,4))
            fprintf('%4g/%4g',idx,size(theFrames,4));
            for(jdx = 1:NumDigits)
               timg = theFrames(:,ScaleFac(jdx):ScaleFac(jdx+1),:,idx);
               if(sum(timg(:))./prod(size(timg)) > 0.02)
                   temp = regionprops(timg,'BoundingBox');
                   temp.BoundingBox = round(temp.BoundingBox);
                   bottom = max(temp.BoundingBox(2),1);
                   top = min(temp.BoundingBox(2)+temp.BoundingBox(4),size(timg,1));
                   left = max(temp.BoundingBox(1),1);
                   right = min(temp.BoundingBox(1)+temp.BoundingBox(3),size(timg,2));
                   timg = timg(bottom:top,left:right);
                   [X,Y] = meshgrid(linspace(1,100,size(timg,2)),linspace(1,100,size(timg,1)));
                   [XI,YI] = meshgrid(linspace(1,100,40),linspace(1,100,70));
                   timg = interp2(X,Y,double(timg),XI,YI);
                   CorrVec = zeros(1,10);
                   for(kdx = 1:10)
                       for(ldx = 1:length(ID_Dig{kdx}.TheDCT))
                           CorrVec(kdx) = CorrVec(kdx) + sum(sum((dct(timg).*ID_Dig{kdx}.TheDCT{ldx})));
                       end
                       CorrVec(kdx) = CorrVec(kdx)./length(ID_Dig{kdx}.TheDCT);
                   end
                   [junk,WhichDig] = max(CorrVec);
                   DoneData(idx,jdx) = WhichDig - 1;
               else
                   %blank digit
                   DoneData(idx,jdx) = 0;
               end
               if(DEBUG_ON)
                   figure(PREV_FIG);
                   subplot(211);
                   bar(0:9,CorrVec);
                   subplot(212);
                   imagesc(timg); axis image;
                   title(sprintf('%g',WhichDig-1));
               end
            end
            fprintf('\b\b\b\b\b\b\b\b\b');
        end
        fprintf('\n');
end

toc

DoneVec = zeros(size(DoneData,1),1);
for(idx = 1:length(DoneVec))
    tempStr = '';
    for(jdx = 1:NumDigits)
        tempStr = sprintf('%s%1g',tempStr,DoneData(idx,jdx));
    end
    DoneVec(idx) = str2double(tempStr);
end
TimeVec = [0:length(DoneVec)-1].*1/get(theReader,'FrameRate');

CorrectVec = repmat(0:1:9,100,1);
CorrectVec = CorrectVec(:);

MSE = mean((CorrectVec - DoneVec).^2)

TargMat = zeros(10,1000);
for(idx = 1:1000)
    TargMat(ceil(idx/100),idx) = 1;
end

OutMat = zeros(10,1000);
for(idx = 1:1000)
    OutMat(DoneVec(idx)+1,idx) = 1;
end

figure;
imagesc([0 9.99],[0 9],OutMat);
axis image; grid on;
plotconfusion(TargMat,OutMat);