当前位置: 首页 > 工具软件 > ABSA-PyTorch > 使用案例 >

运行ABSA-PyTorch报错ImportError: cannot import name ‘SAVE_STATE_WARNING‘ from ‘torch.optim.lr_scheduler‘

楚雪松
2023-12-01

项目场景:

ABSA-PyTorch-master:链接为https://github.com/songyouwei/ABSA-PyTorch

基于pytorch的属性级情感分析


问题描述

torch版本和transformers不对:

Traceback (most recent call last):
  File "train.py", line 17, in <module>
    from transformers import BertModel
  File "C:\Users\admin\anaconda3\lib\site-packages\transformers\__init__.py", line 626, in <module>
    from .trainer import Trainer
  File "C:\Users\admin\anaconda3\lib\site-packages\transformers\trainer.py", line 69, in <module>
    from .trainer_pt_utils import (
  File "C:\Users\admin\anaconda3\lib\site-packages\transformers\trainer_pt_utils.py", line 40, in <module>
    from torch.optim.lr_scheduler import SAVE_STATE_WARNING
ImportError: cannot import name 'SAVE_STATE_WARNING' from 'torch.optim.lr_scheduler' (C:\Users\admin\anaconda3\lib\site-packages\torch\optim\lr_scheduler.py)


原因分析:

这是项目requirements.txt 要求安装:

numpy>=1.13.3
torch>=0.4.0
transformers>=3.5.1,<4.0.0
sklearn

我选择的是transformers==3.5.1版本和 torch ==1.13.0,之后运行就会出现上述报错;
网上查阅资料是torch版本和transformers不对,参考网友解决办法: 将Pytorch版本降级到1.4.0
或者是更新transformers的版本

原链接:https://stackoverflow.com/questions/66590981/transformer-error-importing-packages-importerror-cannot-import-name-save-st


解决方案:

但是由于我找不到1.4.0版本的torch,便使用的是1.7.1,居然能够匹配

pip install torch==1.7.1

正常安装完成,运行train,即可开始训练

python train.py --model_name bert_spc --dataset restaurant

运行效果如下:

> n_trainable_params: 109484547, n_nontrainable_params: 0
> training arguments:
>>> model_name: bert_spc
>>> dataset: restaurant
>>> optimizer: <class 'torch.optim.adam.Adam'>
>>> initializer: <function xavier_uniform_ at 0x00000217B3EF1AF0>
>>> lr: 2e-05
>>> dropout: 0.1
>>> l2reg: 0.01
>>> num_epoch: 20
>>> batch_size: 16
>>> log_step: 10
>>> embed_dim: 300
>>> hidden_dim: 300
>>> bert_dim: 768
>>> pretrained_bert_name: bert-base-uncased
>>> max_seq_len: 85
>>> polarities_dim: 3
>>> hops: 3
>>> patience: 5
>>> device: cpu
>>> seed: 1234
>>> valset_ratio: 0
>>> local_context_focus: cdm
>>> SRD: 3
>>> model_class: <class 'models.bert_spc.BERT_SPC'>
>>> dataset_file: {'train': './datasets/semeval14/Restaurants_Train.xml.seg', 'test': './datasets/semeval14/Restaurants_Test_Gold.xml.seg'}
>>> dataset_file: {'train': './datasets/semeval14/Restaurants_Train.xml.seg', 'test': './datasets/semeval14/Restaurants_Test_Gold.xml.seg'}
>>> inputs_cols: ['concat_bert_indices', 'concat_segments_indices']
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
epoch: 0
loss: 1.0587, acc: 0.5125
loss: 1.0007, acc: 0.5531
loss: 0.9748, acc: 0.5792
loss: 0.9881, acc: 0.5781
loss: 0.9771, acc: 0.5813
loss: 0.9650, acc: 0.5875
loss: 0.9570, acc: 0.5893
loss: 0.9282, acc: 0.6016
loss: 0.8982, acc: 0.6111
loss: 0.8898, acc: 0.6138
loss: 0.8741, acc: 0.6244
loss: 0.8566, acc: 0.6307
loss: 0.8370, acc: 0.6404
loss: 0.8285, acc: 0.6438
loss: 0.8195, acc: 0.6467
loss: 0.7989, acc: 0.6570
loss: 0.7943, acc: 0.6614
loss: 0.7826, acc: 0.6677
loss: 0.7751, acc: 0.6720

 类似资料: