MMSegmentation使用心得(二)——分布式训练

朱自明
2023-12-01

在资源允许的情况下,很多小伙伴想要使用MMSegmentation进行分布式训练,下面我们就来讲解一下如何进行分布式训练。

**MMSegmentation不支持使用DataParallel进行分布式训练,只能使用命令行调用自带的文件进行。**同时在用MMDataParallel的时候很多小伙伴可能报错,大家可以参考一下我的方法。

如果要在 Linux 服务器上使用分布式训练,可以执行以下操作: 首先在 Linux 下提供dist_train.sh

chmod 777 ./mmsegmentation/tools/dist_train.sh
vi ./mmsegmentation/tools/dist_train.sh
set ff=unix

接下来,可以使用以下命令进行分布式训练。博主用的方式是自己定义配置文件,例如 swin,可以根据需要自行配置配置文件

nohup ./mmsegmentation/tools/dist_train.sh ./mine/myconfig_swin.py 4 > hehe.log 2>&1 &

(使用nohup是为了防止在断网的时候训练中断,也不需要一直挂在本地上,很方便,使用远程服务器时都可以使用)

下面是博主自定义的config文件,供大家参考,大家在进行单卡训练时也可以参考下面的配置文件对swin进行尝试哦

norm_cfg = dict(type='SyncBN', requires_grad=True)
backbone_norm_cfg = dict(type='LN', requires_grad=True)
model = dict(
    type='EncoderDecoder',
    # pretrained='pretrain/swin_base_patch4_window12_384_22k.pth',
    backbone=dict(
        type='SwinTransformer',
        pretrain_img_size=384,
        embed_dims=128,
        patch_size=4,
        window_size=12,
        mlp_ratio=4,
        depths=[2, 2, 18, 2],
        num_heads=[4, 8, 16, 32],
        strides=(4, 2, 2, 2),
        out_indices=(0, 1, 2, 3),
        qkv_bias=True,
        qk_scale=None,
        patch_norm=True,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.3,
        use_abs_pos_embed=False,
        act_cfg=dict(type='GELU'),
        norm_cfg=dict(type='LN', requires_grad=True)),
    decode_head=dict(
        type='UPerHead',
        in_channels=[128, 256, 512, 1024],
        in_index=[0, 1, 2, 3],
        pool_scales=(1, 2, 3, 6),
        channels=512,
        dropout_ratio=0.1,
        num_classes=2,
        norm_cfg=dict(type='SyncBN', requires_grad=True),
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
#             dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]
# sampler=dict(type='OHEMPixelSampler', thresh=0.98)
    auxiliary_head=dict(
        type='FCNHead',
        in_channels=512,
        in_index=2,
        channels=256,
        num_convs=1,
        concat_input=False,
        dropout_ratio=0.1,
        num_classes=2,
        norm_cfg=dict(type='SyncBN', requires_grad=True),
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# loss_decode=[dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
#             dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0)]
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))
dataset_type = 'myDataset'
data_root = '/media/home/jianlong.li/ljl/data/Extractbuildingdata'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (384, 384)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='RandomCrop', crop_size=(384, 384), cat_max_ratio=0.75),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='Pad', size=(384, 384), pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(384, 384),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ])
]
data = dict(
    samples_per_gpu=4,
    workers_per_gpu=2,
    train=dict(
        type='myDataset',
        data_root='',
        img_dir=
        '/media/home/data/Extractbuildingdata/data/train/image',
        ann_dir=
        '/media/home/data/Extractbuildingdata/data/train/label',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations'),
            dict(type='RandomCrop', crop_size=(384, 384), cat_max_ratio=0.75),
            dict(type='RandomFlip', flip_ratio=0.5),
            dict(type='PhotoMetricDistortion'),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='Pad', size=(384, 384), pad_val=0, seg_pad_val=255),
            dict(type='DefaultFormatBundle'),
            dict(type='Collect', keys=['img', 'gt_semantic_seg'])
        ],
        split=
        '/media/home/data/Extractbuildingdata/data/train/train.txt'
    ),
    val=dict(
        type='myDataset',
        data_root='',
        img_dir=
        '/media/home/data/Extractbuildingdata/data/val/image',
        ann_dir=
        '/media/home/data/Extractbuildingdata/data/val/label',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(384, 384),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ],
        split=
        '/media/home/data/Extractbuildingdata/data/val/val.txt'
    ),
    test=dict(
        type='myDataset',
        data_root='',
        img_dir=
        '/media/home/data/Extractbuildingdata/data/val/image',
        ann_dir=
        '/media/home/data/Extractbuildingdata/data/val/label',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(384, 384),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img'])
                ])
        ],
        split=
        '/media/home/data/Extractbuildingdata/data/val/val.txt'
    ))
log_config = dict(
    interval=5000, hooks=[dict(type='TextLoggerHook', by_epoch=False)])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from=None
# load_from = '/media/home/data/Extractbuildingdata/upernet_swin_base_patch4_window12_512x512_160k_ade20k_pretrain_384x384_22K_20210531_125459-429057bf.pth'
resume_from = '/media/home/data/Extractbuildingdata/work/tutorial/iter_20000.pth'
workflow = [('train', 1)]
cudnn_benchmark = True
optimizer = dict(
    type='AdamW',
    lr=6e-05,
    betas=(0.9, 0.999),
    weight_decay=0.01,
    paramwise_cfg=dict(
        custom_keys=dict(
            absolute_pos_embed=dict(decay_mult=0.0),
            relative_position_bias_table=dict(decay_mult=0.0),
            norm=dict(decay_mult=0.0))))
optimizer_config = dict()
lr_config = dict(
    policy='poly',
    warmup='linear',
    warmup_iters=1500,
    warmup_ratio=1e-06,
    power=1.0,
    min_lr=0.0,
    by_epoch=False)
runner = dict(type='IterBasedRunner', max_iters=120000)
checkpoint_config = dict(by_epoch=False, interval=10000)
evaluation = dict(interval=10000, metric='mIoU', pre_eval=True)
work_dir = '/media/home/data/Extractbuildingdata/work/tutorial'
seed = 42
gpu_ids = range(0, 4)

我们也可以手动选择所需的GPU

CUDA_VISIBLE_DEVICES=2,3 ./mmsegmentation/tools/dist_train.sh ./mine/myconfig_biet.py 2

对于文件,我们可以定义自己的文件类型,并通过以下方式将它们添加到 /mmsegmentation/mmseg/datasets/ 中。(我是直接把下面输出为一个.py文件放到源代码的文件夹下面)

import os.path as osp

from .builder import DATASETS
from .custom import CustomDataset

classes = ('buildinding', 'background')
palette = [[0, 0, 0], [255, 255, 255]]
@DATASETS.register_module()
class myDataset(CustomDataset):
  CLASSES = classes
  PALETTE = palette

  def __init__(self, split, **kwargs):
      super().__init__(img_suffix='.tif', seg_map_suffix='.tif',
                       split=split, **kwargs)
      assert osp.exists(self.img_dir) and self.split is not None

训练结束后,我们就可以通过下面的代码进行预测了

model = init_segmentor(cfg, '/media/home/Extractbuildingdata/work/tutorial/iter_90000.pth', device='cuda:1')

大家快去试试吧,有什么问题请在评论区留言

 类似资料: