/* Send and receive User Datagram Protocol packets */
#include "machdep.h"
#include "mbuf.h"
#include "netuser.h"
#include "udp.h"
#include "internet.h"

struct udp_cb *udps[NUDP];	/* Hash table for UDP structures */
struct udp_stat udp_stat;	/* Statistics */

/* Create a UDP control block for lsocket, so that we can queue
 * incoming datagrams.
 */
int
open_udp(lsocket,r_upcall)
struct socket *lsocket;
void (*r_upcall)();
{
	char *malloc();
	register struct udp_cb *up;
	struct udp_cb *lookup_udp();
	int16 hval,hash_udp();

	if((up = lookup_udp(lsocket)) != NULLUDP)
		return 0;	/* Already exists */
	if((up = (struct udp_cb *)malloc(sizeof (struct udp_cb))) == NULLUDP){
		net_error = NO_SPACE;
		return -1;
	}
	up->rcvq = NULLBUF;
	up->rcvcnt = 0;
	up->socket.address = lsocket->address;
	up->socket.port = lsocket->port;
	up->r_upcall = r_upcall;

	hval = hash_udp(lsocket);
	up->next = udps[hval];
	up->prev = NULLUDP;
	up->next->prev = up;
	udps[hval] = up;
	return 0;
}

/* Send a UDP datagram */
int
send_udp(lsocket,fsocket,tos,ttl,bp,length,id,df)
struct socket *lsocket;		/* Source socket */
struct socket *fsocket;		/* Destination socket */
char tos;					/* Type-of-service for IP */
char ttl;					/* Time-to-live for IP */
struct mbuf *bp;			/* Data field, if any */
int16 length;				/* Length of data field */
int16 id;					/* Optional ID field for IP */
char df;					/* Don't Fragment flag for IP */
{
	struct mbuf *hbp;
	int16 hdr_len;
	struct pseudo_header ph;
	struct udp_header *uhdr;

	if(length == 0 && bp != NULLBUF)
		length = len_mbuf(bp);
	hdr_len = sizeof(struct udp_header);
	length += hdr_len;

	/* Allocate UDP protocol header and fill it in */
	if((hbp = alloc_mbuf(hdr_len)) == NULLBUF){
		net_error = NO_SPACE;
		return -1;
	}
	hbp->cnt = hdr_len;

	uhdr = (struct udp_header *)hbp->data;
	uhdr->source = htons(lsocket->port);
	uhdr->dest = htons(fsocket->port);
	uhdr->length = htons(length);
	uhdr->checksum = 0;

	/* Link in the user data */
	hbp->next = bp;

	/* Create IP pseudo-header, compute checksum and send it */
	ph.length = length;
	ph.source = lsocket->address;
	ph.dest = fsocket->address;
	ph.protocol = UDP_PTCL;
	ph.zero = 0;
	/* All zeros and all ones is equivalent in one's complement arithmetic;
	 * the spec requires us to change zeros into ones to distinguish an
 	 * all-zero checksum from no checksum at all
	 */
	if((uhdr->checksum = cksum(&ph,hbp,length)) == 0)
		uhdr->checksum = 0xffffffff;

	udp_stat.sent++;
	ip_send(lsocket->address,fsocket->address,UDP_PTCL,tos,ttl,hbp,length,id,df);
	return length;
}

/* Accept a waiting datagram, if available. Returns length of datagram */
int
recv_udp(lsocket,fsocket,bp)
struct socket *lsocket;		/* Local socket to receive on */
struct socket *fsocket;		/* Place to stash incoming socket */
struct mbuf **bp;			/* Place to stash data packet */
{
	struct udp_cb *lookup_udp();
	register struct udp_cb *up;
	struct socket *sp;
	struct mbuf *buf;
	int16 length;

	up = lookup_udp(lsocket);
	if(up == NULLUDP){
		net_error = NO_CONN;
		return -1;
	}
	if(up->rcvcnt == 0){
		net_error = WOULDBLK;
		return -1;
	}
	buf = dequeue(&up->rcvq);
	up->rcvcnt--;

	sp = (struct socket *)buf->data;
	/* Fill in the user's foreign socket structure, if given */
	if(fsocket != NULLSOCK){
		fsocket->address = sp->address;
		fsocket->port = sp->port;
	}
	/* Strip socket header and hand data to user */
	pullup(&buf,NULLCHAR,sizeof(struct socket));
	length = len_mbuf(buf);
	if(bp != (struct mbuf **)NULL)
		*bp = buf;
	else
		free_p(buf);
	return length;
}
/* Delete a UDP control block */
int
del_udp(lsocket)
struct socket *lsocket;
{
	register struct udp_cb *up;
	struct udp_cb *lookup_udp();
	struct mbuf *bp;
	int16 hval;

	if((up = lookup_udp(lsocket)) == NULLUDP){
		net_error = INVALID;
		return -1;
	}		
	/* Get rid of any pending packets */
	while(up->rcvcnt != 0){
		bp = up->rcvq;
		up->rcvq = up->rcvq->anext;
		free_p(bp);
		up->rcvcnt--;
	}
	hval = hash_udp(&up->socket);
	if(udps[hval] == up){
		/* First on list */
		udps[hval] = up->next;
		up->next->prev = NULLUDP;
	} else {
		up->prev->next = up->next;
		up->next->prev = up->prev;
	}
	free((char *)up);
	return 0;
}
/* Process an incoming UDP datagram */
void
udp_input(bp,protocol,source,dest,tos,length,rxbroadcast)
struct mbuf *bp;
char protocol;
int32 source;		/* Source IP address */
int32 dest;		/* Dest IP address */
char tos;
int16 length;
char rxbroadcast;	/* The only protocol that accepts 'em */
{
	struct pseudo_header ph;
	struct udp_header udp;
	struct udp_cb *up,*lookup_udp();
	struct socket lsocket;
	struct socket *fsocket;
	struct mbuf *sp;
	int ckfail = 0;

	if(bp == NULLBUF)
		return;

	udp_stat.rcvd++;

	/* Create pseudo-header and verify checksum */
	ph.source = source;
	ph.dest = dest;
	ph.protocol = protocol;
	ph.length = length;
	ph.zero = 0;
	if(cksum(&ph,bp,length) != 0)
		/* Checksum apparently failed, note for later */
		ckfail++;

	/* Extract UDP header in host order */
	pullup(&bp,(char *)&udp,sizeof(struct udp_header));

	/* If the checksum field is zero, then ignore a checksum error.
	 * I think this is dangerously wrong, but it is in the spec.
	 */
	if(ckfail && udp.checksum != 0){
		udp_stat.cksum++;
		free_p(bp);
		return;
	}
	udp.dest = ntohs(udp.dest);
	udp.source = ntohs(udp.source);

	/* If this was a broadcast packet, pretend it was sent to us */
	if(rxbroadcast){
		lsocket.address = ip_addr;
		udp_stat.bdcsts++;
	} else
		lsocket.address = dest;

	lsocket.port = udp.dest;
	/* See if there's somebody around to read it */
	if((up = lookup_udp(&lsocket)) == NULLUDP){
		/* Nope, toss it on the floor */
		udp_stat.unknown++;
		free_p(bp);
		return;
	}
	/* Create a buffer which will contain the foreign socket info */
	if((sp = alloc_mbuf(sizeof(struct socket))) == NULLBUF){
		/* No space, drop whole packet */
		free_p(bp);
		return;
	}
	sp->cnt = sizeof(struct socket);

	fsocket = (struct socket *)sp->data;
	fsocket->address = source;
	fsocket->port = udp.source;

	/* Yes, remove the now redundant UDP header, chain the foreign socket
	 * info in front of it and queue it
	 */

	sp->next = bp;
	enqueue(&up->rcvq,sp);
	up->rcvcnt++;
	if(up->r_upcall)
		(*up->r_upcall)(&lsocket,up->rcvcnt);
}
/* Look up UDP socket, return control block pointer or NULLUDP if nonexistant */
static
struct udp_cb *
lookup_udp(socket)
struct socket *socket;
{
	register struct udp_cb *up;
	int16 hash_udp();

	up = udps[hash_udp(socket)];
	while(up != NULLUDP){
		if(bcmp((char *)socket,(char *)&up->socket,sizeof(struct socket)) == 0)
			break;
		up = up->next;
	}
	return up;
}

/* Hash a UDP socket (address and port) structure */
static
int16
hash_udp(socket)
struct socket *socket;
{
	int16 hval;

	/* Compute hash function on socket structure */
	hval = hiword(socket->address);
	hval ^= loword(socket->address);
	hval ^= socket->port;
	hval %= NUDP;
	return hval;
}
#ifdef	TRACE
/* Dump UDP statistics and control blocks */
doudpstat()
{
	extern struct udp_stat udp_stat;
	char *psocket();
	register struct udp_cb *udp;
	register int i;

	printf("sent %u rcvd %u bdcsts %u cksum err %u unknown socket %u\r\n",
	udp_stat.sent,udp_stat.rcvd,udp_stat.bdcsts,udp_stat.cksum,udp_stat.unknown);
#ifdef	AMIGA
	printf("&UCB   Rcv-Q  Local socket\r\n");
#else
	printf("&UCB Rcv-Q  Local socket\r\n");
#endif
	for(i=0;i<NUDP;i++){
		for(udp = udps[i];udp != NULLUDP; udp = udp->next){
#ifdef	AMIGA
			printf("%6lx%6u  %s\r\n",(unsigned long)udp,udp->rcvcnt,
#else
			printf("%x%6u  %s\r\n",(int)udp,udp->rcvcnt,
#endif
			 psocket(&udp->socket));
		}
	}
}
/* Dump a UDP header */
void
udp_dump(bp,source,dest,check)
struct mbuf *bp;
int32 source,dest;
int check;		/* If 0, bypass checksum verify */
{
	register struct udp_header *udph;
	struct pseudo_header ph;
	int i;
	char tmpbuf;

	if(bp == NULLBUF)
		return;
	/* If packet isn't in a single buffer, make a temporary copy and
	 * note the fact so we free it later
	 */
	if(bp->next != NULLBUF){
		bp = copy_p(bp,len_mbuf(bp));
		tmpbuf = 1;
	} else
		tmpbuf = 0;

	udph = (struct udp_header *)bp->data;
	printf("UDP:");
	printf(" %u->%u",ntohs(udph->source),
		ntohs(udph->dest));
	printf(" len %u",ntohs(udph->length));
	
	if(check){
		/* Verify checksum */
		ph.source = source;
		ph.dest = dest;
		ph.zero = 0;
		ph.protocol = UDP_PTCL;
		ph.length = len_mbuf(bp);
		if(udph->checksum != 0 && (i = cksum(&ph,bp,ph.length)) != 0)
			printf(" CHECKSUM ERROR (%u)",i);
	}
	printf("\r\n");
	if(tmpbuf)
		free_p(bp);
}

#endif
