/*=============================*/
/*           NETS              */
/*                             */
/* a product of the AI Section */
/* NASA, Johnson Space Center  */
/*                             */
/* principal author:           */
/*       Paul Baffes           */
/*                             */
/* contributing authors:       */
/*      Bryan Dulock           */
/*      Chris Ortiz            */
/*=============================*/


/*
----------------------------------------------------------------------
  This is the main module for the neural net.  All of the main code   
  plus documentation on the system features lives here.

  A note on nomenclature: Most of the routine calls which you will 
  see in the code throughout these files will be "prefixed" by some 
  letters and an underscore.  That is one of my conventions for indicating
  which file contains the code for the subroutine in question. Because
  I have several files, this makes tracing problems and debugging
  files easier to accomplish.  

  This file is one of 14 source files which make up the back
  propagation code. These files are the following:

      activate.c              semi-linear activation function
      buildnet.c
      compile.c
      convert.c               conversion routines, conversion routines
      layer.c                 layer manipulation and creation
      lnrate.c
      net.c                   net manipulation, learning, propagation
      netio.c                 I/O routines, file handlers
      netmain.c               main routines, menus, user interface
      pairs.c
      prop.c
      shownet.c
      teach.c
      weights.c               weights manipulation and creation

  All of these are covered in detail in the respective files. The prefix
  codes for each of the files are as follows:

      activate.c              "A_"
      buildnet.c              "B_"
      compile.c               "CC_"
      convert.c               "C_"
      dribble.c               "D_"
      net.c                   "N_"
      netio.c                 "IO_"
      netmain.c               *NONE*
      pairs.c                 "PA_"
      parser.c                "PS_"
      prop.c                  "P_"
      layer.c                 "L_"
      lnrate.c                "LR_"
      shownet.c               "S_"
      teach.c                 "T_"
      weights.c               "W_"
      sysdep.c                "sys_"    (system dependent code)

  The rest of this file is organized into the folloing groups:

  (1) include files
  (2) externed functions
  (3) global variables
  (4) subroutines
----------------------------------------------------------------------
*/


/*
----------------------------------------------------------------------
  INCLUDE FILES
----------------------------------------------------------------------
*/
#include  "common.h"
#include  "netio.h"
#include  "weights.h"
#include  "layer.h"
#include  "net.h"


/*
----------------------------------------------------------------------
  EXTERNED FUNCTIONS
----------------------------------------------------------------------
*/
extern Net    *B_create_net();
extern Net    *B_free_net();

extern void   N_query_net();
extern int    N_reset_wts();
extern void   N_save_wts();
extern Layer  *N_get_layer();

extern void   T_teach_net();

extern void   S_show_weights();
extern void   S_show_biases();
extern void   S_show_net();

extern float  IO_my_get_float();
extern int    IO_my_get_int();
extern int    IO_get_default_int();
extern int    IO_get_num_cycles();
extern void   IO_my_get_string();
extern void   IO_set_filenames();
extern void   IO_get_io_name();
extern void   IO_get_wts_name();
extern void   IO_print();

extern void   PA_initialize();
extern void   PA_setup_iopairs();
extern void   PA_reset_iopairs();
extern void   PA_randomize_file();

extern void   D_initialize();
extern void   D_dribble_status(); 

extern Sint   C_float_to_Sint();
extern void   L_modify_learning();
extern void   sys_init_rand();
extern void   CC_create_delivery();


/*
----------------------------------------------------------------------
  GLOBAL VARIABLES
----------------------------------------------------------------------
 Next come the global variables, declared in other routines, which    
  need to be referenced here.  In general I tried to keep the number  
  of globals to a minimum since they can be messy, but many of the    
  io functions needed to keep "state" variables for the lifetime of   
  the program execution.  Examples are the default file names which   
  are used when prompting the user.  These names are read here, and   
  referenced in  net.c as well as netio.c, and thus I needed to be    
  able to pass them around.  I could have simply left them as global  
  values and referenced them as needed, but instead I passed them     
  explicitly to the non-io routines that needed them.  I could just   
  as easily have declared them here and made them external to the     
  io package, but since file names and IO are so intimately tied, I   
  thought it more logical to declare these variables with the other   
  IO code.                                                            
----------------------------------------------------------------------
*/
extern char  net_config[];                 /* these three from netio.c  */
extern char  net_iop[];
extern char  net_fwt[];
extern char  net_pwt[];
extern char  IO_str[MAX_LINE_SIZE];

static Net  *the_net;          /* variable which holds ptr to the net   */
                               /* (global to this netmain routine only) */
                               

/*
======================================================================
  ROUTINES IN NETMAIN.C
======================================================================
  The routines in this file are grouped below by function.  NO ROUTINE
  ARE PREFIXED IN THIS FILE. 

  The main routine is just below.  It basically amounts to a read-eval-
  print loop, much like an interpreter. After initialization, this code
  loops forever, calling "print_menu", "read_choice", and then 
  "evaluate".  The only other routine is a "check_net_ptr" routine 
  used during evaluate to verify that a valid net exists before 
  attempting some operation.
======================================================================
*/


main()
BEGIN
   void  initialize(), print_menu(), evaluate(), cleanup();
   char  read_choice();
   char  c;
      
   initialize();
   print_menu();
   while (TRUE) BEGIN
      c = read_choice();
      if (c == 'q') BEGIN
         cleanup();
         break;
      ENDIF
      evaluate(c);
   ENDWHILE

END /* main */


void initialize()
/*
-----------------------------------------------------------------
 This is the initialization routine for the main program. Much   
  of what the program is able to do is assumed here. For example 
  I have assumed (for now) that we will be dealing with only one 
  net at a time (even though I have made allowances in the NET   
  structure for multiple nets).  Thus, this guy only initializes 
  the one global 'the_net' variable.  This would need changing   
  in the future if more nets were added.                         
 Note also that the system random number generator is setup here 
  This is significant since IT MAY NOT BE PORTABLE.              
-----------------------------------------------------------------
*/
BEGIN
   the_net = NULL;
   sys_init_rand();
   PA_initialize();
   D_initialize();
   
END /* initialize */


void print_menu()
/*
----------------------------------------------------------------------
 This routine prints out the variety of choices available to the user 
  for operating on/creating his net.                                  
----------------------------------------------------------------------
*/
BEGIN
   sprintf(IO_str, "\n\n");
   IO_print(0);
   sprintf(IO_str, "NASA-JSC Artificial Intelligence Section\n");
   IO_print(0);
   sprintf(IO_str, "NETS Back Propagation Simulator Version 2.0\n\n");
   IO_print(0);
   sprintf(IO_str, "b -- show bias values\n");
   IO_print(0);
   sprintf(IO_str, "c -- create a net\n");
   IO_print(0);
   sprintf(IO_str, "d -- set dribble parameters\n");
   IO_print(0);
   sprintf(IO_str, "g -- generate delivery code\n");
   IO_print(0);
   sprintf(IO_str, "i -- setup I/O pairs\n");
   IO_print(0);
   sprintf(IO_str, "j -- reset I/O pairs\n");
   IO_print(0);
   sprintf(IO_str, "l -- change learning rate\n");
   IO_print(0);
   sprintf(IO_str, "m -- print this menu\n");
   IO_print(0);
   sprintf(IO_str, "n -- show net configuration\n");
   IO_print(0);
   sprintf(IO_str, "o -- reorder I/O pairs for training\n");
   IO_print(0);
   sprintf(IO_str, "p -- propagate an input through the net\n");
   IO_print(0);
   sprintf(IO_str, "r -- reset weights from a file\n");
   IO_print(0);
   sprintf(IO_str, "s -- save weights to a file\n");
   IO_print(0);
   sprintf(IO_str, "t -- teach the net\n");
   IO_print(0);
   sprintf(IO_str, "w -- show weights between two layers \n");
   IO_print(0);
   sprintf(IO_str, "q -- quit program\n");
   IO_print(0);

END /* print_menu */


char  read_choice()
/*
----------------------------------------------------------------------
 Now, this routine will seem simple, but I wanted to separate it out  
  in the event that the future brings more sophisticated input.  Also 
  this program needs to make sure that the input read in is legal,    
  which is really not part of the print_menu function at all.         
----------------------------------------------------------------------
*/
BEGIN
   char  result[MAX_LINE_SIZE];

   sprintf(IO_str, "\nNETS Choice(m = menu)? ");
   IO_print(0);
   gets(result);
   while (result[0] != 'b' 
          && result[0] != 'c' 
          && result[0] != 'd' 
          && result[0] != 'g' 
          && result[0] != 'i' 
          && result[0] != 'j' 
          && result[0] != 'l' 
          && result[0] != 'm' 
          && result[0] != 'n' 
          && result[0] != 'o' 
          && result[0] != 'p' 
          && result[0] != 'r' 
          && result[0] != 's' 
          && result[0] != 't' 
          && result[0] != 'w' 
          && result[0] != 'q') BEGIN
      sprintf(IO_str, "\nSorry, I don't understand that input.\n");
      IO_print(0);
      sprintf(IO_str, "Please try again: ");
      IO_print(0);
      gets(result);
   ENDWHILE
   return(result[0]);

END /* read_choice */


void  evaluate(input)
char  input;
/*
----------------------------------------------------------------------
 Evaluates the input command and calls the appropriate function to    
  carry out the actions                                               
----------------------------------------------------------------------
*/
BEGIN
   int    t1, t2, check_net_ptr();
   float  tf1;
   char   filename[MAX_WORD_SIZE], file2[MAX_WORD_SIZE], 
          tmp_str[MAX_WORD_SIZE];
   FILE   *fp; /* for query from file */

   switch (input) BEGIN
      case 'b': BEGIN
         if (check_net_ptr(the_net) == OK) BEGIN
            if (the_net->use_biases == FALSE) BEGIN
               sprintf(IO_str, "\n*** network has no biases");
               IO_print(0);
            ENDIF
            else BEGIN
               sprintf(IO_str, "\n   layer number? ");
               IO_print(0);
               t1 = IO_my_get_int();
               S_show_biases(the_net, t1);
            ENDELSE
         ENDIF
         break;
      ENDCASE
      case 'c': BEGIN
         the_net = B_free_net(the_net);
         IO_set_filenames();
         the_net = B_create_net(1, net_config);
         break;
      ENDCASE
      case 'd': BEGIN
         D_dribble_status();
         break;
      ENDCASE
      case 'g': BEGIN
         if (check_net_ptr(the_net) == OK) BEGIN
            /*--------------------------------*/
            /* use the net_config filename as */
            /* the configuration file.        */
            /*--------------------------------*/
             strcpy(filename, net_config);
            /*-----------------------------*/
            /* prompt for the weights file */
            /*-----------------------------*/
            sprintf(IO_str, "\n   Enter filename with PORTABLE weight values");
            IO_print(0);
            sprintf(IO_str, " (default=%s): ", net_pwt);
            IO_print(0);
            IO_my_get_string(file2);
            if (file2[0] == ENDSTRING) 
               strcpy(file2, net_pwt);

            /*-------------------------*/
            /* call the create routine */
            /*-------------------------*/
            CC_create_delivery(the_net, filename, file2);
         ENDIF
         break;
      ENDCASE
      case 'i': BEGIN
         if (check_net_ptr(the_net) == OK) BEGIN
            IO_get_io_name();
            PA_setup_iopairs(the_net, net_iop);
         ENDIF
         break;
      ENDCASE
      case 'j': BEGIN
         if (check_net_ptr(the_net) == OK) BEGIN
            sprintf(IO_str, "\n   Resetting I/O pairs to 'workfile.net'");
            IO_print(0);
            PA_reset_iopairs(the_net, "workfile.net");
         ENDIF
         break;
      ENDCASE
      case 'l': BEGIN
         if (check_net_ptr(the_net) == OK) BEGIN
            sprintf(IO_str, "\n   Layer number: ");
            IO_print(0);
            t1 = IO_my_get_int();
            L_modify_learning(t1, N_get_layer(the_net, t1));
         ENDIF
         break;
      ENDCASE
      case 'm' : BEGIN
         print_menu();
         break;
      ENDCASE
      case 'n': BEGIN
         if(check_net_ptr(the_net) == OK)
            S_show_net(the_net);
         break;
      ENDCASE
      case 'o': BEGIN
         if(check_net_ptr(the_net) == OK)
            if (the_net->num_io_pairs <= 0) BEGIN
               sprintf(IO_str, "\n\n*** no valid set of io pairs to randomize ***\n");
               IO_print(0);
            ENDIF
            else BEGIN
               sprintf(IO_str, "\n   Enter filename for reordered IO pairs: ");
               IO_print(0);
               IO_my_get_string(filename);
               while (filename[0] == ENDSTRING) BEGIN
                  sprintf(IO_str, "\n      try again: ");
                  IO_print(0);
                  IO_my_get_string(filename);
               ENDWHILE
               PA_randomize_file(filename, "workfile.net", the_net->num_io_pairs,
                                 the_net->num_inputs + the_net->num_outputs);
            ENDELSE

         break;
      ENDCASE
      case 'p': BEGIN
         if (check_net_ptr(the_net) == OK) BEGIN
            /*----------------------------------*/
            /* Get the input file/screen source */
            /* and number of IO pairs to prop.  */
            /*----------------------------------*/
            sprintf(IO_str, "\n   Enter filename with INPUT(default=manual entry): ");
            IO_print(0);
            IO_my_get_string(filename);
            
            if (filename[0] != ENDSTRING) BEGIN
               sprintf(IO_str, "\n   Enter number of I/O pairs to propagate (default=all): ");
               IO_print(0);
               t1 = IO_get_default_int(-1);
            ENDIF
            
            /*----------------------------------------*/
            /* Get the output file/screen destination */
            /*----------------------------------------*/
            sprintf(IO_str, "\n   Enter OUTPUT destination file(default=screen): ");
            IO_print(0);
            IO_my_get_string(file2);
            if (file2[0] == ENDSTRING)
               fp = NULL;
            else 
               fp = fopen(file2, "wt");
            
            /*-------------------------------*/
            /* call the query function, then */
            /* tidy up output file if used   */
            /*-------------------------------*/
            N_query_net(the_net, filename, fp, t1);
            if (fp != NULL) fclose(fp);
         ENDIF
         break;
      ENDCASE
      case 'r': BEGIN
         if (check_net_ptr(the_net) == OK) BEGIN
            sprintf(IO_str,"\n   Load weights from FAST or PORTABLE format");
            IO_print(0);
            sprintf(IO_str, "(f/p, default=f)? ");
            IO_print(0);
            IO_my_get_string(tmp_str);
            if ((tmp_str[0] == 'p') || (tmp_str[0] == 'P'))
               t1 = PORTABLE_FORMAT;
            else t1 = FAST_FORMAT;
            sprintf(IO_str, "\n   Enter name of file with weight values");
            IO_print(0);
            IO_get_wts_name(t1);
            t1 = ( (t1 == FAST_FORMAT) 
                   ? N_reset_wts(the_net, net_fwt, t1)
                   : N_reset_wts(the_net, net_pwt, t1) );
            if (t1 == ERROR) BEGIN
               sprintf(IO_str, "\n*** weight resetting was incomplete ***");
               IO_print(0);
            ENDIF
         ENDIF
         break;
      ENDCASE
      case 's': BEGIN
         if (check_net_ptr(the_net) == OK) BEGIN
            sprintf(IO_str,"\n   Save weights in FAST or PORTABLE format");
            IO_print(0);
            sprintf(IO_str, "(f/p, default=f)? ");
            IO_print(0);
            IO_my_get_string(tmp_str);
            if ((tmp_str[0] == 'p') || (tmp_str[0] == 'P'))
               t1 = PORTABLE_FORMAT;
            else t1 = FAST_FORMAT;
            sprintf(IO_str, "\n   Enter file name for storing weights");
            IO_print(0);
            IO_get_wts_name(t1);
            if (t1 == FAST_FORMAT)
               N_save_wts(the_net, net_fwt, t1);
            else N_save_wts(the_net, net_pwt, t1);
         ENDIF
         break;
      ENDCASE
      case 't': BEGIN
         if(check_net_ptr(the_net) == OK) BEGIN
            if (the_net->num_io_pairs <= 0) BEGIN
               sprintf(IO_str, "\n\n*** no valid set of io pairs to teach ***\n");
               IO_print(0);
            ENDIF
            else BEGIN
               sprintf(IO_str, "\n   Enter constraint error: ");
               IO_print(0);
               tf1 = IO_my_get_float();
               sprintf(IO_str, "\n   Enter max number of cycles(default=%d): ", MAX_CYCLES);
               IO_print(0);
               t1  = IO_get_num_cycles();
               sprintf(IO_str, "\n   Enter cycle increment for showing errors");
               IO_print(0);
               sprintf(IO_str, "(default=1): ");
               IO_print(0);
               t2  = IO_get_default_int(1);
               T_teach_net(the_net, C_float_to_Sint(tf1), t1, t2);
            ENDELSE
         ENDIF
         break;
      ENDCASE
      case 'w': BEGIN
         if (check_net_ptr(the_net) == OK) BEGIN
            sprintf(IO_str, "\n   source layer? ");
            IO_print(0);
            t1 = IO_my_get_int();
            sprintf(IO_str, "\n   target layer? ");
            IO_print(0);
            t2 = IO_my_get_int();
            S_show_weights(the_net, t1, t2);
         ENDIF
         break;
      ENDCASE
      default :
         break;
   ENDSWITCH

END /* evaluate */


int  check_net_ptr(ptr_net)
Net  *ptr_net;
/*
----------------------------------------------------------------------
 This small routine is simply a check to see whether or not the arg   
  passed in is actually a valid Net pointer.  If a net has not been   
  created, or if there was an error in its creation, then no          
  operations ought to be allowed using the net pointer.               
 Returns 0 if the ptr is INvalid, otherwise returns a 1.  Note that   
  a net which has an error will return a non-NULL pointer, but the ID 
  of such a net will be negative. (see return values of B_create_net) 
----------------------------------------------------------------------
*/
BEGIN
   if (ptr_net == NULL) BEGIN
      sprintf(IO_str, "\n\n*** no valid net exists ***\n");
      IO_print(0);
      return(0);
   ENDIF
   if (ptr_net->ID == ERROR) BEGIN
      sprintf(IO_str, "\n\n*** no valid net exists ***\n");
      IO_print(0);
      return(0);
   ENDIF
   return(OK);

END /* check_net_ptr */


void  cleanup()
/*
----------------------------------------------------------------------
   Before quitting the program entirely, this routine ensures that all
   the loose ends are accounted for. Right now, that consists of freeing
   all the memory currently used for the network.
----------------------------------------------------------------------
*/
BEGIN

   the_net = B_free_net(the_net);
   
END /* cleanup */

