/***************************************************************************
                          ann_eval_compet.cpp  -  description
                             -------------------
    begin                : sob wrz  6 2003
    copyright            : (C) 2003 by Bartosz Lis
    email                : bartoszl@ics.p.lodz.pl
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/

#include "config.h"

#include <ann_eval_compet.h>

  //                        //
 // class ANN::Eval_compet //
//                        //

ANN::Eval_compet::Eval_compet(NIT &ne_, bool compact_, Log *log_, 
                              const char *label_)
: Eval(ne_, log_, label_ ? label_ : "competition"), compact(compact_) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(NIT &ne_, bool compact_, const char *label_)
: Eval(ne_, label_ ? label_ : "competition"), compact(compact_) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(NIT &ne_, Log *log_, const char *label_)
: Eval(ne_, log_, label_ ? label_ : "competition"), compact(false) 
{
}

ANN::Eval_compet::Eval_compet(NIT &ne_, const char *label_)
: Eval(ne_, label_ ? label_ : "competition"), compact(false) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(NIT *ne_, bool compact_, Log *log_, 
                              const char *label_)
: Eval(ne_, log_, label_ ? label_ : "competition"), compact(compact_) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(NIT *ne_, bool compact_, const char *label_)
: Eval(ne_, label_ ? label_ : "competition"), compact(compact_) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(NIT *ne_, Log *log_, const char *label_)
: Eval(ne_, log_, label_ ? label_ : "competition"), compact(false) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(NIT *ne_, const char *label_)
: Eval(ne_, label_ ? label_ : "competition"), compact(false) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(Size in_size_, bool compact_, Log *log_, 
                              const char *label_)
: Eval(in_size_, log_, label_ ? label_ : "competition"), compact(compact_) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(Size in_size_, bool compact_, const char *label_)
: Eval(in_size_, label_ ? label_ : "competition"), compact(compact_) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(Size in_size_, Log *log_, const char *label_)
: Eval(in_size_, log_, label_ ? label_ : "competition"), compact(false) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(Size in_size_, const char *label_)
: Eval(in_size_, label_ ? label_ : "competition"), compact(false) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::Eval_compet(EvalC_compet &c, const Init *inst)
: Eval(c,inst), compact(c.compact) 
{
  if (ne) comp.resize(ne->get_out_sizes());
}

ANN::Eval_compet::~Eval_compet()
{
}

ANN::NE::Status 
ANN::Eval_compet::find_max(const Term &in)
{
  Status ret;
  if (ne && ((ret=ne->feed(in,comp))!=done)) return ret;
  const Term   *in_=(ne ? &comp : &in);
  const double *X=in_->get_data();
  size_t        i, n=in_->get_size();
  max_ind=0;
  if (n)
  {
    double max_val=X[0];
    for (i=1; i<n; ++i) if (max_val<X[i]) max_val=X[max_ind=i];
  }
  return done;
}

const ANN::Size &
ANN::Eval_compet::get_out_sizes() const
{
  return compact ? single_out_size 
                 : (ne ? ne->get_out_sizes() : get_in_sizes());
}

size_t
ANN::Eval_compet::get_out_size() const
{
  return compact ? 1 : (ne ? ne->get_out_size() : get_in_size());
}

ANN::NE::Status
ANN::Eval_compet::feed(const Term &in, Term &out)
{
  Status ret;
  if ((ret=find_max(in))!=done) return ret;
  double *y=out.get_data();
  if (compact) y[0]=max_ind;
  else 
  {
    size_t i, n=Eval::get_out_size();
    for (i=0; i<max_ind; ++i) y[i]=0;
    y[max_ind]=1;
    for (i=max_ind+1; i<n; ++i) y[i]=0;
  }
  if (log) log->log(out,label);
  return done;
}

void
ANN::Eval_compet::feed_back(const Term &in, const Term &out,
                            const Term *out_fb, Term *in_fb)
{
  Term       *grad_=(ne ? (ne->fb_accept(in_fb) ? &grad : 0) : in_fb);
  const Term *grad_c=grad_;
  if (grad_)
  {
    double *g=grad_->get_data();
    size_t  i, n=Eval::get_out_size();
    if (compact)
    {
      for (i=0; i<max_ind; ++i) g[i]=0;
      g[max_ind]=1;
      for (i=max_ind+1; i<n; ++i) g[i]=0;
    }
    else if (ne) grad_c=&out;
    else *grad_=out;
    if (ne) ne->feed_back(in,comp,grad_c,in_fb);
  }
}

