/*-GNU-GPL-BEGIN-*
nepim - network pipemeter
Copyright (C) 2005 Everton da Silva Marques

nepim is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2, or (at your option)
any later version.

nepim is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with nepim; see the file COPYING.  If not, write to
the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
Boston, MA 02111-1307, USA.
*-GNU-GPL-END-*/


/* $Id: server.c,v 1.9 2005/04/19 03:40:33 evertonm Exp $ */

#include <stdio.h>
#include <stdlib.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>
#include <assert.h>
#include <string.h>
#include <errno.h>

#include "conf.h"
#include "sock.h"
#include "pipe.h"
#include "common.h"

nepim_pipe_set_t pipes;

const char *INET_ANY = "0.0.0.0";
const char *INET6_ANY = "::";

static void schedule_stat_interval(int sd);
static void tcp_pipe_cancel_timers(int sd);
static void tcp_pipe_cancel_io(int sd);
static void tcp_pipe_kill(int sd);
static void *on_tcp_rate_delay(oop_source *src, struct timeval tv, void *user);

static void *on_tcp_read(oop_source *src, int sd,
			 oop_event event, void *user)
{
  char buf[nepim_global.tcp_read_size];
  int rd;
  nepim_pipe_t *pipe = user;

  rd = read(sd, buf, nepim_global.tcp_read_size);
  if (rd < 1) {
    switch (errno) {
    case EINTR:
      fprintf(stderr, "read: EINTR on TCP socket %d\n", sd);
      return OOP_CONTINUE;
    case EAGAIN:
      fprintf(stderr, "read: EAGAIN on TCP socket %d\n", sd);
      return OOP_CONTINUE;
    }

    fprintf(stderr, "read: connection lost on TCP socket %d\n", sd);

    if (!pipe->duration_done)
      report_broken_pipe_stat(stdout, pipe);

    tcp_pipe_kill(sd);

    return OOP_CONTINUE;
  }

  pipe->byte_total_recv += rd;
  pipe->byte_interval_recv += rd;
  ++pipe->total_reads;
  ++pipe->interval_reads;

  return OOP_CONTINUE;
}

static void *on_tcp_write(oop_source *src, int sd,
			  oop_event event, void *user)
{
  char buf[nepim_global.tcp_write_size];
  int wr;
  nepim_pipe_t *pipe = user;

  wr = write(sd, buf, nepim_global.tcp_write_size);
  if (wr < 1) {
    switch (errno) {
    case EINTR:
      fprintf(stderr, "write: EINTR on TCP socket %d\n", sd);
      return OOP_CONTINUE;
    case EAGAIN:
      fprintf(stderr, "write: EAGAIN on TCP socket %d\n", sd);
      return OOP_CONTINUE;
    case EPIPE:
      fprintf(stderr, "write: EPIPE on TCP socket %d\n", sd);
      break;
    }

    fprintf(stderr, "write: connection lost on TCP socket %d\n", sd);

    if (!pipe->duration_done)
      report_broken_pipe_stat(stdout, pipe);

    tcp_pipe_kill(sd);

    return OOP_CONTINUE;
  }

  pipe->byte_total_sent += wr;
  pipe->byte_interval_sent += wr;
  ++pipe->total_writes;
  ++pipe->interval_writes;

  return OOP_CONTINUE;
}

static void *on_tcp_rate_write(oop_source *src, int sd,
			       oop_event event, void *user)
{
  char buf[nepim_global.tcp_write_size];
  int wr;
  nepim_pipe_t *pipe = user;
  int to_write;

  assert(event == OOP_WRITE);
  assert(sd == pipe->sd);
  assert(pipe->max_bit_rate > 0);

  to_write = NEPIM_MIN(pipe->rate_remaining, nepim_global.tcp_write_size);

  wr = write(sd, buf, to_write);
  if (wr < 1) {
    switch (errno) {
    case EINTR:
      fprintf(stderr, "rate_write: EINTR on TCP socket %d\n", sd);
      return OOP_CONTINUE;
    case EAGAIN:
      fprintf(stderr, "rate_write: EAGAIN on TCP socket %d\n", sd);
      return OOP_CONTINUE;
    case EPIPE:
      fprintf(stderr, "rate_write: EPIPE on TCP socket %d\n", sd);
      break;
    }

    fprintf(stderr, "rate_write: connection lost on TCP socket %d\n", sd);

    if (!pipe->duration_done)
      report_broken_pipe_stat(stdout, pipe);

    tcp_pipe_kill(sd);

    return OOP_CONTINUE;
  }

  pipe->byte_total_sent += wr;
  pipe->byte_interval_sent += wr;
  pipe->rate_remaining -= wr;
  ++pipe->total_writes;
  ++pipe->interval_writes;

#if 0
  fprintf(stderr, 
	  "DEBUG %s %s %d: to_write=%d written=%d missing=%d\n",
	  __FILE__, __PRETTY_FUNCTION__,
	  sd, to_write, wr, pipe->rate_remaining);
#endif

  /* finished ? */
  if (pipe->rate_remaining < 1) {
    /* stop writing */
    nepim_global.oop_src->cancel_fd(nepim_global.oop_src,
				    sd, OOP_WRITE);

    /* schedule for next time saved by on_tcp_rate_delay() */
    nepim_global.oop_src->on_time(nepim_global.oop_src,
				  pipe->tv_rate,
				  on_tcp_rate_delay, pipe);
  }

  return OOP_CONTINUE;
}

static void *on_tcp_rate_delay(oop_source *src, struct timeval tv, void *user)
{
  nepim_pipe_t *pipe = user;
  long long tmp;

  assert(timercmp(&tv, &pipe->tv_rate, ==));
  assert(pipe->max_bit_rate > 0);

  /* save next scheduling time */
  {
    int result = gettimeofday(&pipe->tv_rate, 0);
    assert(!result);
  }
  pipe->tv_rate.tv_usec += pipe->write_delay;
  if (pipe->tv_rate.tv_usec >= 1000000) {
    pipe->tv_rate.tv_usec -= 1000000;
    ++pipe->tv_rate.tv_sec;
  }

  /* calculate bytes to be written from rate */
  tmp = pipe->max_bit_rate;
  tmp *= pipe->write_delay;
  tmp /= 8000000;
  pipe->rate_remaining = tmp;

#if 0
  fprintf(stderr, 
	  "DEBUG %s %s %d: "
	  "(bit_rate * delay / 8M = bytes) "
	  "%d * %ld / 8M = %d\n",
	  __FILE__, __PRETTY_FUNCTION__,
	  pipe->sd, 
	  pipe->max_bit_rate,
	  pipe->write_delay, 
	  pipe->rate_remaining);
#endif

  /* start to write */
  nepim_global.oop_src->on_fd(nepim_global.oop_src,
			      pipe->sd, OOP_WRITE,
			      on_tcp_rate_write, pipe);

  return OOP_CONTINUE;
}

static void *on_tcp_duration(oop_source *src, struct timeval tv, void *user)
{
  nepim_pipe_t *pipe = user;

  nepim_dump_stat(stdout, NEPIM_LABEL_TOTAL, pipe->sd, 
		  pipe->byte_total_recv,
		  pipe->byte_total_sent,
		  pipe->test_duration,
		  pipe->tv_start.tv_sec,
		  pipe->test_duration,
		  pipe->total_reads,
		  pipe->total_writes);

  pipe->byte_total_recv = 0;
  pipe->byte_total_sent = 0;
  pipe->total_reads     = 0;
  pipe->total_writes    = 0;
  pipe->duration_done   = 1;

  tcp_pipe_cancel_timers(pipe->sd);

  return OOP_CONTINUE;
}

static void *on_tcp_interval(oop_source *src, struct timeval tv, void *user)
{
  nepim_pipe_t *pipe = user;

  nepim_dump_stat(stdout, NEPIM_LABEL_PARTIAL, pipe->sd, 
		  pipe->byte_interval_recv,
		  pipe->byte_interval_sent,
		  pipe->stat_interval,
		  pipe->tv_start.tv_sec,
		  pipe->test_duration,
		  pipe->interval_reads,
		  pipe->interval_writes);

  pipe->byte_interval_recv = 0;
  pipe->byte_interval_sent = 0;
  pipe->interval_reads     = 0;
  pipe->interval_writes    = 0;

  schedule_stat_interval(pipe->sd);

  return OOP_CONTINUE;
}

const char *SERVER_SEND   = "server_send=";
const char *BIT_RATE      = "bit_rate=";
const char *STAT_INTERVAL = "stat_interval=";
const char *TEST_DURATION = "test_duration=";
const char *WRITE_DELAY   = "write_delay=";

static int parse_greetings(int sd, const char *buf, const char *past_end)
{
  char tmp[past_end - buf];
  const char *SEP = " ";
  const char *tok;
  char *ptr;
  int server_send    = -2;
  long long bit_rate = -2;
  int stat_interval  = -2;
  int test_duration  = -2;
  long write_delay   = -2;
  const int SERVER_SEND_LEN   = strlen(SERVER_SEND);
  const int BIT_RATE_LEN      = strlen(BIT_RATE);
  const int STAT_INTERVAL_LEN = strlen(STAT_INTERVAL);
  const int TEST_DURATION_LEN = strlen(TEST_DURATION);
  const int WRITE_DELAY_LEN   = strlen(WRITE_DELAY);

  memcpy(tmp, buf, past_end - buf);

  tok = strtok_r(tmp, SEP, &ptr);

  if (!tok) {
    fprintf(stderr, "%d: bad greeting: missing first token\n",
	    sd);
    return -1;
  }

  if (strncmp(tok, "hello", 5)) {
    fprintf(stderr, "%d: bad greeting: missing hello prefix\n",
	    sd);
    return -1;
  }

  for (;;) {
    tok = strtok_r(0, SEP, &ptr);
    if (!tok)
      break;

    if (!strncmp(tok, SERVER_SEND, SERVER_SEND_LEN)) {
      server_send = atoi(tok + SERVER_SEND_LEN);
      continue;
    }

    if (!strncmp(tok, BIT_RATE, BIT_RATE_LEN)) {
      bit_rate = atoll(tok + BIT_RATE_LEN);
      continue;
    }

    if (!strncmp(tok, STAT_INTERVAL, STAT_INTERVAL_LEN)) {
      stat_interval = atoi(tok + STAT_INTERVAL_LEN);
      continue;
    }

    if (!strncmp(tok, TEST_DURATION, TEST_DURATION_LEN)) {
      test_duration = atoi(tok + TEST_DURATION_LEN);
      continue;
    }

    if (!strncmp(tok, WRITE_DELAY, WRITE_DELAY_LEN)) {
      write_delay = atoi(tok + WRITE_DELAY_LEN);
      continue;
    }

    fprintf(stderr, "%d: bad parameter token\n", sd);
  }

  if (write_delay < 0)
    write_delay = nepim_global.write_delay;

  fprintf(stderr, 
	  "%d: TCP incoming: send=%d max_rate=%lld "
	  "interval=%d duration=%d delay=%ld\n", 
	  sd, server_send, bit_rate, stat_interval,
	  test_duration, write_delay);

  nepim_pipe_set_add(&pipes, sd, 
		     server_send, 
		     bit_rate, 
		     stat_interval,
		     test_duration,
		     write_delay);

  return 0;
}

static int read_greetings(int sd)
{
  char buf[1024];
  int len = 0;
  char *eos;

  for (;;) {
    char *curr = buf + len;
    int left = sizeof(buf) - len;
    int rd;

    assert(left > 0);

    rd = read(sd, curr, left);
    if (!rd) {
      fprintf(stderr, 
	      "%s: incoming connection lost\n", 
	      __PRETTY_FUNCTION__);
      return -1;
    }
    if (rd < 0) {

      if (errno == EAGAIN)
	continue;
    }

    assert(rd > 0);
    assert(rd <= left);

    len += rd;

    eos = memchr(curr, '\n', rd);
    if (eos)
      break;
  }

  *eos = '\0';

  if (parse_greetings(sd, buf, eos)) {
    fprintf(stderr, 
	    "%s: bad client greetings: %s\n", 
	    __PRETTY_FUNCTION__, buf);
    return -1;
  }

  return 0;
}

static void schedule_stat_interval(int sd)
{
  nepim_pipe_t *pipe = nepim_pipe_set_get(&pipes, sd);
  assert(pipe);

  {
    int result = gettimeofday(&pipe->tv_interval, 0);
    assert(!result);
  }

  pipe->tv_interval.tv_sec += pipe->stat_interval;

  nepim_global.oop_src->on_time(nepim_global.oop_src,
				pipe->tv_interval,
				on_tcp_interval, pipe);
}

static void tcp_pipe_start(int sd)
{
  nepim_pipe_t *pipe = nepim_pipe_set_get(&pipes, sd);
  assert(pipe);

  nepim_global.oop_src->on_fd(nepim_global.oop_src,
			      sd, OOP_READ,
			      on_tcp_read, pipe);

  if (pipe->must_send) {
    if (pipe->max_bit_rate > 0) {
      pipe->tv_rate = OOP_TIME_NOW;
      nepim_global.oop_src->on_time(nepim_global.oop_src,
				    pipe->tv_rate,
				    on_tcp_rate_delay, pipe);
    }
    else {
      nepim_global.oop_src->on_fd(nepim_global.oop_src,
				  sd, OOP_WRITE,
				  on_tcp_write, pipe);
    }
  }

  {
    int result = gettimeofday(&pipe->tv_start, 0);
    assert(!result);
    pipe->tv_duration = pipe->tv_start;
  }
  pipe->tv_duration.tv_sec += pipe->test_duration;
  nepim_global.oop_src->on_time(nepim_global.oop_src,
				pipe->tv_duration,
				on_tcp_duration, pipe);

  schedule_stat_interval(sd);
}

static void tcp_pipe_cancel_timers(int sd)
{
  nepim_pipe_t *pipe = nepim_pipe_set_get(&pipes, sd);
  assert(pipe);

  nepim_global.oop_src->cancel_time(nepim_global.oop_src,
				    pipe->tv_duration,
				    on_tcp_duration, pipe);

  nepim_global.oop_src->cancel_time(nepim_global.oop_src,
				    pipe->tv_interval,
				    on_tcp_interval, pipe);
}

static void tcp_pipe_cancel_io(int sd)
{
  nepim_pipe_t *pipe = nepim_pipe_set_get(&pipes, sd);
  assert(pipe);

  nepim_global.oop_src->cancel_fd(nepim_global.oop_src,
				  sd, OOP_READ);

  if (pipe->must_send) {
    if (pipe->max_bit_rate > 0) {
      /* stop current writing, if any */
      nepim_global.oop_src->cancel_fd(nepim_global.oop_src,
				      sd, OOP_WRITE);
      
      /* stop periodic write scheduler */
      nepim_global.oop_src->cancel_time(nepim_global.oop_src,
					pipe->tv_rate,
					on_tcp_rate_delay, pipe);
    }
    else {
      nepim_global.oop_src->cancel_fd(nepim_global.oop_src,
				      sd, OOP_WRITE);
    }
  }
}

static void tcp_pipe_kill(int sd)
{
  nepim_pipe_t *pipe = nepim_pipe_set_get(&pipes, sd);
  assert(pipe);

  tcp_pipe_cancel_timers(sd);
  tcp_pipe_cancel_io(sd);

  close(sd);

  nepim_pipe_set_del(&pipes, sd);
}

static void *on_tcp_connect(oop_source *src, int sd,
			    oop_event event, void *unnused)
{
  int conn_sd;
  union {
    struct sockaddr_in inet;
    struct sockaddr_in6 inet6;
  } sa;
  int len = sizeof(sa);
  int result;

  conn_sd = accept(sd, (struct sockaddr *) &sa, &len);
  if (conn_sd < 0) {
    fprintf(stderr, 
	    "%s: could not accept connection: %s\n", 
	    __PRETTY_FUNCTION__, strerror(errno));
    return OOP_CONTINUE;
  }

  result = nepim_socket_opt(conn_sd);
  if (result) {
    fprintf(stderr, 
	    "%s: could not set socket options: %d\n",
	    __PRETTY_FUNCTION__, result);
    close(conn_sd);
    return OOP_CONTINUE;
  }

  result = nepim_socket_tcp_opt(conn_sd);
  if (result) {
    fprintf(stderr, 
	    "%s: could not set tcp options: %d\n",
	    __PRETTY_FUNCTION__, result);
    close(conn_sd);
    return OOP_CONTINUE;
  }

  if (nepim_socket_block(conn_sd)) {
    fprintf(stderr, 
	    "%s: could not set blocking socket mode\n",
	    __PRETTY_FUNCTION__);
    close(conn_sd);
    return OOP_CONTINUE;
  }

  if (read_greetings(conn_sd)) {
    fprintf(stderr, 
	    "%s: could not parse client greetings\n",
	    __PRETTY_FUNCTION__);
    close(conn_sd);
    return OOP_CONTINUE;
  }

  if (nepim_socket_nonblock(conn_sd)) {
    fprintf(stderr, 
	    "%s: could not set non-blocking socket mode\n",
	    __PRETTY_FUNCTION__);
    close(conn_sd);
    return OOP_CONTINUE;
  }

  tcp_pipe_start(conn_sd);

  return OOP_CONTINUE;
}

static void spawn_tcp_listener(const char *hostname)
{
  struct addrinfo hints;
  struct addrinfo *ai_res;
  struct addrinfo *ai;
  int result;

  hints.ai_socktype = SOCK_STREAM;
  hints.ai_protocol = IPPROTO_TCP;
  hints.ai_flags = AI_CANONNAME;
  hints.ai_family = PF_UNSPEC;
  hints.ai_addrlen = 0;
  hints.ai_addr = 0;
  hints.ai_canonname = 0;

  result = getaddrinfo(hostname, 0, &hints, &ai_res);
  if (result) {
    fprintf(stderr, "getaddrinfo(%s): %s\n",
            hostname, gai_strerror(result));
    return;
  }

  for (ai = ai_res; ai; ai = ai->ai_next) {
    int sd;

    fprintf(stderr, 
	    "TCP listener: host=%s,%d len=%d family=%d type=%d proto=%d\n",
	    hostname, nepim_global.port, 
	    ai->ai_addrlen, ai->ai_family, 
	    ai->ai_socktype, ai->ai_protocol);

    sd = nepim_create_listener_socket(ai->ai_addr, ai->ai_addrlen,
				      ai->ai_family, ai->ai_socktype, 
				      ai->ai_protocol, nepim_global.port);
    if (sd < 0) {
      fprintf(stderr, 
	      "%s: could not create TCP listener socket on %s,%d: %d\n",
	      __PRETTY_FUNCTION__, hostname, nepim_global.port, sd);
      break;
    }

    nepim_global.oop_src->on_fd(nepim_global.oop_src,
			  	sd, OOP_READ,
				on_tcp_connect, 0);

    fprintf(stderr, 
	    "running TCP listener socket %d on %s,%d\n",
	    sd, hostname, nepim_global.port);
  }

  freeaddrinfo(ai_res);
}

void nepim_udp_listener(const char *hostname);

void nepim_server_run()
{
  void *result;

  if (!nepim_global.no_inet6)
    spawn_tcp_listener(INET6_ANY);
  spawn_tcp_listener(INET_ANY);
  if (!nepim_global.no_inet6)
    nepim_udp_listener(INET6_ANY);
  nepim_udp_listener(INET_ANY);

  nepim_pipe_set_init(&pipes);

  result = oop_sys_run(nepim_global.oop_sys);
}


