/* stats.c - statistical tools for analysing the text
 * 
 * This program is part of Crank, a cryptanalysis tool
 * Copyright (C) 2000 Matthew Russell
 *
 * This program is free software; you can redistribute it and/or modify it 
 * under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License (LICENSE) for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307
 * USA
 */

#include "crank.h"

/* Global variables */
/*                           {slft, bift, trift} */
const float error_weight[] = {1.0,  1.0,  2.0};
stats *text_stats;
float *slft_std;
float *bift_std;
float *trift_std;

int randnum (int a, int b) {
  return (a + (rand() % (b - a + 1)));
}

float total_error(float slft_error, float bift_error, float trift_error) {
    return (error_weight[0] * slft_error + error_weight[1] * bift_error + error_weight[2] * trift_error);
}

float slft_error(float *slft_std, float *slft_sample) {
    int i; float error = 0.0, diff;
    for (i = 'A'; i <= 'Z'; i++) {
	diff = slft_std[i] - slft_sample[i];
	error += diff * diff;
    }
    return error;
}

float bift_error(float *bift_std, float *bift_sample) {
    int i, j; float error = 0.0, diff;
    for (i = 'A'; i <= 'Z'; i++) {    
	for (j = 'A'; j <= 'Z'; j++) {
	diff = (bift_std + 26 * i)[j] - (bift_sample + 26 * i)[j]; 
	error += diff * diff;
	}
    }
    return error;
}

float trift_error(float *trift_std, float *trift_sample) {
    int i, j, k; float error = 0.0, diff;
    for (i = 'A'; i <= 'Z'; i++) {    
	for (j = 'A'; j <= 'Z'; j++) {
	    for (k = 'A'; k <= 'Z'; k++) {
		diff = (trift_std + 26 * 26 * i + 26 * j)[k] - (trift_sample + 26 * 26 * i + 26 * j)[k]; 
		error += diff * diff;
	    }
	}
    }
    return error;
}

/* Make frequency tables */
int make_ft(char *text, float *slft, float *bift, float *trift) {

    int length = strlen(text), i, j, k, slft_total = 0, bift_total = 0, trift_total = 0;
    char c, pc = 0, ppc = 0;
    int islft[(int) 'Z' + 1];
    int ibift[(int) 'Z' + 1][(int) 'Z' + 1];
    int itrift[(int) 'Z' + 1][(int) 'Z' + 1][(int) 'Z' + 1];
    for (i = (int) 'A'; i <= (int) 'Z'; i++) {
	for (j = (int) 'A'; j <= (int) 'Z'; j++) {
	    for (k = (int) 'A'; k <= (int) 'Z'; k++)
		itrift[i][j][k] = 0;
	    ibift[i][j] = 0;
	}
	islft[i] = 0;
    }
    for (i = 0; i < length; i++) {
	c = text[i];
	if (!isalpha(c))
	    continue;
	c = toupper(c);
	
	islft[(int) c] += 1; slft_total++;
	if (pc) {
	    ibift[(int) pc][(int) c] += 1;
	    bift_total++;
	}
	if (pc & ppc) {
	    itrift[(int) ppc][(int) pc][(int) c] += 1;
	    trift_total++;
	}
	ppc = pc; pc = c;
    }
    for (i = (int) 'A'; i <= (int) 'Z'; i++) {
	for (j = (int) 'A'; j <= (int) 'Z'; j++) {
	    for (k = (int) 'A'; k <= (int) 'Z'; k++) 
		(trift + 26 * 26 * i + 26 * j)[k] = (float) itrift[i][j][k] / (float) trift_total;
	    (bift + 26 * i)[j] = (float) ibift[i][j] / (float) bift_total;
	}
	slft[i] = (float) islft[i] / (float) slft_total;
    }	
    return slft_total; /* i.e. letter_count */
}

void dup_ft(stats *s, float *slft, float *bift, float *trift) {
    int i,j,k;
    for (i = (int) 'A'; i <= (int) 'Z'; i++) {
	for (j = (int) 'A'; j <= (int) 'Z'; j++) {
	    for (k = (int) 'A'; k <= (int) 'Z'; k++) 
		(trift + 26 * 26 * i + 26 * j)[k] = (s->trift + 26 * 26 * i + 26 * j)[k];
	    (bift + 26 * i)[j] = (s->bift + 26 * i)[j];
	}
	slft[i] = s->slft[i];
    }	
}

/* Assumes key and stats are complete */
stats *transform_stats_with_key(stats *s, key *key) {
    int i, j, k, it, jt, kt;
    const int to_upper = 'A' - 'a';
    float *slft = malloc(((int) 'Z' + 1) * sizeof(float));
    float *bift = malloc(((int) 'Z' + 1) * ((int) 'Z' + 1)  * sizeof(float));
    float *trift = malloc(((int) 'Z' + 1) * ((int) 'Z' + 1) * ((int) 'Z' + 1) * sizeof(float));
    stats *new_stats = malloc( sizeof(stats));
    
    for (i = (int) 'A'; i <= (int) 'Z'; i++) {
	it = (*key)[i] + to_upper;
	for (j = (int) 'A'; j <= (int) 'Z'; j++) {
	    jt = (*key)[j] + to_upper;
	    for (k = (int) 'A'; k <= (int) 'Z'; k++) {
		kt = (*key)[k] + to_upper;
		(trift + 26 * 26 * it + 26 * jt)[kt] = (s->trift + 26 * 26 * i + 26 * j)[k] ;
	    }
	    (bift + 26 * it)[jt] = (s->bift + 26 * i)[j]; 
	}
	slft[it] = s->slft[i];
    }
    new_stats->slft = slft;
    new_stats->bift = bift;
    new_stats->trift = trift;

    new_stats->slft_error = slft_error(slft_std, new_stats->slft);
    new_stats->bift_error = bift_error(bift_std, new_stats->bift);
    new_stats->trift_error = trift_error(trift_std, new_stats->trift);
    new_stats->total_error = total_error(new_stats->slft_error, new_stats->bift_error, new_stats->trift_error);
    new_stats->letter_count =  s->letter_count;
    return new_stats;
}

/* Calculate stats for a section of text */
stats *make_stats(char *text) {
    float *slft = malloc(((int) 'Z' + 1) * sizeof(float));
    float *bift = malloc(((int) 'Z' + 1) * ((int) 'Z' + 1)  * sizeof(float));
    float *trift = malloc(((int) 'Z' + 1) * ((int) 'Z' + 1) * ((int) 'Z' + 1) * sizeof(float));
   
    stats *new_stats = malloc( sizeof(stats));
    new_stats->letter_count = make_ft(text, slft, bift, trift);
    new_stats->slft = slft;
    new_stats->bift = bift;
    new_stats->trift = trift;

    new_stats->slft_error = slft_error(slft_std, new_stats->slft);
    new_stats->bift_error = bift_error(bift_std, new_stats->bift);
    new_stats->trift_error = trift_error(trift_std, new_stats->trift);
    new_stats->total_error = total_error(new_stats->slft_error, new_stats->bift_error, new_stats->trift_error);

    return new_stats;
}

void free_stats(stats *the_stats) {
    if (!the_stats)
	return;
    if (the_stats->slft)
	free(the_stats->slft);
    if (the_stats->bift)
	free(the_stats->bift);
    if (the_stats->trift)
	free(the_stats->trift);
    free(the_stats);
}

stats *dup_stats(stats *the_stats) {
    float *slft = malloc(((int) 'Z' + 1) * sizeof(float));
    float *bift = malloc(((int) 'Z' + 1) * ((int) 'Z' + 1)  * sizeof(float));
    float *trift = malloc(((int) 'Z' + 1) * ((int) 'Z' + 1) * ((int) 'Z' + 1) * sizeof(float));
    
    stats *new_stats = malloc( sizeof(stats));

    dup_ft(the_stats, slft, bift, trift);

    new_stats->slft = slft;
    new_stats->bift = bift;
    new_stats->trift = trift;

    new_stats->slft_error = the_stats->slft_error;
    new_stats->bift_error = the_stats->bift_error;
    new_stats->trift_error = the_stats->trift_error;
    new_stats->total_error = the_stats->total_error;
    new_stats->letter_count = the_stats->letter_count;
    return new_stats;
}
