/***************************************************************************
                          ann_to.h  -  description
                             -------------------
    begin                : pon kwi 14 2003
    copyright            : (C) 2003 by Bartosz Lis
    email                : bartoszl@ics.p.lodz.pl
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/

#ifndef __ANN_TO_H
#define __ANN_TO_H

#include <ctype.h>

namespace ANN
{
  class TO;
  class NIT;
}

#include <ann_ds.h>
#include <ann_weight.h>
#include <ann_ne.h>

namespace ANN
{

  //           //
 // class NIT //
//           //

  class NIT : public NE
  {
  public:
    typedef Log         Common_t;
    typedef const Init *Instance_t;
    typedef NE          Base_t;

  private:
    NIT();
    NIT(const NE &that);

  protected:
    TO *torg;

    NIT(Size in_size_, Log &log_, const char *label_=0)
    : NE(in_size_,log_,label_), torg(0) {}
    NIT(Size in_size_, Log *log_, const char *label_=0)
    : NE(in_size_,log_,label_), torg(0) {}
    NIT(Size in_size_, const char *label_=0)
    : NE(in_size_,label_), torg(0) {}
    NIT(Log &log_, Instance_t inst) : NE(log_,inst), torg(0) {}
    NIT(Log *log_, Instance_t inst) : NE(log_,inst), torg(0) {}
    NIT(Instance_t inst) : NE(inst), torg(0) {}

  public:
    virtual ~NIT();

    TO                   *get_torg() const { return torg; }

    /** before training starts or after it stops **/
    virtual void          register_weights(TO &torg_);
    /* registers weights in a training organizr */

    /** training period commencing and concluding **/
    virtual Status        commence(DS &ds_, bool is_input);
    /* training period commences */
    virtual void          conclude();
    /* training period concludes */

    /** single epoch of training **/
    virtual void          open();
    /* called every time epoch starts */
    virtual void          adapt(const Term &in, Term &out)=0;
    /* called to conduct epoch steps */
    virtual Status        close()=0;
    /* called every time epoch ends */

    /** trainig **/
    virtual Status        adapt_for(size_t epochs=1);
    /* conducts training for no more than specified number of epochs */
    virtual size_t        get_epoch() const;
    /* reports current epoch */
    inline size_t         which() const;
    /* reports current epoch step */
    inline const Loc     &where() const;
    /* reports current position in the DS */
  };

  typedef FB<NIT, const Init *> NIT_FB;

  //           //
 // class TOC //
//           //

  class TOC
  {
  public:
    Log    *log;
    NIT_FB *f;
    DS     *ds;
    
    TOC(NIT_FB &f_, DS &ds_, Log *log_=0)
    : log(log_), f(&f_), ds(&ds_) {}
    TOC(NIT_FB &f_, DS &ds_, Log &log_)
    : log(&log_), f(&f_), ds(&ds_) {}
    virtual ~TOC();
  };
  
  //          //
 // class TO //
//          //

  class TO : public NE
  {
  public:
    typedef TOC         Common_t;
    typedef const Init *Instance_t;
    typedef NE          Base_t;
    
  private:
    TO();
    TO(const NE &that);

  protected:
    NIT      *trainee;
    WS       *ws;
    DS       *ds;
    TERM      out_data;
    size_t    epoch;

    void embed(NIT &trainee_);
    
    TO(NIT &trainee_, DS &ds_, Log &log_, const char *label_=0)
    : NE(trainee_.get_in_sizes(), log_, label_), ws(WS::create()), ds(&ds_), 
      epoch(0) { embed(trainee_); }
    TO(NIT &trainee_, DS &ds_, Log *log_, const char *label_=0)
    : NE(trainee_.get_in_sizes(), log_, label_), ws(WS::create()), ds(&ds_), 
      epoch(0) { embed(trainee_); }
    TO(NIT &trainee_, DS &ds_, const char *label_=0)
    : NE(trainee_.get_in_sizes(), label_), ws(WS::create()), ds(&ds_), 
      epoch(0) { embed(trainee_); }
    TO(Common_t &c, Instance_t inst)
    : NE(c.log,inst), ws(WS::create()), ds(c.ds), epoch(0)
    { embed(*c.f->create(inst)); }

  public:
    virtual ~TO();

    void                  register_weight(W &w);
    NIT                  &get_trainee() const { return *trainee; }
    virtual const Size   &get_out_sizes() const;
    virtual size_t        get_out_size() const;

    /** output comutation **/
    virtual Status        feed(const Term &in, Term &out);

    /** before training starts or after it stops **/
    virtual void          reset(bool prod=false);
    /* switch between training and production modes */
    virtual bool          load(Parser_log &parser, time_t begin=0,
                               time_t end=NAV_get(time_t));
    /* loads state information from a log, you should call reset() then */
    virtual void          store(Log &log_);
    /* stores state information to a log */
    virtual void          randomize();
    /* randomizes weights */

    /** trainig **/
  protected:  
    virtual Status        adapt_internal()=0;
    /* conducts training for a single epochs */
  public:  
    virtual Status        adapt_for(size_t epochs=1);
    /* conducts training for no more than specified number of epochs */
    virtual size_t        get_epoch() const;
    /* reports current epoch */
    size_t                which() const { return ds->which(); }
    /* reports current epoch step */
    const Loc            &where() const { return ds->where(); }
    /* reports current position in the DS */
  };

  typedef FB<TO, const Init *> TO_FB;

} //namespace ANN

inline size_t
ANN::NIT::which() const
{
  return torg ? torg->which() : size_max;
}

const ANN::Loc &
ANN::NIT::where() const
{
  return torg ? torg->where() : Loc::empty;
}

#endif /* __ANN_TO_H */
