
/*
     Use Crank-Nicolson Algorithm (Algorithm 12.3, p.630) to solve
     
     Exercise Set 12.3, p.633, #3(d).
     
     True solution : u(x, t) = exp(-t) * cos(PI * (x - 0.5))
     
*/   

#include <stdio.h>       //Standard C i/o
#include <stdlib.h>      //Standard C library
#include <math.h>        //Standard C math routines for exp(x)

#define PI 3.141592654

typedef double (*funptr)(double);  //function pointer

double F(double),                  //f(x)
       U(double, double);          //u(x, t)

const size = 50;    //size of the array used

class UserEquation {
     protected:
          double    l;             //endpoint
          double    alpha;         //constant
          double    namda;
          long      m, N;
          double    T;             //maximum time
          double    h, k;
          funptr    f;             //function pointers
     public:
          UserEquation();          //constructor for UserEquation
};

//Use object inheritance
class CrankNicolsonAlgorithm : public UserEquation {
     public:
          CrankNicolsonAlgorithm();
          //void print_initial_values(void);
};

main()
{
     CrankNicolsonAlgorithm cna;   //This is all we need.
}

//The constructor initializes its member variables.
UserEquation::UserEquation()
{
     h = 0.1;
     k = 0.04;
     alpha = 1. / PI;
     namda = alpha * alpha * k / (h * h);
     l = 1.0;
     N = 10;
     m = l / h + 1;
     if(m > size) {
          printf("m = %ld too big for size = %ld!\n", m, size);
          exit(0);
     }
     f = F;         //f points to the right function
}

//The constructor function does everything.
CrankNicolsonAlgorithm::CrankNicolsonAlgorithm()
{
     double w[size], L[size], u[size], z[size];
     double t;

     w[m] = 0.;                                   //Step 1.

     for(int i = 1; i < m; i++)                   //Step 2.
          w[i] = f(h * i);    //initial values

     L[1] = 1. + namda;                           //Step 3.
     u[1] = - namda / (L[1] * 2.);

     for(i = 2; i < m - 1; i++) {                 //Step 4.
          L[i] = 1 + namda + namda * u[i-1] / 2.;
          u[i] = - namda / (2. * L[i]);
     }

     L[m-1] = 1 + namda + namda * u[m-2] / 2.;    //Step 5.

     for(int j = 1; j < N + 1; j++) {             //Step 6.
          t = j * k;                                   //Step 7.
          z[1] = ((1. - namda) * w[1] + namda / 2. * w[2]) / L[1];

          for(i = 2; i < m; i++)                       //Step 8.
               z[i] = ((1.- namda) * w[i] + namda/2. *
                         (w[i+1] + w[i-1] + z[i-1])) / L[i];

          w[m-1] = z[m-1];                             //Step 9.

          for(i = m - 2; i > 0; i--)                   //Step 10.
               w[i] = z[i] - u[i] * w[i+1];

          double soln, abserror;
          double x = 0.;
          printf("\nt = %lf\n", t);                    //Step 11.
          for(i = 1; i < m; i++) {
               x += h;
               soln = U(x, t);
               t += k;
               abserror = fabs(soln - w[i]);
               printf("x = %lf, w[%d] = %lf, u(x,t) = %lf, |error| = %lf\n",
                    x, i, w[i], soln, abserror);
          }
     }
     exit(0);                                     //Step 12.
}

//F(x, y) = f(x, y) = (x**2 + y**2) * exp(x * y)
double F(double x)
{
     return cos(PI * (x - 0.5));
}

double U(double x, double t)
{
     return exp(-t) * cos(PI * (x - 0.5)); //u(x, t) is the solution
}
