Accelerate deep learning model training up to 35% with Amazon SageMaker smart sifting
In today’s rapidly evolving landscape of artificial intelligence, deep learning models have found themselves at the forefront of innovation, with applications spanning computer vision (CV), natural language processing (NLP), and recommendation systems. However, the increasing cost associated with training and fine-tuning these models poses a challenge for enterprises. This cost is primarily driven by the sheer volume of data used in training deep learning models. Today, large models are often trained on terabytes of data and can take weeks to train, even with powerful GPU or AWS Trainium-based hardware. Typically, customers rely on techniques and optimizations that improve the efficiency of a model’s training loop, such as optimized kernels or layers, mixed precision training, or features such as the Amazon SageMaker distributed training libraries. However, there is less focus today on the efficiency of the training data itself. Not all data contributes equally to the learning process during model training: a significant proportion of the computational resources may be spent on processing simple examples that don’t contribute substantially to the model’s overall accuracy.
Customers have traditionally relied on preprocessing techniques such as upsampling or downsampling and deduplication to refine and improve the information quality of their data. These techniques can help, but are often time consuming, require specialized data science experience, and can sometimes be more art than science. Customers often also rely on curated datasets, such as RefinedWeb, to improve the performance of their models; however, these datasets aren’t always fully open source and are often more general purpose and not related to your specific use case.
How else can you overcome this inefficiency related to low-information data samples during model training?
We’re excited to announce a public preview of smart sifting, a new capability of SageMaker that can reduce the cost of training deep learning models by up to 35%. Smart sifting is a new data efficiency technique that actively analyzes your data samples during training and filters out the samples that are less informative to the model. By training on a smaller subset of data with only the samples that contribute the most to model convergence, total training and cost decreases with minimal or no impact to accuracy. Additionally, because the feature operates online during model training, smart sifting doesn’t require changes to your upstream data or downstream training pipeline.
In this post, we discuss the following topics:
- The new smart sifting capability in SageMaker and how it works
- How to use smart sifting with PyTorch training workloads
You can also check out our documentation and sample notebooks for additional resources on how to get started with smart sifting.
How SageMaker smart sifting works
We begin this post with an overview of how the smart sifting capability can accelerate your model training on SageMaker.
Smart sifting’s task is to sift through your training data during the training process and only feed the more informative samples to the model. During a typical training with PyTorch, data is iteratively sent in batches to the training loop and to accelerator devices (for example, GPUs or Trainium chips) by the PyTorch DataLoader. Smart sifting is implemented at this data loading stage and therefore is independent of any upstream data preprocessing in your training pipeline.
Smart sifting uses your model and a user-specified loss function to do an evaluative forward pass of each data sample as it’s loaded. Samples that are high-loss will materially impact model training and therefore are used in training; data samples that are relatively low-loss are set aside and excluded from training.
A key input to smart sifting is the proportion of data to exclude: for example, by setting the proportion to 33% (beta_value=0.5), samples in approximately the bottom third of loss of each batch will be excluded from training. When enough high-loss samples have been identified to complete a batch, the data is sent through the full training loop and the model learns and trains normally. You don’t need to make any changes to your training loop when smart sifting is enabled.
The following diagram illustrates this workflow.
By including only a subset of your training data, smart sifting reduces the time and computation needed to train the model. In our tests, we achieved up to a nearly 40% reduction in total training time and cost. With smart sifting of data, there can be minimal or no impact to model accuracy because the excluded samples were relatively low-loss for the model. In the following table, we include a set of experimental results demonstrating the performance improvement possible with SageMaker smart sifting.
In the table, the % Accepted column indicates the proportion of data that is included and used in the training loop. Increasing this tunable parameter decreases the cost (as demonstrated in the IMR Savings % column), but it also can also affect the accuracy. The appropriate setting for % Accepted is a function of your dataset and model; you should experiment with and tune this parameter to achieve the best balance between reduced cost and impact to accuracy.
Solution overview
In the following sections, we walk through a practical example of enabling smart sifting with a PyTorch training job on SageMaker. If you want to get started quickly, you can jump to the PyTorch or PyTorch Lightning examples.
Prerequisites
We assume that you already know how to train a model using PyTorch or PyTorch Lightning using the SageMaker Python SDK and the Estimator class using SageMaker Deep Learning Containers for training. If not, refer to Using the SageMaker Python SDK before continuing.
Get started with SageMaker smart sifting
In a typical PyTorch training job, you initialize the PyTorch training DataLoader with your dataset and other required parameters, which provides input batches as the training progresses. To enable smart sifting of your training data, you’ll use a new DataLoader
class: smart_sifting.dataloader.sift_dataloader.SiftingDataloader
. This class is used as a wrapper on top of your existing PyTorch DataLoader
and the training process will instead use SiftingDataloader
to get input batches. The SiftingDataLoader
gets the input batch from your original PyTorch DataLoader
, evaluates the importance of samples in the batch, and constructs a sifted batch with high-loss samples, which are then passed to the training step. The wrapper looks like the following code:
The SiftingDataloader
requires some additional parameters to analyze your training data, which you can specify via the sift_config
parameter. First, create a smart_sifting.sift_config.sift_configs.RelativeProbabilisticSiftConfig
object. This object holds the configurable and required beta_value
and loss_history_length
, which respectively define the proportion of samples to keep and the window of samples to include when evaluating relative loss. Note that, because smart sifting uses your model for defining the importance of the sample, there can be negative implications if we use a model with completely random weights. Instead, you can use loss_based_sift_config
and a sift_delay
to delay the sift process until the parameter weights in the model are updated beyond random values. (For more details, refer to Apply smart sifting to your training script.) In the following code, we define sift_config
and specify beta_value
and loss_history_length
, as well as delay the start of sifting using loss_based_sift_config
:
Next, you must also include a loss_impl
parameter in the SiftingDataloader
object. Smart sifting works on an individual sample level, and it’s crucial to have access to a loss calculation method to determine the importance of the sample. You must implement a sifting loss method that returns a nx1 tensor, which holds loss values of n samples. Typically, you specify the same loss method used by your model
during training. Finally, include a pointer to your model in the SiftingDataloader
object, which is used to evaluate samples before they are included in training. See the following code:
The following code shows a complete example of enabling smart sifting with an existing BERT training job:
Conclusion
In this post, we explored the public preview of smart sifting, a new capability of SageMaker that can reduce deep learning model training costs by up to 35%. This feature improves data efficiency during training that filters out less informative data samples. By including only the most impactful data for model convergence, you can significantly reduce training time and expense, all while maintaining accuracy. What’s more, it seamlessly integrates into your existing processes without requiring alterations to your data or training pipeline.
To dive deeper into SageMaker smart sifting, explore how it works, and implement it with PyTorch training workloads, check out our documentation and sample notebooks and get started with this new capability.
About the authors
Robert Van Dusen is a Senior Product Manager with Amazon SageMaker. He leads frameworks, compilers, and optimization techniques for deep learning training.
K Lokesh Kumar Reddy is a Senior engineer in the Amazon Applied AI team. He is focused on efficient ML training techniques and building tools to improve conversational AI systems. In his spare time he enjoys seeking out new cultures, new experiences, and staying up to date with the latest technology trends.
Abhishek Dan is a senior Dev Manager in the Amazon Applied AI team and works on machine learning and conversational AI systems. He is passionate about AI technologies and works in the intersection of Science and Engineering in advancing the capabilities of AI systems to create more intuitive and seamless human-computer interactions. He is currently building applications on large language models to drive efficiency and CX improvements for Amazon.
Leave a Reply