//////////////////////////////////////////////////////////////////////////////
//
// File: school.cc
//
// Purpose: This is the implementation for the School class.
//
// Authors:
//   txe  Travis Emmitt
//
// Modifications:
//   20-APR-1998  txe  Initial creation, stuff moved from game.cc
//   21-APR-1998  txe  Completed construction
//   22-APR-1998  txe  Debugging...
//   23-APR-1998  txe  Using static debug, changed constructor
//   24-APR-1998  txe  Added FindBestArch(), FindBestWeights()
//   05-MAY-1998  txe  Changed breakdown labels to be consistent w/ Loss Bar
//
//////////////////////////////////////////////////////////////////////////////

#include <iostream.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include "common.h"
#include "debug.h"
#include "game.h"
#include "player.h"
#include "nplayer.h"
#include "school.h"

/////////////////////////////////////////////////////////////////////////

School::School (char *name, char *players_file, char *rules_file,
		char *best_root)
      : Game   (name, players_file, rules_file) {

  // get pointers to student //

  student = NULL;

  for (int i = 1; i <= num_players; i++) {
    ASSERT (players[i] != NULL);
    if (players[i]->name[0] == 'N') {
      student = (NeuralPlayer *) players[i];
      DEBUG(SCH) << name << " enrolled " << players[i]->name << "\n";
    }
  }

  if (student == NULL) {
    DEBUG(SCH) << name
	       << " doesn't have any students (need to be NeuralPlayers)!\n";
    exit (-1);
  }
  student->SetTraining (1);

  // Get a pointer to the student's arch so that we can mutate it //

  if ((arch = student->GetArch ()) == NULL) {
    ERR << name << " couldn't load student's NeuNet\n";
    exit (-1);
  }

  // create best_arch, which we'll use to save the best architecture //

  if ((best_arch = new Arch ("best arch")) == NULL) {
    ERR << name << " couldn't create best_arch; out of memory!\n";
    exit (-1);
  }

  ASSERT (best_root != NULL);

  sprintf (this->best_arch_file, "%.*s.cfg", MAX_LEN-4, best_root);
  sprintf (this->best_wts_file,  "%.*s.wts", MAX_LEN-4, best_root);

  this->mutation_rate  = DEFAULT_MUTATION_RATE;
  this->loss_rate_goal = DEFAULT_LOSS_RATE_GOAL;
}

/////////////////////////////////////////////////////////////////////////

int School::FindBestArch (int num_sessions, int num_matches, int num_games) {
  ASSERT (num_sessions > 0);
  ASSERT (num_matches  > 0);
  ASSERT (num_games    > 0);
  ASSERT (student   != NULL);
  ASSERT (best_arch != NULL);
  ASSERT (arch      != NULL);

  this->num_matches  = num_matches;
  this->num_games    = num_games;
  this->finding_arch = 1;

  best_loss_rate    = 100; 

  float loss_rate;

  DEBUG(0) << "\n\n" << name << " running " << num_sessions
	   << " sessions of " << num_matches << " matches, each of "
	   << num_games << " games\n\n";
  
  best_arch->Copy (arch);

  // loop through sessions //

  for (session = 1; session <= num_sessions; session++) {

    // Mutate before each session (except the first) //

    if (session > 1) {
      arch->Mutate (mutation_rate);
    }

    // run a session //

    if ((loss_rate = RunSession (num_matches, num_games)) < 0) {
      ERR << "Couldn't finish school session #" << session << "\n";
      return 0;
    }
    
    DEBUG(SCH) << "\nSession loss rate = " << (int) loss_rate << "% -- ";  

    // if the loss rate improved, update best_arch //

    if (loss_rate < best_loss_rate - LOSS_RATE_THRESHOLD) {
      DEBUG(SCH) << "IMPROVED!!!  (previous best was "
		 << (int) best_loss_rate << "%)\n";

      best_loss_rate = loss_rate;
      best_arch->Copy (arch);
      
      if (loss_rate <= loss_rate_goal) {
	DEBUG(SCH) << "\nSession loss rate goal of " << (int) loss_rate_goal
		   << "% was achieved!  (ending architecture search)\n";
	break;
      }
    }

    // if loss rate got worse, go back to best architecture //

    else if (loss_rate > best_loss_rate + LOSS_RATE_THRESHOLD) {
      DEBUG(SCH) << "got worse, restoring old arch...\n";
      arch->Copy (best_arch);
    }

    // if it (roughly) stayed the same, don't do anything //

    else {
      DEBUG(SCH) << "stayed within " << LOSS_RATE_THRESHOLD << "%\n";
    }
  }
  
  cout << "\nDone searching for optimal architecture.\n";
  return 1;
}

/////////////////////////////////////////////////////////////////////////

int School::FindBestWeights (int num_matches, int num_games) {
  ASSERT (num_matches > 0);
  ASSERT (num_games   > 0);

  best_loss_rate = 100;
  finding_arch   = 0;

  RunSession (num_matches, num_games);

  return 1;
}

/////////////////////////////////////////////////////////////////////////

void School::PrintLossBar () {
  int i, s = 1;
  long total_losses = 0;

  cout << "s" << session << ".m" << match << "\tLosses:    |";

  for (i = 0; i < NUM_PERIODS; i++) {
    total_losses += losses[i];
    for (s = s; s <= (total_losses * BAR_LEN / num_games); s++) {
      cout << (i % 10);
    }
  }

  cout << BLANKS (BAR_LEN-s+1) << "| "
       << (long) total_losses * 100 / (long) num_games << "%\n"
       << "s" << session << ".m" << match << "\tBreakdown:";

  for (i = 0; i < NUM_PERIODS; i++) {
    cout << " " << i << ":" << losses[i];
  }
  
  cout << "\tTotal:" << total_losses << "/" << num_games << "\n";
}

/////////////////////////////////////////////////////////////////////////

float School::RunSession (int num_matches, int num_games) {
  ASSERT (num_matches > 0);
  ASSERT (num_games   > 0);
  ASSERT (ref     != NULL);

  long  total_games      = (num_matches * num_games);
  long  games_per_period = (long) ((float) num_games / NUM_PERIODS + .99);
  long  best_losses      = (long) (best_loss_rate * total_games) / 100;
  long  num_losses       = 0;
  long  total_losses     = 0;
  float best_rate        = 100;  // best rate for a match //
  float loss_rate        = 0;    // loss rate for a match //

  if (finding_arch) {
    DEBUG(SCH) << "\n****************************** Session #" << session
	       << " *************************\n";
  }
  else {
    DEBUG(SCH) << "\n*************** Final Session (looking for best weights) "
	       << " ****************\n";
  }

  arch->Print (SCH);

  DEBUG(SCH) << "Splitting matches (" << num_games << " games each) into "
	     << NUM_PERIODS << " periods of " << games_per_period
	     << " games each\n";
  
  // Rebuild NeuNet for each session (because of the change in arch) //

  if (!student->BuildNet ()) {
    ERR << name << " couldn't get student to build NeuNet\n";
    return 0;
  }

  // Loop through matches until either done or too many losses //

  for (match = 1; match <= num_matches; match++) {
    DEBUG(SCH+1) << "\nMatch #" << match << "\n";
    ref->NewMatch ();
    
    num_losses = 0;
    
    for (int j = 0; j < NUM_PERIODS; j++) {
      DEBUG(SCH+1) << "\nPeriod #" << j + 1 << "\n";
      ref->ResetScores ();
      
      for (int k = 0; k < games_per_period; k++) {
	if (ref->NewGame () == QUIT) {
	  DEBUG(0) << name << " was asked to quit\n";
	  return QUIT;
	}
	PrintWins (1);
      }
      
      losses[j] = student->GetNumLosses ();
      num_losses += losses[j];
    }

    PrintLossBar ();

    // If still searching for best arch, we might be able to abort early //

    if (finding_arch) {
      total_losses += num_losses;       // change this later (losses[9])

      if (total_losses > best_losses && match < num_matches) {
	DEBUG(SCH) << "  Total session losses (" << total_losses
		   << ") already " << "exceed best session's losses ("
		   << best_losses << ", aborting session early\n";
	return 100;
      }
    }

    // If we've got the best arch and are looking for the best weight... //

    else {
      loss_rate = (float) (num_losses * 100 / num_games);

      if (loss_rate < best_rate - LOSS_RATE_THRESHOLD) {
	DEBUG(SCH) << "  Best loss rate yet for this architecture (" 
		   << (int) loss_rate << "% beats " << (int) best_rate
		   << "%)\n  Saving architecture to '" << best_arch_file
		   << "' and weights to '" << best_wts_file << "'...\n";

        if (!student->Save (best_arch_file, best_wts_file, loss_rate)) {
	  ERR << "Couldn't save NeuNet arch and weights\n";
	  return QUIT;
	}
	best_rate = loss_rate;
      }
    }
  }
  
  // Calculate and return session loss rate //

  return (float) (total_losses * 100 / total_games);
}

/////////////////////////////////////////////////////////////////////////

void School::SetMutationRate (float mutation_rate) {
  DEBUG(SCH) << name << " setting Arch's mutation rate to "
	     << mutation_rate << "\n";
  this->mutation_rate = mutation_rate;
}

/////////////////////////////////////////////////////////////////////////

void School::SetLossRateGoal (float loss_rate_goal) {
  DEBUG(SCH) << name << " setting loss rate goal to "
	     << loss_rate_goal << "%\n";
  this->loss_rate_goal = loss_rate_goal;
}

/////////////////////////////////////////////////////////////////////////

