import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
import pytorch_lightning as pl
import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=64):
super().__init__()
self.batch_size = batch_size
def prepare_data(self):
# download only
MNIST('./mnist/', train=True, download=False, transform=transforms.ToTensor())
MNIST('./mnist/', train=False, download=False, transform=transforms.ToTensor())
def setup(self, stage):
# transform
transform=transforms.Compose([transforms.ToTensor()])
mnist_train = MNIST('./mnist/', train=True, download=False, transform=transform)
mnist_test = MNIST('./mnist/', train=False, download=False, transform=transform)
# train/val split
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
# assign to use in dataloaders
self.train_dataset = mnist_train
self.val_dataset = mnist_val
self.test_dataset = mnist_test
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size)
class LitMNIST(LightningModule):
def __init__(self):
super().__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 256)
self.layer_3 = torch.nn.Linear(256, 10)
def forward(self, x):
batch_size, channels, width, height = x.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
x = F.relu(x)
x = self.layer_3(x)
x = F.log_softmax(x, dim=1)
return x
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
print(self.global_step)
self.logger.experiment.add_images('train/imgs',x, self.global_step)
self.log('train/global_step',self.global_step,prog_bar=False)
self.log('train/loss',loss,prog_bar=False)
return loss
def validation_step(self, batch, batch_nb):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
print(self.global_step)
self.logger.experiment.add_images(f'val/imgs/{self.global_step}',x, batch_nb)
self.log('val/global_step',self.global_step,prog_bar=False)
self.log('val/loss',loss,prog_bar=False)
return [loss,x]
def validation_epoch_end(self, outputs):
imgs = []
for i, [loss, x] in enumerate(outputs):
imgs.append(x)
print(x.shape)
imgs = torch.stack(imgs,0)
imgs = imgs.reshape(-1,imgs.shape[2],imgs.shape[3],imgs.shape[4])
self.logger.experiment.add_images(f'val/imgs',imgs, self.global_step)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
if __name__ == "__main__":
from pytorch_lightning import Trainer,loggers
logger = loggers.TestTubeLogger(
save_dir="verify_pl",
name="test_demo",
debug=False,
create_git_tag=False
)
dm = MNISTDataModule()
model = LitMNIST()
trainer = Trainer(gpus=1,
checkpoint_callback=False,
limit_train_batches=0.05,
limit_val_batches=0.1,
logger=logger,
num_sanity_val_steps=3,
check_val_every_n_epoch=1,
max_epochs=20
)
trainer.fit(model, dm)
基于mnist的一个训练代码,能够体会global_step的变换。
可以直接使用,需要把Mnist参数中的`download`设为True
注意training_step、validation_step、validation_epoch_end的区别!
self.global_step是针对所有epoch的全局参数,仅在train_step时更新,在validation时是不更新的。
batch_idx、batch_nb是针对当前的epoch,其中分为train epoch 和 validation epoch,结束一个epoch会清0
Trainer中的参数, num_sanity_val_steps=3 是在训练前验证3个batch
check_val_every_n_epoch=1 是训练后做测试的频率,这里是每训一个epoch就测试一个epoch
limit_train_batches=0.05 一个epoch仅采用5%的训练数据