/*
     Use the Nonlinear Shooting Algorithm (Algorithm 11.2, p.570) to solve
     
     Exercise Set 11.2, p.573, #3(d):
     
     y" =  2 * pow(y, 3) - 6 * y - 2 * pow(x, 3), 
     
                    where 1 <= x <= 2, y(1) = 2, y(2) = 2.5, h = 0.05
     
     note: true solution is x + 1/x
*/
     

#include <stdio.h>       //Standard C i/o
#include <stdlib.h>      //Standard C library
#include <math.h>        //Standard C math routines for pow(x,y)

//true solution y(x) = x + 1 / x  is defined as a macro
#define y(x) ((x) + 1.0 / (x))

typedef double (*funptr)(double, double, double); //function pointer

double F(double, double, double),       //f(x, y, y')
       Fy(double, double, double),      //fy(x, y, y')
       Fyp(double, double, double);     //fy'(x, y, y')

const n = 500; //Number of subintervals

class UserEquation {
     protected:
          double h, a, b;          //step and endpoints
          double alpha, beta;      //boundary conditions
          double TOL;              //tolerance
          long N;                  //number of subintervals
          long M;                  //maximum number of iterations
          funptr f, fy, fyp;       //function pointers
     public:
          UserEquation();          //constructor for UserEquation
};

//Use class inheritance
class NonlinearShootingAlgorithm : public UserEquation {
     public:
          NonlinearShootingAlgorithm();
};

main()
{
     NonlinearShootingAlgorithm nsa;    //This is all we need.
}

UserEquation::UserEquation()
{
     h = 0.05;
     TOL = 0.00001;
     alpha = 2.0;
     beta = 2.5;
     a = 1.0;
     b = 2.0;
     N = (b - a) / h + 1;
     if(N > n) {
          printf("Number of subintervals should be increased.\n");
          exit(0);
     }
     M = 10000;

     f = F;         //f(x, y, y') = F(x, y, y')
     fy = Fy;       //fy(x, y, y') = Fy(x, y, y')
     fyp = Fyp;     //fyp(x, y, y') = Fy'(x, y, y')
}

NonlinearShootingAlgorithm::NonlinearShootingAlgorithm()
{
     double w1[n], w2[n], u1, u2, x, soln;
     double k11, k12, k21, k22, k31, k32, k41, k42;
     double kp11, kp12, kp21, kp22, kp31, kp32, kp41, kp42;

     int k = 1;                              //Step 1.
     double TK = (beta - alpha) / (b - a);

     while(k <= M) {                              //Step 2.
          w1[0] = alpha;                          //Step 3.
          w2[0] = TK;
          u1 = 0.;
          u2 = 1.;

          //Runge-Kutta method for systems is used in step 5 and 6.
          for(int i = 1; i <= N; i++) {           //Step 4.
               x = a + h * (i - 1);                    //Step 5.

               k11 = h * w2[i-1];                      //Step 6.
               k12 = h * f(x, w1[i-1], w2[i-1]);
               k21 = h * (w2[i-1] + k12/2);
               k22 = h * f(x+h/2, w1[i-1]+k11/2, w2[i-1]+k12/2);
               k31 = h * (w2[i-1] + k22/2);
               k32 = h * f(x+h/2, w1[i-1]+k21/2, w2[i-1]+k22/2);
               k41 = h * (w2[i-1] + k32);
               k42 = h * f(x+h, w1[i-1] + k31, w2[i-1] + k32);
               w1[i] = w1[i-1] + (k11 + k21*2 + k31*2 + k41) / 6;
               w2[i] = w2[i-1] + (k12 + k22*2 + k32*2 + k42) / 6;
               kp11 = h * u2;
               kp12 = h * (fy(x, w1[i-1], w2[i-1]) * u1 +
                          fyp(x, w1[i-1], w2[i-1]) * u2);
               kp21 = h * (u2 + kp12 / 2);
               kp22 = h * (fy(x+h/2, w1[i-1], w2[i-1]) * (u1 + kp11/2) +
                          fyp(x+h/2, w1[i-1], w2[i-1]) * (u2 + kp21/2));
               kp31 = h * (u2 + kp22 / 2);
               kp32 = h * (fy(x+h/2, w1[i-1], w2[i-1]) * (u1 + kp21/2) +
                          fyp(x+h/2, w1[i-1], w2[i-1]) * (u2 + kp22/2));
               kp41 = h * (u2 + kp32);
               kp42 = h * (fy(x+h, w1[i-1], w2[i-1]) * (u1 + kp31) +
                          fyp(x+h, w1[i-1], w2[i-1]) * (u2 + kp32));
               u1 += (kp11 + kp21 * 2 + kp31 * 2 + kp41) / 6;
               u2 += (kp12 + kp22 * 2 + kp32 * 2 + kp42) / 6;
          }

          if(fabs(w1[N] - beta) <= TOL) {         //Step 7.
               for(i = 0; i <= N; i++) {               //Step 8.
                    x = a + h * i;
                    soln = y(x);
                    printf("x = %lf, w1[i] = %lf, w2[i] = %lf\n",
                              x, w1[i], w2[i]);
                    printf("The solution is %lf, %%error = %lf%%\n",
                              soln, fabs((soln - w1[i])/soln*100));
               }
               exit(0);                                //Step 9.
          }
          TK -= (w1[N] - beta) / u1;              //Step 10.
          k = k + 1;
     }
     printf("Maximum number of iterations exceeded!\n");//Step 11.
     exit(0);
}

//F = f(x, y, y') = 2y**3 - 6y - 2x**3
double F(double x, double y, double yp)
{
     return 2.0 * pow(y, 3) - 6.0 * y - 2.0 * pow(x, 3);
}

//Fy = fy(x, y, y') = 6y - 6
double Fy(double x, double y, double yp)
{
     return 6.0 * y * y - 6.0; //note: not 6y - 6
}

//Fyp = Fy'(x, y, y') = 0.0 for f
double Fyp(double x, double y, double yp)
{
     return 0.0;
}
