# -*- coding: utf-8 -*-

import numpy
import cv
import scipy.weave as weave

from cvarray import *



class OpticalFlow:
  """Reimplimentation of an optical flow algorithm provided by Jian Li in python, making use of OpenCV and scipy.weave."""
  def __init__(self, width, height, maxima = 8):
    self.width = width
    self.height = height
    
    # Setup parameters...
    self.updateParam()
    
    # Create the output images...
    size = (self.width,self.height)
    self.mask = cv.CreateImage(size, cv.IPL_DEPTH_8U,  1)
    self.u    = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    self.v    = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    
    cv.SetZero(self.mask)
    cv.SetZero(self.u)
    cv.SetZero(self.v)

    # Create the runtime images...
    # (grey and real are temporary, but kept around to save being created/destroyed each time.)
    self.grey  = cv.CreateImage(size, cv.IPL_DEPTH_8U, 1)
    self.real  = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    
    self.blur0 = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    self.blur1 = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    self.blur2 = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    
    cv.SetZero(self.grey)
    cv.SetZero(self.blur0)
    cv.SetZero(self.blur1)
    cv.SetZero(self.blur2)

    self.enoughData = 2 # Need 3 frames before we can do anything - 2 because on the third we dance.
    
    self.y_2_2 = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    self.y_3_1 = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    self.y_2_1 = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    self.y_3_0 = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    self.y_2_0 = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    
    cv.SetZero(self.y_2_2)
    cv.SetZero(self.y_3_1)
    cv.SetZero(self.y_2_1)
    cv.SetZero(self.y_3_0)
    cv.SetZero(self.y_2_0)
    
    # Create the temporary images...    
    self.temp1 = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    self.temp2 = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    self.temp3 = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    
    self.ex  = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    self.ey  = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    self.et  = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    self.exx = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    self.exy = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    self.eyy = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    self.ext = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    self.eyt = cv.CreateImage(size, cv.IPL_DEPTH_32F, 1)
    
    # Create a structuring element for erosion/dilation...
    self.strucElem = cv.CreateStructuringElementEx(3, 3, 1, 1, cv.CV_SHAPE_RECT)

    # Array to store he maxima in...
    self.maxima = numpy.zeros((maxima,3),dtype=numpy.float32) # x,y,speed
    self.speed = numpy.zeros((self.height,self.width),dtype=numpy.float32)
    self.temp = numpy.zeros((self.height,self.width),dtype=numpy.float32)


  def updateParam(self, winSize = 11, timeScale = 2.0, motionThreshold = 2.0, erode = 0, dilate = 0):
    """Updates the parameters - whilst in principal can be called once nextFrame has been called you probably shouldn't. winSize is the size of the window used, i.e. 11 for a 11x11 window; must be odd. timeScale multiplies and affects the time differencing step, motionThreshold is the threshold before it bothers to determine optical flow."""
    self.winSize = winSize
    self.timeScale = timeScale
    self.motionThreshold = motionThreshold
    self.erode = erode
    self.dilate = dilate
    
    self.q = self.timeScale / (self.timeScale + 2.0)
    self.r = (self.timeScale - 2.0) / (self.timeScale + 2.0)
    self.rDbl = 2.0 * self.r
    self.qSqr = self.q * self.q
    self.rSqr = self.r * self.r

  
  def nextFrame(self,image):
    """Takes an openCV colour image as input, updates the data accordingly."""
    # Put this frame into the mix...
    cv.Copy(self.blur1, self.blur2)
    cv.Copy(self.blur0, self.blur1)	
    cv.CvtColor(image, self.grey, cv.CV_RGB2GRAY)
    cv.ConvertScale(self.grey, self.real, 1.0, 0.0);
    cv.Smooth(self.real, self.blur0, cv.CV_GAUSSIAN, 3, 0, 0, 0)
    
    # If its the first frame we can't do anything - nothign to calculate against; otherwise to the optical flow algorithm...
    if self.enoughData>0:
      self.enoughData -= 1
    else:
      # Generate temporal blurs...
      cv.ConvertScale(self.blur1, self.temp1, 2.0, 0.0)
      cv.Add(self.blur0, self.temp1, self.temp1)
      cv.Add(self.temp1, self.blur2, self.temp1)
      cv.ConvertScale(self.temp1, self.temp1, self.qSqr, 0.0)

      cv.ConvertScale(self.y_2_1, self.temp2, self.rDbl, 0.0)
      cv.ConvertScale(self.y_2_2, self.temp3, self.rSqr, 0.0)

      cv.Sub(self.temp1, self.temp2, self.temp1)
      cv.Sub(self.temp1, self.temp3, self.y_2_0)

      cv.Add(self.y_2_0, self.y_2_1, self.temp1)
      cv.ConvertScale(self.temp1, self.temp1, self.q, 0.0)

      cv.ConvertScale(self.y_3_1, self.temp2, self.r, 0.0)
      cv.Sub(self.temp1, self.temp2, self.y_3_0)

      # Move y_ images back in time ready for next frame...
      cv.Copy(self.y_2_1, self.y_2_2)
      cv.Copy(self.y_2_0, self.y_2_1)
      cv.Copy(self.y_3_0, self.y_3_1)

      # Do the optical flow algorithm...
      self.__opticalFlow(self.y_3_0, self.y_2_0)


  def __opticalFlow(self,img1,img2):
    # Create temporal derivatives...
    cv.Sobel(img1, self.ex, 1, 0, 3)
    cv.Sobel(img1, self.ey, 0, 1, 3)
    
    cv.Sub(img2, img1, self.temp1)
    cv.ConvertScale(self.temp1, self.et, self.timeScale, 0.0)
    
    # Use temporal derivatives to generate the motion mask...
    aet = cv2array(self.et)
    amask = numpy.empty((self.height,self.width),dtype=numpy.uint8)

    maskCreate = """
    for (int x=0;x<Naet[0];x++)
    {
     for (int y=0;y<Naet[1];y++)
     {
      if (AET2(x,y)>threshold) AMASK2(x,y) = 1;
                          else AMASK2(x,y) = 0;
     }
    }
    """
    
    threshold = self.motionThreshold
    weave.inline(maskCreate,['aet','amask','threshold'])
    self.mask = array2cv(amask)
    
    # If requested erode/dilate the motion mask...
    if self.dilate!=0: cv.Dilate(self.mask, self.mask, self.strucElem, self.dilate)
    if self.erode!=0: cv.Erode(self.mask, self.mask, self.strucElem, self.erode)
    
    # Mask the derivatives...
    cv.ConvertScale(self.mask, self.temp1, 1.0, 0.0)
    cv.Mul(self.ex, self.temp1, self.ex, 1.0)
    cv.Mul(self.ey, self.temp1, self.ey, 1.0)
    cv.Mul(self.et, self.temp1, self.et, 1.0)
    
    # Multiply derivatives as needed...
    cv.Mul(self.ex, self.ex, self.exx, 1.0)
    cv.Mul(self.ex, self.ey, self.exy, 1.0)
    cv.Mul(self.ey, self.ey, self.eyy, 1.0)
    cv.Mul(self.ex, self.et, self.ext, 1.0)
    cv.Mul(self.ey, self.et, self.eyt, 1.0)
    
    # Smooth derivative images...
    cv.Smooth(self.exx, self.exx, cv.CV_GAUSSIAN, self.winSize, 0, 0, 0)
    cv.Smooth(self.exy, self.exy, cv.CV_GAUSSIAN, self.winSize, 0, 0, 0)
    cv.Smooth(self.eyy, self.eyy, cv.CV_GAUSSIAN, self.winSize, 0, 0, 0)
    cv.Smooth(self.ext, self.ext, cv.CV_GAUSSIAN, self.winSize, 0, 0, 0)
    cv.Smooth(self.eyt, self.eyt, cv.CV_GAUSSIAN, self.winSize, 0, 0, 0)
    
    # Calculate the final velocities...
    aexx = cv2array(self.exx)
    aexy = cv2array(self.exy)
    aeyy = cv2array(self.eyy)
    aext = cv2array(self.ext)
    aeyt = cv2array(self.eyt)
    au = numpy.zeros((self.height,self.width),dtype=numpy.float32)
    av = numpy.zeros((self.height,self.width),dtype=numpy.float32)
    
    velCalc = """
    for (int x=0;x<Namask[0];x++)
    {
     for (int y=0;y<Namask[1];y++)
     {
      if (AMASK2(x,y)!=0)
      {
       float invDet = AEXX2(x,y)*AEYY2(x,y) - AEXY2(x,y)*AEXY2(x,y);
       if (fabs(invDet)>1e-3)
       {
        AU2(x,y) = (AEYY2(x,y)*AEXT2(x,y) - AEXY2(x,y)*AEYT2(x,y))/invDet;
	AV2(x,y) = (AEXX2(x,y)*AEYT2(x,y) - AEXY2(x,y)*AEXT2(x,y))/invDet;
       }
      }
     }
    }
    """
    
    weave.inline(velCalc,['amask','aexx','aexy','aeyy','aext','aeyt','au','av'])


    # Find the top n maxima...
    maximaFind = """
    // First calculate speed...
    for (int y=0;y<Nspeed[1];y++)
    {
     for (int x=0;x<Nspeed[0];x++)
     {
      SPEED2(x,y) = sqrt(AU2(x,y)*AU2(x,y) + AV2(x,y)*AV2(x,y));
     }
    }

    // Blur it with a Gausian...
    // Make the (unnormalised) kernel...
    const int kSize = 8;
    const float sd = 4.0;
    float kernel[kSize];
    for (int i=0;i<kSize;i++) kernel[i] = exp(-float(i*i)/(2.0*sd*sd));

    // speed to temp, x axis...
    for (int y=0;y<Nspeed[1];y++)
    {
     for (int x=0;x<Nspeed[0];x++)
     {
      TEMP2(x,y) = 0.0;
      float div = 0.0;
      
      int low = x - kSize + 1;  if (low<0) low = 0;
      int high = x + kSize - 1; if (high>Nspeed[0]-1) high = Nspeed[0]-1;
      
      for (int i=low;i<=high;i++)
      {
       float w = kernel[abs(x-i)];
       TEMP2(x,y) += w*SPEED2(i,y)*AMASK2(i,y);
       div += w;
      }
      TEMP2(x,y) /= div;
     }
    }

    // temp to speed, y axis...
    for (int y=0;y<Nspeed[1];y++)
    {
     for (int x=0;x<Nspeed[0];x++)
     {
      SPEED2(x,y) = 0.0;
      float div = 0.0;

      int low = y - kSize + 1;  if (low<0) low = 0;
      int high = y + kSize - 1; if (high>Nspeed[1]-1) high = Nspeed[1]-1;

      for (int i=low;i<=high;i++)
      {
       float w = kernel[abs(y-i)];
       SPEED2(x,y) += w*TEMP2(x,i)*AMASK2(x,i);
       div += w;
      }
      SPEED2(x,y) /= div;
     }
    } 

    // Now find maxima...
    int maxFound = 0;
    for (int y=1;y<Nspeed[1]-1;y++)
    {
     for (int x=1;x<Nspeed[0]-1;x++)
     {
      if (AMASK2(x,y)==0) continue;
      
      // Check if its a maxima...
       float s = SPEED2(x,y);
       if (s<=SPEED2(x-1,y-1)) continue;
       if (s<=SPEED2(x-1,y)) continue;
       if (s<=SPEED2(x-1,y+1)) continue;
       if (s<=SPEED2(x,y-1)) continue;
       if (s<=SPEED2(x,y+1)) continue;
       if (s<=SPEED2(x+1,y-1)) continue;
       if (s<=SPEED2(x+1,y)) continue;
       if (s<=SPEED2(x+1,y+1)) continue;

      // Insert it into the maxima array...
       bool inserted = true;
       if (maxFound<Nmaxima[0])
       {
        MAXIMA2(maxFound,0) = x;
        MAXIMA2(maxFound,1) = y;
        MAXIMA2(maxFound,2) = s;
        maxFound += 1;
       }
       else
       {
        if (s>MAXIMA2(maxFound-1,2))
        {
         MAXIMA2(maxFound-1,0) = x;
         MAXIMA2(maxFound-1,1) = y;
         MAXIMA2(maxFound-1,2) = s;
        }
        else inserted = false;
       }

      // Move the lowest score item to the bottom of the array for the next test...
       if (inserted)
       {
        int lowest = 0;
        for (int i=1;i<maxFound;i++)
        {
         if (MAXIMA2(i,2)<MAXIMA2(lowest,2)) lowest = i;
        }
        if (lowest!=(maxFound-1))
        {
         for (int i=0;i<3;i++)
         {
          float swap = MAXIMA2(lowest,i);
          MAXIMA2(lowest,i) = MAXIMA2(maxFound-1,i);
          MAXIMA2(maxFound-1,i) = swap;
         }
        }
       }
     }
    }
    """
    
    maxima = self.maxima
    speed = self.speed
    temp = self.temp
    weave.inline(maximaFind,['amask','au','av','speed','temp','maxima'])

    self.maxima = maxima[numpy.argsort(maxima[:,2]),:]


    # Convert the numpy arrays into cv images...
    self.u = array2cv(au)
    self.v = array2cv(av)
    


  def getMask(self):
    """Returns a mask of where motion is avaliable - 1 means avaliable, 0 means not."""
    return self.mask
  
  def getU(self):
    """Returns the x motion for the last frame added, if avaliable, as an image."""
    return self.u
  
  def getV(self):
    """Returns the y motion for the last frame added, if avaliable, as an image."""
    return self.v

  def getMaxima(self):
    """Returns a sorted list of maxima seen in the last frame, as a matrix where each row is an entry of x,y,speed. First row is the fastest etc, and the array will be re-filled when you give it the next frame, so make a copy if needed."""
    return self.maxima
