# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. use strict; use warnings; use AI::MXNet qw(mx); use AI::MXNet::TestUtils qw(almost_equal same); use Test::More tests => 17; sub test_ndarray_reshape { my $tensor = mx->nd->array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); my $true_res = mx->nd->arange(stop => 8) + 1; is_deeply($tensor->reshape([-1])->aspdl->unpdl, $true_res->aspdl->unpdl); $true_res = mx->nd->array([[1, 2, 3, 4], [5, 6, 7, 8]]); is_deeply($tensor->reshape([2, -1])->aspdl->unpdl, $true_res->aspdl->unpdl); $true_res = mx->nd->array([[1, 2], [3, 4], [5, 6], [7, 8]]); is_deeply($tensor->reshape([-1, 2])->aspdl->unpdl, $true_res->aspdl->unpdl); } sub test_moveaxis { my $X = mx->nd->array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]); my $res = $X->moveaxis(0, 2)->aspdl; my $true_res = mx->nd->array([[[ 1., 7.], [ 2., 8.], [ 3., 9.]], [[ 4., 10.], [ 5., 11.], [ 6., 12.]]]); is_deeply($res->unpdl, $true_res->aspdl->unpdl); is_deeply($X->moveaxis(2, 0)->shape, [3, 2, 2]); } sub test_output { my $shape = [2,2]; my $ones = mx->nd->ones($shape); my $zeros = mx->nd->zeros($shape); my $out = mx->nd->zeros($shape); mx->nd->ones($shape, out=>$out); ok(almost_equal($out->aspdl, $ones->aspdl)); mx->nd->zeros($shape, out=>$out); ok(almost_equal($out->aspdl, $zeros->aspdl)); mx->nd->full($shape, 2, out=>$out); ok(almost_equal($out->aspdl, $ones->aspdl * 2)); } sub test_cached { my $sym = mx->sym->Convolution(kernel=>[3, 3], num_filter=>10) + 2; my $op = mx->nd->CachedOp($sym); my $data = mx->nd->ones([3, 4, 10, 10]); my $weight = mx->nd->ones([10, 4, 3, 3]); my $bias = mx->nd->ones([10]); my $o1 = $op->($data, $weight, $bias); $bias .= 2; my $o2 = $op->($data, $weight, $bias); ok(almost_equal($o2->aspdl, $o1->aspdl+1)); $o2 .= 0; $op->($data, $weight, $bias, out=>$o2); ok(almost_equal($o2->aspdl, $o1->aspdl+1)); $weight->attach_grad(); $bias->attach_grad(); my $o; mx->autograd->record(sub { $bias = $bias + 1; $o = $op->($data, $weight, $bias); $o = $o * 2; $o->backward(); }); mx->autograd->record(sub { $bias = $bias + 1; $o = $op->($data, $weight, $bias); $o = $o * 2; $o->backward(retain_graph=>1); $o->backward(); }); # try a different shape $data = mx->nd->ones([5, 2, 10, 10]); $weight = mx->nd->ones([10, 2, 3, 3]); $bias = mx->nd->ones([10]); $data->attach_grad; mx->autograd->record(sub { $bias = $bias + 1; $o = $op->($data, $weight, $bias); $o = $o * 2; $o->backward(); }); } sub test_ndarray_slice { my $shape = [10]; my $A = mx->random->uniform(-10, 10, $shape); my $A2 = $A->aspdl; ok(same($A->slice([3,7])->aspdl, $A2->slice([3, 7]))); $A2->slice([3, 7]) *= 10; $A->slice([3,7]) .= $A2->slice([3, 7]); ok(same($A->slice([3,7])->aspdl, $A2->slice([3, 7]))); $shape = [3,4,5,6,7]; $A = mx->nd->random->uniform(shape=>$shape); $A2 = $A->aspdl; ok(same($A->slice([1], [3,3], 'X', [1,4], 'X')->aspdl, $A2->slice('X', [1,4], 'X', [3,3], [1]))); ok(($A->slice([1], [3,3], 'X', [1,4], 'X') == mx->nd->array($A2->slice('X', [1,4], 'X', [3,3], [1])))->aspdl->all); ok($A->slice(1,2,3,4,5)->asscalar() == $A2->at(5, 4, 3, 2, 1)); my $a = mx->nd->array([[0, 1], [2, 3]]); ok(($a->slice([[1, 1, 0], [0, 1, 0]])->aspdl == mx->nd->array([2, 3, 0])->aspdl)->all); ok(($a->slice([mx->nd->array([1, 1, 0]), mx->nd->array([0, 1, 0])])->aspdl == mx->nd->array([2, 3, 0])->aspdl)->all); } test_ndarray_slice(); test_ndarray_reshape(); test_moveaxis(); test_output(); test_cached();