/***************************************************************************
                          ann_ne_mlayer.h  -  description
                             -------------------
    begin                : pon kwi 14 2003
    copyright            : (C) 2003 by Bartosz Lis
    email                : bartoszl@ics.p.lodz.pl
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/

#ifndef __ANN_NE_MLAYER_H
#define __ANN_NE_MLAYER_H

#include <vector>

#include <ann_ne.h>
#include <ann_factory.h>

namespace ANN
{

  class NEC_mlayer;
  class NE_mlayer;

  //                  //
 // class NEC_mlayer //
//                  //

  class NEC_mlayer
  {
  private:
    NEC_mlayer();

  public:
    NE_FB  **f;
    size_t   num;
    Log     *log;

    NEC_mlayer(NE_FB **f_, Log *log_=0)
    : f(f_), num(0), log(log_) { while (f[num]) ++num; }
    NEC_mlayer(NE_FB **f_, size_t num_, Log *log_=0)
    : f(f_), num(num_), log(log_) {}
    NEC_mlayer(const NEC_mlayer &that)
    : f(that.f), num(that.num), log(that.log) {}

    void drop_f() { while (num) delete f[--num]; }
    
    NEC_mlayer &operator = (const NEC_mlayer &that)
    { f=that.f; num=that.num; log=that.log; return *this; }
  };

  //                 //
 // class NE_mlayer //
//                 //

  class NE_mlayer : public NE
  {
  protected:
    struct LayerData
    {
    public:
      NE   *elem;
      Term *in;
      Term *err;

      LayerData(NE &elem_, NE *former);
      ~LayerData()
      { if (elem) delete elem; if (in) delete in; if (err) delete err; }
    };
    typedef std::vector<LayerData *>  Layers;
    typedef Layers::iterator          L_it;
    typedef Layers::reverse_iterator  L_rit;
    typedef NEC_mlayer                Common_t;
    typedef NE::Instance_t            Instance_t;
    typedef NE                        Base_t;

  private:
    NE_mlayer();
    NE_mlayer(const NE_mlayer &that);
    
  protected:  
    Layers layers;

    NE_mlayer(Log &log_, Instance_t inst) : NE(log_,inst) {}
    NE_mlayer(Log *log_, Instance_t inst) : NE(log_,inst) {}
    NE_mlayer(Instance_t inst) : NE(inst) {}

    bool add_layer(NE *elem_);
    bool add_layer(NE &elem_) { return add_layer(&elem_); }
    bool add_layer(NE_FB *factory);
    bool add_layer(NE_FB &factory) { return add_layer(&factory); }

    template<class Iterator>
    bool build(Iterator begin, Iterator end)
    {
      for (Iterator it=begin; it!=end; ++it) if (!add_layer(*it)) return false;
      return true;
    }
    bool build(NE_FB **arr, size_t num)
    {
      for (size_t i=0; i<num; ++i) if (!add_layer(arr[i])) return false;
      return true;
    }
    bool build(NE_FB **arr)
    {
      while (*arr) if (add_layer(*arr)) ++arr; else return false;
      return true;
    }

  public:
    template<class Iterator>
    NE_mlayer(Instance_t inst, Log &log_, Iterator begin, Iterator end)
    : NE(log_,inst) { build(begin,end); }
    template<class Iterator>
    NE_mlayer(Instance_t inst, Log *log_, Iterator begin, Iterator end)
    : NE(log_,inst) { build(begin,end); }
    template<class Iterator>
    NE_mlayer(Instance_t inst, Iterator begin, Iterator end)
    : NE(inst) { build(begin,end); }
    NE_mlayer(Instance_t inst, Log &log_, NE_FB **factory, size_t num)
    : NE(log_,inst) { build(factory,num); }
    NE_mlayer(Instance_t inst, Log *log_, NE_FB **factory, size_t num)
    : NE(log_,inst) { build(factory,num); }
    NE_mlayer(Instance_t inst, NE_FB **factory, size_t num)
    : NE(inst) { build(factory,num); }
    NE_mlayer(Instance_t inst, Log &log_, NE_FB **factory)
    : NE(log_,inst) { build(factory); }
    NE_mlayer(Instance_t inst, Log *log_, NE_FB **factory)
    : NE(log_,inst) { build(factory); }
    NE_mlayer(Instance_t inst, NE_FB **factory)
    : NE(inst) { build(factory); }
    NE_mlayer(NEC_mlayer c, Instance_t inst)
    : NE(c.log,inst) { build(c.f,c.num); }

    virtual ~NE_mlayer();

    virtual const Size &get_out_sizes() const;
    virtual void        calc(const Term &in, Term &out);
    virtual bool        is_supervised() const;
    virtual void        reset(TO &t, bool b_reload=false);
    virtual void        prepare(TO &t);
    virtual int         update(TO &t);
    virtual void        finish(TO &t);
    virtual void        adapt(const Term &in, const Term &out,
                              const Term *out_err, Term *in_err);

    template<class Iterator>
    static void drop_f(Iterator begin, Iterator end)
    { while(begin!=end) delete *begin++; }
    static void drop_f(NE_FB **arr, size_t num)
    { for (size_t i=0; i<num; ++i) delete arr[i]; }
    static void drop_f(NE_FB **arr)
    { while (*arr) delete *arr++; }
  };

  typedef F<NE_mlayer> NE_mlayer_F;
  
} // namespace ANN

#endif /* __ANN_NE_MLAYER_H */
