//  poly.cpp - a public domain polynomial root finder
//             by D.A. Steinman (steinman@me.utoronto.ca)
//
//  As part of a project I was working on recently, I required a
//  robust 4th order polynomial root solver.  Simple iterative
//  methods (e.g. Bairstow) don't always work, and most good black
//  box routines are written in Fortran and cost $$$.  However, since
//  I was concerned only with the 4th order case, I knew it could be
//  solved exactly.  The code which follows represents the fruit of my
//  labours...
//
//  POLY solves the general polynomial:
//
//      a[n] x^n + a[n-1] x^n-1 ... a[1] x + a[0] = 0  n<=4
//
//  where the coefficients a[i] are real and, for numerical
//  convenience, the leading term is normalized to unity by dividing
//  each coefficient by a[n].  Polynomials of order 4 or less (i.e.
//  n<=4) can be solved exactly, and this is what POLY will do for
//  you.  No messy iterations.  No worries about multiple roots.  The
//  routines for the cubic and quartic cases are based upon methods
//  outlined in:
//
//      Tignol J-P.  Galois' theory of algebraic equations.  Longman
//          Scientific & Technical, Harlow, 1988  (ISBN 0-470-20919-4)
//
//  POLY was written in C++ (for the complex variables), and was
//  compiled using Turbo C++ v1.0; I haven't tested it for any other
//  flavour of C++.  POLY *should* solve any quartic or lower order
//  polynomial.  If it doesn't, feel free to find the bug and post a
//  fixed version.

#include <stdio.h>
#include <complex.h>
#include <math.h>
#include <stdlib.h>

#define N  5
#define SIGN(a)  (a>=0 ? "+" : "-")

void poly1(double *, complex *);
void poly2(double *, complex *);
void poly3(double *, complex *);
void poly4(double *, complex *);

void main(int argc, char *argv[])
{
    complex x[N],resid,tx;
    double a[N];
    int i,j,k,n,tr;

    /* print syntax message */
    n = (--argc)-1;
    if (n<1 || n>4) {
        puts("");
        puts("poly - zeros of 4th order (and lower) polynomials");
        puts("       v1.00 public domain by D.A. Steinman");
        puts("");
        puts("syntax:  poly  a[n]..a[0] (n=[1..4])");
        exit(1);
    }

    /* initialize coefficients and roots */
    for (i=0;i<=n;i++) a[i] = 0.0;
    for (i=1;i<=n;i++) x[i] = complex(0.0);

    /* assign coefficients */
    for (i=n;i>=0;i--) a[i]=atof(argv[n+1-i]);
    if (fabs(a[n])==0.0) {
        puts("\n*error* - a[n] must be non-zero");
        exit(2);
    }
    for (i=0;i<=n;i++) a[i] /= a[n];

    /* write polynomial */
    printf("\np(x):  ");
    if (n>1) printf("x^%1d",n);
    else printf("x");
    for (i=n-1;i>1;i--) if (fabs(a[i])>0.0) printf(" %s %lg x^%1d",SIGN(a[i]),fabs(a[i]),i);
    if (fabs(a[1])>0.0 && n>1) printf(" %s %lg x",SIGN(a[1]),fabs(a[1]));
    if (fabs(a[0])>0.0) printf(" %s %lg",SIGN(a[0]),fabs(a[0]));
    puts(" = 0\n");

    /* route solver based upon order */
    switch(n) {
      case 1: 
        poly1(&a[0],&x[0]);              // linear
        break;
      case 2:
        if (fabs(a[0])==0.0)
            poly1(&a[1],&x[1]);          // degenerate linear
        else
            poly2(&a[0],&x[0]);          // quadratic
        break;
      case 3:
        if (fabs(a[0])==0.0)
            if (fabs(a[1])==0.0)
                poly1(&a[2],&x[2]);      // degenerate linear
            else
                poly2(&a[1],&x[1]);      // degenerate quadratic
        else
            poly3(&a[0],&x[0]);          // cubic
        break;
      case 4:
        if (fabs(a[0])==0.0)
            if (fabs(a[1])==0.0)
                if (fabs(a[2])==0.0)
                    poly1(&a[3],&x[3]);  // degenerate linear
                else
                    poly2(&a[2],&x[2]);  // degenerate quadratic
            else
                poly3(&a[1],&x[1]);      // degenerate cubic
        else
            poly4(&a[0],&x[0]);          // quartic
        break;
    }

    /* sort by descending real part */
    for (i=1;i<=n;i++) for (j=i+1;j<=n;j++)
        if (real(x[j]) > real(x[i])) {
            tx = x[i];
            x[i] = x[j];
            x[j] = tx;
        }

    /* display roots */
    for (i=1;i<=n;i++) {
        resid = a[0];
        for (j=1;j<=n;j++) resid += a[j]*pow(x[i],double(j));
        printf("x[%1d]: % lf%+lfi   p(x) = %lg\n",i,real(x[i]),imag(x[i]),abs(resid));
    }
}

void poly1(double *a, complex *x)
{
    x[1] = -a[0];
}

void poly2(double *a, complex *x)
{
    complex t0;

    t0 = sqrt(complex(a[1]*a[1]-4*a[0]));
	x[1] = -0.5*(a[1]-t0);
	x[2] = -0.5*(a[1]+t0);
}

void poly3(double *a, complex *x)
{
	complex y[N],z[N];
    double p,q;
    int i;

    /* eliminate x^2 term via y=x+a[2]/3 */
    p = a[1] - a[2]*a[2]/3;
    q = a[0] - a[2]*a[1]/3 + 2*a[2]*a[2]*a[2]/27;

    if (fabs(p)==0.0) {  // special case of p=0 ==> x^3 = -q
       y[1] = pow(complex(-q),1.0/3.0);
	   for (i=2;i<=3;i++) y[i] = y[i-1]*0.5*complex(-1,sqrt(3.0));
    } else if (fabs(q)==0.0) {  // special case of q=0 ==> x^2 = -p
       y[1] = 0.0;
       y[2] = sqrt(complex(-p));
       y[3] = -sqrt(complex(-p));
    } else {
	   z[1] = pow(q/2+sqrt(complex(p*p*p/27+q*q/4)),1.0/3.0);
	   for (i=2;i<=3;i++) z[i] = z[i-1]*0.5*complex(-1,sqrt(3.0));
	   for (i=1;i<=3;i++) y[i] = p/3/z[i]-z[i];
    }
    for (i=1;i<=3;i++) x[i] = y[i]-a[2]/3;
}

void poly4(double *a, complex *x)
{
    complex ux[N],y[N],t0,t1;
    double ua[N],p,q,r;
    int i;

    /* eliminate x^3 term via y=x+a[3]/4 */
    p = a[2] - 3*a[3]*a[3]/8;
    q = a[1] - a[3]*a[2]/2 + a[3]*a[3]*a[3]/8;
    r = a[0] - a[3]*a[1]/4 + a[3]*a[3]*a[2]/16 - 3*a[3]*a[3]*a[3]*a[3]/256;

    /* solve resolvent cubic */
    ua[2] = p;
    ua[1] = p*p/4-r;
    ua[0] = -q*q/8;
    poly3(ua,ux);

    if (fabs(q)==0.0) {  // special case of p=0 ==> x^3 = -q
        t0 = sqrt(complex(p*p/4-r));
        t1 = sqrt(-p/2+t0);
        y[1] =  t1;
        y[2] = -t1;
        t1 = sqrt(-p/2-t0);
        y[3] =  t1;
        y[4] = -t1;
    } else {
        t0 = sqrt(ux[1]/2);
        t1 = sqrt(-ux[1]/2-p/2-q/2/sqrt(2*ux[1]));
        y[1] =  t0+t1;
        y[2] =  t0-t1;
        t1 = sqrt(-ux[1]/2-p/2+q/2/sqrt(2*ux[1]));
        y[3] = -t0+t1;
        y[4] = -t0-t1;
    }
    for (i=1;i<=4;i++) x[i] = y[i]-a[3]/4;
}
