/***************************************************************************
                          ann_eval_compet.cpp  -  description
                             -------------------
    begin                : sob wrz  6 2003
    copyright            : (C) 2003 by Bartosz Lis
    email                : bartoszl@ics.p.lodz.pl
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/

#include "config.h"

#include <ann_eval_compet.h>

  //                        //
 // class ANN::Eval_compet //
//                        //

ANN::Eval_compet::Eval_compet(NIT &ne_, bool scalar_, Log *log_, const char *label_)
: Eval(ne_, log_, label_ ? label_ : "competition"), scalar(scalar_) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(NIT &ne_, bool scalar_, const char *label_)
: Eval(ne_, label_ ? label_ : "competition"), scalar(scalar_) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(NIT &ne_, Log *log_, const char *label_)
: Eval(ne_, log_, label_ ? label_ : "competition"), scalar(false) 
{
}

ANN::Eval_compet::Eval_compet(NIT &ne_, const char *label_)
: Eval(ne_, label_ ? label_ : "competition"), scalar(false) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(NIT *ne_, bool scalar_, Log *log_, const char *label_)
: Eval(ne_, log_, label_ ? label_ : "competition"), scalar(scalar_) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(NIT *ne_, bool scalar_, const char *label_)
: Eval(ne_, label_ ? label_ : "competition"), scalar(scalar_) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(NIT *ne_, Log *log_, const char *label_)
: Eval(ne_, log_, label_ ? label_ : "competition"), scalar(false) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(NIT *ne_, const char *label_)
: Eval(ne_, label_ ? label_ : "competition"), scalar(false) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(Size in_size_, bool scalar_, Log *log_, 
                              const char *label_)
: Eval(in_size_, log_, label_ ? label_ : "competition"), scalar(scalar_) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(Size in_size_, bool scalar_, const char *label_)
: Eval(in_size_, label_ ? label_ : "competition"), scalar(scalar_) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(Size in_size_, Log *log_, const char *label_)
: Eval(in_size_, log_, label_ ? label_ : "competition"), scalar(false) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(Size in_size_, const char *label_)
: Eval(in_size_, label_ ? label_ : "competition"), scalar(false) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(EvalC_compet &c, const Init *inst)
: Eval(c,inst), scalar(c.scalar) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::~Eval_compet()
{
  if (ne) comp.resize(ne->get_out_sizes());
}

const ANN::Size &
ANN::Eval_compet::get_out_sizes() const
{
  return scalar ? single_out_size 
                : (ne ? ne->get_out_sizes() : get_in_sizes());
}

size_t
ANN::Eval_compet::get_out_size() const
{
  return scalar ? 1 : (ne ? ne->get_out_size() : get_in_size());
}

ANN::NE::Status
ANN::Eval_compet::feed(const Term &in, Term &out)
{
  Status ret;
  if (ne && ((ret=ne->feed(in,comp))!=done)) return ret;
  double       *y=out.get_data(), max_val;
  const double *X=(ne ? (const Term &)comp : in).get_data();
  size_t        max_ind=0, i, n=Eval::get_out_size();
  if (n) 
  {
    max_val=X[0];
    for (i=1; i<n; ++i) if (max_val<X[i]) max_val=X[max_ind=i];
    if (scalar) y[0]=max_ind;
    else
    {
      for (i=0; i<max_ind; ++i) y[i]=0;
      y[max_ind]=1;
      for (i=max_ind+1; i<max_ind; ++i) y[i]=0;
    }
  }
  if (log) log->log(out,label);
  return done;
}

void
ANN::Eval_compet::feed_back(const Term &in, const Term &out,
                            const Term *out_fb, Term *in_fb)
{
  Term       *grad_=(ne ? (ne->fb_accept(in_fb) ? &grad : 0) : in_fb);
  const Term *grad_c=grad_;
  if (grad_)
  {
    double *g=grad_->get_data();
    size_t  i, n=Eval::get_out_size();
    if (scalar)
    {
      size_t max_ind=(unsigned)out[0];
      for (i=0; i<max_ind; ++i) g[i]=0;
      g[max_ind]=1;
      for (i=max_ind+1; i<max_ind; ++i) g[i]=0;
    }
    else if (ne) grad_c=&out;
    else *grad_=out;
    if (ne) ne->feed_back(in,comp,grad_c,in_fb);
  }
}

