NAME
AI::MXNet::Module::Bucketing
SYNOPSIS
my $buckets = [10, 20, 30, 40, 50, 60];
my $start_label = 1;
my $invalid_label = 0;
my ($train_sentences, $vocabulary) = tokenize_text(
'./data/ptb.train.txt', start_label => $start_label,
invalid_label => $invalid_label
);
my ($validation_sentences) = tokenize_text(
'./data/ptb.test.txt', vocab => $vocabulary,
start_label => $start_label, invalid_label => $invalid_label
);
my $data_train = mx->rnn->BucketSentenceIter(
$train_sentences, $batch_size, buckets => $buckets,
invalid_label => $invalid_label
);
my $data_val = mx->rnn->BucketSentenceIter(
$validation_sentences, $batch_size, buckets => $buckets,
invalid_label => $invalid_label
);
my $stack = mx->rnn->SequentialRNNCell();
for my $i (0..$num_layers-1)
{
$stack->add(mx->rnn->LSTMCell(num_hidden => $num_hidden, prefix => "lstm_l${i}_"));
}
my $sym_gen = sub {
my $seq_len = shift;
my $data = mx->sym->Variable('data');
my $label = mx->sym->Variable('softmax_label');
my $embed = mx->sym->Embedding(
data => $data, input_dim => scalar(keys %$vocabulary),
output_dim => $num_embed, name => 'embed'
);
$stack->reset;
my ($outputs, $states) = $stack->unroll($seq_len, inputs => $embed, merge_outputs => 1);
my $pred = mx->sym->Reshape($outputs, shape => [-1, $num_hidden]);
$pred = mx->sym->FullyConnected(data => $pred, num_hidden => scalar(keys %$vocabulary), name => 'pred');
$label = mx->sym->Reshape($label, shape => [-1]);
$pred = mx->sym->SoftmaxOutput(data => $pred, label => $label, name => 'softmax');
return ($pred, ['data'], ['softmax_label']);
};
my $contexts;
if(defined $gpus)
{
$contexts = [map { mx->gpu($_) } split(/,/, $gpus)];
}
else
{
$contexts = mx->cpu(0);
}
my $model = mx->mod->BucketingModule(
sym_gen => $sym_gen,
default_bucket_key => $data_train->default_bucket_key,
context => $contexts
);
$model->fit(
$data_train,
eval_data => $data_val,
eval_metric => mx->metric->Perplexity($invalid_label),
kvstore => $kv_store,
optimizer => $optimizer,
optimizer_params => {
learning_rate => $lr,
momentum => $mom,
wd => $wd,
},
initializer => mx->init->Xavier(factor_type => "in", magnitude => 2.34),
num_epoch => $num_epoch,
batch_end_callback => mx->callback->Speedometer($batch_size, $disp_batches),
($chkp_epoch ? (epoch_end_callback => mx->rnn->do_rnn_checkpoint($stack, $chkp_prefix, $chkp_epoch)) : ())
);
DESCRIPTION
Implements the AI::MXNet::Module::Base API, and allows multiple symbols to be used depending on the `bucket_key` provided by each different mini-batch of data
new
Parameters ---------- $sym_gen : subref or any perl object that overloads &{} op A sub when called with a bucket key, returns a list with triple of ($symbol, $data_names, $label_names). $default_bucket_key : str or anything else The key for the default bucket. $logger : Logger $context : AI::MXNet::Context or array ref of AI::MXNet::Context objects Default is cpu(0) $work_load_list : array ref of Num Default is undef, indicating uniform workload. $fixed_param_names: arrayref of str Default is undef, indicating no network parameters are fixed. $state_names : arrayref of str states are similar to data and label, but not provided by data iterator. Instead they are initialized to 0 and can be set by set_states()
bind
Binding for a AI::MXNet::Module::Bucketing means setting up the buckets and bind the executor for the default bucket key. Executors corresponding to other keys are binded afterwards with switch_bucket.
Parameters ---------- :$data_shapes : ArrayRef[AI::MXNet::DataDesc|NameShape] This should correspond to the symbol for the default bucket. :$label_shapes= : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] This should correspond to the symbol for the default bucket. :$for_training : Bool Default is 1. :$inputs_need_grad : Bool Default is 0. :$force_rebind : Bool Default is 0. :$shared_module : AI::MXNet::Module::Bucketing Default is undef. This value is currently not used. :$grad_req : str, array ref of str, hash ref of str to str Requirement for gradient accumulation. Can be 'write', 'add', or 'null' (defaults to 'write'). Can be specified globally (str) or for each argument (array ref, hash ref). $bucket_key : str bucket key for binding. by default is to use the ->default_bucket_key
switch_bucket
Switch to a different bucket. This will change $self->_curr_module.
Parameters ---------- :$bucket_key : str (or any perl object that overloads "" op) The key of the target bucket. :$data_shapes : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] Typically $data_batch->provide_data. :$label_shapes : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] Typically $data_batch->provide_label.