/************************************************************************\
 * Magic Square solves magic squares.                                   *
 * Copyright (C) 2019  Asher Gordon <AsDaGo@posteo.net>                 *
 *                                                                      *
 * This file is part of Magic Square.                                   *
 *                                                                      *
 * Magic Square is free software: you can redistribute it and/or modify *
 * it under the terms of the GNU General Public License as published by *
 * the Free Software Foundation, either version 3 of the License, or    *
 * (at your option) any later version.                                  *
 *                                                                      *
 * Magic Square is distributed in the hope that it will be useful,      *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of       *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the        *
 * GNU General Public License for more details.                         *
 *                                                                      *
 * You should have received a copy of the GNU General Public License    *
 * along with Magic Square.  If not, see                                *
 * <https://www.gnu.org/licenses/>.                                     *
\************************************************************************/

/* square.c -- functions for operating on magic squares */

#ifdef HAVE_CONFIG_H
# include <config.h>
#endif

#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>

#ifndef assert
# include <assert.h>
#endif

#include "square.h"
#include "write.h"

/* Static helper functions */
static cellval_t add_cellval(cellval_t, cellval_t);
static int square_solve_recursive(square_t *, cellval_t, square_list_t **,
				  size_t *, char, FILE *);
static int square_valid(square_t *, cellval_t);

/* Solve a magic square and return 0 on success and put all of the
   solutions in `*solutions' (which will be allocated) if `solutions'
   is not NULL. Store the size in `*solutions_size' if
   `solutions_size' is not NULL. If `keep_going' is nonzero, keep
   finding solutions even after the first match. If `file' is not
   NULL, print first the description of the square (if any) followed
   by each of the solutions. Returns nonzero on error (in which case
   `*solutions' should NOT be free()'d). `*solutions' also should NOT
   be free()'d in the event that no solutions are found. Note that all
   the elements of `*solutions' should be free()'d also with
   square_destroy(). */
int square_solve(square_t *square, square_t **solutions,
		 size_t *solutions_size, char keep_going, FILE *file) {
  square_list_t *solutions_list, *beginning;
  cellval_t sum; /* The sum that all the rows, columns, and center
		    diagonals add up to */
  size_t solutions_size_local, *solutions_size_local_ptr;

  solutions_size_local_ptr = solutions_size ?
    solutions_size : &solutions_size_local;

  *solutions_size_local_ptr = 0;

  /* Get the sum */
  sum.type = INT;
  sum.i = 0;

  /* Add all of the immutable squares */
  for (size_t x = 0; x < square->size; x++) {
    for (size_t y = 0; y < square->size; y++) {
      if (!square->cells[x][y].mutable) {
	sum = add_cellval(sum, square->cells[x][y].val);

	if (!cellval_valid(sum)) {
	  /* That means an error occurred */
	  return 1;
	}
      }
    }
  }

  /* Add the other numbers too */
  for (size_t i = 0; i < square->nums_size; i++) {
    sum = add_cellval(sum, square->nums[i]);

    if (!cellval_valid(sum)) {
      /* That means an error occurred */
      return 1;
    }
  }

  /* Divide the sum of all the numbers by the size of the square to
     get the magic sum. */
  if (sum.type == INT) {
    char is_neg = (sum.i < 0);

    if (is_neg) {
      /* Negative integer division and modulo arithmetic does not work
	 properly for some reason when using values computed at
	 run-time. This looks like a compiler bug, but we should work
	 around it for now. */
      sum.i = -sum.i;
    }

    if (sum.i % square->size) {
      /* If the size doesn't divide evenly into the sum of all the
	 numbers, the square is impossible to solve. Think about it:
	 If there are only integers, how do you expect to get a
	 non-integer sum by adding them together? */
      return 0;
    }

    sum.i = sum.i / square->size;

    if (is_neg)
      sum.i = -sum.i;
  }
  else {
    sum.f = sum.f / square->size;
  }

  /* If `solutions' is NULL, we aren't saving any solutions. */
  if (solutions) {
    /* Allocate the solutions list */
    solutions_list = malloc(sizeof(*solutions_list));
    /* Remember where the beginning is */
    beginning = solutions_list;
  }

  /* Print the description, if any and `file' isn't NULL. */
  if (square->description && file) {
    fputs(square->description, file);
    putc('\n', file);
  }

  /* Now solve the square! */
  if (square_solve_recursive(square, sum, solutions ? &solutions_list : NULL,
			     solutions_size_local_ptr, keep_going, file)) {
    /* An error occured */
    if (solutions) {
      while (beginning != solutions_list) {
	square_list_t *old = beginning;

	beginning = beginning->next;

	square_destroy(&(old->square));
	free(old);
      }

      free(solutions_list);
    }

    return 1;
  }

  if (solutions) {
    if (*solutions_size_local_ptr) {
      /* Convert the list to an array */
      *solutions = malloc(sizeof(**solutions) * *solutions_size_local_ptr);

      for (size_t i = 0; i < *solutions_size_local_ptr; i++) {
	square_list_t *old = beginning;

	(*solutions)[i] = beginning->square;
	beginning = beginning->next;

	free(old);
      }
    }

    assert(beginning == solutions_list);

    free(solutions_list);
  }

  return 0;
}

/* Duplicate a square including the description */
square_t square_dup(square_t square) {
  square_t new_square = square_dup_nodesc(square);

  if (square.description)
    new_square.description = strdup(square.description);

  return new_square;
}

/* Duplicate a square excluding the description */
square_t square_dup_nodesc(square_t square) {
  square_t new_square;

  new_square.size		= square.size;
  new_square.nums_size		= square.nums_size;
  /* Don't duplicate it; that's for square_dup() */
  new_square.description	= NULL;

  new_square.cells	= malloc(sizeof(*(new_square.cells)) *
				 new_square.size);
  if (new_square.nums_size)
    new_square.nums	= malloc(sizeof(*(new_square.nums)) *
				 new_square.nums_size);

  for (size_t x = 0; x < new_square.size; x++) {
    new_square.cells[x] = malloc(sizeof(*(new_square.cells[x])) *
				  new_square.size);

    for (size_t y = 0; y < new_square.size; y++) {
      new_square.cells[x][y] = square.cells[x][y];
    }
  }

  for (size_t i = 0; i < new_square.nums_size; i++) {
    new_square.nums[i] = square.nums[i];
  }

  return new_square;
}

/* Destroy a square */
void square_destroy(square_t *square) {
  if (square->description)
    free(square->description);

  /* Free each of the columns */
  for (size_t i = 0; i < square->size; i++)
    free(square->cells[i]);

  /* Now free `square->cells' itself */
  free(square->cells);

  /* Free the other numbers if there are any */
  if (square->nums_size)
    free(square->nums);
}

/***************************\
|* Static helper functions *|
\***************************/

/* Add two `cellval_t's and return the result. */
static cellval_t add_cellval(cellval_t a, cellval_t b) {
  /* Make sure the types are valid */
  if (!cellval_valid(a)) {
    errno = EINVAL;
    return a;
  }

  if (!cellval_valid(b)) {
    errno = EINVAL;
    return b;
  }

  switch (a.type) {
  case INT:
    switch (b.type) {
    case INT:
      a.i += b.i;
      break;
    case FLOAT:
      a.type = FLOAT;
      a.f = a.i + b.f;
      break;
    }

    break;
  case FLOAT:
    a.f += (b.type == INT) ? b.i : b.f;
    break;
  }

  return a;
}

/* Solve a square adding the solutions (if found) to
   `solutions'. Returns 0 on success or nonzero on error. */
static int square_solve_recursive(square_t *square, cellval_t sum,
				  square_list_t **solutions,
				  size_t *solutions_size, char keep_going,
				  FILE *file) {
  char all_immutable = 1; /* Whether all of the cells are immutable */
  int valid;

  if (!keep_going && *solutions_size > 0) {
    /* We've already found a solution and we weren't asked to find
       more. */
    return 0;
  }

  valid = square_valid(square, sum);

  if (valid < 0) {
    /* Error */
    return 1;
  }

  if (!valid) {
    /* We needn't check this square; it's already invalid */
    return 0;
  }

  for (size_t x = 0; x < square->size && all_immutable; x++) {
    for (size_t y = 0; y < square->size && all_immutable; y++) {
      if (square->cells[x][y].mutable) {
	cellval_t *nums;

	if (!square->nums_size) {
	  /* That's an error */
	  errno = EINVAL;
	  return 1;
	}

	all_immutable = 0;

	/* Temporarily make the cell immutable */
	square->cells[x][y].mutable = 0;

	/* Make a copy of `square->nums' so we can delete certain numbers */
	nums = malloc(sizeof(*nums) * square->nums_size);

	for (size_t i = 0; i < square->nums_size; i++) {
	  nums[i] = square->nums[i];
	}

	/* Decrement `nums_size' by one since we will delete one
	   number from `square->nums' */
	square->nums_size--;

	/* Try to solve the square with this cell replaced with each
	   of the other numbers (remember, `square->nums_size + 1' is
	   the actual size of `nums') */
	for (size_t i = 0; i < square->nums_size + 1; i++) {
	  square->cells[x][y].val = nums[i];

	  /* Delete the number from `square->nums' */
	  square->nums[i] = square->nums[square->nums_size];

	  /* Try to solve it */
	  if (square_solve_recursive(square, sum, solutions, solutions_size,
				     keep_going, file)) {
	    /* Error */
	    square->nums[i] = nums[i];
	    square->nums_size++;
	    free(nums);

	    return 1;
	  }

	  /* Put the number back */
	  square->nums[i] = nums[i];
	}

	/* We're done with this now */
	free(nums);

	/* And don't forget to reset `square->nums_size'! */
	square->nums_size++;

	/* Reset the cell to mutable */
	square->cells[x][y].mutable = 1;
      }
    }
  }

  if (all_immutable) {
    /* Hooray! We've found a solution! Add it to `solutions' if it's
       not NULL. */
    if (solutions) {
      (*solutions)->square = square_dup_nodesc(*square);
      (*solutions)->next = malloc(sizeof(*((*solutions)->next)));
      (*solutions) = (*solutions)->next;
    }

    (*solutions_size)++;

    if (file) {
      char *description = square->description;

      /* Write the solution to `file' */

      /* Temporarily get rid of the description */
      square->description = NULL;

      if (*solutions_size > 1 || description) {
	/* Separate this solution from the last solution or the
	   description */
	putc('\n', file);
      }

      if (keep_going)
	fprintf(file, "Solution %zu:\n\n", *solutions_size);

      if (!write_human(square, file)) {
	/* Error! */
	square->description = description;
	return 1;
      }

      /* Replace the description */
      square->description = description;
    }
  }

  return 0;
}

/* Check if a square is valid returning -1 on error. */
static int square_valid(square_t *square, cellval_t sum) {
  char mutable; /* Whether at least one cell was mutable */
  cellval_t check_sum;

  /* Check each column */
  for (size_t x = 0; x < square->size; x++) {
    check_sum.type = INT;
    check_sum.i = 0;
    mutable = 0;

    for (size_t y = 0; y < square->size; y++) {
      if (square->cells[x][y].mutable) {
	mutable = 1;
	break;
      }

      check_sum = add_cellval(check_sum, square->cells[x][y].val);

      if (!cellval_valid(sum))
	return -1;
    }

    if (!mutable && !cellval_equal(check_sum, sum)) {
      /* It's invalid */
      return 0;
    }
  }

  /* Check each row */
  for (size_t y = 0; y < square->size; y++) {
    check_sum.type = INT;
    check_sum.i = 0;
    mutable = 0;

    for (size_t x = 0; x < square->size; x++) {
      if (square->cells[x][y].mutable) {
	mutable = 1;
	break;
      }

      check_sum = add_cellval(check_sum, square->cells[x][y].val);

      if (!cellval_valid(sum))
	return -1;
    }

    if (!mutable && !cellval_equal(check_sum, sum)) {
      /* It's invalid */
      return 0;
    }
  }

  /* Check the upper left to lower right diagonal */
  check_sum.type = INT;
  check_sum.i = 0;
  mutable = 0;

  for (size_t i = 0; i < square->size; i++) {
    if (square->cells[i][i].mutable) {
      mutable = 1;
      break;
    }

    check_sum = add_cellval(check_sum, square->cells[i][i].val);

    if (!cellval_valid(sum))
      return -1;
  }

  if (!mutable && !cellval_equal(check_sum, sum)) {
    /* It's invalid */
    return 0;
  }

  /* Check the upper right to lower left diagonal */
  check_sum.type = INT;
  check_sum.i = 0;
  mutable = 0;

  for (size_t i = 0; i < square->size; i++) {
    if (square->cells[square->size - 1 - i][i].mutable) {
      mutable = 1;
      break;
    }

    check_sum = add_cellval(check_sum,
			    square->cells
			    [square->size - 1 - i][i].val);

    if (!cellval_valid(sum))
      return -1;
  }

  if (!mutable && !cellval_equal(check_sum, sum)) {
    /* It's invalid */
    return 0;
  }

  /* If we're here, the square is valid */
  return 1;
}
