当前位置: 首页 > 工具软件 > MONAI > 使用案例 >

monai学习

任小云
2023-12-01

MetaTensor

这是monai特有的类型
metaTensor=tensor+metaObj
也是原本的tensor,带上meta data,例如affine,spacing,direction,origin

#a: MetaTensor
a.meta['spatial_orgin_shape']
a.meta['affine']
a.meta['original_affine']
a.affine

两个MetaTensor相加,如果is_batch为False,则复制第一个的metadata

数据增强

一般都在monai.transforms底下
有两种版本,一种是增强一张图片,另一种是增强一组图片(例如图片+标签),这种结尾会带一个d
增强完了类型都会变成metatensor

结尾带d的需要传入dict,会返回dict

再训练的时候,可以考虑

data['image'].as_tensor()

来转为普通的tensor,减少计算开销

加载图片

class monai.transforms.LoadImaged(keys, reader=None, dtype=<class 'numpy.float32'>, 
meta_keys=None, meta_key_postfix='meta_dict', overwriting=False, image_only=False, 
ensure_channel_first=False, simple_keys=False, prune_meta_pattern=None, prune_meta_sep='.', 
allow_missing_keys=False, *args, **kwargs)

例如

trans=LoadImaged(['flair', 't1', 't1ce', 't2', 'seg'], image_only=True, allow_missing_keys=True)
trans({'flair':'/mnt/data/xxx', 't1':'/mnt/data/xxxx'})

其中image_only表示只返回图片,False的话,会额外往返回的dict中添加key_meta_dict
allow_missing_keys表示有存在的key也不会报错
读取图片的reader,默认如下
(nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), (npz, npy -> NumpyReader), (dcm, DICOM series and others -> ITKReader).

类型转换

class monai.transforms.CastToTyped(keys, dtype=<class 'numpy.float32'>, 
allow_missing_keys=False)

例如

CastToTyped(keys=['seg'], dtype=torch.long, allow_missing_keys=True)

添加通道

class monai.transforms.EnsureChannelFirstd(keys, meta_keys=None, meta_key_postfix='meta_dict', strict_check=True, allow_missing_keys=False, channel_dim=None)

随机旋转

class monai.transforms.RandRotated(keys, range_x=0.0, range_y=0.0, range_z=0.0, prob=0.1, 
keep_size=True, mode=GridSampleMode.BILINEAR, padding_mode=GridSamplePadMode.BORDER, 
align_corners=False, dtype=<class 'numpy.float32'>, allow_missing_keys=False)

例如

RandRotated(keys=['image', 'seg'], range_x=10, range_y=10, range_z=10, allow_missing_keys=True),  # (-10, 10)

需要注意的是,这个默认的只有float能旋转,而且返回是float,所以可以考虑跟一个转类型的

翻转

class monai.transforms.RandFlipd(keys, prob=0.1, spatial_axis=None, allow_missing_keys=False)

需要注意,这个翻转的话,是指定的轴一起翻转
所以可以考虑多写几个

RandFlipd(keys=["image", "seg"], prob=0.5, spatial_axis=0, allow_missing_keys=True),
RandFlipd(keys=["image", "seg"], prob=0.5, spatial_axis=1, allow_missing_keys=True),
RandFlipd(keys=["image", "seg"], prob=0.5, spatial_axis=2, allow_missing_keys=True)

pad

class monai.transforms.Padd(keys, padder, mode=PytorchPadMode.CONSTANT, allow_missing_keys=False)

这个似乎必须带有通道维度,并且的pad形状要跟带通道后的形状一样

clip

这个确实没找到,但是找到了一种代替的方法

# clip(-325, 325)
transforms.ThresholdIntensityd(keys=['image'], threshold=-325, above=True, cval=-325),  # max(img, -325)
transforms.ThresholdIntensityd(keys=['image'], threshold=325, above=False, cval=325),  # min(img, 325)

z-score

class monai.transforms.NormalizeIntensityd(keys, subtrahend=None, divisor=None, nonzero=False, channel_wise=False, dtype=<class 'numpy.float32'>, allow_missing_keys=False)

例如

transforms.NormalizeIntensityd(keys=['image']),  # z-score

分位点

from monai.transforms import MapTransform

class Percentile(MapTransform):
    def __init__(
            self,
            keys: KeysCollection,
            lower_percentile: float = 0.,
            upper_percentile: float = 100.,
            allow_missing_keys: bool = False,
    ) -> None:
        """
        Args:
            keys: keys of the corresponding items to be transformed.
                See also: :py:class:`monai.transforms.compose.MapTransform`
            lower_percentile: lower percentile(0-100)
            upper_percentile: upper percentile(0-100)
            allow_missing_keys: don't raise exception if key is missing.

        """
        MapTransform.__init__(self, keys, allow_missing_keys)
        self.lower_percentile = lower_percentile / 100.
        self.upper_percentile = upper_percentile / 100.

    def __call__(self, data):
        # the first dim of data should be the channel(CHW[D])
        d = dict(data)
        for key in self.key_iterator(d):
            images = data[key]

            lower = torch.quantile(images, self.lower_percentile)
            upper = torch.quantile(images, self.upper_percentile)
            images = torch.clip(images, lower, upper)
            d[key] = images
        return d

完整的例子

#!/usr/bin/env python
# _*_ coding:utf-8 _*_
import numpy as np
import torch
from monai.config import KeysCollection
from monai.data import NibabelWriter
from monai.transforms import MapTransform, Compose, LoadImaged, CastToTyped, EnsureChannelFirstd, RandSpatialCropd, \
    RandRotated, RandScaleIntensityd, RandShiftIntensityd, RandFlipd, Pad, Padd

data = {
    'flair': '/mnt/data/datasets/BraTS_2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_2013_2_1/Brats18_2013_2_1_flair.nii.gz',
    'seg': '/mnt/data/datasets/BraTS_2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_2013_2_1/Brats18_2013_2_1_seg.nii.gz',
    't1': '/mnt/data/datasets/BraTS_2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_2013_2_1/Brats18_2013_2_1_t1.nii.gz',
    't1ce': '/mnt/data/datasets/BraTS_2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_2013_2_1/Brats18_2013_2_1_t1ce.nii.gz',
    't2': '/mnt/data/datasets/BraTS_2018/MICCAI_BraTS_2018_Data_Training/HGG/Brats18_2013_2_1/Brats18_2013_2_1_t2.nii.gz'
}



class StackImagesd(MapTransform):
    """
    stack images
    add the result in the dict with the key 'image'
    """

    def __call__(self, data):
        d = dict(data)
        result = []
        for key in self.key_iterator(d):
            result.append(d[key])
        d['image'] = torch.stack(result, dim=0)  # (H, W, D)->(4, H, W, D)
        return d


class PercentileAndZScored(MapTransform):
    def __init__(
            self,
            keys: KeysCollection,
            lower_percentile: float = 0.,
            upper_percentile: float = 100.,
            allow_missing_keys: bool = False,
    ) -> None:
        """
        Args:
            keys: keys of the corresponding items to be transformed.
                See also: :py:class:`monai.transforms.compose.MapTransform`
            lower_percentile: lower percentile(0-100)
            upper_percentile: upper percentile(0-100)
            allow_missing_keys: don't raise exception if key is missing.

        """
        MapTransform.__init__(self, keys, allow_missing_keys)
        self.lower_percentile = lower_percentile / 100.
        self.upper_percentile = upper_percentile / 100.

    def __call__(self, data):
        # the first dim of data should be the channel(CHW[D])
        d = dict(data)
        for key in self.key_iterator(d):
            images = data[key]
            C = images.size(0)
            mask = images.sum(0) > 0  # brain
            for k in range(C):
                x = images[k, ...]
                y = x[mask]
                lower = torch.quantile(y, self.lower_percentile)
                upper = torch.quantile(y, self.upper_percentile)

                x[mask & (x < lower)] = lower
                x[mask & (x > upper)] = upper

                # z-score across the brain
                y = x[mask]
                x -= y.mean()
                x /= y.std()

                images[k, ...] = x
            d[key] = images
        return d


if __name__ == '__main__':
    transform = Compose([
        LoadImaged(['flair', 't1', 't1ce', 't2', 'seg'], image_only=True, allow_missing_keys=True),

        CastToTyped(keys=['seg'], dtype=torch.long, allow_missing_keys=True),
        EnsureChannelFirstd(keys=['seg'], allow_missing_keys=True),

        StackImagesd(keys=['flair', 't1', 't1ce', 't2']),  # add ['image']
        PercentileAndZScored(keys=['image'], lower_percentile=0.2, upper_percentile=99.8),

        RandSpatialCropd(keys=['image', 'seg'], roi_size=(128, 128, 128), random_size=False, allow_missing_keys=True),
        RandRotated(keys=['image', 'seg'], range_x=10, range_y=10, range_z=10, allow_missing_keys=True),  # (-10, 10)
        CastToTyped(keys=['seg'], dtype=torch.long, allow_missing_keys=True),
        RandScaleIntensityd(keys=['image'], factors=0.1),  # (-0.1, 0.1), img * (1 + scale)
        RandShiftIntensityd(keys=['image'], offsets=0.1),  # (-0.1, 0.1), img + offset
        RandFlipd(keys=["image", "seg"], prob=0.5, spatial_axis=0, allow_missing_keys=True),
        RandFlipd(keys=["image", "seg"], prob=0.5, spatial_axis=1, allow_missing_keys=True),
        RandFlipd(keys=["image", "seg"], prob=0.5, spatial_axis=2, allow_missing_keys=True),
        CastToTyped(keys=['seg'], dtype=torch.long, allow_missing_keys=True),
        Padd(keys=["image", "seg"], padder=Pad([(0, 0), (0, 0), (0, 0), (0, 5)]), allow_missing_keys=True),
    ])
    result = transform(data)
    print(result['seg'].shape)

保存

nii

默认是会重采样的,重采样的依据是metadata的spatial_shape,affine,original_affine

writer = NibabelWriter(output_dtype=torch.uint8)
writer.set_data_array(mask.squeeze(0), channel_dim=0) # (C, H, W, D)
writer.set_metadata(mask.meta, resample=True, mode='nearest')
writer.write('faQ.nii.gz', verbose=True)

不重采样

下面这段,会把affine直接写入数据中,而忽略原来的affine

writer = NibabelWriter(output_dtype=torch.uint8)
writer.set_data_array(mask, channel_dim=0)# (C, H, W, D)
writer.set_metadata({
    'spatial_shape': result['image'].meta['spatial_shape'],
    'affine': result['image'].meta['original_affine'],
    'original_affine': result['image'].meta['original_affine']
}, resample=False, mode='nearest')

writer.write('faQ.nii.gz')

Loss

都在monai.losses里
常见有:dice,focal loss, GeneralizedDiceLoss

metric

需要注意的是,很多metric要求传进去one_hot形式

具体可以看下面的例子

import argparse
import os
from glob import glob

import torch
from monai import transforms
from monai.metrics import DiceMetric
from monai.networks import one_hot
from monai.utils import MetricReduction
from tqdm.contrib import tzip

metric = DiceMetric(include_background=False)
loader = transforms.LoadImage(image_only=True, dtype=torch.uint8)
predict_paths = get_data(args.predict_path)
gt_paths = get_data(args.gt_path)
CLASS_NUMBER = 5

for predict_path, gt_path in tzip(predict_paths, gt_paths):
    predict = loader(predict_path).unsqueeze(0)
    gt = loader(gt_path).unsqueeze(0)
    predict = one_hot(predict, CLASS_NUMBER, dtype=torch.uint8, dim=0).unsqueeze(0)
    gt = one_hot(gt, CLASS_NUMBER, dtype=torch.uint8, dim=0).unsqueeze(0)
    dice = metric(predict, gt) # (B, C) =(B,5-1) (include_background=False)

print('mean_dice: liver={:.2f}, kidney={:.2f}, spleen={:.2f}, pancreas={:.2f}'.format(
    *metric.aggregate(MetricReduction.MEAN_BATCH)))
print('mean_dice: {:.2f}'.format(metric.aggregate(MetricReduction.MEAN).item()))
 类似资料: