NAME
AI::MXNet::Gluon::Trainer
DESCRIPTION
Applies an `Optimizer` on a set of Parameters. Trainer should
be used together with `autograd`.
Parameters
----------
params : AI::MXNet::Gluon::ParameterDict
The set of parameters to optimize.
optimizer : str or Optimizer
The optimizer to use. See
`help <https://mxnet.io/api/python/optimization/optimization.html#the-mxnet-optimizer-package>`_
on Optimizer for a list of available optimizers.
optimizer_params : hash ref
Key-word arguments to be passed to optimizer constructor. For example,
{learning_rate => 0.1}. All optimizers accept learning_rate, wd (weight decay),
clip_gradient, and lr_scheduler. See each optimizer's
constructor for a list of additional supported arguments.
kvstore : str or KVStore
kvstore type for multi-gpu and distributed training. See help on
mx->kvstore->create for more information.
compression_params : hash ref
Specifies type of gradient compression and additional arguments depending
on the type of compression being used. For example, 2bit compression requires a threshold.
Arguments would then be {type => '2bit', threshold => 0.5}
See AI::MXNet::KVStore->set_gradient_compression method for more details on gradient compression.
update_on_kvstore : Bool, default undef
Whether to perform parameter updates on kvstore. If undef, then trainer will choose the more
suitable option depending on the type of kvstore.
Properties
----------
learning_rate : float
The current learning rate of the optimizer. Given an Optimizer object
optimizer, its learning rate can be accessed as optimizer->learning_rate.
step
Makes one step of parameter update. Should be called after
`autograd->backward()` and outside of `record()` scope.
For normal parameter updates, `step()` should be used, which internally calls
`allreduce_grads()` and then `update()`. However, if you need to get the reduced
gradients to perform certain transformation, such as in gradient clipping, then
you may want to manually call `allreduce_grads()` and `update()` separately.
Parameters
----------
$batch_size : Int
Batch size of data processed. Gradient will be normalized by `1/batch_size`.
Set this to 1 if you normalized loss manually with `loss = mean(loss)`.
$ignore_stale_grad : Bool, optional, default=False
If true, ignores Parameters with stale gradient (gradient that has not
been updated by `backward` after last step) and skip update.
allreduce_grads
For each parameter, reduce the gradients from different contexts.
Should be called after `autograd.backward()`, outside of `record()` scope,
and before `trainer.update()`.
For normal parameter updates, `step()` should be used, which internally calls
`allreduce_grads()` and then `update()`. However, if you need to get the reduced
gradients to perform certain transformation, such as in gradient clipping, then
you may want to manually call `allreduce_grads()` and `update()` separately.
set_learning_rate
Sets a new learning rate of the optimizer.
Parameters
----------
lr : float
The new learning rate of the optimizer.
update
Makes one step of parameter update.
Should be called after autograd->backward() and outside of record() scope,
and after trainer->update`.
For normal parameter updates, step() should be used, which internally calls
allreduce_grads() and then update(). However, if you need to get the reduced
gradients to perform certain transformation, such as in gradient clipping, then
you may want to manually call allreduce_grads() and update() separately.
Parameters
----------
$batch_size : Int
Batch size of data processed. Gradient will be normalized by `1/$batch_size`.
Set this to 1 if you normalized loss manually with $loss = mean($loss).
$ignore_stale_grad : Bool, optional, default=False
If true, ignores Parameters with stale gradient (gradient that has not
been updated by backward() after last step) and skip update.
save_states
Saves trainer states (e.g. optimizer, momentum) to a file.
Parameters
----------
fname : str
Path to output states file.
load_states
Loads trainer states (e.g. optimizer, momentum) from a file.
Parameters
----------
fname : str
Path to input states file.