/* Input/output handling.
   Copyright (c) 1999, 2000 Idaya Ltd.
   Contributed by Nick Burrett <nick@dsvr.net>

   This file is part of the Virtual Server Administrator (FreeVSD)

   FreeVSD is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2, or (at your option)
   any later version.

   FreeVSD is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with FreeVSD; see the file COPYING.  If not, write to
   the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
   Boston, MA 02111-1307, USA.  */

/* Protocol input format:
     virtual-server {
       command [args]
       command {
     <data>
       }
     }
     virtual-server {
       command [args]
       command [args]
     }
     EOF

   Output format:
     virtual-server {
       command {
         result-code/result-data
       }
       command {
         result-code/result-data
       }
     }
     EOF

   Please see freevsd/doc/freevsd-protocol.txt for a fuller description
   of the protocol.  */

#include "config.h"
#include <ctype.h>
#include <stdio.h>
#include <errno.h>
#include <time.h>
#include <stdarg.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include "io.h"
#include "vsd.h"
#include "libvsd.h"

/* This source file presents two input/output methods.  One is encryption
   based, and the other is a simple line-oriented text interface with
   no authentication.  The former is recommended for production servers,
   whereas the latter is useful if you are debugging VSD.  */

#ifdef WITH_OPENSSL

#if 0
static RSA *rsa_cb (SSL *ssl, int is_export, int key_length)
{
  RSA *tmp_rsa;

  vsd_log_error (VSD_LOG_DEBUG, "in rsa_cb");

  tmp_rsa = RSA_generate_key (key_length, RSA_F4, NULL, NULL);
  if (! tmp_rsa)
    {
      vsd_log_error (VSD_LOG_ERROR,
		 "Failed to generate temporary %d-bit %s RSA key",
		  key_length, is_export ? "export" : "domestic", 0);
      return NULL;
    }
  return tmp_rsa;
}
#endif

static int verify_callback (int ok, X509_STORE_CTX *ctx)
{
  X509 *cert;
  X509_NAME *subject, *issuer;
  int errnum, errdepth;
  char *sname, *iname;

  cert = X509_STORE_CTX_get_current_cert (ctx);
  errnum = X509_STORE_CTX_get_error (ctx);
  errdepth = X509_STORE_CTX_get_error_depth (ctx);

  subject = X509_get_subject_name (cert);
  issuer = X509_get_issuer_name (cert);
  sname = X509_NAME_oneline (subject, NULL, 0);
  iname = X509_NAME_oneline (issuer, NULL, 0);

  vsd_log_error (VSD_LOG_DEBUG,
		 "SSL certificate verification: ok: %d, depth: %d, subject: %s, issuer: %s",
		 ok, errdepth,
		 sname ? sname : "-unknown-",
		 iname ? iname : "-unknown-");
  if (sname)
    CRYPTO_free (sname);
  if (iname)
    CRYPTO_free (iname);

  return ok;
}

int io_initialise (struct connection *vc)
{
  int err;
  X509 *client_cert;
  char buf[256];
  struct io *io = (struct io *) vc->io;

  /* Setting this alarm will stop our program hanging indefinately, waiting
     for an SSL session to be established.  */
  alarm (vc->command_timeout);

  SSL_load_error_strings ();
  SSLeay_add_ssl_algorithms ();

  /* Create a context.  */
  io->ctx = SSL_CTX_new (SSLv23_method ());

  /* In order to create new session contexts, we need the certificate
     and RSA key.  */
  if (io->ctx == NULL
      || SSL_CTX_use_PrivateKey_file (io->ctx, SSL_BASE "/server/host.key",
				      SSL_FILETYPE_PEM) <= 0
      || SSL_CTX_use_certificate_file (io->ctx, SSL_BASE "/server/host.crt",
				       SSL_FILETYPE_PEM) <= 0)
    {
      vsd_log_error (VSD_LOG_ERROR, "%s",
		 ERR_error_string (ERR_get_error (), buf));
      return 1;
    }

  /* Check the private key is valid.  */
  if (! SSL_CTX_check_private_key (io->ctx))
    {
      vsd_log_error (VSD_LOG_ERROR, "Private key does not match the certificate public key");
      return 1;
    }

  if (! SSL_CTX_load_verify_locations (io->ctx, SSL_BASE "/ca/cacert.pem", SSL_BASE "/ca")
      || !SSL_CTX_set_default_verify_paths (io->ctx))
    {
      vsd_log_error (VSD_LOG_ERROR, "load_verify_locations: %s",
		 ERR_error_string (ERR_get_error (), buf));
    }
  /* Add a Certificate Authority.  */
  SSL_CTX_set_client_CA_list (io->ctx,
			      SSL_load_client_CA_file (SSL_BASE "/server/ca-cert.pem"));

  /**/
  SSL_CTX_set_verify (io->ctx, SSL_VERIFY_PEER
		      | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, verify_callback);

  /*SSL_CTX_set_tmp_rsa_callback (io->ctx, rsa_cb);*/

  /* Create a new SSL structure.  */
  io->ssl = SSL_new (io->ctx);
  if (io->ssl == NULL)
    {
      vsd_log_error (VSD_LOG_ERROR, "cannot initialise SSL");
      return 1;
    }

  /* Associate the input file descriptor with the SSL structure.  */
  SSL_set_fd (io->ssl, io->in_fd);

  /* Do SSL using our certificate and RSA key.  */
  err = SSL_accept (io->ssl);
  if (err == -1)
    {
      vsd_log_error (VSD_LOG_ERROR, "%s",
		 ERR_error_string (ERR_get_error (), buf));
      return 1;
    }
  vsd_log_error (VSD_LOG_DEBUG, "SSL connection using %s",
		 SSL_get_cipher (io->ssl));

  client_cert = SSL_get_peer_certificate (io->ssl);
  if (client_cert != NULL)
    {
      char *s, buf[256], cn[128];
      int x;

      /* Extract CN field from certificate subject.
	 FIXME: There is a proper way to do this but the documentation
	 for SSL leaves a lot to guesswork.  */
      vsd_log_error (VSD_LOG_DEBUG, "client certificate:");
      s = X509_NAME_oneline (X509_get_subject_name (client_cert),
			     buf, sizeof (buf));
      if (s)
	vsd_log_error (VSD_LOG_DEBUG, "  subject: %s", s);

      s = strstr (buf, "/CN=");
      if (s == NULL)
	vsd_log_error (VSD_LOG_ERROR,
		       "  expected a CN field in certificate subject: %s", buf);
      s += 4;
      x = 0;
      while (*s && *s != '/' && x < sizeof (cn))
	cn[x++] = *s++;
      if (*s == '/' || *s == '\0')
	{
	  cn[x] = '\0';
	  vsd_log_error (VSD_LOG_DEBUG, "  found cn: %s", cn);
	  io->certificate_vs = strdup (cn);
	}
      else
	{
	  cn[0] = '\0';
	  vsd_log_error (VSD_LOG_ERROR,
			 "  unexpected certificate subject format or buffer overflow: %s",
			 buf);
	  X509_free (client_cert);
	  return 1;
	}

      s = X509_NAME_oneline (X509_get_issuer_name (client_cert),
			     buf, sizeof (buf));
      if (s)
	vsd_log_error (VSD_LOG_DEBUG, "  issuer: %s", s);

      X509_free (client_cert);
    }
  else
    vsd_log_error (VSD_LOG_DEBUG, "client does not have a certificate");

  return 0;
}

int io_finalise (struct connection *vc)
{
  struct io *io = (struct io *) vc->io;

  SSL_free (io->ssl);
  SSL_CTX_free (io->ctx);
  return 0;
}

/* Get the next line of input from the connected socket.  */
static int input (struct connection *vc, int state)
{
  int pos, ret;
  struct io *io = (struct io *) vc->io;

  alarm (vc->command_timeout);

  if (io->input_buffer == NULL)
    {
      io->input_bufsiz = 256;
      io->input_buffer = (char *) malloc (io->input_bufsiz);
    }

  pos = 0;
  while (1)
    {
      if (io->input_bufsiz <= pos)
	{
	  /* Need more space.  */
	  io->input_bufsiz += 256;
	  io->input_buffer = (char *) realloc (io->input_buffer,
					       io->input_bufsiz);
	}

      /* There is a buffering problem that I can't seem to remember how
	 to get around. So instead, read the stream one byte at a time.  */
      while ((ret = SSL_read (io->ssl, io->input_buffer + pos, 1)) == -1
	     && errno == EINTR)
	;

      /* Replace this with a select call.  */
      if (ret == 0)
	{
	  /* Sleep for a short period to prevent a CPU resource drain.  */
	  struct timespec period, remain;
	  period.tv_sec = 0;
	  period.tv_nsec = 200000000;
	  nanosleep (&period, &remain);
	}

      /* cp = strstr (vc->input_buffer, "\n"); */
      /* if (cp) */

      pos += ret;
      if (io->input_buffer[pos - ret] == '\n')
	{
	  if (io->input_buffer[pos - ret - 1] == '\r')
	    pos --;
	  io->input_buffer[pos - ret] = '\0';
	  /* *cp = '\0'; */
	  alarm (0);
	  return 0;
	}
    }
  alarm (0);

  return 1;
}

/* Output encrypted data.  For encryption, it is better to collect
   all data together, encrypt the lot and then send.  */
void output (struct connection *vc, const char *fmt, ...)
{
  va_list ap;
  char *s = NULL;
  int len;
  struct io *io = NULL;

  if (vc != NULL)
    io = (struct io *) vc->io;

  va_start (ap, fmt);
  len = vsnprintf (s, 0, fmt, ap);
  s = (char *) malloc (len + 1);
  vsprintf (s, fmt, ap);

  if (vsd_log_level >= 4)
    vsd_log_error (VSD_LOG_PROTOCOL, "out: %s", s);

  if (io && io->ssl)
    SSL_write (io->ssl, s, len);
  else
    {
      /* SSL might not be setup, if we timed-out during SSL negotiation.  */
      while (write (fileno (stderr), s, len) == -1 && errno == EINTR)
	;
    }


  free (s);
  va_end (ap);
}

#else /* ! WITH_OPENSSL */
int io_initialise (struct connection *vc)
{
  return 0;
}

int io_finalise (struct connection *vc)
{
  return 0;
}

/* Get the next line of input from the connected socket.  */
static int input (struct connection *vc, int state)
{
  int pos, ret;
  struct io *io = (struct io *) vc->io;

  alarm (vc->command_timeout);

  if (io->input_buffer == NULL)
    {
      io->input_bufsiz = 256;
      io->input_buffer = (char *) malloc (io->input_bufsiz);
    }

  pos = 0;
  while (1)
    {
      if (io->input_bufsiz <= pos)
	{
	  /* Need more space.  */
	  io->input_bufsiz += 256;
	  io->input_buffer = (char *) realloc (io->input_buffer,
					       io->input_bufsiz);
	}

      /* There is a buffering problem that I can't seem to remember how
	 to get around. So instead, read the stream one byte at a time.  */
      while ((ret = read (io->in_fd, io->input_buffer + pos, 1)) == -1
	     && errno == EINTR)
	;

      /* Replace this with a select call.  */
      if (ret == 0)
	{
	  /* Sleep for a short period to prevent a CPU resource drain.  */
	  struct timespec period, remain;
	  period.tv_sec = 0;
	  period.tv_nsec = 200000000;
	  nanosleep (&period, &remain);
	}

      /* cp = strstr (io->input_buffer, "\n"); */
      /* if (cp) */

      pos += ret;
      if (io->input_buffer[pos - ret] == '\n')
	{
	  if (io->input_buffer[pos - ret - 1] == '\r')
	    pos --;
	  io->input_buffer[pos - ret] = '\0';
	  /* *cp = '\0'; */
	  alarm (0);
	  return 0;
	}
    }
  alarm (0);

  return 1;
}

/* Output a string and flush it.  */
void output (struct connection *vc, const char *fmt, ...)
{
  va_list ap;
  char *s = NULL;
  int len, fd;
  struct io *io = NULL;

  if (vc != NULL)
    {
      struct io *io = (struct io *) vc->io;
      fd = io->out_fd;
    }
  else
    fd = fileno (stderr);

  va_start (ap, fmt);
  len = vsnprintf (s, 0, fmt, ap);
  s = (char *) malloc (len + 1);
  vsprintf (s, fmt, ap);


  if (vsd_log_level >= 4)
    /* Log output (result) protocol to syslog.  */
    vsd_log_error (VSD_LOG_PROTOCOL, "out: %s", s);

  while (write (fd, s, len) == -1 && errno == EINTR)
      ;
  free (s);
  va_end (ap);
}
#endif /* ! WITH_OPENSSL */

/* Allocate memory of size `bytes' bytes.  If ptr is not NULL, then assume
   we are re-allocating memory.  */
static void *xalloc (void *ptr, size_t bytes)
{
  if (! ptr)
    return calloc (bytes, sizeof (char));
  return realloc (ptr, bytes);
}

/* Parse incoming data, storing it in an `io_requests' structure.  */
struct io_requests *io_parse_input (struct connection *vc)
{
  struct io_requests *req = NULL;
  int level, req_num = 0, cmd_num;
  struct io *io = (struct io *) vc->io;

  /* Nesting level. At level 0 we expect a virtual-server name. Increment/
     decrement when we come across a `{'/`}' respectively. At level 1
     we expect commands.  */
  level = 0;

  /* Process data one line at a time, the terminator is the string "EOF".  */
  while (! input (vc, STATE_TRANSACTION))
    {
      char *in = io->input_buffer;
      char **vec = NULL;
      size_t vecc = 0;

      if (vsd_log_level >= 4)
	/* Log input protocol to syslog.  */
	vsd_log_error (VSD_LOG_PROTOCOL, "in: %s", in);

      /* Nesting level 2 is where we accept raw data. So it isn't
	 a good idea to split the line up into vectors.  */
      if (level != 2)
	{
	  /* Skip blanks.  */
	  while (*in && isspace (*in))
	    in++;
	  if (!*in)
	    continue; /* Blank line.  */
	  /* EOF terminates an input stream.  */
	  if (strcmp (in, "EOF") == 0)
	    break;

	  /* Split space-seperated arguments into an argument vector.  */
	  vec = vsd_argv_parse (' ', in, &vecc);
	  if (vec == NULL)
	    continue;
	}

      if (level == 0)
	{
	  /* Expect `virtual-server {'.  */
	  req = (struct io_requests *)
	    xalloc (req, (req_num + 1) * sizeof (struct io_requests));
	  memset (&req[req_num], 0, sizeof (struct io_requests));
	  req[req_num].vs = strdup (vec[0]);
	  cmd_num = 0;
	  level ++;
	  /* vsd_argv_parse allocates memory for the argument vector.  */
	  free (vec);
	}
      else if (level == 1)
	{
	  /* Expect either a command or a '}'.  */
	  if (*in == '}')
	    {
	      /* And so ends this block. Increment request number.  */
	      req[req_num++].actionc = cmd_num;
	      level --;
	    }
	  else
	    {
	      int x;

	      /* Found a command. Create space for it and make a copy of
		 the command and its arguments.  */
	      req[req_num].actionv = (struct req_command *)
		xalloc (req[req_num].actionv,
			(cmd_num + 1) * sizeof (struct req_command));

	      req[req_num].actionv[cmd_num].argv = (char **)
		malloc (sizeof (char *) * (vecc + 1));
	      if (*vec[vecc - 1] == '{')
		{
		  /* Command has extra data.  */
		  level ++;
		  /* Avoid adding curly bracket to the command line.  */
		  vecc --;
		}
	      for (x = 0; x < vecc; x++)
		req[req_num].actionv[cmd_num].argv[x] = strdup (vec[x]);

	      /* req[req_num].actionv[cmd_num].argv = vec; */
	      req[req_num].actionv[cmd_num].argc = vecc; /* Arg count.  */
	      /* Don't change command numbers if we have just gone up a
		 level as we want to associate the extra data with this
		 command.  */
	      if (level == 1)
		cmd_num ++; 
	    }
	  /* vsd_argv_parse allocates memory for the argument vector.  */
	  free (vec);
	}
      else if (level == 2)
	{
	  if (*in == '}')
	    {
	      /* End of data segment.  */
	      req[req_num].actionv[cmd_num].data = strdup (io->io_buffer);
	      io_buffer_free (vc);
	      level --;
	      cmd_num ++;
	    }
	  else
	    io_buffer_store (vc, "%s\n", in);
	}
    }
  
  io->requests = req;
  io->requests_count = req_num;
  return req;
}

/* Output collected VSD results in a suitable format for the client.  */
void io_generate_output (struct connection *vc)
{
  struct io *io = (struct io *) vc->io;
  struct io_results *res = io->results;
  int i, j;

  switch (io->io_output_format)
    {
    case 1:
      /* Output format:
	 <virtual server> <cmd name> <result data>
	 <virtual server> <cmd name> <result data>

	 XXX This is unused.
      */
      for (i = 0; i < io->results_count; i++)
	for (j = 0; j < res[i].cmdc; j++)
	  output (vc, "%s %s %s\n",
		  res[i].vs, res[i].cmdv[j].cmd, res[i].cmdv[j].data);
      break;
    default:
      for (i = 0; i < io->results_count; i++)
	{
	  output (vc, "%s {\n", res[i].vs);
	  for (j = 0; j < res[i].cmdc; j++)
	    {
	      output (vc, "  %s {\n", res[i].cmdv[j].cmd);
	      output (vc, "%s\n",
		      (res[i].cmdv[j].data) ? res[i].cmdv[j].data : "");
	      output (vc, "  }\n");
	    }
	  output (vc, "}\n");
	}
      break;
    }

  output (vc, "EOF\n");
}

/* Store the data collected by `io_buffer_store' into the results
   structure.  */
void io_collect_output (struct connection *vc, const char *command)
{
  struct io *io = (struct io *) vc->io;
  struct io_results *res = io->results;
  int i;

  /* Try and find a matching virtual server in the results collected
     so far. If we do, then add the command and its data to the `command'
     struct.  */
  for (i = 0; i < io->results_count; i++)
    if (strcmp (vc->virtual_server, res[i].vs) == 0)
      break;

  if (i == io->results_count)
    {
      /* No match found.  Create a new slot.  */
      io->results_count ++;
      res = (struct io_results *) xalloc (res, io->results_count
					  * sizeof (struct io_results));
      memset (&res[i], 0, sizeof (struct io_results));
      res[i].vs = strdup (vc->virtual_server);
    }
  /* Found a matching VS. Add a new command.  */
  res[i].cmdv = (struct res_command *) xalloc (res[i].cmdv, (res[i].cmdc + 1)
					   * sizeof (struct res_command));
  res[i].cmdv[res[i].cmdc].cmd = strdup (command);
  res[i].cmdv[res[i].cmdc].data = ((io->io_buffer)
				   ? strdup (io->io_buffer)
				   : NULL);
  res[i].cmdc ++;
  io->results = res;
  io_buffer_free (vc);
}

/* Collect data specified by the format string `fmt' and its arguments
   into a extendable buffer.  */
void io_buffer_store (struct connection *vc, const char *fmt, ...)
{
  va_list ap;
  /* Preserve errno for printf.  */
  int saved_errno = errno, nchars, freespace;
  struct io *io = (struct io *) vc->io;

  va_start (ap, fmt);
  if (! io->io_buffer_size)
    {
      /* Size of the buffer.  */
      io->io_buffer_size = 1024;
      /* Pointer to the start of the buffer.  */
      io->io_buffer = (char *) malloc (io->io_buffer_size);
      /* Pointer to the end of the last bit of text written out.  */
      io->io_bufferp = io->io_buffer;
      if (! io->io_buffer) /* Memory failure.  */
	abort ();
    }
  /* Calculate actual space available in buffer for writing a string.  */
  freespace = io->io_buffer_size - (io->io_bufferp - io->io_buffer);

  /* Restore preserved errno for correct output by printf.  */
  errno = saved_errno;
  nchars = vsnprintf (io->io_bufferp, freespace, fmt, ap);
  if (nchars >= freespace)
    {
      int i = io->io_bufferp - io->io_buffer;

      /* Not enough free space. Allocate some more.  */
      io->io_buffer_size += nchars + 2 - freespace;
      io->io_buffer = (char *) realloc (io->io_buffer, io->io_buffer_size);
      io->io_bufferp = io->io_buffer + i;
    }
  /* Finally store the string.  */
  errno = saved_errno;
  io->io_bufferp += vsnprintf (io->io_bufferp, io->io_buffer_size, fmt, ap);
  va_end (ap);
}

/* Free the buffer used by `io_buffer_store'.  */
void io_buffer_free (struct connection *vc)
{
  struct io *io = (struct io *) vc->io;

  if (io->io_buffer_size)
    {
      free (io->io_buffer);
      io->io_buffer = io->io_bufferp = NULL;
      io->io_buffer_size = 0;
    }

  if (io->input_buffer)
    {
      free (io->input_buffer);
      io->input_buffer = NULL;
    }
}
