Handle CUDA Out of Memory Error in Pytorch
Jiguang Li
Center for Applied Artificial Intelligence
May 6th, 2022
Motivation
I
Suppose we want to build a large CNN model that can predict
multiple outcomes from X-ray images.
I
The average X-ray images have very large resolution
(2000 × 2000).
I
We run into CUDA out of memory error with batch size as
small as 8, after downsampling image to 512 × 512.
Question: What are some general approaches to avoid
downsampling/ or to use larger batch size?
I
Data Parallelization
I
Model-based Parallelization
I
Gradient Check Pointing
I
· · ·
Motivation
I
Suppose we want to build a large CNN model that can predict
multiple outcomes from X-ray images.
I
The average X-ray images have very large resolution
(2000 × 2000).
I
We run into CUDA out of memory error with batch size as
small as 8, after downsampling image to 512 × 512.
Question: What are some general approaches to avoid
downsampling/ or to use larger batch size?
I
Data Parallelization
I
Model-based Parallelization
I
Gradient Check Pointing
I
· · ·
Motivation
I
Suppose we want to build a large CNN model that can predict
multiple outcomes from X-ray images.
I
The average X-ray images have very large resolution
(2000 × 2000).
I
We run into CUDA out of memory error with batch size as
small as 8, after downsampling image to 512 × 512.
Question: What are some general approaches to avoid
downsampling/ or to use larger batch size?
I
Data Parallelization
I
Model-based Parallelization
I
Gradient Check Pointing
I
· · ·
Motivation
I
Suppose we want to build a large CNN model that can predict
multiple outcomes from X-ray images.
I
The average X-ray images have very large resolution
(2000 × 2000).
I
We run into CUDA out of memory error with batch size as
small as 8, after downsampling image to 512 × 512.
Question: What are some general approaches to avoid
downsampling/ or to use larger batch size?
I
Data Parallelization
I
Model-based Parallelization
I
Gradient Check Pointing
I
· · ·
Motivation
I
Suppose we want to build a large CNN model that can predict
multiple outcomes from X-ray images.
I
The average X-ray images have very large resolution
(2000 × 2000).
I
We run into CUDA out of memory error with batch size as
small as 8, after downsampling image to 512 × 512.
Question: What are some general approaches to avoid
downsampling/ or to use larger batch size?
I
Data Parallelization
I
Model-based Parallelization
I
Gradient Check Pointing
I
· · ·
Our Multi-head CNN Model
I
Multiple medical imaging
research works have shown
DenseNet architecture works
well for x-ray images [2, 3].
I
The dense block requires
high GPU memory.
I
We have made our model
flexible to adjust faster
iterations!
Approach 1: Data Parallelization: nn.DataParallel
I
Replicate a copy of our
model in each GPU.
I
Split minibatch across all
GPUs.
I
Forward: each replica
handles a portion of the
input.
I
Backward: gradients from
each replica are summed
into the original module.
Approach 1: Data Parallelization: nn.DataParallel
I
Replicate a copy of our
model in each GPU.
I
Split minibatch across all
GPUs.
I
Forward: each replica
handles a portion of the
input.
I
Backward: gradients from
each replica are summed
into the original module.
Approach 1: Data Parallelization: nn.DataParallel
I
Replicate a copy of our
model in each GPU.
I
Split minibatch across all
GPUs.
I
Forward: each replica
handles a portion of the
input.
I
Backward: gradients from
each replica are summed
into the original module.
Approach 1: Data Parallelization: nn.DataParallel
I
Replicate a copy of our
model in each GPU.
I
Split minibatch across all
GPUs.
I
Forward: each replica
handles a portion of the
input.
I
Backward: gradients from
each replica are summed
into the original module.
Data Parallelization: The Good The bad and The Ugly
The Good
I
Very easy to implement.
I
Fast: taking advantage of multiple GPUs.
I
Save Memory: each GPU only gets smaller number of images.
The Bad
I
What if the models are large?
I
Unstable training due to batch normalization.
Data Parallelization: The Good The bad and The Ugly
The Good
I
Very easy to implement.
I
Fast: taking advantage of multiple GPUs.
I
Save Memory: each GPU only gets smaller number of images.
The Bad
I
What if the models are large?
I
Unstable training due to batch normalization.
Approach 2: Model-based Parallelization
I
Evenly distribute a single
model into multiple GPUs.
I
During forward pass, each
GPU is only responsible for
one component of the
calculation.
Approach 2: Model-based Parallelization
I
Evenly distribute a single
model into multiple GPUs.
I
During forward pass, each
GPU is only responsible for
one component of the
calculation.
DenseNet169 Model-based Parallelization
Figure: DenseNet169 Model-Based Parallelization Implementation
Model-Based Parallelization: The Good The bad and The
Ugly
Almost the inverse of data parallelization.
The Good
I
Stable training: we are passing all batch of images to each
layer.
I
Save Memory: especially when our model is too large. Note
only one GPU is working at the same time.
The Bad
I
We have to implement from scratch.
I
Very slow: slower than using one GPU.
Model-Based Parallelization: The Good The bad and The
Ugly
Almost the inverse of data parallelization.
The Good
I
Stable training: we are passing all batch of images to each
layer.
I
Save Memory: especially when our model is too large. Note
only one GPU is working at the same time.
The Bad
I
We have to implement from scratch.
I
Very slow: slower than using one GPU.
Model-Based Parallelization: Faster Version
Figure: DenseNet169 Model-Based Parallelization + Data Parallel
Model-Based + Data Parallelization: implementation
I have done a truly remarkable implementation which this margin is
too small to contain.
Figure: DenseNet169 Model-Based Parallelization + Data Parallel
Approach 3: Gradient-Checkpointing
Intuition
I
The total memory used by a neural network is the static
memory used by the model (weights), and dynamic memory
formed from computational graph.
I
During forward pass, gradient checkpointing omits part of the
activation values from the computational graph.
I
During Back propagation, we recalculate the forward pass on
demand.
I
We can show gradient checkpointing can only cost O(
p
(n))
memory to train a n layer network, with an extra forward pass
cost for each mini-batch [1].
Approach 3: Gradient-Checkpointing
Intuition
I
The total memory used by a neural network is the static
memory used by the model (weights), and dynamic memory
formed from computational graph.
I
During forward pass, gradient checkpointing omits part of the
activation values from the computational graph.
I
During Back propagation, we recalculate the forward pass on
demand.
I
We can show gradient checkpointing can only cost O(
p
(n))
memory to train a n layer network, with an extra forward pass
cost for each mini-batch [1].
Approach 3: Gradient-Checkpointing
Intuition
I
The total memory used by a neural network is the static
memory used by the model (weights), and dynamic memory
formed from computational graph.
I
During forward pass, gradient checkpointing omits part of the
activation values from the computational graph.
I
During Back propagation, we recalculate the forward pass on
demand.
I
We can show gradient checkpointing can only cost O(
p
(n))
memory to train a n layer network, with an extra forward pass
cost for each mini-batch [1].
Approach 3: Gradient-Checkpointing
Intuition
I
The total memory used by a neural network is the static
memory used by the model (weights), and dynamic memory
formed from computational graph.
I
During forward pass, gradient checkpointing omits part of the
activation values from the computational graph.
I
During Back propagation, we recalculate the forward pass on
demand.
I
We can show gradient checkpointing can only cost O(
p
(n))
memory to train a n layer network, with an extra forward pass
cost for each mini-batch [1].
Gradient-Checkpointing: Implementation
Figure: DenseNet169 Gradient Checkpointing
Gradient Checkpointing: The Good The bad and The Ugly
The Good
I
Relatively easy to implement.
I
Stable Training.
I
Save Memory: no need to record parts of activation values.
I
One GPU is all you need.
I
You can combine nn.DataParallel and gradient checkpointing.
The Bad
I
Slow, one more forward pass.
Gradient Checkpointing: The Good The bad and The Ugly
The Good
I
Relatively easy to implement.
I
Stable Training.
I
Save Memory: no need to record parts of activation values.
I
One GPU is all you need.
I
You can combine nn.DataParallel and gradient checkpointing.
The Bad
I
Slow, one more forward pass.
Reference
1 Chen, Tianqi et al. “Training Deep Nets with Sublinear
Memory Cost.” ArXiv abs/1604.06174 (2016): n. pag.
2 Irvin, Jeremy A. et al. “CheXpert: A Large Chest Radiograph
Dataset with Uncertainty Labels and Expert Comparison.”
AAAI (2019).
3 Rajpurkar, Pranav et al. “MURA: Large Dataset for
Abnormality Detection in Musculoskeletal Radiographs.”
arXiv: Medical Physics (2017): n. pag.