Pytorch Gradient Checkpoint使用示例

祁承嗣
2023-12-01

训练深度学习模型过程中,经常会遇到CUDA error: out of memory(OOM)的问题。有一些简单粗暴但不elegent的解决办法:

  • 减小Batch Size, e.g 32 → \rightarrow 16
  • 减小输入的大小,e.g.332 × \times × 332 × \times × 3 → \rightarrow 224 × \times × 224 × \times × 3
  • 换一块显存更大的GPU

查了一下,PyTorch提供了一种更优雅的解决方式gradient checkpoint(查了一下应该是0.4.0之后引入的新功能),以计算时间换内存的方式,显著减小模型训练对GPU的占用。在我的模型里,使用gradien checkpoint后,显存占用节省约30%。

虽然PyTorch的gradient checkpoint使用非常简单,但刚开始接触还是希望能有一些示例可以参考。网上找了好久才找到了一篇示例参考。这里给出一个UNet的使用示例,也把过程中遇到的问题和解决办法总结下来。

gradient checkpoint

PyTorch的gradient checkpoint是通过torch.utils.checkpoint.checkpoint(function, *args, **kwargs)函数实现的。
这里把PyTorch官方文档中关于该函数的介绍引用翻译如下:

Checkpointing works by trading compute for memory. Rather than storing all intermediate activations of the entire computation graph for computing backward, the checkpointed part does not save intermediate activations, and instead recomputes them in backward pass. It can be applied on any part of a model.

Gradient Checkpoint是通过以更长的计算时间为代价,换取更少的显存占用。相比于原本需要存储所有中间变量以供反向传播使用,使用了checkpoint的部分不存储中间变量而是在反向传播过程中重新计算这些中间变量。模型中的任何部分都可以使用gradient checkpoint。

gradient checkpoint使用示例

这里以UNet来演示如何使用gradient checkpoint.

import torch 
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.checkpoint import checkpoint

class in_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(in_conv, self).__init__()
        self.op = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        def forward(self, x):
        	x = self.op(x)
        	return x

class conv3x3(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(conv3x3, self).__init__()
        self.op = nn.Sequential(
            nn.Conv2d(in_ch,out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch,out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.op(x)
        return x
        
class down_block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down_block, self).__init__()
        self.pool = nn.MaxPool2d(2, stride=2)
        self.conv = conv3x3(in_ch, out_ch)
        
    def forward(self, x):
        x = self.pool(x)
        x = self.conv(x)
        return x

class up_block(nn.Module):
    def __init__(self, in_ch, out_ch, residual=False):
        super(up_block, self).__init__()
        
        self.up = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)
        self.conv = conv3x3(in_ch, out_ch)
        
    def forward(self, x1, x2):
        x1 = self.up(x1)

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2))

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)

        return x

class out_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(out_conv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1)

    def forward(self, x):
        x = self.conv(x)
        return x

class UNet(nn.Module):
    def __init__(self, img_channels, n_classes,use_checkpoint=False):
        super(UNet, self).__init__()
        self.inc = in_conv(img_channels,64)
        self.down1 = down_block(64, 128)
        self.down2 = down_block(128, 256)
        self.down3 = down_block(256, 512)
        self.down4 = down_block(512, 1024)
        self.up1 = up_block(1024, 512)
        self.up2 = up_block(512, 256)
        self.up3 = up_block(256, 128)
        self.up4 = up_block(128, 64)
        self.outc = out_conv(64, 1)

    def forward(self, x):
        def forward(self, x):
        x = Variable(x,requires_grad=True) 
        if self.use_checkpoint:
            x1 = checkpoint(self.inc,x)
            x2 = checkpoint(self.down1,x1)
            x3 = checkpoint(self.down2,x2)
            x4 = checkpoint(self.down3,x3)
            x5 = checkpoint(self.down4,x4)
            x = checkpoint(self.up1,x5,x4)
            x = checkpoint(self.up2,x,x3)
            x = checkpoint(self.up3,x,x2)
            x = checkpoint(self.up4,x,x1)
            x = checkpoint(self.outc,x)
        else:
        	x1 = self.inc(x)
            x2 = self.down1(x1)
            x3 = self.down2(x2)
            x4 = self.down3(x3)
            x5 = self.down4(x4)
            x = self.up1(x5, x4)
            x = self.up2(x, x3)
            x = self.up3(x, x2)
            x = self.up4(x, x1)
            x = self.outc(x)

        return x

注意第94行,必须确保checkpoint的输入输出都声明为require_grad=True的Variable,否则运行时会报如下的错

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

GPU使用情况监控

推荐一个比nvidia-smi更好用的gpu监控方式——gpustat.
nvidia-smi如果要持续监控GPU使用情况的话,需要loop nvidia-smi,持续打印,不太美观也不易于查看。
gpustat可以动态监控,开一个tab运行就可以持续观测GPU使用情况的变化了。
安装方式

pip install gpustat

使用方式

watch  --color -n1 gpustat -cpu
 类似资料: