/*
 * cryptengine_nettle.cpp
 *
 * Copyright (C) 2006 Jernimo Pellegrini
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to:
 *   The Free Software Foundation, Inc.,
 *   51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 */


#include <boost/filesystem/fstream.hpp>

#include <gmpxx.h>

extern "C" {
#include <fcntl.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <nettle/aes.h>
#include <nettle/sha.h>
#include <nettle/base64.h>
#include <nettle/base16.h>
#include <nettle/cbc.h>
#include <nettle/nettle-types.h>
#include <nettle/yarrow.h>
#include <nettle/rsa.h>
#include <nettle/bignum.h>
}


#include "cctools.h"
#include "cryptengine_nettle.h"


// FIXME: RANDOM_DEVICE should actually be "platform::RANDOM_DEVICE",
// so we can have "platform" being Linux, BSD, Windows, etc.
#define RANDOM_DEVICE "/dev/urandom"

namespace apso {


// Static members:
struct yarrow256_ctx CryptEngineNettle::yarrow;


/**
 * Constructor for Nettle engine.
 *
 * This will initialize the RNG.
 *
 * FIXME: The RNG should reseed periodically, so the constructor 
 * should use two sources of entropy and use them to update the pools:
 * yarrow256_init   (yarrow&, 2, char ** sources);
 * Later the random_fill method can do:
 * yarrow256_update (yarrow&, source_number, 0, length, data);
 */
CryptEngineNettle::CryptEngineNettle()
	throw (std::runtime_error) {
        unsigned source_size = 50;
        unsigned length;
        char * seedbuffer = (char*) malloc (source_size*sizeof(char));
        yarrow256_init (& CryptEngineNettle::yarrow, 0, NULL);

	boost::filesystem::ifstream in (RANDOM_DEVICE);
        in.read(seedbuffer, source_size);
        length=in.gcount();
        if (length == 0)
                throw std::runtime_error ("CryptEngineNettle::CryptEngineNettle(): Can't initialize RNG: RANDOM_DEVICE returned zero bytes");

        yarrow256_seed (& CryptEngineNettle::yarrow, length, (uint8_t *)seedbuffer);
        free(seedbuffer);
}


/**
 * Returns a char buffer with random bits.
 *
 * @param size The size of the buffer.
 */
char *
CryptEngineNettle::random_fill_char(const size_t size) throw (std::runtime_error, std::bad_alloc) {
	char *random_buf = (char *) malloc(sizeof (char) * size);
	unsigned usize   = size_to_unsigned(size);

	if (random_buf == NULL) {
		throw std::bad_alloc ();
	}

	yarrow256_random (& CryptEngineNettle::yarrow, usize, (uint8_t *)random_buf);

	return random_buf;
}

/**
 * Fills a bdata buffer with random bits. Resize the bdata to match 'size' given.
 *
 * @param buffer The buffer to be filled.
 * @param size The size of the buffer.
 */
void
CryptEngineNettle::random_fill(bdata& buffer, const size_t size) throw (std::runtime_error) {
	char *random_buf = random_fill_char(size);
	buffer.set_data(random_buf, size);
}


/**
 * Calculates the hash of a text.
 */
bdata_ptr
CryptEngineNettle::hash (const bdata& text)
	throw (std::bad_alloc,std::runtime_error) {
	
	struct sha256_ctx ctx;
	uint8_t* buffer = (uint8_t *) text.get_data();
	uint8_t* digest = (uint8_t *) malloc (SHA256_DIGEST_SIZE * sizeof(uint8_t));

	sha256_init(&ctx);
	sha256_update(&ctx, size_to_unsigned(text.get_size()), buffer);
	sha256_digest(&ctx, SHA256_DIGEST_SIZE, digest);
	free(buffer);

	if (digest == 0) {
		throw std::bad_alloc();
	}

	bdata_ptr result (new bdata ((char *) digest, SHA256_DIGEST_SIZE));
	free(digest);

	return result;
}


/**
 * Decodes data using base64.
 */
bdata_ptr
CryptEngineNettle::decode(const bdata& text)
	throw (std::bad_alloc,std::runtime_error) {

	unsigned usize = size_to_unsigned(text.get_size());
	unsigned decoded_size = BASE64_DECODE_LENGTH(usize);

	uint8_t *decoded_text = (uint8_t *) malloc (decoded_size * sizeof(uint8_t));
	struct base64_decode_ctx b64_ctx;
	base64_decode_init(&b64_ctx);

	uint8_t *encoded_text = (uint8_t *) text.get_data();

	int done = base64_decode_update(&b64_ctx, &decoded_size,
			decoded_text, usize, encoded_text);
	if (!done) {
		std::cerr << "+++\n";
		std::cerr << encoded_text << "\n";
		std::cerr << "+++\n";
		throw std::runtime_error ("CryptEngineNettle::decode(): Could not decode!");
	}
	base64_decode_final(&b64_ctx);
	free(encoded_text);

	bdata_ptr result (new bdata((char *) decoded_text, unsigned_to_size(decoded_size)));
	free(decoded_text);

	return result;
}

/**
 * Encodes data using base64.
 */
bdata_ptr
CryptEngineNettle::encode(const bdata& text)
        throw (std::bad_alloc,std::runtime_error) {

	unsigned usize = size_to_unsigned(text.get_size());
	unsigned encoded_size = BASE64_ENCODE_LENGTH(usize) + BASE64_ENCODE_FINAL_LENGTH;

        uint8_t *encoded_buf = (uint8_t *) malloc (encoded_size * sizeof(uint8_t));
	uint8_t *text_buf = (uint8_t *) text.get_data();

        struct base64_encode_ctx b64_ctx;
        base64_encode_init(&b64_ctx);

        unsigned done = base64_encode_update(&b64_ctx, encoded_buf, usize, text_buf);
        done += base64_encode_final(&b64_ctx, encoded_buf + done);
	free(text_buf);

	bdata_ptr result (new bdata ((char *) encoded_buf, unsigned_to_size(done)));
	free(encoded_buf);

	return result;
}

/**
 * Encodes data using base16.
 */
bdata_ptr
CryptEngineNettle::encode16(const bdata& text)
        throw (std::bad_alloc,std::runtime_error) {

	unsigned usize = size_to_unsigned(text.get_size());
        unsigned encoded_size = BASE16_ENCODE_LENGTH(usize);

        uint8_t *encoded_buf = (uint8_t *) malloc (encoded_size * sizeof(uint8_t));
	uint8_t *text_buf = (uint8_t *) text.get_data();

	base16_encode_update(encoded_buf, usize, text_buf);

	free(text_buf);

	bdata_ptr result (new bdata ((char *) encoded_buf, unsigned_to_size(encoded_size)));
	free(encoded_buf);

	return result;
}

/**
 * Does symmetric encryption.
 *
 * Encrypts bdatas. This is a GOOD way of encrypting things, since values and size will get
 * in and out encapsulated, and you don't have to worry about them.
 * 
 * Currently uses AES.
 * 
 */
bdata_ptr
CryptEngineNettle::sym_enc  (const bdata& text, const Key& key) throw(std::runtime_error, std::bad_alloc) {
	unsigned usize = size_to_unsigned(text.get_size());
	
	//
	// Pad the original
	//

	uint8_t *padded_clear_text;
	unsigned padded_size = ((usize / AES_BLOCK_SIZE) + 1) * AES_BLOCK_SIZE;
	short pad_size = padded_size - usize;

	uint8_t *text_buf = (uint8_t *) text.get_data();
	padded_clear_text = (uint8_t *) malloc (padded_size * sizeof(uint8_t));
	memmove (padded_clear_text, text_buf, text.get_size());
	free(text_buf);

	unsigned i;
	for (i = usize; i < padded_size; i++) {
		padded_clear_text[i] = (char) pad_size;
	}


	//
	// Encrypt
	//
	
	uint8_t *IV = (uint8_t *) random_fill_char(32);
	struct CBC_CTX (struct aes_ctx, AES_BLOCK_SIZE) cr_ctx;

	bdata_ptr b_key = key.get_value();
	unsigned keysize = size_to_unsigned(b_key->get_size());
	uint8_t *keydata = (uint8_t *) b_key->get_data();

	aes_set_encrypt_key (&(cr_ctx.ctx), keysize, keydata); // Set AES key

	uint8_t * cr_text = (uint8_t *) calloc (padded_size + 32, sizeof(uint8_t));
	CBC_SET_IV (&cr_ctx, IV);
	CBC_ENCRYPT (&cr_ctx, aes_encrypt, padded_size, cr_text + 32, padded_clear_text); // The IV will take the first 32 bytes of data

	free(keydata);

	// Put the IV so it's possible to decrypt:
	memmove(cr_text,IV,32);

	std::cout << "IV:  ";
	for (i=0; i<32; i++)
		printf(" %02x", IV[i]);
	std::cout << "\n";

	//
	// Encode using base64:
	//
	
	size_t encoded_size = unsigned_to_size(padded_size + 32); // Will encode the IV too!
	
	bdata_ptr b_text (new bdata((char *) cr_text, encoded_size));
	free (cr_text);
	bdata_ptr result = encode(*b_text);

        if (result == 0)
                throw std::bad_alloc();
	
        return result;
}

/**
 * Does symmetric decrypting.
 *
 * Decrypts bdatas. This is a GOOD way of decrypting things, since values and size will get
 * in and out encapsulated, and you don't have to worry about them.
 * 
 * Currently uses AES.
 * 
 */
bdata_ptr
CryptEngineNettle::sym_dec(const bdata& text, const Key& key)
	throw(std::runtime_error, std::bad_alloc) {

	// Decode from base64:
	bdata_ptr decoded = decode(text);

	uint8_t *decoded_text = (uint8_t *) decoded->get_data();
	size_t decoded_size = decoded->get_size();

	// Check block size
        if (decoded_size % AES_BLOCK_SIZE){
		throw runtime_error("CryptEngineNettle::sym_dec(): Wrong block size for AES!\n");
	}
	
	//
	// Decrypt:
	//

	// Get the IV:
        uint8_t *IV = (uint8_t *) malloc (sizeof(uint8_t) * 32);
	memmove(IV, decoded_text, 32);

	// Some debugging:
	int i;
        std::cout << "IV:  ";
        for (i = 0; i < 32; i++)
                printf(" %02x", IV[i]);
        std::cout << "\n";

	//
	// Decrypt:
	//
	
	decoded_size -= 32; // The rest of the encrypted text doesn't include the 32 IV bytes

	bdata_ptr b_key = key.get_value();
	unsigned keysize = size_to_unsigned(b_key->get_size());
	uint8_t *keydata = (uint8_t *) b_key->get_data();

        struct CBC_CTX (struct aes_ctx, AES_BLOCK_SIZE) cr_ctx;
        aes_set_decrypt_key (&(cr_ctx.ctx), keysize, keydata); // Set AES key

	uint8_t * result = (uint8_t *) malloc (decoded_size * sizeof(uint8_t));
	if (result == 0) {
		std::cerr << "result pointer is zero after malloc()\n";
		throw std::bad_alloc();
	}
        CBC_SET_IV (&cr_ctx, IV);
        CBC_DECRYPT (&cr_ctx, aes_decrypt, decoded_size, result, decoded_text + 32); // 32 bytes for the IV

	free(keydata);

	if (result == 0) {
		std::cerr << "result pointer is zero after decrypt()\n";
		throw std::bad_alloc();
        }
	free (decoded_text);

	size_t pad = result[decoded_size-1];
	
	bdata_ptr b_text (new bdata((char *) result, decoded_size - pad));
	free (result);

	if (b_text == 0) {
		std::cerr << "result pointer is zero after realloc()\n";
		throw std::bad_alloc();
	}

	return b_text;
}

/**
 * Does asymmetric encrypting.
 *
 * Currently uses RSA.
 * 
 */
bdata_ptr
CryptEngineNettle::asym_enc (const bdata& text, const Key& key) throw(std::runtime_error, std::bad_alloc) {

	//
	// Extract the RSA public key to encrypt.
	//
	
	bdata_ptr b_key = key.get_value();
	unsigned keysize = size_to_unsigned(b_key->get_size());
	uint8_t *keydata = (uint8_t *) b_key->get_data();

	struct rsa_public_key rsa_key;
	mpz_t x;
	
	rsa_public_key_init(&rsa_key);
	if (!rsa_keypair_from_sexp(&rsa_key, NULL, 0, keysize, keydata)) {
		throw std::runtime_error ("CryptEngineNettle::asym_enc(): Can't read public key.");
	}
	mpz_init(x);
	free(keydata);


	//
	// Encrypt.
	//

	unsigned usize = size_to_unsigned(text.get_size());
	uint8_t *text_buf = (uint8_t *) text.get_data();

	if (!rsa_encrypt(&rsa_key,
			 & CryptEngineNettle::yarrow,
			 (nettle_random_func) yarrow256_random,
			 usize,
			 text_buf,
			 x)) {
			 throw std::runtime_error ("CryptEngineNettle::asym_enc(): RSA_ENCRYPT FAILED");
	}
	rsa_public_key_clear(&rsa_key);
	free(text_buf);

	unsigned unew_size = nettle_mpz_sizeinbase_256_u (x);
	uint8_t *pre_result = (uint8_t *) malloc (unew_size * sizeof(uint8_t));
	nettle_mpz_get_str_256(unew_size, pre_result, x);

	if (pre_result == 0) {
		throw std::runtime_error ("CryptEngineNettle::asym_enc(): Could not RSA-encrypt.");
	}

	size_t new_size = unsigned_to_size(unew_size);
	bdata_ptr b_pre_result (new bdata((char *) pre_result, new_size));
        free(pre_result);

	bdata_ptr result = encode(*b_pre_result);

	return result;
}

/**
 * Does asymmetric decrypting.
 *
 * Currently uses RSA.
 */
bdata_ptr
CryptEngineNettle::asym_dec (const bdata& text, const Key& key) throw(std::runtime_error, std::bad_alloc) {
	//
	// Base64 decoding
	//
	
	bdata_ptr decoded = decode(text);
	unsigned decoded_size = decoded->get_size();
	uint8_t *decoded_text = (uint8_t *) decoded->get_data();


	//
	// Extract RSA private key to decrypt.
	//

	bdata_ptr b_key = key.get_value();
	unsigned keysize = size_to_unsigned(b_key->get_size());
	uint8_t *keydata = (uint8_t *) b_key->get_data();
	
	mpz_t x;
	struct rsa_private_key rsa_key;
        rsa_private_key_init(&rsa_key);
	if (!rsa_keypair_from_sexp(NULL, &rsa_key, 0, keysize, keydata))
		throw std::runtime_error ("CryptEngineNettle::asym_dec(): Can't read private key");


	//
	// Decrypt.
	//

	nettle_mpz_init_set_str_256_u(x, decoded_size, decoded_text);
        nettle_mpz_set_str_256_u(x, decoded_size, decoded_text);
	free(decoded_text);
	
	uint8_t *plaintext = (uint8_t *)
		malloc ((unsigned_to_size(decoded_size)) * sizeof(uint8_t *));
	unsigned plainsize = decoded_size;

        int res;
	if ( ! (res = rsa_decrypt(&rsa_key, &plainsize, plaintext, x)) ) {
		//fprintf(stderr,"Can't decrypt. res = %d, length = %d, size = %d\n",res,length,*size);
		throw std::runtime_error ("Can't decrypt.");
	}

	if (plaintext == 0) {
		throw std::runtime_error("CryptEngineNettle::asym_dec(): Could not RSA-decrypt.");
	}


	bdata_ptr result (new bdata((char *) plaintext, plainsize));
        free(plaintext);

	return result;
}

/**
 * Convert from size_t to unsigned.
 */
unsigned
CryptEngineNettle::size_to_unsigned(const size_t& value) const {
	/* FIXME: we should check if the value will be truncated! */
	return static_cast<unsigned>(value);
}

/**
 * Convert from unsigned to size_t.
 */
size_t
CryptEngineNettle::unsigned_to_size(const unsigned& value) const {
	/* FIXME: we should check if the value will be truncated! */
	return static_cast<size_t>(value);
}

}
