当前位置: 首页 > 工具软件 > lightning > 使用案例 >

pytorch lightning最简上手

孟思远
2023-12-01

pytorch lightning最简上手

pytorch lightning 是对原生 pytorch 的通用模型开发过程进行封装的一个工具库。本文不会介绍它的高级功能,而是通过几个最简单的例子来帮助读者快速理解、上手基本的使用方式。在掌握基础 API 和使用方式之后,读者可自行到 pytorch lightning 的官方文档,了解进阶 API。本文假设读者对原生 pytorch 训练脚本的搭建方法已经比较熟悉。

安装

pytorch lighning 的安装非常简单,直接使用 pip 安装即可:

pip install pytorch-lightning

最简例子

pytorch lightning 有两个最核心的 API:LigtningModuleTrainer

其中 LightningModule 是我们熟悉的 torch.nn.Module 的子类,可以通过

print(isinstance(pl.LightningModule(), torch.nn.Module))

来验证。这意味着该类同样需要实现 forward 方法,并可直接通过实例调用。

Trainer 则是开始执行模型训练、测试过程的类,传入一个 LightningModule 和对应控制参数来实例化即可开始训练。

我们从一个最简单的例子——MNIST 手写数字识别开始:

1 导入必要的库

导入 pytorch_lightning 和 pytorch 常用的库。

import os

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl

2 实现最简LigntningModule

我们先实现一个最简的 LightningModule。

  • __init__

    构造函数中,像常见的 torch.nn.Module 一样,我们定义好模型的层。由于是最简实例,这里只有一层线性层,将手写数字图像映射为输出 logits。

  • forward

    由于是继承自 torch.nn.Module,因此实现 forward 方法是必须的。forward 方法要完成模型的前向过程,这里直接调用 __init__ 中定义好的线性层,完成模型前向过程。

  • train_dataloader

    train_dataloader 方法也是最简实现中必须的,它的功能是获取训练集的 DataLoader。这里我们返回 MNIST 数据集的 DataLoader。dataloader 的获取也可以不在类内实现,而是在 fit 时传入,后面会介绍。

  • training_step

    training_step 是是 LigtningModule 的核心方法,它定义了一个训练步中需要做的事情。在深度学习的训练步中,最核心的事情就是模型前向,得到结果,计算损失,反向传播,更新参数,这几步在 pytorch 中都有对应的方法供调用。但是在 pytorch lightning 中,我们只需要进行模型前向,并返回必要的信息即可。在最简实现中,我们只需返回损失。

  • configure_optimizer

    在 training_step 中,我们只需返回损失,这意味着模型的反向传播和参数更新过程由 pytorch lightning 帮我们完成了。虽然这个过程可以有框架自己完成,但是我们还是要指定参数更新所用的优化器,在很多模型中,优化器、学习率等超参数设置对结果影响很大。在最简实现中,我们设置好学习率,并返回一个 Adam 优化器。

class MNISTModel(pl.LightningModule):

    def __init__(self):
        super(MNISTModel, self).__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))
      
    def train_dataloader(self):
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
      
    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

以上我们实现 training_step,train_dataloader, configure_optimizer,已经是最简单的 LightningModule 的实现了。如果连这三个方法都没有实现的话,将会报错:

 No `xxx` method defined. Lightning `Trainer` expects as minimum a `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined

3 开始训练

在实现好 LightningModule 之后,就可以开始训练了。

启动训练的最简实现非常简单,只需三行:实例化模型、实例化训练器、开始训练!

model = MNISTModel()
trainer = pl.Trainer(gpus=1, max_epochs=2)
trainer.fit(model)

开始训练后,pytorch lightning 会打印出可用设备、模型参数等丰富的信息。

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 7.9 K
--------------------------------
7.9 K     Trainable params
0         Non-trainable params
7.9 K     Total params
0.031     Total estimated model params size (MB)
Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:07<00:00, 261.53it/s, loss=1.3, v_num=10]

总结

以上我们用 30 行左右代码,实现了一个最简的 pytorch lightning 训练过程。这足以体现出 pytorch lightning 的简洁、易用。但是,显然这个最简实现缺少了很多东西,比如验证、测试、日志打印、模型保存等。接下来,我们将实现相对完整但依旧简洁的 pytorch lightning 模型开发过程。

pytorch lightning更多功能

本节将介绍相对更完整的 pytorch lightning 模型开发过程。

LighningModeul需实现方法

在一个相对完整的 LightnintModule 中,用户应当实现以下方法:

1 模型定义 (__init__)

通常定义模型的各个层,在 forward 调用这些层,完成模型前向。与原生 pytorch 类似。

2 前向计算 (forward)

与 torch.nn.Module 的 forward 中做的事情一样,调用 _init_ 中定义的层。完成模型前向。与原生 pytorch 类似。

3 训练/验证/测试步 (training_step/validation_step/test_step)

定义训练/测试/训练每一步中要做的事情,一般是计算损失、指标并返回。

def training_step(self, batch, batch_idx):
    # ....
    return xxx # 如果是training_step, 则必须包含损失

通常有两个入参 batch 和 batch_idx。是 batch 是 dataloader 给出的输入数据和标签,batch_idx 是当前 batch 的索引。

注意训练步的返回值必须是损失值,或者是包含 ‘loss’ 字段的字典。验证/测试步的返回值不必包括损失,可以是任意结果。

4 训练/验证/测试步结束后 (training_step_end/validation_step_end/test_step_end)

只在使用多个node进行训练且结果涉及如softmax之类需要全部输出联合运算的步骤时使用该函数。

5 训练/验证/测试轮结束后 (training_epoch_end/validation_epoch_end/test_epoch_end)

以 training_epoch_end 为例,其他类似。

如果需要对整一轮的结果进行处理,比如计算一些平均指标等,可以通过 training_epoch_end 来实现。

def training_epoch_end(self, outputs):
    # ....
    return xxx

其中入参 outputs 是一个列表,包含了每一步 training_step 返回的内容。我们可以在每一轮结束后,对每一步的结果进行处理。

4 选用优化器 (configure_optimizers)

设置模型参数更新所用的优化器。值得一提的是如果需要多个优化器(比如在训练 GAN 时),可以返回优化器列表。也可以在优化器的基础上返回学习率调整器,那就要返回两个列表。

5 数据加载器 (train_dataloader, val_dataloader, test_dataloader)

返回 dataloader。

各个 dataloader 也可以在运行 fit/validation/test 时传入,如:

train_loader = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
model = MNISTModel()		# 不需要实现get_dataloader方法
trainer.fit(model, train_loader)

LightningModule自带工具

LightningModule 中提供了一些常用工具供用户直接使用:

log

Tensorboard 损失/指标日志保存和查看,不要自己定义,直接用即可。用法非常简单,将要记录的值传入:

self.log('train loss', loss)

当然一个功能完整的日志保存接口肯定提供了很多参数来控制,比如是按照 epoch 记录还是按照 step 记录、多卡训练时如何同步、指标是否要展示在进度条上、指标是否要保存在日志文件中等等。pytorch lightning 为这些选项都提供了控制参数,读者可以参考官方文档中 log 相关部分

print

python 自带的 print 函数在进行多进程训练时会在每个进程都打印内容,这是原生 pytorch 进行分布式训练时一个很小但是很头疼的问题。LightningModule 提供的 print 只打印一次。

freeze

冻结所有权重以供预测时候使用。仅当已经训练完成且后面只测试时使用。

Trainer实例化参数

在实例化 Trainer 时,pytorch lightning 也提供了很多控制参数,这里介绍常用的几个,完整参数及含义请参考官方文档中 Trainer 相关部分

  • default_root_dir:默认存储地址。所有的实验变量和权重全部会被存到这个文件夹里面。默认情况下,实验结果会存在 lightning_logs/version_x/
  • max_epochs:最大训练周期数,默认为 1000,如果不设上限 epoch 数,设置为 -1。
  • auto_scale_batch_size:在进行训练前自动选择合适的batch size。
  • auto_select_gpus:自动选择合适的GPU。尤其是在有GPU处于独占模式时候,非常有用。
  • gpus:控制使用的GPU数。当设定为None时,使用 cpu。
  • auto_lr_find:自动找到合适的初始学习率。使用了该论文的技术。当且仅当执行 trainer.tune(model) 代码时工作。
  • precision:浮点数精度。默认 32,即常规单精度 fp32 旬来呢。指定为 16 可以使用 fp16 精度加快模型训练并减少显存占用。
  • val_check_interval:进行验证的周期。默认为 1,如果要训练 10 个 epoch 进行一次验证,设置为 10。
  • fast_dev_run:如果设定为true,会只执行一个 batch 的 train, val 和 test,然后结束。仅用于debug。
  • callbacks:需要调用的 callback 函数列表,关于常用 callback 函数下面会介绍。

callback函数

Callback 是一个自包含的程序,可以与训练流程交织在一起,而不会污染主要的研究逻辑。Callback 并不一定只能在 epoch 结尾调用。pytorch-lightning 提供了数十个hook(接口,调用位置)可供选择,也可以自定义callback,实现任何想实现的模块。

推荐使用方式是,随问题和项目变化的操作,实现到 lightning module里面。而独立的、可复用的内容则可以定义单独的模块,方便多个模型调用。

常见的内建 callback 如:EarlyStopping,根据某个值,在数个epoch没有提升的情况下提前停止训练。。PrintTableMetricsCallback,在每个epoch结束后打印一份结果整理表格等。更多内建 callbacks 可参考相关文档

模型加载与保存

模型保存

ModelCheckpoint 是一个自动储存的 callback 模块。默认情况下训练过程中只会自动储存最新的模型与相关参数,而用户可以通过这个 module 自定义。如观测一个 val_loss 的值,并储存 top 3 好的模型,且同时储存最后一个 epoch 的模型,等等。例:

from pytorch_lightning.callbacks import ModelCheckpoint

# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    filename='sample-mnist-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
    save_last=True
)

trainer = pl.Trainer(gpus=1, max_epochs=3, callbacks=[checkpoint_callback])

ModelCheckpoint Callback中,如果 save_weights_only=True,那么将会只储存模型的权重,相当于 model.save_weights(filepath),反之会储存整个模型(包括模型结构),相当于model.save(filepath))。

另外,也可以手动存储checkpoint: trainer.save_checkpoint("example.ckpt")

模型加载

加载一个模型,包括它的模型权重和超参数:

model = MyLightingModule.load_from_checkpoint(PATH)

print(model.learning_rate)
# 打印出超参数

model.eval()
y_hat = model(x)

加载模型时替换一些超参数:

class LitModel(LightningModule):
    def __init__(self, in_dim, out_dim):
      super().__init__()
      self.save_hyperparameters()
      self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim)

# 如果在训练和保存模型时,超参数设置如下,在加载后可以替换这些超参数。
LitModel(in_dim=32, out_dim=10)

# 仍然使用in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)

# 替换为in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)

完整加载训练状态,包括模型的一切,以及和训练相关的一切参数,如 model, epoch, step, LR schedulers, apex 等。

model = LitModel()
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')

# 自动恢复 model, epoch, step, LR schedulers, apex, etc...
trainer.fit(model)

实例

基于第三节介绍的更多功能,我们扩展第二节 MNIST 训练程序。代码如下。

import os

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
import numpy as np


class MNISTModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.fc(x.view(-1, 28 * 28)))

    def training_step(self, batch, batch_nb):
        # REQUIRED
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('train_loss', loss, on_step=False, on_epoch=True)
        return {'loss': loss}

    def validation_step(self, batch, batch_nb):
        # OPTIONAL
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        pred = y_hat.argmax(dim=1, keepdim=True)
        correct = pred.eq(y.view_as(pred)).sum().item()
        acc = correct / x.shape[0]
        self.log('val_acc', acc, on_step=False, on_epoch=True)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        return {'val_loss': loss, 'val_acc': acc}

    def validation_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = np.mean([x['val_acc'] for x in outputs])
        return {'val_loss': avg_loss, 'val_acc': avg_acc}

    def test_step(self, batch, batch_nb):
        # OPTIONAL
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return {'test_loss': loss}

    def test_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        return {'test_loss': avg_loss}

    def configure_optimizers(self):
        # REQUIRED
        return torch.optim.Adam(self.parameters(), lr=0.02)

    def train_dataloader(self):
        # REQUIRED
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

    def val_dataloader(self):
        # OPTIONAL
        return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)

    def test_dataloader(self):
        # OPTIONAL
        return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)

model = MNISTModel()
trainer = pl.Trainer(
        gpus=1,
        max_epochs=10,
        callbacks=[
            pl.callbacks.EarlyStopping( monitor="val_loss", patience=3),
        ]
)
trainer.fit(model)
trainer.test()

Ref

 类似资料: