/* Vgrid - Virtual grid program for radiology
   Copyright (C) 2020, 2021 Sonia Diaz Pacheco.

   This program is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#define _FILE_OFFSET_BITS 64

#include <cmath>
#include <cstdio>
#include <cstring>
#include <string>
#include <vector>

#include "vgrid.h"
#include "matrix.h"


namespace {

/* 'pos' ranges from 0 to 'length'.
   'pos == 0' returns 'a'. 'pos == length' returns 'b'. */
inline double weighted_average( const double a, const double b,
                                const long pos, const long length )
  { return ( a * ( length - pos ) + b * pos ) / length; }


struct Area		// ranges of pixel values inside a rectangular area
  {
  long left, top, width, height;
  double vmax, vmin, mean, b_range, d_range;	// bright and dark ranges

  Area( const long l, const long t, const long w, const long h )
    : left( l ), top( t ), width( w ), height( h ) {}
  Area( const long l, const long t, const long w, const long h,
        const Matrix & image )
    : left( l ), top( t ), width( w ), height( h ) { find_ranges( image ); }

  long hcenter() const { return left + width / 2; }  // horizontal coord of center
  long vcenter() const { return top + height / 2; }  // vertical coord of center
  long right() const { return left + width; }
  long bottom() const { return top + height; }

  // finds ranges above mean (bright) and below mean (dark)
  void find_ranges( const Matrix & image )
    {
    const long size = height * width;
    vmax = -INFINITY;
    vmin =  INFINITY;
    mean = 0;
    for( long row = top; row < top + height; ++row )
      for( long col = left; col < left + width; ++col )
        {
        const double val = image.get_element( row, col );
        mean += val;
        if( vmax < val )
          { const double mval = image.median( row, col, 1 );
            if( vmax < mval ) vmax = mval; }
        if( vmin > val )
          { const double mval = image.median( row, col, 1 );
            if( vmin > mval ) vmin = mval; }
        }
    mean /= size;
    update_ranges();
    }

  void remove_letters_from_range( const Matrix & image, const Area & parent,
                                  const int maxval )
    {
    if( vmax < parent.vmax ) return;		// no white letters found
    std::vector< long > histogram( maxval + 1, 0 );	// range of pixel values [0,maxval]
    for( long row = top; row < top + height; ++row )
      for( long col = left; col < left + width; ++col )
        ++histogram[(int)image.get_element( row, col )];
    std::vector< long > chistogram( histogram );	// cumulative histogram
    for( unsigned i = 1; i < chistogram.size(); ++i )
      chistogram[i] += chistogram[i-1];

    // ignore a large last bin of white pixels (probably overprinted letters)
    const long size = height * width;
    for( unsigned i = 0; i < chistogram.size(); ++i )
      if( chistogram[i] >= size )
        {
        if( i == 0 ) break;		// no letters, all image is black
        if( histogram[i] < 4 * histogram[i-1] ) break;	// not large enough
        // skip possible gap between letters and brighter pixels in image
        while( --i > 0 && chistogram[i] > 0 && histogram[i] < 1 ) {}
        if( i > vmin && i < vmax )	// remove letters from range
          {
          vmax = i; mean = 0;
          for( unsigned j = 0; j <= i; ++j ) mean += j * histogram[j];
          mean /= chistogram[i];	// mean without the letters
          update_ranges();
          }
        break;
        }
    }

  void update_ranges()		// set ranges according to statistic values
    {
    b_range = ( vmax > mean ) ? vmax - mean : 1;
    d_range = ( mean > vmin ) ? mean - vmin : 1;
    }

  // enlarge local range to limit gain later when expanded to full range
  void set_max_gain( const Area & parent, const double max_gain )
    {
    b_range = std::max( b_range, parent.b_range / max_gain );
    d_range = std::max( d_range, parent.d_range / max_gain );
    }
  };


class Partition			// matrix of areas with adjusted ranges
  {
  std::vector< std::vector< Area > > data;	// 2D array of Areas

public:
          /* Keep the areas (almost) square by making N parts in the shorter
             side and (N * aspect_ratio) parts in the longer side. */
  Partition( const Area & global_area, const Matrix & image, const int parts,
             const int maxval )
    {
    const long gheight = global_area.height;
    const long gwidth = global_area.width;
    if( gheight < parts || gwidth < parts ) return;
    const int vparts = std::max( gheight * parts / gwidth, (long)parts );
    const int hparts = std::max( gwidth * parts / gheight, (long)parts );
    data.resize( vparts );
    long t = global_area.top;
    for( int i = 1; i <= vparts; ++i )
      {
      const long h = ( ( i * gheight ) / vparts ) + global_area.top - t;
      long l = global_area.left;
      for( int j = 1; j <= hparts; ++j )
        {
        const long w = ( ( j * gwidth ) / hparts ) + global_area.left - l;
        data[i-1].push_back( Area( l, t, w, h ) );
        l += w;
        }
      t += h;
      }
    for( int i = 0; i < vparts; ++i )		// find ranges of each area
      for( int j = 0; j < hparts; ++j )
        {
        data[i][j].find_ranges( image );
        data[i][j].remove_letters_from_range( image, global_area, maxval );
        }

    const std::vector< std::vector< Area > > copy = data;
    for( int i = 0; i < vparts; ++i )		// adjust ranges of each area
      for( int j = 0; j < hparts; ++j )
        {
        double fvmax = -INFINITY;		// vmax of fringe areas
        double fvmin =  INFINITY;		// vmin of fringe areas
        double fmean = 0;		// sum of means of fringe areas
        int fareas = 0;				// number of fringe areas
        for( int dr = -1; dr <= 1; ++dr )		// row delta
          {
          if( i + dr < 0 || i + dr >= height() ) continue;	// no area here
          for( int dc = -1; dc <= 1; ++dc )		// column delta
            {
            if( j + dc < 0 || j + dc >= width() ) continue;	// no area here
            if( dr == 0 && dc == 0 ) continue;		// skip central area
            const Area & f = copy[i+dr][j+dc];		// fringe area
            if( fvmax < f.vmax ) fvmax = f.vmax;
            if( fvmin > f.vmin ) fvmin = f.vmin;
            fmean += f.mean; ++fareas;
            }
          }
        Area & a = data[i][j];				// central area
        if( a.vmax < fvmax ) a.vmax = ( a.vmax + fvmax ) / 2;
        if( a.vmin > fvmin ) a.vmin = ( a.vmin + fvmin ) / 2;
        a.mean = ( a.mean + fmean ) / ( fareas + 1 );	// widened local mean
        a.update_ranges();
        }

    for( int i = 0; i < vparts; ++i )		// limit gain in each area
      for( int j = 0; j < hparts; ++j )
        data[i][j].set_max_gain( global_area, 5 );
    }

  int height() const { return data.size(); }
  int width() const { return data[0].size(); }
  const Area & get_element( const long row, const long col ) const
    { return data[row][col]; }
  };


// calculate and assign the new (expanded) value of a pixel
inline void expand_pixel( Matrix & r, const long row, const long col,
                          const double d_range, const double mean,
                          const double b_range, const int midval,
                          const int midval1 )
  {
  const double val = r.get_element( row, col );
  if( val < mean )	// pixel is dark
    r.set_element( row, col, nearbyint( midval1 - midval1 * ( mean - val ) / d_range ) );
  else			// pixel is bright
    r.set_element( row, col, nearbyint( midval + midval1 * ( val - mean ) / b_range ) );
  }


// expand contrast in the most exterior quadrant of each corner area
void expand_corner( Matrix & r, const Area & a, const int midval )
  {
  const int midval1 = midval - 1;
  for( long row = a.top; row < a.bottom(); ++row )
    for( long col = a.left; col < a.right(); ++col )
      expand_pixel( r, row, col, a.d_range, a.mean, a.b_range, midval, midval1 );
  }


// calculate coords of the most exterior quadrant of each corner area
void expand_corners( Matrix & r, const Partition & p, const int maxval )
  {
  const int midval = ( maxval + 1 ) / 2;
  Area a = p.get_element( 0, 0 ); a.height /= 2; a.width /= 2;
  expand_corner( r, a, midval );		// expand upper left corner

  a = p.get_element( 0, p.width() - 1 ); a.height /= 2;
  a.left += a.width / 2; a.width -= a.width / 2;
  expand_corner( r, a, midval );		// expand upper right corner

  a = p.get_element( p.height() - 1, 0 ); a.width /= 2;
  a.top += a.height / 2; a.height -= a.height / 2;
  expand_corner( r, a, midval );		// expand lower left corner

  a = p.get_element( p.height() - 1, p.width() - 1 );
  a.top += a.height / 2; a.height -= a.height / 2;
  a.left += a.width / 2; a.width -= a.width / 2;
  expand_corner( r, a, midval );		// expand lower right corner
  }


// expand contrast in image borders between centers of each 2 areas
void expand_periphery( Matrix & r, const Partition & p, const int maxval )
  {
  const int midval = ( maxval + 1 ) / 2;
  const int midval1 = midval - 1;
  // expand top and bottom borders
  for( int j = 0; j + 1 < p.width(); ++j )
    {
    const Area & a1 = p.get_element( 0, j );		// top left
    const Area & a2 = p.get_element( 0, j + 1 );	// top right
    const Area & a3 = p.get_element( p.height() - 1, j );	// bottom left
    const Area & a4 = p.get_element( p.height() - 1, j + 1 );	// bottom right
    const long left = a1.hcenter();
    const long right = a2.hcenter();
    const long width = right - left;
    for( long col = left; col < right; ++col )
      {
      const long hpos = col - left;	// horizontal position in rectangle
      // expand top border
      double d_range = weighted_average( a1.d_range, a2.d_range, hpos, width );
      double mean = weighted_average( a1.mean, a2.mean, hpos, width );
      double b_range = weighted_average( a1.b_range, a2.b_range, hpos, width );
      for( long row = a1.top; row < a1.vcenter(); ++row )
        expand_pixel( r, row, col, d_range, mean, b_range, midval, midval1 );

      // expand bottom border
      d_range = weighted_average( a3.d_range, a4.d_range, hpos, width );
      mean = weighted_average( a3.mean, a4.mean, hpos, width );
      b_range = weighted_average( a3.b_range, a4.b_range, hpos, width );
      for( long row = a3.vcenter(); row < a3.bottom(); ++row )
        expand_pixel( r, row, col, d_range, mean, b_range, midval, midval1 );
      }
    }

  // expand left and right borders
  for( int i = 0; i + 1 < p.height(); ++i )
    {
    const Area & a1 = p.get_element( i, 0 );		// upper left
    const Area & a2 = p.get_element( i + 1, 0 );	// lower left
    const Area & a3 = p.get_element( i, p.width() - 1 );	// upper right
    const Area & a4 = p.get_element( i + 1, p.width() - 1 );	// lower right
    const long top = a1.vcenter();
    const long bottom = a2.vcenter();
    const long height = bottom - top;
    for( long row = top; row < bottom; ++row )
      {
      const long vpos = row - top;	// vertical position in rectangle
      // expand left border
      double d_range = weighted_average( a1.d_range, a2.d_range, vpos, height );
      double mean = weighted_average( a1.mean, a2.mean, vpos, height );
      double b_range = weighted_average( a1.b_range, a2.b_range, vpos, height );
      for( long col = a1.left; col < a1.hcenter(); ++col )
        expand_pixel( r, row, col, d_range, mean, b_range, midval, midval1 );

      // expand right border
      d_range = weighted_average( a3.d_range, a4.d_range, vpos, height );
      mean = weighted_average( a3.mean, a4.mean, vpos, height );
      b_range = weighted_average( a3.b_range, a4.b_range, vpos, height );
      for( long col = a3.hcenter(); col < a3.right(); ++col )
        expand_pixel( r, row, col, d_range, mean, b_range, midval, midval1 );
      }
    }
  }


// expand contrast in internal rectangles among centers of each 4 areas
void expand_interior( Matrix & r, const Partition & p, const int maxval )
  {
  const int midval = ( maxval + 1 ) / 2;		// 128 or 32768
  const int midval1 = midval - 1;
  for( int i = 0; i + 1 < p.height(); ++i )		// upper row of group
    for( int j = 0; j + 1 < p.width(); ++j )		// left col of group
      {
      const Area & a1 = p.get_element( i, j );		// upper left
      const Area & a2 = p.get_element( i, j + 1 );	// upper right
      const Area & a3 = p.get_element( i + 1, j );	// lower left
      const Area & a4 = p.get_element( i + 1, j + 1 );	// lower right
      const long left = a1.hcenter();
      const long right = a2.hcenter();
      const long width = right - left;
      const long top = a1.vcenter();
      const long bottom = a3.vcenter();
      const long height = bottom - top;
      for( long row = top; row < bottom; ++row )
        {
        const long vpos = row - top;	// vertical position in rectangle
        const double ld_range = weighted_average( a1.d_range, a3.d_range, vpos, height );
        const double l_mean = weighted_average( a1.mean, a3.mean, vpos, height );
        const double lb_range = weighted_average( a1.b_range, a3.b_range, vpos, height );
        const double rd_range = weighted_average( a2.d_range, a4.d_range, vpos, height );
        const double r_mean = weighted_average( a2.mean, a4.mean, vpos, height );
        const double rb_range = weighted_average( a2.b_range, a4.b_range, vpos, height );
        for( long col = left; col < right; ++col )
          {
          const long hpos = col - left;	// horizontal position in rectangle
          const double hmean = weighted_average( l_mean, r_mean, hpos, width );
          const double val = r.get_element( row, col );
          if( val < hmean )	// pixel is dark
            {
            const double range = weighted_average( ld_range, rd_range, hpos, width );
            r.set_element( row, col, nearbyint( midval1 - midval1 * ( hmean - val ) / range ) );
            }
          else			// pixel is bright
            {
            const double range = weighted_average( lb_range, rb_range, hpos, width );
            r.set_element( row, col, nearbyint( midval + midval1 * ( val - hmean ) / range ) );
            }
          }
        }
      }
  }

} // end namespace


/* Implementation of local contrast amplification by Adaptive Windowing (AW).
 * image - Input image
 * outfile - Pointer to output file
 * maxval - Maximum possible pixel value of the input PNG image (255 or 65535)
 * bit_depth - Desired bit depth of the output PNG image (8 or 16)
 * median_radius - If > 0, run output through a median filter of this radius
 * expand - If true, adjust image values to fit in the output bit depth
 */
int recontrast( const Matrix & image, FILE * outfile, const int maxval,
                const unsigned bit_depth, const int median_radius,
                const Color_info & color_info, const bool expand )
  {
  const int areas[] = { 3, 5, 9, 15, 0 };  // areas in short side of each plane
  // global area containing the ranges of the full image
  const Area global_area( 0, 0, image.width(), image.height(), image );
  Matrix sum( image.height(), image.width() );	// sum of image planes
  int planes = 0;			// total number of planes averaged
  for( int i = 0; areas[i] > 0; ++i )
    {
    // divide image in NxM areas
    Partition p( global_area, image, areas[i], maxval );
    Matrix r( image );				// one image plane
    expand_corners( r, p, maxval );
    expand_periphery( r, p, maxval );
    expand_interior( r, p, maxval );
    sum += r;
    ++planes;
//    r.write_png( "plane-", areas[i], real_bits( maxval ) );
    }
  sum /= planes;				// average of all image planes
  if( maxval != ( 1 << bit_depth ) - 1 )	// adjust to bit_depth
    sum *= (double)( ( 1 << bit_depth ) - 1 ) / maxval;
  if( median_radius > 0 ) sum = sum.median_filter( median_radius );
  sum.write_png( outfile, bit_depth, color_info, expand );	// save PNG image
  return 0;
  }
