讲解一下solver.prototxt文件里面个参数的意义。
DL的任务中,几乎找不到解析解,所以将其转化为数学中的优化问题。sovler的主要作用就是交替调用前向传导和反向传导 (forward & backward) 来更新神经网络的连接权值,从而达到最小化loss,实际上就是迭代优化算法中的参数。
Caffe的solver类提供了6种优化算法,配置文件中可以通过type关键字设置:
1.首先设计好需要优化的对象,以及用于学习的训练网络和测试网络的prototxt文件(通常是train.prototxt和test.prototxt文件)
2.通过forward和backward迭代进行优化来更新参数
3.定期对网络进行评价
4.优化过程中显示模型和solver的状态
base_lr
这个参数代表的是此网络最开始的学习速率(Beginning Learning rate),一般是个浮点数,根据机器学习中的知识,lr过大会导致不收敛,过小会导致收敛过慢,所以这个参数设置也很重要。
lr_policy
这个参数代表的是learning rate应该遵守什么样的变化规则,这个参数对应的是字符串,选项及说明如下:
gamma
这个参数就是和learning rate相关的,lr_policy中包含此参数的话,需要进行设置,一般是一个实数。
stepsize
This parameter indicates how often (at some iteration count) that we should move onto the next “step” of training. This value is a positive integer.
stepvalue
This parameter indicates one of potentially many iteration counts that we should move onto the next “step” of training. This value is a positive integer. There are often more than one of these parameters present, each one indicated the next step iteration.
max_iter
最大迭代次数,这个数值告诉网络何时停止训练,太小会达不到收敛,太大会导致震荡,为正整数。
momentum
上一次梯度更新的权重,real fraction
weight_decay
权重衰减项,用于防止过拟合。
solver_mode
选择CPU训练或者GPU训练。
snapshot
训练快照,确定多久保存一次model和solverstate,positive integer。
snapshot_prefix
snapshot的前缀,就是model和solverstate的命名前缀,也代表路径。
net
path to prototxt (train and val)
test_iter
每次test_interval的test的迭代次数,假设测试样本总数为10000张图片,一次性执行全部的话效率很低,所以将测试数据分为几个批次进行测试,每个批次的数量就是batch_size。如果batch_size=100,那么需要迭代100次才能将10000个数据全部执行完,所以test_iter设置为100。
test_interval
测试间隔,每训练多少次进行一次测试。
display
间隔多久对结果进行输出
iter_size
这个参数乘上train.prototxt中的batch size是你实际使用的batch size。 相当于读取batchsize * itersize个图像才做一下gradient decent。 这个参数可以规避由于gpu内存不足而导致的batchsize的限制 因为你可以用多个iteration做到很大的batch 即使单次batch有限。
average_loss
取多次foward的loss作平均,进行显示输出
net: "examples/mnist/lenet_train_test.prototxt"
test_iter: 100
test_interval: 500
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
lr_policy: "inv"
gamma: 0.0001
power: 0.75
display: 100
max_iter: 10000
snapshot: 5000
snapshot_prefix: "examples/mnist/model/"
solver_mode: CPU
net: "examples/mnist/lenet_train_test.prototxt"
test_iter: 100
test_interval: 500
base_lr: 0.001
momentum: 0.9
momentum2: 0.999
lr_policy: "fixed"
display: 100
max_iter: 10000
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet"
type: "Adam"
solver_mode: CPU
net: "examples/mnist/mnist_autoencoder.prototxt"
test_state: { stage: 'test-on-train' }
test_iter: 500
test_state: { stage: 'test-on-test' }
test_iter: 100
test_interval: 500
test_compute_loss: true
base_lr: 0.01
lr_policy: "fixed"
display: 100
max_iter: 65000
weight_decay: 0.0005
snapshot: 10000
snapshot_prefix: "examples/mnist/mnist_autoencoder_adagrad_train"
solver_mode: GPU
type: "AdaGrad"
net: "examples/mnist/lenet_train_test.prototxt"
test_iter: 100
test_interval: 500
base_lr: 0.01
momentum: 0.0
weight_decay: 0.0005
lr_policy: "inv"
gamma: 0.0001
power: 0.75
display: 100
max_iter: 10000
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet_rmsprop"
solver_mode: GPU
type: "RMSProp"
rms_decay: 0.98
net: "examples/mnist/lenet_train_test.prototxt"
test_iter: 100
lr_policy: "fixed"
momentum: 0.95
weight_decay: 0.0005
display: 100
max_iter: 10000
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet_adadelta"
solver_mode: GPU
type: "AdaDelta"
delta: 1e-6
net: "examples/mnist/mnist_autoencoder.prototxt"
test_state: { stage: 'test-on-train' }
test_iter: 500
test_state: { stage: 'test-on-test' }
test_iter: 100
test_interval: 500
test_compute_loss: true
base_lr: 0.01
lr_policy: "step"
gamma: 0.1
stepsize: 10000
display: 100
max_iter: 65000
weight_decay: 0.0005
snapshot: 10000
snapshot_prefix: "examples/mnist/mnist_autoencoder_nesterov_train"
momentum: 0.95
solver_mode: GPU
type: "Nesterov"
train_net: "train.prototxt"
test_net: "val.prototxt"
test_iter: 736
# make test net, but don't invoke it from the solver itself
test_interval: 999999999
display: 20
average_loss: 20
lr_policy: "fixed"
# lr for unnormalized softmax
base_lr: 1e-10
# high momentum
momentum: 0.99
# no gradient accumulation
iter_size: 1
max_iter: 100000
weight_decay: 0.0005
snapshot: 4000
snapshot_prefix: "snapshot/train"
test_initialization: false