12月15日,facebook宣布开源PyText NLP框架。 PyText是一种基于深度学习的NLP建模框架,基于PyTorch 1.0构建。它可以连接 ONNX 和 Caffe2,借助 PyText,AI 研究人员和工程师可以把 PyTorch 模型转化为 ONNX,然后将其导出为 Caffe2,用于大规模生产部署,让模型的建立,更新,发布更加便捷。
项目地址:https://github.com/facebookresearch/pytext
平台:linux centos7
IDE:anaconda
python:3.6(注意一定要是3.6版本以上,否则bug大人会来找你)
step1.创建虚拟环境
用anaconda可以很方便的进行python的版本管理和包的管理,anaconda下载地址:
https://www.anaconda.com/download/
创建一个python 3.6版本,名称为pytext的虚拟环境:
conda create -n pytext python=3.6
step2.下载相关的包
pip install pytext-nlp
不过值得注意的是,通过这个方式下载的是CPU版,如果需要安装GPU版,需要自己手动去安装PyTorch的GPU版,步骤如下:
export CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" # [anaconda的根目录]
#安装基本的依赖包
conda install numpy pyyaml mkl mkl-include setuptools cmake cffi typing
#为GPU添加LAPACK支持
conda install -c pytorch magma-cuda92 # or [magma-cuda80 | magma-cuda91] 取决于你的cuda版本,一定要在7.5以上
step1.将git上的PyText项目克隆到本机:
git clone https://github.com/facebookresearch/pytext.git
step2.PyText的用法
pytext [OPTIONS] COMMAND [ARGS]...
[OPTIONS]:
–config-file TEXT
–config-json TEXT
–help
Commands:
export 将pytext模型快照转换为caffe2模型.
gen-default-config 根据默认参数生成一个json格式的配置文件
help-config 打印配置文件参数的帮助信息
predict 启动caffe2模型进行预测
predict-py 启动PyTorch进行预测
test 测试训练模型快照
train 训练模型并保存最好的快照
step3.demo实现
数据集是PyText提供的demo数据,做一个分类模型的训练,根据输入的命令文本来判断命令属于哪一个类。这个数据集十分小,因此不要对其准确率抱有期待了,后面的预测环节,更是不要有所期待。熟悉用法之后,还是拿自己的数据来训练吧。
训练集如下:
alarm/modify_alarm 16:24:datetime,39:57:datetime change my alarm tomorrow to wake me up 30 minutes earlier
alarm/set_alarm Turn on all my alarms
训练配置
{
"task": {
"DocClassificationTask": {
"data_handler": {
"train_path": "tests/data/train_data_tiny.tsv",
"eval_path": "tests/data/test_data_tiny.tsv",
"test_path": "tests/data/test_data_tiny.tsv"
}
}
}
}
训练(10 epoch,默认将结果保存在/tmp/test_out.txt)
pytext train < demo/configs/docnn.json
测试
pytext test < demo/configs/docnn.json
模型快照的导出(默认是在/tmp/model.caffe2.predictor)
pytext export --output-path exported_model.c2 (这里是指定的路径)< demo/configs/docnn.json
预测
最简单的方法是使用命令行:
pytext --config-file demo/configs/docnn.json predict <<< '{"raw_text": "create an alarm for 1:30 pm"}'
但是如果你保存在了自己路径,不是/tmp/model.caffe2.predictor的话,就会find不到模型啦,因此,还可以自己写一个脚本运行
import pytext
config_file = 'demo/configs/docnn.json' #配置文件路径
model_file = 'exported_model_demo.c2' #之前export模型的路径
config = pytext.load_config(config_file)
predictor = pytext.create_predictor(config,model_file)
result = predictor({"raw_text":"create an alarm for 1:30 pm"})
print(result)
PyTorch和TensorFlow的战争,会不会因为PyText的出现有所改变,拭目以待。
参考资料: