继续从理论到实践开发自己的聊天机器人
优质
小牛编辑
131浏览
2023-12-01
上篇回顾
上篇文章《自己动手做聊天机器人 四十二-(重量级长文)从理论到实践开发自己的聊天机器人》使用带attention的seq2seq模型实现一般聊天机器人,经过10个小时对1000条样本的训练,达到了比较好的效果,代码分享在https://github.com/lcdevelop/ChatBotCourse/tree/master/chatbotv5
但是存在一个问题,当把样本量加大的时候内存随之增长,如果样本量达到万级别,内存占用已经达到了10G,样本量如果到几十万几百万,内存已经不知道能到多少了,这个主要问题是每次迭代都是把样本全量加载到内存并一次性训练完再更新模型,另外还有一个问题就是词表是基于样本生成的,没有做任何限制,导致样本大词表就大,那么模型就很大,所以占据内存也更大,所以我做了一版优化,在自己机器上尝试训练20w的样本内存占用不到1G,希望大家能找到更大量的样本来帮我充分测试,我这里有三千万的聊天语料可以使用,欢迎大家尝试,获取方式请见《自己动手做聊天机器人 二十九-重磅:近1GB的三千万聊天语料供出》。
优化方案
首先我们把全量加载样本改成批量加载,修改原来的train()函数,核心部分如下:
# 训练很多次迭代,每隔10次打印一次loss,可以看情况直接ctrl+c停止
previous_losses = []
for step in xrange(20000):
sample_encoder_inputs, sample_decoder_inputs, sample_target_weights = get_samples(train_set, 1000)
input_feed = {}
for l in xrange(input_seq_len):
input_feed[encoder_inputs[l].name] = sample_encoder_inputs[l]
for l in xrange(output_seq_len):
input_feed[decoder_inputs[l].name] = sample_decoder_inputs[l]
input_feed[target_weights[l].name] = sample_target_weights[l]
input_feed[decoder_inputs[output_seq_len].name] = np.zeros([len(sample_decoder_inputs[0])], dtype=np.int32)
[loss_ret, _] = sess.run([loss, update], input_feed)
if step % 10 == 0:
print 'step=', step, 'loss=', loss_ret, 'learning_rate=', learning_rate.eval()
if len(previous_losses) > 5 and loss_ret > max(previous_losses[-5:]):
sess.run(learning_rate_decay_op)
previous_losses.append(loss_ret)
# 模型持久化
saver.save(sess, './model/demo')
这里的get_samples(train_set, 1000)是批量获取样本,其中1000是每次获取的样本量,这个函数增加了如下逻辑:
if batch_num >= len(train_set):
batch_train_set = train_set
else:
random_start = random.randint(0, len(train_set)-batch_num)
batch_train_set = train_set[random_start:random_start+batch_num]
for sample in batch_train_set:
raw_encoder_input.append([PAD_ID] * (input_seq_len - len(sample[0])) + sample[0])
raw_decoder_input.append([GO_ID] + sample[1] + [PAD_ID] * (output_seq_len - len(sample[1]) - 1))
也就是说每次都在全量样本中随机位置抽取连续的1000条样本
另外,在加载样本词表时做了词的最小频率的限制,如下:
def load_file_list(self, file_list, min_freq):
......
for index, item in enumerate(sorted_list):
word = item[1]
if item[0] < min_freq:
break
self.word2id_dict[word] = self.START_ID + index
self.id2word_dict[self.START_ID + index] = word
return index
试验效果
经过如上改造,我把样本量扩展到21w,运行起来内存占用不到1G,最新的代码请见最新更新的:https://github.com/lcdevelop/ChatBotCourse/tree/master/chatbotv5