/*
* IPv6 comparison function written by ale in milano on 20 oct 2023

Copyright (C) 2023 Alessandro Vesely

This file is part of Ipqbdb.

Ipqbdb is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Ipqbdb is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with Ipqbdb.  If not, see <http://www.gnu.org/licenses/>.

*/
#include "max_range_ip6.h"
#include <assert.h>

static int ntz(unsigned char x)
/*
* Number of trailing zeroes.
*/
{
	if (x & 0xf)
	{
		if (x & 3)
		{
			if (x & 1)
				return 0;

			return 1;
		}

		if (x & 4)
			return 2;

		return 3;
	}

	if (x & 0x30)
	{
		if (x & 0x10)
			return 4;

		return 5;
	}

	if (x & 0x40)
		return 6;

	if (x)
		return 7;

	return 8;
}

static inline int find10(unsigned char x, int *done)
/*
* Return the subtrahend (where the minuend is the suffix length).
* Set done to 0 if found one followed by zero, return 1 if found all
* ones after leading zeroes.
*/
{
	assert(x != 0);
	assert(x != 0xff);
	assert(done);

	// number of leading zeroes
	int n = 1, entry = x;
	if ((x >> 4) == 0) { n += 4; x <<= 4; }
	if ((x >> 6) == 0) { n += 2; x <<= 2; }
	n -= x >> 7;

	// if what follows is not all ones, remove one more.
	int mask = (1 << (8 - n)) - 1;
	if (entry == mask)
		*done = 1;
	else
	{
		*done = 0;
		++n;
	}

	return n;
}

int max_range(ip_range const *in)
/*
* Return the maximum (lowest) plen such that the range starting at
* in->u.ipv6 is included in in.  Return -1 if u2.ipv6 < u.ipv6.
*
* TODO: There must be a better way to do this.
*/
{
	assert(in);
	assert(in->args > 1);

	/*
	* Count the number of trailing zeroes, where the range starts.
	*/
	int mslen = 0, ndx = 15;
	while (in->u.ipv6[ndx] == 0)
	{
		mslen += 8;
		if (--ndx < 0)
			break;
	}

	int bits = 0;
	if (ndx >= 0)
	{
		bits = ntz(in->u.ipv6[ndx]);
		mslen += bits;
	}

	/*
	* Full range or impossible if the u2 prefix is higher or lower
	* than the corresponding u prefix.
	*/
	for (int i = 0; i < ndx; ++i)
	{
		if (in->u2.ipv6[i] > in->u.ipv6[i])
			return 128 - mslen;
		if (in->u2.ipv6[i] < in->u.ipv6[i])
			return -1;
	}

	int wereone = 0;
	if (ndx >= 0)
	{
		int u2 = in->u2.ipv6[ndx], u = in->u.ipv6[ndx];
		int mask = (1 << bits) - 1;
		if ((u2 & ~mask) > (u & ~mask))
			return 128 - mslen;

		if ((u2 & ~mask) < (u & ~mask))
			return -1;

	/*
	* Prefixes match.  If u2 is all ones beyond this point, we still
	* have the full range.  If u2 is all zeroes (like u) we have a /128.
	* Otherwise remove a bit from the suffix for each zero.  When a one
	* is found, if the following are all ones, that's the suffix; if a
	* zero follows anywhere after the one remove one more bit.
	*/
		u2 &= mask;
		if (mask && u2 == mask)
			wereone = 1;
		else
		{
			if (u2 == 0)
				mslen -= bits;
			else
			{
				int subtrahend = find10(u2, &wereone) - 8 + bits;
				mslen -= subtrahend;
				if (!wereone)
					return 128 - mslen;
			}
		}
	}

	for (int i = ndx+1; i <= 15; ++i)
	{
		int u2 = in->u2.ipv6[i];
		if (u2 == 0xff)
			wereone = 1;
		else
		{
			if (wereone)
			{
				mslen -= 1;
				return 128 - mslen;
			}

			if (u2 == 0)
			{
				mslen -= 8;
			}
			else
			{
				mslen -= find10(u2, &wereone);
				if (!wereone)
					return 128 - mslen;
			}
		}
	}

	return 128 - mslen;
}

void last_in_range(unsigned char ip[16], int plen)
/*
* Zero all bits beyond plen.
*/
{
	assert(plen >= 0);
	assert(plen <= 128);

	int slen = 128 - plen;
	int ndx = 15;
	while (slen > 8)
	{
		ip[ndx] = 0xff;
		ndx -= 1;
		slen -= 8;
	}

	unsigned char mask = ((1 << slen) - 1);
	ip[ndx] |= mask;
}


int add_one(unsigned char u[16])
{
	for (int i = 15; i >= 0; --i)
	{
		u[i] += 1;
		if (u[i] > 0)
			return 0;
	}

	return 1; // overflow
}

#if defined TEST_MAX_RANGE
#include <stdio.h>
#include <string.h>
#include "ip_util.h"

int main(int argc, char *argv[])
{
	for (int i = 1; i < argc; ++i)
	{
		ip_range ip;
		int rtc = parse_ip_address(argv[i], &ip, NULL);
		if (rtc == 0)
		{
			int cnt = 0;
			if (ip.args> 1)
			{
				// code copied from ibd-white.c
				int cmp;
				do // loop for ranges spanning multiple CIDRs
				{
					ip_range next;
					next.ip = 6;
					next.args = 2;
					memcpy(next.u.ipv6, ip.u.ipv6, 16);

					if (ip.args > 1)
					{
						int plen = max_range(&ip);
						if (plen < 0)
						{
							char buf[INET6_ADDRSTRLEN];
							char buf2[INET6_ADDRSTRLEN];
							fprintf(stderr,
								"Cannot get max_range(%s, %s)\n",
								inet_ntop(AF_INET6, ip.u.ipv6, buf, sizeof buf),
								inet_ntop(AF_INET6, ip.u2.ipv6, buf2, sizeof buf2));
							return -2;
						}

						last_in_range(ip.u.ipv6, plen);
					}

					memcpy(next.u2.ipv6, ip.u.ipv6, 16);
					char buf[INET_RANGESTRLEN];
					char buf2[INET6_ADDRSTRLEN];
					printf("%5d)   %s -> %s\n", ++cnt,
						snprint_range(buf, sizeof buf, &next),
						inet_ntop(AF_INET6, next.u2.ipv6, buf2, sizeof buf2));

					add_one(ip.u.ipv6);
					cmp = memcmp(ip.u.ipv6, ip.u2.ipv6, 16);
					ip.args = cmp? 2: 1;
				} while (cmp <= 0);

				putchar('\n');
			}
		}
		else
			fprintf(stderr, "Bad argument %s: %s (%d)\n",
				argv[i], parse_ip_invalid_what(rtc), rtc);
	}

	return 0;
}
#endif // defined TEST_MAX_RANGE
