/* Port redirection service for virtual servers.
   Copyright (c) 2000 Idaya Ltd.
   Contributed by Tim Sellar <tim@idaya.net>

   This file is part of the Virtual Server Administrator (FreeVSD)

   FreeVSD is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2, or (at your option)
   any later version.

   FreeVSD is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with FreeVSD; see the file COPYING.  If not, write to
   the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
   Boston, MA 02111-1307, USA.  */

#include <stdarg.h>
#include <stdio.h>
#include <string.h>
#include <errno.h>
#include <netdb.h>
#include <signal.h>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#include "libvsd.h"

#define fatal(x) _fatal(__LINE__, __FILE__, x)
void _fatal (int line, char *file, char *fn)
{
#ifdef DEBUG
  fprintf (stderr, "fatal error at line %d in %s\n", line, file);
  fprintf (stderr, " '%s'\n", strerror (errno));
  fprintf (stderr, " calling %s\n", fn);
#else
  perror (fn);
#endif

  exit (1);
}

void dbgmsg (const char *fmt, ...)
{
  va_list ap;

  va_start (ap, fmt);
  vfprintf (stderr, fmt, ap);
  va_end (ap);
}

void redirect (int insock, int outsock)
{
  fd_set iofds;
  fd_set c_iofds;
  int max_fd;
  unsigned long bytes;
  char buf[4096];

#ifdef DEBUG
  unsigned long bytes_in;
  unsigned long bytes_out;
  unsigned int start_time;
  unsigned int end_time;

  start_time = (unsigned int) time(NULL);
#endif

  FD_ZERO(&iofds);
  FD_SET(insock, &iofds);
  FD_SET(outsock, &iofds);
  
  max_fd = (insock > outsock ? insock : outsock);
  
  while(1) 
    {
      memcpy (&c_iofds, &iofds, sizeof(iofds));
      
      if (select (max_fd + 1, &c_iofds, (fd_set *)0, (fd_set *)0, NULL) <= 0)
	break;
      
      if (FD_ISSET (insock, &c_iofds)) 
	{
	  if ((bytes = read (insock, buf, sizeof (buf))) <= 0)
	    break;
	  if (write (outsock, buf, bytes) != bytes)
	    break;
#ifdef DEBUG
	  bytes_out += bytes;
#endif
	}
      if (FD_ISSET (outsock, &c_iofds)) 
	{
	  if ((bytes = read (outsock, buf, sizeof (buf))) <= 0)
	    break;
	  if (write (insock, buf, bytes) != bytes)
	    break;
#ifdef DEBUG
	  bytes_in += bytes;
#endif
	}
    }
  
  shutdown (insock, 0);
  shutdown (outsock, 0);
  close (insock);
  close (outsock);
  
#ifdef DEBUG
  end_time = (unsigned int) time(NULL);
  dbgmsg ("closed [insock:%d][outsock:%d][in:%ldB][out:%ldB][time:%ds]\n",
	  insock, outsock, bytes_in, bytes_out, end_time - start_time);
#endif
}

void vsdaccept (int serversock, struct sockaddr_in *target)
{
  struct sockaddr_in client;
  int targetsock = 0;
  int clientsock = 0;
  int targetlen = sizeof(struct sockaddr_in);
  int clientlen = sizeof(struct sockaddr_in);
  
  if ((clientsock = 
       accept (serversock, (struct sockaddr *)&client, &clientlen)) < 0)
    fatal ("accept()");
  
  switch(fork())
    {
    case -1:
      fatal ("(server) fork()");
    case 0:
      break;
    default:
      {
	int status;	    
	(void) wait (&status);
	close (clientsock);
	return;
      }
    }
  
  switch(fork())
    {
    case -1:
      fatal ("(child) fork()");
    case 0:
      break;
    default:
      exit (0);
    }
  
  if ((targetsock = socket(AF_INET, SOCK_STREAM, 0)) < 0) 
    fatal ("socket()");
  
  if (connect (targetsock, (struct sockaddr *)target, targetlen) < 0) 
    fatal ("connect()");
  
  redirect(clientsock, targetsock);

  exit (0);
}

int vsdredirect (void *addr, int serverport, int targetport)
{
  struct sockaddr_in server;
  struct sockaddr_in target;
  struct hostent *phostent, *test_hostent;
  int serversock = 0;
  int serverlen = sizeof(struct sockaddr_in);
  unsigned char reuseaddr = 1;
  unsigned char linger = 0;

  server.sin_family      = AF_INET;
  server.sin_port        = htons (serverport);

  target.sin_family      = AF_INET;
  target.sin_port        = htons (targetport);

  if ((phostent = gethostbyname(addr)) == NULL)
    fatal ("gethostbyname()");

  memcpy(&server.sin_addr, phostent->h_addr, phostent->h_length);
  memcpy(&target.sin_addr, phostent->h_addr, phostent->h_length);
 
#ifdef DEBUG
  dbgmsg ("vsdredirect: %s:%d -> %s:%d\n",
	  addr, serverport,
	  addr, targetport); 
#endif
  
  if ((serversock = socket(AF_INET, SOCK_STREAM, 0)) < 0) 
    fatal ("socket()");

  setsockopt(serversock, SOL_SOCKET, SO_REUSEADDR, &reuseaddr, sizeof(reuseaddr));
  setsockopt(serversock, SOL_SOCKET, SO_LINGER, &linger, sizeof(linger)); 
     
  if (bind (serversock, (struct sockaddr *)&server, serverlen) < 0) 
    fatal ("bind()");
  
  if (listen (serversock, 10) < 0) 
    fatal ("listen()");

  while (1) 
    vsdaccept (serversock, &target);
      
  return (0);
}

int main(int argc, char *argv[])
{
  char to_addr[32];
  char *portstring;
  int to_port, from_port;

  /* Quickly check that at least correct number of args have been given */
  if ((argc < 2) || (argc >= 3))
  {
    fprintf (stderr, "vsdredirect: usage: <ip address>:<port to redirect from>:<port to redirect to>\n");
    exit (-1);
  }

  /* Place args into variables and strip port numbers */
  portstring = strrchr(argv[1], ':');
  to_port = atoi(++portstring);
  portstring = strchr(argv[1], ':');
  from_port = atoi(++portstring);
  strncpy (to_addr, argv[1], sizeof(to_addr));
  portstring = strchr(to_addr, ':'); 
  *portstring++ = '\0';

  switch(fork())
    {
    case -1:
      fatal ("(main) fork()");
    case 0:
      if (vsdredirect (&to_addr, from_port, to_port) != 0)
        fatal ("vsdredirect()");
      break;
    default:
    }

  exit (0);
}


