可以使用convert_pytorch_checkpoint_to_tf.py将pytorch版本的 bert模型转换为TF版本的bert模型,不过需要注意的是需要将程序进行一定的修改:
原始代码:
model = BertModel.from_pretrained(
pretrained_model_name_or_path=args.model_name,
state_dict=torch.load(args.pytorch_model_path),
cache_dir=args.cache_dir
)
更改为:
model = BertModel.from_pretrained(
pretrained_model_name_or_path=args.cache_dir
)
也就是说,以下两个参数并没有用到
model_name
pytorch_model_path