下面代码不包括DDPG强化学习参数优化器和Distill蒸馏训练
conda create --name PocketFlow python=3.6
source activate PocketFlow
pip install tensorflow-gpu=1.10.0
pip install numpy=1.14.5
conda install panda
conda install scikit-learn
cifar-10:使用binary版本
wget https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
下载完成后,解压到data目录
并在path.conf中设置路径:data_dir_local_cifar10 = /home/mars/hewu/tensorflow/PocketFlow/data/cifar-10-batches-bin
./scripts/run_local.sh nets/resnet_at_cifar10_run.py
官方教程需要在工程根目录配置path.conf文件,然后执行上述脚本。个人觉得不太方便调试,直接启动py,单步跟踪,更加便于理解程序逻辑。
nets/resnet_at_cifar10_run.py --model_http_url https://api.ai.tencent.com/pocketflow --data_dir_local /home/mars/hewu/tensorflow/PocketFlow/data/cifar-10-batches-bin --learner channel --cp_prune_option uniform
main.py
#/home/mars/hewu/tensorflow/PocketFlow/main.py
from nets.resnet_at_cifar10 import ModelHelper
from learners.learner_utils import create_learner
#1创建模型helper和learner
model_helper = ModelHelper()#网络和数据集的类
learner = create_learner(sw_writer,model_helper)#跳转到不同的压缩算法learner
#2进入训练,或者评估
learner.train()
learner.evaluate()
channel_pruning/learner.py
#/home/mars/hewu/tensorflow/PocketFlow/learners/channel_pruning/learner.py
from learners.distillation_helper import DistillationHelper #蒸馏相关
from learners.abstract_learner import AbstractLearner
from learners.channel_pruning.model_wrapper import Model #模型相关
from learners.channel_pruning.channel_pruner import ChannelPruner #裁剪相关
from rl_agents.ddpg.agent import Agent as DdpgAgent #强化学习代理DDPG
#继承自AbstractLearner
class ChannelPrunedLearner(AbstractLearner):
#继承初始化
super(ChannelPrunedLearner,self).__init__(sm_writer,model_helper)
#类内初始化
#蒸馏类初始化
self.learner_dst = DistillationHelper(sm_writer,model_helper)
#构建
#构建输入数据,模型定义,计算裁剪上下限等
self.__build(is_train=True)
#1train函数
def train(self):
#下载预训练模型,恢复权重,创建裁剪者pruner
#...
self.create_pruner()
#选择裁剪策略:list,auto,uniform
if FLAGS.cp_prune_option == 'list':
self.__prune_and_finetune_list()
#self.__prune_and_finetune_auto()
#self.__prune_and_finetune_uniform()
#2
def create_pruner(self):
#...
self.model = Model(self.sess_train)
self.pruner = ChannlPruner(
self.model,#模型
images=train_images,
labels=train_labels,
mem_images=mem_images,
mem_labels=mem_labels,
metrics=metrics,#度量,loss,accuracy
lbound=self.lbound,#裁剪保留通道比例
summary_op=summary_op,
sm_writer=self.sm_writer)
#3以auto策略为例介绍具体裁剪方法
def __prune_and_finetune_auto(self):
self.__prune_rl()#初始化RL类并进行裁剪(调用compress),学习最佳裁剪方法
while not done:#完成prune和finetune
done = self.__prune_list_layers(queue, [FLAGS.cp_list_group])
def __prune_rl(self):
#RL学习搜索裁剪策略
#5__prune_rl()和__prune_list_layers()中都会调用compress
def compress(self, c_ratio):
#裁剪时,只把选中的裁剪通道的权值置0,并没有真的裁剪掉
self.prune_kernel(conv_op,c_ratio)#裁剪策略lasso等
self.prune_W1(father_conv, idxs)#裁剪父conv的输出通道数(即当前conv的输入通道数)
self.prune_W2(conv_op, idxs, W2)#裁剪当前conv的输入通道数
def prune_kernel(self, op, nb_channel_new): #裁剪的具体步骤
#当前卷积:裁剪后通道数,newX输入feature map,Y目标值,W2权值
nb_channel_new = max(int(np.around(c * nb_channel_new)), 1)#hw new channel number
newX = self.__extract_input(op)
Y = self.feats_dict[outname]
W2 = self._model.param_data(op)
#lasso裁剪,得到新的权值newW2,以及通道索引(True/False)
idxs, newW2 = self.compute_pruned_kernel(newX, W2, Y, c_new=nb_channel_new)
def compute_pruned_kernel(
self,
X,
W2,
Y,
alpha=1e-4,
c_new=None,
tolerance=0.02):
#固定beta,优化W,即求解W
while True:
_, tmp, coef = solve(right)
...
#固定W,优化beta,即求解beta(idxs索引就是beta)
while True:
idxs, tmp, coef = solve(alpha)
...
channel_pruning/model_wrapper.py
#/home/mars/hewu/tensorflow/PocketFlow/learners/channel_pruning/model_wrapper.py
def get_Add_if_is_first_after_resblock(self, op):
#Add的输出层
def get_Add_if_is_last_in_resblock(cls, op):
#Add的输入层
def is_W1_prunable(self, conv):
#可以裁剪的层
channel_pruning/channel_pruner.py
#/home/mars/hewu/tensorflow/PocketFlow/learners/channel_pruning/channel_pruner.py
from sklearn.linear_model import LassoLars
from sklearn.linear_model import LinearRegression
class ChannelPruner(object):
def __init__(self,...):
self._model = model
self.thisconvs = self._model.get_operations_by_type()#网络中的卷积层
self.__build()
def __build(self):
self.__extract_output_of_conv_and_sum()#获取conv和add op,存入self.names列表
self.__create_extractor()#创建用于获取卷积输入feature map的extractor
self.initialize_state()#初始化状态:主要是确定哪些能裁剪,裁剪率等
def initialize_state(self):
#op名,对应裁剪保留范围:[] 例如第一个和最后一个卷积不裁剪,则范围为[1.0, 1.0]
self.max_strategy_dict = {} # collection of intilial max [inp preserve, out preserve]
#op名,对应输入通道列表和输出通道列表,里面的值为True保留这个通道,False裁剪这个通道
self.fake_pruning_dict = {} # collection of fake pruning indices
#layer n c H W stride maxreduce layercomp
#状态 输出通道 输入通道 高 宽 stride 最大缩减 层计算量 都是除以每一列最大值后的归一化结果
resnet20裁剪:
权值维度:[KH,KW,Cin,Cout]
当前卷积都是裁剪输入通道,父卷积都是裁剪输出通道。
depthwise conv可以往前递推,直至找到一个普通的Conv2D OP,因为depthwise conv中不同channel之间没有dependency
裁剪的当前卷积 | 裁剪的父卷积 |
---|---|
conv2d_1 | conv2d |
conv2d_2 | conv2d |
conv2d_3 | conv2d_2 |
conv2d_5 | conv2d_4 |
conv2d_7 | conv2d_6 |
conv2d_10 | conv2d_9 |
conv2d_12 | conv2d_11 |
conv2d_14 | conv2d_13 |
conv2d_17 | conv2d_16 |
conv2d_19 | conv2d_18 |
conv2d_4可以裁剪输入通道,但是转pb时需要在其前面插入tf.gather | Add |
conv2d_6 | Add |
conv2d_8 | Add |
conv2d_9 | Add |
conv2d_11 | Add |
conv2d_13 | Add |
conv2d_15 | Add |
conv2d_16 | Add |
conv2d_18 | Add |
conv2d_20 | Add |
最后一个卷积不裁剪
conv2d_21