/* The actual evaluator */
#include "math.h"
#include "support.h"
#include "evalint.h"
#include "operations.h"
#include "variables.h"

/* These macros are not 'correct' (the use of "op") */
#define RES(op, res1, res2) \
{ \
    _FPERR = 0; \
    res = check_fp(res1 op res2, res1, res2, "op"); \
}

#define RET(op, res1, res2) \
{ \
    _FPERR = 0; \
    return(check_fp(res1 op res2, res1, res2, "op")); \
}

/* recursive evaluator */
value _eval(expr, flags)
value expr;
int flags;
{
    register value node;

    switch (expr->type)
    {
	case cste:
	    node = cst(expr->cte);
	    break;
	case var:
	    {
		register value val;

		val = get_var(&expr->vr);
		if (val && (flags & VAR)) node = id(val);
		else if (val && (flags & REC))
		{
		    /* Check for recursion */
		    if (expr->vr.adr->flags & USED)
		    {
			eval_error = RECURSIVE;
			node = NULL;
		    }
		    else
		    {
			/* Mark variable as already referenced */
			expr->vr.adr->flags |= USED;
			node = _eval(val, flags);
			expr->vr.adr->flags &= ~USED;
		    }
		}
		else node = dup_var(&expr->vr);
		break;
	    }
	case fonc:
	    {
		register value arg;
		double c;

		arg = _eval(expr->fn.arg, flags);
		/* Reduce constants ? */
		if (arg && (!(flags & NORED)) && arg->type == cste)
		{
		    c = (*expr->fn.func)(arg->cte);
		    if ((!(flags & NICE)) || c == trunc(c))
		    {
			/* We can replace by a number */
			arg->cte = c;
			node = arg;
			break;
		    }
		}
		node = apply_func(expr->fn.func, arg);
		break;
	    }
	case op:
	    {
		value sarg1, sarg2 = NULL;
		double res;

		sarg1 = _eval(expr->opr.arg1, flags);
		if (sarg1)
		{
		    sarg2 = _eval(expr->opr.arg2, flags);

		    /* Reduce constants ? */
		    if (sarg2 && (!(flags & NORED)) && sarg1->type == cste && sarg2->type == cste)
		    {
			switch (expr->opr.op)
			{
			    case plus: RES(+, sarg1->cte, sarg2->cte); break;
			    case minus: RES(-, sarg1->cte, sarg2->cte); break;
			    case times: RES(*, sarg1->cte, sarg2->cte); break;
			    case divide: RES(/, sarg1->cte, sarg2->cte); break;
			    case power: res = pow(sarg1->cte, sarg2->cte); break;
			}
			if ((!(flags & NICE)) || res == trunc(res))
			{
			    /* Result is a number */
			    sarg1->cte = res;
			    FREE_VALUE(sarg2);
			    node = sarg1;
			    break;
			}
		    }
		}
		node = dup_op(expr, sarg1, sarg2);
		break;
	    }
    }
    /* Do substitutions (don't bother checking for constants and variables) */
    if (node && (flags & PAT) && (node->type == op || node->type == fonc))
    {
	register subst **scan;
	value new = NULL;

	/* scan list of substitutions */
	for (scan = simp; *scan; scan++)
	{
	    if (eval_error == 0 && (new = substitute(*scan, node)))
	    {
		/* found one ! */
		free_expr(node); /* free old expr */
		node = new;
		break;
	    }
	}
    }
    return(node);
}

/* User callable routine (clears eval_error) */
value eval(expr, flags)
value expr;
int flags;
{
    eval_error = 0;
    return(_eval(expr, flags));
}

/* return a number ! */
static double _quick_eval(expr)
register value expr;
{
    switch (expr->type)
    {
	case cste:
	    return(expr->cte);
	case var:
	    {
		double ret;
		value val = get_var(&expr->vr);

		if (!val) eval_error = NOTNUM;
		else
		{
		    /* All variables are recursively evaluated */
		    if (expr->vr.adr->flags & USED) eval_error = RECURSIVE;
		    else
		    {
			expr->vr.adr->flags |= USED;
			ret = _quick_eval(expr->vr.adr->val);
			expr->vr.adr->flags &= ~USED;
			return(ret);
		    }
		}
	    }
	    return(0.0); /* there was an error ... */
	case fonc:
	    return((*expr->fn.func)(_quick_eval(expr->fn.arg)));
	case op:
	    {
		double arg1, arg2;

		arg1 = _quick_eval(expr->opr.arg1);
		if (eval_error == 0)
		{
		    arg2 = _quick_eval(expr->opr.arg2);

		    if (eval_error == 0)
			switch (expr->opr.op)
			{
			    case plus: RET(+, arg1, arg2);
			    case minus: RET(-, arg1, arg2);
			    case times: RET(*, arg1, arg2);
			    case divide: RET(/, arg1, arg2);
			    case power: return(pow(arg1, arg2));
			}
		}
		return(0.0);
	    }
    }
}

/* User callable version */
double quick_eval(expr)
value expr;
{
    eval_error = 0;
    return(_quick_eval(expr));
}

