MXNet Gluon Fit API
In this tutorial, you will learn how to use the Gluon Fit API which is the easiest way to train deep learning models using the Gluon API in Apache MXNet.
With the Fit API, you can train a deep learning model with a minimal amount of code. Just specify the network, loss function and the data you want to train on. You don’t need to worry about the boiler plate code to loop through the dataset in batches (often called as ‘training loop’). Advanced users can train with bespoke training loops, and many of these use cases will be covered by the Fit API.
To demonstrate the Fit API, you will train an image classification model using the ResNet-18 neural network architecture. The model will be trained using the Fashion-MNIST dataset.
Basic Usage
1 | import mxnet as mx |
Advanced Usage
The Fit API is also customizable with several Event Handlers which give a fine grained control over the steps in training and exposes callback methods that provide control over the stages involved in training. Available callback methods are: train_begin, train_end, batch_begin, batch_end, epoch_begin and epoch_end.
You can use built-in event handlers such as LoggingHandler, CheckpointHandler or EarlyStoppingHandler to log and save the model at certain time-steps during training. You can also stop the training when the model’s performance plateaus. There are also some default utility handlers that will be added to your estimator by default. For example, StoppingHandler is used to control when the training ends, based on number of epochs or number of batches trained. MetricHandler is used to calculate training metrics at end of each batch and epoch. ValidationHandler is used to validate your model on test data at each epoch’s end and then calculate validation metrics. You can create these utility handlers with different configurations and pass to estimator. This will override the default handler configuration. You can create a custom handler by inheriting one or multiple base event handlers including: TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd.
Handler
1 | import mxnet as mx |
Train
1 | import mxnet as mx |