Author: Yirong Chen from South China University of Technology
My CSDN Blog: https://blog.csdn.net/m0_37201243
My Homepage: http://www.yirongchen.com/
Dependencies:
参考网站:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math
Cornell Movie-Dialogs Corpus是一个丰富的电影角色对话数据集:
这个数据集庞大而多样,在语言形式、时间段、情感上等都有很大的变化。我们希望这种多样性使我们的模型能够适应多种形式的输入和查询。
### 下载数据集
import os
import requests
print("downloading Cornell Movie-Dialogs Corpus数据集")
data_url = "http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip"
path = "./data/"
if not os.path.exists(path):
os.makedirs(path)
res = requests.get(data_url)
with open("./data/cornell_movie_dialogs_corpus.zip", "wb") as fp:
fp.write(res.content)
print("Cornell Movie-Dialogs Corpus数据集下载完毕!")
downloading Cornell Movie-Dialogs Corpus数据集
Cornell Movie-Dialogs Corpus数据集下载完毕!
import time
import zipfile
srcfile = "./data/cornell_movie_dialogs_corpus.zip"
file = zipfile.ZipFile(srcfile, 'r')
file.extractall(path)
print('解压cornell_movie_dialogs_corpus.zip完毕!')
print("Cornell Movie-Dialogs Corpus数据集的文件组成如下:")
corpus_file_list=os.listdir("./data/cornell movie-dialogs corpus")
print(corpus_file_list)
解压cornell_movie_dialogs_corpus.zip完毕!
Cornell Movie-Dialogs Corpus数据集的文件组成如下:
['formatted_movie_lines.txt', 'chameleons.pdf', '.DS_Store', 'README.txt', 'movie_conversations.txt', 'movie_lines.txt', 'raw_script_urls.txt', 'movie_characters_metadata.txt', 'movie_titles_metadata.txt']
def printLines(file, n=10):
with open(file, 'rb') as datafile:
lines = datafile.readlines()
for line in lines[:n]:
print(line)
# corpus_name = "cornell movie-dialogs corpus"
# corpus = os.path.join("data", corpus_name)
corpus_file_list=os.listdir("./data/cornell movie-dialogs corpus")
for file_name in corpus_file_list:
file_dir = os.path.join("./data/cornell movie-dialogs corpus", file_name)
print(file_dir,"的前10行")
printLines(file_dir)
这部分的结果省略在博客中!
Note:movie_lines.txt是关键数据文件,其实我们在找到一个数据集的时候,是可以从它的官网、来源或者相应的论文当中看到相应的介绍。也就是,我们至少知道某个数据集它的文件组成。
以下函数便于解析原始 movie_lines.txt 数据文件。
# 将文件的每一行拆分为字段字典
def loadLines(fileName, fields):
lines = {}
with open(fileName, 'r', encoding='iso-8859-1') as f:
for line in f:
values = line.split(" +++$+++ ")
# Extract fields
lineObj = {}
for i, field in enumerate(fields):
lineObj[field] = values[i]
lines[lineObj['lineID']] = lineObj
return lines
# 将 `loadLines` 中的行字段分组为基于 *movie_conversations.txt* 的对话
def loadConversations(fileName, lines, fields):
conversations = []
with open(fileName, 'r', encoding='iso-8859-1') as f:
for line in f:
values = line.split(" +++$+++ ")
# Extract fields
convObj = {}
for i, field in enumerate(fields):
convObj[field] = values[i]
# Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
lineIds = eval(convObj["utteranceIDs"])
# Reassemble lines
convObj["lines"] = []
for lineId in lineIds:
convObj["lines"].append(lines[lineId])
conversations.append(convObj)
return conversations
# 从对话中提取一对句子
def extractSentencePairs(conversations):
qa_pairs = []
for conversation in conversations:
# Iterate over all the lines of the conversation
for i in range(len(conversation["lines"]) - 1): # We ignore the last line (no answer for it)
inputLine = conversation["lines"][i]["text"].strip()
targetLine = conversation["lines"][i+1]["text"].strip()
# Filter wrong samples (if one of the lists is empty)
if inputLine and targetLine:
qa_pairs.append([inputLine, targetLine])
return qa_pairs
Note:以下代码使用上面定义的函数创建格式化数据文件
import csv
import codecs
# 定义新文件的路径
datafile = os.path.join(corpus, "formatted_movie_lines.txt")
delimiter = '\t'
delimiter = str(codecs.decode(delimiter, "unicode_escape"))
# 初始化行dict,对话列表和字段ID
lines = {}
conversations = []
MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]
# 加载行和进程对话
print("\nProcessing corpus...")
lines = loadLines(os.path.join(corpus, "movie_lines.txt"), MOVIE_LINES_FIELDS)
print("\nLoading conversations...")
conversations = loadConversations(os.path.join(corpus, "movie_conversations.txt"),
lines, MOVIE_CONVERSATIONS_FIELDS)
# 写入新的csv文件
print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
for pair in extractSentencePairs(conversations):
writer.writerow(pair)
# 打印一个样本的行
print("\nSample lines from file:")
printLines(datafile)
Processing corpus...
Loading conversations...
Writing newly formatted file...
Sample lines from file:
b"Can we make this quick? Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad. Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part. Please.\n"
b"Not the hacking and gagging and spitting part. Please.\tOkay... then how 'bout we try out some French cuisine. Saturday? Night?\n"
b"You're asking me out. That's so cute. What's your name again?\tForget it.\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser. My sister. I can't date until she does.\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser. My sister. I can't date until she does.\tSeems like she could get a date easy enough...\n"
b'Why?\tUnsolved mystery. She used to be really popular when she started high school, then it was just like she got sick of it or something.\n'
b"Unsolved mystery. She used to be really popular when she started high school, then it was just like she got sick of it or something.\tThat's a shame.\n"
b'Gosh, if only we could find Kat a boyfriend...\tLet me see what I can do.\n'
# 默认词向量
PAD_token = 0 # Used for padding short sentences
SOS_token = 1 # Start-of-sentence token
EOS_token = 2 # End-of-sentence token
创建了一个Voc类,它会存储从单词到索引的映射、索引到单词的反向映射、每个单词的计数和总单词量。这个类提供向词汇表中添加单词的方法(addWord
)、添加句子的所有单词到词汇表中的方法 (addSentence
) 和清洗不常见的单词方法(trim
)。更多的数据清洗在后面进行。
class Voc:
def __init__(self, name):
self.name = name
self.trimmed = False
self.word2index = {}
self.word2count = {}
self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
self.num_words = 3 # Count SOS, EOS, PAD
# 添加句子中的所有单词到词汇表
def addSentence(self, sentence):
for word in sentence.split(' '):
self.addWord(word)
# 向词汇表中添加单词
def addWord(self, word):
if word not in self.word2index:
self.word2index[word] = self.num_words
self.word2count[word] = 1
self.index2word[self.num_words] = word
self.num_words += 1
else:
self.word2count[word] += 1
# 删除低于特定计数阈值的单词
def trim(self, min_count):
if self.trimmed:
return
self.trimmed = True
keep_words = []
for k, v in self.word2count.items():
if v >= min_count:
keep_words.append(k)
print('keep_words {} / {} = {:.4f}'.format(
len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
))
# 重初始化字典
self.word2index = {}
self.word2count = {}
self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
self.num_words = 3 # Count default tokens
for word in keep_words:
self.addWord(word)
使用unicodeToAscii
将 unicode 字符串转换为 ASCII。然后,我们应该将所有字母转换为小写字母并清洗掉除基本标点之 外的所有非字母字符 (normalizeString
)。最后,为了帮助训练收敛,我们将过滤掉长度大于MAX_LENGTH 的句子 (filterPairs
)。
# 将Unicode字符串转换为纯ASCII
def unicodeToAscii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
)
# normalizeString函数是一个正则化的函数,也就是使数据更加标准化的
def normalizeString(s):
s = unicodeToAscii(s.lower().strip())
s = re.sub(r"([.!?])", r" \1", s)
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
return s
MAX_LENGTH = 10 # Maximum sentence length to consider
# 初始化Voc对象 和 格式化pairs对话存放到list中
def readVocs(datafile, corpus_name):
print("Reading lines...")
# Read the file and split into lines
lines = open(datafile, encoding='utf-8').read().strip().split('\n')
# Split every line into pairs and normalize
pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
voc = Voc(corpus_name)
return voc, pairs
# 如果对 'p' 中的两个句子都低于 MAX_LENGTH 阈值,则返回True
def filterPair(p):
# Input sequences need to preserve the last word for EOS token
return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH
# 过滤满足条件的 pairs 对话
def filterPairs(pairs):
return [pair for pair in pairs if filterPair(pair)]
# 使用上面定义的函数,返回一个填充的voc对象和对列表
def loadPrepareData(corpus, corpus_name, datafile, save_dir):
print("Start preparing training data ...")
voc, pairs = readVocs(datafile, corpus_name)
print("Read {!s} sentence pairs".format(len(pairs)))
pairs = filterPairs(pairs)
print("Trimmed to {!s} sentence pairs".format(len(pairs)))
print("Counting words...")
for pair in pairs:
voc.addSentence(pair[0])
voc.addSentence(pair[1])
print("Counted words:", voc.num_words)
return voc, pairs
# 加载/组装voc和对
save_dir = os.path.join("data", "save")
voc, pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)
# 打印一些对进行验证
print("\npairs:")
for pair in pairs[:10]:
print(pair)
Start preparing training data ...
Reading lines...
Read 221282 sentence pairs
Trimmed to 63446 sentence pairs
Counting words...
Counted words: 17774
pairs:
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', ' the real you . ']
另一种有利于让训练更快收敛的策略是去除词汇表中很少使用的单词。减少特征空间也会降低模型学习目标函数的难度。我们通过以下两个步 骤完成这个操作:
voc.trim
函数去除 MIN_COUNT 阈值以下单词 。MIN_COUNT = 3 # 修剪的最小字数阈值
def trimRareWords(voc, pairs, MIN_COUNT):
# 修剪来自voc的MIN_COUNT下使用的单词
voc.trim(MIN_COUNT)
# Filter out pairs with trimmed words
keep_pairs = []
for pair in pairs:
input_sentence = pair[0]
output_sentence = pair[1]
keep_input = True
keep_output = True
# 检查输入句子
for word in input_sentence.split(' '):
if word not in voc.word2index:
keep_input = False
break
# 检查输出句子
for word in output_sentence.split(' '):
if word not in voc.word2index:
keep_output = False
break
# 只保留输入或输出句子中不包含修剪单词的对
if keep_input and keep_output:
keep_pairs.append(pair)
print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
return keep_pairs
# 修剪voc和对
pairs = trimRareWords(voc, pairs, MIN_COUNT)
keep_words 7706 / 17771 = 0.4336
Trimmed from 63446 pairs to 52456, 0.8268 of total
print("pairs类型:", type(pairs))
print("pairs的Size:", len(pairs))
print("pairs前10个元素:", pairs[0:10])
pairs类型: <class 'list'>
pairs的Size: 52456
pairs前10个元素: [['there .', 'where ?'], ['you have my word . as a gentleman', 'you re sweet .'], ['hi .', 'looks like things worked out tonight huh ?'], ['have fun tonight ?', 'tons'], ['well no . . .', 'then that s all you had to say .'], ['then that s all you had to say .', 'but'], ['but', 'you always been this selfish ?'], ['do you listen to this crap ?', 'what crap ?'], ['what good stuff ?', ' the real you . '], ['wow', 'let s go .']]
Note: 实际上,在python当中,所有数据清洗到最后,在转换成数字之前,基本都转换成列表的形式:
[
[样本1],
[样本2],
[样本3],
...,
[样本n],
]
【作者简介】陈艺荣,男,目前在华南理工大学电子与信息学院广东省人体数据科学工程技术研究中心攻读博士,担任IEEE Access、IEEE Photonics Journal的审稿人。两次获得美国大学生数学建模竞赛(MCM)一等奖,获得2017年全国大学生数学建模竞赛(广东赛区)一等奖、2018年广东省大学生电子设计竞赛一等奖等科技竞赛奖项,主持一项2017-2019年国家级大学生创新训练项目获得优秀结题,参与两项广东大学生科技创新培育专项资金、一项2018-2019年国家级大学生创新训练项目获得良好结题,发表SCI论文4篇,授权实用新型专利8项,受理发明专利13项。
我的主页
我的Github
我的CSDN博客
我的Linkedin