当前位置: 首页 > 工具软件 > NiftyNet > 使用案例 >

基于NiftyNet训练自己的数据

邰德业
2023-12-01

1. 环境准备

参考官方文档:https://niftynet.readthedocs.io/en/dev/installation.html

# 创建 conda 环境
conda create -n tensorflow-gpu python=3.6
# 安装 tensorflow
pip install tensorflow-gpu==1.10.0
# 安装 niftynet
git clone https://github.com/NifTK/NiftyNet.git
# 安装依赖
cd NiftyNet/
pip install -r requirements-gpu.txt

2. 训练

2.1 数据准备

将训练数据放在 home/niftynet/data/my_test 目录下,数据格式为 nii,训练和测试数据分别打包为一个 nii 文件。

注意文件名中的关键词,在后面的配置文件中需要设置

2.2 修改配置文件

参考官方文档:https://niftynet.readthedocs.io/en/dev/config_spec.html#configuration-sections

~/niftynet/extensions/ 下创建文件夹 mytest,复制 dense_vnet_abdominal_ct 下的配置文件 config.ini 到该文件夹,根据自己的情况进行修改。

配置文件 config.ini 至少需要包含两部分:[SYSTEM][NETWORK],训练和测试阶段分别对应 [TRAINING][INFERENCE],对于分割任务还需要包含 [SEGMENTATION]

以 dense v-net 的配置文件为例:

############################ input configuration sections
# 用于指定训练图像
[ct]
# 训练图像路径
path_to_search = ./data/dense_vnet_abdominal_ct/
# 用于匹配文件名的关键字,在训练时此关键字将被删除
filename_contains = CT
# 输入大小(H,W,img_num)
spatial_window_size = (144, 144, 144)
# 输入数据的插值顺序,1为线性插值
interp_order = 1
# 图像读入内存转换为 A R S 方向
axcodes=(A, R, S)

[label]
# 训练对象对应的mask
path_to_search = ./data/dense_vnet_abdominal_ct/
filename_contains = Label
spatial_window_size = (144, 144, 144)
interp_order = 0
axcodes=(A, R, S)

############################## system configuration sections
[SYSTEM]
# 设置使用的 GPU
cuda_devices = ""
# 预处理线程数
num_threads = 1
# 可用 GPU 数量
num_gpus = 1
# 加载模型的目录,默认是当前路径
model_dir = models/dense_vnet_abdominal_ct
queue_length = 36

[NETWORK]
# 网络名称,设置为 niftynet.network.toynet.ToyNet则会导入 ToyNet类
name = dense_vnet
# batch size 1 for inference
# batch size 6 for training
batch_size = 1

# volume level preprocessing
volume_padding_size = 0
# 将图片 resize 到 window size
window_sampling = resize

[TRAINING]
sample_per_volume = 1
lr = 0.001
loss_type = dense_vnet_abdominal_ct.dice_hinge.dice
starting_iter = 0
# 每训练多少次保存一次模型
save_every_n = 1000
# 最大迭代次数
max_iter = 3001

[INFERENCE]
# 裁剪网络的输出大小
border = (0, 0, 0)
# 指定用于测试的模型
inference_iter = 3000
# 网络输出的插值循序
output_interp_order = 0
spatial_window_size = (144, 144, 144)
# 测试结果保存路径
save_seg_dir = ./segmentation_output/

############################ custom configuration sections
[SEGMENTATION]
image = ct
label = label
label_normalisation = False
output_prob = False
num_classes = 9

2.3 训练

参考官方文档:https://niftynet.readthedocs.io/en/dev/config_spec.html#overview

训练命令:

# command to run from git-cloned NiftyNet source code folder
python net_run.py train -c <path_to/config.ini> -a <application>

其中,<path_to/config.ini> 是配置文件的路径,<application> 内容格式为 user.path.python.module.MyApplication,效果是导入 user/path/python/module.py 文件下的 MyApplication 类。

现成儿可以用的在 ./niftynet/application/ 目录下,这里使用图像分割为例:

python net_run.py train -c <path_to/config.ini> -a niftynet.application.segmentation_application.SegmentationApplication 

举个例子:

python net_run.py train -c /home/tzq-zyy/niftynet/extensions/mytest/config.ini -a niftynet.application.segmentation_application.SegmentationApplication

训练正常的话,终端输出如下:

dice[0.177198395 0.103725933 0.0715316236 0.0127105657 0.00791870896 0.100049123 0.100663938 0.0575987399 0.0468616374]
INFO:niftynet: training iter 1, total_loss=3.3650364875793457, loss=3.3650364875793457 (30.331078s)
dice[0.199683234 0.0769364685 0.128822058 0.0285055339 0.0270173885 0.0912262872 0.0768199712 0.143819273 0.0744452626]
INFO:niftynet: training iter 2, total_loss=2.26554536819458, loss=2.26554536819458 (11.347821s)
dice[0.243362054 0.177675933 0.108794935 0.0613077283 0.050180003 0.0818266049 0.199541628 0.152265966 0.0877473578]
INFO:niftynet: training iter 3, total_loss=1.3663134574890137, loss=1.3663134574890137 (11.688515s)
dice[0.294665605 0.201383099 0.116870329 0.16179 0.102572314 0.168376252 0.15632163 0.154163659 0.0982201174]
INFO:niftynet: training iter 4, total_loss=0.8387561440467834, loss=0.8387561440467834 (10.815102s)
dice[0.334134877 0.209191054 0.114639401 0.173099115 0.101677962 0.242945343 0.135729268 0.146350488 0.101767704]
INFO:niftynet: training iter 5, total_loss=0.8267183303833008, loss=0.8267183303833008 (11.016553s)
...

训练过程中保存的 model 保存在 model_dir/models 下。

 类似资料: