NAME
AI::MXNet::Module::Bucketing
SYNOPSIS
my
$buckets
= [10, 20, 30, 40, 50, 60];
my
$start_label
= 1;
my
$invalid_label
= 0;
my
(
$train_sentences
,
$vocabulary
) = tokenize_text(
'./data/sherlockholmes.train.txt'
,
start_label
=>
$start_label
,
invalid_label
=>
$invalid_label
);
my
(
$validation_sentences
) = tokenize_text(
'./data/sherlockholmes.test.txt'
,
vocab
=>
$vocabulary
,
start_label
=>
$start_label
,
invalid_label
=>
$invalid_label
);
my
$data_train
= mx->rnn->BucketSentenceIter(
$train_sentences
,
$batch_size
,
buckets
=>
$buckets
,
invalid_label
=>
$invalid_label
);
my
$data_val
= mx->rnn->BucketSentenceIter(
$validation_sentences
,
$batch_size
,
buckets
=>
$buckets
,
invalid_label
=>
$invalid_label
);
my
$stack
= mx->rnn->SequentialRNNCell();
for
my
$i
(0..
$num_layers
-1)
{
$stack
->add(mx->rnn->LSTMCell(
num_hidden
=>
$num_hidden
,
prefix
=>
"lstm_l${i}_"
));
}
my
$sym_gen
=
sub
{
my
$seq_len
=
shift
;
my
$data
= mx->sym->Variable(
'data'
);
my
$label
= mx->sym->Variable(
'softmax_label'
);
my
$embed
= mx->sym->Embedding(
data
=>
$data
,
input_dim
=>
scalar
(
keys
%$vocabulary
),
output_dim
=>
$num_embed
,
name
=>
'embed'
);
$stack
->
reset
;
my
(
$outputs
,
$states
) =
$stack
->unroll(
$seq_len
,
inputs
=>
$embed
,
merge_outputs
=> 1);
my
$pred
= mx->sym->Reshape(
$outputs
,
shape
=> [-1,
$num_hidden
]);
$pred
= mx->sym->FullyConnected(
data
=>
$pred
,
num_hidden
=>
scalar
(
keys
%$vocabulary
),
name
=>
'pred'
);
$label
= mx->sym->Reshape(
$label
,
shape
=> [-1]);
$pred
= mx->sym->SoftmaxOutput(
data
=>
$pred
,
label
=>
$label
,
name
=>
'softmax'
);
return
(
$pred
, [
'data'
], [
'softmax_label'
]);
};
my
$contexts
;
if
(
defined
$gpus
)
{
$contexts
= [
map
{ mx->gpu(
$_
) }
split
(/,/,
$gpus
)];
}
else
{
$contexts
= mx->cpu(0);
}
my
$model
= mx->mod->BucketingModule(
sym_gen
=>
$sym_gen
,
default_bucket_key
=>
$data_train
->default_bucket_key,
context
=>
$contexts
);
$model
->fit(
$data_train
,
eval_data
=>
$data_val
,
eval_metric
=> mx->metric->Perplexity(
$invalid_label
),
kvstore
=>
$kv_store
,
optimizer
=>
$optimizer
,
optimizer_params
=> {
learning_rate
=>
$lr
,
momentum
=>
$mom
,
wd
=>
$wd
,
},
initializer
=> mx->init->Xavier(
factor_type
=>
"in"
,
magnitude
=> 2.34),
num_epoch
=>
$num_epoch
,
batch_end_callback
=> mx->callback->Speedometer(
$batch_size
,
$disp_batches
),
(
$chkp_epoch
? (
epoch_end_callback
=> mx->rnn->do_rnn_checkpoint(
$stack
,
$chkp_prefix
,
$chkp_epoch
)) : ())
);
DESCRIPTION
Implements the AI::MXNet::Module::Base API, and allows multiple
symbols to be used depending on the `bucket_key` provided by
each
different
mini-batch of data
new
Parameters
----------
$sym_gen
: subref or any perl object that overloads &{} op
A
sub
when
called
with
a bucket key, returns a list
with
triple
of (
$symbol
,
$data_names
,
$label_names
).
$default_bucket_key
: str or anything
else
The key
for
the
default
bucket.
$logger
: Logger
$context
: AI::MXNet::Context or array
ref
of AI::MXNet::Context objects
Default is cpu(0)
$work_load_list
: array
ref
of Num
Default is
undef
, indicating uniform workload.
$fixed_param_names
: arrayref of str
Default is
undef
, indicating
no
network parameters are fixed.
$state_names
: arrayref of str
states are similar to data and label, but not provided by data iterator.
Instead they are initialized to 0 and can be set by set_states()
bind
Binding
for
a AI::MXNet::Module::Bucketing means setting up the buckets and
bind
the
executor
for
the
default
bucket key. Executors corresponding to other
keys
are
binded afterwards
with
switch_bucket.
Parameters
----------
:
$data_shapes
: ArrayRef[AI::MXNet::DataDesc|NameShape]
This should correspond to the symbol
for
the
default
bucket.
:
$label_shapes
= : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]
This should correspond to the symbol
for
the
default
bucket.
:
$for_training
: Bool
Default is 1.
:
$inputs_need_grad
: Bool
Default is 0.
:
$force_rebind
: Bool
Default is 0.
:
$shared_module
: AI::MXNet::Module::Bucketing
Default is
undef
. This value is currently not used.
:
$grad_req
: str, array
ref
of str, hash
ref
of str to str
Requirement
for
gradient accumulation. Can be
'write'
,
'add'
, or
'null'
(defaults to
'write'
).
Can be specified globally (str) or
for
each
argument (array
ref
, hash
ref
).
:
$bucket_key
: str
switch_bucket
Switch to a different bucket. This will change
$self
->_curr_module.
Parameters
----------
:
$bucket_key
: str (or any perl object that overloads
""
op)
The key of the target bucket.
:
$data_shapes
: Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]
Typically
$data_batch
->provide_data.
:
$label_shapes
: Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]
Typically
$data_batch
->provide_label.
save_checkpoint
Save current progress to a checkpoint.
Use mx->callback->module_checkpoint as epoch_end_callback to save during training.
Parameters
----------
prefix : str
The file prefix to checkpoint to
epoch :
int
The current epoch number
save_optimizer_states : bool
Whether to save optimizer states
for
later training