/* 
 * This is source code to CASL (Custom Audit Scripting Language)
 *
 * Copyright 1998 Secure Networks, Inc.
 * Copyright 1999 Network Associates, Inc.
 * All Rights Reserved
 *
 * BEFORE YOU INSTALL, USE, OR MODIFY THIS SOFTWARE PRODUCT,
 * CAREFULLY READ THE TERMS AND CONDITIONS IN THE FILE
 * "LICENSE.TXT" ACCOMPANYING THIS DOCUMENT. IF THE FILE
 * "LICENSE.TXT" IS MISSING, IT MAY BE OBTAINED FROM
 * NETWORK ASSOCIATES. NETWORK ASSOCIATES IS PERMITTING
 * THE USE, DISTRIBUTION, AND LIMITED MODIFICATION OF THIS
 * SOFTWARE PRODUCT ON A NON-COMMERCIAL BASIS SUBJECT TO
 * ALL OF THE CONDITIONS IN THE FILE "LICENSE.TXT." BY INSTALLING,
 * USING, OR MODIFYING THE SOFTWARE PRODUCT, YOU AND ANY
 * SUBSEQUENT USER ARE AGREEING TO BE BOUND BY ALL OF THE
 * TERMS AND CONDITIONS IN THE FILE "LICENSE.TXT." IF YOU DO
 * NOT AGREE TO ALL OF THOSE TERMS AND CONDITIONS, DO NOT
 * INSTALL, USE, OR MODIFY THIS SOFTWARE PRODUCT.
 */

#include "casl.h"

/* Manage the symbol tables for CASL.
 *
 * The most important functions in here are st_get() and st_insert(),
 * which insert and remove ASR nodes to/from the current symbol table
 * based on a string ID. 
 * 
 * There are multiple symbol tables active at any point in time during
 * the execution of a CASL script. Each subroutine call creates a new
 * local symbol table, and there's also a constantly available global
 * symbol table. Only one local symbol table (the current one) is 
 * available at any point in time, and the symbol retrieval function
 * will attempt to retrieve from the global table if a local lookup
 * fails.
 *
 * The structure management code is also in here; each structure has
 * it's own "symbol table", which contains a definition of the fields
 * of the structure (length, offset, name, etc).  
 *
 */

table_t *Global = NULL;
table_t *Local  = NULL;
table_t *Builtins = NULL;
list_t *LocalStack = NULL;
int Level = 0;

table_t *StructTab = NULL;

static asr_t *st_check(char *id, int level);

/* ---------------------------------------------------------------------------
** Create a new symbol table and return it.
*/

table_t *st_new() {
	table_t *t = t_new(1, NULL, NULL);

#ifdef TRACE_CALLS
	Dprintf("st_new()\n");
#endif
	
	return(t);
}

/* ---------------------------------------------------------------------------
** Insert a new symbol into the table, at level (global, local) 
** indicated. Causes the creation of the global symbol table if 
** this is the first insert.
*/

int st_insert(char *id, asr_t *node, int level) {
	if(!Global) 
		Global = st_new();

#ifdef TRACE_CALLS
	Dprintf("st_insert(%s, %p, %d)\n", id, node, level);
#endif

	if(st_get(id, 1)) 
		return(1);

	if(st_get(id, 0)) 
		error(E_USER,
			"initialization of \"%s\" shadows a global symbol", id);

	t_put(level ? Local : Global, id, node);

	return(0);
}

/* ---------------------------------------------------------------------------
** Check the symbol table for a symbol.
*/

static asr_t *st_check(char *id, int level) {
	asr_t *ap = NULL;

       	if(level == BUILTIN) {
		if(!Builtins)
			return(0);

		if((ap = t_get(Builtins, id)))
			return(ap);
		else
			return(0);
	}

	if(level == UPLEVEL) {
		if(!LocalStack) {
			if((ap = t_get(Global, id)))
				return(ap);
			else
				return(0);
		} else {
			table_t *t = LocalStack->data;

			if(!t)
				return(0);
			
			if((ap = t_get(t, id)))
				return(ap);
			else
				return(0);
		}
	}

	if(level) {
		if((ap = t_get(Local, id)))
			return(ap);
		else
			return(0);
	}

	if((ap = t_get(Global, id)))
		return(ap);
	else
		return(0);	       		
}

/* ---------------------------------------------------------------------------
** Replace a symbol in the symbol table.
*/

int st_replace(char *id, asr_t *node, int level) {
	asr_t *old; 

	if(!Global)
		Global = st_new();

#ifdef TRACE_CALLS
	Dprintf("st_replace(%s, %p, %d)\n", id, node, level);
#endif

	if((old = st_check(id, Level))) {
		st_remove(id, Level);
		st_insert(id, node, Level);
	} else if(Level && (old = st_check(id, 0))) {
		st_remove(id, 0);
		st_insert(id, node, 0);
	} else {
		st_insert(id, node, Level);
		return(0);
	}

	alloc_downref(old);

	return(1);
}

/* ---------------------------------------------------------------------------
** Return true if the symbol is a global symbol.
*/

int st_global(char *id) {
	if(st_check(id, 0))
		return(1);
	
	return(0);
}

/* ---------------------------------------------------------------------------
** Add a builtin function to the builtin-function-table
*/

void bt_insert(char *id, asr_t *(*func)(asr_t *)) {
	if(!Builtins)
		Builtins = st_new();

#ifdef TRACE_CALLS
	Dprintf("bt_insert(%s, %p)\n", id, func);
#endif

	t_put(Builtins, id, func);
	return;
}

/* ---------------------------------------------------------------------------
** retrieve a node associated with a symbol from the indicated symbol
** table.
*/

asr_t *st_get(char *id, int level) {
	asr_t *ap;
	table_t *t = level ? Local : Global;

#ifdef TRACE_CALLS
	Dprintf("st_get(%s, %d)\n", id, level);
#endif

	assert(Global);

	if(level == BUILTIN) {
		if(!Builtins)
			return(NULL);

		return(t_get(Builtins, id));
	}

	if(level == UPLEVEL) 
		return(st_get_uplevel(id));

	if(!t) 
		return(NULL);

	if(!(ap = t_get(t, id)) && t != Global)
		ap = t_get(Global, id);

	return(ap);
}

/* ---------------------------------------------------------------------------
** used solely during argument passing; get a local variable from the
** scope of the calling function, or from global scope.
*/

asr_t *st_get_uplevel(char *id) {
	asr_t *ap = NULL;
	table_t *lt = NULL;

#ifdef TRACE_CALLS
	Dprintf("st_get_uplevel(%s)\n", id);
#endif

	if(LocalStack)
		lt = (table_t *) LocalStack->data;

	assert(Global);

	if(lt)
		ap = t_get(lt, id);

	if(!ap)
		ap = t_get(Global, id);

	return(ap);
}

/* ---------------------------------------------------------------------------
** Remove a symbol from the symbol table.
*/

void st_remove(char *id, int level) {
#ifdef TRACE_CALLS
	Dprintf("st_remove(%s, %d)\n", id, level);
#endif

	if(level > 0 && !Local)
		return;

	t_remove(level ? Local : Global, id);
	return;
}

/* ---------------------------------------------------------------------------
** Push the current local table onto the scope stack.
*/

void st_push() {
	table_t *t = st_new();

#ifdef TRACE_CALLS
	Dprintf("st_push()\n");
#endif

	LocalStack = l_push(LocalStack, Local);
	Local = t;

	return;
}

/* ---------------------------------------------------------------------------
** Do GC on all the values in a symbol table before freeing it.
*/

static void scollect (const void *key, void **value, void *cl);

void st_collect () {
	assert (Local);

	t_map (Local, scollect, (void *)0);
}

void st_globalfree () {
	assert (Global);

	t_map (Global, scollect, (void *)0);

	table_free (&Global);

	Global = 0;
}

static void scollect (const void *key, void **value, void *cl)
{
	asr_t *gone = *(asr_t **)value;

	alloc_downref (gone);
}

/* ---------------------------------------------------------------------------
** Pop a symbol table from the scope stack and make it local.
*/

void st_pop() {
	table_t *t = NULL;

#ifdef TRACE_CALLS
	Dprintf("st_pop()\n");
#endif

	if(!LocalStack) 
		return;

	LocalStack = l_pop(LocalStack, (void *) &t);
	
	table_free(&Local);
	Local = t;

	return;
}

/* ---------------------------------------------------------------------------
** Create a new structure definition.
*/

void sd_new(char *name) {
	st *s = (st *) NEW(sizeof(*s));
	table_t *sde = t_new(10, NULL, NULL);

#ifdef TRACE_CALLS
	Dprintf("sd_new(%s)\n", name);
#endif

	s->st_name = strdup(name);
	s->st_sds  = sde;
	
	if(!StructTab) 
		StructTab = t_new(512, NULL, NULL);

	t_put(StructTab, s->st_name, s);

	return;
}

/* ---------------------------------------------------------------------------
** Add a field to a structure definition.
*/

int sd_addfield(char *name, sdef *s) {
	st *sts;

#ifdef TRACE_CALLS
	Dprintf("sd_addfield(%s, %p)\n", name, s);
#endif

	if(!StructTab)
		return(0);

	sts = t_get(StructTab, name);
	if(!sts)
		return(0);

	if(t_get(sts->st_sds, s->sd_name))
		error(E_USER, "redefinition of field \"%s\" in struct \"%s\"",
		      s->sd_name, name);	  
		
	t_put(sts->st_sds, s->sd_name, s);
	return(1);
}	

/* ---------------------------------------------------------------------------
** Get a field from a structure definition.
*/

sdef *sd_getfield(char *name, char *field) {
	sdef *sd;
	st *s;

#ifdef TRACE_CALLS
	Dprintf("sd_getfield(%s, %s)\n", name, field);
#endif

	if(!StructTab)
		return(NULL); 

	s = t_get(StructTab, name);
	if(!s) 
		return(NULL);

	sd = t_get(s->st_sds, field);
	if(!sd)
		return(NULL);
	
	return(sd);
}

/* ---------------------------------------------------------------------------
** Get a structure definition by name.
*/

st *sd_get(char *name) {
	st *s;

#ifdef TRACE_CALLS
	Dprintf("sd_get(%s)\n", name);
#endif	

	if(!StructTab)
		return(NULL);

	s = t_get(StructTab, name);

	return(s);
}

/* ---------------------------------------------------------------------------
** Sacrifice a chicken for a structure definion.
*/

static void sd_lc(const void *x, void **v, void *c) {
	int *boff = (int *) c;
	sdef *s = (sdef *) *v;

	*boff += s->sd_bitlength;

	return;
}

/* ---------------------------------------------------------------------------
** Get the size needed to store a structure.
*/

int sd_len(char *name) {
	int i = 0;
	st *s;

#ifdef TRACE_CALLS
	Dprintf("sd_len(%s)\n", name);
#endif
	
	if(!StructTab)
		return(0);

	s = t_get(StructTab, name);
	if(!s)
		return(0);

	t_map(s->st_sds, sd_lc, &i);
	
	return(i);
}
