
// Copyright (C) 2008 Eric Chassande-Mottin, CNRS (France)

// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, see .

#include <octave/config.h>
#include <octave/defun-dld.h>
#include <octave/error.h>
#include <octave/gripes.h>
#include <octave/oct-obj.h>
#include <octave/pager.h>
#include <octave/quit.h>
#include <octave/variables.h>

#include <vector>
using namespace std;

static int Nr1,step,Nr2,debug_flag;

static int is_power_of_two(int n)
{
  if (n>0)
    {
      while(!(n & 1))
	n = n>>1;
    }
  return(n==1);
}

inline static int min ( int x, int y)
{
  return x < y ? x : y;
}


inline static int min (const int32NDArray& v)
{
  int l = v.length ();

  if (l == 0)
    return octave_NaN;
  
  int retval=v(0);
  for (int i=1; i<l; i++)
    retval=( (int) v(i) < retval ? (int) v(i) : retval);
  
  return retval;

}

inline static int max (const int32NDArray& v)
{
  int l = v.length ();

  if (l == 0)
    return octave_NaN;
  
  int retval=v(0);
  for (int i=1; i<l; i++)
    retval=( (int) v(i) > retval ? (int) v(i) : retval);
  
  return retval;

}

static ColumnVector divide (const ColumnVector& x, const ColumnVector& y)
{
  int lx = x.length ();
  int ly = y.length ();

  if (lx != ly)
    {
      gripe_nonconformant ("divide", lx, ly);
      return ColumnVector ();
    }

  if (lx == 0)
    return ColumnVector (0);

  ColumnVector retval(lx);
  for (int i=0; i<lx; i++)
    retval(i)=( y(i)==0.0 ? 0.0 : x(i)/y(i));

  return retval;
}


static void
compute_statistic (double *numer, double *denom, int m0, int m1, const  Matrix & data, const Matrix & norm)
{
  int b, m, n;
  double a, dm;
  
  b=data.columns();
  
  *numer=0.0;
  *denom=0.0;

  if (m1==m0)
    {
        for (n=1 ; n<b ; n++)
          {
            *numer=data(m1,0);
            *denom=norm(m1,0);
            for (n=1 ; n<b ; n++)
	      {
		*numer+=data(m1,n);
		*denom+=norm(m1,n);
	      }
          }
    }
  else
    {
      a=(double) (m0-m1)/((double) b);
      *numer=data(m1,0);
      *denom=norm(m1,0);
      for (n=1 ; n<b ; n++)
	{
	  dm=m1+a*n;
	  m=(int) dm; 
	  dm-=m;
	  *numer+=(1-dm)*data(m,n)+dm*data(m+1,n);
	  *denom+=(1-dm)*norm(m,n)+dm*norm(m+1,n);
	}
    }

}

static void 
DP_first_interval(ColumnVector& current, ColumnVector& cur_norm, const  Matrix& data, const Matrix& norm, const int32NDArray& index, const int32NDArray& label)
{
  int m0,m1,i,N,c;

  N=index.length();
  
  for (m0=0 ; m0<N ; m0++)
    {
      c=label(m0);
      compute_statistic(&current(c),&cur_norm(c),(int) index(m0),(int) index(m0),data,norm);
	
      i=1; m1=m0-step;
      while ((i<=Nr1)&(m1>=0))
	{
	  compute_statistic(&current(c-i),&cur_norm(c-i),(int) index(m0),(int) index(m1),data,norm);
	  m1-=step; i++;
	}
	
      i=1; m1=m0+step;
      while ((i<=Nr1)&(m1<N))
	{
	  compute_statistic(&current(c+i),&cur_norm(c+i),(int) index(m0),(int) index(m1),data,norm);
	  m1+=step; i++;
	}
    }
}

static int
find_max(const ColumnVector& previous, const int m0, const int m1, const int N, const int32NDArray& label)
{
  int m2,d,d2,max,i,j,c;
  double s=0.0;

  d=m1-m0;
  m2=m1+d;  // we prolong with the chirplet in the continuity
  i=(int) (d/step);
  c=label(m1);
  c+=i;

  // if this chirplet is outside the TF domain (below lower boundary)
  if (m2<0) 
    {
      j=min(Nr2,Nr1-i);
      d2=j*step;

      if (m2+d2>=0)
	{
	  s=previous(c+j);
	  max=c+j;
	  j--; d2-=step;

	  while (m2+d2>=0)
	    {
	      if (previous(c+j)>s)
		{
		  s=previous(c+j);
		  max=c+j;
		}
	      j--; d2-=step;
	    }
	}
      else
	max=-1;
      return max;
    }

  // if this chirplet is outside the TF domain (above upper boundary)
  if (m2>=N)
    {
      j=-min(Nr2,Nr1+i);
      d2=j*step;

      if (m2+d2<N)
	{
	  s=previous(c+j);
	  max=c+j;

	  j++; d2+=step;
	  while (m2+d2<N)
	    {
	      if (previous(c+j)>s)
		{
		  s=previous(c+j);
		  max=c+j;
		}
	      j++; d2+=step;
	    }
	}
      else
	max=-1;
      return max;
    }

  // if this chirplet is inside the TF domain
  s=previous(c);
  max=c;

  // scan downward
  j=-1; m2=m1+d-step;
  while ((m2>=0)&(j>=-Nr2)&(i+j>=-Nr1))
    {

      if (previous(c+j)>s)
	{
	  s=previous(c+j);
	  max=c+j;
	}
      j--; m2-=step;
    }
  
  // scan upward
  j=1; m2=m1+d+step;
  while ((m2<N)&(j<=Nr2)&(i+j<=Nr1))
    {
      if (previous(c+j)>s)
	{
	  s=previous(c+j);
	  max=c+j;
	}
      j++; m2+=step;
    }
  
  return max;
}

static void
DP_next_interval(ColumnVector& current, ColumnVector& norm_cur, int32NDArray& trace, const ColumnVector& previous, const ColumnVector& norm_pre, const Matrix& data, const Matrix& norm, const int32NDArray& index, const int32NDArray& label)
{
  
  double s=0.0,n=0.0;
  int i,m0,m1,max,N,c;
  bool trace_flag;

  N=index.length();
  trace_flag=trace.length()>0;

  ColumnVector previous_statistic=divide(previous,norm_pre);
  
  for (m0=0 ; m0<N ; m0++)
    {
      c=label(m0);
      
      compute_statistic(&s,&n,(int) index(m0),(int) index(m0),data,norm);
      max=find_max(previous_statistic,m0,m0,N,label);
      current(c)=previous(max)+s;
      norm_cur(c)=norm_pre(max)+n;      
      if (trace_flag)
	trace(c)=max;
      
      i=1; m1=m0-step;
      while ((m1>=0)&(i<=Nr1))
	{
	  compute_statistic(&s,&n,(int) index(m0),(int) index(m1),data,norm);
	  max=find_max(previous_statistic,m0,m1,N,label);
	  if (max>=0) // is chirplet connected?
	    {
	      current(c-i)=previous(max)+s;
	      norm_cur(c-i)=norm_pre(max)+n;
	    }
	  else
	    {
	      current(c-i)=-INFINITY;
	      norm_cur(c-i)=+INFINITY;
	    }
	  if (trace_flag)
	    trace(c-i)=max;
	  i++; m1-=step;
	}
	
	i=1; m1=m0+step;
	while ((m1<N)&(i<=Nr1))
	  {
	    compute_statistic(&s,&n,(int) index(m0),(int) index(m1),data,norm);
	    max=find_max(previous_statistic,m0,m1,N,label);
	    if (max>=0) // is chirplet connected?
	      {
		current(c+i)=previous(max)+s;
		norm_cur(c+i)=norm_pre(max)+n;
	      }
	    else
	      {
		current(c+i)=-INFINITY;
		norm_cur(c+i)=+INFINITY;
	      }
	    if (trace_flag)
	      trace(c+i)=max;
	    i++; m1+=step;
	  }
    }

}

static int 
chirplet_labels(int32NDArray & horizontal, int32NDArray & steepest)
{
  int i,k,l,n=0,N;

  N=horizontal.length();
  
  for (i=0 ; i<N ; i++)
    {
      l=0; k=i-step;
      while ((k>=0)&(l<=Nr1))
	{
	  k-=step;
	  l++; n++;
	}
      horizontal(i)=n; n++;
      l=0; k=i+step;
      while ((k<N)&(l<=Nr1))
	{
	  k+=step;
	  l++; n++;
	}
      steepest(i)=n-1;
    }
  
  return(n);
}

RowVector trace_chain(int seed, int step, const int32NDArray& trace, const int32NDArray& horizontal, const int32NDArray& steepest)
{
  int N=trace.columns();

  if ((int) trace(seed,N-1)<0)
    return RowVector ();

  vector<double> tmp;

  int k=0;
  while ((int) steepest(k)<seed)
    k++;
  tmp.push_back((double)(k+1));

  int n=seed; 
  for (int j=N-1 ; j>-1 ; j--)
    {
      if ((int) trace(n,j)<0)
	break;

      n=trace(n,j);
      k=0;
      while ((int) steepest(k)<n)
	k++;
      tmp.push_back((double)(k+1));
    }

  N=tmp.size();
  RowVector chain(N+1);

  for (int j=N; j>0; j--)
    chain(j)=tmp[N-j];

  chain(0)=chain(1)+step*(n-(double) horizontal(k));

  return (chain); 
}

void check_trace(int32NDArray & trace, const int32NDArray& horizontal, const int32NDArray& steepest)
{

  // check integrity

  for (int j=0 ; j<trace.columns() ; j++)
    for (int n=0 ; n<trace.rows() ; n++)
      if (((int) trace(n,j)<-1) | ((int) trace(n,j)>=trace.rows()))
	error("dpw: trace(%d,%d)=%d and %d chirplets\n",j,n,(int) trace(n,j),trace.rows());
  
}

DEFUN_DLD (dpw, args, nargout,
  "-*- texinfo -*-\n\
@deftypefn {Loadable Function} {[@var{snr} @var{max} @var{trace}] =} dpw (@var{in},@var{norm},@var{index},@var{b},@var{Nr1},@var{step},@var{Nr2})\n\
@deftypefnx {Loadable Function} {[@var{snr} @var{max} @var{trace}] =} dpw (@var{in},@var{norm},@var{index},@var{b},@var{Nr1},@var{step},@var{Nr2},@var{debug})\n\
Find the path of maximum weighted integral in matrix @var{in} with dynamic programming.\n\
The path regularity is specified by the parameters @var{b},@var{Nr1},@var{step} and @var{Nr2}.\n\
The integral is weighted by the inverse path integral in @var{norm}.\n\
The search can be limited to paths contained within the rows of @var{in} in the range @var{index}.\n\
If @var{index} is an empty vector, the index range defaults to the entire set of rows of @var{in}.\n\
@end deftypefn\n\n")
{
  octave_value_list retval;
  
  int nargin = args.length ();
  
  if (nargin < 7)
    {
      print_usage();
      return retval;
    }

  bool trace_flag=false;
  if (nargout > 2)
    trace_flag=true; // back tracing required

  /****************** get and validate arguments ******************/

  /* data */
  Matrix in (args(0).matrix_value());

  if (error_state) 
    { 
      gripe_wrong_type_arg("dpw",args(0));
      return retval; 
    }

  int c=in.columns();
  int r=in.rows();

  if (!is_power_of_two(c))
    {
      error("dpw: the number of columns of in should be a power of 2");
      return retval;
    }

  /* normalization weight */
  Matrix norm (args(1).matrix_value());

  if (error_state) 
    { 
      gripe_wrong_type_arg("dpw",args(1));
      return retval; 
    }
  
  if ((norm.rows() != r)|(norm.columns() != c))
    {
      error("dpw: size of norm should match size of in");
      return retval;
    }

  /* row indices */
  int32NDArray index ( args(2).int32_array_value());

  if (error_state) 
    { 
      gripe_wrong_type_arg("dpw",args(2));
      return retval; 
    }

  if (index.length() == 0)
    {
      /* set default y indices */
      index.resize (dim_vector(r));
      for (int n=0; n<r; n++)
	index(n)=n;
    }
  else
    {
      /* convert Octave indices to C++ indices */
      index=index-1;
    }

  if ((min(index)<0)|(max(index)>=r))
    {
      error("dpw: invalid index range");
      return retval;
    }

  int ni=index.length();

  /* size of chirplet interval */
  int b =args(3).int_value();

  if (error_state) 
    { 
      gripe_wrong_type_arg("dpw",args(3));
      return retval; 
    }

  if (!is_power_of_two(b))
    {
      error("dpw: b should be a power of 2");
      return retval;
    }

  if ((b<2)|(b>=c))
    {
      error("dpw: b should be 0 < b < columns(in)");
      return retval;
    }
  
  int Nt=(int) c/b;

  /* 1st order regularity of the chirplet chain */
  Nr1 =args(4).int_value();

  if (error_state) 
    { 
      gripe_wrong_type_arg("dpw",args(4));
      return retval; 
    }

  if (Nr1<0)
    {
      error("dpw: Nr1 should be > 0");
      return retval;
    }

  /* step size*/
  step =args(5).int_value();

  if (error_state) 
    { 
      gripe_wrong_type_arg("dpw",args(5));
      return retval; 
    }

  if (step<0)
    {
      error("dpw: step should be > 0");
      return retval;
    }

  /* 2nd order regularity of the chirplet chain */
  Nr2 =args(6).int_value();

  if (error_state) 
    { 
      gripe_wrong_type_arg("dpw",args(6));
      return retval; 
    }

  if (Nr2<0)
    {
      error("dpw: Nr2 should be > 0");
      return retval;
    }
  
  if (Nr2>=(2*Nr1))
    fprintf(stderr,"dpw: warning! no effective constraint on 2nd derivative");

  /* debug flag */
  debug_flag=0;
  if (nargin==8)
    {
      debug_flag=args(7).int_value();

      if (error_state) 
	{ 
	  gripe_wrong_type_arg("dpw",args(7));
	  return retval; 
	}
    }

  if (debug_flag >= 2)
    trace_flag=true; // force back tracing in debug mode

  /****************** compute chirplet labels ******************/

  dim_vector label_dims(ni);
  int32NDArray label_horiz_chirplet(label_dims);
  int32NDArray label_steepest_chirplet(label_dims);
  int Nc=chirplet_labels(label_horiz_chirplet,label_steepest_chirplet);
    
  /******************        main loop        ******************/

  ColumnVector cur(Nc,0.0);
  ColumnVector norm_cur(Nc,0.0);
  ColumnVector pre(Nc,0.0);
  ColumnVector norm_pre(Nc,0.0);
  
  dim_vector trace_dims(trace_flag ? dim_vector(Nc, Nt-1) : dim_vector(0));
  int32NDArray trace(trace_dims);

  dim_vector this_trace_dims(trace_flag ? dim_vector(Nc) : dim_vector(0));
  int32NDArray this_trace(this_trace_dims);

  Matrix interval(r,b,0.0);
  interval=in.extract(0,0,r-1,b-1);

  DP_first_interval(pre,norm_pre,interval,norm,index,label_horiz_chirplet);

//   for (int n=0; n<pre.length(); n++)
//     printf(" %02d ",n);
//   printf("\n");
//   for (int n=0; n<pre.length(); n++)
//     printf("%02.1f ",pre(n));
//   printf("\n");

//   for (int n=0; n<norm_pre.length(); n++)
//     printf(" %02d ",n);
//   printf("\n");
//   for (int n=0; n<norm_pre.length(); n++)
//     printf("%02.1f ",norm_pre(n));
//   printf("\n");


  int k=b;
  for (int j=0 ; j<(Nt-1) ; j++)
    {

      interval=in.extract(0,k,r-1,k+b-1);
      DP_next_interval(cur,norm_cur,this_trace,pre,norm_pre,interval,norm,index,label_horiz_chirplet);
      pre=cur;
      norm_pre=norm_cur;

//       for (int n=0; n<cur.length(); n++)
// 	printf(" %02d ",n);
//       printf("\n");
//       for (int n=0; n<cur.length(); n++)
// 	printf("%02.1f ",cur(n));
//       printf("\n");

//       for (int n=0; n<norm_cur.length(); n++)
// 	printf(" %02d ",n);
//       printf("\n");
//       for (int n=0; n<norm_cur.length(); n++)
// 	printf("%02.1f ",norm_cur(n));
//       printf("\n");

//       for (int n=0; n<this_trace.length(); n++)
// 	printf("%02d ",n);
//       printf("\n");
//       for (int n=0; n<this_trace.length(); n++)
// 	printf("%02d ",(int) this_trace(n));
//       printf("\n");
      
      if (trace_flag)
	trace.insert(this_trace,0,j);
      k+=b;
    }

  /******************   compute global max    ******************/

  int label_max=0;
  double max=(norm_cur(0)==0.0 ? 0.0 : cur(0)/norm_cur(0));
  for (int i=1 ; i<cur.length() ; i++)
    {
      double statistic=(norm_cur(i)==0.0 ? 0.0 : cur(i)/norm_cur(i));
      if (statistic > max)
	{
	  max=statistic;
	  label_max=i;
	}
    }

  retval(0)=sqrt(max);
  retval(1)=max;

  /******************   trace chirplet chain  ******************/  

  if (trace_flag)
    retval(2)=trace_chain(label_max,step,trace,label_horiz_chirplet,label_steepest_chirplet);

  if (debug_flag==2)
    check_trace(trace,label_horiz_chirplet,label_steepest_chirplet);
  
  return retval;
}

