/* Lower half of IP, consisting of gateway routines
 * Includes routing and options processing code
 */
#include "global.h"
#include "mbuf.h"
#include "internet.h"
#include "timer.h"
#include "netuser.h"
#include "ip.h"
#include "icmp.h"
#include "iface.h"
#include "trace.h"

struct route *routes[32][NROUTE];	/* Routing table */
struct route r_default;			/* Default route entry */

int32 ip_addr;
struct ip_stats ip_stats;

#ifndef	GWONLY
struct mbuf *loopq;	/* Queue for loopback packets */
#endif

/* Route an IP datagram. This is the "hopper" through which all IP datagrams,
 * coming or going, must pass.
 *
 * This router is a temporary hack, since it only does host-specific or
 * default routing (no hierarchical routing yet).
 *
 * "rxbroadcast" is set to indicate that the packet came in on a subnet
 * broadcast. The router will kick the packet upstairs regardless of the
 * IP destination address.
 */
void
ip_route(bp,rxbroadcast)
struct mbuf *bp;
char rxbroadcast;	/* True if packet had link broadcast address */
{
	struct mbuf *htonip();
	struct ip ip;			/* IP header being processed */
	int16 ip_len;			/* IP header length */
	int16 length;			/* Length of data portion */
	int32 gateway;			/* Gateway IP address */
	register struct route *rp;	/* Route table entry */
	struct route *rt_lookup();
	int16 offset;			/* Offset into current fragment */
	int strict = 0;			/* Strict source routing flag */
	char precedence;		/* Extracted from tos field */
	char delay;
	char throughput;
	char reliability;
	int16 opt_len;	/* Length of current option */
	char *opt;	/* -> beginning of current option */
	char *ptr;	/* -> pointer field in source route fields */
	struct mbuf *tbp;

	ip_stats.total++;
	if(len_mbuf(bp) < IPLEN){
		/* The packet is shorter than a legal IP header */
		ip_stats.runt++;
		free_p(bp);
		return;
	}
	/* Sneak a peek at the IP header's IHL field to find its length */
	ip_len = (bp->data[0] & 0xf) << 2;
	if(ip_len < IPLEN){
		/* The IP header length field is too small */
		ip_stats.length++;
		free_p(bp);
		return;
	}
	if(cksum(NULLHEADER,bp,ip_len) != 0){
		/* Bad IP header checksum; discard */
		ip_stats.checksum++;
		free_p(bp);
		return;
	}
	/* Extract IP header */
	ntohip(&ip,&bp);

	if(ip.version != IPVERSION){
		/* We can't handle this version of IP */
		ip_stats.version++;
		free_p(bp);
		return;
	}
	/* Trim data segment if necessary. */
	length = ip.length - ip_len;	/* Length of data portion */
	trim_mbuf(&bp,length);	
				
	/* Process options, if any. Also compute length of secondary IP
	 * header in case fragmentation is needed later
	 */
	strict = 0;
	for(opt = ip.options; opt < &ip.options[ip.optlen];opt += opt_len){
		int32 get32();

		/* Most options have a length field. If this is a EOL or NOOP,
		 * this (garbage) value won't be used
		 */
		opt_len = opt[1] & 0xff;

		switch(opt[0] & OPT_NUMBER){
		case IP_EOL:
			goto no_opt;	/* End of options list, we're done */
		case IP_NOOP:
			opt_len = 1;
			break;		/* No operation, skip to next option */
		case IP_SSROUTE:	/* Strict source route & record route */
			strict = 1;	/* note fall-thru */
		case IP_LSROUTE:	/* Loose source route & record route */
			/* Source routes are ignored unless we're in the
			 * destination field
			 */
			if(ip.dest != ip_addr)
				break;	/* Skip to next option */
			if((opt[2] & 0xff) >= opt_len){
				break;	/* Route exhausted; it's for us */
			}
			/* Put address for next hop into destination field,
			 * put our address into the route field, and bump
			 * the pointer
			 */
			ptr = opt + (opt[2] & 0xff) - 1;
			ip.dest = get32(ptr);
			put32(ptr,ip_addr);
			opt[2] += 4;
			break;
		case IP_RROUTE:	/* Record route */
			if((opt[2] & 0xff) >= opt_len){
				/* Route area exhausted; kick back an error */
				union icmp_args icmp_args;

				icmp_args.pointer = IPLEN + opt - ip.options;
				icmp_output(&ip,bp,PARAM_PROB,0,&icmp_args);
				free_p(bp);
				return;
			}
			/* Add our address to the route */
			ptr = opt + (opt[2] & 0xff) - 1;
			ptr = put32(ptr,ip_addr);
			opt[2] += 4;
			break;
		}
	}
no_opt:

	/* See if it's a broadcast or addressed to us, and kick it upstairs */
	if(ip.dest == ip_addr || rxbroadcast){
#ifdef	GWONLY
	/* We're only a gateway, we have no host level protocols */
		if(!rxbroadcast)
			icmp_output(&ip,bp,DEST_UNREACH,PROT_UNREACH,(union icmp_args *)NULL);
		free_p(bp);
#else
		/* Put IP header back on */
		if((bp = htonip(&ip,bp)) == NULLBUF)
			return;

		/* If this is a local loopback packet, place on the loopback
		 * queue for processing in the main loop. This prevents the
		 * infinite stack recursion and other problems that would
		 * otherwise occur when we talk to ourselves, e.g., with ftp
		 */
		if(ip.source == ip_addr){
			/* LOOPBACK TRACING GOES HERE */
			/* Copy loopback packet into new buffer.
			 * This avoids an obscure problem with TCP which
			 * dups its outgoing data before transmission and
			 * then frees it when an ack comes, even though the
			 * receiver might not have actually read it yet
			 */
			tbp = copy_p(bp,len_mbuf(bp));
			free_p(bp);
			if(tbp != NULLBUF)
				enqueue(&loopq,tbp);
		} else {
			ip_recv(bp,rxbroadcast);
		}
#endif
		return;
	}

	/* Decrement TTL and discard if zero */
	if(--ip.ttl == 0){
		/* Send ICMP "Time Exceeded" message */
		icmp_output(&ip,bp,TIME_EXCEED,0,NULLICMP);
		free_p(bp);
		return;
	}
	/* Look up target address in routing table */
	if((rp = rt_lookup(ip.dest)) == NULLROUTE){
		/* No route exists, return unreachable message */
		icmp_output(&ip,bp,DEST_UNREACH,HOST_UNREACH,NULLICMP);
		free_p(bp);
		return;
	}
	/* Find gateway; zero gateway in routing table means "send direct" */
	if(rp->gateway == (int32)0)
		gateway = ip.dest;
	else
		gateway = rp->gateway;

	if(strict && gateway != ip.dest){
		/* Strict source routing requires a direct entry */
		icmp_output(&ip,bp,DEST_UNREACH,ROUTE_FAIL,NULLICMP);
		free_p(bp);
		return;
	}
	precedence = PREC(ip.tos);
	delay = ip.tos & DELAY;
	throughput = ip.tos & THRUPUT;
	reliability = ip.tos & RELIABILITY;

	if(ip.length <= rp->interface->mtu){
		/* Datagram smaller than interface MTU; put header
		 * back on and send normally
		 */
		bp = htonip(&ip,bp);
		(*rp->interface->send)(bp,rp->interface,gateway,
			precedence,delay,throughput,reliability);
		return;
	}
	/* Fragmentation needed */
	if(ip.fl_offs & DF){
		/* Don't Fragment set; return ICMP message and drop */
		icmp_output(&ip,bp,DEST_UNREACH,FRAG_NEEDED,NULLICMP);
		free_p(bp);
		return;
	}
	/* Create fragments */
	offset = (ip.fl_offs & F_OFFSET) << 3;
	while(length != 0){		/* As long as there's data left */
		int16 fragsize;		/* Size of this fragment's data */
		struct mbuf *f_data;	/* Data portion of fragment */

		/* After the first fragment, should remove those
		 * options that aren't supposed to be copied on fragmentation
		 */
		ip.fl_offs = offset >> 3;
		if(length + ip_len <= rp->interface->mtu){
			/* Last fragment; send all that remains */
			fragsize = length;
		} else {
			/* More to come, so send multiple of 8 bytes */
			fragsize = (rp->interface->mtu - ip_len) & 0xfff8;
			ip.fl_offs |= MF;
		}
		ip.length = fragsize + ip_len;

		/* Move the data fragment into a new, separate mbuf */
		if((f_data = alloc_mbuf(fragsize)) == NULLBUF){
			free_p(bp);
			break;
		}
		f_data->cnt = pullup(&bp,f_data->data,fragsize);

		/* Put IP header back on */
		if((f_data = htonip(&ip,f_data)) == NULLBUF){
			free_p(bp);
			break;
		}
		/* and ship it out */
		(*rp->interface->send)(f_data,rp->interface,gateway,
			precedence,delay,throughput,reliability);

		offset += fragsize;
		length -= fragsize;
	}
}

/* Add an entry to the IP routing table. Returns 0 on success, -1 on failure */
int
rt_add(target,bits,gateway,metric,interface)
int32 target;	/* Target IP address prefix */
unsigned bits;	/* Size of target address prefix in bits (0-32) */
int32 gateway;
int metric;
struct interface *interface;
{
	struct route *rp,**hp,*rt_lookup();
	int16 hash_ip(),i;

	if(interface == NULLIF)
		return -1;

	/* Zero bits refers to the default route */
	if(bits == 0){
		rp = &r_default;
	} else {
		if(bits > 32)
			bits = 32;

		/* Mask off don't-care bits */
		for(i=31;i >= bits;i--)
#if (ATARI_ST && LATTICE)
			target &=  (0x80000000 >> (i-1));	/* DG2KK */
#else
			target &= ~(0x80000000 >> i);
#endif

		/* Search appropriate chain for existing entry */
		for(rp = routes[bits-1][hash_ip(target)];rp != NULLROUTE;rp = rp->next){
			if(rp->target == target)
				break;
		}
	}
	if(rp == NULLROUTE){
		/* The target is not already in the table, so create a new
		 * entry and put it in.
		 */
		if((rp = (struct route *)malloc(sizeof(struct route))) == NULLROUTE)
			return -1;	/* No space */
		/* Insert at head of table */
		rp->prev = NULLROUTE;
		hp = &routes[bits-1][hash_ip(target)];
		rp->next = *hp;
		if(rp->next != NULLROUTE)
			rp->next->prev = rp;
		*hp = rp;
	}
	rp->target = target;
	rp->gateway = gateway;
	rp->metric = metric;
	rp->interface = interface;
	return 0;
}

/* Remove an entry from the IP routing table. Returns 0 on success, -1
 * if entry was not in table.
 */
int
rt_drop(target,bits)
int32 target;
unsigned bits;
{
	register struct route *rp;
	struct route *rt_lookup();
	unsigned i;
	int16 hash_ip();

	if(bits == 0){
		/* Nail the default entry */
		r_default.interface = NULLIF;
		return 0;
	}
	if(bits > 32)
		bits = 32;

	/* Mask off don't-care bits */
	for(i=31;i > bits;i--)
#if (ATARI_ST && LATTICE)
		target &=  (0x80000000 >> (i-1));	/* DG2KK */
#else
		target &= ~(0x80000000 >> i);
#endif

	/* Search appropriate chain for existing entry */
	for(rp = routes[bits-1][hash_ip(target)];rp != NULLROUTE;rp = rp->next){
		if(rp->target == target)
			break;
	}
	if(rp == NULLROUTE)
		return -1;	/* Not in table */

	if(rp->next != NULLROUTE)
		rp->next->prev = rp->prev;
	if(rp->prev != NULLROUTE)
		rp->prev->next = rp->next;
	else
		routes[bits-1][hash_ip(target)] = rp->next;

	free((char *)rp);
	return 0;
}

/* Compute hash function on IP address */
static int16
hash_ip(addr)
register int32 addr;
{
	register int16 ret;

	ret = hiword(addr);
	ret ^= loword(addr);
	ret %= NROUTE;
	return ret;
}
#ifndef	GWONLY
/* Given an IP address, return the MTU of the local interface used to
 * reach that destination. This is used by TCP to avoid local fragmentation
 */
int16
ip_mtu(addr)
int32 addr;
{
	register struct route *rp;
	struct route *rt_lookup();

	rp = rt_lookup(addr);
	if(rp != NULLROUTE && rp->interface != NULLIF)
		return rp->interface->mtu;
	else
		return 0;
}
#endif
/* Look up target in hash table, matching the entry having the largest number
 * of leading bits in common. Return default route if not found;
 * if default route not set, return NULLROUTE
 */
static struct route *
rt_lookup(target)
int32 target;
{
	register struct route *rp;
	int16 hash_ip();
	unsigned bits;

	for(bits = 32;bits != 0; bits--){
		if(bits != 32)
#if (ATARI_ST && LATTICE)
			target &=  (0x80000000 >> (bits-1));	/* DG2KK */
#else
			target &= ~(0x80000000 >> bits);
#endif
		for(rp = routes[bits-1][hash_ip(target)];rp != NULLROUTE;rp = rp->next){
			if(rp->target == target)
				return rp;
		}
	}
	if(r_default.interface != NULLIF)
		return &r_default;
	else
		return NULLROUTE;
}
/* Convert IP header in host format to network mbuf */
struct mbuf *
htonip(ip,data)
struct ip *ip;
struct mbuf *data;
{
	int16 hdr_len;
	struct mbuf *bp;
	register char *cp;
	int16 checksum;

	hdr_len = IPLEN + ip->optlen;
	if((bp = alloc_mbuf(hdr_len)) == NULLBUF){
		return NULLBUF;
	}
	bp->cnt = hdr_len;
	cp = bp->data;
	
	*cp++ = (IPVERSION << 4) | (hdr_len >> 2);
	*cp++ = ip->tos;
	cp = put16(cp,ip->length);
	cp = put16(cp,ip->id);
	cp = put16(cp,ip->fl_offs);
	*cp++ = ip->ttl;
	*cp++ = ip->protocol;
	cp = put16(cp,0);	/* Clear checksum */
	cp = put32(cp,ip->source);
	cp = put32(cp,ip->dest);
	if(ip->optlen != 0)
		memcpy(cp,ip->options,ip->optlen);

	/* Compute checksum and insert into header */
	checksum = cksum(NULLHEADER,bp,hdr_len);
	put16(&bp->data[10],checksum);

	bp->next = data;
	return bp;
}
/* Extract an IP header from mbuf */
ntohip(ip,bpp)
struct ip *ip;
struct mbuf **bpp;
{
	char v_ihl;
	int16 ihl;

	v_ihl = pullchar(bpp);
	ip->version = (v_ihl >> 4) & 0xf;
	ip->tos = pullchar(bpp);
	ip->length = pull16(bpp);
	ip->id = pull16(bpp);
	ip->fl_offs = pull16(bpp);
	ip->ttl = pullchar(bpp);
	ip->protocol = pullchar(bpp);
	(void)pull16(bpp);	/* Toss checksum */
	ip->source = pull32(bpp);
	ip->dest = pull32(bpp);

	ihl = (v_ihl & 0xf) << 2;
	if(ihl < IPLEN){
		/* Bogus packet; header is too short */
		return -1;
	}
	ip->optlen = ihl - IPLEN;
	if(ip->optlen != 0)
		pullup(bpp,ip->options,ip->optlen);

	return ip->optlen + IPLEN;
}
/* Perform end-around-carry adjustment */
int16
eac(sum)
register int32 sum;	/* Carries in high order 16 bits */
{
	register int16 csum;

	while((csum = sum >> 16) != 0)
		sum = csum + (sum & 0xffffL);
	return (int16) (sum & 0xffffl);	/* Chops to 16 bits */
}
/* Checksum a mbuf chain, with optional pseudo-header */
int16
cksum(ph,m,len)
struct pseudo_header *ph;
register struct mbuf *m;
int16 len;
{
	register unsigned int cnt, total;
	register int32 sum, csum;
	register unsigned char *up;
	int16 csum1;
	int swap = 0;
	int16 lcsum();

	sum = 0l;

	/* Sum pseudo-header, if present */
	if(ph != NULLHEADER){
		sum = hiword(ph->source);
		sum += loword(ph->source);
		sum += hiword(ph->dest);
		sum += loword(ph->dest);
		sum += ph->protocol & 0xff;
		sum += ph->length;
	}
	/* Now do each mbuf on the chain */
	for(total = 0; m != NULLBUF && total < len; m = m->next) {
		cnt = min(m->cnt, len - total);
		up = (unsigned char *)m->data;
		csum = 0;

		if(((long)up) & 1){
			/* Handle odd leading byte */
			if(swap)
				csum = *up++ & 0xff;
			else
				csum = (int16)((*up++ & 0xff) << 8);
			cnt--;
			swap = !swap;
		}
		if(cnt > 1){
			/* Have the primitive checksumming routine do most of
			 * the work. At this point, up is guaranteed to be on
			 * a short boundary
			 */
			csum1 = lcsum((unsigned short *)up, cnt >> 1);
			if(swap)
				csum1 = (csum1 << 8) | (csum1 >> 8);
			csum += csum1;
		}
		/* Handle odd trailing byte */
		if(cnt & 1){
			if(swap)
				csum += up[--cnt] & 0xff;
			else
				csum += (int16)((up[--cnt] & 0xff) << 8);
			swap = !swap;
		}
		sum += csum;
		total += m->cnt;
	}
	/* Do final end-around carry, complement and return */
	return ~eac(sum) & 0xffff;
}
/* Machine-independent, alignment insensitive network-to-host long conversion */
static int32
get32(cp)
register char *cp;
{
	int32 rval;

	rval = *cp++ & 0xff;
	rval <<= 8;
	rval |= *cp++ & 0xff;
	rval <<= 8;
	rval |= *cp++ & 0xff;
	rval <<= 8;
	rval |= *cp & 0xff;

	return rval;
}
