/*
Detects SIFT features in two images and finds matches between them.

Copyright (C) 2006  Rob Hess <hess@eecs.oregonstate.edu>

@version 1.1.1-20070913
*/

#include <cv.h>
#include <cxcore.h>
#include <highgui.h>
#include <vector>

#include <stdio.h>

extern "C" {
#include "sift.h"
#include "imgfeatures.h"
#include "kdtree.h"
#include "utils.h"
#include "xform.h"
}



/* Defines */
/***********/
//#define kWindowName        "Combined Image"
//#define kScaleFactor       15.0           /* MAKE SURE TO USE AN ODD SCALE FACTOR!!!!!! */
#define kOutputFileName1   "bilateralized.jpg"
#define kOutputFileName2   "Combined.jpg"
#define kWindowName        "MotionMatte Image"
#define kOutputFileName    "Output.jpg"

#define PIXEL_SURROUND 5

#define MY_PI 3.1415926535
#define SIGMA_S 4
//#define SIGMA_R1 70
#define SIGMA_R1 40
#define SIGMA_R2 8

/* the maximum number of keypoint NN candidates to check during BBF search */
#define KDTREE_BBF_MAX_NN_CHKS 200

/* threshold on squared ratio of distances between NN and 2nd NN */
#define NN_SQ_DIST_RATIO_THR 0.49

#define MAX_X_ITER 3



/* Macros */
/**********/
#ifndef MIN
#define MIN(a,b)  (a < b ? a : b)
#endif

#ifndef MAX
#define MAX(a,b)  (a > b ? a : b)
#endif




/* Namespaces */
/**************/
using namespace std;

/* Extern Functions */
/***********************/
void showMatte( vector< vector< double **>* >, int, int );
void simpleAvg( IplImage**, vector< vector< IplImage* > > );

int countFeatures(char *keyFilePath);
/* Function Prototypes */
/***********************/
int countFeatures(char *keyFilePath);

int performMatting( IplImage* baseImage, vector< vector< IplImage* > > imageStore ) {
   
   /* Get the number of images supplied */
   unsigned int numImages = imageStore.size();
 
   IplImage *bilateralized = cvCreateImage(cvGetSize(baseImage), baseImage->depth, baseImage->nChannels);

   for (int i=0; i < baseImage->height; i++) {
     for (int j=0; j < baseImage->width; j++) {

       CvScalar basePixel = cvGet2D(baseImage, i, j);
       double totalweight = 0;

       vector< CvScalar > bilatPixels;
       vector< double > weights;

       for (int c1=-PIXEL_SURROUND; c1 < PIXEL_SURROUND; c1++) {
         for (int c2=-PIXEL_SURROUND; c2 < PIXEL_SURROUND; c2++) {

           if( i + c1 >= 0 && i + c1 < baseImage->height
            && j + c2 >= 0 && j + c2 < baseImage->width ) {
             double domaindist = c1 * c1 + c2 * c2;

             double domainweight = (1 / (2 * MY_PI * SIGMA_S)) *
               exp ( - domaindist / (2*SIGMA_S*SIGMA_S) );

             CvScalar bilateralPixel = cvGet2D(baseImage, i + c1, j + c2);

             double rangedist = (basePixel.val[0] - bilateralPixel.val[0]) * (basePixel.val[0] - bilateralPixel.val[0]) +
                 (basePixel.val[1] - bilateralPixel.val[1]) * (basePixel.val[1] - bilateralPixel.val[1]) +
                 (basePixel.val[2] - bilateralPixel.val[2]) * (basePixel.val[2] - bilateralPixel.val[2]);

             double rangeweight = (1 / (2 * MY_PI * SIGMA_R1)) *
               exp ( - rangedist/ (2*SIGMA_R1*SIGMA_R1) );

             double weight = rangeweight * domainweight;

             bilatPixels.push_back( cvScalar( bilateralPixel.val[0]*weight, bilateralPixel.val[1]*weight, bilateralPixel.val[2]*weight ) );
             //bilatPixels.push_back( cvScalar( bilateralPixel.val[0], bilateralPixel.val[1], bilateralPixel.val[2] ) );
             weights.push_back( weight );

             totalweight += weight;
           }
         }
       }

       CvScalar pixelVal = cvScalar(0,0,0);

       /*double tally = 0;

       for( unsigned int doing = 0; doing < bilatPixels.size(); doing++ ) {

         tally += weights[doing];

         if( tally >= (totalweight/2) ) {
           pixelVal.val[0] = bilatPixels[doing].val[0];
           pixelVal.val[1] = bilatPixels[doing].val[1];
           pixelVal.val[2] = bilatPixels[doing].val[2];
           break;
         }
       }*/

       for( unsigned int doing = 0; doing < bilatPixels.size(); ++doing ) {
         pixelVal.val[0] += bilatPixels[doing].val[0] / totalweight;
         pixelVal.val[1] += bilatPixels[doing].val[1] / totalweight;
         pixelVal.val[2] += bilatPixels[doing].val[2] / totalweight;
       }
       
       cvSet2D(bilateralized, i, j, pixelVal);
     }
   }

   cvSaveImage(kOutputFileName1, bilateralized);
  
   /* Create an output image */
   IplImage *outputImage = cvCreateImage(cvGetSize(baseImage), baseImage->depth, baseImage->nChannels);
   
   /* Create an output image */
   printf("Creating an output image\n");

   fprintf( stderr, "images: %d\n", numImages );

   vector< vector< double **>* > matteWeights;

   for (unsigned int imageNum=0; imageNum < numImages; ++imageNum) {
     matteWeights.push_back( new vector< double ** >() );
     for (unsigned int alignNum=0; alignNum < imageStore[imageNum].size(); ++alignNum) {
        double** weightArray = (double**)malloc( baseImage->height * sizeof( double* ) );
        matteWeights[ imageNum ]->push_back( weightArray );

        for (unsigned int temp=0; temp < baseImage->height; ++temp) {
          weightArray[ temp ] = (double*)malloc( baseImage->width * sizeof( double ) );
        }
     }
   }

   for (int i=0; i < baseImage->height; i++) {
      for (int j=0; j < baseImage->width; j++) {

         CvScalar resultPixel = cvScalar(0, 0, 0);
         double totalweight = 0;
         vector< CvScalar > bilatPixels;
         
         for (unsigned int imageNum=0; imageNum < numImages; ++imageNum) {
           for (unsigned int alignNum=0; alignNum < imageStore[imageNum].size(); ++alignNum) {

             CvScalar basePixel = cvGet2D(bilateralized, i, j);
             CvScalar currentPixelValue = cvGet2D(imageStore[imageNum][alignNum], i, j);

             double rangedist = (basePixel.val[0] - currentPixelValue.val[0]) * (basePixel.val[0] - currentPixelValue.val[0]) +
                 (basePixel.val[1] - currentPixelValue.val[1]) * (basePixel.val[1] - currentPixelValue.val[1]) +
                 (basePixel.val[2] - currentPixelValue.val[2]) * (basePixel.val[2] - currentPixelValue.val[2]);

             double rangeweight = (1 / (2 * MY_PI * SIGMA_R2)) *
               exp ( - rangedist/ (2*SIGMA_R2*SIGMA_R2) );

             bilatPixels.push_back( cvScalar( currentPixelValue.val[0]*rangeweight, currentPixelValue.val[1]*rangeweight, currentPixelValue.val[2]*rangeweight ) );

             totalweight += rangeweight;

             (*matteWeights[ imageNum ])[ alignNum ][ i ] [ j ] = rangeweight;
           }
         }

       CvScalar pixelVal = cvScalar(0,0,0);

       for( unsigned int doing = 0; doing < bilatPixels.size(); ++doing ) {
         pixelVal.val[0] += bilatPixels[doing].val[0] / totalweight;
         pixelVal.val[1] += bilatPixels[doing].val[1] / totalweight;
         pixelVal.val[2] += bilatPixels[doing].val[2] / totalweight;
       }
       
       cvSet2D(outputImage, i, j, pixelVal);
     }
   }

   showMatte( matteWeights, baseImage->height, baseImage->width );
   printf("Process complete...\n");
   
   /* Display the image within the window */
   //cvShowImage(kWindowName, outputImage);
   cvWaitKey(0);
   
   /* Save the image for project submission */
   cvSaveImage(kOutputFileName2, outputImage);
   
   /* Release the image objects */
   for (unsigned int i=0; i < numImages; i++) {
      //cvReleaseImage(&imageArray[i]);
      //cvReleaseImage(&matteArray[i]);
   }
   //cvReleaseImage(&outputImage);
   
   /* Release the window object */
   cvDestroyWindow(kWindowName);
   
   
   return 0;
   
}

/******************************************************************************\
|* Function Name: main                                                        *|
|* Prototype:     int main (int argc, char * const argv[]);                   *|
|* Author:        Brendan Duncan - brendand@stanford.edu                      *|
|* Date:          2/11/2010                                                   *|
|* Description:   argv[1] - image1 path                                       *|
|*                argv[2] - image2 path                                       *|
|*                   .                                                        *|
|*                argv[k] - image-k path                                      *|
\******************************************************************************/
int main (int argc, char * const argv[]) {
   
   /* Get the number of images supplied */
   unsigned int numImages = argc - 1;
   
   /* Create an array of images */
   IplImage **imageArray = (IplImage**)malloc( numImages * sizeof(IplImage*) );
   vector< vector< IplImage* > > imageStore;
   int* xFormedIndeces = (int*)calloc( sizeof(int) , numImages );
   
   /* Read in all of the images */
   for (unsigned int i=0; i < numImages; i++) {
      imageArray[i] = cvLoadImage(argv[i+1]);
      if (!imageArray[i]) {
         printf("Failed to import image.\n");
         printf("Example usage: ./Combine image1Path image2Path [...imagekpath]\n");
         return -1;
      } else {
         printf("Loaded image %d\n", i+1);
      }
   }
   
   /* Create a list of features for the first image */
   struct feature *feat1;
   int n1 = sift_features(imageArray[0], &feat1);
   printf("Calculated SIFT features for image 1\n");
   
   /* Build a kd-tree */
   struct kd_node* kd_root;
   kd_root = kdtree_build(feat1, n1);
   printf("Built a KD-Tree\n");
   
   /* Define the pixels which define the region of interest */
   CvPoint upperLeft = cvPoint(0, 0);
   CvPoint lowerRight = cvPoint(imageArray[0]->width - 1, imageArray[0]->height - 1);
      
  imageStore.push_back( vector<IplImage*>() );
  imageStore[0].push_back( imageArray[0] );
   
   /* Loop through the remaining images */
   for (unsigned int imageNum=1; imageNum < numImages; imageNum++) {

      struct feature *feat2;
      struct feature *feat;
      struct feature** nbrs;
      double d0, d1;
      int n2, k, m = 0;

      imageStore.push_back( vector<IplImage*>() );

      n2 = sift_features(imageArray[imageNum], &feat2);
      
      printf("Calculated SIFT features for image %d\n", imageNum+1);
      
      for(int i = 0; i < n2; i++ ) {
         feat = &feat2[i];
         k = kdtree_bbf_knn(kd_root, feat, 2, &nbrs, KDTREE_BBF_MAX_NN_CHKS);
         if( k == 2 ) {
            d0 = descr_dist_sq(feat, nbrs[0]);
            d1 = descr_dist_sq(feat, nbrs[1]);
        
            if( d0 < d1 * NN_SQ_DIST_RATIO_THR ) {
               m++;
               feat2[i].fwd_match = nbrs[0];
            }
         }
         free(nbrs);
      }

      int xIndex = 0;
      int numNextFeatures = 0;
      do {

        IplImage *xformed;
        CvMat* H;
        struct feature ** inliers;
        int numInliers;
        H = ransac_xform( feat2, n2, FEATURE_FWD_MATCH, lsq_homog, 4, 0.01, homog_xfer_err, 3.0, &inliers, &numInliers );

        if( numInliers < 4 ) break;

        xFormedIndeces[imageNum-1]++;

        numNextFeatures = 0;
        for( int i = 0; i < n2; i++ ) {
          bool bbb = 0;
          for( int j = 0; j < numInliers; j++ ) {
            if( &feat2[i] == inliers[j] ) {
              bbb = 1;
              break;
            }
          }
          if( !bbb ) {
            feat2[numNextFeatures++] = feat2[i];
          }
        }
        fprintf(stderr,"inliers:%d prev: %d next:%d\n", numInliers, n2, numNextFeatures);
        n2 = numNextFeatures;

        if(H) {
          xformed = cvCreateImage(cvGetSize(imageArray[0]), IPL_DEPTH_8U, 3);
          //cvWarpPerspective( imageArray[imageNum], xformed, H, CV_INTER_LINEAR + CV_WARP_FILL_OUTLIERS, CV_RGB(255,255,255) );
          cvWarpPerspective( imageArray[imageNum], xformed, H, CV_INTER_CUBIC + CV_WARP_FILL_OUTLIERS, CV_RGB(255,255,255) );
          //cvWarpPerspective( imageArray[imageNum], xformed, H, CV_INTER_LINEAR);
          //cvReleaseImage(&imageArray[imageNum]);
          //imageArray[imageNum] = xformed;

          imageStore[imageNum-1].push_back( cvCreateImage(cvGetSize(imageArray[0]), IPL_DEPTH_8U, 3) );
          cvCopy( xformed, imageStore[imageNum-1][( imageStore[imageNum-1].size() - 1 ) ]);

          cvReleaseMat(&H);
        }

      } while ( numNextFeatures > 4 && ++xIndex < MAX_X_ITER );

      free(feat2);
      
      printf("Calculated warped image %d\n", imageNum+1);
   }
   kdtree_release(kd_root);
   free(feat1);

   simpleAvg( imageArray, imageStore );
   
   /* Release the image objects */
   for (unsigned int i=0; i < numImages; i++) {
      //cvReleaseImage(&imageArray[i]);
   }
   //cvReleaseImage(&combinedImage);
   
   /* Release the window object */
   cvDestroyWindow(kWindowName);

   for( int picIndex = 0; picIndex < imageStore.size(); picIndex++ ) {
     for( int xIndex = 0; xIndex < xFormedIndeces[picIndex]; xIndex++ ) {

       char imgname[20];

       sprintf(imgname, "match%d_%d.jpg", xIndex, picIndex);

       cvSaveImage( imgname, imageStore[picIndex][(xIndex)] );
     }
   }

   performMatting( imageArray[0], imageStore );
   
   return 0;
   
} /* end main */





/* End of File */
