/***************************************************************************
                          ann_eval_bprop.cpp  -  description
                             -------------------
    begin                : pon kwi 14 2003
    copyright            : (C) 2003 by Bartosz Lis
    email                : bartoszl@ics.p.lodz.pl
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/

#include "config.h"

#include <ann_eval_bprop.h>

  //                       //
 // class ANN::Eval_bprop //
//                       //

static void
set_labels(bool ne, ANN::Label &label, char *&l_in, char *&l_out, 
           char *&l_des_out)
{
  size_t l=label.length();
  if (ne)
  {
    label.set_label("input",l);
    l_in=label.dup_label();
  }
  else l_in=0;
  label.set_label("outpt",l);
  l_out=label.dup_label();
  label.set_label("desrd",l);
  l_des_out=label.dup_label();
}

ANN::Eval_bprop::Eval_bprop(NIT &ne_, DS &ds_, double threshold_, Log *log_,
                            const char *label_)
: Eval(ne_, log_, label_ ? label_ : "backpropagation"), des_output(&ds_), 
  threshold(threshold_)
{
  Label inst(label);
  set_labels(ne,inst,label_in,label_out,label_des_out);
}

ANN::Eval_bprop::Eval_bprop(NIT &ne_, DS &ds_, double threshold_,
                            const char *label_)
: Eval(ne_, label_ ? label_ : "backpropagation"), des_output(&ds_), 
  threshold(threshold_)
{
  Label inst(label);
  set_labels(ne,inst,label_in,label_out,label_des_out);
}

ANN::Eval_bprop::Eval_bprop(NIT *ne_, DS &ds_, double threshold_, Log *log_,
                            const char *label_)
: Eval(ne_, log_, label_ ? label_ : "backpropagation"), des_output(&ds_), 
  threshold(threshold_)
{
  Label inst(label);
  set_labels(ne,inst,label_in,label_out,label_des_out);
}

ANN::Eval_bprop::Eval_bprop(NIT *ne_, DS &ds_, double threshold_,
                            const char *label_)
: Eval(ne_, label_ ? label_ : "backpropagation"), des_output(&ds_), 
  threshold(threshold_)
{
  Label inst(label);
  set_labels(ne,inst,label_in,label_out,label_des_out);
}

ANN::Eval_bprop::Eval_bprop(Size in_size_, DS &ds_, double threshold_, 
                            Log *log_, const char *label_)
: Eval(in_size_, log_, label_ ? label_ : "backpropagation"), des_output(&ds_), 
  threshold(threshold_)
{
  Label inst(label);
  set_labels(ne,inst,label_in,label_out,label_des_out);
}

ANN::Eval_bprop::Eval_bprop(Size in_size_, DS &ds_, double threshold_,
                          const char *label_)
: Eval(in_size_, label_ ? label_ : "backpropagation"), des_output(&ds_), 
  threshold(threshold_)
{
  Label inst(label);
  set_labels(ne,inst,label_in,label_out,label_des_out);
}

ANN::Eval_bprop::Eval_bprop(Common_t &c, Instance_t inst_)
: Eval(c,inst_), des_output(c.ds), threshold(c.threshold)
{
  Label inst(label);
  set_labels(ne,inst,label_in,label_out,label_des_out);
}

ANN::Eval_bprop::~Eval_bprop()
{
  if (label_in) delete [] label_in;
  if (label_out) delete [] label_out;
  if (label_des_out) delete [] label_des_out;
}

ANN::NE::Status
ANN::Eval_bprop::commence(DS &ds_, bool is_input)
{
  if (ds_.get_lengths()!=des_output->get_lengths()) return wrong;
  return Eval::commence(ds_,is_input);
}

void
ANN::Eval_bprop::conclude()
{
  Eval::conclude();
  if (log)
  {
    log->nl(label) << "last epoch error:" << error;
    log->nl();
  }
}

void
ANN::Eval_bprop::open()
{
  error=0;
  des_output->reset();
  Eval::open();
}

void
ANN::Eval_bprop::feed_back(const Term &in, const Term &out,
                           const Term *out_fb, Term *in_fb)
{
  des_output->locate(torg->where());
  Term         *grad_=(ne ? (ne->fb_accept(in_fb) ? &grad : 0) : in_fb);
  double       *g=(grad_ ? grad_->get_data() : 0), d, e=0;
  const double *y=out.get_data(), *Y=des_output->term().get_data();
  size_t        i, n=get_out_size();
  for (i=0; i<n; ++i) if ((Y[i]!=NAN) && Y)
  {
    d=y[i]-Y[i];
    if (g) g[i]=d;
    e+=d*d;
  }
  else if (g) g[i]=0;
  error+=e/2;
  if (log)
  {
    if (label_in) log->log(in,label_in);
    log->log(out,label_out);
    log->log(des_output->term(),label_des_out);
  }
  if (ne) ne->feed_back(in,out,grad_,in_fb);
}

ANN::NE::Status
ANN::Eval_bprop::close()
{
  Status ret=Eval::close() + (error<threshold);
  if (log)
  {
    log->nl(label) << "epoch error:" << error;
    log->nl();
  }
  return ret;
}

