//////////////////////////////////////////////////////////////////////////////
//
// File: nplayer.cc
//
// Purpose: Routines for managering the NeuralPlayer class.
//
// Authors:
//   txe  Travis Emmitt
//
// Modifications:
//   16-APR-1998  txe  Initial creation
//   17-APR-1998  txe  Using one hidden level
//   18-APR-1998  txe  Added Peek(), NewGame(), peek_watch to Feedback()
//   19-APR-1998  txe  Using new NeuNet Architecture class
//   20-APR-1998  txe  Ported to UNIX, added LoadArch(), Purifying...
//   21-APR-1998  txe  Added GetArch(), Load(), Save()
//   22-APR-1998  txe  Debugging...
//   23-APR-1998  txe  Added BuildNet(), ResetWeights(), using static debug
//   24-APR-1998  txe  Debugging, removed ability
//
//////////////////////////////////////////////////////////////////////////////

#include <iostream.h>
#include <stdio.h>
#include <stdlib.h>
#include "common.h"
#include "board.h"
#include "player.h"
#include "referee.h"
#include "neunet.h"
#include "nn_arch.h"
#include "nplayer.h"

///////////////////////////////////////////////////////////////////////////

NeuralPlayer::NeuralPlayer (char *name, int color, char *arch_file,
			    char *wts_file)
            : Player       (name, color) {

  ASSERT (arch_file != NULL);
  ASSERT (wts_file  != NULL);

  num_inputs  = 0;
  num_outputs = 0;
  training    = 0;
  nn          = NULL;

  DEBUG(0) << name << " building architecture data structure\n";

  if ((arch = new Arch ("NN Arch")) == NULL) {
    ERR << name << " couldn't create arch; out of memory!\n";
    exit (-1);
  }
  
  DEBUG(0) << name << " loading architecture from '" << arch_file << "'\n";
  
  if (!arch->Load (arch_file)) {
    ERR << name << " couldn't load arch_file\n";
    exit (-1);
  }
  
  sprintf (this->arch_file, "%.*s", MAX_LEN, arch_file);
  sprintf (this->wts_file,  "%.*s", MAX_LEN, wts_file);
  
  if (wts_file[0]) {
    DEBUG(0) << name << " will (try to) use weights in '" << wts_file << "'\n";
  }
  else {
    DEBUG(0) << name << " will use random weights\n";
  }
}

///////////////////////////////////////////////////////////////////////////

int NeuralPlayer::BuildNet () {
  num_inputs  = (size_x * size_y * num_players);
  num_outputs = (size_x * size_y);

  if (nn != NULL) {
    DEBUG(0) << "\n" << name << " replacing its old NeuNet...\n";
    delete nn;
  }
  else {
    DEBUG(0) << "\n" << name << " creating its first NeuNet\n";
  }

  if ((nn = new NeuNet ("NeuNet", arch, num_inputs, num_outputs)) == NULL) {
    ERR << name << " couldn't create nn; out of memory!\n";
    return 0;
  }

  return ResetWeights ();
}

///////////////////////////////////////////////////////////////////////////

void NeuralPlayer::Feedback (int code) {
  Player::Feedback (code);

#if RESTORE_THIS_LATER
  if (!training) {
    DEBUG(2) << name << " not training, so no changes to NeuNet\n";
    return;
  }
#endif

  DEBUG(2) << name << " training, updating NeuNet...\n";

  ASSERT (nn != NULL);
  char line[MAX_LEN+1] = "";

  if (code == INVALID || code == LOSE) {
	for (int i = 0; i < num_outputs; i++) {
	  nn->SetOutputGoal (i, i != (move_y * size_x) + move_x);
    }
    
    if (peek_watch) {
      cout << "\n" << name
	   << " reinforcing, peek? [y/n/'a'fter/'q'uit peeking]: ";
      gets (line);
      switch (line[0]) {
        case 'y' : nn->Peek ();    break;
        case 'q' : peek_watch = 0; break;
      }
    }
    
    nn->Reinforce ();
    
    if (line[0] == 'y' || line[0] == 'a') {
      DEBUG(0) << "Peeking after reinforcement...\n";
      nn->Run ();
      nn->Peek ();
    }
  }
}

///////////////////////////////////////////////////////////////////////////

Arch *NeuralPlayer::GetArch () {
  return arch;
}

///////////////////////////////////////////////////////////////////////////

int NeuralPlayer::GetMove () {
  ASSERT (nn != NULL);

  int i = 0, j, color, x, y;
  float max_level = -9999, level;

 // Feed the NN the current board state //

  ITERATE (, x, y, size_x, size_y) {
    color = board->GetColor (x, y);
    for (j = 0; j < num_players; j++) {
      nn->SetInputLevel (i++, (j == color ? 1 : 0));
    }
  }
  
  // Get output from the NN //
  
  nn->Run ();

  i = 0;
  ITERATE (, x, y, size_x, size_y) {
    if ((level = nn->GetOutputLevel (i++)) > max_level) {
      max_level = level;
      move_x = x;
      move_y = y;
    }
  }
  
  DEBUG(2) << name << "'s max output level = " << max_level << "\n";

  if (max_level <= 0) {
    DEBUG(2) << name << " couldn't find a move, passing\n";
    return PASS;
  }

  return 1;
}

//////////////////////////////////////////////////////////////////////////////

int NeuralPlayer::NewGame () {
  Player::NewGame ();
  ASSERT (nn != NULL);
  nn->Print (2);
  return 1;
}

//////////////////////////////////////////////////////////////////////////////

int NeuralPlayer::NewMatch () {
  Player::NewMatch ();
  DEBUG(2) << name << " starting new match (calling Load)\n";

  if (nn == NULL) {
    return BuildNet ();
  }

  return ResetWeights ();
}

///////////////////////////////////////////////////////////////////////////

void NeuralPlayer::Peek () {
  ASSERT (nn != NULL);
  peek_watch = 1;
  nn->Peek ();
}

///////////////////////////////////////////////////////////////////////////

int NeuralPlayer::ResetWeights () {
  if (!training && wts_file[0]) {
    DEBUG(1) << name << " loading weights from file\n";
    if (!nn->Load (wts_file)) {
      ERR << name << " couldn't load NeuNet's weights\n";
      return 0;
    }
  }
  else {
    DEBUG(1) << name << " randomizing NeuNet weights...\n";
    nn->RandomizeWeights ();
  }
  return 1;
}

///////////////////////////////////////////////////////////////////////////

int NeuralPlayer::Save () {
  return Save (this->arch_file, this->wts_file, 100);
}

///////////////////////////////////////////////////////////////////////////

int NeuralPlayer::Save (char *arch_file, char *wts_file, float loss_rate) {
  ASSERT (arch_file != NULL);
  ASSERT (wts_file  != NULL);
  ASSERT (arch      != NULL);
  ASSERT (nn        != NULL);

  if (!arch->Save (arch_file)) {
    ERR << name << " couldn't save NN architecture file\n";
    return 0;
  }

  if (!nn->Save (wts_file, loss_rate)) {
    ERR << name << " couldn't save NN weights file\n";
    return 0;
  }

  return 1;
}

///////////////////////////////////////////////////////////////////////////

void NeuralPlayer::SetTraining (int training) {
  DEBUG(0) << name << " is " << (training ? "" : "not ") << "training...\n";
  this->training = training;
}

///////////////////////////////////////////////////////////////////////////

