/* "compile" a string expression */
#include "evalint.h"
#include "symbol.h"
#include "support.h"
#include "operations.h"
#include <string.h>

static value _compile(void);

/* Compile numbers, variables, function invocations and (expression) */
static value factor()
{
    register value node;

    switch (symbol_type) {
	case s_left: /* (expression) */
	    advance();
	    node = _compile();
	    if (symbol_type != s_right)
	    {
		eval_error = UNMATCHED;
		free_expr(node);
		node = NULL;
	    }
	    advance();
	    return(node);
	case s_value: /* number, variable or function */
	    switch (symbol.type) {
		case cste:
		    node = cst(symbol.cte);
		    advance();
		    return(node);
		case fonc:
		    {
			FUNCTION func = symbol.fn.func;

			advance();
			if (symbol_type != s_left)
			{
			    eval_error = WANT_LEFT;
			    return(NULL);
			}
			advance();
			node = _compile();
			if (symbol_type != s_right)
			{
			    eval_error = UNMATCHED;
			    free_expr(node);
			    return(NULL);
			}
			advance();
			return(apply_func(func, node));
		    }
		case var:
		    node = dup_var(&symbol.vr);
		    advance();
		    return(node);
	    }
    }
    /* Unexpected symbol */
    eval_error = SYNTAX;
    return(NULL);
}

/* Allow -factor +factor factor */
static value negate()
{
    register value node;
    register e_op _op;

    _op = symbol.opr.op;
    if (symbol_type == s_value && symbol.type == op &&
	(_op == plus || _op == minus))
    {	/* We have + or - */
	advance();
	node = factor();
	if (_op == minus) node = apply_func(neg, node);
    }
    else node = factor();

    return(node);
}

/* power = negate [^ power] */
static value _power()
{
    register value arg1;

    arg1 = negate();

    while (arg1 && symbol_type == s_value && symbol.type == op &&
	   symbol.opr.op == power)
    {
	advance();
	arg1 = topow(arg1, negate());
    }
    return(arg1);
}

/* term = power [ * / term ] */
static value term()
{
    register value arg1;
    register e_op _op;

    arg1 = _power();

    while (arg1 && symbol_type == s_value && symbol.type == op &&
	   ((_op = symbol.opr.op) == times || _op == divide))
    {
	advance();
	if (_op == times) arg1 = mult(arg1, _power());
	else arg1 = div(arg1, _power());
    }
    return(arg1);
}

/* expression = term [+- expression] */
static value _compile()
{
    register value arg1;
    register e_op _op;

    arg1 = term();

    while (arg1 && symbol_type == s_value && symbol.type == op &&
	   ((_op = symbol.opr.op) == plus || _op == minus))
    {
	advance();
	if (_op == plus) arg1 = add(arg1, term());
	else arg1 = sub(arg1, term());
    }
    return(arg1);
}

/* Call recursive descent parser */
value compile(expr)
char *expr;
{
    register value res;

    eval_error = 0;

    /* Initialise lexical analyser. We've always got one symbol read ahead */
    init_symbol(expr);
    advance();
    res = _compile();
    if (symbol_type != s_none)
    {
	eval_error = SYNTAX;
	free_expr(res);
	res = NULL;
    }
    return(res);
}

/* Free expression */
void free_expr(expr)
register value expr;
{
    if (expr)
    {
	switch (expr->type)
	{
	    case var: FREEMEM(strlen(expr->vr.name) + 1, expr->vr.name); break;
	    case op: free_expr(expr->opr.arg1); free_expr(expr->opr.arg2); break;
	    case fonc: free_expr(expr->fn.arg); break;
	}
	FREE_VALUE(expr);
    }
}

