//////////////////////////////////////////////////////////////////////////////
//
// File: nn_node.cc
//
// Purpose: Implementation for NeuNet's Node class.
//
// Authors:
//   txe  Travis Emmitt
//
// Modifications:
//   17-APR-1998  txe  Initial creation
//   18-APR-1998  txe  Added ComputeOutput(), SetError(), AddError()
//   19-APR-1998  txe  Cleaned up, added contrustor, destructor
//   20-APR-1998  txe  Purifying...
//   22-APR-1998  txe  Debugging Bus Error
//   23-APR-1998  txe  Using static debug, changed constructor
//
//////////////////////////////////////////////////////////////////////////////

#include <stdio.h>
#include <math.h>
#include <stdlib.h>
#include "common.h"
#include "nn_arch.h"
#include "nn_link.h"
#include "nn_node.h"

//////////////////////////////////////////////////////////////////////////////

Node::Node (char *name, Arch *arch, int num_in_links, int num_out_links)
    : Debug (name) {

  ASSERT (num_in_links  >= 0 && num_in_links  <= MAX_LINKS);
  ASSERT (num_out_links >= 0 && num_out_links <= MAX_LINKS);

  this->arch          = arch;
  this->num_in_links  = num_in_links;
  this->num_out_links = num_out_links;
  this->delta         = 0;
  this->error         = 0;
  this->level         = 0;

  if (num_in_links && (this->in_links = new Link *[num_in_links]) == NULL) {
    ERR << name << " couldn't create in_links[]; out of memory!\n";
    ERR << "  (array of size " << num_in_links << ")\n";
    exit (-1);
  }

  if (num_out_links && (this->out_links = new Link *[num_out_links]) == NULL) {
    ERR << name << " couldn't create out_links[]; out of memory!\n";
    ERR << "  (array of size " << num_out_links << ")\n";
    exit (-1);
  }

  int i;
  for (i = 0; i < num_in_links; i++) {
    in_links[i] = NULL;
  }
  for (i = 0; i < num_out_links; i++) {
    out_links[i] = NULL;
  }
}

//////////////////////////////////////////////////////////////////////////////

Node::~Node () {
  if (num_in_links > 0) {
    DEBUG(DEL) << "Destroying " << name << "'s in_links[]\n";
    ASSERT (in_links != NULL);
    DELETE_ARRAY in_links;
  }

  if (num_out_links > 0) {
    DEBUG(DEL) << "Destroying " << name << "'s out_links (there are "
	       << num_out_links << ")...\n";
    ASSERT (out_links != NULL);

    for (int i = 0; i < num_out_links; i++) {
      ASSERT (out_links[i] != NULL);
      delete out_links[i];
    }

    DEBUG(DEL) << "Destroying " << name << "'s out_links[]\n";
    DELETE_ARRAY out_links;
  }

  DEBUG(DEL) << "Destroying " << name << "\n";
}

//////////////////////////////////////////////////////////////////////////////

void Node::AddError (float error) {
  this->error += error;
  this->delta = error * level * (1.0 - level);
}

//////////////////////////////////////////////////////////////////////////////

#define MAX_TOTAL	500
#define MIN_TOTAL	(-MAX_TOTAL)

float Node::ComputeLevel () {
  DEBUG(4) << "Computing output level for " << name << "...\n";

  double total = 0.0;

  for (int i = 0; i < num_in_links; i++) {
    ASSERT (in_links[i] != NULL);
    total += (in_links[i]->source->level * in_links[i]->weight);
  }

  total = MAX (MIN_TOTAL, MIN (MAX_TOTAL, total));

  level = arch->act_numerator / (1 + exp (-total)) + arch->act_add;

  ASSERT (level >= 0.0 && level <= 1.0);
  return level;
}

//////////////////////////////////////////////////////////////////////////////

void Node::Print () {
  char buffer[100];
  float contrib;

  cout << "\n" << name << " level/activation = " << level
       << "\n"
       << "\n       I N P U T   L I N K S             O U T P U T  L I N K S         "
       << "\n     Node     Level  Weight Contrib       Node     Level  Weight Contrib"
       << "\n  ----------  -----  ------ -------    ----------  -----  ------ -------";
  
  for (int i = 0; i < MAX (num_in_links, num_out_links); i++) {
    if (i < num_in_links) {
      ASSERT (in_links[i]         != NULL);
      ASSERT (in_links[i]->source != NULL);
      contrib = in_links[i]->source->level * in_links[i]->weight;
      sprintf (buffer, "%-10s %5.2f * %5.2f = %5.2f",
	       in_links[i]->source->name, in_links[i]->source->level,
	       in_links[i]->weight, contrib);
    }
    else {
      sprintf (buffer, "%33s", " ");
    }
    cout << "\n  " << buffer << "     ";
    
    if (i < num_out_links) {
      ASSERT (out_links[i]       != NULL);
      ASSERT (out_links[i]->dest != NULL);
      contrib = level * out_links[i]->weight;
      sprintf (buffer, "%-10s %5.2f * %5.2f = %5.2f",
	       out_links[i]->dest->name, level,
	       out_links[i]->weight,     contrib);
      cout << buffer;
    }
  }
  cout << "\n\n";
}

//////////////////////////////////////////////////////////////////////////////

void Node::SetError (float error) {
  this->error = error;
  this->delta = error * level * (1.0 - level);
}

//////////////////////////////////////////////////////////////////////////////

void Node::SetLevel (float level) {
  this->level = level;
}

//////////////////////////////////////////////////////////////////////////////

void Node::SetLink (int in_id, int out_id, Node *dest) {
  ASSERT (dest != NULL);
  ASSERT (in_id  >= 0 && in_id  < dest->num_in_links);
  ASSERT (out_id >= 0 && out_id < this->num_out_links);
  ASSERT (dest->in_links[in_id]   == NULL);
  ASSERT (this->out_links[out_id] == NULL);

  char temp[NAME_LEN+1];
  sprintf (temp, "[%s -> %s]", name, dest->name);

  if ((out_links[out_id] = new Link (temp, arch, this, dest)) == NULL) {
    ERR << name << " couldn't create out_links[" << out_id
	<< "]; out of memory!\n";
    exit (-1);
  }

  dest->in_links[in_id] = out_links[out_id];
}

//////////////////////////////////////////////////////////////////////////////

