Detecting hidden but non-trivial problems in transfer learning models using Amazon SageMaker Debugger
Rapid development of deep learning technology has produced an abundance of open-sourced, pre-trained models in computer vision and natural language processing. As a result, transfer learning has become a popular approach in deep learning. Transfer learning is a machine learning technique where a model pre-trained on one task is fine-tuned on a new task. Given the significant compute and time resources required to develop neural network models, adapting a pre-trained model to new data is compelling in business applications. If you’re new to deep learning, transfer learning is also a good starting point because you don’t have to build a model from scratch. For deep learning beginners, one question you may have is, how do I systematically examine model predictions to see what mistakes were made now that my data is in the form of pictures or text?
In this post, we show you an end-to-end example of doing transfer learning by using Amazon SageMaker Debugger to detect hidden problems that may have serious consequences. Debugger doesn’t incur additional costs if you’re running training on Amazon SageMaker. Moreover, you can enable the built-in rules with just a few lines of code when you call the Amazon SageMaker estimator function. For our use case, we do transfer learning using a ResNet model to recognize German traffic signs [1].
In this post, we focus on issues that occur during training. For more information about using Debugger for inference and explainability, see Detecting and analyzing incorrect model predictions with Amazon SageMaker Model Monitor and Debugger.
Setting up a transfer learning training job
For our use case, we want to adapt a pre-trained computer vision model to recognize traffic signs in Germany. We use the GTSRB dataset [1] for this new task. You can find the notebook and training script in the GitHub repo.
Applying preprocessing on the dataset
We first apply some typical preprocessing for a ResNet model on our dataset (see the complete notebook for where to download the dataset). To improve model generalization, we apply data augmentation (RandomResizedCrop
and RandomHorizontalFlip
). These operations ensure that an image looks differently in each epoch. Lastly, we normalize the data: because the model has been pre-trained on the ImageNet dataset, we apply the same preprocessing and normalization (subtract the mean and divide by the standard deviation of the ImageNet dataset). See the following code:
For the validation dataset, we don’t apply data augmentation and only resize images to the appropriate size:
Loading a pre-trained ResNet model
Because we have a limited variety of traffic signs, we pick a simpler ResNet model for this task: resnet18. You can load a ResNet18
from the PyTorch model zoo with pre-trained weights using just one line of code:
The model has been pre-trained on the ImageNet dataset, which consists of 1,000 image classes. For our use case, we fine-tune it on a dataset that only has 43 classes. We adjust the last layer, which is a fully connected Linear
layer:
Because we train a multi-classification model, we use the cross entropy loss function:
The following code blocks define the training loop. We iterate over ten epochs, perform the forward and backward pass, and update the model parameters.
If you just run the preceding code, the training runs just on your Amazon SageMaker notebook. To make the most out of Amazon SageMaker, you want to use the pre-built DLC containers, which come with optimized performance and let you access the full feature sets of Debugger at no additional cost. By running on Amazon SageMaker, we can easily train our models at scale. Most deep learning models are trained on GPU due to the computational intensity. With Amazon SageMaker, GPU instances are automatically created and torn down after training completes, so you only pay for the time the resources were used.
Making a training script compatible with Amazon SageMaker
To run training on Amazon SageMaker, you need to change the location variable in your pre-processing code to the generic Amazon SageMaker environment variables. When Amazon SageMaker spins up the training instance, it automatically downloads the training and validation data from Amazon Simple Storage Service (Amazon S3) into a local folder on the training instance. We can retrieve the local path with os.environ['SM_CHANNEL_TRAIN']
and os.environ['SM_CHANNEL_TEST']
:
After the change, you should save all the model code as a separate script called train.py
.
Uploading data to the S3 bucket
As mentioned in the previous step, Amazon SageMaker automatically downloads training and validation data into the training instance. We need to upload the data to Amazon S3 first. You can find detailed instructions on how to do that in the notebook.
Setting up debugger
Now that we have defined the training script and uploaded the data, we’re ready to start training on Amazon SageMaker. We run the training with several Debugger built-in rules enabled. Via the Amazon SageMaker Python SDK and the rule_configs
module, we can select any of the 20 available built-in rules, which run at no additional cost. For demonstration purposes, we select loss_not_decreasing, class_imbalance and dead_relu. We can configure several parameters for these rules: for instance, most rules take a threshold parameter that can be adjusted to define when a rule should trigger. We can also define the set of tensors the rules should run on.
The class imbalance rule takes the inputs into the loss function and counts the number of samples per class that the model has seen throughout training. To create the rule, we specify rule_configs.class_imbalance()
and the rule runs on the inputs of the loss function. To fine-tune the model, we use the cross entropy loss function, which takes predictions and labels and outputs a loss value. See the following code:
Next we define the loss_not_decreasing
rule. It determines if the training or validation loss is decreasing and raises an issue if the loss has not decreased by a certain percentage in the last few iterations. In contrast to the previous rule, this rule runs on the outputs of the loss function (CrossEntropyLoss_output_0
). See the following code:
The dead_relu
rule identifies how many rectified linear unit (ReLU) activations are outputting zero values. ReLU is a non-linear activation function used in many state-of-the-art models. It increases linearly for increasing positive values and outputs zero otherwise. A model can suffer from the dying ReLU problem, where the gradients become zero due to the activation output being zero. If the majority of ReLU activations output zero values, the model can’t effectively learn because weights are no longer getting updated. We instantiate the rule by specifying rule_configs.dead_relu()
, and the rule runs on all tensors that captured outputs from ReLU activations:
To record additional tensors, we can specify a debugger hook configuration. We can either use default collections such as weights and gradients or define our own custom collection. The following collection saves model inputs and loss function inputs and outputs. We just need to specify a regular expression of tensor names. We save the tensors every 500 steps, and a step is one forward and backward pass. So we get tensors for step 0, 500, 1,000, and so on. See the following code:
For a full list of collections and rules Debugger offers, see List of Debugger Built-in Rules. Debugger captures the tensor collections you specified throughout the training steps and automatically analyzes them against the rules.
Calling the Amazon SageMaker training API
We define the PyTorch estimator that takes the separate training script we saved earlier and specify the instance type that Amazon SageMaker creates for us. To run the training with Debugger and built-in rules, we only have to pass the list of rules and the debugger hook configuration:
Now we start the training on Amazon SageMaker by calling fit()
. The function takes a dictionary that specifies the location of the training and validation data in Amazon S3. The keys of the dictionary are the name of data channels Amazon SageMaker creates in the training instance. See the following code:
While the training is in progress, we can monitor the rule status in real time in Amazon SageMaker Studio. It turns out that the loss_not_decreasing
and class_imbalance
rules are triggered. The training runs for 10 epochs and reaches a final test accuracy of 96.3%.
This seems good, but why were the rules triggered? Let’s dive into the data Debugger captured to find out the root causes.
Using SageMaker Debugger rules and data to uncover hidden problems
In this section, we investigate the data to find any hidden problems, create custom rules to fix the model, and rerun the training.
Inspecting loss_not_decreasing
We use Debugger to investigate what triggered the loss_not_decreasing
rule. We use the smdebug library, which provides all the functionalities to read and access Debugger data. First we create a trial object that takes the path where the Debugger data is stored as input. This can either be a local or Amazon S3 path. With just a few lines of code, we can retrieve and visualize the loss values as training is still in progress. With trial.steps()
, we retrieve the number of recorded steps: a step is one forward and backward pass. We can also specify a mode to retrieve data from the training (modes.TRAIN)
or validation phase (modes.EVAL
). Debugger’s default sampling interval is 500, so we get loss values for step 0, 500, 1,000, and so on.
To access the loss values, we pass the name of the loss into the trial.tensor()
function. The cross entropy loss function we picked measures the performance of a multi-classification model. It takes two inputs: the model outputs and ground truth labels. We can access its outputs via trial.tensor('CrossEntropyLoss_output_0').values():
This code returns a dictionary in which the keys are the step numbers and the values are the loss values. We can now easily visualize the loss values as training is still in progress. See the following code:
The blue curve in the following graph shows that the default training configuration ran the training for too long. Instead of training for 4,000 steps, early_stopping
should have been applied after 1,000 steps. We can use Debugger to enable auto-termination, which stops the training when a rule triggers. For our use case, doing so reduces compute time by more than half (orange curve).
Debugger can auto-terminate training jobs. Metrics are sent to Amazon CloudWatch, so you can set up a CloudWatch alarm and AWS Lambda function that stops a training job if a rule triggers.
For more information about how the auto-termination feature helped one customer reduce compute costs by 70%, see Autodesk optimizes visual similarity search model in Fusion 360 with Amazon SageMaker Debugger.
Inspecting class_imbalance
Real-world datasets are often imbalanced and noisy. If the model training doesn’t account for these factors, it produces a model that has low or no predictive power for the classes with few samples. You can address this in different ways, such as during data-loading, when more samples can be drawn from the under-represented classes, or you can adjust the loss function to assign a higher penalty to incorrect predictions using class weights.
To investigate the class imbalance issue, we retrieve the inputs of the loss function (previously we retrieved the outputs) The loss function takes the model predictions and the ground truth labels as inputs. We use the latter (CrossEntropyLoss_input_1
) to count the number of samples the model has seen during training:
The following visualization shows a high imbalance and several classes with fewer than a hundred samples.
To fix the class imbalance issue, we change the default configuration of the dataloaders to take the class weights into account and draw more samples from classes with fewer samples. Therefore, we define WeightedRandomSampler:
During training, the dataloader now draws more samples from classes with lower counts. Class imbalance may lead to the problem where the model performs well on classes with a lot of samples but poorly on classes with fewer counts. Because we trained the model without WeightedRandomSampler
, let’s see which classes had particularly low accuracy by looking at the confusion matrix.
Visualizing the confusion matrix in real time
To evaluate the performance of our model, we retrieve labels and predictions and create the confusion matrix:
Each row in the matrix corresponds to the actual class, and the column indicates the predicted classes. For example, the first row shows class 0 and how often it was predicted as class 0, class 1, class 2, and so on. Ideally, we want high counts on the diagonal, because these are correctly predicted classes. Elements not on the diagonal are incorrect predictions. The confusion matrix helps us determine if particular classes in our dataset get confused more often with each other. This can happen, for instance, because samples from two different classes may be very similar. Debugger also provides a confusion built-in rule that computes the confusion matrix while the training is in progress and triggers if the ratio of data on-diagonal values and off-diagonal values exceeds a pre-defined threshold.
The following image shows that in most cases our model is predicting the correct classes, but there are a few outliers. You can use Debugger to look more closely into those outliers.
Inspecting incorrect model predictions
To find out what is causing those outliers in the confusion matrix, we investigate the examples upon which the model made false predictions. To do this analysis, we take both inputs into the loss function into account: CrossEntropyLoss_input_0
presents the model predictions, and CrossEntropyLoss_input_1
are the labels. We also retrieve the model inputs ResNet_input_0
, which presents the input images. We perform the analysis on data recorded during the validation phase, so we specify mode=modes.EVAL
.
We iterate over the predictions and model inputs saved by Debugger and select those where the label and prediction do not match. Then we plot the predictions and corresponding images:
The following images show the result of the code segment. The analysis reveals that the model is often confused by traffic signs that involve a direction. Clearly this is a severe model bug despite the model achieving a decent test accuracy of 96.3%.
The root cause of this is the data augmentation pipeline, which performs a random horizontal flip on the training data. This data augmentation step is typically used when ResNet models are trained from scratch, but it causes a problem in our use case where the dataset contains similar classes where images just differ in their direction.
Running training with a custom rule
With Debugger, we can easily write a custom rule that checks how often the model was confused about directions. For example, we take the image “Dangerous curve to left” (class 19) and count how often it was mistaken as “Dangerous curve to right” (class 20) or vice versa. We just need to implement the function invoke_at_step
that Debugger calls every time data for a new step is available. Like before, we access the inputs into the loss function: check if image class 19 or 20 is present, and count how often it was mistaken for the other. If this happens more than 10 times, the rule triggers. See the following code:
We can easily test and run the custom rule locally by creating the trial object and invoking the rule on the data:
Running the code cell in the notebook gives the following output:
The rule triggered at step 1812. After the rule has been tested locally, we can run it as part of our Amazon SageMaker training job. First we need to save the rule in a separate file and then define the following configuration where we indicate on which instance type the rule should run:
After we define the configuration, we add the custom_rule
to the list of rules in the estimator object.
Fixing the model and rerunning training
Now that Debugger has helped us identify some critical issues in our model, we apply the fixes and rerun the training. As mentioned before, weighted re-sampling allows us to fix the class imbalance problem. We also change the data augmentation pipeline and remove the horizontal flip. We reduce the number of epochs from 10 to 3, because we have seen that the loss doesn’t decrease after roughly 1,000 iterations.
With Debugger, we can now compare data from different training jobs and see if the issues persist or not. We just need to create a new trial object, read data from both trials, and compare their tensors:
The following visualization shows the label counts for the original training job and the one where we applied weighted re-sampling (orange). We see that there is no longer a class imbalance issue and the model sees roughly the same amount of instances per class.
We run the training with the same built-in rules as before and add our own custom rule. We monitor the status of the rules and can now see that none of them trigger.
Summary
In this post, we have shown an end-to-end example of how to use Amazon SageMaker Debugger to automatically find, inspect, and fix issues in a deep neural network training.
As the state-of-the-art models grow in size and complexity, finding issues early in the prototyping phase is critical to save time and costs. Model bugs may not always be obvious and as we have shown in this post, a suboptimal model may still achieve an overall good accuracy.
In our use case, we not only found critical bugs, but also reduced training time by a factor of two and improved model performance.
There is no extra cost for running Debugger built-in rules, and you benefit by having them enabled because you may discover non-obvious model issues. If you want to learn more about Debugger, check out the following:
- Amazon SageMaker Debugger
- Amazon SageMaker Debugger GitHub repo
- Amazon SageMaker Debugger – Debug Your Machine Learning Models
- ML Explainability with Amazon SageMaker Debugger
- Autodesk optimizes visual similarity search model in Fusion 360 with Amazon SageMaker Debugger
- Detecting and analyzing incorrect model predictions with Amazon SageMaker Model Monitor and Debugger
References
[1] Johannes Stallkamp, Marc Schlipsing, Jan Salmen, Christian Igel, The German traffic sign recognition benchmark: A multi-class classification competition, The 2011 International Joint Conference on Neural Networks, 2011
About the Authors
Nathalie Rauschmayr is an Applied Scientist at AWS, where she helps customers develop deep learning applications.
Lu Huang is a Senior Product Manager on the AWS Deep Engine team, managing Sagemaker Debugger.
Satadal Bhattacharjee is Principal Product Manager at AWS AI. He leads the machine learning engine PM team on projects such as SageMaker and optimizes machine learning frameworks such as TensorFlow, PyTorch, and MXNet.
Tags: Archive
Leave a Reply