/* Evaluate the probabilistic distribution of hitpoints for one
 * Wesnoth units (the attacker) in combat with another (the defender).
 *
 * THE PROBLEM:
 * 
 * Unit A fights unit B.  Each has 'hp' hitpoints, 'num_attacks'
 * attacks, each of which does 'damage' hitpoints of damage, and has a
 * probability 'hit_chance' of actually hitting.  A attacks first, if
 * B survives (ie. hp still > 0), it attacks back, alternating until
 * all attacks are done.  So if A has 2 attacks and B has 3, attacks go
 * A->B, B->A, A->B, B->A, B->A.
 * 
 * There are three twists:
 * (1) Some units can "drain" other units: when they hit the opponent,
 *     they gain half of the damage they did (up to their maximum).
 * 
 * (2) Some units can "slow" other units: the first time they hit
 *     their opponent, that opponent begins doing half its normal damage.
 *
 * (3) Some units are "berserk".  If either unit is berserk, the combat
 *     is repeated up to 30 times or (as always) until someone dies.
 *
 * THE ALGORITHM:
 * 
 * The algorithm uses a matrix of probabilities, # hitpoints A x #
 * hitpoints B.  This starts with all zeroes, except [HP A, HP B]
 * which is 1.0.  When A attacks B, it has a certain probability of
 * hitting ('hit_chance') and doing damage ('damage'), so we move
 * 'hit_chance' portion of that probability down 'damage' rows. eg:
 *
 * A has 4 HP, B has 5 HP.  A does 2 damage, 30% chance of hitting.
 * So we move 30% of every element across two rows:
 *
 * 0.0  0.0  0.0  0.0  0.0       0.0  0.0  0.0  0.0  0.0
 * 0.0  0.0  0.0  0.0  0.0   =>  0.0  0.0  0.0  0.0  0.0
 * 0.0  0.0  0.0  0.0  0.0       0.0  0.0  0.0  0.0  0.0
 * 0.0  0.0  0.0  0.0  1.0       0.0  0.0  0.3  0.0  0.7
 *
 * Now B attacks.  B does 7 damage with a 10% chance of hitting.  You
 * can see we never move below 0 on the matrix:
 * 
 * 0.0  0.0  0.0  0.0  0.0       0.0  0.0  0.03 0.0  0.07
 * 0.0  0.0  0.0  0.0  0.0   =>  0.0  0.0  0.0  0.0  0.0
 * 0.0  0.0  0.0  0.0  0.0       0.0  0.0  0.0  0.0  0.0
 * 0.0  0.0  0.3  0.0  0.7       0.0  0.0  0.27 0.0  0.63
 * 
 * If a unit drains, the elements are moved on the other axis as well,
 * to reflect the striker's increase (up to the maximum).  If B
 * drained, we would instead see:
 * 
 * 0.0  0.0  0.0  0.0  0.0       0.0  0.0  0.0  0.0  0.1
 * 0.0  0.0  0.0  0.0  0.0   =>  0.0  0.0  0.0  0.0  0.0
 * 0.0  0.0  0.0  0.0  0.0       0.0  0.0  0.0  0.0  0.0
 * 0.0  0.0  0.3  0.0  0.7       0.0  0.0  0.27 0.0  0.63
 * 
 * We simply repeat this algorithm until all attacks are done.
 *
 * If a unit hits another, and has 'slow' ability, we keep a separate
 * matrix which we add the result to.  We refer to this as a separate
 * 'plane'.  On their 'slow' plane, units will only do half the damage
 * they normally do, so we move numbers by half as much.  There are
 * four planes: NEITHER_SLOWED, A_SLOWED, B_SLOWED and BOTH_SLOWED.
 *
 * Algorithm by Yogin.  Typing by Rusty.
 */
#define _GNU_SOURCE
#include <stdbool.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <sys/time.h>
#include <time.h>

struct matrix;

#define QUIET
#ifndef QUIET
#define debug printf

static void graph_prob(unsigned int hp, double prob)
{
        unsigned int i;

        printf("%3u %2u%% |", hp, (int)(prob*100));
        for (i = 0; i < prob * 100; i++)
                printf("#");
        printf("\n");
}

static void draw_results(double res_a[], unsigned amax,
                         double res_b[], unsigned bmax)
{
        unsigned int i;

        printf("Attacker HP:\n");
        for (i = 0; i <= amax; i++)
                graph_prob(amax - i, res_a[amax - i]);
                
        printf("Defender HP:\n");
        for (i = 0; i <= bmax; i++)
                graph_prob(bmax - i, res_b[bmax - i]);
}

static void draw_matrix(const struct matrix *matrix)
{
        unsigned int row, col, m;
        const char *names[] __attribute__((unused))
                = { "NEITHER_SLOWED", "A_SLOWED", "B_SLOWED", "BOTH_SLOWED" };

        for (m = 0; m < 4; m++) {
                if (!matrix->plane[m])
                        continue;
                debug("%s:\n", names[m]);
                for (row = 0; row < matrix->rows; row++) {
                        debug("  ");
                        for (col = 0; col < matrix->cols; col++)
                                debug("%4.3g", *val(matrix, m, row, col)*100);
                        debug("\n");
                }
        }
}
#else
#define debug(...)
static void draw_matrix(const struct matrix *matrix)
{
}
static void draw_results(double res_a[], unsigned amax,
                         double res_b[], unsigned bmax)
{
}
#endif

/* from the Linux Kernel:
 * min()/max() macros that also do
 * strict type-checking.. See the
 * "unnecessary" pointer comparison.
 */
#define max(x,y) ({ \
        typeof(x) _x = (x);     \
        typeof(y) _y = (y);     \
        (void) (&_x == &_y);            \
        _x > _y ? _x : _y; })

struct unit
{
        unsigned damage, num_attacks, hp, max_hp;
        double hit_chance;
        bool slows, drains, berserk;
};

/* We need four matrices, two for each unit, reflecting the possible
 * "slowed" states (neither slowed, A slowed, B slowed, both slowed)
 * We refer to these as different "planes". */
#define NEITHER_SLOWED 0
#define A_SLOWED 1
#define B_SLOWED 2
#define BOTH_SLOWED 3

struct matrix
{
        unsigned int rows, cols;
        double *plane[4];
};

/* Allocate a new array, initialized. */
static double *new_arr(unsigned int size)
{
        return calloc(size, sizeof(double));
}

static struct matrix *new_matrix(unsigned int a_max_hp,
                                 unsigned int b_max_hp,
                                 bool a_slows, bool b_slows)
{
        struct matrix *m;
        m = malloc(sizeof(struct matrix));
        m->rows = a_max_hp+1;
        m->cols = b_max_hp+1;
        m->plane[0] = new_arr(m->rows*m->cols);
        if (b_slows)
                m->plane[1] = new_arr(m->rows*m->cols);
        else
                m->plane[1] = NULL;
        if (a_slows)
                m->plane[2] = new_arr(m->rows*m->cols);
        else
                m->plane[2] = NULL;
        if (a_slows || b_slows)
                m->plane[3] = new_arr(m->rows*m->cols);
        else
                m->plane[3] = NULL;

        return m;
}

static void free_matrix(struct matrix *m)
{
        free(m->plane[0]);
        free(m->plane[1]);
        free(m->plane[2]);
        free(m->plane[3]);
        free(m);
}

static double *val(struct matrix *m, unsigned plane,
                   unsigned row, unsigned col)
{
        return &m->plane[plane][row * m->cols + col];
}

static void xfer(struct matrix *m,
                 unsigned dst_plane, unsigned src_plane,
                 unsigned row_dst, unsigned col_dst,
                 unsigned row_src, unsigned col_src,
                 double prob)
{
        double *src, *dst;
        double diff;

        /* FIXME: This is here for drain. */
        if (col_dst >= m->cols)
                col_dst = m->cols - 1;
        if (row_dst >= m->rows)
                row_dst = m->rows - 1;

        src = val(m, src_plane, row_src, col_src);
        dst = val(m, dst_plane, row_dst, col_dst);
        diff = *src * prob;
        *src -= diff;
        *dst += diff;

        if (diff)
                debug("Shifted %4.3g from (%u,%u) to (%u,%u)\n",
                      diff, row_src, col_src, row_dst, col_dst);
}

static void shift_matrix_cols(struct matrix *matrix,
                              unsigned dst, unsigned src,
                              unsigned damage, double prob, bool drain)
{
        unsigned int row, col;

        if (damage >= matrix->cols)
                damage = matrix->cols - 1;

        /* Loop backwards so we write drain behind us, for when src == dst. */
        for (row = matrix->rows - 1; row > 0; row--) {
                /* These are all going to die (move to col 0). */
                for (col = 1; col <= damage; col++) {
                        unsigned int drain_off = drain ? col/2 : 0;
                        xfer(matrix, dst, src,
                             row + drain_off, 0, row, col, prob);
                }
                for (col = damage+1; col < matrix->cols; col++) {
                        unsigned int drain_off = drain ? damage/2 : 0;
                        xfer(matrix, dst, src, 
                             row + drain_off, col - damage, row, col, prob);
                }
        }
}

/* Shift matrix to reflect probability 'hit_chance' that damage (up
 * to) 'damage' is done to 'b'.
 */
static void receive_blow_b(struct matrix *matrix, 
                           unsigned damage, double hit_chance,
                           bool a_slows, bool a_drains)
{
        int src, dst;

        /* Walk backwards so we don't copy already-copied matrix planes. */
        for (src = 3; src >=0; src--) {
                /* A is slow in planes 1 and 3. */
                bool a_is_slow = (src & 1);

                if (!matrix->plane[src])
                        continue;

                /* If a slows us we go from 0=>2, 1=>3, 2=>2 3=>3. */
                if (a_slows)
                        dst = (src|2);
                else
                        dst = src;

                shift_matrix_cols(matrix, dst, src,
                                  a_is_slow ? damage/2 : damage,
                                  hit_chance,
                                  a_drains);
        }
}

static void shift_matrix_rows(struct matrix *matrix,
                              unsigned dst, unsigned src,
                              unsigned damage, double prob, bool drain)
{
        unsigned int row, col;

        if (damage >= matrix->rows)
                damage = matrix->rows - 1;

        /* Loop downwards so if we drain, we write behind us. */
        for (col = matrix->cols - 1; col > 0; col--) {
                /* These are all going to die (move to row 0). */
                for (row = 1; row <= damage; row++) {
                        unsigned int drain_off = drain ? row/2 : 0;
                        xfer(matrix, dst, src, 0, col + drain_off,
                             row, col, prob);
                }
                for (row = damage+1; row < matrix->rows; row++) {
                        unsigned int drain_off = drain ? damage/2 : 0;
                        xfer(matrix, dst, src, row - damage, col + drain_off,
                             row, col, prob);
                }
        }
}

/* Shift matrix to reflect probability 'hit_chance' that damage (up
 * to) 'damage' is done to 'a'.
 */
static void receive_blow_a(struct matrix *matrix, 
                           unsigned damage, double hit_chance,
                           bool b_slows, bool b_drains)
{
        int src, dst;

        /* Walk backwards so we don't copy already-copied matrix planes. */
        for (src = 3; src >=0; src--) {
                /* B is slow in planes 2 and 3. */
                bool b_is_slow = (src & 2);

                if (!matrix->plane[src])
                        continue;

                /* If b slows us we go from 0=>1, 1=>1, 2=>3 3=>3. */
                if (b_slows)
                        dst = (src|1);
                else
                        dst = src;

                shift_matrix_rows(matrix, dst, src, 
                                  b_is_slow ? damage/2 : damage, hit_chance,
                                  b_drains);
        }
}

// A attacks B.  Who wins?
static void calculate_attack(const struct unit *a,
                             const struct unit *b,
                             double final_a[], double final_b[])
{
        struct matrix *matrix = new_matrix(a->max_hp, b->max_hp, a->slows, b->slows);
        unsigned int row, col, m, berserk;

        *val(matrix, NEITHER_SLOWED, a->hp, b->hp) = 1.0;

        if (a->berserk || b->berserk)
                berserk = 30;
        else
                berserk = 0;

        do {
                unsigned int i;
                for (i = 0; i < max(a->num_attacks, b->num_attacks); i++) {
                        if (i < a->num_attacks) {
                                debug("A strikes\n");
                                receive_blow_b(matrix, a->damage,
                                               a->hit_chance, a->slows,
                                               a->drains);
                                draw_matrix(matrix);
                        }
                        if (i < b->num_attacks) {
                                debug("B strikes\n");
                                receive_blow_a(matrix, b->damage,
                                               b->hit_chance, b->slows,
                                               b->drains);
                                draw_matrix(matrix);
                        }
                }
        } while (berserk--);

        debug("Combat ends:\n");
        draw_matrix(matrix);

        /* Sum rows and columns to give final results */
        for (m = 0; m < 4; m++) {
                if (!matrix->plane[m])
                        continue;
                for (row = 0; row <= a->max_hp; row++) {
                        for (col = 0; col <= b->max_hp; col++) {
                                final_a[row] += *val(matrix, m, row, col);
                                final_b[col] += *val(matrix, m, row, col);
                        }
                }
        }
        free_matrix(matrix);
}

/* We create a significant number of nasty-to-calculate units, and
 * test each one against the others. */
#define NUM_UNITS 100
int main(int argc, char *argv[])
{
        /* N^2 battles. */
        struct unit u[NUM_UNITS];
        unsigned int i, j;
        struct timeval start, end;

        printf("Creating %i units...\n", NUM_UNITS);
        for (i = 0; i < NUM_UNITS; i++) {
                u[i].hp = i/2 + (i%20);
                u[i].max_hp = u[i].hp + i % 4;
                u[i].damage = (i % 7) + 2;
                u[i].num_attacks = (i % 4) + 1;
                u[i].slows = (i % 2);
                u[i].drains = (i % 9) == 0;
                u[i].berserk = (i % 5) == 0;
                u[i].hit_chance = 0.3 + (i % 6)*0.1;
        }

        printf("Beginning battle...\n");
        gettimeofday(&start, NULL);
        for (i = 0; i < NUM_UNITS; i++) {
                double i_result[u[i].max_hp+1];
                memset(i_result, 0, sizeof(i_result));
                for (j = 0; j < NUM_UNITS; j++) {
                        double j_result[u[j].max_hp+1];
                        memset(j_result, 0, sizeof(j_result));
                        calculate_attack(&u[i], &u[j], i_result, j_result);
                        draw_results(i_result, u[i].max_hp,
                                     j_result, u[j].max_hp);
                }
                printf("."); fflush(stdout);
        }
        gettimeofday(&end, NULL);

        if (end.tv_usec < start.tv_usec) {
                end.tv_usec += 1000000;
                end.tv_sec--;
        }
        printf("\nTotal time for %i combats was %lu.%06lu\n",
               NUM_UNITS*NUM_UNITS,
               end.tv_sec - start.tv_sec, end.tv_usec - start.tv_usec);
        printf("Time per calc = %li us\n", 
               ((end.tv_sec - start.tv_sec)*1000000
                + (end.tv_usec - start.tv_usec)) / (NUM_UNITS*NUM_UNITS));
        exit(0);
}
