/* Vgrid - Virtual grid program for radiology
   Copyright (C) 2020, 2021 Sonia Diaz Pacheco.

   This program is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#define _FILE_OFFSET_BITS 64

#include <climits>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <new>
#include <string>
#include <vector>
#include <png.h>

#include "vgrid.h"
#include "matrix.h"

#if PNG_LIBPNG_VER < 10209
#error "Wrong libpng version. At least 1.2.9 is required."
#endif


/* Check to see if a file is a PNG file using png_sig_cmp().
 * png_sig_cmp() returns zero if the image is a PNG, and nonzero otherwise.
 *
 * If this call is successful, and you are going to keep the file open,
 * you should call png_set_sig_bytes(png_ptr, PNG_BYTES_TO_CHECK); once
 * you have created the png_ptr, so that libpng knows your application
 * has read that many bytes from the start of the file.
 */
bool read_check_png_sig8( FILE * const f )
  {
  enum { sig_size = 8 };
  png_byte buf[sig_size];

  // read the signature bytes
  if( std::fread( buf, 1, sig_size, f ) != sig_size ) return false;
  return !png_sig_cmp( buf, 0, sig_size );
  }


/* Create a Matrix with elements in range [0, maxval] from a PNG file.
   If maxvalp != 0, return in *maxvalp the maxval of the original image.
   'sig_read' indicates the number of magic bytes already read (maybe 0). */
Matrix::Matrix( FILE * const f, const int sig_read, int * const maxvalp,
                Color_info * const color_infop, const bool invert )
  {
  /* Create and initialize the png_struct with the desired error handler
   * functions.  If you want to use the default stderr and longjump method,
   * you can supply NULL for the last three parameters.  We also supply the
   * compiler header file version, so that we know if the application
   * was compiled with a compatible version of the library.  REQUIRED.
   */
  png_structp png_ptr =
    png_create_read_struct( PNG_LIBPNG_VER_STRING, NULL, NULL, NULL );
  if( !png_ptr ) throw std::bad_alloc();

  /* Allocate/initialize the memory for image information.  REQUIRED. */
  png_infop info_ptr = png_create_info_struct( png_ptr );
  if( !info_ptr )
    {
    png_destroy_read_struct( &png_ptr, NULL, NULL );
    throw std::bad_alloc();
    }

  if( setjmp( png_jmpbuf( png_ptr ) ) )		// Set error handling
    {
    /* Free all of the memory associated with the png_ptr and info_ptr. */
    png_destroy_read_struct( &png_ptr, &info_ptr, NULL );
    throw Error( "Error reading PNG image." );
    }

  /* Set up the input control if you are using standard C streams. */
  png_init_io( png_ptr, f );

  /* If we have already read some of the signature */
  png_set_sig_bytes( png_ptr, sig_read );

  /* Read the entire image (including pixels) into the info structure with a
     call to png_read_png, equivalent to png_read_info(), followed by the
     set of transformations indicated by the transform mask, then
     png_read_image(), and finally png_read_end():
  */
  png_read_png( png_ptr, info_ptr,
                PNG_TRANSFORM_STRIP_ALPHA | PNG_TRANSFORM_EXPAND, NULL );

  /* At this point you have read the entire image.
     Now let's extract the data. */

  const unsigned height     = png_get_image_height( png_ptr, info_ptr );
  const unsigned width      = png_get_image_width( png_ptr, info_ptr );
  const unsigned bit_depth  = png_get_bit_depth( png_ptr, info_ptr );
  const unsigned maxval = ( 1 << bit_depth ) - 1;
  if( maxvalp ) *maxvalp = maxval;
  const unsigned color_type = png_get_color_type( png_ptr, info_ptr );
  const unsigned channels   = png_get_channels( png_ptr, info_ptr );
  /* bit_depth  - holds the bit depth of one of the image channels. (valid
                  values are 1, 2, 4, 8, 16 and depend also on color_type.
     color_type - describes which color/alpha channels are present.
                  PNG_COLOR_TYPE_GRAY (bit depths 1, 2, 4, 8, 16)
                  PNG_COLOR_TYPE_GRAY_ALPHA (bit depths 8, 16)
                  PNG_COLOR_TYPE_PALETTE (bit depths 1, 2, 4, 8)
                  PNG_COLOR_TYPE_RGB (bit_depths 8, 16)
                  PNG_COLOR_TYPE_RGB_ALPHA (bit_depths 8, 16)
                  PNG_COLOR_MASK_PALETTE
                  PNG_COLOR_MASK_COLOR
                  PNG_COLOR_MASK_ALPHA
     channels   - number of channels of info for the color type
                  (valid values are 1 (GRAY, PALETTE), 2 (GRAY_ALPHA),
                  3 (RGB), 4 (RGB_ALPHA or RGB + filler byte)) */

  if( ( color_type != PNG_COLOR_TYPE_GRAY && color_type != PNG_COLOR_TYPE_RGB ) ||
      ( channels != 1 && channels != 3 ) )
    throw Error( "Unsupported type of PNG image." );

  const png_bytepp row_pointers = png_get_rows( png_ptr, info_ptr );

  data.resize( height );			// fill data vector
  if( channels == 1 )
    for( unsigned row = 0; row < height; ++row )
      {
      const png_byte * ptr = row_pointers[row];
      for( unsigned col = 0; col < width; ++col )
        {
        unsigned val = *ptr++;
        if( bit_depth == 16 ) val = ( val << 8 ) + *ptr++;
        data[row].push_back( invert ? maxval - val : val );
        }
      }
  else if( channels == 3 )
    {
    if( color_infop ) color_infop->resize( height, width );
    for( unsigned row = 0; row < height; ++row )
      {
      const png_byte * ptr = row_pointers[row];
      for( unsigned col = 0; col < width; ++col )
        {
        unsigned r = *ptr++;
        if( bit_depth == 16 ) r = ( r << 8 ) + *ptr++;
        unsigned g = *ptr++;
        if( bit_depth == 16 ) g = ( g << 8 ) + *ptr++;
        unsigned b = *ptr++;
        if( bit_depth == 16 ) b = ( b << 8 ) + *ptr++;
        const double val = ( r + g + b ) / 3.0;
        data[row].push_back( invert ? maxval - val : val );
        if( color_infop )
          {
          color_infop->data[0][row].push_back( r );
          color_infop->data[1][row].push_back( g );
          color_infop->data[2][row].push_back( b );
          }
        }
      }
    }

  /* Clean up after the read, and free any memory allocated.  REQUIRED. */
  png_destroy_read_struct( &png_ptr, &info_ptr, NULL );
  }


bool Matrix::write_png( FILE * const f, const unsigned bit_depth,
                        const Color_info & color_info, const bool expand ) const
  {
  if( bit_depth != 8 && bit_depth != 16 )
    throw Error( "Invalid bit depth writing PNG image." );
  const unsigned channels = ( color_info.data.size() == 3 ) ? 3 : 1;
  /* row_bytes is the width x number of channels x (bit-depth / 8) */
  const unsigned row_bytes = width() * channels * ((bit_depth <= 8) ? 1 : 2);
  png_byte * const png_pixels = (png_byte *)std::malloc( height() * row_bytes );
  if( !png_pixels ) return false;
  png_byte ** const row_pointers =
    (png_byte **)std::malloc( height() * sizeof row_pointers[0] );
  if( !row_pointers ) { std::free( png_pixels ); return false; }

  // fill png_pixels[] and set row_pointers
  const double maxval = ( 1 << bit_depth ) - 1;
  int idx = 0;					// index in png_pixels[]
  if( channels == 1 )
    {
    if( !expand )				// write image as-is
      for( int row = 0; row < height(); ++row )
        for( int col = 0; col < width(); ++col )
          {
          const double dval = nearbyint( data[row][col] );
          const unsigned val = (unsigned)std::max( 0.0, std::min( maxval, dval ) );
          if( bit_depth == 16 ) png_pixels[idx++] = ( val >> 8 ) & 0xFF;
          png_pixels[idx++] = val & 0xFF;
          }
    else					// fully expand pixel values
      {
      const double cached_max = max();
      const double cached_min = min();
      const double range = (cached_max > cached_min) ? cached_max - cached_min : 1;
      for( int row = 0; row < height(); ++row )
        for( int col = 0; col < width(); ++col )
          {
          const unsigned val = (unsigned)nearbyint(
                maxval * ( data[row][col] - cached_min ) / range );
          if( bit_depth == 16 ) png_pixels[idx++] = ( val >> 8 ) & 0xFF;
          png_pixels[idx++] = val & 0xFF;
          }
      }
    }
  else	// channels == 3
    {
    if( !expand )				// write image as-is
      for( int row = 0; row < height(); ++row )
        for( int col = 0; col < width(); ++col )
          {
          unsigned rgb[3];
          for( int plane = 0; plane < 3; ++plane )
            rgb[plane] = color_info.data[plane][row][col];
          const double old_val = ( rgb[0] + rgb[1] + rgb[2] ) / 3.0;
          const double new_val = std::max( 0.0, std::min( maxval, data[row][col] ) );
          double a = new_val; if( old_val > 0 ) a /= old_val;	// amplification
          for( int plane = 0; plane < 3; ++plane )	// limit amplification
            if( a * rgb[plane] > maxval ) a = maxval / rgb[plane];
          for( int plane = 0; plane < 3; ++plane )
            {
            const unsigned val = (unsigned)nearbyint(
              std::max( 0.0, std::min( maxval, a * rgb[plane] ) ) );
            if( bit_depth == 16 ) png_pixels[idx++] = ( val >> 8 ) & 0xFF;
            png_pixels[idx++] = val & 0xFF;
            }
          }
    else					// fully expand pixel values
      {
      const double cached_max = max();
      const double cached_min = min();
      const double range = (cached_max > cached_min) ? cached_max - cached_min : 1;
      for( int row = 0; row < height(); ++row )
        for( int col = 0; col < width(); ++col )
          {
          unsigned rgb[3];
          for( int plane = 0; plane < 3; ++plane )
            rgb[plane] = color_info.data[plane][row][col];
          const double old_val = ( rgb[0] + rgb[1] + rgb[2] ) / 3.0;
          const double new_val = maxval * ( data[row][col] - cached_min ) / range;
          double a = new_val; if( old_val > 0 ) a /= old_val;	// amplification
          for( int plane = 0; plane < 3; ++plane )	// limit amplification
            if( a * rgb[plane] > maxval ) a = maxval / rgb[plane];
          for( int plane = 0; plane < 3; ++plane )
            {
            const unsigned val = (unsigned)nearbyint(
              std::max( 0.0, std::min( maxval, a * rgb[plane] ) ) );
            if( bit_depth == 16 ) png_pixels[idx++] = ( val >> 8 ) & 0xFF;
            png_pixels[idx++] = val & 0xFF;
            }
          }
      }
    }
  for( int i = 0; i < height(); ++i )
    row_pointers[i] = png_pixels + ( i * row_bytes );

  /* Create and initialize the png_struct with the desired error handler
   * functions.  If you want to use the default stderr and longjump method,
   * you can supply NULL for the last three parameters.  We also check that
   * the library version is compatible with the one used at compile time,
   * in case we are using dynamically linked libraries.  REQUIRED.
   */
  png_structp png_ptr =
    png_create_write_struct( PNG_LIBPNG_VER_STRING, NULL, NULL, NULL );
  if( !png_ptr )
    { std::free( row_pointers ); std::free( png_pixels ); return false; }

  /* Allocate/initialize the image information data.  REQUIRED. */
  png_infop info_ptr = png_create_info_struct( png_ptr );
  if( !info_ptr )
    {
    png_destroy_write_struct( &png_ptr,  NULL );
    std::free( row_pointers );
    std::free( png_pixels );
    return false;
    }

  /* Set up error handling. */
  if( setjmp( png_jmpbuf( png_ptr ) ) )
    {
    /* If we get here, we had a problem writing the file. */
    png_destroy_write_struct( &png_ptr, &info_ptr );
    std::free( row_pointers );
    std::free( png_pixels );
    return false;
    }

  /* Set up the output control if you are using standard C streams. */
  png_init_io( png_ptr, f );

  png_set_IHDR( png_ptr, info_ptr, width(), height(), bit_depth,
                ( channels == 1 ) ? PNG_COLOR_TYPE_GRAY : PNG_COLOR_TYPE_RGB,
                PNG_INTERLACE_NONE,
                PNG_COMPRESSION_TYPE_BASE, PNG_FILTER_TYPE_BASE );

  png_set_rows( png_ptr, info_ptr, row_pointers );

  // write the PNG image
  png_write_png( png_ptr, info_ptr, PNG_TRANSFORM_IDENTITY, NULL );

  /* Clean up after the write, and free any allocated memory. */
  png_destroy_write_struct( &png_ptr, &info_ptr );
  std::free( row_pointers );
  std::free( png_pixels );
  return true;
  }


// save intermediate (numbered if n >= 0) debug images
bool Matrix::write_png( const std::string & filename, int n,
                        const unsigned bit_depth, const bool expand ) const
  {
  std::string name( filename );				// base file name
  if( n >= 0 )
    {
    const unsigned i = name.size();
    do { name.insert( name.begin() + i, n % 10 + '0' ); n /= 10; }
    while( n > 0 );					// add number
    }
  name += ".png";					// add extension
  FILE * const f = std::fopen( name.c_str(), "wb" );
  if( !f ) return false;
  return ( write_png( f, bit_depth, Color_info(), expand ) &
           ( std::fclose( f ) == 0 ) );		// close even if write fails
  }


int raw_to_png( const char * const input_filename, FILE * const infile,
                FILE * const f, const Raw_params & raw_params )
  {
  if( raw_params.bit_depth != 8 && raw_params.bit_depth != 16 )
    { show_error( "Invalid bit depth writing PNG image." ); return 1; }
  if( raw_params.channels != 1 && raw_params.channels != 3 )
    { show_error( "Invalid number of channels writing PNG image." ); return 1; }
  /* row_bytes is the width x number of channels x (bit-depth / 8) */
  const unsigned row_bytes = raw_params.width * raw_params.channels *
                             ((raw_params.bit_depth <= 8) ? 1 : 2);
  const unsigned long raw_size = raw_params.height * row_bytes;
  png_byte * png_pixels = new png_byte[raw_size];
  png_byte ** row_pointers = new png_byte *[raw_params.height];

  // fill png_pixels[] and set row_pointers
  if( std::fread( png_pixels, raw_size, 1, infile ) != 1 )
    { show_file_error( input_filename, "Error reading raw data." ); return 1; }
  for( unsigned i = 0; i < raw_params.height; ++i )
    row_pointers[i] = png_pixels + ( i * row_bytes );

  /* Create and initialize the png_struct with the desired error handler
   * functions.  If you want to use the default stderr and longjump method,
   * you can supply NULL for the last three parameters.  We also check that
   * the library version is compatible with the one used at compile time,
   * in case we are using dynamically linked libraries.  REQUIRED.
   */
  png_structp png_ptr =
    png_create_write_struct( PNG_LIBPNG_VER_STRING, NULL, NULL, NULL );
  if( !png_ptr )
    { delete[] row_pointers; delete[] png_pixels; throw std::bad_alloc(); }

  /* Allocate/initialize the image information data.  REQUIRED. */
  png_infop info_ptr = png_create_info_struct( png_ptr );
  if( !info_ptr )
    {
    png_destroy_write_struct( &png_ptr,  NULL );
    delete[] row_pointers;
    delete[] png_pixels;
    throw std::bad_alloc();
    }

  /* Set up error handling. */
  if( setjmp( png_jmpbuf( png_ptr ) ) )
    {
    /* If we get here, we had a problem writing the file. */
    png_destroy_write_struct( &png_ptr, &info_ptr );
    delete[] row_pointers;
    delete[] png_pixels;
    throw std::bad_alloc();
    }

  /* Set up the output control if you are using standard C streams. */
  png_init_io( png_ptr, f );

  png_set_IHDR( png_ptr, info_ptr, raw_params.width, raw_params.height,
                raw_params.bit_depth, ( raw_params.channels == 1 ) ?
                PNG_COLOR_TYPE_GRAY : PNG_COLOR_TYPE_RGB, PNG_INTERLACE_NONE,
                PNG_COMPRESSION_TYPE_BASE, PNG_FILTER_TYPE_BASE );

  png_set_rows( png_ptr, info_ptr, row_pointers );

  // write the PNG image
  png_write_png( png_ptr, info_ptr, PNG_TRANSFORM_IDENTITY, NULL );

  /* Clean up after the write, and free any allocated memory. */
  png_destroy_write_struct( &png_ptr, &info_ptr );
  delete[] row_pointers;
  delete[] png_pixels;
  return 0;
  }


int show_png_info( const char * const input_filename, FILE * const f,
                   const int sig_read )
  {
  if( verbosity >= 0 ) std::printf( "%s\n", input_filename );
  png_structp png_ptr =
    png_create_read_struct( PNG_LIBPNG_VER_STRING, NULL, NULL, NULL );
  if( !png_ptr ) throw std::bad_alloc();

  png_infop info_ptr = png_create_info_struct( png_ptr );
  if( !info_ptr )
    {
    png_destroy_read_struct( &png_ptr, NULL, NULL );	// avoid memory leak
    throw std::bad_alloc();
    }

  if( setjmp( png_jmpbuf( png_ptr ) ) )		// Set error handling
    {
    /* Free all of the memory associated with the png_ptr and info_ptr. */
    png_destroy_read_struct( &png_ptr, &info_ptr, NULL );
    throw Error( "Error reading PNG image." );
    }

  /* Set up the input control if you are using standard C streams. */
  png_init_io( png_ptr, f );

  /* If we have already read some of the signature */
  png_set_sig_bytes( png_ptr, sig_read );

  if( verbosity <= 0 ) png_read_info( png_ptr, info_ptr );	// read info
  else png_read_png( png_ptr, info_ptr,		// read the entire image
                     PNG_TRANSFORM_STRIP_ALPHA | PNG_TRANSFORM_EXPAND, NULL );

  /* Now let's print the data. */

  const unsigned height     = png_get_image_height( png_ptr, info_ptr );
  const unsigned width      = png_get_image_width( png_ptr, info_ptr );
  const long size           = height * width;
  const unsigned bit_depth  = png_get_bit_depth( png_ptr, info_ptr );
  const unsigned maxval     = ( 1 << bit_depth ) - 1;
  const unsigned color_type = png_get_color_type( png_ptr, info_ptr );
  const unsigned channels   = png_get_channels( png_ptr, info_ptr );
  const unsigned interlace_type = png_get_interlace_type( png_ptr, info_ptr );
  const char * ct;
  if( color_type == PNG_COLOR_TYPE_GRAY_ALPHA ) ct = "greyscale with alpha channel";
  else if( color_type == PNG_COLOR_TYPE_GRAY )      ct = "greyscale";
  else if( color_type == PNG_COLOR_TYPE_PALETTE )   ct = " colormap";
  else if( color_type == PNG_COLOR_TYPE_RGB )       ct = "RGB";
  else if( color_type == PNG_COLOR_TYPE_RGB_ALPHA ) ct = "RGB with alpha channel";
  else if( color_type == PNG_COLOR_MASK_PALETTE )   ct = "mask colormap";
  else if( color_type == PNG_COLOR_MASK_COLOR )     ct = "mask color";
  else if( color_type == PNG_COLOR_MASK_ALPHA )     ct = "mask alpha";
  else ct = "unknown color_type";

  if( verbosity >= 0 )
    std::printf( "  PNG image %4u x %4u (%5.2f megapixels), "
                 "%2u-bit %s, %u channel(s), %sinterlaced\n",
                 width, height, size / 1000000.0, bit_depth, ct, channels,
                 ( interlace_type == PNG_INTERLACE_NONE ) ? "non-" : "" );

  if( verbosity >= 1 )
    {
    const png_bytepp row_pointers = png_get_rows( png_ptr, info_ptr );

    double mean = 0;
    unsigned min = UINT_MAX, max = 0;			// find extremes
    std::vector< long > histogram( maxval + 1, 0 );	// range of pixel values [0,maxval]
    if( channels == 1 )
      for( unsigned row = 0; row < height; ++row )
        {
        const png_byte * ptr = row_pointers[row];
        for( unsigned col = 0; col < width; ++col )
          {
          unsigned val = *ptr++;
          if( bit_depth == 16 ) val = ( val << 8 ) + *ptr++;
          mean += val;
          ++histogram[val];
          if( min > val ) { min = val; } if( max < val ) max = val;
          }
        }
    else if( channels == 3 )
      for( unsigned row = 0; row < height; ++row )
        {
        const png_byte * ptr = row_pointers[row];
        for( unsigned col = 0; col < width; ++col )
          {
          unsigned r = *ptr++;
          if( bit_depth == 16 ) r = ( r << 8 ) + *ptr++;
          unsigned g = *ptr++;
          if( bit_depth == 16 ) g = ( g << 8 ) + *ptr++;
          unsigned b = *ptr++;
          if( bit_depth == 16 ) b = ( b << 8 ) + *ptr++;
          const unsigned val = (unsigned)nearbyint( ( r + g + b ) / 3.0 );
          mean += val;
          ++histogram[val];
          if( min > val ) { min = val; } if( max < val ) max = val;
          }
        }
    mean /= size;
    long mode_count = 0;
    unsigned mode = 0;					// statistical mode
    for( unsigned i = 0; i < histogram.size(); ++i )
      if( mode_count < histogram[i] ) { mode_count = histogram[i]; mode = i; }

    std::printf( "  min  = %5u      max  = %5u  range = %5u = %6.2f%% (maxval = %5u)\n"
                 "  mean = %9.3f  mode = %5u (%ld)\n",
                 min, max, max - min + 1,
                 ( 100.0 * ( max - min + 1 ) ) / ( maxval + 1 ), maxval,
                 mean, mode, mode_count );
    }

  png_destroy_read_struct( &png_ptr, &info_ptr, NULL );
  return 0;
  }
