#ifdef	TRACE
#include <stdio.h>
#endif

#include "machdep.h"
#include "timer.h"
#include "mbuf.h"
#include "netuser.h"
#include "internet.h"
#include "tcp.h"

struct tcb *tcbs[NTCB];

/* Lookup connection, return TCB pointer or NULLTCB if nonexistant */
struct tcb *
lookup_tcb(conn)
struct connection *conn;
{
	register struct tcb *tcb;
	int16 hash_tcb();	

	tcb = tcbs[hash_tcb(conn)];
	while(tcb != NULLTCB){
		/* Yet another structure compatibility hack */
		if(conn->local.address == tcb->conn.local.address
		 && conn->remote.address == tcb->conn.remote.address
		 && conn->local.port == tcb->conn.local.port
		 && conn->remote.port == tcb->conn.remote.port)
			break;
		tcb = tcb->next;
	}
	return tcb;
}

/* Create a TCB, return pointer. Return pointer if TCB already exists. */
struct tcb *
create_tcb(conn)
struct connection *conn;
{
	char *calloc();
	register struct tcb *tcb;
	void tcp_timeout(),tcp_msl();
	void link_tcb();

	if((tcb = lookup_tcb(conn)) != NULLTCB)
		return tcb;
	if((tcb = (struct tcb *)calloc(1,sizeof (struct tcb))) == NULLTCB)
		return NULLTCB;
	bcopy((char *)conn,(char *)&tcb->conn,sizeof(struct connection));

	tcb->mss = DEF_MSS;
	tcb->srtt = DEF_RTT * MSPTICK;
	/* Initialize retransmission timeout */
	tcb->timer.start = (BETA * tcb->srtt)/MSPTICK;

	tcb->timer.func = tcp_timeout;
	tcb->timer.arg = (int *)tcb;

	link_tcb(tcb);
	return tcb;
}

/* Close our TCB */
void
close_self(tcb,reason)
register struct tcb *tcb;
char reason;
{
	struct reseq *rp,*rp1;

	stop_timer(&tcb->timer);
	tcb->reason = reason;

	/* Flush reassembly queue; nothing more can arrive */
	for(rp = tcb->reseq;rp != NULLRESEQ;rp = rp1){
		rp1 = rp->next;
		free_p(rp->bp);
		free((char *)rp);
	}
	tcb->reseq = NULLRESEQ;
	setstate(tcb,CLOSED);
}

/* Determine initial sequence number */

#ifdef	AMIGA
/*
 *  routine called at startup time with inital value of iss for system.  This
 *  is probably based on the time or something
 */
static int32 seq;

void
setiss(initval)
	int32 initval;
{
	seq = initval;
}
#endif

int32
iss()
{
#ifndef	AMIGA
	static int32 seq;
#endif
	seq += 250000;
	return seq;
}

/* Sequence number comparisons
 * Return true if x is between low and high inclusive,
 * false otherwise
 */
int
seq_within(x,low,high)
register int32 x,low,high;
{
	if(low <= high){
		if(low <= x && x <= high)
			return 1;
	} else {
		if(low >= x && x >= high)
			return 1;
	}
	return 0;
}
int
seq_lt(x,y)
register int32 x,y;
{
	return (long)(x-y) < 0;
}
int
seq_le(x,y)
register int32 x,y;
{
	return (long)(x-y) <= 0;
}
int
seq_gt(x,y)
register int32 x,y;
{
	return (long)(x-y) > 0;
}
int
seq_ge(x,y)
register int32 x,y;
{
	return (long)(x-y) >= 0;
}

/* Hash a connect structure into the hash chain header array */
static int16
hash_tcb(conn)
struct connection *conn;
{
	register int16 hval;

	/* Compute hash function on connection structure */
	hval = hiword(conn->remote.address);
	hval ^= loword(conn->remote.address);
	hval ^= hiword(conn->local.address);
	hval ^= loword(conn->local.address);
	hval ^= conn->remote.port;
	hval ^= conn->local.port;
	hval %= NTCB;
	return hval;
}
/* Insert TCB at head of proper hash chain */
void
link_tcb(tcb)
register struct tcb *tcb;
{
	register struct tcb **tcbhead;
	int16 hash_tcb();
	char i_state;

	tcb->prev = NULLTCB;
	i_state = disable();
	tcbhead = &tcbs[hash_tcb(&tcb->conn)];
	tcb->next = *tcbhead;
	if(tcb->next != NULLTCB){
		tcb->next->prev = tcb;
	}
	*tcbhead = tcb;
	restore(i_state);
}
/* Remove TCB from whatever hash chain it may be on */
void
unlink_tcb(tcb)
register struct tcb *tcb;
{
	register struct tcb **tcbhead;
	int16 hash_tcb();
	char i_state;

	i_state = disable();
	tcbhead = &tcbs[hash_tcb(&tcb->conn)];
	if(*tcbhead == tcb)
		*tcbhead = tcb->next;	/* We're the first one on the chain */
	if(tcb->prev != NULLTCB)
		tcb->prev->next = tcb->next;
	if(tcb->next != NULLTCB)
		tcb->next->prev = tcb->prev;
	restore(i_state);
}
void
setstate(tcb,newstate)
register struct tcb *tcb;
register char newstate;
{
	register char oldstate;

	oldstate = tcb->state;
	tcb->state = newstate;
	if(tcb->s_upcall){
		(*tcb->s_upcall)(tcb,oldstate,newstate);
	}
	/* Notify the user that he can begin sending data */
	if(tcb->t_upcall && newstate == ESTABLISHED){
		(*tcb->t_upcall)(tcb,tcb->window - tcb->sndcnt);
	}
}
#ifdef	TRACE
/* TCP connection states */
char *tcpstates[] = {
	"Closed",
	"Listen",
	"SYN sent",
	"SYN received",
	"Established",
	"FIN wait 1",
	"FIN wait 2",
	"Close wait",
	"Closing",
	"Last ACK",
	"Time wait"
};
/* TCP segment header flags */
char *tcpflags[] = {
	"FIN",	/* 0x01 */
	"SYN",	/* 0x02 */
	"RST",	/* 0x04 */
	"PSH",	/* 0x08 */
	"ACK",	/* 0x10 */
	"URG"	/* 0x20 */
};

/* TCP closing reasons */
char *reasons[] = {
	"Normal",
	"Reset",
	"Timeout",
	"ICMP"
};
/* Return 1 if arg is a valid TCB, 0 otherwise */
int
tcpval(tcb)
struct tcb *tcb;
{
	register int i;
	register struct tcb *tcb1;

	for(i=0;i<NTCB;i++){
		for(tcb1=tcbs[i];tcb1 != NULLTCB;tcb1 = tcb1->next){
			if(tcb1 == tcb)
				return 1;
		}
	}
	return 0;
}

/* Dump TCP stats and summary of all TCBs
/* &TCB Rcv-Q Snd-Q  Local socket           Remote socket          State
 * 1234     0     0  xxx.xxx.xxx.xxx:xxxxx  xxx.xxx.xxx.xxx:xxxxx  Established
 */
int
tcpstat()
{
	register int i;
	register struct tcb *tcb;
	char *psocket();

	printf("conout %u conin %u reset out %u runt %u chksum err %u bdcsts %u\r\n",
		tcp_stat.conout,tcp_stat.conin,tcp_stat.resets,tcp_stat.runt,
		tcp_stat.checksum,tcp_stat.bdcsts);
#ifdef	AMIGA
	printf("&TCB   Rcv-Q Snd-Q  Local socket           Remote socket          State\r\n");
#else
	printf("&TCB Rcv-Q Snd-Q  Local socket           Remote socket          State\r\n");
#endif
	for(i=0;i<NTCB;i++){
		for(tcb=tcbs[i];tcb != NULLTCB;tcb = tcb->next){
#ifdef	AMIGA
			printf("%6lx%6u%6u  ",(unsigned long)tcb,
						tcb->rcvcnt,tcb->sndcnt);
#else
			printf("%4x%6u%6u  ",(int)tcb,tcb->rcvcnt,tcb->sndcnt);
#endif
			printf("%-23s",psocket(&tcb->conn.local));
			printf("%-23s",psocket(&tcb->conn.remote));
			printf("%-s\r\n",tcpstates[tcb->state]);
		}
	}
	fflush(stdout);
	return 0;
}
/* Dump a TCP control block */
void
state_tcp(tcb)
struct tcb *tcb;
{
	int32 sent,recvd;

	if(tcb == NULLTCB)
		return;
	/* Compute total data sent and received; take out SYN and FIN */
	sent = tcb->snd.una - tcb->iss;	/* Acknowledged data only */
	recvd = tcb->rcv.nxt - tcb->irs;
	switch(tcb->state){
	case LISTEN:
	case SYN_SENT:		/* Nothing received or acked yet */
		sent = recvd = 0;	
		break;
	case SYN_RECEIVED:
		recvd--;	/* Got SYN, no data acked yet */
		sent = 0;
		break;
	case ESTABLISHED:	/* Got and sent SYN */
	case FINWAIT1:		/* FIN not acked yet */
		sent--;
		recvd--;
		break;
	case FINWAIT2:		/* Our SYN and FIN both acked */
		sent -= 2;
		recvd--;
		break;
	case CLOSE_WAIT:	/* Got SYN and FIN, our FIN not yet acked */
	case CLOSING:
	case LAST_ACK:
		sent--;
		recvd -= 2;
		break;
	case TIME_WAIT:		/* Sent and received SYN/FIN, all acked */
		sent -= 2;
		recvd -= 2;
		break;
	}
	printf("Local: %s",psocket(&tcb->conn.local));
	printf(" Remote: %s",psocket(&tcb->conn.remote));
	printf(" State: %s\r\n",tcpstates[tcb->state]);
	printf("      Init seq    Unack     Next      WL1      WL2  Wind   MSS Queue      Total\r\n");
	printf("Send:");
	printf("%9lx",tcb->iss);
	printf("%9lx",tcb->snd.una);
	printf("%9lx",tcb->snd.nxt);
	printf("%9lx",tcb->snd.wl1);
	printf("%9lx",tcb->snd.wl2);
	printf("%6u",tcb->snd.wnd);
	printf("%6u",tcb->mss);
	printf("%6u",tcb->sndcnt);
	printf("%11lu\r\n",sent);

	printf("Recv:");
	printf("%9lx",tcb->irs);
	printf("         ");
	printf("%9lx",tcb->rcv.nxt);
	printf("         ");
	printf("         ");
	printf("%6u",tcb->rcv.wnd);
	printf("      ");
	printf("%6u",tcb->rcvcnt);
	printf("%11lu\r\n",recvd);

	if(tcb->reseq != (struct reseq *)NULL){
		register struct reseq *rp;

		printf("Reassembly queue:\r\n");
		for(rp = tcb->reseq;rp != (struct reseq *)NULL; rp = rp->next){
			printf("  seq x%lx %u bytes\r\n",rp->seg.seq,rp->length);
		}
	}
	printf("Retry %u",tcb->retry);
	switch(tcb->timer.state){
	case TIMER_STOP:
		printf(" Timer stopped");
		break;
	case TIMER_RUN:
		printf(" Timer running (%ld/%ld mS)",
		 (long)MSPTICK * (tcb->timer.start - tcb->timer.count),
		 (long)MSPTICK * tcb->timer.start);
		break;
	case TIMER_EXPIRE:
		printf(" Timer expired");
	}
	printf(" Smoothed round trip time %ld mS\r\n",tcb->srtt);
	fflush(stdout);
}

/* Dump a TCP segment header. Assumed to be in network byte order */
void
tcp_dump(bp,source,dest,check)
struct mbuf *bp;
int32 source,dest;	/* IP source and dest addresses */
int check;		/* 0 if checksum test is to be bypassed */
{
	int hdr_len,i;
	register struct tcp_header *tcph;
	struct pseudo_header ph;
	char tmpbuf;

	if(bp == NULLBUF)
		return;
	/* If packet isn't in a single buffer, make a temporary copy and
	 * note the fact so we free it later
	 */
	if(bp->next != NULLBUF){
		bp = copy_p(bp,len_mbuf(bp));
		tmpbuf = 1;
	} else
		tmpbuf = 0;
		
	tcph = (struct tcp_header *)bp->data;
	hdr_len = hinibble(tcph->offset) * sizeof(int32);
	printf("TCP: %u->%u Seq x%lx",
		ntohs(tcph->source),ntohs(tcph->dest),
		ntohl(tcph->seq),ntohl(tcph->ack));
	
	if(tcph->flags & ACK)
		printf(" Ack x%lx",ntohl(tcph->ack));
	for(i=0;i<6;i++){
		if(tcph->flags & 1 << i){
			printf(" %s",tcpflags[i]);
		}
	}
	printf(" Wnd %u",ntohs(tcph->wnd));
	if(tcph->flags & URG)
		printf(" UP x%x",ntohs(tcph->up));

	if(hdr_len > sizeof(struct tcp_header)){
		struct mss *mssp;
	
		mssp = (struct mss *)(tcph + 1);
		if(mssp->kind == MSS_KIND && mssp->length == MSS_LENGTH){
			printf(" MSS %u",ntohs(mssp->mss));
		}
	}
	/* Verify checksum */
	if(check){
		ph.source = source;
		ph.dest = dest;
		ph.protocol = TCP_PTCL;
		ph.length = len_mbuf(bp);
		ph.zero = 0;
		if((i = cksum(&ph,bp,ph.length)) != 0)
			printf(" CHECKSUM ERROR (%u)",i);
	}
	printf("\r\n");
	if(tmpbuf)
		free_p(bp);
}
#endif
