当前位置: 首页 > 面试题库 >

Scikit学习predict_proba给出错误答案

贡俊
2023-03-14
问题内容

在该问题中,我引用了以下代码

>>> import sklearn
>>> sklearn.__version__
'0.13.1'
>>> from sklearn import svm
>>> model = svm.SVC(probability=True)
>>> X = [[1,2,3], [2,3,4]] # feature vectors
>>> Y = ['apple', 'orange'] # classes
>>> model.fit(X, Y)
>>> model.predict_proba([1,2,3])
array([[ 0.39097541,  0.60902459]])

我在该问题中发现,该结果按model.classes_给出的顺序表示了属于每个类的点的概率

>>> zip(model.classes_, model.predict_proba([1,2,3])[0])
[('apple', 0.39097541289393828), ('orange', 0.60902458710606167)]

所以…如果正确解释,此答案表示该点可能是“橙色”(由于数据量很少,因此置信度较低)。但是直觉上,这个结果显然是不正确的,因为给出的点与“苹果”的训练数据相同。可以肯定的是,我也进行了相反的测试:

>>> zip(model.classes_, model.predict_proba([2,3,4])[0])
[('apple', 0.60705475211840931), ('orange', 0.39294524788159074)]

同样,显然是不正确的,但方向相反。

最后,我尝试了更远的点。

>>> X = [[1,1,1], [20,20,20]] # feature vectors
>>> model.fit(X, Y)
>>> zip(model.classes_, model.predict_proba([1,1,1])[0])
[('apple', 0.33333332048410247), ('orange', 0.66666667951589786)]

同样,该模型会预测错误的概率。但是,model.predict函数正确了!

>>> model.predict([1,1,1])[0]
'apple'

现在,我记得在文档中读到一些有关小型数据集的predict_proba错误的信息,尽管我似乎再也找不到它了。这是预期的行为,还是我做错了什么?如果这是预期的行为,那么为什么为什么predict和predict_proba函数的输出不一致?而且重要的是,在我可以信任预报代码_proba的结果之前,数据集需要多大?

--------更新--------

好的,所以我对此做了更多的“实验”:predict_proba的行为严重依赖于“ n”,但并非以任何可预测的方式!

>>> def train_test(n):
...     X = [[1,2,3], [2,3,4]] * n
...     Y = ['apple', 'orange'] * n
...     model.fit(X, Y)
...     print "n =", n, zip(model.classes_, model.predict_proba([1,2,3])[0])
... 
>>> train_test(1)
n = 1 [('apple', 0.39097541289393828), ('orange', 0.60902458710606167)]
>>> for n in range(1,10):
...     train_test(n)
... 
n = 1 [('apple', 0.39097541289393828), ('orange', 0.60902458710606167)]
n = 2 [('apple', 0.98437355278112448), ('orange', 0.015626447218875527)]
n = 3 [('apple', 0.90235408180319321), ('orange', 0.097645918196806694)]
n = 4 [('apple', 0.83333299908143665), ('orange', 0.16666700091856332)]
n = 5 [('apple', 0.85714254878984497), ('orange', 0.14285745121015511)]
n = 6 [('apple', 0.87499969631893626), ('orange', 0.1250003036810636)]
n = 7 [('apple', 0.88888844127886335), ('orange', 0.11111155872113669)]
n = 8 [('apple', 0.89999988018127364), ('orange', 0.10000011981872642)]
n = 9 [('apple', 0.90909082368682159), ('orange', 0.090909176313178491)]

如何在我的代码中安全使用此功能?至少,是否有n的值可以保证与model.predict的结果一致?


问题答案:

如果您使用svm.LinearSVC()estimator,并且.decision_function()(例如svm.SVC的.predict_proba())将结果从最可能的类排序到最不可能的类。这与.predict()功能一致。另外,此估算器速度更快,并且获得的结果几乎相同svm.SVC()

对您来说,唯一的缺点可能是.decision_function()给出一个介于-1和3之间的有符号值sth而不是概率值。但它与预测一致。



 类似资料:
  • 问题内容: 我正在尝试从集群模块调用函数,如下所示: 我收到以下错误: 在IPython中,制表符补全似乎可以访问基本,克隆,外部,re,setup_module,sys和警告模块。sklearn目录中没有其他(包括群集)。 遵循以下pbu的建议并使用 我得到: 我在Windows上使用Python 3.4,scikit-learn 0.16.1。 问题答案: 问题是scipy / numpy安装

  • 我正在练习JPA/Hibernate,在运行下面的代码时: 和Book类(没有getters和setters): 我收到以下错误: 我真的不知道这意味着什么以及如何解决它,我花了几个小时搜索,找不到类似的东西。persistence.xml代码: 我在这里真的找不到任何错误——我在互联网上的许多地方看到了非常相似的代码。正如另一位用户所指出的,这是Hibernate生成的sql语句:创建表Book

  • 我正在学习DOM,希望创建一个简单的JavaScript with Html测验(用于练习)。现在我遇到的问题是,当我点击提交时,所有的答案都是对的,而不是一个是对的,三个是错的。我认为这是我的html和我将ID分配给不同标签的方式的问题,但我不能弄清楚我做错了什么。 代码 多谢了。

  • 我使用下面的代码来格式化日期。但是当我以不正确的格式给出数据时,它会给出意想不到的结果。 在上述情况下,输出为-formattedVal:0009-02-05。 它不是抛出解析异常,而是解析值并给我一个不正确的输出。有人能帮我理解这种反常的行为吗。

  • 从sklearn加载流行数字数据集。数据集模块,并将其分配给可变数字。 分割数字。将数据分为两组,分别命名为X_train和X_test。还有,分割数字。目标分为两组Y_训练和Y_测试。 提示:使用sklearn中的训练测试分割方法。模型选择;将随机_状态设置为30;并进行分层抽样。使用默认参数,从X_序列集和Y_序列标签构建SVM分类器。将模型命名为svm_clf。 在测试数据集上评估模型的准确

  • 问题内容: 我计算出以下内容: 即使执行10.0-9.2也可以得到上述结果。为什么多余的7会出现在结果中? 我在python 3.2上。 问题答案: 浮点算法基于数字的二进制近似,因此存在一些内置问题。 有一个很好的解释在Python文档。 如果您需要更准确的答案,则可以签出该模块。