当前位置: 首页 > 面试题库 >

LogisticRegression:未知标签类型:在python中使用sklearn的“ continuous”

呼延运恒
2023-03-14
问题内容

我有以下代码来测试sklearn python库的一些最流行的ML算法:

import numpy as np
from sklearn                        import metrics, svm
from sklearn.linear_model           import LinearRegression
from sklearn.linear_model           import LogisticRegression
from sklearn.tree                   import DecisionTreeClassifier
from sklearn.neighbors              import KNeighborsClassifier
from sklearn.discriminant_analysis  import LinearDiscriminantAnalysis
from sklearn.naive_bayes            import GaussianNB
from sklearn.svm                    import SVC

trainingData    = np.array([ [2.3, 4.3, 2.5],  [1.3, 5.2, 5.2],  [3.3, 2.9, 0.8],  [3.1, 4.3, 4.0]  ])
trainingScores  = np.array( [3.4, 7.5, 4.5, 1.6] )
predictionData  = np.array([ [2.5, 2.4, 2.7],  [2.7, 3.2, 1.2] ])

clf = LinearRegression()
clf.fit(trainingData, trainingScores)
print("LinearRegression")
print(clf.predict(predictionData))

clf = svm.SVR()
clf.fit(trainingData, trainingScores)
print("SVR")
print(clf.predict(predictionData))

clf = LogisticRegression()
clf.fit(trainingData, trainingScores)
print("LogisticRegression")
print(clf.predict(predictionData))

clf = DecisionTreeClassifier()
clf.fit(trainingData, trainingScores)
print("DecisionTreeClassifier")
print(clf.predict(predictionData))

clf = KNeighborsClassifier()
clf.fit(trainingData, trainingScores)
print("KNeighborsClassifier")
print(clf.predict(predictionData))

clf = LinearDiscriminantAnalysis()
clf.fit(trainingData, trainingScores)
print("LinearDiscriminantAnalysis")
print(clf.predict(predictionData))

clf = GaussianNB()
clf.fit(trainingData, trainingScores)
print("GaussianNB")
print(clf.predict(predictionData))

clf = SVC()
clf.fit(trainingData, trainingScores)
print("SVC")
print(clf.predict(predictionData))

前两个工作正常,但在LogisticRegression通话中出现以下错误:

root@ubupc1:/home/ouhma# python stack.py 
LinearRegression
[ 15.72023529   6.46666667]
SVR
[ 3.95570063  4.23426243]
Traceback (most recent call last):
  File "stack.py", line 28, in <module>
    clf.fit(trainingData, trainingScores)
  File "/usr/local/lib/python2.7/dist-packages/sklearn/linear_model/logistic.py", line 1174, in fit
    check_classification_targets(y)
  File "/usr/local/lib/python2.7/dist-packages/sklearn/utils/multiclass.py", line 172, in check_classification_targets
    raise ValueError("Unknown label type: %r" % y_type)
ValueError: Unknown label type: 'continuous'

输入数据与之前的调用中的数据相同,所以这里发生了什么?

顺便说一下,为什么会出现在第一预测一个巨大的性差异LinearRegression()SVR()算法(15.72 vs 3.95)


问题答案:

您正在将浮点数传递给分类器,该分类器期望将分类值作为目标向量。如果将其转换int为输入,那么它将被接受为输入(尽管这样做是否正确还是值得怀疑的)。

最好使用scikit的labelEncoder功能来转换您的训练成绩。

您的DecisionTree和KNeighbors限定符也是如此。

from sklearn import preprocessing
from sklearn import utils

lab_enc = preprocessing.LabelEncoder()
encoded = lab_enc.fit_transform(trainingScores)
>>> array([1, 3, 2, 0], dtype=int64)

print(utils.multiclass.type_of_target(trainingScores))
>>> continuous

print(utils.multiclass.type_of_target(trainingScores.astype('int')))
>>> multiclass

print(utils.multiclass.type_of_target(encoded))
>>> multiclass


 类似资料:
  • 问题内容: 我尝试运行以下代码。顺便说一句,我是python和sklearn的新手。 其中y是0和1的np.ndarray 我收到以下信息: 文件“ C:\ Anaconda3 \ lib \ site-packages \ sklearn \ linear_model \ logistic.py”,行> 1174,适合check_classification_targets(y) 文件“ C:\

  • 部署到本地JBOSS服务器时会发生此错误。是否有解决此警告的方法? 22:31:22992警告[org.jboss.as.server.deployment](MSC服务线程1-13)JBAS015852:无法索引类com/company/core/security/AuthRealm。类位于/C:/DevTools/jboss-eap-6.3/bin/content/platform-ws-0.

  • 我以编程方式签名PDF。每个新签名都是以增量方式添加的,我在%EOF之后添加签名字典,并像这样更新AcroForm(对不起,我在工作,所以我不能上传PDF): 也许问题是我有多个具有相同ID的对象,而我的最后一个AcroForm只引用了最后一个签名?我想能够签署一个文件多次,但我有一个问题。第一次签名没有问题,并显示以下横幅: 然后我尝试用另一个证书再次对同一个文件进行签名,这给了我签名旁边的垃圾

  • 当我试着运行一个特定的应用程序时,Gradle没有编译并显示这个奇怪的错误。我在文件中搜索了标签,但这样的标签不存在。应用程序的确切状态 错误:未知标记 然而,我之前构建的其他应用程序并非如此,它们运行平稳,没有任何错误。请告诉我有什么补救办法。谢谢大家。 编辑:新项目也面临同样的问题。 编辑2:当我删除实现的com时。Android支持约束:约束布局:构建时为1.1.0'。应用程序的渐变。它显示

  • 我正在机器学习项目中使用和从对数据集中的标签(国家名称)进行编码。一切正常,我的模型运行良好。该项目将根据包括客户所在国家在内的许多特征(数据)对银行客户是否继续或离开银行进行分类。 当我想预测(分类)一个新客户(仅一个)时,我的问题就出现了。新客户的数据仍未预处理(即,国家名称未编码)。如下所示: 在我学习机器学习的在线课程中,讲师打开包含编码数据的预处理数据集,手动检查法国代码,并在中进行更新