/***************************************************************************
                          ann_ne_bprop.cpp  -  description
                             -------------------
    begin                : pon kwi 14 2003
    copyright            : (C) 2003 by Bartosz Lis
    email                : bartoszl@ics.p.lodz.pl
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/

#include <ann_ne_bprop.h>

  //                     //
 // class ANN::NE_bprop //
//                     //

static void
set_labels(ANN::Label &label, char *&l_in, char *&l_out, char *&l_des_out)
{
  size_t l=label.length();
  label.set_label("input",l);
  l_in=label.dup_label();
  label.set_label("outpt",l);
  l_out=label.dup_label();
  label.set_label("desrd",l);
  l_des_out=label.dup_label();
}

ANN::NE_bprop::NE_bprop(NE &ne_, double threshold_, Log *log_, const char *label_)
: NE(ne_.get_in_sizes(), log_, label_ ? label_ : "backpropagation"), ne(&ne_),
  threshold(threshold_)
{
  Label inst(label);
  set_labels(inst,label_in,label_out,label_des_out);
  if (ne->is_supervised()) grad.resize(ne->get_out_sizes());
}

ANN::NE_bprop::NE_bprop(NE &ne_, double threshold_, const char *label_)
: NE(ne_.get_in_sizes(), label_ ? label_ : "backpropagation"), ne(&ne_),
  threshold(threshold_)
{
  Label inst(label);
  set_labels(inst,label_in,label_out,label_des_out);
  if (ne->is_supervised()) grad.resize(ne->get_out_sizes());
}

ANN::NE_bprop::NE_bprop(Common_t &c, Instance_t inst_)
: NE(*c.log,inst_), ne(0), threshold(c.threshold)
{
  Init   inst(inst_);
  size_t l=inst.length();
  set_labels(inst,label_in,label_out,label_des_out);
  inst.set_label("network",l);
  ne=c.f->create(&inst);
  if (ne->is_supervised()) grad.resize(ne->get_out_sizes());
}

ANN::NE_bprop::~NE_bprop()
{
  if (ne) delete ne;
  if (label_in) delete [] label_in;
  if (label_out) delete [] label_out;
  if (label_des_out) delete [] label_des_out;
}

const ANN::Size &
ANN::NE_bprop::get_out_sizes() const
{
  return ne->get_out_sizes();
}

const size_t
ANN::NE_bprop::get_out_size() const
{
  return ne->get_out_size();
}

void
ANN::NE_bprop::calc(const Term &in, Term &out)
{
  ne->calc(in,out);
}

bool
ANN::NE_bprop::is_supervised() const
{
  return true;
}

void
ANN::NE_bprop::reset(TO &t, bool b_reload)
{
  ne->reset(t,b_reload);
}

void
ANN::NE_bprop::prepare(TO &t)
{
  error=0;
  ne->prepare(t);
}

int
ANN::NE_bprop::update(TO &t)
{
  int ret=trained_state(ne->update(t),error<threshold);
  if (log)
  {
    log->nl(label) << "epoch error:" << error;
    log->nl();
  }
  return ret;
}

void
ANN::NE_bprop::finish(TO &t)
{
  ne->finish(t);
  if (log)
  {
    log->nl(label) << "last epoch error:" << error;
    log->nl();
  }
}

void
ANN::NE_bprop::adapt(const Term &in, const Term &out,
                     const Term *out_bp, Term *in_bp)
{
  double        d, e=0;
  double       *g=(ne->is_supervised() ? grad.get_data() : 0);
  const double *y=out.get_data(), *Y=(out_bp ? out_bp->get_data() : 0);
  size_t        i, n=grad.get_size();
  for (i=0; i<n; ++i) if ((Y[i]!=NAN) && Y)
  {
    d=y[i]-Y[i];
    if (g) g[i]=d;
    e+=d*d;
  }
  else if (g) g[i]=0;
  error+=e/2;
  if (log)
  {
    log->log(in,label_in);
    log->log(out,label_out);
    if (out_bp) log->log(*out_bp,label_des_out);
  }  
  ne->adapt(in,out,(g ? &grad : 0),in_bp);
}

bool
ANN::NE_bprop::load(Parser_log &parser, time_t begin, time_t end)
{
  return ne->load(parser,begin,end);
}

void
ANN::NE_bprop::store(Log &log_)
{
  ne->store(log_);
}
