/*
**	msql_proc.c	- 
**
**
** Copyright (c) 1993-95  David J. Hughes
** Copyright (c) 1995-96  Hughes Technologies Pty Ltd
**
*/


#include <stdio.h>
#include <sys/types.h>

#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <signal.h>
#include <netdb.h>
#include <unistd.h>
#include <stdlib.h>
#include <sys/time.h>
#include <string.h>

#include <common/debug.h>
#include <common/portability.h>



#define MSQL_ADT
#define _MSQL_SERVER_SOURCE
#include "msql_priv.h"
#include "msql.h"
#include "y.tab.h"

int     command,
	notnullflag,
	keyflag,
	outSock,
	msqlSelectLimit;

char 	*curDB,
	*arrayLen;

cond_t	*condHead = NULL;
cstk_t 	*condStack = NULL;
field_t	*fieldHead = NULL,
	*lastField = NULL;
order_t	*orderHead = NULL;
tname_t	*tableHead = NULL;
mindex_t	indexHead;
time_t	queryTime;
char	seqTable[NAME_LEN + 1];
seq_t	sequence;

static	cond_t	*condTail = NULL;
static	field_t	*fieldTail = NULL;
static	order_t	*orderTail = NULL;
static	tname_t	*tableTail = NULL;
static 	int	havePriKey = 0;


#define REG		register


extern	char	*packet;


/*
** A few things to reduce the amount of mallocing we do
*/

#define	MAX_SPARE_VALUES	15
static 	val_t *spareValues[MAX_SPARE_VALUES];
static	int numSpareValues = 0;



/****************************************************************************
** 	_msqlClean - clean out the internal structures
**
**	Purpose	: Free all space and reset structures after a query
**	Args	: None
**	Returns	: Nothing
**	Notes	: Updates all public data structures
*/

void msqlClean()
{
	register field_t	*curField, *tmpField;
	register cond_t		*curCond, *tmpCond;
	register order_t 	*curOrder, *tmpOrder;
	register tname_t	*curTable, *tmpTable;

	msqlTrace(TRACE_IN,"msqlClean()");
	command = 0;
	havePriKey = 0;

	/*
	** blow away the table list from the query
	*/
	curTable = tableHead;
	while(curTable)
	{
		tmpTable = curTable;
		curTable = curTable->next;
		(void)free(tmpTable);
	}


	/*
	** blow away the field list from the query
	*/
	curField = fieldHead;
	while(curField)
	{
		msqlFreeValue(curField->value);
		tmpField = curField;
		curField = curField->next;
		(void)free(tmpField);
	}

	/*
	** Blow away the condition list from the query
	*/
	curCond = condHead;
	while(curCond)
	{
		msqlFreeValue(curCond->value);
		curCond->op = curCond->bool = 0;
		tmpCond = curCond;
		curCond = curCond->next;
		(void)free(tmpCond);
	}
	curOrder = orderHead;
	while(curOrder)
	{
		curOrder->dir = 0;
		tmpOrder = curOrder;
		curOrder = curOrder->next;
		(void)free(tmpOrder);
	}


	/*
	** Reset the list pointers
	*/

	condHead = condTail = (cond_t *) NULL;
	fieldHead = fieldTail = lastField = (field_t *) NULL;
	orderHead = orderTail = (order_t *) NULL;
	tableHead = tableTail = (tname_t *) NULL;

	msqlBackendClean();

	msqlTrace(TRACE_OUT,"msqlClean()");
}



ident_t *msqlCreateIdent(seg1,seg2)
	char	*seg1,
		*seg2;
{
	ident_t	*new;

	msqlTrace(TRACE_IN,"msqlCreateIdent()");
	if (seg1)
	{
		if ((int)strlen(seg1) > NAME_LEN)
		{
			sprintf(packet,
				"-1:Identifier name '%s' too long\n",seg1);
			writePkt(outSock);
			msqlTrace(TRACE_OUT,"msqlCreateIdent()");
			return(NULL);
		}
	}
	if (seg2)
	{
		if ((int)strlen(seg2) > NAME_LEN)
		{
			sprintf(packet,
				"-1:Identifier name '%s' too long\n",seg2);
			writePkt(outSock);
			msqlTrace(TRACE_OUT,"msqlCreateIdent()");
			return(NULL);
		}
	}
	new = (ident_t *)fastMalloc(sizeof(ident_t));
	*(new->seg1) = *(new->seg1) = 0;
	if (seg1)
	{
		(void)strcpy(new->seg1,seg1);
	}
	if (seg2)
	{
		(void)strcpy(new->seg2,seg2);
	}
	msqlTrace(TRACE_OUT,"msqlCreateIdent()");
	return(new);
}



static u_char expandEscape(c,remain)
	u_char	*c;
	int	remain;
{
	u_char	ret;

	switch(*c)
	{
		case 'n':
			ret = '\n';
			break;
		case 't':
			ret = '\t';
			break;
		case 'r':
			ret = '\r';
			break;
		case 'b':
			ret = '\b';
			break;
		default:
			ret = *c;
			break;
	}
	return(ret);
}



val_t *msqlCreateValue(textRep,type,tokLen)
	u_char	*textRep;
	int	type,
		tokLen;
{
	val_t	*new;
	int	length,
		remain;
	REG 	u_char	*cp,
			*cp2;

	msqlTrace(TRACE_IN,"msqlCreateValue()");
	if (numSpareValues > 0)
	{
                numSpareValues--;
                new = spareValues[numSpareValues];
	}
	else
	{
		new = (val_t *)malloc(sizeof(val_t));
	}

	new->type = type;
	new->dataLen = tokLen;
	new->nullVal = 0;
	switch(type)
	{
		case NULL_TYPE:
			new->nullVal = 1;
			break;
		case IDENT_TYPE:
			new->val.identVal = (ident_t *)textRep;
			break;
		case CHAR_TYPE:
			remain = length = tokLen - 2;
			new->val.charVal = (u_char *)malloc(length+1);
			cp = textRep+1;
			cp2 = new->val.charVal;
			while(remain)
			{
				if (*cp == '\\')
				{
					remain--;
					*cp2 = expandEscape(++cp,remain);
					if (*cp2)
					{
						cp2++;
						cp++;
						remain--;
					}
				}
				else
				{
					*cp2++ = *cp++;
					remain--;
				}
			}
			break;

		case INT_TYPE:
			new->val.intVal = atoi((char *)textRep);
			break;

		case REAL_TYPE:
			sscanf((char *)textRep ,"%lg",&new->val.realVal);
			break;
	}
	msqlTrace(TRACE_OUT,"msqlCreateValue()");
	return(new);
}


val_t *fillValue(val,type,length)
	char	*val;
	int	type,
		length;
{
	val_t	*new;

	msqlTrace(TRACE_IN,"fillValue()");
        if (numSpareValues > 0)
        {
                numSpareValues--;
                new = spareValues[numSpareValues];
        }
        else
        {
                new = (val_t *)malloc(sizeof(val_t));
        }

	new->type = type;
	new->nullVal = 0;
	switch(type)
	{
		case CHAR_TYPE:
			new->val.charVal = (u_char *)malloc(length+1);
			(void)bcopy(val,new->val.charVal,length);
			break;

		case INT_TYPE:
#ifndef _CRAY
			new->val.intVal = (int) * (int *)val;
#else
			new->val.intVal = unpackInt32(val);
#endif
			break;

		case REAL_TYPE:
			new->val.realVal = (double) * (double *)val;
			break;
	}
	msqlTrace(TRACE_OUT,"fillValue()");
	return(new);
}


val_t *nullValue()
{
	val_t	*new;

        if (numSpareValues > 0)
        {
                numSpareValues--;
                new = spareValues[numSpareValues];
        }
        else
        {
                new = (val_t *)malloc(sizeof(val_t));
        }

	new->nullVal = 1;
	return(new);
}



void msqlFreeValue(val)
	val_t	*val;
{
	msqlTrace(TRACE_IN,"msqlFreeValue()");
	if (!val)
	{
		msqlTrace(TRACE_OUT,"msqlFreeValue()");
		return;
	}
	switch(val->type)
	{
		case IDENT_TYPE:
			(void)free(val->val.identVal);
			break;
		case CHAR_TYPE:
		case TEXT_TYPE:
			if (!val->nullVal)
				(void)free(val->val.charVal);
			break;
	}
	if (numSpareValues < MAX_SPARE_VALUES)
	{
		spareValues[numSpareValues] = val;
		numSpareValues++;
	}
	else
	{
		(void)free(val);
	}
	msqlTrace(TRACE_OUT,"msqlFreeValue()");
}





int msqlAddSequence(table, step, val)
	char	*table;
	int	step,
		val;
{
	cache_t	*entry;

	msqlTrace(TRACE_IN,"msqlAddSequence()");

	strcpy(seqTable,table);
	sequence.step = step;
	sequence.value = val;

	msqlTrace(TRACE_OUT,"msqlAddSequence()");
}




int msqlAddSetFunct(funct, ident)
	char	*funct;
	ident_t	*ident;
{
	register field_t	*new;
	char	*name,
		*table;

	msqlTrace(TRACE_IN,"msqlAddSetField()");

	name = ident->seg2;
	table = ident->seg1;

	if (checkSetFunctName(funct) < 0)
	{
		msqlTrace(TRACE_OUT,"msqlAddSetFunct()");
		return(-1);
	}
	new = (field_t *)malloc(sizeof(field_t));
	if (table)
	{
		(void)strncpy(new->table,table,NAME_LEN - 1);
	}
	(void)strncpy(new->name,name,NAME_LEN - 1);
	(void)strncpy(new->funct,funct,NAME_LEN - 1);

	if (!fieldHead)
	{
		fieldHead = fieldTail = new;
	}
	else
	{
		fieldTail->next = new;
		fieldTail = new;
	}
	free(ident);
	msqlTrace(TRACE_OUT,"msqlAddSetField()");
	return(0);
}


/****************************************************************************
** 	_msqlAddField - add a field definition to the list
**
**	Purpose	: store field details from the query for later use
**	Args	: field name, field type, field length, value
**	Returns	: Nothing
**	Notes	: Depending on the query in process, only some of the
**		  args will be supplied.  eg. a SELECT doesn't use the
**		  type arg.  The length arg is only used during a create
**		  if the field is of type CHAR
*/

int msqlAddField(ident,type,length,notNull,priKey)
	ident_t	*ident;
	int 	type;
	char	*length;
	int	notNull,
		priKey;
{
	register field_t	*new;
	char	*name,
		*table;

	msqlTrace(TRACE_IN,"msqlAddField()");

	name = ident->seg2;
	table = ident->seg1;


	/*
	** Look for duplicate field names on a table create
	*/
	if (type != 0)
	{
		new = fieldHead;
		while(new)
		{
			if (strcmp(new->name,name) == 0)
			{
				sprintf(packet,
					"-1:Duplicate field name '%s'\n",
					name);
				writePkt(outSock);
				msqlTrace(TRACE_OUT,"msqlAddField()");
				return(-1);
			}
			new = new->next;
		}
	}

	if (priKey)
	{
		sprintf(packet,
			"-1:Primary keys are obsolete.  Use CREATE INDEX\n");
		writePkt(outSock);
		msqlTrace(TRACE_OUT,"msqlAddField()");
		return(-1);
	}


	new = (field_t *)fastMalloc(sizeof(field_t));
	*(new->table) = *(new->name) = *(new->funct) = 0;
	new->value = NULL;
	new->entry = NULL;
	new->next = NULL;
	new->type = new->sysvar = new->length = new->dataLength =
		new->offset = new->null = new->flags = 
		new->fieldID = new->overflow = 0;


	if (table)
	{
		(void)strncpy(new->table,table,NAME_LEN - 1);
	}
	(void)strncpy(new->name,name,NAME_LEN - 1);
	if (notNull)
	{
		new->flags |= NOT_NULL_FLAG;
	}
	switch(type)
	{
		case INT_TYPE:
			new->type = INT_TYPE;
			new->dataLength = new->length = 4;
			break;

		case CHAR_TYPE:
			new->type = CHAR_TYPE;
			new->dataLength = new->length = atoi(length);
			break;

		case REAL_TYPE:
			new->type = REAL_TYPE;
			new->dataLength = new->length = sizeof(double);
			break;

		case TEXT_TYPE:
			new->type = TEXT_TYPE;
			new->length = atoi(length);
			new->dataLength = new->length + VC_HEAD_SIZE;
			break;

		default:
			new->type = 0;
			new->dataLength = new->length = 0;
			break;
	}
	new->overflow = NO_POS;
	if (!fieldHead)
	{
		fieldHead = fieldTail = new;
	}
	else
	{
		fieldTail->next = new;
		fieldTail = new;
	}
	free(ident);
	msqlTrace(TRACE_OUT,"msqlAddField()");
	return(0);
}

void msqlSetSelectLimit(value)
	val_t	*value;
{
	msqlSelectLimit = value->val.intVal;
}


void msqlAddFieldValue(value)
	val_t	*value;
{
	register field_t	*fieldVal;
	u_char	*buf;

	msqlTrace(TRACE_IN,"msqlAddFieldValue()");
	if (!lastField)
	{
		lastField = fieldVal = fieldHead;
	}
	else
	{	
		fieldVal = lastField->next;
		lastField = lastField->next;
	}
	if (fieldVal)
	{
		if (fieldVal->type == CHAR_TYPE)
		{
			buf = (u_char *)malloc(fieldVal->length+1);
			bcopy(value->val.charVal,buf,value->dataLen);
			free(value->val.charVal);
			value->val.charVal = buf;
		}
		fieldVal->value = value;
	}
	msqlTrace(TRACE_OUT,"msqlAddFieldValue()");
}




void msqlPushCond()
{
	cstk_t	*new;

	new = (cstk_t *)malloc(sizeof(cstk_t));
	new->next = condStack;
	new->head = condHead;
	new->tail = condTail;
	condStack = new;
	condHead = condTail = NULL;
}


void msqlPopCond()
{
	cstk_t	*tmp;

	tmp = condStack;
	condStack = condStack->next;
	condHead = tmp->head;
	condTail = tmp->tail;
	free(tmp);
}


void msqlAddSubCond(bool)
	int	bool;
{
	register cond_t	*new;

	msqlTrace(TRACE_IN,"msqlAddSubCond()");

	new = (cond_t *)malloc(sizeof(cond_t));
	new->op = 0;
	new->bool = bool;
	new->value = NULL;
	new->subCond = condHead;
	msqlPopCond();

	if (!condHead)
	{
		condHead = condTail = new;
	}
	else
	{
		condTail->next = new;
		condTail = new;
	}
	msqlTrace(TRACE_OUT,"msqlAddSubCond()");
}


/****************************************************************************
** 	_msqlAddCond  -  add a conditional spec to the list
**
**	Purpose	: Store part of a "where_clause" for later use
**	Args	: field name, test op, value, bool (ie. AND | OR)
**	Returns	: Nothing
**	Notes	: the BOOL field is only provided if this is not the first
**		  element of a where_clause.
*/

void msqlAddCond(ident,op,value,bool)
	ident_t	*ident;
	int	op;
	val_t	*value;
	int	bool;
{
	register cond_t	*new;
	char	*name,
		*table;

	msqlTrace(TRACE_IN,"msqlAddCond()");

	if (*(ident->seg2))
	{
		name = ident->seg2;
		table = ident->seg1;
	}
	else
	{
		name = ident->seg1;
		table = NULL;
	}

	new = (cond_t *)malloc(sizeof(cond_t));
	(void)strcpy(new->name,name);
	if (table)
	{
		(void)strcpy(new->table,table);
	}
	
	new->op = op;
	new->bool = bool;
	new->value = value;
	new->subCond = NULL;

	if (!condHead)
	{
		condHead = condTail = new;
	}
	else
	{
		condTail->next = new;
		condTail = new;
	}
	free(ident);
	msqlTrace(TRACE_OUT,"msqlAddCond()");
}



/****************************************************************************
** 	_msqlAddOrder  -  add an order definition to the list
**
**	Purpose	: Store part of an "order_clause"
**	Args	: field name, order direction (ie. ASC or DESC)
**	Returns	: Nothing
**	Notes	: 
*/

void msqlAddOrder(ident,dir)
	ident_t	*ident;
	int	dir;
{
	register order_t	*new;

	msqlTrace(TRACE_IN,"msqlAddOrder()");

	new = (order_t *)malloc(sizeof(order_t));
	if (*ident->seg1)
	{
		(void)strcpy(new->table,ident->seg1);
	}
	(void)strcpy(new->name,ident->seg2);
	new->dir = dir;
	if (!orderHead)
	{
		orderHead = orderTail = new;
	}
	else
	{
		orderTail->next = new;
		orderTail = new;
	}
	free(ident);
	msqlTrace(TRACE_OUT,"msqlAddOrder()");
}




void msqlAddTable(name,alias)
	char	*name,
		*alias;
{
	register tname_t	*new;

	msqlTrace(TRACE_IN,"msqlAddTable()");

	new = (tname_t *)fastMalloc(sizeof(tname_t));
	*(new->name) = *(new->cname) = 0;
	new->done = 0;
	new->next = NULL;

	if (alias)
	{
		(void)strcpy(new->name,alias);
		(void)strcpy(new->cname,name);
	}
	else
	{
		(void)strcpy(new->name,name);
		*(new->cname) = 0;
	}
	if (!tableHead)
	{
		tableHead = tableTail = new;
	}
	else
	{
		tableTail->next = new;
		tableTail = new;
	}
	msqlTrace(TRACE_OUT,"msqlAddTable()");
}



void msqlAddIndex(name, table, uniq, type)
	char	*name,
		*table;
	int	uniq,
		type;
{
	msqlTrace(TRACE_IN,"msqlAddIndex()");
	strncpy(indexHead.name,name,NAME_LEN);
	strncpy(indexHead.table,table,NAME_LEN);
	indexHead.unique = uniq;
	indexHead.idxType = type;
	msqlTrace(TRACE_OUT,"msqlAddIndex()");
}


void msqlSetDB(db)
	char	*db;
{
	curDB = db;
}



/****************************************************************************
** 	_msqlProcessQuery  -  send to query to the approp routine
**
**	Purpose	: Call the required routine to build the native query
**	Args	: None
**	Returns	: Nothing
**	Notes	: Global command variable used.  This is called when the
**		  end of an individual query is found in the miniSQL code
**		  and provides the hooks into the per-database routines.
*/



void msqlProcessQuery()
{
	int	res;

	msqlTrace(TRACE_IN,"msqlProcessQuery()");
	if (!curDB)
	{
		sprintf(packet,"-1:No current database\n");
		writePkt(outSock);
		msqlDebug(MOD_ERR,"No current database\n");
		msqlTrace(TRACE_OUT,"msqlProcessQuery()");
		return;
	}
	queryTime = time(NULL);
	switch(command)
	{
		case SELECT: 
			if (!msqlCheckPerms(READ_ACCESS))
			{
				sprintf(packet,"-1:Access Denied\n");
				writePkt(outSock);
				msqlTrace(TRACE_OUT,"msqlProcessQuery()");
				return;
			}
			res = msqlServerSelect(tableHead,fieldHead,condHead,
				orderHead,curDB);
			break;
		case CREATE_TABLE: 
			if (!msqlCheckPerms(WRITE_ACCESS))
			{
				sprintf(packet,"-1:Access Denied\n");
				writePkt(outSock);
				msqlTrace(TRACE_OUT,"msqlProcessQuery()");
				return;
			}
			res = msqlServerCreateTable(tableHead->name,fieldHead,
				curDB);
			break;
		case CREATE_INDEX: 
			if (!msqlCheckPerms(WRITE_ACCESS))
			{
				sprintf(packet,"-1:Access Denied\n");
				writePkt(outSock);
				msqlTrace(TRACE_OUT,"msqlProcessQuery()");
				return;
			}
			res = msqlServerCreateIndex(&indexHead,fieldHead,curDB);
			break;
		case CREATE_SEQUENCE: 
			if (!msqlCheckPerms(WRITE_ACCESS))
			{
				sprintf(packet,"-1:Access Denied\n");
				writePkt(outSock);
				msqlTrace(TRACE_OUT,"msqlProcessQuery()");
				return;
			}
			res = msqlServerCreateSequence(seqTable,sequence.step,
				sequence.value,curDB);
			break;
		case UPDATE: 
			if (!msqlCheckPerms(RW_ACCESS))
			{
				sprintf(packet,"-1:Access Denied\n");
				writePkt(outSock);
				msqlTrace(TRACE_OUT,"msqlProcessQuery()");
				return;
			}
			res = msqlServerUpdate(tableHead->name,fieldHead,
				condHead, curDB);
			break;
		case INSERT: 
			if (!msqlCheckPerms(WRITE_ACCESS))
			{
				sprintf(packet,"-1:Access Denied\n");
				writePkt(outSock);
				msqlTrace(TRACE_OUT,"msqlProcessQuery()");
				return;
			}
			res = msqlServerInsert(tableHead->name,fieldHead,curDB);
			break;
		case DELETE: 
			if (!msqlCheckPerms(WRITE_ACCESS))
			{
				sprintf(packet,"-1:Access Denied\n");
				writePkt(outSock);
				msqlTrace(TRACE_OUT,"msqlProcessQuery()");
				return;
			}
			res = msqlServerDelete(tableHead->name,condHead,curDB);
			break;
		case DROP_TABLE: 
			if (!msqlCheckPerms(WRITE_ACCESS))
			{
				sprintf(packet,"-1:Access Denied\n");
				writePkt(outSock);
				msqlTrace(TRACE_OUT,"msqlProcessQuery()");
				return;
			}
			res = msqlServerDropTable(tableHead->name,curDB);
			break;
		case DROP_INDEX: 
			if (!msqlCheckPerms(WRITE_ACCESS))
			{
				sprintf(packet,"-1:Access Denied\n");
				writePkt(outSock);
				msqlTrace(TRACE_OUT,"msqlProcessQuery()");
				return;
			}
			res = msqlServerDropIndex(&indexHead,curDB);
			break;
		case DROP_SEQUENCE: 
			if (!msqlCheckPerms(WRITE_ACCESS))
			{
				sprintf(packet,"-1:Access Denied\n");
				writePkt(outSock);
				msqlTrace(TRACE_OUT,"msqlProcessQuery()");
				return;
			}
			res = msqlServerDropSequence(seqTable,curDB);
			break;
	}
	if (res < 0)
	{
		extern	char	errMsg[];

		sprintf(packet,"-1:%s\n",errMsg);
		msqlTrace(TRACE_OUT,"msqlProcessQuery()");
		writePkt(outSock);
	}
	msqlTrace(TRACE_OUT,"msqlProcessQuery()");
}


void msqlQueryOverrunError(txt)
	char	*txt;
{

	msqlTrace(TRACE_IN,"msqlQueryOverrunError()");
	sprintf(packet,"-1:Syntax error.  Bad text after query. '%s'\n",txt);
	writePkt(outSock);
	msqlTrace(TRACE_OUT,"msqlQueryOverrunError()");
}




void msqlParseQuery(inBuf,sock)
        char    *inBuf;
        int     sock;
{
	msqlTrace(TRACE_IN,"msqlParseQuery()");
        outSock = sock;
        msqlInitScanner((u_char *)inBuf);
	bzero(&indexHead, sizeof(indexHead));
        yyparse();
	msqlTrace(TRACE_OUT,"msqlParseQuery()");
}

