# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

use strict;
use warnings;
use Test::More tests => 77;
use AI::MXNet 'mx';
use AI::MXNet::Gluon 'gluon';
use AI::MXNet::TestUtils qw/allclose almost_equal/;
use AI::MXNet::Base;
use Scalar::Util 'blessed';

sub test_rnn
{
    my $cell = gluon->rnn->RNNCell(100, prefix=>'rnn_');
    my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2];
    my ($outputs) = $cell->unroll(3, $inputs);
    $outputs = mx->sym->Group($outputs);
    is_deeply([sort $cell->collect_params()->keys()], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
    is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);

    my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
    is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}

test_rnn();

sub test_lstm
{
    my $cell = gluon->rnn->LSTMCell(100, prefix=>'rnn_');
    my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2];
    my ($outputs) = $cell->unroll(3, $inputs);
    $outputs = mx->sym->Group($outputs);
    is_deeply([sort $cell->collect_params()->keys()], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
    is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);

    my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
    is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}

test_lstm();

sub test_lstm_forget_bias
{
    my $forget_bias = 2;
    my $stack = gluon->rnn->SequentialRNNCell();
    $stack->add(gluon->rnn->LSTMCell(100, i2h_bias_initializer=>mx->init->LSTMBias($forget_bias), prefix=>'l0_'));
    $stack->add(gluon->rnn->LSTMCell(100, i2h_bias_initializer=>mx->init->LSTMBias($forget_bias), prefix=>'l1_'));

    my $dshape = [32, 1, 200];
    my $data = mx->sym->Variable('data');

    my ($sym) = $stack->unroll(1, $data, merge_outputs=>1);
    my $mod = mx->mod->Module($sym, context=>mx->cpu(0));
    $mod->bind(data_shapes=>[['data', $dshape]]);

    $mod->init_params();

    my ($bias_argument) = grep { /i2h_bias$/ } @{ $sym->list_arguments() };
    my $expected_bias = pdl((0)x100, ($forget_bias)x100, (0)x200);
    ok(allclose(($mod->get_params())[0]->{$bias_argument}->aspdl, $expected_bias));
}

test_lstm_forget_bias();

sub test_gru
{
    my $cell = gluon->rnn->GRUCell(100, prefix=>'rnn_');
    my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2];
    my ($outputs) = $cell->unroll(3, $inputs);
    $outputs = mx->sym->Group($outputs);
    is_deeply([sort $cell->collect_params()->keys()], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
    is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);

    my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
    is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}

test_gru();

sub test_residual
{
    my $cell = gluon->rnn->ResidualCell(gluon->rnn->GRUCell(50, prefix=>'rnn_'));
    my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..1];
    my ($outputs) = $cell->unroll(2, $inputs);
    $outputs = mx->sym->Group($outputs);
    is_deeply([sort $cell->collect_params()->keys()], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
    my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50]);
    is_deeply($outs, [[10, 50], [10, 50]]);
    $outputs = $outputs->eval(args => { rnn_t0_data=>mx->nd->ones([10, 50]),
                           rnn_t1_data=>mx->nd->ones([10, 50]),
                           rnn_i2h_weight=>mx->nd->zeros([150, 50]),
                           rnn_i2h_bias=>mx->nd->zeros([150]),
                           rnn_h2h_weight=>mx->nd->zeros([150, 50]),
                           rnn_h2h_bias=>mx->nd->zeros([150]) });
    my $expected_outputs = mx->nd->ones([10, 50]);
    ok(($outputs->[0] == $expected_outputs)->aspdl->all);
    ok(($outputs->[1] == $expected_outputs)->aspdl->all);
}

test_residual();

sub test_residual_bidirectional
{
    my $cell = gluon->rnn->ResidualCell(
        gluon->rnn->BidirectionalCell(
            gluon->rnn->GRUCell(25, prefix=>'rnn_l_'),
            gluon->rnn->GRUCell(25, prefix=>'rnn_r_')
        )
    );
    my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..1];
    my ($outputs) = $cell->unroll(2, $inputs, merge_outputs => 0);
    $outputs = mx->sym->Group($outputs);
    is_deeply([sort $cell->collect_params()->keys()],
                ['rnn_l_h2h_bias', 'rnn_l_h2h_weight', 'rnn_l_i2h_bias', 'rnn_l_i2h_weight',
                'rnn_r_h2h_bias', 'rnn_r_h2h_weight', 'rnn_r_i2h_bias', 'rnn_r_i2h_weight']);
    my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50]);
    is_deeply($outs, [[10, 50], [10, 50]]);
    $outputs = $outputs->eval(args => { rnn_t0_data=>mx->nd->ones([10, 50])+5,
                           rnn_t1_data=>mx->nd->ones([10, 50])+5,
                           rnn_l_i2h_weight=>mx->nd->zeros([75, 50]),
                           rnn_l_i2h_bias=>mx->nd->zeros([75]),
                           rnn_l_h2h_weight=>mx->nd->zeros([75, 25]),
                           rnn_l_h2h_bias=>mx->nd->zeros([75]),
                           rnn_r_i2h_weight=>mx->nd->zeros([75, 50]),
                           rnn_r_i2h_bias=>mx->nd->zeros([75]),
                           rnn_r_h2h_weight=>mx->nd->zeros([75, 25]),
                           rnn_r_h2h_bias=>mx->nd->zeros([75]),
    });
    my $expected_outputs = mx->nd->ones([10, 50])+5;
    ok(($outputs->[0] == $expected_outputs)->aspdl->all);
    ok(($outputs->[1] == $expected_outputs)->aspdl->all);
}

test_residual_bidirectional();

sub test_stack
{
    my $cell = gluon->rnn->SequentialRNNCell();
    for my $i (0..4)
    {
        if($i == 1)
        {
            $cell->add(gluon->rnn->ResidualCell(gluon->rnn->LSTMCell(100, prefix=>"rnn_stack${i}_")));
        }
        else
        {
            $cell->add(gluon->rnn->LSTMCell(100, prefix=>"rnn_stack${i}_"));
        }
    }
    my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2];
    my ($outputs) = $cell->unroll(3, $inputs);
    $outputs = mx->sym->Group($outputs);
    my %keys = map { $_ => 1 } $cell->collect_params()->keys();
    for my $i (0..4)
    {
        ok($keys{"rnn_stack${i}_h2h_weight"});
        ok($keys{"rnn_stack${i}_h2h_bias"});
        ok($keys{"rnn_stack${i}_i2h_weight"});
        ok($keys{"rnn_stack${i}_i2h_bias"});
    }
    is_deeply($outputs->list_outputs(), ['rnn_stack4_t0_out_output', 'rnn_stack4_t1_out_output', 'rnn_stack4_t2_out_output']);
    my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
    is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}

test_stack();

sub test_bidirectional
{
    my $cell = gluon->rnn->BidirectionalCell(
            gluon->rnn->LSTMCell(100, prefix=>'rnn_l0_'),
            gluon->rnn->LSTMCell(100, prefix=>'rnn_r0_'),
            output_prefix=>'rnn_bi_');
    my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2];
    my ($outputs) = $cell->unroll(3, $inputs);
    $outputs = mx->sym->Group($outputs);
    is_deeply($outputs->list_outputs(), ['rnn_bi_t0_output', 'rnn_bi_t1_output', 'rnn_bi_t2_output']);
    my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
    is_deeply($outs, [[10, 200], [10, 200], [10, 200]]);
}

test_bidirectional();

sub test_zoneout
{
    my $cell = gluon->rnn->ZoneoutCell(gluon->rnn->RNNCell(100, prefix=>'rnn_'), zoneout_outputs=>0.5,
                              zoneout_states=>0.5);
    my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2];
    my ($outputs) = $cell->unroll(3, $inputs);
    $outputs = mx->sym->Group($outputs);
    my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
    is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
}

test_zoneout();

sub check_rnn_forward
{
    my ($layer, $inputs, $deterministic) = @_;
    $deterministic //= 1;
    $inputs->attach_grad();
    $layer->collect_params()->initialize();
    my $out;
    mx->autograd->record(sub {
        $out = ($layer->unroll(3, $inputs, merge_outputs=>0))[0];
        mx->autograd->backward($out);
        $out = ($layer->unroll(3, $inputs, merge_outputs=>1))[0];
        $out->backward;
    });

    my $pdl_out = $out->aspdl;
    my $pdl_dx = $inputs->grad->aspdl;

    $layer->hybridize;

    mx->autograd->record(sub {
        $out = ($layer->unroll(3, $inputs, merge_outputs=>0))[0];
        mx->autograd->backward($out);
        $out = ($layer->unroll(3, $inputs, merge_outputs=>1))[0];
        $out->backward;
    });

    if($deterministic)
    {
        ok(almost_equal($pdl_out, $out->aspdl, 1e-3));
        ok(almost_equal($pdl_dx, $inputs->grad->aspdl, 1e-3));
    }
}

sub test_rnn_cells
{
    check_rnn_forward(gluon->rnn->LSTMCell(100, input_size=>200), mx->nd->ones([8, 3, 200]));
    check_rnn_forward(gluon->rnn->RNNCell(100, input_size=>200), mx->nd->ones([8, 3, 200]));
    check_rnn_forward(gluon->rnn->GRUCell(100, input_size=>200), mx->nd->ones([8, 3, 200]));
    my $bilayer = gluon->rnn->BidirectionalCell(
        gluon->rnn->LSTMCell(100, input_size=>200),
        gluon->rnn->LSTMCell(100, input_size=>200)
    );
    check_rnn_forward($bilayer, mx->nd->ones([8, 3, 200]));
    check_rnn_forward(gluon->rnn->DropoutCell(0.5), mx->nd->ones([8, 3, 200]), 0);
    check_rnn_forward(
        gluon->rnn->ZoneoutCell(
            gluon->rnn->LSTMCell(100, input_size=>200),
            0.5, 0.2
        ),
        mx->nd->ones([8, 3, 200]),
        0
    );
    my $net = gluon->rnn->SequentialRNNCell();
    $net->add(gluon->rnn->LSTMCell(100, input_size=>200));
    $net->add(gluon->rnn->RNNCell(100, input_size=>100));
    $net->add(gluon->rnn->GRUCell(100, input_size=>100));
    check_rnn_forward($net, mx->nd->ones([8, 3, 200]));
}

test_rnn_cells();

sub check_rnn_layer_forward
{
    my ($layer, $inputs, $states) = @_;
    $layer->collect_params()->initialize();
    $inputs->attach_grad;
    my $out;
    mx->autograd->record(sub {
        if(defined $states)
        {
            $out = $layer->($inputs, $states);
            ok(@$out == 2);
            $out = $out->[0];
        }
        else
        {
            $out = $layer->($inputs);
            ok(blessed $out and $out->isa('AI::MXNet::NDArray'));
        }
        $out->backward();
    });

    my $pdl_out = $out->aspdl;
    my $pdl_dx = $inputs->grad->aspdl;
    $layer->hybridize;
    mx->autograd->record(sub {
        if(defined $states)
        {
            ($out, $states) = $layer->($inputs, $states);
            ok(blessed $out and $out->isa('AI::MXNet::NDArray'));
        }
        else
        {
            $out = $layer->($inputs, $states);
            ok(blessed $out and $out->isa('AI::MXNet::NDArray'));
        }
        $out->backward();
    });
    ok(almost_equal($pdl_out, $out->aspdl, 1e-3));
    ok(almost_equal($pdl_dx, $inputs->grad->aspdl, 1e-3));
}

sub test_rnn_layers
{
    check_rnn_layer_forward(gluon->rnn->RNN(10, 2), mx->nd->ones([8, 3, 20]));
    check_rnn_layer_forward(gluon->rnn->RNN(10, 2, bidirectional=>1), mx->nd->ones([8, 3, 20]), mx->nd->ones([4, 3, 10]));
    check_rnn_layer_forward(gluon->rnn->LSTM(10, 2), mx->nd->ones([8, 3, 20]));
    check_rnn_layer_forward(gluon->rnn->LSTM(10, 2, bidirectional=>1), mx->nd->ones([8, 3, 20]), [mx->nd->ones([4, 3, 10]), mx->nd->ones([4, 3, 10])]);
    check_rnn_layer_forward(gluon->rnn->GRU(10, 2), mx->nd->ones([8, 3, 20]));
    check_rnn_layer_forward(gluon->rnn->GRU(10, 2, bidirectional=>1), mx->nd->ones([8, 3, 20]), mx->nd->ones([4, 3, 10]));
}

test_rnn_layers();