pytorch-textclassification是一个以pytorch和transformers为基础,专注于文本分类的轻量级自然语言处理工具包。支持中文长文本、短文本的多类分类和多标签分类。
所有数据集均来源于网络,只做整理供大家提取方便,如果有侵权等问题,请及时联系删除。
1. 文本分类 (txt格式, 每行为一个json):
1.1 多类分类格式:
{"text": "人站在地球上为什么没有头朝下的感觉", "label": "教育"}
{"text": "我的小baby", "label": "娱乐"}
{"text": "请问这起交通事故是谁的责任居多小车和摩托车发生事故在无红绿灯", "label": "娱乐"}
1.2 多标签分类格式:
{"label": "3|myz|5", "text": "课堂搞东西,没认真听"}
{"label": "3|myz|2", "text": "测验90-94.A-"}
{"label": "3|myz|2", "text": "长江作业未交"}
更多样例sample详情见test/tc目录
# !/usr/bin/python
# -*- coding: utf-8 -*-
# !/usr/bin/python
# -*- coding: utf-8 -*-
# @time : 2021/2/23 21:34
# @author : Mo
# @function: 多标签分类, 根据label是否有|myz|分隔符判断是多类分类, 还是多标签分类
# 适配linux
import platform
import json
import sys
import os
path_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
sys.path.append(os.path.join(path_root, "pytorch_textclassification"))
print(path_root)
# 分类下的引入, pytorch_textclassification
from tcTools import get_current_time
from tcRun import TextClassification
from tcConfig import model_config
evaluate_steps = 320 # 评估步数
save_steps = 320 # 存储步数
# pytorch预训练模型目录, 必填
pretrained_model_name_or_path = "bert-base-chinese"
# 训练-验证语料地址, 可以只输入训练地址
path_corpus = os.path.join(path_root, "corpus", "text_classification", "school")
path_train = os.path.join(path_corpus, "train.json")
path_dev = os.path.join(path_corpus, "dev.json")
if __name__ == "__main__":
model_config["evaluate_steps"] = evaluate_steps # 评估步数
model_config["save_steps"] = save_steps # 存储步数
model_config["path_train"] = path_train # 训练模语料, 必须
model_config["path_dev"] = path_dev # 验证语料, 可为None
model_config["path_tet"] = None # 测试语料, 可为None
# 损失函数类型,
# multi-class: 可选 None(BCE), BCE, BCE_LOGITS, MSE, FOCAL_LOSS, DICE_LOSS, LABEL_SMOOTH
# multi-label: SOFT_MARGIN_LOSS, PRIOR_MARGIN_LOSS, FOCAL_LOSS, CIRCLE_LOSS, DICE_LOSS等
model_config["path_tet"] = "FOCAL_LOSS"
os.environ["CUDA_VISIBLE_DEVICES"] = str(model_config["CUDA_VISIBLE_DEVICES"])
model_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path
model_config["model_save_path"] = "../output/text_classification/model_{}".format(model_type[idx])
model_config["model_type"] = "BERT"
# main
lc = TextClassification(model_config)
lc.process()
lc.train()
This library is inspired by and references following frameworks and papers.
For citing this work, you can refer to the present GitHub project. For example, with BibTeX:
@software{Pytorch-NLU,
url = {https://github.com/yongzhuo/Pytorch-NLU},
author = {Yongzhuo Mo},
title = {Pytorch-NLU},
year = {2021}
*希望对你有所帮助!
precision recall f1-score support
micro_avg 0.7920 0.7189 0.7537 466 MARGIN_LOSS
micro_avg 0.6706 0.8519 0.7505 466 PRIOR-MARGIN_LOSS
micro_avg 0.8258 0.6309 0.7153 466 FOCAL_LOSS【0.5, 2】
micro_avg 0.7890 0.7382 0.7627 466 CIRCLE_LOSS
micro_avg 0.7612 0.7661 0.7636 466 DICE_LOSS【直接学习F1?】
micro_avg 0.8062 0.7232 0.7624 466 BCE
micro_avg 0.7825 0.7103 0.7447 466 BCE-Logits
micro_avg 0.7899 0.7017 0.7432 466 BCE-smooth
micro_avg 0.7235 0.8197 0.7686 466 FOCAL_LOSS【0.5, 2】 + PRIOR-MARGIN_LOSS / 2
precision recall f1-score support
macro_avg 0.6198 0.5338 0.5641 466 MARGIN_LOSS
macro_avg 0.5103 0.7200 0.5793 466 PRIOR-MARGIN_LOSS
macro_avg 0.7655 0.4973 0.5721 466 FOCAL_LOSS【0.5, 2】
macro_avg 0.6275 0.5235 0.5627 466 CIRCLE_LOSS
macro_avg 0.4287 0.3918 0.4025 466 DICE_LOSS【直接学习F1?】
macro_avg 0.6978 0.5158 0.5828 466 BCE
macro_avg 0.6046 0.5123 0.5433 466 BCE-Logits
macro_avg 0.6963 0.5012 0.5721 466 BCE-smooth
macro_avg 0.6033 0.6809 0.6369 466 FOCAL_LOSS【0.5, 2】 + PRIOR-MARGIN_LOSS / 2
precision recall f1-score support
3 0.8102 0.7919 0.8009 221
2 0.8030 0.8030 0.8030 132
1 0.7333 0.4925 0.5893 67
6 0.7143 0.5000 0.5882 10
5 0.7778 0.4828 0.5957 29
0 0.0000 0.0000 0.0000 4
4 0.5000 0.6667 0.5714 3
micro_avg 0.7920 0.7189 0.7537 466
macro_avg 0.6198 0.5338 0.5641 466
weighted_avg 0.7841 0.7189 0.7454 466
precision recall f1-score support
3 0.7279 0.8959 0.8032 221
2 0.7039 0.9545 0.8103 132
1 0.5897 0.6866 0.6345 67
6 0.3333 0.5000 0.4000 10
5 0.6296 0.5862 0.6071 29
0 0.1875 0.7500 0.3000 4
4 0.4000 0.6667 0.5000 3
micro_avg 0.6706 0.8519 0.7505 466
macro_avg 0.5103 0.7200 0.5793 466
weighted_avg 0.6799 0.8519 0.7538 466
precision recall f1-score support
3 0.8482 0.7330 0.7864 221
2 0.8349 0.6894 0.7552 132
1 0.7586 0.3284 0.4583 67
6 0.6667 0.4000 0.5000 10
5 0.7500 0.4138 0.5333 29
0 1.0000 0.2500 0.4000 4
4 0.5000 0.6667 0.5714 3
micro_avg 0.8258 0.6309 0.7153 466
macro_avg 0.7655 0.4973 0.5721 466
weighted_avg 0.8206 0.6309 0.7038 466
precision recall f1-score support
3 0.8125 0.8235 0.8180 221
2 0.7914 0.8333 0.8118 132
1 0.7333 0.4925 0.5893 67
6 0.6667 0.4000 0.5000 10
5 0.7222 0.4483 0.5532 29
0 0.0000 0.0000 0.0000 4
4 0.6667 0.6667 0.6667 3
micro_avg 0.7890 0.7382 0.7627 466
macro_avg 0.6275 0.5235 0.5627 466
weighted_avg 0.7785 0.7382 0.7521 466
precision recall f1-score support
3 0.7714 0.8552 0.8112 221
2 0.7727 0.9015 0.8322 132
1 0.7347 0.5373 0.6207 67
6 0.0000 0.0000 0.0000 10
5 0.7222 0.4483 0.5532 29
0 0.0000 0.0000 0.0000 4
4 0.0000 0.0000 0.0000 3
micro_avg 0.7612 0.7661 0.7636 466
macro_avg 0.4287 0.3918 0.4025 466
weighted_avg 0.7353 0.7661 0.7441 466
precision recall f1-score support
3 0.8136 0.8100 0.8118 221
2 0.8029 0.8333 0.8178 132
1 0.8235 0.4179 0.5545 67
6 0.6667 0.4000 0.5000 10
5 0.7778 0.4828 0.5957 29
0 0.0000 0.0000 0.0000 4
4 1.0000 0.6667 0.8000 3
micro_avg 0.8062 0.7232 0.7624 466
macro_avg 0.6978 0.5158 0.5828 466
weighted_avg 0.8009 0.7232 0.7493 466
precision recall f1-score support
3 0.7973 0.8009 0.7991 221
2 0.8000 0.7879 0.7939 132
1 0.7317 0.4478 0.5556 67
6 0.6667 0.4000 0.5000 10
5 0.7368 0.4828 0.5833 29
0 0.0000 0.0000 0.0000 4
4 0.5000 0.6667 0.5714 3
micro_avg 0.7825 0.7103 0.7447 466
macro_avg 0.6046 0.5123 0.5433 466
weighted_avg 0.7733 0.7103 0.7344 466
precision recall f1-score support
3 0.7945 0.7873 0.7909 221
2 0.8120 0.8182 0.8151 132
1 0.7027 0.3881 0.5000 67
6 0.8000 0.4000 0.5333 10
5 0.7647 0.4483 0.5652 29
0 0.0000 0.0000 0.0000 4
4 1.0000 0.6667 0.8000 3
micro_avg 0.7899 0.7017 0.7432 466
macro_avg 0.6963 0.5012 0.5721 466
weighted_avg 0.7790 0.7017 0.7296 466
【1/2】
precision recall f1-score support
3 0.7640 0.8643 0.8110 221
2 0.7205 0.8788 0.7918 132
1 0.6620 0.7015 0.6812 67
6 0.4167 0.5000 0.4545 10
5 0.7600 0.6552 0.7037 29
0 0.4000 0.5000 0.4444 4
4 0.5000 0.6667 0.5714 3
micro_avg 0.7235 0.8197 0.7686 466
macro_avg 0.6033 0.6809 0.6369 466
weighted_avg 0.7245 0.8197 0.7679 466
【调和平均数】
precision recall f1-score support
3 0.8474 0.7285 0.7835 221
2 0.8304 0.7045 0.7623 132
1 0.8182 0.4030 0.5400 67
6 0.8000 0.4000 0.5333 10
5 0.7143 0.3448 0.4651 29
0 1.0000 0.2500 0.4000 4
4 0.6667 0.6667 0.6667 3
micro_avg 0.8324 0.6395 0.7233 466
macro_avg 0.8110 0.4996 0.5930 466
weighted_avg 0.8292 0.6395 0.7132 466
【1/3 + 2/3-focal】
precision recall f1-score support
3 0.7890 0.8462 0.8166 221
2 0.7516 0.8939 0.8166 132
1 0.6935 0.6418 0.6667 67
6 0.3636 0.4000 0.3810 10
5 0.6538 0.5862 0.6182 29
0 0.4000 0.5000 0.4444 4
4 0.5000 0.6667 0.5714 3
micro_avg 0.7430 0.8004 0.7707 466
macro_avg 0.5931 0.6478 0.6164 466
weighted_avg 0.7420 0.8004 0.7686 466
【1/4-prior + 3/4-focal】
precision recall f1-score support
3 0.7956 0.8100 0.8027 221
2 0.7712 0.8939 0.8281 132
1 0.6981 0.5522 0.6167 67
6 0.6667 0.4000 0.5000 10
5 0.7143 0.5172 0.6000 29
0 0.3333 0.2500 0.2857 4
4 0.5000 0.6667 0.5714 3
micro_avg 0.7656 0.7639 0.7648 466
macro_avg 0.6399 0.5843 0.6007 466
weighted_avg 0.7610 0.7639 0.7581 466
【4/9-prior + 5/9-focal】
precision recall f1-score support
3 0.7819 0.8597 0.8190 221
2 0.7578 0.9242 0.8328 132
1 0.6567 0.6567 0.6567 67
6 0.5000 0.5000 0.5000 10
5 0.6250 0.5172 0.5660 29
0 0.2857 0.5000 0.3636 4
4 0.5000 0.6667 0.5714 3
micro_avg 0.7364 0.8155 0.7739 466
macro_avg 0.5867 0.6607 0.6156 466
weighted_avg 0.7352 0.8155 0.7715 466
希望对你有所帮助!