Train gigantic models with near-linear scaling using sharded data parallelism on Amazon SageMaker

In the pursuit of superior accuracy, deep learning models in areas such as natural language processing and computer vision have significantly grown in size in the past few years, frequently counted in tens to hundreds of billions of parameters. Training these gigantic models is challenging and requires complex distribution strategies. Data scientists and machine learning engineers are constantly looking for the best way to optimize their training compute, yet are struggling with the communication overhead that can increase along with the overall cluster size.

This is why we recently launched sharded data parallelism on Amazon SageMaker, a new memory-saving distributed training technique in the SageMaker model parallel (SMP) library. Sharded data parallelism is purpose-built for extreme-scale models and uses Amazon in-house MiCS technology under the hood, a science effort to minimize the communication scale by bringing down expensive communication overhead rooted in parameter gathering and gradient synchronization. With a 30B parameter GPT-2 model with sequence length 2048, this new feature achieved 141 TFLOPs, a 39.7% speed up compared to DeepSpeed ZeRO-3. For a 10B GPT-2 model with sequence length 512, this new feature also achieved 564 samples per second, a 13.9% speed up compared to PyTorch’s Fully Sharded Data Parallel (FSDP). Remember that in gigantic model training, every percentage of speedup translates to dollars saved and productivity gained in your team.

In this blog post, we’ll first take a closer look at the key differentiators of sharded data parallelism and when to use it. Then, you’ll learn how to train a 30B parameter GPT-2 model on SageMaker with ease with this new feature. Finally we’ll compare the performance with other open source options, notably outperforming DeepSpeed ZeRO by up to 39.7% on 256 GPUs.

How sharded data parallelism works and when to use it

Before we introduce sharded data parallelism, let’s look at its broader technique family. Recent distributed training approaches for large models have moved to a paradigm where model parameters, gradients, and optimizer states are shared across data-parallel nodes. Unlike Pipeline Parallelism which has the innate complexity of choosing layers to partition across devices especially when your framework doesn’t support automated model splitting, this paradigm elegantly preserves the simplicity of data parallelism, while removing data parallelism’s constraint where a model must fit into a single GPU.

In existing frameworks that fall under this paradigm, notably DeepSpeed ZeRO-3 and PyTorch’s FSDP upstreamed from FairScale, model states are sharded across all GPUs, a strategy that lowers the memory consumption on each GPU at the cost of incurring large communication overhead which increases with cluster size and therefore causes the scalability to significantly drop at scale. In contrast, sharded data parallelism in the SMP library partitions model states in a scale-aware manner by partitioning each replica of model states only within a subset of GPUs.

Let’s look closer at the scale-aware model partitioning in MiCS, the core technology behind sharded data parallel. The intuition behind this design is that partitioning training states across the entire data-parallel group may not be required to train a model with tens of billions of parameters. For example, 8 V100 GPUs (32GB each) are sufficient to hold the model states replica of a 10B-parameter model which needs about 200GB of memory when training with Adam optimizer using mixed-precision. By limiting a complete replica of model states in the smallest subset of GPUs, we can effectively reduce the scale of communication overhead compared to DeepSpeed and PyTorch FSDP. Sharded data parallel also leverages other techniques in MiCS such as Hierarchical Communication and 2-hop Gradient Synchronization. For more information, check out Near-linear scaling of gigantic-model training on AWS or MiCS: Near-linear Scaling for Training Gigantic Model on Public Cloud.

Now, how do you know when to choose sharded data parallel over other distributed training techniques? The general rule is that if your model has less than 1 billion parameters and can fit into GPU memory, SageMaker data parallel library or SageMaker training compiler can be sufficient for you. If you have larger language or computer vision models, our suggestion is to train it with the sharded data parallelism technique combined with activation checkpointing and activation offloading in the SageMaker model parallel library first, before other techniques such as tensor parallelism or pipeline parallelism.

Using sharded data parallelism to train GPT-2 on Amazon SageMaker

Let’s now learn how to train a GPT-2 model with sharded data parallel, with SMP encapsulating the complexity for you. This complete tutorial notebook walks you through the entire process, from data processing, defining and submitting training jobs, to monitoring training logs. What follows is a brief overview highlighting key steps for using this feature.

1. Get started

Sharded data parallelism is available in PyTorch v1.12.0+ and works with both FP16 and BF16. The easiest way to use the SMP library is through a prebuilt AWS Deep Learning Container for PyTorch. However, if you want to bring your own Docker container, you can refer to Create Your Own Docker Container with the SageMaker Distributed Model Parallel Library. To get started, follow Modify a PyTorch Training Script to adapt SMPs’ APIs in your training script. In this section, we only call out a few main steps with code snippets from the ready-to-use training script You can follow the comments in the script and API document to learn more about where SMP APIs are used.

First, import and initialize the library by calling smdistributed.modelparallel.torch.init() at the beginning of the training script:

import smdistributed.modelparallel.torch as smp


Second, wrap the model to be partitioned with smdistributed.modelparallel.torch.DistributedModel and use the returned DistributedModel object going forward:

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_config(model_config)
model = smp.DistributedModel(model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation)

Wrap the optimizer with smdistributed.modelparallel.torch.DistributedOptimizer for saving and loading optimizer states.

from torch import optim

optimizer = optim.Adam(
    param_groups, betas=(args.beta1, args.beta2),, weight_decay=args.weight_decay

optimizer = smp.DistributedOptimizer(
        dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2},

Put the forward and backward logic in a step function and decorate it with smdistributed.modelparallel.torch.step.  Any computation defined inside the smp.step-decorated function is executed in a distributed manner.

def train_step(model, optimizer, input_ids, attention_mask, args):
    loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)["loss"]

    return loss

def test_step(model, input_ids, attention_mask):
    loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)["loss"]
    return loss

2. Prepare the dataset

We use the openwebtext is the dataset we use in this example. The notebook uses the script to download and preprocess the dataset. You can also train with other datasets by modifying When dealing with large dataset and model, you can speed up the training job by using data stored in Amazon FSx for Lustre, which provides a high-performance file system natively integrated with Amazon Simple Storage Service (S3). Please see the instructions from Configure Data Input Channel to Use Amazon FSx for Lustre for guidance on setting an FSx Lustre file system as data input channel.

3. Start the training jobs

This step assumes you have already modified your training script and prepared the dataset as mentioned in the preceding sections. To enable sharded data parallelism, simply set the sharded_data_parallel_degree in the PyTorch Estimator. In this tutorial, we set sharded_data_parallel_degree=128 and instace_count=32 for p4d.24xlarge nodes, which indicates that the model states will be sharded across 128 GPUs among the total 256 GPUs. Based on this selected value, SMP will then automatically sets the data parallel degree to 2 (because 256/128=2), meaning we’ll have two replicas for data parallelism. A general rule for picking an ideal value for sharded_data_parallel_degree is to add one more node to the sharing group per every 3B of model parameters. In this tutorial, our model size is 30B, so we should use at least 10 nodes for sharding. And because 16 nodes (128 GPUs) is the smallest power-of-2 above the threshold, we set sharded_data_parallel_degree=128.

For checkpointing, we also provide a set of checkpointing utilities in , including a utility to reconstruct the full state_dict for advanced use cases. Finally, we can launch a distributed training job by calling fit() on the Estimator.

smp_estimator = PyTorch(
        "mpi": {
            "enabled": True,
            "processes_per_host": processes_per_host,
            "custom_mpi_options": mpioptions,
        "smdistributed": {
            "modelparallel": {
                "enabled": True,
                "parameters": {
                    "ddp": True,
                    "skip_tracing": True,
                    "delayed_parameter_initialization": True,
                    "offload_activations": True,
                    "activation_loading_horizon": 4,
                    # To enable sharded data parallelism.
                    # Here we shard model states across 128 GPUs. 
                    "sharded_data_parallel_degree": 128, 
                    "fp16": False,
                    "bf16": True,
                    # This is to disable pipeline parallelism.
                    "partitions": 1,
    checkpoint_s3_uri=checkpoint_s3_uri if not use_fsx else None,
    checkpoint_local_path=hyperparameters["checkpoint-dir"] if use_fsx else None,

4. Monitor the training jobs

You can access the training logs and track GPU and memory utilization on Amazon CloudWatch. Make sure to look at the logs of “algo-1” because that is the main node whose output stream has the training job logs from all instances.

Benchmarking performance

We benchmarked sharded data parallelism in the SMP library on both 16 and 32 p4d.24xlarge nodes for sequence length 512 and 2048, respectively. The 30B-parameter GPT2 model is configured to use a hidden width of 7168, 48 layers, and 64 heads. You can adopt the exact same configuration where sequence length is 2048 by setting model_config = "gpt2-30b" in the tutorial notebook. With this setting, SMP achieved 73.52 samples per second, a 39.7% speed up compared to DeepSpeed ZeRO-3. If your token size is 500 billion, this speed up means nearly 367 hours of savings on p4d.24xlarge nodes, an equivalent of more than $12,000 budget saved per training! The following table summarizes our benchmark results.

Configuration Performance Time to train with SMP (days)
Model/Training Cluster DeepSpeed SMP Speed (samples/sec)
DeepSpeed v0.7.2
Speed (samples/sec)
SMP v1.11
% Speedup of SMP TFLOPS achieved by SMP 100 billion tokens 500 billion tokens
30B GPT-2
Seq length:512
Global batch size:3072
16 p4d.24xlarge nodes Activation checkpointing
Activation checkpointing
142 181.05 27.5 173.6 12.49 62.43
30B GPT-2
Seq length:2048
Global batch size 1536
32 p4d.24xlarge nodes Activation checkpointing
Activation checkpointing sharded_data_parallel_degree:128
52.6 73.52 39.77 141 7.69 38.43
1/ For each model configuration, we tested different features, stages, and configurations in DeepSpeed ZeRO and chose the one that provides the best throughput as the DeepSpeed baseline. The benchmark was run on Amazon Elastic Compute Cloud (Amazon EC2). 2/ These results rely on improved communication collectives optimized for AWS which will be made available soon. 3/ Time to train is projected from speed based on number of tokens processed.

In summary, we observed consistently higher throughput with sharded data parallelism in SMP when compared to DeepSpeed across a range of models and configurations. This new feature also demonstrated a better memory efficiency compared to DeepSpeed, enabling SMP to fit a larger batch size and reduce the level of gradient accumulation required to fit a particular global batch size.


In this post, we introduced a new distributed training technique — sharded data parallelism — and how it speeds up gigantic model training with near linear-scaling on Amazon SageMaker. We also walked through how to train a GPT-2 model with the new technique following this complete example. You can follow the Amazon SageMaker Examples GitHub repo to track all SageMaker model parallel examples or attend our next distributed training workshops. To learn more about sharded data parallelism, please see the documentation.

About the authors

Emily Webber joined AWS just after SageMaker launched, and has been trying to tell the world about it ever since! Outside of building new ML experiences for customers, Emily enjoys meditating and studying Tibetan Buddhism.

Can Karakus is a Senior Applied Scientist at AWS, optimizing large-scale distributed deep learning on AWS. His research interests cover deep learning, distributed optimization, distributed systems, and information theory. Outside of work, he enjoys cycling, traveling, reading and learning.

Rahul Huilgol is a Senior Software Engineer at AWS. He works on distributed deep learning systems, towards making it easy and performant to train large deep learning models in the cloud. In his spare time, he enjoys photography, biking and gardening.

Suhit Kodgule is a Software Development Engineer with AWS Artificial Intelligence group working on deep learning frameworks. In his spare time, he enjoys hiking, traveling and cooking.

Erin Ho is a Product Manager for AWS Deep Learning. She works on products that make it easier for customers to train deep learning models on AWS. For fun outside work, she enjoys hiking and skiing.

View Original Source ( Here.

Leave a Reply

Your email address will not be published. Required fields are marked *

Shared by: AWS Machine Learning