当前位置: 首页 > 知识库问答 >
问题:

LabelPropagation -如何避免被零除?

宗政燕七
2023-03-14

在使用LabelPropagation时,我经常遇到此警告(这应该是一个错误,因为它完全无法传播):

/usr/local/lib/python3.5/dist-packages/sklearn/semi_supervised/label_propagation。py:279:RuntimeWarning:true_divit自身中遇到无效值。label_distributions_/=标准化器

因此,在对RBF内核进行了几次尝试之后,我发现准材料伽马具有影响。

问题来自以下几行:

        if self._variant == 'propagation':
            normalizer = np.sum(
                self.label_distributions_, axis=1)[:, np.newaxis]
            self.label_distributions_ /= normalizer

我不明白label_distributions_怎么会全是零,特别是当它的定义是:

self.label_distributions_ = safe_sparse_dot(
graph_matrix, self.label_distributions_)

伽玛对graph_matrix有影响(因为graph_matrix是_build_graph调用内核函数的结果)。好吧。但是仍然有问题。

我提醒您如何计算传播的图权重:W = exp(-gamma * D),D数据集所有点之间的成对距离矩阵。

问题是:< code>np.exp(x)如果x非常小,则返回0.0。< br >假设我们有两个点< code>i和< code>j,因此< code>dist(i,j) = 10。

>>> np.exp(np.asarray(-10*40, dtype=float)) # gamma = 40 => OKAY
1.9151695967140057e-174
>>> np.exp(np.asarray(-10*120, dtype=float)) # gamma = 120 => NOT OKAY
0.0

实际上,我不是手动设置伽马,而是使用本文(第2.4节)中描述的方法。

我能想到的唯一方法是在每个维度上规范化数据集,但是我们失去了数据集的一些几何/拓扑属性(例如,一个2x10的矩形变成了一个1x1的正方形)

在这个例子中,这是最糟糕的:即使gamma = 20,它也会失败。

In [11]: from sklearn.semi_supervised.label_propagation import LabelPropagation

In [12]: import numpy as np

In [13]: X = np.array([[0, 0], [0, 10]])

In [14]: Y = [0, -1]

In [15]: LabelPropagation(kernel='rbf', tol=0.01, gamma=20).fit(X, Y)
/usr/local/lib/python3.5/dist-packages/sklearn/semi_supervised/label_propagation.py:279: RuntimeWarning: invalid value encountered in true_divide
  self.label_distributions_ /= normalizer
/usr/local/lib/python3.5/dist-packages/sklearn/semi_supervised/label_propagation.py:290: ConvergenceWarning: max_iter=1000 was reached without convergence.
  category=ConvergenceWarning
Out[15]: 
LabelPropagation(alpha=None, gamma=20, kernel='rbf', max_iter=1000, n_jobs=1,
         n_neighbors=7, tol=0.01)

In [16]: LabelPropagation(kernel='rbf', tol=0.01, gamma=2).fit(X, Y)
Out[16]: 
LabelPropagation(alpha=None, gamma=2, kernel='rbf', max_iter=1000, n_jobs=1,
         n_neighbors=7, tol=0.01)

In [17]: 

共有1个答案

相德宇
2023-03-14

基本上你在做一个softmax函数,对吧?

防止< code>softmax上溢/下溢的一般方法是(从这里)

# Instead of this . . . 
def softmax(x, axis = 0):
    return np.exp(x) / np.sum(np.exp(x), axis = axis, keepdims = True)

# Do this
def softmax(x, axis = 0):
    e_x = np.exp(x - np.max(x, axis = axis, keepdims = True))
    return e_x / e_x.sum(axis, keepdims = True)

此边界e_x 0 和 1 之间,并确保e_x的一个值将始终为 1(即元素 np.argmax(x))。这可以防止上溢和下溢(当 np.exp(x.max()) 大于或小于 float64 可以处理的时)。

在这种情况下,由于您无法更改算法,我将采用输入D并制作D_ = D - D.min(),因为这在数字上应该与上述等效,因为W.max()应该是-gamma * D.min()(因为您只是翻转符号)。做你的算法关于D_

编辑:

正如下面@PaulBrodersen所推荐的,您可以在此处基于< code>sklearn实现构建一个“安全”的rbf内核:

def rbf_kernel_safe(X, Y=None, gamma=None): 

      X, Y = sklearn.metrics.pairwise.check_pairwise_arrays(X, Y) 
      if gamma is None: 
          gamma = 1.0 / X.shape[1] 

      K = sklearn.metrics.pairwise.euclidean_distances(X, Y, squared=True) 
      K *= -gamma 
      K -= K.max()
      np.exp(K, K)    # exponentiate K in-place 
      return K 

然后在传播中使用它

LabelPropagation(kernel = rbf_kernel_safe, tol = 0.01, gamma = 20).fit(X, Y)

可惜我只有< code>v0.18,不接受< code>LabelPropagation的用户自定义内核函数,所以无法测试。

第二版:

检查你的源代码为什么有这么大的gamma值让我怀疑你是否使用了gamma=D.min()/3,这是不正确的。定义是sigma=D.min()/3,而w中的sigma的定义是

w = exp(-d**2/sigma**2)  # Equation (1)

这将使正确的gamma1/sigma**2

 类似资料:
  • 问题内容: 我有此错误信息: 消息8134,级别16,状态1,第1行除以零错误。 编写SQL代码的最佳方法是什么,这样我就再也看不到此错误消息了? 我可以执行以下任一操作: 添加一个where子句,这样我的除数永远不会为零 或者 我可以添加一个case语句,以便对零进行特殊处理。 使用子句的最佳方法是吗? 有没有更好的方法,或者如何执行? 问题答案: 为了避免出现“被零除”错误,我们对此进行了如下

  • 问题内容: 我正在尝试通过从客户端向服务器发送密钥和随机数来认证用户。 我的代码未向我显示客户端的响应。执行下面的代码时,我得到了一个空指针异常。 问题答案: 解决大多数问题的固定步骤: 阅读堆栈跟踪以确定哪一行代码引发NPE 在该行代码处设置一个断点 使用调试器,在遇到断点时,确定该行中的对象引用是 弄清楚为什么引用该文件(到目前为止,这是唯一实际的困难部分) 解决根本原因(也可能很困难)

  • 问题内容: 我有两个简单的Java代码。第一个将恒定功率定义为power = a.pow(b); 第二个将恒定功率定义为power = BigInteger.ONE.shiftLeft(b) 在命令行中设置内存标志- Xmx1024m,第一个代码可以正常工作,但是第二个代码却出现错误:java.lang.OutOfMemoryError:Java堆空间 我的问题:我应该在第二个代码中更改什么以避免

  • 问题内容: 我有一个用于将文本添加到现有.doc文件中的代码,它将通过使用apache POI将其另存为另一个名称。 以下是到目前为止我尝试过的代码 以下是我得到的 我已经添加了与此对应的所有jar文件,但仍然找不到解决方案。我对apache poi是陌生的,所以请帮我提供一些解释和示例。谢谢 问题答案: 从我对问题的评论中复制: 看起来您需要Apache POI发行版中的poi-ooxml-sc

  • 我有一个实现远程后台服务的应用程序。这个服务是用来下载线程中的文件的(我想说这个服务是作为下载管理器工作的)。 当我想下载一个文件时,我将url发送给服务,服务启动一个线程(我使用的是AsyncTask,但它只在Android4.1中工作)。但下载迟早会停止,我能够知道这一点,因为我显示的通知不再更新。当我单击取消下载的通知时,将向服务发送一个挂起的意图,告诉它取消下载,但当服务重新创建时,将取消

  • 我正在开发一个以活动和片段为结构的简单应用程序,其中一个要求是使其可访问,因此我完成了所有内容描述、导航、焦点等。 它工作得很好,除了片段,如果有一个活动加载一个片段,对讲读取它的内容,然后用户点击一些东西,一个细节片段可以被添加到堆栈的顶部。 如果用户继续导航对讲,仍然记得丢失片段的每个元素的位置。 有没有办法清除事件的辅助功能列表并强制它再次获取它?可访问性管理器似乎没有任何方法。 - 编辑