/* tls.c -- encryption and authentication code
   Copyright (C) 2006-2020 Maximiliano Pin

   This program is free software; you can redistribute it and/or
   modify it under the terms of the GNU General Public License
   as published by the Free Software Foundation; either version 2
   of the License, or (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License along
   with this program; if not, write to the Free Software Foundation, Inc.,
   51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
*/

#define _POSIX_SOURCE

#include <unistd.h>		/* read, write, fcntl */
#include <fcntl.h>		/* fcntl */
#include <errno.h>		/* errno */
#include <stdint.h>		/* intptr_t */
#include "common.h"
#if HAVE_GNUTLS
#include <gnutls/gnutls.h>
#endif
#include "tls.h"
#include "demux.h"
#include "user_iface.h"

extern int errno;

#if HAVE_GNUTLS

/* Implementation based on example at:
   https://gnutls.org/manual/html_node/Simple-client-example-with-anonymous-authentication.html
*/

/*
#define TLS_DEBUG 1
*/

#define CHECK_TLS_SUCCESS(x) CHECK ((x) == GNUTLS_E_SUCCESS)
#define CALL_TLS(x) do { int ret_ = (x); \
                         if (ret_ != GNUTLS_E_SUCCESS) \
                           ui_output_err ("%s", gnutls_strerror (ret_)); \
                    } while (0)

static gnutls_anon_server_credentials_t anoncred_srv;
static gnutls_anon_client_credentials_t anoncred_cli;

/* Prototypes */
static void cb_continue_handshake (int fd, void *data);
#ifdef TLS_DEBUG
static void print_cipher_suite_list (const char *priorities);
#endif

#endif

static BOOL enabled = FALSE;

void
tls_init (void)
{
#if HAVE_GNUTLS
	int ret;
	CHECK_TLS_SUCCESS (ret = gnutls_global_init ());
	CHECK_TLS_SUCCESS (ret = gnutls_anon_allocate_server_credentials (&anoncred_srv));
	CHECK_TLS_SUCCESS (ret = gnutls_anon_allocate_client_credentials (&anoncred_cli));
	enabled = TRUE;
	return;
error:
	ui_output_err ("%s", gnutls_strerror (ret));
	return;
#endif
}

void
tls_finish (void)
{
#if HAVE_GNUTLS
	gnutls_anon_free_server_credentials (anoncred_srv);
	gnutls_anon_free_client_credentials (anoncred_cli);
	gnutls_global_deinit ();
#endif
}

BOOL
tls_enabled (void)
{
	return enabled;
}

BOOL
tls_session_init (contact_t *contact, int server)
{
#if HAVE_GNUTLS
	gnutls_session_t session;
	void *cred;
	int ret;
	const char *errp = NULL;
	const char *priorities = "PERFORMANCE:+ANON-ECDH:+ANON-DH";

	ret = gnutls_init (&session, server ? GNUTLS_SERVER : GNUTLS_CLIENT);
	if (ret != GNUTLS_E_SUCCESS) {
		ui_output_err ("SSL/TLS init failed!");
		ui_output_err ("%s", gnutls_strerror (ret));
		return FALSE;
	}

        ret = (gnutls_priority_set_direct(session, priorities, &errp));
	if (ret != GNUTLS_E_SUCCESS) {
		ui_output_err ("SSL/TLS set priority failed!");
		ui_output_err ("%s (at: \"%s\")", gnutls_strerror (ret), errp);
		return FALSE;
	}

#ifdef TLS_DEBUG
	print_cipher_suite_list (priorities);
#endif

	cred = server ? (void*)anoncred_srv : (void*)anoncred_cli;
	CALL_TLS (gnutls_credentials_set (session, GNUTLS_CRD_ANON, cred));

	gnutls_transport_set_ptr
	    (session, (gnutls_transport_ptr_t)(intptr_t)contact->state.socket);

	contact->state.tls_info = (void *)session;
#endif
	return TRUE;
}

void
tls_session_deinit (contact_t *contact)
{
#if HAVE_GNUTLS
	gnutls_session_t session = (gnutls_session_t)contact->state.tls_info;
	if (session) {
		gnutls_deinit (session);
		contact->state.tls_info = NULL;
	}
#endif
}

void
tls_session_finish (contact_t *contact)
{
#if HAVE_GNUTLS
	gnutls_session_t session = (gnutls_session_t)contact->state.tls_info;
	if (session) {
		/* TODO i'm not sure if nonblocking affects gnutls_bye */
		fcntl (contact->state.socket, F_SETFL, 0);
		CALL_TLS (gnutls_bye (session, GNUTLS_SHUT_WR));
	}
#endif
}

void
tls_handshake (contact_t *contact)
{
#if HAVE_GNUTLS
	gnutls_session_t session = (gnutls_session_t)contact->state.tls_info;
	int ret;

	if (!session) {
		contact->state.cn_state = CS_CONNECTED;
		return;
	}

	ret = gnutls_handshake (session);
	if (ret < 0) {
		/* TODO what if the HELLO isn't still fully sent?
		        (the socket is still in demux!) */
		if (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) {
			if (contact->state.cn_state != CS_HANDSHAKING) {
				dmx_add_output_fd (contact->state.socket,
				                   cb_continue_handshake,
				                   (void*)contact);
				contact->state.cn_state = CS_HANDSHAKING;
			}
		}
		else {
			/*TODO*/
			ui_output_err ("SSL/TLS handshake failed!");
			ui_output_err ("%s", gnutls_strerror (ret));
			dmx_remove_output_fd (contact->state.socket);
			/* TODO should it be disconnected? or show warning */
			tls_session_finish (contact);
			contact->state.cn_state = CS_CONNECTED;
			ui_redraw_contacts ();
		}
	}
	else {
		dmx_remove_output_fd (contact->state.socket);
		contact->state.cn_state = CS_CONNECTED;
		/* this will create a window for the contact */
		ui_redraw_contacts ();
#ifdef TLS_DEBUG
		{
			char *desc = gnutls_session_get_desc(session);
			ui_output_info("Session info: %s", desc);
			gnutls_free(desc);
		}
#endif
	}
#else
	contact->state.cn_state = CS_CONNECTED;
#endif
}

ssize_t
tls_send (contact_t *contact, const void *data, size_t size)
{
	ssize_t ret; /* TODO think about the return value */

#if HAVE_GNUTLS
	gnutls_session_t session = (gnutls_session_t)contact->state.tls_info;

	if (session && CT_IS_CONNECTED (contact)) {
		ret = gnutls_record_send (session, data, size);
	}
	else
#endif
	{
		/* connection is not encrypted */
		ret = write (contact->state.socket, data, size);
	}

	return ret;
}

ssize_t
tls_recv (contact_t *contact, void *data, size_t size)
{
	ssize_t ret;

#if HAVE_GNUTLS
	gnutls_session_t session = (gnutls_session_t)contact->state.tls_info;

	/* TODO implement functionality of net_read() in transport.c and
	   remove that */
	/* TODO think about return value */

	if (session && CT_IS_CONNECTED (contact)) {
		ret = gnutls_record_recv (session, data, size);
		if (ret == 0) {
			ret = ERR; /* disconnect */
		}
		else if (ret < 0) {
			if (ret == GNUTLS_E_INTERRUPTED ||
			    ret == GNUTLS_E_AGAIN)
				ret = 0;
			else {
				ui_output_err ("Error receiving data.");
				ui_output_err ("%s", gnutls_strerror (ret));
				ret = ERR;
			}
		}
	}
	else
#endif
	{
		/* connection is not encrypted */
		ret = read (contact->state.socket, data, size);
		if (ret == 0) {
			ret = ERR; /* disconnect */
		}
		else if (ret < 0) {
			if (errno == EINTR || errno == EAGAIN)
				ret = 0;
			else {
				ui_output_err ("Error receiving data.");
				ret = ERR;
			}
		}
	}

	return ret;
}

void
tls_session_info (contact_t *contact)
{
#if HAVE_GNUTLS
	gnutls_session_t session = (gnutls_session_t)contact->state.tls_info;
	const char *tmp;
	gnutls_credentials_type_t cred;
	gnutls_kx_algorithm_t kx;
	int dh_bits;

	if (!session) {
		ui_output_info ("\\b\\2-\\0 SSL/TLS session not established");
		return;
	}

	/* print the key exchange's algorithm name */
	kx = gnutls_kx_get (session);
	tmp = gnutls_kx_get_name (kx);
	ui_output_info ("\\b\\2-\\0 Key Exchange: %s", tmp);

	/* check the authentication type used and switch to the appropriate */
	cred = gnutls_auth_get_type (session);
	switch (cred) {
	case GNUTLS_CRD_SRP:
		ui_output_info ("\\b\\2-\\0 SRP session with username %s",
		                gnutls_srp_server_get_username (session));
		break;
	case GNUTLS_CRD_ANON:	/* anonymous authentication */
		dh_bits = gnutls_dh_get_prime_bits (session);
		if (dh_bits) {
			ui_output_info ("\\b\\2-\\0 Anonymous DH using prime "
					"of %d bits", dh_bits);
		}
		else {
			ui_output_info ("\\b\\2-\\0 Anonymous authentication");
		}
		break;
	case GNUTLS_CRD_CERTIFICATE:	/* certificate authentication */
		/* check if we have been using ephemeral Diffie Hellman */
		if (kx == GNUTLS_KX_DHE_RSA || kx == GNUTLS_KX_DHE_DSS) {
			ui_output_info ("\\b\\2-\\0 Ephemeral DH using prime "
			                "of %d bits",
			                gnutls_dh_get_prime_bits (session));
		}
		/* if the certificate list is available, then print some
                   information about it */
		/*TODO print_x509_certificate_info (session);*/
		break;
	}

	/* print the protocol's name (ie TLS 1.0) */
	tmp = gnutls_protocol_get_name (gnutls_protocol_get_version (session));
	ui_output_info ("\\b\\2-\\0 Protocol: %s", tmp);

	/* print the certificate type of the peer. ie X.509 */
	tmp = gnutls_certificate_type_get_name
		(gnutls_certificate_type_get (session));
	ui_output_info ("\\b\\2-\\0 Certificate Type: %s", tmp);

	/* print the name of the cipher used. ie 3DES */
	tmp = gnutls_cipher_get_name (gnutls_cipher_get (session));
	ui_output_info ("\\b\\2-\\0 Cipher: %s", tmp);

	/* print the MAC algorithms name. ie SHA1 */
	tmp = gnutls_mac_get_name (gnutls_mac_get (session));
	ui_output_info ("\\b\\2-\\0 MAC: %s", tmp);
#endif
}

#if HAVE_GNUTLS

static void
cb_continue_handshake (int fd, void *data)
{
	tls_handshake ((contact_t *)data);
}

#ifdef TLS_DEBUG

static void
print_cipher_suite_list (const char *priorities)
{
	size_t i;
	int ret;
	unsigned int idx;
	const char *name;
	const char *err;
	unsigned char id[2];
	gnutls_protocol_t version;
	gnutls_priority_t pcache;

	ui_output_info("Cipher suites for %s", priorities);

	ret = gnutls_priority_init(&pcache, priorities, &err);
	if (ret < 0) {
		ui_output_err("Syntax error at: %s", err);
		return;
	}

	for (i = 0;; i++) {
		ret = gnutls_priority_get_cipher_suite_index(pcache, i, &idx);
		if (ret == GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE)
			break;
		if (ret == GNUTLS_E_UNKNOWN_CIPHER_SUITE)
			continue;

		name = gnutls_cipher_suite_info(idx, id, NULL, NULL, NULL,
		                                &version);

		if (name != NULL) {
			ui_output_info("%-50s\t0x%02x, 0x%02x\t%s",
			               name, (unsigned char) id[0],
			               (unsigned char) id[1],
			               gnutls_protocol_get_name(version));
		}
	}

	return;
}

#endif
#endif
