当前位置: 首页 > 知识库问答 >
问题:

无法为多标签分类器执行堆叠

晏沈义
2023-03-14

我正在处理一个多标签文本分类问题(目标标签总数为90)。数据分布具有长尾和类不平衡性,大约有100k条记录。我正在使用OAA策略(一对一)。我试图创建一个合奏使用堆叠。

文本功能:HashingVectorizer(功能数量2**20,字符分析器)
TSVD以降低维度(n_分量=200)。

text_pipeline = Pipeline([
    ('hashing_vectorizer', HashingVectorizer(n_features=2**20,
                                             analyzer='char')),
    ('svd', TruncatedSVD(algorithm='randomized',
                         n_components=200, random_state=19204))])

feat_pipeline = FeatureUnion([('text', text_pipeline)])

estimators_list = [('ExtraTrees',
                    OneVsRestClassifier(ExtraTreesClassifier(n_estimators=30,
                                                             class_weight="balanced",
                                                             random_state=4621))),
                   ('linearSVC',
                    OneVsRestClassifier(LinearSVC(class_weight='balanced')))]
estimators_ensemble = StackingClassifier(estimators=estimators_list,
                                         final_estimator=OneVsRestClassifier(
                                             LogisticRegression(solver='lbfgs',
                                                                max_iter=300)))

classifier_pipeline = Pipeline([
    ('features', feat_pipeline),
    ('clf', estimators_ensemble)])

错误

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-41-ad4e769a0a78> in <module>()
      1 start = time.time()
----> 2 classifier_pipeline.fit(X_train.values, y_train_encoded)
      3 print(f"Execution time {time.time()-start}")
      4 

3 frames
/usr/local/lib/python3.6/dist-packages/sklearn/utils/validation.py in column_or_1d(y, warn)
    795         return np.ravel(y)
    796 
--> 797     raise ValueError("bad input shape {0}".format(shape))
    798 
    799 

ValueError: bad input shape (89792, 83)

共有1个答案

向安福
2023-03-14

到目前为止,Stacking分类器不支持多标签分类。您可以通过查看fi参数的形状值来理解这些功能,例如这里。

解决方案是将OneVsRestClassifier包装器放在StackingClassifier之上,而不是放在单个模型上。

例子:

from sklearn.datasets import make_multilabel_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.svm import LinearSVC
from sklearn.ensemble import StackingClassifier
from sklearn.multiclass import OneVsRestClassifier

X, y = make_multilabel_classification(n_classes=3, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, 
                                                    test_size=0.33,
                                                    random_state=42)

estimators_list = [('ExtraTrees', ExtraTreesClassifier(n_estimators=30, 
                                                       class_weight="balanced", 
                                                       random_state=4621)),
                   ('linearSVC', LinearSVC(class_weight='balanced'))]

estimators_ensemble = StackingClassifier(estimators=estimators_list,
                                         final_estimator = LogisticRegression(solver='lbfgs', max_iter=300))

ovr_model = OneVsRestClassifier(estimators_ensemble)

ovr_model.fit(X_train, y_train)
ovr_model.score(X_test, y_test)

# 0.45454545454545453

 类似资料:
  • 校验者: @溪流-十四号 @大魔王飞仙 翻译者: @v Warning All classifiers in scikit-learn do multiclass classification out-of-the-box. You don’t need to use the sklearn.multiclass module unless you want to experiment with

  • 有关错误和可能的解决方案的更多信息,请阅读以下文章:http://cwiki.apache.org/confluence/display/maven/mojoExecutionException

  • subCategories标签 版本5.0.170927 新增 标签名 作用 包含属性 subCategories 获取指定分类下的子分类 categoryId ,item 标签属性: | 标签属性名 | 含义 | | --- | --- | | categoryId | 父级分类 id| | item | 循环变量,默认 vo | 代码演示 <portal:subCategories categ

  • 我正在y_test并y_pred混淆矩阵。我的数据用于多标签分类,因此行值是一种热编码。 我的数据有30个标签,但在输入混淆矩阵后,输出只有11行和列,这让我很困惑。我想我应该有一辆30X30的。 它们的格式是numpy数组。(y\u test和y\u pred是我使用dataframe.values将其转换为numpy数组的数据帧) y\U测试。形状 y_test y\u预测。形状 y\u预测

  • 好的,我是Maven和Eclipse的新手,我必须运行一个现有的项目。我试图使用命令eclipse:eclipse-Dwtpversion=2.0,但它显示了以下错误: [INFO]添加对WTP版本2.0的支持。 [INFO]使用Eclipse工作区:null [INFO]添加默认类路径容器:org.eclipse.jdt.launching.JRE_CONTAINER [INFO] ------

  • categories标签 版本5.0.170927 新增 标签名 作用 包含属性 categories 获取文章分类列表 where,order,item 标签属性: | 标签属性名 | 含义 | | --- | --- | | where | 查询条件变量, 支持数组和字符串,如$where | | order | 排序方式 | | item | 循环变量,默认 vo | 获取所有文章分类 <p