/*
 * Routines to compress and uncompess tcp packets (for transmission
 * over low speed serial lines.
 *
 * Copyright (c) 1989 Regents of the University of California.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms are permitted
 * provided that the above copyright notice and this paragraph are
 * duplicated in all such forms and that any documentation,
 * advertising materials, and other materials related to such
 * distribution and use acknowledge that the software was developed
 * by the University of California, Berkeley.  The name of the
 * University may not be used to endorse or promote products derived
 * from this software without specific prior written permission.
 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED
 * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE.
 *
 *	Van Jacobson (van@helios.ee.lbl.gov), Dec 31, 1989:
 *	- Initial distribution.
 */
/*
 * modified for KA9Q Internet Software Package by
 * Katie Stevens (dkstevens@ucdavis.edu)
 * University of California, Davis
 * Computing Services
 *	- 01-31-90	initial adaptation (from 1.19)
 *	PPP.05	02-15-90 [ks]
 *	PPP.08	05-02-90 [ks]	use PPP protocol field to signal compression
 *	PPP.15	09-90	 [ks]	improve mbuf handling
 *	PPP.16  11-02	 [karn]	substantially rewritten to use NOS facilities
 */

#include <mem.h>
#include "global.h"
#include "mbuf.h"
#include "tcp.h"
#include "ip.h"
#include "internet.h"
#include "slcomp.h"

static char *encode __ARGS((char *cp,int16 n));
static long decode __ARGS((struct mbuf **bufp));

void
sl_compress_init(comp)
struct slcompress *comp;
{
	register int16 i;
	register struct cstate *tstate = comp->tstate;

	memset((char *)comp, 0, sizeof(*comp));
	for(i = MAX_STATES - 1; i > 0; --i){
		tstate[i].cs_id = i;
		tstate[i].cs_next = &tstate[i - 1];
	}
	tstate[0].cs_next = &tstate[MAX_STATES - 1];
	tstate[0].cs_id = 0;
	comp->last_cs = &tstate[0];
	comp->last_recv = 255;
	comp->last_xmit = 255;
}

/* Encode a number */
static char *
encode(cp,n)
register char *cp;
int16 n;
{
	if(n >= 256 || n == 0){
		*cp++ = 0;
		cp = put16(cp,n);
	} else {
		*cp++ = n;
	}
	return cp;
}

/* Decode a number */
static long
decode(bufp)
struct mbuf **bufp;
{
	register int x;

	x = PULLCHAR(bufp);
	if(x == 0){
		return pull16(bufp);	/* pull16 returns -1 on error */
	} else {
		return (long)x;		/* -1 if PULLCHAR returned error */
	}
}

int
sl_compress_tcp(bpp,comp, compress_cid)
struct mbuf **bpp;
struct slcompress *comp;
int compress_cid;
{
	register struct cstate *cs = comp->last_cs->cs_next;
	register int16 hlen;
	register struct tcp *oth;
	register unsigned long deltaS, deltaA;
	register int16 changes = 0;
	char new_seq[16];
	register char *cp = new_seq;
	struct mbuf *bp;
	struct tcp th;
	struct ip iph;

	/* Extract IP header */
	hlen = ntohip(&iph,bpp);

	/* Bail if this packet isn't TCP, or is an IP fragment */
	if(iph.protocol != TCP_PTCL || iph.offset != 0 || iph.flags.mf){
		/* Send as regular IP */
		if(iph.protocol != TCP_PTCL)
			comp->sls_nontcp++;
		else
			comp->sls_asistcp++;
		*bpp = htonip(&iph,*bpp,1);
		return SL_TYPE_IP;
	}
	/* Extract TCP header */
	hlen += ntohtcp(&th,bpp);

	/*  Bail if the TCP packet isn't `compressible' (i.e., ACK isn't set or
	 *  some other control bit is set).
	 */
	if(th.flags.syn || th.flags.fin || th.flags.rst || !th.flags.ack){
		/* TCP connection stuff; send as regular IP */
		comp->sls_asistcp++;
		*bpp = htontcp(&th,*bpp,NULLHEADER);
		*bpp = htonip(&iph,*bpp,1);
		return SL_TYPE_IP;
	}
	/*
	 * Packet is compressible -- we're going to send either a
	 * COMPRESSED_TCP or UNCOMPRESSED_TCP packet.  Either way we need
	 * to locate (or create) the connection state.  Special case the
	 * most recently used connection since it's most likely to be used
	 * again & we don't have to do any reordering if it's used.
	 */
	if(iph.source != cs->cs_ip.source ||
	 iph.dest  != cs->cs_ip.dest ||
	 th.source != cs->cs_tcp.source ||
	 th.dest != cs->cs_tcp.dest){
		/*
		 * Wasn't the first -- search for it.
		 *
		 * States are kept in a circularly linked list with
		 * last_cs pointing to the end of the list.  The
		 * list is kept in lru order by moving a state to the
		 * head of the list whenever it is referenced.  Since
		 * the list is short and, empirically, the connection
		 * we want is almost always near the front, we locate
		 * states via linear search.  If we don't find a state
		 * for the datagram, the oldest state is (re-)used.
		 */
		register struct cstate *lcs;
		register struct cstate *lastcs = comp->last_cs;

		do {
			lcs = cs; cs = cs->cs_next;
			comp->sls_searches++;
			if(iph.source == cs->cs_ip.source
			 && iph.dest == cs->cs_ip.dest
			 && th.source == cs->cs_tcp.source
			 && th.dest == cs->cs_tcp.dest)
				goto found;
		} while(cs != lastcs);

		/*
		 * Didn't find it -- re-use oldest cstate.  Send an
		 * uncompressed packet that tells the other side what
		 * connection number we're using for this conversation.
		 * Note that since the state list is circular, the oldest
		 * state points to the newest and we only need to set
		 * last_cs to update the lru linkage.
		 */
		comp->sls_misses++;
		comp->last_cs = lcs;

		goto uncompressed;

	found:
		/*
		 * Found it -- move to the front on the connection list.
		 */
		if(cs == lastcs)
			comp->last_cs = lcs;
		else {
			lcs->cs_next = cs->cs_next;
			cs->cs_next = lastcs->cs_next;
			lastcs->cs_next = cs;
		}
	}

	/*
	 * Make sure that only what we expect to change changed.
	 * Check the following:
	 * IP protocol version, header length & type of service.
	 * The "Don't fragment" bit.
	 * The time-to-live field.
	 * The TCP header length.
	 * IP options, if any.
	 * TCP options, if any.
	 * If any of these things are different between the previous &
	 * current datagram, we send the current datagram `uncompressed'.
	 */
	oth = &cs->cs_tcp;

	if(iph.version != cs->cs_ip.version || iph.optlen != cs->cs_ip.optlen
	 || iph.tos != cs->cs_ip.tos
	 || iph.flags.df != cs->cs_ip.flags.df
	 || iph.ttl != cs->cs_ip.ttl
	 || th.optlen != cs->cs_tcp.optlen
	 || iph.optlen != cs->cs_ip.optlen
	 || (iph.optlen > 0 && memcmp(iph.options,cs->cs_ip.options,iph.optlen) != 0)
	 || (th.optlen > 0 && memcmp(th.options,cs->cs_tcp.options,th.optlen) != 0)){
		goto uncompressed;
	}
	/*
	 * Figure out which of the changing fields changed.  The
	 * receiver expects changes in the order: urgent, window,
	 * ack, seq (the order minimizes the number of temporaries
	 * needed in this section of code).
	 */
	if(th.flags.urg){
		deltaS = th.up;
		cp = encode(cp,deltaS);
		changes |= NEW_U;
	} else if(th.up != oth->up){
		/* argh! URG not set but urp changed -- a sensible
		 * implementation should never do this but RFC793
		 * doesn't prohibit the change so we have to deal
		 * with it. */
		goto uncompressed;
	}
	if((deltaS = th.wnd - oth->wnd) != 0){
		cp = encode(cp,deltaS);
		changes |= NEW_W;
	}
	if((deltaA = th.ack - oth->ack) != 0L){
		if(deltaA > 0x0000ffff)
			goto uncompressed;
		cp = encode(cp,deltaA);
		changes |= NEW_A;
	}
	if((deltaS = th.seq - oth->seq) != 0L){
		if(deltaS > 0x0000ffff)
			goto uncompressed;
		cp = encode(cp,deltaS);
		changes |= NEW_S;
	}

	switch(changes){
	case 0:	/* Nothing changed. If this packet contains data and the
		 * last one didn't, this is probably a data packet following
		 * an ack (normal on an interactive connection) and we send
		 * it compressed.  Otherwise it's probably a retransmit,
		 * retransmitted ack or window probe.  Send it uncompressed
		 * in case the other side missed the compressed version.
		 */
		if(iph.length != cs->cs_ip.length && cs->cs_ip.length == hlen)
			break;
		goto uncompressed;
	case SPECIAL_I:
	case SPECIAL_D:
		/* actual changes match one of our special case encodings --
		 * send packet uncompressed.
		 */
		goto uncompressed;
	case NEW_S|NEW_A:
		if(deltaS == deltaA &&
		    deltaS == cs->cs_ip.length - hlen){
			/* special case for echoed terminal traffic */
			changes = SPECIAL_I;
			cp = new_seq;
		}
		break;
	case NEW_S:
		if(deltaS == cs->cs_ip.length - hlen){
			/* special case for data xfer */
			changes = SPECIAL_D;
			cp = new_seq;
		}
		break;
	}
	deltaS = iph.id - cs->cs_ip.id;
	if(deltaS != 1){
		cp = encode(cp,deltaS);
		changes |= NEW_I;
	}
	if(th.flags.psh)
		changes |= TCP_PUSH_BIT;
	/* Grab the cksum before we overwrite it below.  Then update our
	 * state with this packet's header.
	 */
	deltaA = th.checksum;
	ASSIGN(cs->cs_ip,iph);
	ASSIGN(cs->cs_tcp,th);
	/* We want to use the original packet as our compressed packet.
	 * (cp - new_seq) is the number of bytes we need for compressed
	 * sequence numbers.  In addition we need one byte for the change
	 * mask, one for the connection id and two for the tcp checksum.
	 * So, (cp - new_seq) + 4 bytes of header are needed.
	 */
	deltaS = cp - new_seq;
	if(compress_cid == 0 || comp->last_xmit != cs->cs_id){
		bp = *bpp = pushdown(*bpp,deltaS + 4);
		cp = bp->data;
		*cp++ = changes | NEW_C;
		*cp++ = cs->cs_id;
	} else {
		bp = *bpp = pushdown(*bpp,deltaS + 3);
		cp = bp->data;
		*cp++ = changes;
	}
	cp = put16(cp,(int16)deltaA);	/* Write TCP checksum */
	memcpy(cp,new_seq,deltaS);	/* Write list of deltas */
	comp->sls_compressed++;
	return SL_TYPE_COMPRESSED_TCP;

	/* Update connection state cs & send uncompressed packet (i.e.,
	 * a regular ip/tcp packet but with the 'conversation id' we hope
	 * to use on future compressed packets in the protocol field).
	 */
uncompressed:
	iph.protocol = cs->cs_id;
	ASSIGN(cs->cs_ip,iph);
	ASSIGN(cs->cs_tcp,th);
	comp->last_xmit = cs->cs_id;
	comp->sls_uncompressed++;
	*bpp = htontcp(&th,*bpp,NULLHEADER);
	*bpp = htonip(&iph,*bpp,1);
	return SL_TYPE_UNCOMPRESSED_TCP;
}


int
sl_uncompress_tcp(bufp, len, type, comp)
struct mbuf **bufp;
int len;
int16 type;
struct slcompress *comp;
{
	register int changes;
	long x;
	register struct tcp *thp;
	register struct cstate *cs;
	struct ip iph;
	struct tcp th;

	switch(type){
	case SL_TYPE_UNCOMPRESSED_TCP:
		/* Extract IP and TCP headers and verify conn ID */
		ntohip(&iph,bufp);
		ntohtcp(&th,bufp);
		if(uchar(iph.protocol) >= MAX_STATES)
			goto bad;

		/* Update local state */
		cs = &comp->rstate[comp->last_recv = uchar(iph.protocol)];
		comp->flags &=~ SLF_TOSS;
		iph.protocol = TCP_PTCL;
		ASSIGN(cs->cs_ip,iph);
		ASSIGN(cs->cs_tcp,th);

		/* Put headers back on packet
		 * Neither header checksum is recalculated
		 */
		*bufp = htontcp(&th,*bufp,NULLHEADER);
		*bufp = htonip(&iph,*bufp,1);
		comp->sls_uncompressedin++;
		return len;

	default:
		goto bad;

	case SL_TYPE_COMPRESSED_TCP:
		break;
	}
	/* We've got a compressed packet; read the change byte */
	comp->sls_compressedin++;
	if(len < 3){
		comp->sls_errorin++;
		return 0;
	}
	changes = PULLCHAR(bufp);	/* "Can't fail" */
	if(changes & NEW_C){
		/* Make sure the state index is in range, then grab the state.
		 * If we have a good state index, clear the 'discard' flag.
		 */
		x = PULLCHAR(bufp);	/* Read conn index */
		if(x < 0 || x >= MAX_STATES)
			goto bad;

		comp->flags &=~ SLF_TOSS;
		comp->last_recv = x;
	} else {
		/* this packet has an implicit state index.  If we've
		 * had a line error since the last time we got an
		 * explicit state index, we have to toss the packet. */
		if(comp->flags & SLF_TOSS){
			comp->sls_tossed++;
			return 0;
		}
	}
	cs = &comp->rstate[comp->last_recv];
	thp = &cs->cs_tcp;
	
	if((x = pull16(bufp)) == -1)	/* Read the TCP checksum */
		goto bad; 
	thp->checksum = x;

	thp->flags.psh = (changes & TCP_PUSH_BIT) ? 1 : 0;

	switch(changes & SPECIALS_MASK){
	case SPECIAL_I:		/* Echoed terminal traffic */
		{
		register int16 i;
		i = cs->cs_ip.length;
		i -= (cs->cs_ip.optlen + IPLEN + TCPLEN);
		thp->ack += i;
		thp->seq += i;
		}
		break;

	case SPECIAL_D:			/* Unidirectional data */
		thp->seq += cs->cs_ip.length - (cs->cs_ip.optlen +IPLEN + TCPLEN);
		break;

	default:
		if(changes & NEW_U){
			thp->flags.urg = 1;
			if((x = decode(bufp)) == -1)
				goto bad;
			thp->up = x;
		} else
			thp->flags.urg = 0;
		if(changes & NEW_W){
			if((x = decode(bufp)) == -1)
				goto bad;
			thp->wnd += x;
		}
		if(changes & NEW_A){
			if((x = decode(bufp)) == -1)
				goto bad;
			thp->ack += x;
		}
		if(changes & NEW_S){
			if((x = decode(bufp)) == -1)
				goto bad;
			thp->seq += x;
		}
		break;
	}
	if(changes & NEW_I){
		if((x = decode(bufp)) == -1)
			goto bad;
		cs->cs_ip.id += x;
	} else
		cs->cs_ip.id++;

	/*
	 * At this point, bufp points to the first byte of data in the
	 * packet.  Put the reconstructed TCP and IP headers back on the
	 * packet.
	 */
	len = len_p(*bufp) + IPLEN + TCPLEN + cs->cs_ip.optlen;
	cs->cs_ip.length = len;

	*bufp = htontcp(thp,*bufp,NULLHEADER);
	*bufp = htonip(&cs->cs_ip,*bufp,0);
	return len;
bad:
	comp->flags |= SLF_TOSS;
	comp->sls_errorin++;
	return 0;
}
