/***************************************************************************
                          ann_nit_linear.cpp  -  description
                             -------------------
    begin                : sob wrz  6 2003
    copyright            : (C) 2003 by Bartosz Lis
    email                : bartoszl@ics.p.lodz.pl
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/

#include "config.h"

#include <ann_nit_linear.h>

  //                  //
 // class NIT_linear //
//                  //

ANN::NIT_linear::NIT_linear(NITC_linear &c, const Init *inst)
: NIT(c.log,inst), out_size(c.out_size), bias(c.bias), w(c.out_size),
  w_grad(inst->size.total_size()+c.bias)
{
  Init   sub_inst(inst->size.total_size()+bias,label);
  size_t len=sub_inst.length();
  sub_inst.set_label("neuron",len);
  size_t pos=sub_inst.length(), i, l=c.out_size.total_size();
  for (i=0; i<l; ++i)
  {
    sub_inst.set_index(i,pos);
    sub_inst.set_label("weights",sub_inst.length());
    w[i]=c.wf->create(&sub_inst);
  }
  sub_inst.set_label("sum",len);
  label_sum=sub_inst.dup_label();
}

ANN::NIT_linear::~NIT_linear()
{
  if (label_sum) delete [] label_sum;
}

const ANN::Size &
ANN::NIT_linear::get_out_sizes() const
{
  return out_size;
}

size_t
ANN::NIT_linear::get_out_size() const
{
  return out_size.total_size();
}

void
ANN::NIT_linear::register_weights(TO &torg)
{
  NIT::register_weights(torg);
  size_t i, n=out_size.total_size();
  for (i=0; i<n; ++i) torg.register_weight(*w[i]);
}

ANN::NE::Status
ANN::NIT_linear::close()
{
  return dumb;
}

ANN::NE::Status
ANN::NIT_linear::feed(const Term &in, Term &out)
{
  size_t        i, n=out_size.total_size();
  size_t        j, m=in_size.total_size();
  const double *x=in.get_data(), *u;
  double       *X=out.get_data();
  for (i=0; i<n; ++i)
  {
    u=w[i]->values().get_data();
    X[i]=(bias ? u[m] : 0);
    for (j=0; j<m; ++j) X[i]+=x[j]*u[j];
  }
  if (label_sum && log && log->is_active()) log->log(out,label_sum);
  return done;
}

bool
ANN::NIT_linear::fb_accept(bool former) const
{  
  return true;
}

void
ANN::NIT_linear::feed_back(const Term &in, const Term &out,
                          const Term *out_fb, Term *in_fb)
{
  size_t        i, n=out_size.total_size();
  size_t        j, m=in_size.total_size();
  const double *u, *x=in.get_data(), *X=out.get_data(), *E=out_fb->get_data();
  double       *d=w_grad.get_data(), *e, E_;
  if (in_fb)
  {
    e=in_fb->get_data();
    for (j=0; j<m; ++j) e[j]=0;
  }  
  for (i=0; i<n; ++i)
  {
    E_=E[i];
    u=w[i]->values().get_data();
    if (bias) d[m]=-E_;
    for (j=0; j<m; ++j)
    {
      if (in_fb) e[j]+=E_*u[j];
      d[j]=-E_*x[j];
    }  
    w[i]->feed_back(w_grad);
  }
}

