NAME

AI::MXNet::Module::Bucketing

SYNOPSIS

my $buckets = [10, 20, 30, 40, 50, 60];
my $start_label   = 1;
my $invalid_label = 0;

my ($train_sentences, $vocabulary) = tokenize_text(
    './data/ptb.train.txt', start_label => $start_label,
    invalid_label => $invalid_label
);
my ($validation_sentences) = tokenize_text(
    './data/ptb.test.txt', vocab => $vocabulary,
    start_label => $start_label, invalid_label => $invalid_label
);
my $data_train  = mx->rnn->BucketSentenceIter(
    $train_sentences, $batch_size, buckets => $buckets,
    invalid_label => $invalid_label
);
my $data_val    = mx->rnn->BucketSentenceIter(
    $validation_sentences, $batch_size, buckets => $buckets,
    invalid_label => $invalid_label
);

my $stack = mx->rnn->SequentialRNNCell();
for my $i (0..$num_layers-1)
{
    $stack->add(mx->rnn->LSTMCell(num_hidden => $num_hidden, prefix => "lstm_l${i}_"));
}

my $sym_gen = sub {
    my $seq_len = shift;
    my $data  = mx->sym->Variable('data');
    my $label = mx->sym->Variable('softmax_label');
    my $embed = mx->sym->Embedding(
        data => $data, input_dim => scalar(keys %$vocabulary),
        output_dim => $num_embed, name => 'embed'
    );
    $stack->reset;
    my ($outputs, $states) = $stack->unroll($seq_len, inputs => $embed, merge_outputs => 1);
    my $pred = mx->sym->Reshape($outputs, shape => [-1, $num_hidden]);
    $pred    = mx->sym->FullyConnected(data => $pred, num_hidden => scalar(keys %$vocabulary), name => 'pred');
    $label   = mx->sym->Reshape($label, shape => [-1]);
    $pred    = mx->sym->SoftmaxOutput(data => $pred, label => $label, name => 'softmax');
    return ($pred, ['data'], ['softmax_label']);
};

my $contexts;
if(defined $gpus)
{
    $contexts = [map { mx->gpu($_) } split(/,/, $gpus)];
}
else
{
    $contexts = mx->cpu(0);
}

my $model = mx->mod->BucketingModule(
    sym_gen             => $sym_gen,
    default_bucket_key  => $data_train->default_bucket_key,
    context             => $contexts
);

$model->fit(
    $data_train,
    eval_data           => $data_val,
    eval_metric         => mx->metric->Perplexity($invalid_label),
    kvstore             => $kv_store,
    optimizer           => $optimizer,
    optimizer_params    => {
                                learning_rate => $lr,
                                momentum      => $mom,
                                wd            => $wd,
                        },
    initializer         => mx->init->Xavier(factor_type => "in", magnitude => 2.34),
    num_epoch           => $num_epoch,
    batch_end_callback  => mx->callback->Speedometer($batch_size, $disp_batches),
    ($chkp_epoch ? (epoch_end_callback  => mx->rnn->do_rnn_checkpoint($stack, $chkp_prefix, $chkp_epoch)) : ())
);

DESCRIPTION

Implements the AI::MXNet::Module::Base API, and allows multiple
symbols to be used depending on the `bucket_key` provided by each different
mini-batch of data

new

Parameters
----------
$sym_gen : subref or any perl object that overloads &{} op
    A sub when called with a bucket key, returns a list with triple
    of ($symbol, $data_names, $label_names).
$default_bucket_key : str or anything else
    The key for the default bucket.
$logger : Logger
$context : AI::MXNet::Context or array ref of AI::MXNet::Context objects
    Default is cpu(0)
$work_load_list : array ref of Num
    Default is undef, indicating uniform workload.
$fixed_param_names: arrayref of str
    Default is undef, indicating no network parameters are fixed.
$state_names : arrayref of str
    states are similar to data and label, but not provided by data iterator.
    Instead they are initialized to 0 and can be set by set_states()

bind

Binding for a AI::MXNet::Module::Bucketing means setting up the buckets and bind the
executor for the default bucket key. Executors corresponding to other keys are
binded afterwards with switch_bucket.

Parameters
----------
:$data_shapes : ArrayRef[AI::MXNet::DataDesc|NameShape]
    This should correspond to the symbol for the default bucket.
:$label_shapes= : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]
    This should correspond to the symbol for the default bucket.
:$for_training : Bool
    Default is 1.
:$inputs_need_grad : Bool
    Default is 0.
:$force_rebind : Bool
    Default is 0.
:$shared_module : AI::MXNet::Module::Bucketing
    Default is undef. This value is currently not used.
:$grad_req : str, array ref of str, hash ref of str to str
    Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
    (defaults to 'write').
    Can be specified globally (str) or for each argument (array ref, hash ref).
:$bucket_key : str
    bucket key for binding. by default is to use the ->default_bucket_key

switch_bucket

Switch to a different bucket. This will change $self->_curr_module.

Parameters
----------
:$bucket_key : str (or any perl object that overloads "" op)
    The key of the target bucket.
:$data_shapes :  Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]
    Typically $data_batch->provide_data.
:$label_shapes : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]
    Typically $data_batch->provide_label.