/* **************************************************** */
/* file misc.c:  contains pattern manipulation routines */
/*               and miscellaneous other functions.     */
/*                                                      */
/* Copyright (c) 1990 by Donald R. Tveter               */
/*                                                      */
/* **************************************************** */

#include <stdio.h>
#ifdef INTEGER
#include "ibp.h"
#else
#include "rbp.h"
#endif

extern short backoutput();
extern void backinner();
extern short cbackoutput();
extern void cbackinner();
extern WTTYPE rdr();
extern WTTYPE readchar();
extern void saveweights();
extern WTTYPE scale();
extern double unscaleint();
extern void updatej();
extern void updateo();

extern char backprop;
extern FILE *data;
extern char datafilename[50];
extern UNIT *hlayer;
extern UNIT *ilayer;
extern char informat;
extern UNIT *jlayer;
extern UNIT *klayer;
extern LAYER *last;
extern int lastprint;
extern int npats;
extern int prevnpats;
extern int readerror;
extern int saverate;
extern int skiprate;
extern LAYER *start;
extern char summary;
extern WTTYPE toler;
#ifdef INTEGER
extern int totaldiff;
#else
extern double totaldiff;
#endif
extern int totaliter;
extern int unlearnedpats;
extern char update;
extern WTTYPE wtlimit;
extern char wtlimithit;
extern int wttotal;

void nullpatterns()  /* dispose of any patterns before reading more */
{
  PATLIST *pl, *nextpl;
  PATNODE *pn, *nextpn;
  if (start->patstart != NULL)
     {
       pl = start->patstart;
       nextpl = pl->next;
       while (pl != NULL)
          {
            pn = pl->pats;
            nextpn = pn->next;
            while (pn != NULL)
               {
                 free(pn);
                 pn = nextpn;
                 nextpn = pn->next;
               };
            free(pl);
            pl = nextpl;
            nextpl = pl->next;
          };
       pl = last->patstart;
       nextpl = pl->next;
       while (pl != NULL)
          {
            pn = pl->pats;
            nextpn = pn->next;
            while (pn != NULL)
               {
                 free(pn);
                 pn = nextpn;
                 nextpn = pn->next;
               };
            free(pl);
            pl = nextpl;
            nextpl = pl->next;
          };
     };
  start->patstart = NULL;
  last->patstart = NULL;
  npats = 0;
  prevnpats = 0;
}

void resetpats()
{
 start->currentpat = NULL;
 last->currentpat = NULL;
}

void findendofpats(layer)  /* purpose is to set all layer->currentpat */
LAYER *layer;              /* fields to end of pattern list so more   */
                           /* patterns can be added at the end.       */
{
 PATLIST *pl;

 pl = (PATLIST *) layer->patstart;
 while (pl->next != NULL) pl = pl->next;
 layer->currentpat = pl;
}

int copyhidden(input,hidden,l)
UNIT *input, **hidden;
int l;
{
  if (hidden == NULL)
     {
       printf("ran out of hidden units in layer %d\n",l);
       return(1);
     }
  input->oj = (*hidden)->oj;
  *hidden = (*hidden)->next;
  return(0);
}

void nextpat()
{
  if (start->currentpat == NULL)
     {
       start->currentpat = start->patstart;
       last->currentpat = last->patstart;
     }
  else
     {
       start->currentpat = (start->currentpat)->next;
       last->currentpat = (last->currentpat)->next;
     };
}

void setonepat()       /* sets up patterns on input units */
{
  register PATNODE *p;
  register UNIT *u;
  register LAYER *innerlayers;
  UNIT *hunit, *iunit, *junit, *kunit;
  PATLIST *pl;
  
  hunit = hlayer;
  iunit = ilayer;
  junit = jlayer;
  kunit = klayer;
  pl = start->currentpat;
  p = (PATNODE *) pl->pats;
  u = (UNIT *) start->units;
  while (p != NULL)
     {
       if (p->val > KCODE) u->oj = p->val;
       else if (p->val == HCODE)
               {if (copyhidden(u,&hunit,2) == 1) return;}
       else if (p->val == ICODE)
               {if (copyhidden(u,&iunit,3) == 1) return;}
       else if (p->val == JCODE)
               {if (copyhidden(u,&junit,4) == 1) return;}
       else if (copyhidden(u,&kunit,5) == 1) return;
       u = u->next;
       p = p->next;
     };

  innerlayers = start->next;
  while (innerlayers->next != NULL)
     {  /* set errors on the inner layer units to 0 */
       u = (UNIT *) innerlayers->units;
       while (u != NULL)
          {
            u->error = 0;
            u = u->next;
          };
       innerlayers = innerlayers->next;
     };
}

void limitwts()
{
  register LAYER *layer;
  register UNIT *u;
  register WTNODE *w;

  layer = start->next;
  while (layer != NULL)
   {
    u = (UNIT *) layer->units;
    while (u != NULL)
     {
      w = (WTNODE *) u->wtlist;
      while (w != NULL)
       {
#ifdef SYMMETRIC
        if (*(w->weight) > wtlimit)
           {
             *(w->weight) = wtlimit;
             wtlimithit = 1;
           }
        else if (*(w->weight) < -wtlimit)
           {
             *(w->weight) = -wtlimit;
             wtlimithit = 1;
           };
#else
        if (w->weight > wtlimit)
           {
             w->weight = wtlimit;
             wtlimithit = 1;
           }
        else if (w->weight < -wtlimit)
           {
             w->weight = -wtlimit;
             wtlimithit = 1;
           };
#endif
        w = w->next;
       };
      u = u->next;
     };
    layer = layer->next;
   };
}

#ifndef SYMMETRIC

void whittle(amount)    /* removes weights whose absolute */
WTTYPE amount;          /* value is less than amount      */
{LAYER *layer;
 UNIT *u;
 WTNODE *w, *wprev;

 layer = start->next;
 while (layer != NULL)
   {
     u = (UNIT *) layer->units;
     while (u != NULL)
       {
         w = (WTNODE *) u->wtlist;
         wprev = (WTNODE *) NULL;
         while (w->next != (WTNODE *) NULL)
           {
             if ((w->weight) < amount && (w->weight) > -amount)
               {
                 if (wprev == NULL) (WTNODE *) u->wtlist = w->next;
                 else (WTNODE *) wprev->next = w->next;
                 wttotal = wttotal - 1;
               }
             else wprev = w;
             w = w->next;
           }
         u = u->next;
       }
     layer = layer->next;
   }
}

#endif

void oneset() /* go through the patterns once and update weights */
{ int i;
  LAYER *layer;
  register UNIT *u;
  register WTNODE *w;
  short numbernotclose, attempted, passed;

begin:
 layer = last;      /* make all b->totals = 0 */
 while (layer->backlayer != NULL)
    {
      u = (UNIT *) layer->units;
      while (u != NULL)
         {
           w = (WTNODE *) u->wtlist;
           while (w != NULL)
              {
#ifdef SYMMETRIC
                *(w->total) = 0;
#else
                w->total = 0;
#endif
                w = w->next;
              };
           u = u->next;
         };
      layer = layer->backlayer;
    };
 attempted = 0;
 passed = 0;
 unlearnedpats = npats;
 resetpats();
 for(i=1;i<=npats;i++)
    {
      nextpat();
      if (last->currentpat->bypass <= 0)
         {
           setonepat();
           forward();
           attempted = attempted + 1;
           if (update == 'c' || update == 'C')
              numbernotclose = cbackoutput();
           else numbernotclose = backoutput();
           if (numbernotclose != 0)
              {
#ifndef SYMMETRIC
                if (update == 'c' || update == 'C') cbackinner();
                else backinner();
#endif
              }
           else /* this one pattern has been learned */
              {
                passed = passed + 1;
                unlearnedpats = unlearnedpats - 1;
                last->currentpat->bypass = skiprate;
#ifndef SYMMETRIC
                if (backprop)
                   {
                     if (update == 'c' || update == 'C') cbackinner();
                     else backinner();
                   };
#endif
              }
         }
      else last->currentpat->bypass = last->currentpat->bypass - 1;
    };
 if (unlearnedpats == 0) return;
 if (attempted == passed)
    {
      resetpats();
      for (i=1;i<=npats;i++)
         {
           nextpat();
           last->currentpat->bypass = 0;
         };
      goto begin;
    };
 if (update == 'j') updatej();
 else if (update == 'o' || update == 'd') updateo();
 if (wtlimit != 0) limitwts();
}

void kick(size,amount) /* give the network a kick */
WTTYPE size;
WTTYPE amount;
{ LAYER *layer;
  UNIT *u;
  WTNODE *w;
  WTTYPE value;
  WTTYPE delta;
  int sign;

  layer = start->next;
  while (layer != NULL)
   {
    u = (UNIT *) layer->units;
    while (u != NULL)
     {
      w = (WTNODE *) u->wtlist;
      while (w != NULL)
       {
#ifdef SYMMETRIC
         value = *(w->weight);
#else
         value = w->weight;
#endif
         if (value != 0) sign = 1;
         else if (rand() > 16383) sign = -1;
         else sign = 1;
         delta = (sign * amount * rand()) / 32768;
         if (value >= size) value = value - delta;
         else if (value < -size) value = value + delta;
#ifdef SYMMETRIC
         if (((UNIT *) w->backunit)->unitnumber != u->unitnumber)
            *(w->weight) = value;
#else
         w->weight = value;
#endif
         w = w->next;
       }
      u = u->next;
     }
    layer = layer->next;
   } 
}

void printpats(first,finish,printheader,printerrors,callfromrun)
int first,finish,printheader,printerrors,callfromrun;
{
  int i;
  double err;

  if (summary == '+' && callfromrun)
     {
       printf("%6d   ",totaliter);
       printf("%6d learned ",npats-unlearnedpats);
       printf("%6d unlearned     ",unlearnedpats);
       err = unscaleint(totaldiff) / (npats * last->unitcount);
       printf("%7.5lf error/unit\n",err);
       return;
     };
  lastprint = totaliter;
  if (printheader == 1)
     printf("%d iterations, file = %s\n",totaliter,datafilename);
  resetpats();
  for (i=2;i<=first;i++) nextpat();
  for (i=first;i<=finish;i++)
     { 
       nextpat();
       setonepat();
       printf("%3d ",i);
       forward();
       printoutunits(last,printerrors);
     };
}

void run(n,prpatsrate)
int n;              /* the number of iterations to run */
int prpatsrate;     /* rate at which to print output patterns */

{ int i;
  char wtlimitbefore;

  printf("running . . .\n");
  for (i=1;i<=n;i++)
    {
      totaldiff = 0;
      wtlimitbefore = wtlimithit;
      oneset();
      totaliter = totaliter + 1;
      if (wtlimitbefore == 0 && wtlimithit == 1)
         printf(">>>>> WEIGHT LIMIT HIT <<<<< at %d\n",totaliter);
      if (unlearnedpats == 0)
        {
          if (update != 'c' && update != 'C') totaliter = totaliter - 1;
          if ((prpatsrate > 0) && (lastprint != totaliter))
             printpats(1,npats,1,1,1);
          printf("patterns learned to within %4.2lf",unscale(toler));
          printf(" at iteration %d\n",totaliter);
          return;
        };
      if (totaliter % saverate == 0) saveweights();
      if ((prpatsrate > 0) && ((i % prpatsrate == 0) || (i == n)))
         printpats(1,npats,1,1,1);
    };
} 
