/*!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Copyright 2010.  Los Alamos National Security, LLC. This material was    !
! produced under U.S. Government contract DE-AC52-06NA25396 for Los Alamos !
! National Laboratory (LANL), which is operated by Los Alamos National     !
! Security, LLC for the U.S. Department of Energy. The U.S. Government has !
! rights to use, reproduce, and distribute this software.  NEITHER THE     !
! GOVERNMENT NOR LOS ALAMOS NATIONAL SECURITY, LLC MAKES ANY WARRANTY,     !
! EXPRESS OR IMPLIED, OR ASSUMES ANY LIABILITY FOR THE USE OF THIS         !
! SOFTWARE.  If software is modified to produce derivative works, such     !
! modified software should be clearly marked, so as not to confuse it      !
! with the version available from LANL.                                    !
!                                                                          !
! Additionally, this program is free software; you can redistribute it     !
! and/or modify it under the terms of the GNU General Public License as    !
! published by the Free Software Foundation; version 2.0 of the License.   !
! Accordingly, this program is distributed in the hope that it will be     !
! useful, but WITHOUT ANY WARRANTY; without even the implied warranty of   !
! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General !
! Public License for more details.                                         !
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/

#include <math.h>
#include <stdio.h>

#include "Matrix.h"

#define sign(X,Y) ((Y)>=0.0?abs(X):-abs(X))

// Ugh... this is a workaround so this variable only gets declared once
#if REALSIZE==4
  int *signlist=NULL;
#else
  extern int *signlist;
#endif

void sp2fermi_init_nospin(REAL bndfil, int hdim, REAL *bo_pointer, REAL maxeval, REAL *h_pointer, REAL maxminusmin, REAL *chempot_pointer,
			  int norecs, REAL *kbt_pointer, REAL *beta0_pointer, REAL breaktol) {
  int i, j, ii, iter;
  int breakloop, mysign;
  REAL trx, occ=1.0*bndfil*hdim, trx2, trxomx, trx1, lambda=0.0;
  REAL maxshift=1.0, occerror=0.0, kbt=*kbt_pointer, chempot=*chempot_pointer;

  REAL *x1_pointer;
  REAL preverror, preverror2, preverror3, beta0;
  
  Matrix bo, x2;
  M_InitWithLocal(bo, bo_pointer, hdim, hdim);

  M_Init(x2, hdim, hdim);
  
  signlist=(int *)realloc(signlist, norecs*sizeof(int));

  beta0=*beta0_pointer;

  //
  // We'll have spin and non-spin dependent versions separate
  //
  
  iter = 0;

  breakloop = 0;

  preverror = 0.0;
  preverror2 = 0.0;
  preverror3 = 0.0;

  Matrix x0x1, x1x0, x1;
  M_Init(x0x1, hdim, hdim);
  M_Init(x1x0, hdim, hdim);
  M_Init(x1, hdim, hdim);
  x1_pointer=x1.Local;
  
  while(breakloop==0) { // This loop locates the chemical potential
    iter = iter + 1;
    if (iter == 100) {
      printf("sp2fermiinit is not converging: stop!\n");
      exit(1);
    }

    memset(x1_pointer, '\0', hdim*hdim*sizeof(REAL));

    for (i = 0; i<hdim; i++) {
      for (j = i; j<hdim; j++) {
        if (i == j)
          bo_pointer[i+i*hdim] = (maxeval - h_pointer[i+i*hdim] - chempot)/maxminusmin;
        else {
          bo_pointer[j+i*hdim] = -h_pointer[j+i*hdim]/maxminusmin;
          bo_pointer[i+j*hdim] = bo_pointer[j+i*hdim];
	}
      }
      x1_pointer[i+i*hdim] = -1.0/maxminusmin;
    }

    M_Push(bo);
//printf("Initial bo:\n");
//bo.Print();
    M_Push(x1);
    for (ii=0; ii<norecs; ii++) {
      //
      // density matrix squared
      //

      M_Multiply(bo, bo, x2);
      trx=M_Trace(bo);
      trx2=M_Trace(x2);

      if (fabs(trx2-occ)<fabs(2.0*trx-trx2-occ))
        mysign = -1;
      else
        mysign = 1;

      signlist[ii] = mysign;

      //
      // density matrix x reponse
      //

      M_Multiply(bo, x1, x0x1);
      M_Multiply(x1, bo, x1x0);

      // 
      // update response
      //
           
      if (mysign==1) {
        M_Multiply(2.0, x1, x1);
        M_Subtract(x1, x0x1, x1);
	M_Subtract(x1, x1x0, x1);
        //
        // update density matrix
        //
	M_Multiply(2.0, bo, bo);
	M_Subtract(bo, x2, bo);
      }
      else {
	// OPT: Make a function that zeroes 
	M_Subtract(x1, x1, x1);
	M_Add(x1, x0x1, x1);
	M_Add(x1, x1x0, x1);
        //
        // update density matrix
        //
	M_Subtract(bo, bo, bo);
	M_Add(bo, x2, bo);
      }
    }

    trx1=M_Trace(x1);
    trx=M_Trace(bo);
    trx2=M_TraceX2(bo);
    trxomx=trx-trx2;
    
    if (fabs(trxomx)>1e-6) {
      beta0 = trx1/trxomx;
      kbt = fabs(1.0/beta0);
    }
    else {
      beta0 = 1e10;
      kbt = 0.0;
    }

    if (fabs(trx1)>1e-6)
      lambda = (occ - trx)/trx1;
    else {
      if (occ > trx)
        lambda = maxshift;
      if (occ < trx)
        lambda = -maxshift;
    }

    if (fabs(lambda) > maxshift)
      lambda = sign(maxshift, lambda);

    chempot = chempot + lambda;

    preverror3 = preverror2;
    preverror2 = preverror;
    preverror = occerror;
        
    occerror = fabs(occ - trx);

//printf("chempot: %f, lambda: %f, trx: %f, trx1: %f, occ: %f, occerror: %f\n", chempot, lambda, trx, trx1, occ, occerror);

    // how we figure if we've reached convergence. an absolute
    // tolerance works well in double precision, but in single
    // precision we need to look for noise when we're near 
    // self-consistency

    if (sizeof(REAL) == 8) {
      if (occerror < breaktol)
        breakloop = 1;
    }
    else {
      if (occerror == preverror || occerror == preverror2 || occerror == preverror3 || iter == 25 )
        breakloop = 1;
    }
  }

  //
  // little bit of fine tuning...
  //
  M_Multiply(lambda, x1, x1);
  M_Add(bo, x1, bo);
  M_Multiply(2.0, bo, bo);

  M_Pull(bo);

  *kbt_pointer=kbt;
  *chempot_pointer=chempot;
  *beta0_pointer=beta0;  

  M_DeallocateLocal(x0x1);
  M_DeallocateDevice(x0x1);
  M_DeallocateLocal(x1x0);
  M_DeallocateDevice(x1x0);
  M_DeallocateLocal(x1);
  M_DeallocateDevice(x1);
  M_DeallocateLocal(x2);
  M_DeallocateDevice(x2);
  M_DeallocateDevice(bo);
}

void sp2fermi_init_spin(REAL bndfil, int hdim, REAL *rhoup_ptr, REAL *rhodown_ptr, REAL maxeval, REAL *hup, REAL *hdown, REAL maxminusmin, REAL *chempot_pointer,
			  int norecs, REAL *kbt_pointer, REAL *beta0_pointer, REAL breaktol) {
  // This is an implementation of Niklasson's SP2 algorithm for the
  // Fermi operator (i.e., finite temperature SP2)
  int i, j, ii, iter=0;
  int breakloop=0, mysign;
  REAL trx, occ=2.0*bndfil*hdim, trx2, trxomx, trx1;
  REAL lambda, maxshift = 1.0, occerror = 0.0;
  REAL preverror=0.0, preverror2=0.0, preverror3=0.0, chempot, kbt, beta0;
  REAL totne=2.0*bndfil*hdim;
  
  //
  // we'll have spin and non-spin dependent versions separate
  //

  Matrix x0x1up, x0x1down, x1x0up, x1x0down, x1up, x1down, x2up, x2down, rhoup, rhodown;
  M_Init(x0x1up, hdim, hdim);
  M_Init(x0x1down, hdim, hdim);
  M_Init(x1x0up, hdim, hdim);
  M_Init(x1x0down, hdim, hdim);
  M_Init(x1up, hdim, hdim);
  M_Init(x1down, hdim, hdim);
  M_Init(x2up, hdim, hdim);
  M_Init(x2down, hdim, hdim);
  M_InitWithLocal(rhoup, rhoup_ptr, hdim, hdim);
  M_InitWithLocal(rhodown, rhodown_ptr, hdim, hdim);
  
  chempot=*chempot_pointer;
  kbt=*kbt_pointer;
  beta0=*beta0_pointer;
  
  signlist=(int *)realloc(signlist, norecs*sizeof(int));

  while (breakloop == 0) {
    iter++;
    if (iter == 100) {
      printf("sp2fermiinit is not converging: stop!\n");
      exit(1);
    }
    memset(x1up.Local, '\0', hdim*hdim*sizeof(REAL));
    for (i=0; i<hdim; i++) {
      for (j=i; j<hdim; j++) {
	if (i == j) {
          rhoup.Local[i+i*hdim] = (maxeval - hup[i+i*hdim] - chempot)/maxminusmin;
	  rhodown.Local[i+i*hdim] = (maxeval - hdown[i+i*hdim] - chempot)/maxminusmin;
	}
	else {
	  rhoup.Local[j+i*hdim] = -hup[j+i*hdim]/maxminusmin;
          rhoup.Local[i+j*hdim] = rhoup.Local[j+i*hdim];

	  rhodown.Local[j+i*hdim] = -hdown[j+i*hdim]/maxminusmin;
          rhodown.Local[i+j*hdim] = rhodown.Local[j+i*hdim];
	}
      }
      x1up.Local[i+i*hdim] = -1.0/maxminusmin;
    }
    memcpy(x1down.Local, x1up.Local, hdim*hdim*sizeof(REAL));

    M_Push(rhoup);
    M_Push(rhodown);
    M_Push(x1up);
    M_Push(x1down);

//    printf("Tr(rhoup)= %f\n", rhoup.Trace());
//    printf("Tr(rhodown)= %f\n", rhodown.Trace());
//    printf("Tr(x1up)= %f\n", x1up.Trace());
//    printf("Tr(x1down)= %f\n", x1down.Trace());
    
    for (ii=0; ii<norecs; ii++) {
      // density matrix squared
      M_Multiply(rhoup, rhoup, x2up);
      M_Multiply(rhodown, rhodown, x2down);
      trx = M_Trace(rhoup) + M_Trace(rhodown);
//printf("Iter=%d, sp2iter=%d, Tr(X)=%f\n", iter, ii, trx);
      trx2 = M_Trace(x2up) + M_Trace(x2down);
      if (fabs(trx2 - totne) < fabs(2.0 * trx - trx2 - totne))
        mysign = -1;
      else
	mysign = 1;
      signlist[ii] = mysign;

      M_Multiply(rhoup, x1up, x0x1up);
      M_Multiply(rhodown, x1down, x0x1down);
      
      M_Multiply(x1up, rhoup, x1x0up);
      M_Multiply(x1down, rhodown, x1x0down);
      // update response
      if (mysign==1) {
        M_Multiply(2.0, x1up, x1up);
	M_Subtract(x1up, x0x1up, x1up);
	M_Subtract(x1up, x1x0up, x1up);
	
        M_Multiply(2.0, x1down, x1down);
	M_Subtract(x1down, x0x1down, x1down);
	M_Subtract(x1down, x1x0down, x1down);
        
	M_Multiply(2.0, rhoup, rhoup);
	M_Subtract(rhoup, x2up, rhoup);
	M_Multiply(2.0, rhodown, rhodown);
	M_Subtract(rhodown, x2down, rhodown);
      }
      else if (mysign==-1) {
        M_Copy(x0x1up, x1up);
	M_Add(x1up, x1x0up, x1up);
	
        M_Copy(x0x1down, x1down);
	M_Add(x1down, x1x0down, x1down);

        M_Copy(x2up, rhoup);
	M_Copy(x2down, rhodown);
      }
    }
    
    trx = M_Trace(rhoup) + M_Trace(rhodown);
    trxomx = trx - (M_TraceX2(rhoup) + M_TraceX2(rhodown));
    trx1 = M_Trace(x1up) + M_Trace(x1down);

//printf("trx=%g, trx1=%g, trxomx=%g\n", trx, trx1, trxomx);
    
    if (fabs(trxomx)>1e-6 && fabs(trx1)>1e-6) {
      beta0 = trx1/trxomx;
      kbt = fabs(1.0/beta0);
    }
    else {
      beta0 = 1e10;
      kbt = 0.0;
    }

//printf(" beta0=%g\n", beta0);
//printf(" kbt=%f\n", kbt);

    if (fabs(trx1)>1e-6)
      lambda = (totne - trx)/trx1;
    else {
      if (totne > trx)
        lambda = maxshift;
      else
        lambda = -maxshift;
    }

//printf(" lambda=%f\n", lambda);

    if (fabs(lambda) > maxshift)
      lambda = sign(maxshift, lambda);

    chempot = chempot + lambda;
//printf(" chemical potential=%f\n", chempot);

    preverror3 = preverror2;
    preverror2 = preverror;
    preverror = occerror;
        
    occerror = fabs(occ - trx);
//printf(" occupation error=%f\n", occerror);

    // how we figure if we've reached convergence. an absolute
    // tolerance works well in double precision, but in single
    // precision we need to look for noise when we're near 
    // self-consistency

    if (sizeof(REAL) == 8) {
      if (occerror < breaktol)
        breakloop = 1;
    }
    else {
      if (occerror == preverror || occerror == preverror2 || occerror == preverror3 || iter == 25 )
        breakloop = 1;
    }
  }
  
  //
  // little bit of fine tuning...
  //
  M_Multiply(lambda, x1up, x1up);
  M_Add(rhoup, x1up, rhoup);
  M_Multiply(lambda, x1down, x1down);
  M_Add(rhodown, x1down, rhodown);

  M_DeallocateDevice(x0x1up);
  M_DeallocateLocal(x0x1up);
  M_DeallocateDevice(x0x1down);
  M_DeallocateLocal(x0x1down);
  M_DeallocateDevice(x1x0up);
  M_DeallocateLocal(x1x0up);
  M_DeallocateDevice(x1x0down);
  M_DeallocateLocal(x1x0down);
  M_DeallocateDevice(x1up);
  M_DeallocateLocal(x1up);
  M_DeallocateDevice(x1down);
  M_DeallocateLocal(x1down);
  M_DeallocateDevice(x2up);
  M_DeallocateLocal(x2up);
  M_DeallocateDevice(x2down);
  M_DeallocateLocal(x2down);
  M_DeallocateDevice(rhoup);
  M_DeallocateDevice(rhodown);

  *chempot_pointer=chempot;
  *kbt_pointer=kbt;
  *beta0_pointer=beta0;
}

void sp2fermi_nospin(REAL bndfil, int hdim, REAL *bo_pointer, 
                     REAL maxeval, REAL *h_pointer, REAL maxminusmin, 
                     REAL *chempot_pointer, int norecs, REAL *kbt_pointer, 
                     REAL *beta0_pointer, REAL breaktol) {
  //
  // This is an implementation of Niklasson's SP2 algorithm for the
  // Fermi operator (i.e., finite temperature SP2)
  //
  // GERSHGORIN and SP2FERMIINIT must be run first to initialize 
  //
  int i, j, ii;
  int iter, breakloop;
  REAL occ=1.0*bndfil*hdim;
  REAL trx, trx1;
  REAL lambda=0.0;
  REAL maxshift = 1.0;
  REAL occerror=0.0;
  REAL preverror = 0.0, preverror2 = 0.0, preverror3 = 0.0;
  Matrix bo, x2;
  M_InitWithLocal(bo, bo_pointer, hdim, hdim);
  M_Init(x2, hdim, hdim);

  iter = 0;
  breakloop = 0;

  REAL chempot=*chempot_pointer;
  REAL kbt=*kbt_pointer;
  REAL beta0=*beta0_pointer;

//  printf("Begin sp2fermi_loop\n");

  while (breakloop == 0) {
    iter = iter + 1;
    if (iter == 100) {
      printf("sp2fermi is not converging: stop!\n");
      exit(1);
    }
    for(i=0; i<hdim; i++) {
      for(j=i; j<hdim; j++) {
	if (i==j) {
	  bo_pointer[i+i*hdim] = (maxeval - h_pointer[i+i*hdim] - chempot)/maxminusmin;
	}
	else {
	  bo_pointer[j+i*hdim] = (0.0 - h_pointer[j+i*hdim])/maxminusmin;
	  bo_pointer[i+j*hdim] = bo_pointer[j+i*hdim];
	}
      }
    }
    M_Push(bo);

    for(ii = 0;ii<norecs;ii++) {
      //
      // bo^2
      //
      M_Multiply(bo, bo, x2);
      //
      // 'recurse' using the sequence of operations defined
      // in sp2fermiinit
      // 
      // Can we make this faster by choosing between +1 and -1? -- EJS
      M_Subtract(bo, x2, x2);
      M_Multiply((REAL)signlist[ii], x2, x2);
      M_Add(bo, x2, bo);
    }
    trx = M_Trace(bo);
    trx1 = M_TraceX2(bo);
    trx1 = beta0*(trx-trx1);

    if (fabs(trx1)>1e-6)
      lambda = (occ - trx)/trx1;
    else {
      if (occ > trx)
        lambda = maxshift;
      if (occ < trx)
        lambda = -maxshift;
    }
    //
    // new chempot
    //
    if (fabs(lambda) >= maxshift) {
      lambda = sign(maxshift, lambda);
    }
      
    chempot = chempot + lambda;

    preverror3 = preverror2;
    preverror2 = preverror;
    preverror = occerror;
    occerror = fabs(occ - trx);

    if (sizeof(REAL) == 8) {
      if (occerror < breaktol) {
	breakloop = 1;
      }
    }
    else {
      if (occerror == preverror || occerror == preverror2 || occerror == preverror3 || iter == 10 ) {
	breakloop = 1;
      }
    }
  }
  // if you forget the following you'll spend about a day
  // trying to find the bug in every other subroutine...
  M_Multiply(2.0, bo, bo);
  M_Pull(bo);

  *chempot_pointer=chempot;
  *kbt_pointer=kbt;

  M_DeallocateDevice(bo);
  M_DeallocateDevice(x2);
  M_DeallocateLocal(x2);
}

void sp2fermi_spin(REAL bndfil, int hdim, REAL *rhoup_ptr, REAL *rhodown_ptr, REAL maxeval, REAL *hup, REAL *hdown, REAL maxminusmin, REAL *chempot_pointer,
			  int norecs, REAL *kbt_pointer, REAL *beta0_pointer, REAL breaktol) {
  // This is an implementation of Niklasson's SP2 algorithm for the
  // Fermi operator (i.e., finite temperature SP2)
  // GERSHGORIN and SP2FERMIINIT must be run first to initialize 
  int i, j, ii;
  int iter=0, breakloop=0;
  REAL occ=2.0*bndfil*hdim, trx, trx1, lambda=0.0, maxshift = 1.0, occerror=0.0;
  REAL preverror = 0.0, preverror2 = 0.0, preverror3 = 0.0;
  REAL beta0=*beta0_pointer, chempot=*chempot_pointer;
  REAL totne=2.0*bndfil*hdim;
  
  Matrix rhoup, rhodown, x2up, x2down;
  M_InitWithLocal(rhoup, rhoup_ptr, hdim, hdim);
  M_InitWithLocal(rhodown, rhodown_ptr, hdim, hdim);
  M_Init(x2up, hdim, hdim);
  M_Init(x2down, hdim, hdim);
  
  while ( breakloop == 0 ) {
    iter++;
    if (iter == 50) {
      printf("sp2fermi is not converging: stop\n");
      exit(1);
    }
    
    for(i = 0; i<hdim; i++) {
      for (j = 0; j<hdim; j++) {
        if (i == j) {
          rhoup.Local[i+i*hdim] = (maxeval - hup[i+i*hdim] - chempot)/maxminusmin;
          rhodown.Local[i+i*hdim] = (maxeval - hdown[i+i*hdim] - chempot)/maxminusmin;
	}
        else {
          rhoup.Local[j+i*hdim] = -hup[j+i*hdim]/maxminusmin;
          rhoup.Local[i+j*hdim] = rhoup.Local[j+i*hdim];

          rhodown.Local[j+i*hdim] = -hdown[j+i*hdim]/maxminusmin;
          rhodown.Local[i+j*hdim] = rhodown.Local[j+i*hdim];
	}
      }
    }

    M_Push(rhoup);
    M_Push(rhodown);
    for (ii = 0; ii<norecs; ii++) {
      M_Multiply(rhoup, rhoup, x2up);
      M_Multiply(rhodown, rhodown, x2down);

      if (signlist[ii]==1) {
	M_Multiply(2.0, rhoup, rhoup);
	M_Subtract(rhoup, x2up, rhoup);
	M_Multiply(2.0, rhodown, rhodown);
	M_Subtract(rhodown, x2down, rhodown);
      }
      if (signlist[ii]==-1) {
	M_Copy(x2up, rhoup);
	M_Copy(x2down, rhodown);
      }
    }

    trx = M_Trace(rhoup) + M_Trace(rhodown);
    trx1 = beta0*(trx - (M_TraceX2(rhoup) + M_TraceX2(rhodown)));

    if (fabs(trx1)>1e-6)
      lambda = (totne - trx)/trx1;
    else {
      if (totne > trx)
        lambda = maxshift;
      else
        lambda = -maxshift;
    }

//printf(" lambda=%f\n", lambda);

    if (fabs(lambda) > maxshift)
      lambda = sign(maxshift, lambda);

    chempot = chempot + lambda;
//printf(" chemical potential=%f\n", chempot);

    preverror3 = preverror2;
    preverror2 = preverror;
    preverror = occerror;
        
    occerror = fabs(occ - trx);
//printf(" occupation error=%f\n", occerror);

    // how we figure if we've reached convergence. an absolute
    // tolerance works well in double precision, but in single
    // precision we need to look for noise when we're near 
    // self-consistency

    if (sizeof(REAL) == 8) {
      if (occerror < breaktol)
        breakloop = 1;
    }
    else {
      if (occerror == preverror || occerror == preverror2 || occerror == preverror3 || iter == 25 )
        breakloop = 1;
    }
  }
  M_Pull(rhoup);
  M_Pull(rhodown);

  M_DeallocateDevice(x2up);
  M_DeallocateLocal(x2up);
  M_DeallocateDevice(x2down);
  M_DeallocateLocal(x2down);
  M_DeallocateDevice(rhoup);
  M_DeallocateDevice(rhodown);

  *chempot_pointer=chempot;
}
  
