/***************************************************************************
                          ann_ne_mlayer.cpp  -  description
                             -------------------
    begin                : pon kwi 14 2003
    copyright            : (C) 2003 by Bartosz Lis
    email                : bartoszl@ics.p.lodz.pl
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/

#include <ann_ne_mlayer.h>

  //                      //
 // class ANN::NE_mlayer //
//                      //

ANN::NE_mlayer::LayerData::LayerData(NE &elem_, NE *former)
: elem(&elem_), in(former ? new Term(elem_.get_in_sizes()) : 0),
  err(former && former->is_supervised() ? new Term(elem_.get_in_sizes()) : 0)
{
}

ANN::NE_mlayer::~NE_mlayer()
{
  L_it it=layers.begin();
  L_it end=layers.end();
  while (it!=end) delete *it++;
}

bool
ANN::NE_mlayer::add_layer(NE *elem_)
{
  if (!elem_) return false;
  size_t s=(layers.size() ? layers.back()->elem->get_out_size()
                          : in_size.total_size());
  if (!s || (s!=elem_->get_in_size())) return false;
  layers.push_back(new LayerData(*elem_,layers.size() ? layers.back()->elem
                                                      : 0));
  return true;
}

bool
ANN::NE_mlayer::add_layer(NE_FB *factory)
{
  if (!factory) return false;
  Init inst(layers.size() ? layers.back()->elem->get_out_sizes() : in_size,
            label);
  if (!inst.size.total_size()) return false;
  inst.set_label("layer",inst.length());
  inst.set_index(layers.size(),inst.length());
  NE *elem_=factory->create(&inst);
  layers.push_back(new LayerData(*elem_,layers.size() ? layers.back()->elem
                                                      : 0));
  return true;
}

const ANN::Size &
ANN::NE_mlayer::get_out_sizes() const
{
  return layers.size() ? layers.back()->elem->get_out_sizes() : in_size;
}

void
ANN::NE_mlayer::calc(const Term &in, Term &out)
{
  L_it        it=layers.begin();
  L_it        end=layers.end();
  const Term *in_=&in;
  Term       *out_;
  LayerData  *l;
  while (it!=end)
  {
    l=*it++;
    l->elem->calc(*in_, *(out_=(it==end ? &out : (*it)->in)));
    in_=out_;
  }
}

bool
ANN::NE_mlayer::is_supervised() const
{
  return !layers.empty() && layers.back()->elem->is_supervised();
}

void
ANN::NE_mlayer::reset(TO &t, bool b_reload)
{
  L_it it=layers.begin();
  L_it end=layers.end();
  while (it!=end) (*it++)->elem->reset(t,b_reload);
}

void
ANN::NE_mlayer::prepare(TO &t)
{
  L_it it=layers.begin();
  L_it end=layers.end();
  while (it!=end) (*it++)->elem->prepare(t);
}

int
ANN::NE_mlayer::update(TO &t)
{
  int ret=0;
  L_it it=layers.begin();
  L_it end=layers.end();
  while (it!=end) ret=trained_state(ret,(*it++)->elem->update(t));
  return ret;
}

void
ANN::NE_mlayer::finish(TO &t)
{
  L_it it=layers.begin();
  L_it end=layers.end();
  while (it!=end) (*it++)->elem->finish(t);
}

void
ANN::NE_mlayer::adapt(const Term &in, const Term &out,
                      const Term *out_err, Term *in_err)
{
  L_rit       it=layers.rbegin();
  L_rit       end=layers.rend();
  const Term *in_, *out_=&out, *out_err_=out_err;
  Term       *in_err_;
  LayerData *l;
  while (it!=end)
  {
    l=*it++;
    l->elem->adapt(*(in_=(it==end ? &in : l->in)), *out_,
                   out_err_, in_err_=(it==end ? in_err : l->err));
    out_=in_;
    out_err_=in_err_;
  }
}

