当前位置: 首页 > 工具软件 > AllenNLP > 使用案例 >

AllenNLP框架学习笔记(模型篇之保存与加载)

司空镜
2023-12-01

通常,用户想在磁盘上保存并加载经过训练的模型。这就是使用AllenNLP的配置文件非常有用的地方,因为加载模型所需的所有内容,包括权重、配置和词汇表,都可以存储在单个tar文件中。在本章中,将介绍三种对模型进行保存与加载的方式。

手动保存与加载

为了正确地保存和加载AllenNLP模型,我们一般需要有如下文件:

  • 模型配置(用于训练模型的规范)
  • 模型权重(模型的训练参数)
  • 词汇表

AllenNLP中,模型配置由Params类管理,可以使用to_file()方法保存到磁盘。用户可以使用model.state_dict()检索模型权重,并使用PyTorchtorch.save()将其保存到磁盘中。 Vocabulary.save_to_files()方法将Vocabulary对象序列化到目录。

为了从文件加载模型,可以使用Model.load()类方法。 它需要一个Params对象,该对象包含模型配置以及模型权重和词汇序列化的目录路径。 该方法还加载和还原词汇表。

示例代码如下:

# 存储模型
serialization_dir = 'model'
config_file = os.path.join(serialization_dir, 'config.json')
vocabulary_dir = os.path.join(serialization_dir, 'vocabulary')
weights_file = os.path.join(serialization_dir, 'weights.th')

os.makedirs(serialization_dir, exist_ok=True)
params.to_file(config_file)
vocab.save_to_files(vocabulary_dir)
torch.save(model.state_dict(), weights_file)

# 加载模型
loaded_params = Params.from_file(config_file)
loaded_model = Model.load(loaded_params, serialization_dir, weights_file)
loaded_vocab = loaded_model.vocab  # 在上一步已经加载进去了

archive保存与加载

因为每次需要保存、加载和移动模型时都要处理这三个文件,所以AllenNLP提供了用于归档和取消归档模型文件的实用功能。用户可以使用archive_model()方法将模型配置、权重和词汇表打包成一个tar.gz文件,以及任何附加的补充文件。此方法假设用户使用 training loop训练模型,并打包 training loop运行时保存的文件。 training loop也调用此函数,以便在训练结束时打包最佳模型权重,因此用户不太可能需要自己调用此方法。

另外,用户可以简单地使用load_archive()从存档文件中还原模型。 这将返回一个Archive对象,其中包含配置和模型。

# 建立archive文件
archive_model(serialization_dir, weights='weights.th')

# 加载archive文件
archive = load_archive(os.path.join(serialization_dir, 'model.tar.gz'))

AllenNLP命令保存与加载

实际上,如果用户使用AllenNLP命令(例如allennlp train),就会自动处理模型的保存。 训练结束或中断训练后,该命令会自动将最佳模型保存到model.tar.gz文件中。 用户还可以从序列化目录恢复训练。训练命令样例如下:

# my_text_classifier.jsonnet指模型配置文件,-s 参数是保存模型的文件夹名,--include-package 是放自己写的自定义脚本所在的文件夹名
allennlp train my_text_classifier.jsonnet -s model --include-package my_text_classifier

Model.from_archive除了load_archive()之外,AllenNLP还提供了一种便捷方法Model.from_archive()。 这基本上只是在后台调用load_archive()。 其主要目的是将其注册为from_archive类型的Model构造函数,以便用户可以从存档文件加载保存的模型,并继续使用allennlp train命令对其进行训练。 为此,请将以下代码片段放入训练配置文件中:

{
    ...
    "model": {
        "type": "from_archive",
        "archive_file": "/path/to/saved/archive/file.tar.gz"
    }
    ...
}

参考资料

 类似资料: