/* Start services on a virtual server.
   Copyright (c) 1999, 2000 Idaya Ltd.
   Contributed by Nick Burrett <nick@dsvr.net>

   This file is part of the Virtual Server Administrator (FreeVSD)

   FreeVSD is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2, or (at your option)
   any later version.

   FreeVSD is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with FreeVSD; see the file COPYING.  If not, write to
   the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
   Boston, MA 02111-1307, USA.  */

#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <signal.h>
#include <errno.h>
#include <getopt.h>
#include <unistd.h>
#include <sys/wait.h>
#include <sys/types.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <net/if.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#include "libvsd.h"
#include "config.h"

#define fatal(x) _fatal(__LINE__, __FILE__, x)
void _fatal (int line, char *file, char *fn);

#define OPTION_START 150
#define OPTION_STOP 151
#define OPTION_RESTART 152
#define OPTION_BIND 153
#define OPTION_UNBIND 154

/* Network interface that we wish to bind virtual servers to.  */
#define NETDEVICE "eth0"

/* Linux limits us to 255 aliases.  */
#define MAX_ALIASES 255

static int skfd = 0, quiet = 0;
extern char **environ;

static void log (const char *fmt, ...)
{
  va_list ap;

  va_start (ap, fmt);
  if (quiet == 0)
    vfprintf (stdout, fmt, ap);
  va_end (ap);
}

static int log_error (const char *fmt, ...)
{
  va_list ap;

  va_start (ap, fmt);
  vfprintf (stderr, fmt, ap);
  va_end (ap);

  return 1;
}

static int find_free_interface (const char *device)
{
  struct ifreq ifr;
  int x, status;

  for (x = 1; x < MAX_ALIASES; x++)
    {
      sprintf (ifr.ifr_name, "%s:%d", device, x);
      status = ioctl (skfd, SIOCGIFADDR, &ifr);

      if (status == -1)
	return x;
    }

  return -1;
}

/* Bind an IP address `ip' to an ethernet interface `device' on alias
   `number'.  */
static int bind_interface (const char *device, int number, void *ip)
{
  char cmd[256];
  char *ipaddr = vsd_ip_ntoa (ip);

  /* FIXME: Need to convert into ioctls.  */
  sprintf (cmd, "/sbin/ifconfig %s:%d %s netmask 255.255.255.255 up",
	   device, number, ipaddr);
  system (cmd);
#ifdef DEBUG
  printf ("dev (%s), num (%d), ip (%s)\n", device, number, ipaddr);
#endif
  sprintf (cmd, "/sbin/route add %s gw %s dev %s:%d", ipaddr, ipaddr,
	   device, number);
  system (cmd);

  return 0;
}

/* Search for `ip' on the ethernet interface `device' and shut it down,
   if found.  */
static int close_interface (const char *device, void *ip)
{
  char cmd[256], *temp;
  struct ifreq ifr;
  int x, skfd, status, found = 0;
  struct sockaddr_in *sin;
  char ipaddr[32];

  /* Store on the stack here because inet_ntoa will corrupt
     our result.  */
  strcpy (ipaddr, vsd_ip_ntoa (ip));

  skfd = socket (AF_INET, SOCK_DGRAM, 0);
  for (x = 0; x < MAX_ALIASES && ! found; x++)
    {
      sprintf (ifr.ifr_name, "%s:%d", device, x);
      status = ioctl (skfd, SIOCGIFADDR, &ifr);
      if (status == 0)
	{
	  sin = (struct sockaddr_in *) &ifr.ifr_addr;
	  sin->sin_family = AF_INET;
	  sin->sin_port = 0;
	  temp = inet_ntoa (sin->sin_addr);
	  if (strcmp (temp, ipaddr) == 0)
	    {
	      close (skfd);
	      sprintf (cmd, "/sbin/ifconfig %s:%d %s down",
		       device, x, ipaddr);
	      system (cmd);
	      return 0;
	    }

	}
    }

  close (skfd);
  return -1;
}

static int command (const char *server_dir, const char *cmd)
{
  pid_t pid;
  int status;

  if (access (server_dir, R_OK | X_OK) == -1)
    {
      log_error ("couldn't access %s: %s\n", server_dir, strerror (errno));
      return 1;
    }

  /* Spawn a child process to initialise the virtual server.  */
  pid = fork ();
  if (pid == (pid_t) 0)
    {
      const char *new_argv[4];

      if (chroot (server_dir))
	{
	  log_error ("chroot %s failed\n", server_dir);
	  _exit (127);
	}

      chdir ("/");
      new_argv[0] = "sh";
      new_argv[1] = "-c";
      new_argv[2] = cmd;
      new_argv[3] = NULL;
      execve ("/bin/sh", (char *const *) new_argv, environ);

      log_error ("vsboot - execve failed: %s\n", strerror (errno));
      _exit (127);
    }
  else if (pid < (pid_t) 0)
    {
      /* Fork failed.  */
      log_error ("vsboot - fork failed: %s\n", strerror (errno));
    }
  else
    {
      /* Parent site.  */
      if (waitpid (pid, &status, 0) != pid)
	{
	  log_error ("vsboot - wait failed\n");
	  return 1;
	}
    }

  return 0;
}

static int boot_server (struct vsd_vs_map *map, struct vsd_vs *vs, int stage)
{
  const char *path;
  int x;

  if (vs == NULL)
    return -1;

  if (vs->status != S_ACTIVE)
    return 0;
  
  path = vsd_map_server_root (map, vs->name);

  if (access (path, R_OK | X_OK) == -1)
    {
      log_error ("couldn't access %s: %s\n", path, strerror (errno));
      return 1;
    }

  if (stage == 1)
    {
      /* Bind an interface to the virtual server IP, mount the /proc
	 file system and start the services.  */
      bind_interface (NETDEVICE, find_free_interface (NETDEVICE), vs->ip);
      
      /* Bind interfaces for any IP aliases that the virtual server has.  */
      for (x = 0; x < vs->ipalias_total; x++)
	bind_interface (NETDEVICE, find_free_interface (NETDEVICE),
			&vs->ipalias[x]);

      /* Always wait for /proc to be mounted.  */
      command (path, "mount /proc -o defaults,ro");
    }
  if (stage == 2)
    command (path, "/etc/rc");

  return 0;
}

static int shutdown_server (struct vsd_vs_map *map, struct vsd_vs *vs,
			    int stage)
{
  const char *path;
  int x, signum, *procs;

  if (vs == NULL)
    return -1;

  if (vs->status != S_ACTIVE)
    return 0;

  path = vsd_map_server_root (map, vs->name);

  if (access (path, R_OK | X_OK) == -1)
    {
      log_error ("couldn't access %s: %s\n", path, strerror (errno));
      return 1;
    }

  if (stage == 0)
    {
      /* Shutdown interfaces for any IP aliases that the virtual server has.
	 Do this before killing the processes to block users out while we are
	 trying to kill their processes.  */
      for (x = 0; x < vs->ipalias_total; x++)
	close_interface (NETDEVICE, &vs->ipalias[x]);
      close_interface (NETDEVICE, vs->ip);
    }

  signum = (stage == 1) ? SIGTERM : SIGKILL;

  if (stage == 1 || stage == 2)
    {
      /* Stop the virtual server services.  */
      procs = vsd_enum_procs (vs->name);
      x = 0;
      while (procs[x] != -1)
	{
	  kill (procs[x], signum);
	  x++;
	}
    }

  if (stage == 3)
    command (path, "umount /proc");

  return 0;
}

static int find_vs (struct vsd_vs_map *map, const char *vs)
{
  int x;

  for (x = 0; x < map->servers; x++)
    if (strcmp (map->server[x].name, vs) == 0)
      return x;

  return -1;
} 

static void display_help (void)
{
  printf ("syntax: vsboot [options] {command} [<virtual server>]\n");
  printf ("Options:\n");
  printf ("  -d  --delay <n> set an arbitary pause during VS startup\n");
  printf ("  -f   --freq <n> number of VSes to start before delaying\n");
  printf ("  -q  --quiet     don't output the work being done\n");
  printf ("  -p   --part <n> only control VSes on Partition <n>\n"); 
  printf ("Commands:\n");
  printf ("       --bind     bind interfaces for a server\n");
  printf ("     --unbind     remove interfaces for a server\n");
  printf ("      --start     start a virtual server\n");
  printf ("       --stop     stop a virtual server\n");
  printf ("    --restart     restart a virtual server\n");

  printf ("\nThe action will be applied to all virtual servers unless\n");
  printf ("a partition number or virtual server name is given.  The\n");
  printf ("--delay and --freq options are can be used to workaround the\n");
  printf ("error `fork error: cannot allocate memory' that can often occur\n");
  printf ("if the virtual servers are started with haste.\n");
}

int main (int argc, char *argv[])
{
  struct vsd_vs_map *map;
  enum { bind, unbind, start, stop, restart, refresh, unrecog } action;
  char *vs = NULL;
  int delay = 0, freq = 0, partition = -1;

  const char *shortopts = "-hq";
  struct option longopts[] = {
    { "delay", required_argument, NULL, 'd' },
    { "freq", required_argument, NULL, 'f' },
    { "start", no_argument, NULL, OPTION_START },
    { "stop", no_argument, NULL, OPTION_STOP },
    { "restart", no_argument, NULL, OPTION_RESTART },
    { "bind", no_argument, NULL, OPTION_BIND },
    { "unbind", no_argument, NULL, OPTION_UNBIND },
    { "help", no_argument, NULL, 'h' },
    { "quiet", no_argument, NULL, 'q' },
    { "part", required_argument, NULL, 'p' },
    { NULL, no_argument, NULL, 0 }
  };

  action = unrecog;
  while (1)
    {
      int longind, optc;

      optc = getopt_long_only (argc, argv, shortopts, longopts, &longind);
      if (optc == -1)
	break;
      switch (optc)
	{
	default:
	  return 1;
	  break;
	case 1: /* Virtual server.  */
	  vs = optarg;
	  break;
	case 'h':
	  display_help ();
	  return 0;
	  break;
	case 'q':
	  quiet = 1;
	  break;
	case 'd':
	  delay = atoi (optarg);
	  break;
	case 'f':
	  freq = atoi (optarg);
	  break;
	case 'p':
	  partition = atoi (optarg);
	  break;
	case OPTION_START:
	  action = start;
	  break;
	case OPTION_STOP:
	  action = stop;
	  break;
	case OPTION_RESTART:
	  action = restart;
	  break;
	case OPTION_BIND:
	  action = bind;
	  break;
	case OPTION_UNBIND:
	  action = unbind;
	  break;
	}
    }

  if (action == unrecog)
    return log_error ("%s: nothing to do\n", argv[0]);

  if (getuid ())
    return log_error ("%s: must be run as root\n", argv[0]);

  map = vsd_map_read ();
  if (map == NULL)
    return log_error ("couldn't read server map file: %s\n",
		      strerror (errno));

  skfd = socket (AF_INET, SOCK_DGRAM, 0);
  if (skfd == -1)
    return log_error ("couldn't create socket: %s\n", strerror (errno));

  switch (action)
    {
    case bind:
    case unbind:
      {
	int x, y;
	/* Add or remove bind interfaces.  */
	for (x = 0; x < map->servers; x++)
	  if (! vs || (vs && strcmp (map->server[x].name, vs) == 0))
	    {
	      if (action == bind && map->server[x].status == S_ACTIVE)
		bind_interface (NETDEVICE, find_free_interface (NETDEVICE),
			        map->server[x].ip);
	      if (action == unbind)
		close_interface (NETDEVICE, map->server[x].ip);

	      for (y = 0; y < map->server[x].ipalias_total; y++)
		{
		  if (action == bind && map->server[x].status == S_ACTIVE)
		    bind_interface (NETDEVICE, find_free_interface (NETDEVICE),
				    &map->server[x].ipalias[y]);
		  if (action == unbind)
		    close_interface (NETDEVICE, &map->server[x].ipalias[y]);
		}
	    }
      }
      break;
    case start:
    case stop:
    case restart:
      if (vs)
	{
	  int x, stage;
	  if ((x = find_vs (map, vs)) == -1)
	    {
	      log ("vs %s not found\n", vs);
	      break;
	    }
	  if (action == stop || action == restart)
	    {
	      for (stage = 0; stage <= 3; stage ++)
		{
		  log ("shutdown VS %s: stage [%d]\n", vs, stage);
		  shutdown_server (map, &map->server[x], stage);
		}
	    }
	  if (action == start || action == restart)
	    {
	      if (map->server[x].status != S_ACTIVE)
		log ("start VS %s failed: VS is disabled\n", vs);
	      else
		{
		  for (stage = 1; stage <= 2; stage ++)
		    {
		      log ("start VS %s: stage [%d]\n", vs, stage);
		      boot_server (map, &map->server[x], stage);
		    }
		}
	    }
	}
      else
	{
	  int x, stage;
	  if (action == stop || action == restart)
	    {
	      for (stage = 0; stage <= 3; stage ++)
		{
	          if (partition != -1)
		    log ("shutdown VSes on partition %d: stage [%d]\n",
			 partition, stage);
	          else
		    log ("shutdown all VSes: stage [%d]\n", stage);
		  for (x = 0; x < map->servers; x++)
		    if (partition == -1
			|| map->server[x].partition == partition)
		      shutdown_server (map, &map->server[x], stage);
		}
	    }

	  if (action == start || action == restart)
	    {
	      for (stage = 1; stage <= 2; stage ++)
		{
	          if (partition != -1)
		    log ("start VSes on partition %d: stage [%d]\n",
			 partition, stage);
	          else
		    log ("start all VSes: stage [%d]\n", stage);
		  for (x = 0; x < map->servers; x++)
		    if (map->server[x].status == S_ACTIVE
			&& (partition == -1
			    || map->server[x].partition == partition))
		      boot_server (map, &map->server[x], stage);
		}
	    }
	}
      break;

    default:
      printf ("unrecognised command\n");
      break;
    }

  close (skfd);
  vsd_map_free (map);

  return 0;
}


