1. NiftyNet项目概述
NiftyNet项目对tensorflow进行了比较好的封装,实现了一整套的DeepLearning流程。将数据加载、模型加载,网络结构定义等进行了很好的分离,抽象封装成了各自独立的模块。虽然抽象的概念比较多,使得整个项目更为复杂,但是整体结构清晰,支持的模块多。可扩展性还没有进行试验,暂时不是很清楚。 该项目能够实现:
- 图像分割
- 图像分类
- gan
- Autoencoder
- 回归
项目支持医学图像的读取,提供的读取器有:
- nibabel 支持.nii医学文件格式
- simpleitk 支持.dcm和.mhd格式的医疗图像
- opencv 支持.jpg等常见图像读取,读取后通道顺序为BGR
- skimage 支持.jpg等常见图像读取
- pillow 支持.jpg等常见图像读取
在使用中遇到了一些问题,其训练的速度非常慢。最开始单个iter的平均训练时间估计在40秒以上,有的iter时间会有200秒。现在主要在查找性能瓶颈。
一、 项目结构
niftynet.engine.application_driver(ApplicationDriver)定义并驱动着整个Application的生命周期,将配置数据进行解析后,实例化Application并启动流程。
i. Application
Application 作为核心概念,承担整个train或inference的主要功能。所有Application继承于niftynet.application.base_application(简称为BaseApplication)。BaseApplication使用单例模式。
在Application类中,构建了Tensorflow的图结构和创建Session用于驱动计算。
BaseApplication单例模式的具体实现有一点小问题。
Application所完成的工作具体可以划分成以下4个环节
- 输入数据相关 数据加载,数据增强,数据取样等,抽象在这两个接口中在SegmentationApplication中,sampler支持:uniform, weighted, resized, balanced4种方式
initialise_dataset_loader()
initialise_sampler()
- 网络结构相关 网络结构的定义,参数的管理,自定义操作等,抽象在此接口中
initialise_network()
- 模型共享相关 完成由网络的输入到网络的输出,计算loss、gradient,创建optimizer等,抽象在此接口中
connect_data_and_network()
- 输出解码相关 inference将网络输出解码操作,抽象在此接口中
interpret_output()
ii. Config
配置文件需要必须包含的模块:
- [SYSTEM]
- [NETWORK]
- 如果action为train,那么config中需要包含[TRAINING]模块
- 如果action为inference,那么config中需要包含[INFERENCE]模块
- 额外的,根据特定的application,会需要包含指定名称的模块。如:
– [GAN]
– [SEGMENTATION]
– [REGRESSION]
– [AUTOENCODER]
- 除了以上的配置外,其他的数据会处理为input data source specifications【数据声明模块】
l 数据声明模块
Name | 解释 | 例子 | 默认值 |
csv_file | 包含输入图像文件的列表 | csvfile=filelist.csv | '' |
pathtosearch | 如果没有配置csv_file,则从此路径下去搜索输入图像 | pathtosearch=~/ct_data | NiftyNet home folder |
filename_contains | 搜索输入图像时用于匹配的关键词 | filename_contains=foo, bar | '' |
filenamenotcontains | 搜索输入图像时用于排除的关键词 | filenamenotcontains=ti, s1 | '' |
filename_removefromid | 正则表达式,用于从输入图像的文件名中,解析出id | filename_removefromid=foo | '' |
interp_order | 插值法 | interp_order=1 | 3 |
pixdim | 如果指定了,输入的3D图像会重新采样到指定大小再送入网络 | pixdim=1.2, 1.2, 1.2 | '' |
axcodes | 如果指定了,输入的3D图像会重新设定到指定的axcodes顺序再送入网络 参考文章 | axcodes=L, P, S | '' |
spatialwindowsize | 3个整数,指定输入window的大小[能被8整除] | spatialwindowsize=64, 64, 64 | '' |
loader | 指定图像读取loader类型 | loder=simpleitk | None |
[interp_order] 当设定采样方法为resize时,需要这个参数对图片上采样或下采样 1表示双线性插值
0表示最近邻插值
3表示三次样条插值
l [SYSTEM]
Name | 解释 | 例子 | 默认值 |
cude_devices | 指定GPU | cuda_devices=0,1 | '' |
num_threads | 预处理线程的数量 | num_threads=8 | 2 |
num_gpus | 训练时使用GPU数量 | num_gpus=2 | 1 |
model_dir | 保存或读取模型权重和Log的位置 | model_dir=~/niftynet/xxx | config文件所在目录 |
datasetsplitfile | 用于将数据划分成training/validation/inferenct字集 | datasetsplitfile=~/nifnet/xxx | ./datasetsplitfile.csv |
event_handler | 注册事件处理 | eventhandler=modelrestorer | modelsaver, modelrestorer, samplerthreading, applygradients, outputinterpreter, consolelogger, tensorboard_logger |
l [NETWORK]
Names | 解释 | 例子 | 默认值 |
name | 所使用的网络结构 | name=niftynet.network.toynet.ToyNet | ‘’ |
activation_function | 设置网络中使用的激活函数 | activation_function=prelu | Relu |
batch_size | 批大小 | batch_size=10 | 2 |
smaller_final_batch_mode | 当总数据量不能被batch_size整除时,最后一个batch_size的方式 | smaller_final_batch_mode=drop smaller_final_batch_mode=pad smaller_final_batch_mode=dynamic | pad |
decay | 正则化参数 | decay=1e-5 | 0.0 |
reg_type | 正则化类型 | reg_type=L1 | L2 |
volume_padding_size |
| volume_padding_size=4, 4, 4 | 0, 0, 0 |
volume_padding_mode |
| volume_padding_mode=symmetric | minimum |
window_sampling | 采样的类型 | window_sampling=uniform 固定尺寸,相同的概率分布 window_sampling=weighted 固定尺寸,根据intensity作为概率分布 window_sampling=balanced 固定尺寸,每个label拥有相同采样概率 window_sampling=resize 缩放图像到window尺寸 | uniform |
queue_length | 采样时使用的buffer大小 | queue_length=10 | 5 |
keep_prob | 如果网络中使用了dropout | keep_prob=0.2 | 1.0 |
l [TRAINING]
Name | 解释 | 例子 | 默认值 |
optimizer | 优化器类型 | optimizer=momentum | adam |
sample_per_volume | 每个输入图像采样的次数 | sample_per_volume=5 | 1 |
lr | 学习率 | lr=0.0001 | 0.1 |
loss_type | loss计算方式 | loss_type=CrossEntropy | Dice |
starting_iter | 启动的iter | starting_iter=0 | 0 |
save_every_n | 保存的间隔 | save_every_n=50 | 500 |
tensorboard_every_n | tensorboard记录的间隔 | tensorboard_every_n=50 | 20 |
max_iter | 最大iter数 | max_iter=3000 | 10000 |
max_checkpoints | 保存的最多checkpoint数 | max_checkpoints=5 | 100 |
训练时验证
validation_every_n | 训练时进行验证的间隔 | validation_every_n=10 | -1 |
validation_max_iter | 验证时iter的数量 | validation_max_iter=5 | 1 |
exclude_fraction_for_validation | 验证集的比重 | exclude_fraction_for_validation=0.2 | 0.0 |
exclude_fraction_for_inference | 测试集的比重 | exclude_fraction_for_inference=0.1 | 0.0 |
数据增强
rotation_angle | 旋转 | rotation_angle=-10.0, 10.0 | ‘’ |
scaling_percentage | 缩放 | scaling_percentage=-20.0, 20.0 | ‘’ |
random_flipping_axes | 翻转 | random_flipping_axes=1,2 | -1 |
l [INFERENCE]
Name | 解释 | 例子 | 默认值 |
spatial_window_size | 网络输入尺寸大小 | spatial_window_size=64,64,64 | ‘’ |
border | 输入尺寸的边框 | border=5,5,5 | 0,0,0 |
inference_iter | 使用指定iter保存的权重文件 | inference_iter=1000 | -1 |
save_seg_dir | 保存输出路径 | save_seg_dir=output/test | output |
output_postfix | 输出保存的后缀 | output_postfix=_output | _niftynet_out |
output_interp_order | 插值法 | output_interp_order=0 | 0 |
dataset_to_infer | 使用的数据集,可选:’all’, ‘training’, ‘validation’, ‘inference’ | dataset_to_infer=all | ‘’ |
iii. Reader & Dataset
n niftynet.io.image_reader模块
ImageReader的主要作用是,遍历一组目录,搜索并返回一个图像的列表,以及使用iterative的方式将数据加载到内存中。
ImageReader会创建一个tf.data.Dataset的对象,这样使得模块可以很方便地接入到基于tensorflow的程序中。
ImageReader的特点:
l 设计用于支持医疗图像数据的格式
l 支持多模态输入数据
l 支持tf.data.Dataset
n niftynet.contrib.dataset_sampler
sampler将 image reader作为输入,从每张图像中采取出结果输出。
在很多的医学图像处理的情况中,由于GPU显存的限制以及训练效率等的考虑,网络结构会对图像的部分进行处理而非整张图像。
iv. Network
项目中包含了一些已经实现的网络:
- GAN:
– simulator_gan
– siple_gan
- Segmentation:
– highres3dnet, highres3dnetsmall, highres3dnetlarge
– toynet
– unet
– vnet
– dense_vnet
– deepmedic
– scalenet
– holisticnet
– unet_2d
- classification:
– resnet
– se_resnet
- autoencoder:
– vae
v. Loss
已提供支持的loss计算方式
- Segmentation
- CrossEntropy
- CrossEntropy_Dense
- Dice
- Dice_NS
- Dice_Dense
- Dice_Dense_NS
- Tversky
- GDSC
- WGDL
- SensSpec
- Gan
- CrossEntropy
- Regression
- L1Loss
- L2Loss
- RMSE
- MAE
- Huber
- Classification
- CrossEntropy
- AutoEncoder
- VariationalLowerBound
支持的优化器类型
- adam
- gradientdescent
- momentum
- nesterov
- adagrad
- rmsprop
vi. Event机制
NiftyNet项目的设计,使用了Signal和event handler模式,具体实现使用了blinker库。这样可以方便地将模型保存,tensorboard记录等操作进行配置。
目前可供注册的signal有:
- GRAPH_CREATED
- SESS_STARTED
- SESS_FINISHED
- ITER_STARTED
- ITER_FINISHED
信号处理函数注册到对应的信号后,由引擎负责调用。
vii. Layer
网络层的相关设计都封装在Layer类中,可继承layer类,实现定制化结构