/* This file is part of GNU Libraries and Engines for Games  -*- c++ -*-

   $Id: matrix.h,v 1.5 2004/04/30 20:15:54 jechk Exp $
   $Log: matrix.h,v $
   Revision 1.5  2004/04/30 20:15:54  jechk
   Big merge.  See ChangeLog for details.

   Revision 1.4  2004/03/11 06:44:52  jechk
   Made the Sq function generic; added some TODO comments.

   Revision 1.3  2004/03/08 22:26:36  jechk
   Maths update.  Mainly, added Polynomial class.

   Revision 1.2  2004/03/03 03:50:02  jechk
   Changed some names, comments and other things for consistency.

   Revision 1.1  2004/03/03 02:05:22  jechk
   Merged many changes.  See ChangeLog for details.


   Created 1/20/04 by Jeff Binder <bindej@rpi.edu>
   
   Copyright (c) 2003, 2004 Free Software Foundation
   
   This program is free software; you can redistribute it and/or
   modify it under the terms of the GNU General Public
   License as published by the Free Software Foundation; either
   version 2.1 of the License, or (at your option) any later version.
   
   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
   General Public License for more details.
   
   You should have received a copy of the GNU General Public
   License along with this program; if not, write to the Free Software
   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
*/
/*! \file leg/support/maths/matrix.h
  \brief A class representing transformation matrices.
*/

#include <iostream>

#include <leg/support/maths/vector.h>
#include <leg/support/maths/quaternion.h>

#ifndef LEG_SUPPORT_MATHS_MATRIX_H
#define LEG_SUPPORT_MATHS_MATRIX_H

namespace leg
{

namespace support
{

namespace maths
{

  // TODO:  Implement the rest;  Gauss-Jordan elimination.
  // Make the bounds-checking optional.

template<unsigned int m, unsigned int n>
class BasicMatrix;

//! A class to represent m x n matrices of real numbers.
template<unsigned int m, unsigned int n>
class Matrix : public BasicMatrix<m, n>
{
public:
  //! Copy constructor.
  Matrix (const BasicMatrix<m, n> &a)
    : BasicMatrix<m, n> (a)
  {
  }

  /// Copy constructor.
  Matrix (const Matrix<m, n> &a)
    : BasicMatrix<m, n> (a)
  {
  }

  //! Does not initialize the elements.
  Matrix ()
    : BasicMatrix<m, n> ()
  {
  }
};

template<unsigned int m, unsigned int n>
class BasicMatrix
{
 protected:
  real el[m * n];

 public:
  //! Does not initialize the elements.
  BasicMatrix ()
  {
  }

  BasicMatrix (const BasicMatrix<m, n> &a)
  {
    memcpy (el, a.el, m * n * sizeof *el);
  }

  //! Copy constructor.
  BasicMatrix (const Matrix<m, n> &a)
  {
    memcpy (el, a.el, m * n * sizeof *el);
  }

  ~BasicMatrix ()
  {
  }

  //! Access the specified element of the matrix.
  real&
  operator () (unsigned int i, unsigned int j)
  {
    if (i >= m || j >= n)
      {
	throw utils::Error ("Matrix index out of bounds");
      }

    return el[i * n + j];
  }

  //! Access the specified element of the matrix.
  real
  operator () (unsigned int i, unsigned int j) const
  {
    if (i >= m || j >= n)
      {
	throw utils::Error ("Matrix index out of bounds");
      }

    return el[i * n + j];
  }

  //! Returns true if the matrix is a square matrix.
  bool
  IsSquare () const
  {
    return m == n;
  }

  //! Returns the kth row in vector form.
  Vector<n>
  Row (unsigned int k) const
  {
    Vector<n> v;

    for (unsigned int i = 0; i < n; ++i)
      {
	v (i) = (*this) (k, i);
      }

    return v;
  }

  //! Returns the kth column in vector form.
  Vector<m>
  Column (unsigned int k) const
  {
    Vector<m> v;

    for (unsigned int i = 0; i < m; ++i)
      {
	v (i) = (*this) (i, k);
      }

    return v;
  }

  //! Return the matrix with the ith row and mth column dropped.
  Matrix<m-1, n-1>
  Minor (unsigned int i, unsigned int j) const
  {
    Matrix<m-1, n-1> a;

    if (i >= m || j >= n)
      {
	throw utils::Error ("Matrix index out of bounds");
      }

    bool passed_row = false;
    for (unsigned int ii = 0; ii < m; ++ii)
      {
	if (ii == i)
	  {
	    passed_row = true;
	  }
	else
	  {
	    bool passed_col = false;
	    for (unsigned int jj = 0; jj < n; ++jj)
	      {
		if (jj == j)
		  {
		    passed_col = true;
		  }
		else
		  {
		    a (passed_row? ii-1: ii,
		       passed_col? jj-1: jj)
		      = (*this) (ii, jj);
		  }
	      }
	  }
      }

    return a;
  }

  //! Returns the specified cofactor (in terms of deteminants).
  real
  Cofactor (unsigned int i, unsigned int j) const
  {
    return Minor (i, j).Determinant () * (((i + j) % 2)? -1.: 1.);
  }

  //! Return the determinant of the matrix.
  real
  Determinant () const
  {
    real det = 0.;
    
    for (unsigned int i = 0; i < m; ++i)
      {
	if (el[i])
	  {
	    det += el[i] * Cofactor (0, i);
	  }
      }

    return det;
  }

  //! Returns the multiplicative inverse of the matrix.
  /*! Throws an exception if the matrix is not invertible.
   */
  Matrix<n, m>
  Inverse () const
  {
    Matrix<n, m> a;
    real det = Determinant ();

    if (det == 0.)
      {
	throw utils::Error ("Matrix is not invertible.");
      }
    else
      {
	det = 1. / det;
      }
    for (unsigned int i = 0; i < m; ++i)
      {
	for (unsigned int j = 0; j < n; ++j)
	  {
	    a (j, i) = Cofactor (i, j) * det;
	  }
      }

    return a;
  }

  void
  GetLUDecomposition (Matrix<m, n> &l, Matrix<m, n> &d, Matrix<m, n> &u) const
  {
    // TODO
  }

  void
  GetQRDecomposition (Matrix<m, n> &q, Matrix<m, n> &r) const
  {
    // TODO
  }

  void
  GetSVDecomposition (Matrix<m, n> &q1, Matrix<m, n> &sigma, Matrix<m, n> &q2) const
  {
    // TODO
  }

  void
  GetEigenDecomposition (Matrix<m, n> &q, Matrix<m, n> &lambda) const
  {
    // TODO
  }

  //! Return a transposed copy of the matrix.
  Matrix<n, m>
  Transpose () const
  {
    Matrix<n, m> a;

    for (unsigned int i = 0; i < m; ++i)
      {
	for (unsigned int j = 0; j < n; ++j)
	  {
	    a(j, i) = (*this) (i, j);
	  }
      }

    return a;
  }

  //! Return the trace of the matrix (the sum of its diagonal components).
  real
  Trace () const
  {
    real t = 0.;

    if (m != n)
      {
	throw utils::Error ("Attempt to take the Trace of a non-square matrix");
      }
    for (unsigned int i = 0; i < m; ++i)
      {
	t += (*this) (i, i);
      }

    return t;
  }

  //! Assignment equal operator.
  Matrix<m, n>
  operator = (const Matrix<m, n> &r)
    {
      memcpy (el, r.el, m * n * sizeof *el);

      return *static_cast<Matrix<m, n> *> (this);
    }

  /// Matrix/real multiplication.
  Matrix<m, n>
  operator * (real r) const
    {
      Matrix<m, n> a;

      for (unsigned int i = 0; i < m * n; ++i)
	{
	  a.el[i] = el[i] * r;
	}
      
      return a;
    }
  
  /// Destructive matrix/real multiplication.
  Vector<n>
  operator *= (real r)
    {
      for (unsigned int i = 0; i < m * n; ++i)
	{
	  el[i] *= r;
	}
      
      return *this;
    }

  /// Matrix/real division.
  Matrix<m, n>
  operator / (real r) const
    {
      Matrix<m, n> a;

      r = 1. / r;
      for (unsigned int i = 0; i < m * n; ++i)
	{
	  a.el[i] = el[i] * r;
	}
      
      return a;
    }
  
  /// Destructive matrix/real division.
  Vector<n>
  operator /= (real r)
    {
      r = 1. / r;
      for (unsigned int i = 0; i < m * n; ++i)
	{
	  el[i] *= r;
	}
      
      return *this;
    }

  //! Perform matrix multiplication.
  template<unsigned int k>
  Matrix<m, k>
  operator * (const Matrix<n, k> &a) const
  {
    Matrix<m, k> b;
    
    for (unsigned int i = 0; i < m; ++i)
      {
	for (unsigned int j = 0; j < k; ++j)
	  {
	    b (i, j) = Row (i).Dot (a.Column (j));
	  }
      }
    
    return b;
  }

  //! Multiplication and assignment operator.
  Matrix<m, n>
  operator *= (const Matrix<m, n> &a)
  {
    return (*this) = (*this) * a;
  }

  //! Perform matrix addition.
  Matrix<m, n>
  operator + (const Matrix<m, n> &a) const
  {
    Matrix<m, n> b (*this);

    for (unsigned int i = 0; i < m; ++i)
      {
	for (unsigned int j = 0; j < n; ++j)
	  {
	    b (i, j) += a (i, j);
	  }
      }

    return b;
  }

  //! Addition and assignment operator.
  Matrix<m, n>
  operator += (const Matrix<m, n> &a)
  {
    for (unsigned int i = 0; i < m; ++i)
      {
	for (unsigned int j = 0; j < n; ++j)
	  {
	    (*this) (i, j) += a (i, j);
	  }
      }

    return *this;
  }

  //! Perform matrix subtraction.
  Matrix<m, n>
  operator - (const Matrix<m, n> &a) const
  {
    Matrix<m, n> b (*this);

    for (unsigned int i = 0; i < m; ++i)
      {
	for (unsigned int j = 0; j < n; ++j)
	  {
	    b (i, j) -= a (i, j);
	  }
      }

    return b;
  }

  //! Subtraction and assignment operator.
  Matrix<m, n>
  operator -= (const Matrix<m, n> &a)
  {
    for (unsigned int i = 0; i < m; ++i)
      {
	for (unsigned int j = 0; j < n; ++j)
	  {
	    (*this) (i, j) -= a (i, j);
	  }
      }

    return *this;
  }

  //! Perform matrix-vector multiplication.
  Vector<m>
  operator * (const Vector<n> &v) const
  {
    Vector<m> u;

    for (unsigned int i = 0; i < m; ++i)
      {
	u (i) = Row (i).Dot (v);
      }

    return u;
  }
};

template<>
class Matrix<1, 1> : public BasicMatrix<1, 1>
{
public:
  //! Copy constructor.
  Matrix (const BasicMatrix<1, 1> &a)
    : BasicMatrix<1, 1> (a)
  {
  }

  /// Copy constructor.
  Matrix (const Matrix<1, 1> &a)
    : BasicMatrix<1, 1> (a)
  {
  }

  //! Does not initialize the elements.
  Matrix ()
    : BasicMatrix<1, 1> ()
  {
  }

  real
  Determinant ()
  {
    return el[0];
  }
};

template<>
class Matrix<4, 4> : public BasicMatrix<4, 4>
{
public:
  //! Copy constructor.
  Matrix (const BasicMatrix<4, 4> &a)
    : BasicMatrix<4, 4> (a)
  {
  }

  /// Copy constructor.
  Matrix (const Matrix<4, 4> &a)
    : BasicMatrix<4, 4> (a)
  {
  }

  //! Does not initialize the elements.
  Matrix ()
    : BasicMatrix<4, 4> ()
  {
  }

  //! Creates a new matrix with the specified elements.
  /*!
   * Note that the elements are specified in row-major order.  That is, the
   * first digit in the name of the argument is the row, and the second is
   * the column.
   * Added by Jdf on Dec 25 2004:
   * This is not safer than the one below, but this is faster.
   */
  Matrix (real *n)
    : BasicMatrix<4, 4> ()
  {
    std::memcpy (el, n, 16 * sizeof (real));
  }
  
  //! Creates a new matrix with the specified elements.
  /*!
   * Note that the elements are specified in row-major order.  That is, the
   * first digit in the name of the argument is the row, and the second is
   * the column.
   */
  Matrix (real n11, real n12, real n13, real n14,
	  real n21, real n22, real n23, real n24,
	  real n31, real n32, real n33, real n34,
	  real n41, real n42, real n43, real n44)
    : BasicMatrix<4, 4> ()
  {
    el[0] = n11;
    el[1] = n12;
    el[2] = n13;
    el[3] = n14;

    el[4] = n21;
    el[5] = n22;
    el[6] = n23;
    el[7] = n24;

    el[8] = n31;
    el[9] = n32;
    el[10] = n33;
    el[11] = n34;

    el[12] = n41;
    el[13] = n42;
    el[14] = n43;
    el[15] = n44;
  }

  //! Creates a new matrix representing the rotation represented by a quaternion.
  Matrix (Quaternion q)
    : BasicMatrix<4, 4> ()
  {
    el[0] = 1 - 2 * (Sq (q.c) + Sq (q.d));
    el[1] = 2 * (q.b * q.c - q.a * q.d);
    el[2] = 2 * (q.a * q.c + q.b * q.d);
    el[3] = 0;
    
    el[4] = 2 * (q.b * q.c + q.a * q.d);
    el[5] = 1 - 2 * (Sq (q.b) + Sq (q.d));
    el[6] = 2 * (q.c * q.d - q.a * q.b);
    el[7] = 0;
    
    el[8] = 2 * (q.b * q.d - q.a * q.c);
    el[9] = 2 * (q.c * q.d + q.a * q.b);
    el[10] = 1 - 2 * (Sq (q.b) + Sq (q.c));
    el[11] = 0;
    
    el[12] = 0;
    el[13] = 0;
    el[14] = 0;
    el[15] = 1;
  }

  real
  Determinant () const
  {
    real det = 0, minor_det;

    minor_det = (el[5] * el[10] * el[15] + el[6] * el[11] * el[13] + el[7] * el[9] * el[14]
		 - el[5] * el[11] * el[14] - el[6] * el[9] * el[15]
		 - el[7] * el[10] * el[13]);
    det += minor_det * el[0];

    minor_det = (el[4] * el[10] * el[15] + el[6] * el[11] * el[12] + el[7] * el[8] * el[14]
		 - el[4] * el[11] * el[14] - el[6] * el[8] * el[15]
		 - el[7] * el[10] * el[12]);
    det -= minor_det * el[1];

    minor_det = (el[4] * el[9] * el[15] + el[5] * el[11] * el[12] + el[7] * el[8] * el[13]
		 - el[4] * el[11] * el[13] - el[5] * el[8] * el[15]
		 - el[7] * el[9] * el[12]);
    det += minor_det * el[2];

    minor_det = (el[4] * el[9] * el[14] + el[5] * el[10] * el[12] + el[6] * el[8] * el[13]
		 - el[4] * el[10] * el[13] - el[5] * el[8] * el[14]
		 - el[6] * el[9] * el[12]);
    det -= minor_det * el[3];

    return det;
  }

  Matrix<4, 4>
  Inverse () const
  {
    real pairs[12], det;
    Matrix inv;

    pairs[0] = el[10] * el[15];
    pairs[1] = el[14] * el[11];
    pairs[2] = el[6] * el[15];
    pairs[3] = el[14] * el[7];
    pairs[4] = el[6] * el[11];
    pairs[5] = el[10] * el[7];
    pairs[6] = el[2] * el[15];
    pairs[7] = el[14] * el[3];
    pairs[8] = el[2] * el[11];
    pairs[9] = el[10] * el[3];
    pairs[10] = el[2] * el[7];
    pairs[11] = el[6] * el[3];

    inv (0, 0) =
      (pairs[0] * el[5] + pairs[3] * el[9] +
       pairs[4] * el[13]) - (pairs[1] * el[5] +
			       pairs[2] * el[9] +
			       pairs[5] * el[13]);
    inv (0, 1) =
      (pairs[1] * el[1] + pairs[6] * el[9] +
       pairs[9] * el[13]) - (pairs[0] * el[1] +
			       pairs[7] * el[9] +
			       pairs[8] * el[13]);
    inv (0, 2) =
      (pairs[2] * el[1] + pairs[7] * el[5] +
       pairs[10] * el[13]) - (pairs[3] * el[1] +
				pairs[6] * el[5] +
				pairs[11] * el[13]);
    inv (0, 3) =
      (pairs[5] * el[1] + pairs[8] * el[5] +
       pairs[11] * el[9]) - (pairs[4] * el[1] +
			       pairs[9] * el[5] +
			       pairs[10] * el[9]);
    inv (1, 0) =
      (pairs[1] * el[4] + pairs[2] * el[8] +
       pairs[5] * el[12]) - (pairs[0] * el[4] +
			       pairs[3] * el[8] +
			       pairs[4] * el[12]);
    inv (1, 1) =
      (pairs[0] * el[0] + pairs[7] * el[8] +
       pairs[8] * el[12]) - (pairs[1] * el[0] +
			       pairs[6] * el[8] +
			       pairs[9] * el[12]);
    inv (1, 2) =
      (pairs[3] * el[0] + pairs[6] * el[4] +
       pairs[11] * el[12]) - (pairs[2] * el[0] +
				pairs[7] * el[4] +
				pairs[10] * el[12]);
    inv (1, 3) =
      (pairs[4] * el[0] + pairs[9] * el[4] +
       pairs[10] * el[8]) - (pairs[5] * el[0] +
			       pairs[8] * el[4] +
			       pairs[11] * el[8]);

    pairs[0] = el[8] * el[13];
    pairs[1] = el[12] * el[9];
    pairs[2] = el[4] * el[13];
    pairs[3] = el[12] * el[5];
    pairs[4] = el[4] * el[9];
    pairs[5] = el[8] * el[5];
    pairs[6] = el[0] * el[13];
    pairs[7] = el[12] * el[1];
    pairs[8] = el[0] * el[9];
    pairs[9] = el[8] * el[1];
    pairs[10] = el[0] * el[5];
    pairs[11] = el[4] * el[1];

    inv (2, 0) =
      (pairs[0] * el[7] + pairs[3] * el[11] +
       pairs[4] * el[15]) - (pairs[1] * el[7] +
			       pairs[2] * el[11] +
			       pairs[5] * el[15]);
    inv (2, 1) =
      (pairs[1] * el[3] + pairs[6] * el[11] +
       pairs[9] * el[15]) - (pairs[0] * el[3] +
			       pairs[7] * el[11] +
			       pairs[8] * el[15]);
    inv (2, 2) =
      (pairs[2] * el[3] + pairs[7] * el[7] +
       pairs[10] * el[15]) - (pairs[3] * el[3] +
				pairs[6] * el[7] +
				pairs[11] * el[15]);
    inv (2, 3) =
      (pairs[5] * el[3] + pairs[8] * el[7] +
       pairs[11] * el[11]) - (pairs[4] * el[3] +
				pairs[9] * el[7] +
				pairs[10] * el[11]);
    inv (3, 0) =
      (pairs[2] * el[10] + pairs[5] * el[14] +
       pairs[1] * el[6]) - (pairs[4] * el[14] +
			      pairs[0] * el[6] +
			      pairs[3] * el[10]);
    inv (3, 1) =
      (pairs[8] * el[14] + pairs[0] * el[2] +
       pairs[7] * el[10]) - (pairs[6] * el[10] +
			       pairs[9] * el[14] +
			       pairs[1] * el[2]);
    inv (3, 2) =
      (pairs[6] * el[6] + pairs[11] * el[14] +
       pairs[3] * el[2]) - (pairs[10] * el[14] +
			      pairs[2] * el[2] +
			      pairs[7] * el[6]);
    inv (3, 3) =
      (pairs[10] * el[10] + pairs[4] * el[2] +
       pairs[2] * el[6]) - (pairs[8] * el[6] +
			      pairs[11] * el[10] +
			      pairs[5] * el[2]);

    det = el[0] * inv (0, 0) + el[4] * inv (0, 1) +
      el[8] * inv (0, 2) + el[12] * inv (0, 3);

    if (!det)
      {
	throw utils::Error ("No inverse exists.");
	return Matrix ();
      }

    return inv / det;
  }

  //! Create a translation matrix and multiply it into the matrix.
  void
  Translate (const Vector<3> &v)
  {
    el[3] = el[0] * v.x + el[1] * v.y + el[2] * v.z + el[3];
    el[7] = el[4] * v.x + el[5] * v.y + el[6] * v.z + el[7];
    el[11] = el[8] * v.x + el[9] * v.y + el[10] * v.z + el[11];
    el[15] = el[12] * v.x + el[13] * v.y + el[14] * v.z + el[15];
  }

  //! Create a scaling matrix and multiply it into the matrix.
  void
  Scale (const Vector<3> &v)
  {
    el[0] *= v.x;
    el[5] *= v.y;
    el[10] *= v.z;
  }

  //! Create a rotation matrix and multiply it into the matrix.
  void
  Rotate (real angle, const Vector<3> &v)
  {
    real c = cos (angle);
    real s = sin (angle);

    Matrix<4, 4> m (v.x * v.x * (1. - c) + c, v.x * v.y * (1. - c) - v.z * s,
		    v.x * v.z * (1. - c) + v.y * s, 0.,
		    v.y * v.x * (1. - c) + v.z * s, v.y * v.y * (1. - c) + c,
		    v.y * v.z * (1. - c) - v.x * s, 0.,
		    v.z * v.x * (1. - c) - v.y * s, v.z * v.y * (1. - c) + v.x * s,
		    v.z * v.z * (1. - c) + c, 0., 0., 0., 0., 1.);
    *this *= m;
  }

  //! Multiply with a 3-element vector using homogeneous coordinates.
  Vector<3>
  HomogeneousMultiply (const Vector<3> &v) const
  {
    Vector<4> vh = *this * v.GetHomogeneousVector ();
    return Vector<3> (vh (0), vh (1), vh (2)) / vh (3);
  }
   /*
  void Set (real *v)
  {
     std::memcpy (&el[0],&v[0],16*sizeof (real));
  }
  */
};


//! Create an identity matrix.
template<unsigned int m, unsigned int n>
inline Matrix<m, n>
IdentityMatrix ()
{
  Matrix<m, n> a;// (identity);

  for (unsigned int i = 0; i < m; ++i)
    {
      for (unsigned int j = 0; j < n; ++j)
	{
	  a (i, j) = (m == n)? 1.: 0.;
	}
    }
  

  return a;
}

//! Create a translation matrix.
Matrix<4, 4> TranslationMatrix (const Vector<3> &v);

//! Create a scaling matrix.
Matrix<4, 4> ScalingMatrix (const Vector<3> &v);

//! Create a rotation matrix and multiply it into the matrix.
Matrix<4, 4> RotationMatrix (real angle, const Vector<3> &v);

template<unsigned int m, unsigned int n>
inline std::ostream &
operator << (std::ostream &s, Matrix<m, n> a)
{
  for (unsigned int i = 0; i < m; ++i)
    {
      if (i == 0)
	{
	  s << "/";
	}
      else if (i == m - 1)
	{
	  s << "\\";
	}
      else
	{
	  s << "|";
	}
      for (unsigned int j = 0; j < n; ++j)
	{
	  s << std::setw (14) << a (i, j);
	  
	  if (j == n - 1)
	    {
	      if (i == 0)
		{
		  s << "\\";
		}
	      else if (i == m - 1)
		{
		  s << "/";
		}
	      else
		{
		  s << "|";
		}
	    }
	}
      if (i != m - 1)
	{
	  s << std::endl;
	}
    }
  
  return s;
}

}

}

}

#endif // LEG_SUPPORT_MATHS_MATRIX_H
