/*
* Ipq Berkeley db Daemon  BAN - ibd-ban
* written by ale in milano on 13 sep 2008

Copyright (C) 2008-2023 Alessandro Vesely

This file is part of Ipqbdb.

Ipqbdb is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Ipqbdb is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with Ipqbdb.  If not, see <http://www.gnu.org/licenses/>.

*/
#include <syslog.h>
#include <popt.h>

// format of data packet
#include "config_names.h"
#include "dbstruct.h"

#include <arpa/inet.h>

#include "ip_util.h"
#include "initial_count.h"
#include "percent_prob.h"
#include "ext_exec.h"


static app_private ap;
static char const err_prefix[] = "ibd-ban";
#include "logline_written_block.h"

#include "setsig_func.h"
#include <assert.h>

static int initial_count = IPQBDB_INITIAL_COUNT;
static double initial_decay = IPQBDB_INITIAL_DECAY;
static double index_factor = IPQBDB_INDEX_FACTOR;
static char *db_block_name = IPQBDB_DATABASE_NAME;
static char *db_white_name = IPQBDB_WHITE_DATABASE_NAME;
static char *db_descr_name = IPQBDB_DESCR_DATABASE_NAME;
static char *reason_string = "n.a.";
static char **ip_address;
static int version_opt, help_opt, syslog_opt, cleanup_opt, verbose = -1;
static int force_probability_opt, force_decay_opt, force_reason_opt;
static int force_add_category_opt; // not a "force" option, but passed like one
static int exec_connkill_opt;
static struct poptOption opttab[] =
{
	{"ip-addr", 'i', POPT_ARG_ARGV, &ip_address, 0,
	"Address or range to ban, repeatable option", "ip_addr"},
	{"reason", 'r', POPT_ARG_STRING|POPT_ARGFLAG_SHOW_DEFAULT, &reason_string, 0,
	"The reason string (or its record number) for the block record", "int-or-string"},
	{"force-reason", '\0', POPT_ARG_NONE, &force_reason_opt, 0,
	"Change the reason even if probability and decay are not changed", NULL},
	{"initial-decay", 't', POPT_ARG_DOUBLE|POPT_ARGFLAG_SHOW_DEFAULT, &initial_decay, 0,
	"The time taken for the block probability to halve (seconds)", "float"},
	{"force-decay", '\0', POPT_ARG_NONE, &force_decay_opt, 0,
	"Change the decay even if the current value is higher", NULL},
	{"initial-count", 'c', POPT_ARG_INT|POPT_ARGFLAG_SHOW_DEFAULT, &initial_count, 0,
	"Set initial probability count", "integer"},
	{"force-probability", '\0', POPT_ARG_NONE, &force_probability_opt, 0,
	"Change the probability even if the current value is higher", NULL},
	{"index-factor", 'f', POPT_ARG_DOUBLE|POPT_ARGFLAG_SHOW_DEFAULT, &index_factor, 0,
	"Set range index factor", "float"},
	{"log-category", '\0', POPT_ARG_NONE, &force_add_category_opt, 0,
	"Enable logging the report category", NULL},
	{"exec-connkill", 'e', POPT_ARG_NONE, &exec_connkill_opt, 0,
	"Execute connection kill if probability reaches 100%", NULL},
	{"verbose", 'v', POPT_ARG_INT|POPT_ARGFLAG_OPTIONAL, &verbose, 0,
	"Be verbose", "level"},
	{"log-syslog", 'l', POPT_ARG_NONE, &syslog_opt, 0,
	"Log to syslog rather than std I/O", NULL},
	{"version", 'V', POPT_ARG_NONE, &version_opt, 0,
	"Print version number and exit", NULL},
	{"help", 'h', POPT_ARG_NONE, &help_opt, 0,
	"This help.", NULL},
	{"db-block", 'b', POPT_ARG_STRING|POPT_ARGFLAG_SHOW_DEFAULT, &db_block_name, 0,
	"The database of blocked addresses", "filename"},
	{"db-white", 'w', POPT_ARG_STRING|POPT_ARGFLAG_SHOW_DEFAULT, &db_white_name, 0,
	"The whitelist database.", "filename"},
	{"db-descr", 'd', POPT_ARG_STRING|POPT_ARGFLAG_SHOW_DEFAULT, &db_descr_name, 0,
	"The reason description table.", "filename"},
	{"db-cleanup", '\0', POPT_ARG_NONE, &cleanup_opt, 0,
	"On exit cleanup environment (__db.00? files) if not still busy", NULL},
	POPT_TABLEEND
};

typedef enum have_addr
{
	have_no_addr,
	have_ipv4_addr,
	have_ipv6_addr,
	have_both_addr,
	have_error_addr
} have_addr;

static int get_all_ip_address(ip_range **ip_addr)
/*
* Read all addresses in a NULL-terminated array.  Return -1 in case of
* error, NULL if no address is valid, or a have_addr flag.
*/
{
	unsigned count = 0;
	char **a = ip_address;
	if (a && *a)
		while (*a++)
			++count;

	if (count == 0)
	{
		*ip_addr = NULL;
		return 0;
	}

	ip_range *addr = *ip_addr = calloc(count + 1, sizeof(ip_range));
	if (addr == NULL)
		return -1;

	addr[count].ip = -1; // sentinel

	int collect = 0, good = 0;
	for (a = ip_address; *a; ++a, ++addr)
	{
		int err;
		if ((err = parse_ip_address(*a, addr, NULL)) != 0)
		{
			report_error(&ap, LOG_ERR,
				"invalid ip_address %s (%s)\n", *a,
				parse_ip_invalid_what(err));
			collect |= have_error_addr;
		}
		else if (addr->ip)
		{
			++good;
			if (addr->ip == 4)
			{
				if (addr->args > 1)
				{
					--good;
					report_error(&ap, LOG_ERR,
						"invalid range %s (IPv4 range not supported)\n",
						*a);
					collect |= have_error_addr;
					addr->ip = 0;
				}
				else
					collect |= have_ipv4_addr;
			}
			else
			{
				assert(addr->ip == 6);
				int invalid = 0;
				if (addr->args > 1) // range
				{
					unsigned char ipv6[16];
					int plen = addr->plen;
					if (plen == 0)
					{
						memcpy(ipv6, addr->u2.ipv6, sizeof ipv6);
						plen = range_ip(addr->u.ipv6, ipv6, 16);
						if (memcmp(ipv6, addr->u2.ipv6, sizeof ipv6) != 0)
							++invalid;
						else
							addr->plen = plen;
					}

					if (plen >= 128)
						++invalid;

					memcpy(ipv6, addr->u.ipv6, sizeof ipv6);
					first_in_range(ipv6, plen);
					if (memcmp(ipv6, addr->u.ipv6, sizeof ipv6) != 0)
						++invalid;
					if (invalid)
					{
						--good;
						report_error(&ap, LOG_ERR,
							"invalid range %s (not valid /%d)\n",
							*a, plen);
						collect |= have_error_addr;
						addr->ip = 0;
					}
				}

				if (!invalid)
					collect |= have_ipv6_addr;
			}
		}
	}

	if (good == 0)
	{
		free(*ip_addr);
		*ip_addr = NULL;
	}
	return collect;
}

static int display_check_whitelist(DB *db, ip_range *ip_addr, double *decay)
/*
* Return 0 if addr is to be blocked, 1 otherwise; < 0 on error.
*
* For IPv6 ranges, consider ip_addr to be whitelisted if it is fully
* contained in the range; 
*/
{
	assert(db);
	assert(ip_addr);
	assert(decay);

	unsigned char end_range[16];
	int rtc = check_whitelist_e(db, base_ip_u(ip_addr), decay, end_range);
	if (rtc < 0)
		return rtc;

	const int is_range = ip_addr->args > 1;
	if (rtc == 1 && is_range) // found range
	{
		assert(ip_addr->ip == 6);
		if (memcmp(end_range, ip_addr->u2.ipv6, sizeof end_range) < 0)
			rtc = 0; // not fully contained
	}

	if (rtc == 0)
	{
		*decay = initial_decay;
	}

	if (verbose)
	{
		double const d = *decay;
		char addr[INET_RANGESTRLEN];

		if (rtc == 1)
		{
			printf("%s is whitelisted", is_range?
				snprint_range(addr, sizeof addr, ip_addr):
				my_inet_ntop(base_ip_u(ip_addr), addr));
			if (d > 0)
				printf(" with initial decay of %g seconds\n", d);
			else
				fputs(" and won't be blocked\n", stdout);
		}
		else if (rtc == 0)
		{
			printf("%s is not whitelisted, initial decay of %g seconds\n",
				is_range?
				snprint_range(addr, sizeof addr, ip_addr):
				my_inet_ntop(base_ip_u(ip_addr), addr), d);
		}
	}

	return rtc;
}

static const char *month_abbr(unsigned ndx)
{
	static const char *abbr[] =
	{"Jan", "Feb", "Mar", "Apr", "May", "Jun",
	 "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"};
	if (ndx >= sizeof abbr / sizeof abbr[0]) return "---";
	return abbr[ndx];
}

static void tab_printtime(char const *hdr, time_t ref, time_t now)
{
	struct tm tm = *localtime(&ref);
	unsigned elapsed = now - ref;
	unsigned dd, hh, mm;
	char buf[64], *p = &buf[0], buf2[32];;

	memset(buf, ' ', sizeof buf);
	dd = elapsed / (24UL * 3600UL);
	if (dd)
	{
		elapsed -= dd * 24UL * 3600UL;
		p += sprintf(p, "(%u day(s) ", dd);
	}
	hh = elapsed / 3600UL;
	if (hh || dd)
	{
		elapsed -= hh * 3600UL;
		p += sprintf(p, "%s%2uh ", dd? "": "(", hh);
	}
	mm = elapsed / 60UL;
	elapsed -= mm * 60UL;
	sprintf(p, "%s%02um %2us ago)", hh||dd? "": "(", mm, elapsed);

	sprintf(buf2, "%d %s %d",
		tm.tm_mday, month_abbr(tm.tm_mon), tm.tm_year + 1900);

	printf("\t%-15s %18s %29s\n", hdr, buf2, buf);
}

#define INET6_RANGESTRLEN (INET6_ADDRSTRLEN + 5)
static const char *display_range(ip_u *ip_addr, int plen, char addr[INET6_RANGESTRLEN])
{
	if (ip_addr->ip == 6 && plen > 0 && plen < 128)
	{
		unsigned char rng[16];
		memcpy(rng, ip_addr->u.ipv6, sizeof rng);
		first_in_range(rng, plen);
		sprintf(addr, "%s/%d", inet_ntop(AF_INET6, rng, addr, INET6_RANGESTRLEN), plen);
		return addr;
	}
	else
		return my_inet_ntop(ip_addr, addr);
}

static void display_written_block(ip_u *ip_addr, void *key_data,
	int reason_id, int write_block_force,
	ip_data_t *ip_data, ip_data_t *old_data, DB *db_d)
{
	// if syslog_opt, log a line anyway (unlikely given -v).
	if (ap.mode)
		logline_written_block(ip_addr, key_data, reason_id,
			write_block_force, ip_data, old_data, db_d);

	char addr[INET6_RANGESTRLEN];
	display_range(ip_addr, old_data? old_data->plen: ip_data->plen, addr);
	char descr[19 /* IPQBDB_DESCR_MAX_LENGTH */];
	time_t now = ip_data->last_update;

	get_descr(db_d, ip_data->reason_id,
		get_descr_quoted | get_descr_add_id, descr, sizeof descr);
	if (old_data)
	{
		printf("changed probability record for %s", addr);
		if (ip_addr->ip == 6 &&
			memcmp(ip_addr->u.ipv6, key_data, sizeof ip_addr->u.ipv6))
		{
			display_range(ip_addr, ip_data->plen, addr);
			printf("->%s", addr);
		}
		if (ip_data->reason_id == old_data->reason_id)
			printf(", %s", descr);
		fputc('\n', stdout);
		tab_printtime("created:", old_data->created, now);
		tab_printtime("last updated:", old_data->last_update, now);
		if (old_data->last_block)
			tab_printtime("last blocked:", old_data->last_block, now);
		printf("\t%-15s %18g -> %26g\n", "decay:",
			old_data->decay, ip_data->decay);
		printf("\t%-15s %10d=%6.2f%% -> %18d=%6.2f%%\n", "probability:",
			old_data->probability, percent_prob(old_data->probability),
			ip_data->probability, percent_prob(ip_data->probability));
		if (ip_data->reason_id != old_data->reason_id)
		{
			char old_descr[sizeof descr];
			get_descr(db_d, old_data->reason_id,
				get_descr_quoted | get_descr_add_id,
				old_descr, sizeof old_descr);
			printf("\t%-15s %18.18s -> %26.18s\n", "reason:",
				old_descr, descr);
		}
	}
	else
	{
		printf("inserted probability record for %s (reason %d)\n",
			addr, ip_data->reason_id);
		printf("\tdecay: %g, probability: %d=%.2f%%\n",
			ip_data->decay,
			ip_data->probability, percent_prob(ip_data->probability));
	}
}

static int
write_block_display(DB *db, ip_range *ip_addr, double decay, int reason_id, DB *db_d)
/*
* Do initial_count and force options, then write and possibly display block.
* Return -1 on error, 0 if all ok, 1 if all ok and probability reached 100%.
*/
{
	assert(db);
	assert(ip_addr);
	assert(ip_addr->ip == 4 || ip_addr->ip == 6);
	assert(ip_addr->args == 1 || ip_addr->ip == 6); // range only IPv6

	int probability;
	if (initial_count == 999)
		probability = 0;
	else
	{
		if ((probability = count2prob(initial_count)) <= 0)
		{
			report_error(&ap, LOG_NOTICE,
				"count %d is too high: initial probability set to 1\n",
				initial_count);
		}
		probability += 1; // rounded
	}

	int force = 0;
	if (force_probability_opt) force |= write_block_force_probability;
	if (force_decay_opt) force |= write_block_force_decay;
	if (force_reason_opt) force |= write_block_force_reason;
	if (force_add_category_opt) force |= write_block_force_add_category;

	if (write_block(db, base_ip_u(ip_addr), decay, reason_id,
			&probability, force, verbose? display_written_block:
			ap.mode? logline_written_block: NULL, db_d))
				return -1;

	return probability >= RAND_MAX;
}

static int system_exec_ip(connkill_cmd *cmd, ip_u *ip_addr)
{
	assert(cmd);
	assert(ip_addr);
	assert(ip_addr->ip == 4 || ip_addr->ip == 6);

	char addr[INET6_ADDRSTRLEN];
	char const *a = my_inet_ntop(ip_addr, addr);

	return system_exec(cmd, a, &ap);
}

static void run_exec_ip(connkill_cmd *cmd, ip_u *ip_addr)
{
	assert(cmd);
	assert(ip_addr);
	assert(ip_addr->ip == 4 || ip_addr->ip == 6);

	char addr[INET6_ADDRSTRLEN];
	char const *a = my_inet_ntop(ip_addr, addr);
 
	run_exec(cmd, a, &ap);
}

int main(int argc, char const *argv[])
{
	static const char optaliases[] = IPQBDB_OPTION_FILE;
	int rtc = 0, errs = 0;
	ip_range *ip_addr = NULL;
	int collect_have_addr = 0;

	poptContext opt = poptGetContext(err_prefix, argc, argv, opttab, 0);

	if (access(optaliases, F_OK) == 0 &&
		(rtc = poptReadConfigFile(opt, optaliases)) < 0)
	{
		fprintf(stderr, "%s: cannot read %s: %s\n",
			err_prefix, optaliases, poptStrerror(rtc));
		errs = 3;
	}

	rtc = poptGetNextOpt(opt);
	if (rtc != -1)
	{
		fprintf(stderr, "%s: %s at %s\n",
			err_prefix, poptStrerror(rtc), poptBadOption(opt, 0));
		errs = 1;
	}
	else
	{
		if (version_opt)
		{
			fprintf(stdout, "%s: version " PACKAGE_VERSION "\n"
			DB_VERSION_STRING "\n", err_prefix);
			errs = 2;
		}

		if (help_opt)
		{
			poptPrintHelp(opt, stdout, 0);
			fputs_database_help();
			fputs_initial_count_help();
			fputs("The special value 999 for the initial count"
				" with the --force-probability\nflag can be used to zero out"
				" the probability and grant access to the given IP.\n", stdout);
			errs = 2;
		}

		// popt sets verbose to 0 if no arg is given
		verbose += 1;

		ap.mode = error_report_stderr; // 0
		ap.err_prefix = err_prefix;
		if (syslog_opt)
		{
			openlog(err_prefix, LOG_PID, LOG_DAEMON);
			ap.mode = LOG_DAEMON;
		}

		if (poptPeekArg(opt) != NULL)
		{
			report_error(&ap, LOG_ERR, "unexpected argument: %s\n",
				poptGetArg(opt));
			errs = 1;
		}

		collect_have_addr = get_all_ip_address(&ip_addr);
		if (collect_have_addr < 0 || ip_addr == NULL)
			errs = 3;

		if (initial_decay < 0.0)
		{
			report_error(&ap, LOG_WARNING, "initial decay %g is negative\n",
				initial_decay);
			errs = 3;
		}

		if (index_factor != IPQBDB_INDEX_FACTOR)
			set_ipqbdb_index_factor(index_factor);
	}

	if (errs == 1 && !syslog_opt)
		poptPrintUsage(opt, stdout, 0);
	poptFreeContext(opt);
	rtc = 0;

	if (errs)
		rtc = 1;
	else
	{
		switchable_fname *block_fname = NULL;
		switchable_fname *descr_fname = NULL;
		switchable_fname *white_fname = NULL;

		int reason_id = -1;
		char *t = NULL;
		unsigned long l = strtoul(reason_string, &t, 0);
		if (t && *t == 0 && l < INT_MAX)
			reason_id = (int)l;

		if (ip_addr)
		{
			block_fname = database_fname(db_block_name, &ap);
			white_fname = database_fname(db_white_name, &ap);
			if (block_fname == NULL || white_fname == NULL)
				rtc = 1;
		}

		// need descrdb anyway, for write_block()
		descr_fname = database_fname(db_descr_name, &ap);
		if (descr_fname == NULL)
			rtc = 1;

		if (rtc)
		{
			if (verbose)
				report_error(&ap, LOG_INFO, "Bailing out...\n");
		}
		else if (block_fname == NULL && descr_fname == NULL)
		{
			if (verbose)
			{
				report_error(&ap, LOG_WARNING,
					"No ip address to ban, no string description to check\n");
			}
		}
		else
		{
			DB_ENV *db_env = NULL;
			DB *db_block = NULL, *db_white = NULL, *db_de1 = NULL, *db_de2 = NULL;
			DB *db6_block = NULL, *db6_white = NULL;
			DBC *db6c_white = NULL;
			int exec_connkill_cnt = 0;

			setsigs();

			rtc = open_database(block_fname, &ap, &db_env,
					collect_have_addr & have_ipv4_addr? &db_block: NULL,
					collect_have_addr & have_ipv6_addr? &db6_block: NULL);

			if (rtc == 0 && white_fname && caught_signal == 0)
				rtc = open_database(white_fname, &ap, &db_env,
					collect_have_addr & have_ipv4_addr? &db_white: NULL,
					collect_have_addr & have_ipv6_addr? &db6_white: NULL);

			if (rtc == 0 && db6_white && caught_signal == 0)
				rtc = db6_white->cursor(db6_white, NULL, &db6c_white, 0);

			if (rtc == 0 && descr_fname && caught_signal == 0)
				rtc = open_descrdb(descr_fname->fname, db_env, &db_de1, &db_de2);

			if (rtc == 0 && reason_id == -1 && caught_signal == 0)
				rtc = get_reason_id(&reason_id, reason_string,
					&ap, db_de1, db_de2, verbose);

			if (rtc == 0 && ip_addr && caught_signal == 0)
			{
				double decay;
				ip_range *next_ip = ip_addr;
				while (rtc == 0 && next_ip->ip >= 0 && caught_signal == 0)
				{
					int ip = next_ip->ip;
					if (ip > 0)
					{
						assert(ip == 4 || ip == 6);
						void *db_or_c;
						if (ip == 4)
							db_or_c = db_white;
						else
							db_or_c = db6c_white;
						rtc = display_check_whitelist(db_or_c, next_ip, &decay);
						if (rtc < 0)
						{
							DB *db_err = ip == 4? db_white: db6_white;
							db_err->errx(db_err, "cannot check whitelist");
						}

						if (rtc == 0 && caught_signal == 0 && decay > 0.0)
						{
							int a_rc = write_block_display(
								ip == 4? db_block: db6_block,
								next_ip, decay, reason_id, db_de1);
							if (a_rc > 0)
							{
								exec_connkill_cnt += 1;
								next_ip->kill = 1;
							}
							else if (a_rc != 0)
								rtc = -1; // bail out on error
						}
						else if (rtc == 1) // whitelisted
							rtc = 0;
					}
					++next_ip;
				}
			}

			if (rtc)
				rtc = 2;

			close_db(db_de2);
			close_db(db_de1);
			close_db(db_white);
			close_db(db_block);
			if (db6c_white)
				db6c_white->close(db6c_white);
			close_db(db6_white);
			close_db(db6_block);

			close_dbenv(db_env, cleanup_opt);

			if (exec_connkill_cnt > 0 && exec_connkill_opt && caught_signal == 0)
			{
				connkill_cmd *const cmd = read_connkill_cmd(&ap);
				ip_range *next_ip = ip_addr;
				if (cmd)
				{
					while (next_ip->ip >= 0 && caught_signal == 0)
					{
						int ip = next_ip->ip;
						if (ip > 0 && next_ip->kill)
						{
							if (exec_connkill_cnt == 1)
								run_exec_ip(cmd, base_ip_u(next_ip)); // won't return

							int rtc_kill = system_exec_ip(cmd, base_ip_u(next_ip));
							if (rtc_kill == -1)
							{
								rtc = 2;
								break;
							}

							if (rtc_kill)
							{
								if (WIFSIGNALED(rtc_kill))
								{
									rtc = 2;
									caught_signal = 1;
								}
								else
								{
									int r = WEXITSTATUS(rtc_kill);
									if (r == 127)
									{
										rtc = 2;
										break;
									}

									if (r > rtc)
										rtc = r;
								}
							}
						}
						++next_ip;
					}
				}
			}
		}
		free(block_fname);
		free(descr_fname);
		free(white_fname);
	}

	free(ip_addr);
	return rtc;
}

