/* Lower half of IP, consisting of gateway routines
 * Includes routing and options processing code
 */
#include <stdio.h>
#include "machdep.h"
#include "mbuf.h"
#include "internet.h"
#include "timer.h"
#include "netuser.h"
#include "ip.h"
#include "icmp.h"
#include "iface.h"
#ifdef	TRACE
#include "trace.h"
#endif

struct route *routes[32][NROUTE];	/* Routing table */
struct route r_default;			/* Default route entry */

int32 ip_addr;
struct ip_stats ip_stats;

/* Route an IP datagram. This is the "hopper" through which all IP datagrams,
 * coming or going, must pass.
 *
 * This router is a temporary hack, since it only does host-specific or
 * default routing (no hierarchical routing yet).
 *
 * "rxbroadcast" is set to indicate that the packet came in on a subnet
 * broadcast. The router will kick the packet upstairs regardless of the
 * IP destination address.
 */
void
ip_route(bp,rxbroadcast)
struct mbuf *bp;
char rxbroadcast;	/* True if packet had link broadcast address */
{
	register struct ip_header *ip;	/* IP header being processed */
	int16 ip_len;			/* IP header length */
	int16 buflen;			/* Length of mbuf */
	int16 length;			/* Total datagram length */
	int32 target;			/* Target IP address */
	int32 gateway;			/* Gateway IP address */
	register struct route *rp;	/* Route table entry */
	struct route *rt_lookup();
	int opi;			/* Index into options field */
	int opt_len;			/* Length of current option */
	int strict;			/* Strict source routing flag */
	struct mbuf *sbp;		/* IP header for fragmenting */
	int16 fl_offs;			/* fl_offs field of datagram */
	int16 offset;			/* Offset of fragment */
	char precedence;		/* Extracted from tos field */
	char delay;
	char throughput;
	char reliability;

	ip_stats.total++;
	buflen = len_mbuf(bp);
	if(buflen < sizeof(struct ip_header)){
		/* The packet is shorter than a legal IP header */
		ip_stats.runt++;
		free_p(bp);
		return;
	}
	ip = (struct ip_header *)bp->data;
	length = ntohs(ip->length);
	if(buflen > length){
		/* Packet has excess garbage (e.g., Ethernet padding); trim */
		if(bp->next == NULLBUF){
			/* One mbuf, just adjust count */
			bp->cnt = length;
		} else {
			struct mbuf *nbp;
			/* Copy to a new one */
			nbp = copy_p(bp,length);
			free((char *)bp);
			bp = nbp;
			ip = (struct ip_header *)bp->data;
		}
	}
	ip_len = lonibble(ip->v_ihl) * sizeof(int32);
	if(ip_len < sizeof(struct ip_header)){
		/* The IP header length field is too small */
		ip_stats.length++;
		free_p(bp);
		return;
	}
	if(cksum(NULLHEADER,bp,ip_len) != 0){
		/* Bad IP header checksum; discard */
		ip_stats.checksum++;
		free_p(bp);
		return;
	}
	if(hinibble(ip->v_ihl) != IPVERSION){
		/* We can't handle this version of IP */
		ip_stats.version++;
		free_p(bp);
		return;
	}
	/* See if it's a broadcast or addressed to us, and kick it upstairs */
	if(ntohl(ip->dest) == ip_addr || rxbroadcast){
#ifdef	GWONLY
	/* We're only a gateway, we have no host level protocols */
		if(!rxbroadcast)
			icmp_output(bp,DEST_UNREACH,PROT_UNREACH,(union icmp_args *)NULL);
		free_p(bp);
#else
#ifdef	TRACE
		if(trace & TRACE_SELF && ntohl(ip->source) == ip_addr){
			printf("loopback:\r\n");
			if((trace & TRACE_HDR) > 2)
				ip_dump(bp);
			if(trace & TRACE_DUMP)
				hexdump(bp);
			fflush(stdout);
		}
#endif
		ip_recv(bp,rxbroadcast);
#endif
		return;
	}
	/* If we get here, we must forward the packet.
	 * Process options, if any. Also compute length of secondary IP
	 * header in case fragmentation is needed later
	 */
	strict = 0;
	for(opi = sizeof(struct ip_header);opi < ip_len; opi += opt_len){
		char *opt;	/* Points to current option */
		int opt_type;	/* Type of current option */
		int pointer;	/* Pointer field of current option */
		int32 *addr;	/* Pointer to an IP address field in option */

		opt = (char *)ip + opi;
		opt_type = opt[0] & OPT_NUMBER;

		/* Handle special 1-byte do-nothing options */
		if(opt_type == IP_EOL)
			break;		/* End of options list, we're done */
		if(opt_type == IP_NOOP){
			opt_len = 1;	/* No operation, skip to next option */
			continue;
		}
		/* Other options have a length field */
		opt_len = opt[1] & 0xff;

		/* Process options */
		switch(opt_type){
		case IP_SSROUTE:/* Strict source route & record route */
			strict = 1;
		case IP_LSROUTE:/* Loose source route & record route */
			/* Source routes are ignored unless the datagram appears to
			 * be for us
			 */
			if(ntohl(ip->dest) != ip_addr)
				continue;
		case IP_RROUTE:	/* Record route */
			pointer = (opt[2] & 0xff) - 1;
			if(pointer + sizeof(int32) <= opt_len){
				/* Insert our address in the list */
				addr = (int32 *)&opt[pointer];
				if(opt_type != IP_RROUTE)
					/* Old value is next dest only for source routing */
					ip->dest = *addr;
				*addr = htonl(ip_addr);
				opt[2] += 4;
			} else {
				/* Out of space; return a parameter problem and drop */
				union icmp_args icmp_args;

				icmp_args.unused = 0;
				icmp_args.pointer = sizeof(struct ip_header) + opi;
				icmp_output(bp,PARAM_PROB,0,&icmp_args);
				free_p(bp);
				return;
			}
			break;
		}
	}
	/* Decrement TTL and discard if zero */
	if(--ip->ttl == 0){
		/* Send ICMP "Time Exceeded" message */
		icmp_output(bp,TIME_EXCEED,0,NULLICMP);
		free_p(bp);
		return;
	}
	/* Note this address may have been modified by source routing */
	target = ntohl(ip->dest);

	/* Look up target address in routing table */
	if((rp = rt_lookup(target)) == NULLROUTE){
		/* No route exists, return unreachable message */
		icmp_output(bp,DEST_UNREACH,HOST_UNREACH,NULLICMP);
		free_p(bp);
		return;
	}
	/* Find gateway; zero gateway in routing table means "send direct" */
	if(rp->gateway == (int32)0)
		gateway = target;
	else
		gateway = rp->gateway;

	if(strict && gateway != target){
		/* Strict source routing requires a direct entry */
		icmp_output(bp,DEST_UNREACH,ROUTE_FAIL,NULLICMP);
		free_p(bp);
		return;
	}
	precedence = PREC(ip->tos);
	delay = ip->tos & DELAY;
	throughput = ip->tos & THRUPUT;
	reliability = ip->tos & RELIABILITY;

	if(length <= rp->interface->mtu){
		/* Datagram smaller than interface MTU; send normally */
		/* Recompute header checksum */
		ip->checksum = 0;
		ip->checksum = cksum(NULLHEADER,bp,ip_len);
		(*rp->interface->send)(bp,rp->interface,gateway,
			precedence,delay,throughput,reliability);
		return;
	}
	/* Fragmentation needed */
	fl_offs = ntohs(ip->fl_offs);
	if(fl_offs & DF){
		/* Don't Fragment set; return ICMP message and drop */
		icmp_output(bp,DEST_UNREACH,FRAG_NEEDED,NULLICMP);
		free_p(bp);
		return;
	}
	/* Create copy of IP header for each fragment */
	sbp = copy_p(bp,ip_len);
	pullup(&bp,NULLCHAR,ip_len);
	length -= ip_len;

	/* Create fragments */
	offset = (fl_offs & F_OFFSET) << 3;
	while(length != 0){
		int16 fragsize;		/* Size of this fragment's data */
		struct mbuf *f_header;	/* Header portion of fragment */
		struct ip_header *fip;	/* IP header */
		struct mbuf *f_data;	/* Data portion of fragment */

		f_header = copy_p(sbp,ip_len);
		fip = (struct ip_header *)f_header->data;
		fip->fl_offs = htons(offset >> 3);
		if(length + ip_len <= rp->interface->mtu){
			/* Last fragment; send all that remains */
			fragsize = length;
		} else {
			/* More to come, so send multiple of 8 bytes */
			fragsize = (rp->interface->mtu - ip_len) & 0xfff8;
			fip->fl_offs |= htons(MF);
		}
		fip->length = htons(fragsize + ip_len);
		/* Recompute header checksum */
		fip->checksum = 0;
		fip->checksum = cksum(NULLHEADER,f_header,ip_len);

		/* Extract portion of data and link in */
		f_data = copy_p(bp,fragsize);
		pullup(&bp,NULLCHAR,fragsize);
		f_header->next = f_data;

		(*rp->interface->send)(f_header,rp->interface,gateway,
			precedence,delay,throughput,reliability);
		offset += fragsize;
		length -= fragsize;
	}
	free_p(sbp);
}

/* Add an entry to the IP routing table. Returns 0 on success, -1 on failure */
int
rt_add(target,bits,gateway,metric,interface)
int32 target;	/* Target IP address prefix */
unsigned bits;	/* Size of target address prefix in bits (0-32) */
int32 gateway;
int metric;
struct interface *interface;
{
	struct route *rp,**hp,*rt_lookup();
	int16 hash_ip(),i;
	char *malloc();

	if(interface == NULLIF)
		return -1;

	/* Zero bits refers to the default route */
	if(bits == 0){
		rp = &r_default;
	} else {
		if(bits > 32)
			bits = 32;

		/* Mask off don't-care bits */
		for(i=31;i >= bits;i--)
			target &= ~(0x80000000 >> i);

		/* Search appropriate chain for existing entry */
		for(rp = routes[bits-1][hash_ip(target)];rp != NULLROUTE;rp = rp->next){
			if(rp->target == target)
				break;
		}
	}
	if(rp == NULLROUTE){
		/* The target is not already in the table, so create a new
		 * entry and put it in.
		 */
		if((rp = (struct route *)malloc(sizeof(struct route))) == NULLROUTE)
			return -1;	/* No space */
		/* Insert at head of table */
		rp->prev = NULLROUTE;
		hp = &routes[bits-1][hash_ip(target)];
		rp->next = *hp;
		if(rp->next != NULLROUTE)
			rp->next->prev = rp;
		*hp = rp;
	}
	rp->target = target;
	rp->gateway = gateway;
	rp->metric = metric;
	rp->interface = interface;
	return 0;
}

/* Remove an entry from the IP routing table. Returns 0 on success, -1
 * if entry was not in table.
 */
int
rt_drop(target,bits)
int32 target;
unsigned bits;
{
	register struct route *rp;
	struct route *rt_lookup();
	unsigned i;

	if(bits == 0){
		/* Nail the default entry */
		r_default.interface = NULLIF;
		return 0;
	}
	if(bits > 32)
		bits = 32;

	/* Mask off don't-care bits */
	for(i=31;i > bits;i--)
		target &= ~(0x80000000 >> i);

	/* Search appropriate chain for existing entry */
	for(rp = routes[bits-1][hash_ip(target)];rp != NULLROUTE;rp = rp->next){
		if(rp->target == target)
			break;
	}
	if(rp == NULLROUTE)
		return -1;	/* Not in table */

	if(rp->next != NULLROUTE)
		rp->next->prev = rp->prev;
	if(rp->prev != NULLROUTE)
		rp->prev->next = rp->next;
	else
		routes[bits-1][hash_ip(target)] = rp->next;

	free((char *)rp);
	return 0;
}

/* Compute hash function on IP address */
static int16
hash_ip(addr)
register int32 addr;
{
	register int16 ret;

	ret = hiword(addr);
	ret ^= loword(addr);
	ret %= NROUTE;
	return ret;
}
#ifndef	GWONLY
/* Given an IP address, return the MTU of the local interface used to
 * reach that destination. This is used by TCP to avoid local fragmentation
 */
int16
ip_mtu(addr)
int32 addr;
{
	register struct route *rp;
	struct route *rt_lookup();

	rp = rt_lookup(addr);
	if(rp != NULLROUTE && rp->interface != NULLIF)
		return rp->interface->mtu;
	else
		return 0;
}
#endif
/* Look up target in hash table, matching the entry having the largest number
 * of leading bits in common. Return default route if not found;
 * if default route not set, return NULLROUTE
 */
static struct route *
rt_lookup(target)
int32 target;
{
	register struct route *rp;
	int16 hash_ip();
	unsigned bits;

	for(bits = 32;bits != 0; bits--){
		if(bits != 32)
			target &= ~(0x80000000 >> bits);
		for(rp = routes[bits-1][hash_ip(target)];rp != NULLROUTE;rp = rp->next){
			if(rp->target == target)
				return rp;
		}
	}
	if(r_default.interface != NULLIF)
		return &r_default;
	else
		return NULLROUTE;
}
/* Internet checksum routines
 * Improved portability courtesy Rick Spanbauer, WB2CFV
 */
#define SLOWCHECK
#ifdef SLOWCHECK
/*
 * Word aligned linear buffer checksum routine.  Called from mbuf checksum
 * routine with simple args.  Intent is that this routine may be replaced
 * by assembly language routine for speed if so desired.
 */
static int16
lcsum(sum, wp, len)
register int32 sum;
register int16 *wp;
int16 len;
{
	register int16 csum;

	while(len-- != 0)
		sum += *wp++;
	while((csum = sum >> 16) != 0)
		sum = csum + (sum & 0xffff);
	return sum & 0xffff;
}
#endif SLOWCHECK

/* Perform end-around-carry adjustment */
static int16
eac(sum)
register int32 sum;	/* Carries in high order 16 bits */
{
	register int16 csum;

	while((csum = sum >> 16) != 0)
		sum = csum + (sum & 0xffff);
	return sum;	/* Chops to 16 bits */
}
/* Checksum a mbuf chain, with optional pseudo-header */
int16
cksum(ph,m,len)
struct pseudo_header *ph;
register struct mbuf *m;
int16 len;
{
	register unsigned int cnt, total;
	register int32 sum, csum;
	register unsigned char *up;

	sum = 0l;

	/* Sum pseudo-header, if present */
	if(ph != NULLHEADER){
		sum = hiword(ph->source);
		sum += loword(ph->source);
		sum += hiword(ph->dest);
		sum += loword(ph->dest);
		sum += ph->protocol & 0xff;
		sum += ph->length;
		/* Swapping the sum is equivalent to summing the swapped
		 * elements, but faster. Do end-around-carry first.
		 */
		sum = htons(eac(sum));
	}
	/* Now do each mbuf on the chain */
	for(total = 0; m != NULLBUF && total < len; m = m->next) {
		cnt = min(m->cnt, len - total);
		up = (unsigned char *)m->data;

		/* Handle odd leading byte */
		if(((long)up) & 1){
			csum = (int16)ntohs(*up++);
			cnt--;
		} else
			csum = 0;

		/* Handle odd trailing byte */
		if(cnt & 1)
			csum += (int16)ntohs(up[--cnt]<<8);

		if(cnt != 0){
			/* Have the primitive checksumming routine do most of
			 * the work. At this point, up is guaranteed to be on
			 * a short boundary and cnt is guaranteed to be even
			 */
			csum = lcsum(csum, (unsigned short *)up, cnt >> 1);
		}
		/* If the mbuf we just did wasn't on a word boundary within
		 * the whole packet, then byteswap the checksum for this mbuf
		 */
		if((total&1) ^ (((long)m->data)&1)){
			csum = eac(csum);
			csum = (csum >> 8) + ((csum&0xff) << 8);
		}
		sum += csum;
		total += m->cnt;
	} 
	/* Do final end-around carry, complement and return */
	return ~eac(sum) & 0xffff;
}
#ifdef	TRACE
#include "trace.h"

void
ip_dump(bp)
struct mbuf *bp;
{
	void tcp_dump(),udp_dump(),icmp_dump();
	register struct ip_header *ip;
	int32 source,dest;
	int16 ip_len;
	int16 length;
	struct mbuf *tbp;
	int16 offset;
	int i;
	int check;
	char tmpbuf;	

	if(bp == NULLBUF)
		return;	

	/* If packet isn't in a single buffer, make a temporary copy and
	 * note the fact so we free it later
	 */
	if(bp->next != NULLBUF){
		bp = copy_p(bp,len_mbuf(bp));
		tmpbuf = 1;
	} else
		tmpbuf = 0;

	ip = (struct ip_header *)bp->data;
	ip_len = lonibble(ip->v_ihl) * sizeof(int32);
	length = ntohs(ip->length);
	offset = (ntohs(ip->fl_offs) & F_OFFSET) << 3 ;
	source = ntohl(ip->source);
	dest = ntohl(ip->dest);
	printf("IP: %s",inet_ntoa(source));
	printf("->%s len %u ihl %u ttl %u prot %u",
		inet_ntoa(dest),length,ip_len,ip->ttl & 0xff,
		ip->protocol & 0xff);

	if(ip->tos != 0)
		printf(" tos %u",ip->tos);
	if(offset != 0 || (ntohs(ip->fl_offs) & MF))
		printf(" id %u offs %u",ntohs(ip->id),offset);

	if(ntohs(ip->fl_offs) & DF)
		printf(" DF");
	if(ntohs(ip->fl_offs) & MF){
		printf(" MF");
		check = 0;	/* Bypass host-level checksum verify */
	} else {
		check = 1;
	}

	if((i = cksum(NULLHEADER,bp,ip_len)) != 0)
		printf(" CHECKSUM ERROR (%u)",i);
	printf("\r\n");

	if((trace & TRACE_HDR) > 3){
		if(offset == 0){
			dup_p(&tbp,bp,ip_len,length - ip_len);
			switch(ip->protocol & 0xff){
			case TCP_PTCL:
				tcp_dump(tbp,source,dest,check);
				break;
			case UDP_PTCL:
				udp_dump(tbp,source,dest,check);
				break;
			case ICMP_PTCL:
				icmp_dump(tbp,source,dest,check);
				break;
			}
			free_p(tbp);
		}
	}
	if(tmpbuf)
		free_p(bp);
	fflush(stdout);
}
/* Dump IP routing table
 * Dest              Length    Interface    Gateway          Metric
 * 192.001.002.003   32        sl0          192.002.003.004       4
 */
int
dumproute()
{
	register unsigned int i,bits;
	register struct route *rp;

	printf("Dest              Length    Interface    Gateway          Metric\r\n");
	if(r_default.interface != NULLIF){
		printf("default           0         %-13s",
		 r_default.interface->name);
		if(r_default.gateway != 0)
			printf("%-17s",inet_ntoa(r_default.gateway));
		else
			printf("%-17s","");
		printf("%6u\r\n",r_default.metric);
	}
	for(bits=1;bits<=32;bits++){
		for(i=0;i<NROUTE;i++){
			for(rp = routes[bits-1][i];rp != NULLROUTE;rp = rp->next){
				printf("%-18s",inet_ntoa(rp->target));
				printf("%-10u",bits);
				printf("%-13s",rp->interface->name);
				if(rp->gateway != 0)
					printf("%-17s",inet_ntoa(rp->gateway));
				else
					printf("%-17s","");
				printf("%6u\r\n",rp->metric);
			}
		}
	}
	return 0;
}
#endif
