/***************************************************************************
                          ann_nfb_bprop.cpp  -  description
                             -------------------
    begin                : pon kwi 14 2003
    copyright            : (C) 2003 by Bartosz Lis
    email                : bartoszl@ics.p.lodz.pl
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/

#include <ann_nfb_bprop.h>

  //                      //
 // class ANN::NFB_bprop //
//                      //

static void
set_labels(ANN::Label &label, char *&l_in, char *&l_out, char *&l_des_out)
{
  size_t l=label.length();
  label.set_label("input",l);
  l_in=label.dup_label();
  label.set_label("outpt",l);
  l_out=label.dup_label();
  label.set_label("desrd",l);
  l_des_out=label.dup_label();
}

ANN::NFB_bprop::NFB_bprop(NFB &ne_, DS &ds_, double threshold_, Log *log_,
                          const char *label_)
: NFB(ne_.get_in_sizes(), log_, label_ ? label_ : "bprop"), ne(&ne_),
  des_output(&ds_), threshold(threshold_)
{
  Label inst(label);
  set_labels(inst,label_in,label_out,label_des_out);
  if (ne->fb_accept()) grad.resize(ne->get_out_sizes());
}

ANN::NFB_bprop::NFB_bprop(NFB &ne_, DS &ds_, double threshold_,
                          const char *label_)
: NFB(ne_.get_in_sizes(), label_ ? label_ : "backpropagation"), ne(&ne_),
  des_output(&ds_), threshold(threshold_)
{
  Label inst(label);
  set_labels(inst,label_in,label_out,label_des_out);
  if (ne->fb_accept()) grad.resize(ne->get_out_sizes());
}

ANN::NFB_bprop::NFB_bprop(Common_t &c, Instance_t inst_)
: NFB(*c.log,inst_), ne(0), des_output(c.ds), threshold(c.threshold)
{
  Init   inst(inst_);
  size_t l=inst.length();
  set_labels(inst,label_in,label_out,label_des_out);
  inst.set_label("network",l);
  ne=c.f->create(&inst);
  if (ne->fb_accept()) grad.resize(ne->get_out_sizes());
}

ANN::NFB_bprop::~NFB_bprop()
{
  if (ne) delete ne;
  if (label_in) delete [] label_in;
  if (label_out) delete [] label_out;
  if (label_des_out) delete [] label_des_out;
}

const ANN::Size &
ANN::NFB_bprop::get_out_sizes() const
{
  return ne->get_out_sizes();
}

size_t
ANN::NFB_bprop::get_out_size() const
{
  return ne->get_out_size();
}

ANN::NE::Status
ANN::NFB_bprop::feed(const Term &in, Term &out)
{
  return ne->feed(in,out);
}

void
ANN::NFB_bprop::register_weights(TO &torg_)
{
  NFB::register_weights(torg_);
  ne->register_weights(torg_);
}

void
ANN::NFB_bprop::reset(bool prod)
{
  ne->reset(prod);  
}

bool
ANN::NFB_bprop::load(Parser_log &parser, time_t begin, time_t end)
{
  return ne->load(parser,begin,end);
}

void
ANN::NFB_bprop::store(Log &log_)
{
  ne->store(log_);
}

void
ANN::NFB_bprop::randomize()
{
  ne->randomize();
}

ANN::NE::Status
ANN::NFB_bprop::commence(DS &ds_, bool is_input)
{
  Status ret;
  if ((ret=NFB::commence(ds_,is_input))==wrong) return ret;
  if (ds_.get_lengths()!=des_output->get_lengths()) return wrong;
  return ne->commence(ds_,is_input);
}

void
ANN::NFB_bprop::conclude()
{
  ne->conclude();
  if (log)
  {
    log->nl(label) << "last epoch error:" << error;
    log->nl();
  }
}

bool
ANN::NFB_bprop::fb_accept() const
{
  return false;
}

void
ANN::NFB_bprop::feed_back(const Term &in, const Term &out,
                          const Term *out_bp, Term *in_bp)
{
  des_output->locate(torg->where());
  double        d, e=0;
  double       *g=(ne->fb_accept() ? grad.get_data() : 0);
  const double *y=out.get_data(), *Y=des_output->term().get_data();
  size_t        i, n=grad.get_size();
  for (i=0; i<n; ++i) if ((Y[i]!=NAN) && Y)
  {
    d=y[i]-Y[i];
    if (g) g[i]=d;
    e+=d*d;
  }
  else if (g) g[i]=0;
  error+=e/2;
  if (log)
  {
    log->log(in,label_in);
    log->log(out,label_out);
    log->log(des_output->term(),label_des_out);
  }
  ne->feed_back(in,out,(g ? &grad : 0),in_bp);
}

void
ANN::NFB_bprop::open()
{
  error=0;
  des_output->reset();
  ne->open();
}

ANN::NE::Status
ANN::NFB_bprop::close()
{
  Status ret=ne->close() + (error<threshold);
  if (log)
  {
    log->nl(label) << "epoch error:" << error;
    log->nl();
  }
  return ret;
}

