/***************************************************************************
                          ann_nit_norm.cpp  -  description
                             -------------------
    begin                : sob wrz  6  2003
    copyright            : (C) 2003 by Bartosz Lis
    email                : bartoszl@ics.p.lodz.pl
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/

#include "config.h"

#include <ann_nit_norm.h>

  //                //
 // class NIT_norm //
//                //

ANN::NIT_norm::NIT_norm(NITC_norm &c, const Init *inst)
: NIT(c.log,inst), bias(c.bias), norm(0)
{
  if (bias) out_size << (in_size.total_size()+1); 
  else out_size=in_size;
}

ANN::NIT_norm::~NIT_norm()
{
}

const ANN::Size &
ANN::NIT_norm::get_out_sizes() const
{
  return out_size;
}

size_t
ANN::NIT_norm::get_out_size() const
{
  return out_size.total_size();
}

ANN::NE::Status
ANN::NIT_norm::feed(const Term &in, Term &out)
{
  size_t  i, n=in_size.total_size();
  double *X=out.get_data();
  norm=bias;
  if (n)
  {
    const double *x=in.get_data();     
    for (i=0; i<n; ++i) norm+=x[i]*x[i];
    if (norm)
    {
      norm=1/sqrt(norm);
      for (i=0; i<n; ++i) X[i]=x[i]*norm;
    }
  }
  if (bias) X[n]=norm;
  if (log) log->log(out,label);
}

bool
ANN::NIT_norm::fb_accept(bool former) const
{
  return former;
}

void
ANN::NIT_norm::feed_back(const Term &in, const Term &out,
                         const Term *out_fb, Term *in_fb)
{
  if (in_fb && out_fb)
  {
    size_t        i, j, n=in_size.total_size();
    const double *X=out.get_data(), *E=out_fb->get_data();
    double       *e=in_fb->get_data(), E_;    
    if (norm) 
    {
      for (j=0; j<n; ++j) e[j]=E[j]*norm;
      for (i=0; i<j; ++i) 
      {
        E_=E[i]*X[i]*norm;
        for (j=0; j<n; ++j) e[j]-=E_*X[j];
      }
    }
    else for (j=0; j<n; ++j) e[j]=0;
  }
}

