pytorch-lightning的简单入门

云隐水
2023-12-01
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl

import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule


class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        # download only
        MNIST('./mnist/', train=True, download=False, transform=transforms.ToTensor())
        MNIST('./mnist/', train=False, download=False, transform=transforms.ToTensor())

    def setup(self, stage):
        # transform
        transform=transforms.Compose([transforms.ToTensor()])
        mnist_train = MNIST('./mnist/', train=True, download=False, transform=transform)
        mnist_test = MNIST('./mnist/', train=False, download=False, transform=transform)

        # train/val split
        mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

        # assign to use in dataloaders
        self.train_dataset = mnist_train
        self.val_dataset = mnist_val
        self.test_dataset = mnist_test

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)


class LitMNIST(LightningModule):

    def __init__(self):
        super().__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)

        x = F.log_softmax(x, dim=1)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        print(self.global_step)
        self.logger.experiment.add_images('train/imgs',x, self.global_step)
        self.log('train/global_step',self.global_step,prog_bar=False)
        self.log('train/loss',loss,prog_bar=False)
        return loss

    def validation_step(self, batch, batch_nb):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        print(self.global_step)
        self.logger.experiment.add_images(f'val/imgs/{self.global_step}',x, batch_nb)
        self.log('val/global_step',self.global_step,prog_bar=False)
        self.log('val/loss',loss,prog_bar=False)
        return [loss,x]

    def validation_epoch_end(self, outputs):
        imgs = []
        for i, [loss, x] in enumerate(outputs):
            imgs.append(x)
            print(x.shape)
        
        imgs = torch.stack(imgs,0)
        imgs = imgs.reshape(-1,imgs.shape[2],imgs.shape[3],imgs.shape[4])
        self.logger.experiment.add_images(f'val/imgs',imgs, self.global_step)
        

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

if __name__ == "__main__":


    from pytorch_lightning import Trainer,loggers
    
    logger = loggers.TestTubeLogger(
        save_dir="verify_pl",
        name="test_demo",
        debug=False,
        create_git_tag=False
    )
    
    
    dm = MNISTDataModule()
    model = LitMNIST()
    trainer = Trainer(gpus=1,
                      checkpoint_callback=False,
                      limit_train_batches=0.05,
                      limit_val_batches=0.1,
                      logger=logger,
                      num_sanity_val_steps=3,
                      check_val_every_n_epoch=1,
                      max_epochs=20
                      )
    trainer.fit(model, dm)

基于mnist的一个训练代码,能够体会global_step的变换。

可以直接使用,需要把Mnist参数中的`download`设为True

注意training_step、validation_step、validation_epoch_end的区别!

self.global_step是针对所有epoch的全局参数,仅在train_step时更新,在validation时是不更新的。

batch_idx、batch_nb是针对当前的epoch,其中分为train epoch 和 validation epoch,结束一个epoch会清0

Trainer中的参数, num_sanity_val_steps=3 是在训练前验证3个batch

check_val_every_n_epoch=1 是训练后做测试的频率,这里是每训一个epoch就测试一个epoch

limit_train_batches=0.05 一个epoch仅采用5%的训练数据

 类似资料: