/*  srx - SR execution manager
 *
 *  Srx is initiated by an SR program when it first has a need.  Params are:
 *	argv[1]	    a magic string to validate the call
 *	argv[2]	    protocol version number for sanity check with caller
 *	argv[3]	    caller's path (caller's argv[0])
 *
 *  The address of srx's listener socket is output to stdout, on the
 *  assumption that it is a pipe to initiating process (main machine).
 */


#include "../paths.h"
#include "rts.h"
#include <errno.h>
#include <fcntl.h>
#include <signal.h>

#ifdef ns32000	/* if Encore (nothing better to check than "ns32000"?) */
#include <sys/commparam.h>
#else		/* not Encore */
#include <sys/param.h>
#endif



char	version[] = VERSION;		/* SR version number */



/*  physical machine data -- default entry is head of list  */

struct pmdata {
    int num;			/* physical machine number */
    char *hostname;		/* name of host on which to create machine */
    char *exepath;		/* path to executable program on that host */
    struct pmdata *next;	/* next data node */
};

struct pmdata physm;



/*  virtual machine data  */

typedef enum { STARTING, WORKING, DYING, GONE } vmstate;

struct vmdata {		/* virtual machine data */
    int phys;			/* physical machine number */
    int pid;			/* process id */
    vmstate state;		/* current state */
    char addr[SOCK_ADDR_SIZE];	/* socket address */
    int notify;			/* machine to notify on birth or death */
    remd rem;			/* request message pointer for acking CREVM */
};

struct vmdata vm[1+MAX_VM];



/*  miscellaneous globals  */

char sr_net_exe_path[MAX_PATH];	/* network path of exe file */

int sr_my_vm = SRX_VM;		/* VM number for sr_net_send */
int nvm = 0;			/* number of virtual machines started */
int ndied = 0;			/* number that have died */
int exiting = 0;		/* shutdown in progress? */

char my_addr[SOCK_ADDR_SIZE];	/* address of srx socket */



/*  global scratch area used for all packet I/O  */

union {
    struct pach_st header;
    struct saddr_st sock;
    struct num_st npkt;
    struct locn_st locn;
} packet;

#define PH (&packet.header)
#define ORIGIN (packet.header.origin)



/*  function declarations */

extern char *netpath();

extern void sr_init_debug();
extern char *sr_net_start();
extern void sr_net_more(), sr_net_send();
extern enum ms_type sr_net_recv();

struct pmdata *lookup();
char *alloc(), *salloc();
void callme(), locvm(), crevm(), findvm(), hello();
void destvm(), ackdest(), exitmsg(), eof();
void exe(), mort(), setloc();
char *getname();



/*  main program  */

main(argc, argv)
int argc;
char *argv[];
{
    int i;
    char *p;
    struct vmdata *v;
    char cwd[MAX_PATH], mapfile[MAX_PATH];
    char message[100];

    /* init debugging */
    sr_init_debug((char *) NULL);
    DEBUG(2, "srx %s %s %s", (argc>1) ? argv[1] : "--",
	(argc>2) ? argv[2] : "--", (argc>3) ? argv[3] : "--");

    /* check for valid call */
    if (argc < 3 || strcmp(argv[1],VM_MAGIC) != 0)
	sr_net_abort("invalid call");
    if (strcmp(argv[2],PROTO_VER) != 0)
	sr_net_abort("protocol version mismatch; rerun srl to fix");

    /* save path of executable */
    physm.exepath = salloc(argv[3]);

    /* build network path of executable */
    if (p = getenv("SRMAP"))
	strcpy(mapfile,p);
    else
	sprintf(mapfile,"%s/%s", *SRLIB ? SRLIB : SRDIR, "srmap");
    getwd(cwd);
    if (!netpath(argv[3],cwd,mapfile,sr_net_exe_path))
	sr_net_abort("can't build network path for executable");
    DEBUG(2,"netpath is: %s",sr_net_exe_path,0,0);

    /* enter caller in vm table */
    v = &vm[++nvm];
    v->phys = 0;
    v->pid = getppid();
    v->state = STARTING;

    /* close all files except stdin, stdout, stderr */
    for (i = 3; i < NOFILE; i++)
	close(i);

    /* start network I/O and send address to caller */
    strcpy(my_addr,sr_net_start());
    write(0,my_addr,strlen(my_addr));	/* fd 0 used by agreement with net.c */

    /* connect stdin to /dev/null; then will have
     *	0. stdin  /dev/null
     *  1. stdout as inherited from caller
     *  2. stderr as inherited from caller
     *  3. socket for listening
     *  4+ available for connections to VMs
     */
    close(0);
    if (open("/dev/null",O_RDONLY) < 0)
	sr_net_abort("can't open /dev/null");

    /* set up interrupt routine to catch deaths of children */
    signal(SIGCHLD,mort);

    /* now just loop, waiting for things to do... */
    for (;;)  {

	sr_net_recv(PH);				/* read packet header */
	if (PH->size > sizeof(packet))
	    sr_net_abort("incoming packet too big");
	sr_net_more(PH);				/* read the rest */

	switch (PH->type) {

	    case MSG_EOF:     eof();     break;
	    case MSG_HELLO:   hello();   break;
	    case MSG_EXIT:    exitmsg(); break;
	    case REQ_CALLME:  callme();  break;
	    case REQ_CREVM:   crevm();   break;
	    case REQ_FINDVM:  findvm();  break;
	    case REQ_DESTVM:  destvm();  break;
	    case ACK_DESTVM:  ackdest(); break;
	    case REQ_LOCVM:   locvm();   break;

	    default:
		sprintf(message,"unexpected packet type %d",PH->type);
		sr_net_abort(message);
	}
    }
}



/*  locvm() - specify location for virtual machine  */

void
locvm()
{
    char *xfile = packet.locn.text + strlen(packet.locn.text) + 1;
    DEBUG(1, "LOCATE %d %s %s", packet.locn.num, packet.locn.text, xfile);
    setloc(packet.locn.num, packet.locn.text, xfile);
    sr_net_send (ORIGIN, ACK_LOCVM, PH, PACH_SZ);
}



/*  callme() - pass a "call me" message from one VM to another  */

void
callme()
{
    int dest = packet.npkt.num;
    DEBUG(1, "CALLME %d from %d", dest, ORIGIN, 0);
    packet.npkt.num = ORIGIN;
    sr_net_send (dest, REQ_CALLME, PH, sizeof(packet.npkt));
}



/*  crevm() - create virtual machine.  */

void
crevm()
{
    int pid, pm;
    struct vmdata *v;

    nvm++;
    if (nvm > MAX_VM)
	sr_net_abort ("too many virtual machines");
    pm = packet.npkt.num;
    DEBUG(1,"CREVM %d on %d",nvm,pm,0);
    fflush(stdout);
    fflush(stderr);
    if ((pid = vfork()) < 0)
	sr_net_abort("can't vfork for new vm");
    if (pid == 0) 
	exe(pm,nvm);		/* in the child, execute a.out */
    v = &vm[nvm];		/* init data for new machine */
    v->phys = 0;
    v->pid = pid;
    v->state = STARTING;
    v->notify = ORIGIN;		/* save info for acking when HELLO comes back */
    v->rem = PH->rem;
}



/*  exe(pm,vn) - exec SR program to be virtual machine vn on phys machine pm  */

void
exe(pm,vn)
int pm, vn;
{
    struct pmdata *p;
    char pmbuf[10], vmbuf[10], dbbuf[10], magicbuf[sizeof(VM_MAGIC)+2];
    char *path, *h;

    sprintf(pmbuf,"%d",pm);
    sprintf(vmbuf,"%d",vn);
    sprintf(dbbuf,"%X",sr_dbg_flags);
    sprintf(magicbuf,"'%s'",VM_MAGIC);

    p = lookup(pm);
    if (p->exepath && *p->exepath)	/* get path to exe file */
	path = p->exepath;		/* use explicit path if one given */
    else
	path = sr_net_exe_path;		/* else use network path */


    if (pm == 0 && !p->hostname) {	/* exec locally */
	DEBUG(2,"[%d] exec %s args...",vn,path,0);
	fflush(stdout);
	fflush(stderr);
	execl(path,path,VM_MAGIC,pmbuf,vmbuf,my_addr,dbbuf,NULL);
	perror(path);
    } else {				/* exec remotely via rsh */
	if (!p->hostname) {
	    if (!(h = getname(pm)))		/* get hostname */
		{ fprintf(stderr,"srx: unknown machine %d\n",pm); exit(1); }
	    p->hostname = salloc(h);		/* save for next time */
	}
	DEBUG(2,"[%d] rsh %s -n exec %s args...",vn,p->hostname,path);
	fflush(stdout);
	fflush(stderr);
	execl(RSHPATH,"rsh",p->hostname,"-n","exec",
		path,magicbuf,pmbuf,vmbuf,my_addr,dbbuf,NULL);
	perror(RSHPATH);
    }
    sr_net_abort("can't execute program");
}



/*  findvm() - find virtual machine.  */

void
findvm()
{
    int n;
    char message[100];

    n = packet.npkt.num;
    DEBUG(1,"FINDM %d from %d",n,ORIGIN,0);
    switch (vm[n].state) {
	case STARTING:
	    sprintf (message,"can't connect to vm %d -- not yet initialized",n);
	    sr_net_abort (message);
	case WORKING:
	    memcpy (packet.sock.addr, vm[n].addr, SOCK_ADDR_SIZE);
	    sr_net_send (ORIGIN, ACK_FINDVM, PH, sizeof(packet.sock));
	    break;
	case DYING:
	case GONE:
	    sprintf (message,"can't connect to vm %d -- already terminated",n);
	    sr_net_abort (message);
    }
}



/*  eof() - process EOF pseudo-message indicating a vm has died  */

void
eof()
{
    char message[100];

    struct vmdata *v = vm + ORIGIN;
    DEBUG(1,"EOF from %d",ORIGIN,0,0);
    if (v->state != GONE)  {
	sprintf (message, "lost connection to virtual machine %d", ORIGIN);
	sr_net_abort (message);
    }
    v->state = GONE;
    if (++ndied == nvm)			/* exit if all alone */
	{ DEBUG(2,"exiting because no VMs left",0,0,0); exit(1); }
}
 



/*  hello() - process HELLO message 
 *
 *  register the new virtual machine, and pass back acknowlegement to its
 *  creator (if any).
 */

void
hello()
{
    struct vmdata *v = vm + ORIGIN;
    DEBUG(1,"HELLO %d at %s",ORIGIN,packet.sock.addr,0);
    if (v->state != STARTING)
	sr_net_abort("unexpected HELLO");
    strncpy(v->addr,packet.sock.addr,SOCK_ADDR_SIZE);
    v->state = WORKING;
    if (v->notify)  {
	PH->rem = v->rem;
	packet.npkt.num = ORIGIN;
	sr_net_send (v->notify, ACK_CREVM, PH, sizeof(packet.npkt));
	v->notify = 0;
    }
}



/*  destvm() - handle REQ_DESTVM message
 *
 *  make a note that a machine is being destroyed, and pass it the message.
 */

void
destvm()
{
    int n;
    char message[100];

    n = packet.npkt.num;
    DEBUG(1,"DESTVM %d from %d",n,ORIGIN,0);
    if (vm[n].state != WORKING) {
	sprintf (message, "can't destroy VM %d -- it's not now running",n);
	sr_net_abort (message);
    }
    vm[n].state = DYING;
    vm[n].notify = ORIGIN;
    vm[n].rem = PH->rem;
    sr_net_send (n, REQ_DESTVM, PH, PACH_SZ);
}



/*  ackdest() - handle ACK_DESTVM message
 *
 *  mark the vm as gone, notify the original destroyer, and kill the process.
 */

void
ackdest()
{
    int n;

    n = ORIGIN;
    DEBUG(1,"GOODBYE from %d",n,0,0);
    sr_net_send (vm[n].notify, ACK_DESTVM, PH, PACH_SZ);
    DEBUG(2,"kill %d [%d]",vm[n].pid,n,0);
    kill(vm[n].pid,SIGINT);
    vm[n].state = GONE;
}



/*  exitmsg() - process EXIT message
 *
 *  pass the EXIT message to all other virtual machines.
 *  give them a chance to die, then kill stragglers.
 */

void
exitmsg()
{
    int i;

    DEBUG(1,"EXIT %d from %d",packet.npkt.num,ORIGIN,0);
    vm[ORIGIN].state = GONE;		/* note that sender has died */
    ++ndied;
    exiting = 1;			/* flag shutdown in progress */
    for (i = nvm; i > 0; i--)
	if (vm[i].state != GONE)	/* send msg to all still alive */
	    sr_net_send (i, MSG_EXIT, PH, sizeof(packet.npkt));

    /*  give everybody a chance to die quietly; then kill 'em.  */
    if (ndied < nvm)
	sleep(5);
    sr_net_abort((char *) NULL);
}



/*  mort() - interrupt routine called when a child dies
 *
 *  Deaths during srx shutdown are merely counted.  If all VMs are gone we will
 *  exit here, but note that we *aren't* notified by interrupt of VM 1's death
 *  because it's our parent, not our child.
 *
 *  Deaths during VM startup mean failure of REQ_CREVM which must be acked.
 *
 *  Other deaths are ignored and will be caught by EOF processing after
 *  the input pipe is flushed.
 */

void
mort()
{
    int n, s, sig, code;
    struct vmdata *v;
    char buf[10];

    n = wait(&s);
    for (v = vm + nvm; v > vm; v--)
	if (v->pid == n)
	    break;
    if (v == vm)
	sr_net_abort("unknown pid returned by wait())");
    sig = s & 0x7F;
    code = s >> 8;

    if (sig != 0 && sig != SIGINT && sig != SIGQUIT)  {
	sprintf (buf, "vm %d", v - vm);
	psignal ((unsigned) sig, buf);
    } else
	DEBUG(2,"vm %d exited with signal %d, code %d", v - vm, sig, code);

    /*  Re-enable the signal.  This is needed under Sys V and derivatives. */
    signal(SIGCHLD,mort);

    if (!exiting && v->state != STARTING)
	return;				/* ignore, handle when EOF seen */

    v->state = GONE;			/* show vm as down */
    if (++ndied == nvm)			/* exit if all alone */
	{ DEBUG(2,"exiting because no VMs left",0,0,0); exit(1); }
    if (exiting)
	return;				/* no further action if shutting down */

    /* if we get here we need to NAK a VM startup */
    PH->rem = v->rem;			/* init caller's reply address */
    packet.npkt.num = NULL_VM;		/* indicate failure */
    sr_net_send (v->notify, ACK_CREVM, PH, sizeof(packet.npkt));
}



/*  setloc(n,host,path) - set or change location for physical machine  */

void
setloc(n,host,path)
int n;
char *host, *path;
{
    struct pmdata *p;

    p = lookup(n);
    if (host && *host)  {
	DEBUG(2,"HOSTNAME for %d:  %s",n,host,0);
	if (p->hostname)
	    free(p->hostname);
	p->hostname = salloc(host);
    }
    if (path && *path)  {
	DEBUG(2,"EXE_PATH for %d:  %s",n,path,0);
	if (p->exepath)
	    free(p->exepath);
	p->exepath = salloc(path);
    }
}



/*  getname(n) - get hostname for physical machine n  */

char *
getname(n)
int n;
{
    int d[4];
    unsigned char a[4];
    struct hostent *he;

    sscanf(my_addr,"%d.%d.%d.%d",d,d+1,d+2,d+3);
              a[3] = n ? n : d[3];
    n >>= 8;  a[2] = n ? n : d[2];
    n >>= 8;  a[1] = n ? n : d[1];
    n >>= 8;  a[0] = n ? n : d[0];
    he = gethostbyaddr ((char *) a, sizeof (a), AF_INET);
    if (he)
	return he->h_name;
    else
	return NULL;
}



/*  lookup(n) - find (create if necessary) entry for physical machine n.  */

struct pmdata *
lookup(n)
int n;
{
    struct pmdata *p;
    static struct pmdata z;

    for (p = &physm; p; p = p->next)
	if (p->num == n)
	    return (p);
    p = (struct pmdata *) alloc(sizeof(struct pmdata));
    *p = z;
    p->num = n;
    p->next = physm.next;
    physm.next = p;
    return (p);
}




/*  alloc(n) - allocate n bytes, with success guaranteed  */

char *
alloc(n)
int n;
{
    char *s;

    s = malloc((unsigned) n);
    if (!s)
	sr_net_abort("out of memory");
    return(s);
}



/*  salloc(s) - allocate and initialize string, with success guaranteed  */

char *
salloc(s)
char *s;
{
    return strcpy(alloc(strlen(s)+1),s);
}



/*  sr_iowait (query, results, inout) - wait for input from a set of files */

void
sr_iowait (query, results, inout)
fd_set *query, *results;
enum io_type inout;
{
    extern int errno;
    int n;

    if (inout != INPUT)
	sr_net_abort("srx iowait not on input");
    do {
	DEBUG (0x80, "select: %08X", query->fds_bits[0], 0, 0);
	*results = *query;
	n = select (FD_SETSIZE, results, (fd_set *) 0, (fd_set *) 0,
	    (struct timeval *) 0);
	DEBUG (0x80, "selectd %08X n=%d", results->fds_bits[0], n, 0);
    } while (n < 0 && errno == EINTR);
    if (n < 0) {
	perror("select");
	sr_net_abort("select failure");
    }
}




/*  sr_net_abort(message) - kill all other processes and exit.
 *  Message is optional.  Used for network errors, srx errors, and normal exit.
 *
 *  note: we use SIGINT, not SIGKILL, because SIGKILL won't kill the
 *  far end of an rsh.
 */

void
sr_net_abort(message)
char *message;
{
    int i;

    if (message)				/* print message if given */
	fprintf(stderr,"srx: %s\n",message);
    for (i = 1; i <= nvm; i++)			/* kill all machines */
	if (vm[i].state != GONE)  {
	    DEBUG(2,"kill %d [%d]",vm[i].pid,i,0);
	    kill(vm[i].pid,SIGINT);
	}
    if (!message)
	DEBUG(2,"srx exiting",0,0,0);
    exit(1);					/* and quit */
}
