/* Differentiate an expression */
#include "evalint.h"
#include "operations.h"
#include "variables.h"
#include <string.h>

static value _differentiate(value, char *);

/* Do a substitution, but replacing all variables of the form d<name> in the
  replacement pattern with the differential of the value of <name> (which will
  have been set by the call to match_pattern) */
static value substitute_derv(sub, expr, by)
subst *sub;
value expr;
char *by;
{
    value res = NULL;

    if (match_pattern(sub->pat, expr))
    {
	register struct Node *scan;
	register int fail = FALSE;

	/* Check for d<name> vars */
	for (scan = sub->pat->vl.lh_Head; scan->ln_Succ; scan = scan->ln_Succ)
	{
	    if (scan->ln_Name[0] == 'd' && scan->ln_Name[1])
	    {
		register value diff;

		/* Do the differentiation */
		diff = _differentiate(get_var_name(&scan->ln_Name[1]), by);
		if (!diff) { fail = TRUE; break; }
		set_var_name(scan->ln_Name, diff);
	    }
	}
	/* Evaluate differential, and clean up the expression */
	if (!fail) res = _eval(sub->rep, VAR | NICE | PAT);

	/* free expressions of d<name> vars */
	for (scan = sub->pat->vl.lh_Head; scan->ln_Succ; scan = scan->ln_Succ)
	{
	    if (scan->ln_Name[0] == 'd' && scan->ln_Name[1])
	    {
		register value val;

		val = get_var_name(scan->ln_Name);
		if (val) free_expr(val);
	    }
	}
	free_vars(&sub->pat->vl);
    }
    return(res);
}

/* Actual differentiation routine, which is now quite small ... */
static value _differentiate(expr, by)
value expr;
char *by;
{
    switch (expr->type)
    {
	case cste:
	    return(cst(0.0));
	case var:
	    /* d<by>/d<by> = 1, dy/d<by> = 0 */
	    return(cst(strcmp(expr->vr.name, by) == 0 ? 1.0 : 0.0));
	default: /* Check for patterns */
	    {
		register subst **scan;
		value der = NULL;

		for (scan = derv; *scan; scan++)
		{
		    if (eval_error == 0 && (der = substitute_derv(*scan, expr, by)))
			return(der);
		}
		if (eval_error == 0) eval_error = NOT_DIFFERENTIABLE;
		return(NULL);
	    }
    }
}

value differentiate(expr, by)
value expr;
char *by;
{
    eval_error = 0;
    return(_differentiate(expr, by));
}

