Pytorch Lightning 之 amp_level

寿子轩
2023-12-01

amp_level of Pytorch Lightning

  • amp_level (Optional[str]) – The optimization level to use (O1, O2, etc…). By default it will be set to “O2” if amp_backend is set to “apex”.

 该参数生效的前提是设置 amp_backbend='apex', 首先简介 APEX: 

APEX 是来自英伟达 (NVIDIA) 的一个很好用的深度学习加速库。由英伟达开源,完美支持PyTorch框架,用于改变数据格式来减小模型显存占用的工具。其中最有价值的是 amp (Automatic Mixed Precision) ,将模型的大部分操作都用 torch.HalfTensor(Float16) 数据类型测试,一些特别操作仍然使用 torch.FloatTensor(Float32)

torch.HalfTensor的优势就是存储小、计算快、更好的利用CUDA设备的Tensor Core。因此训练的时候可以减少显存的占用(可以增加batchsize了),同时训练速度更快;其劣势是:数值范围小(更容易Overflow / Underflow)、舍入误差(Rounding Error,导致一些微小的梯度信息达不到16bit精度的最低分辨率,从而丢失)。

在 Pytorch Lightning 中 amp_level 有以下几个级别(Source):

  • 00: full precision training(全精度)
  • 01: conservative mixed-precision where only some ops are going to be done in 16-bit.
  • 02: fast mixed precision and standard used in training, which maintains 32-bit precision weights and the optimizers perform direct updates using 32-bit precision but the data is in 16-bit precision.(最常用的)
  • 03: 16-bit precision training only where everything is using 16-bit precision.

 

 类似资料: