使用transformers库中的convert_pytorch_checkpoint_to_tf,转换pytorch的bert模型成TF格式

公良向阳
2023-12-01

可以使用convert_pytorch_checkpoint_to_tf.py将pytorch版本的 bert模型转换为TF版本的bert模型,不过需要注意的是需要将程序进行一定的修改:
原始代码:

    model = BertModel.from_pretrained(
        pretrained_model_name_or_path=args.model_name,
        state_dict=torch.load(args.pytorch_model_path),
        cache_dir=args.cache_dir
    )

更改为:

    model = BertModel.from_pretrained(
        pretrained_model_name_or_path=args.cache_dir
    )

也就是说,以下两个参数并没有用到

model_name
pytorch_model_path

 类似资料: