当前位置: 首页 > 知识库问答 >
问题:

在unittest中比较(断言相等)包含numpy数组的两个复杂数据结构

郭逸清
2023-03-14

我使用Python的unittest模块来检查两个复杂的数据结构是否相等。对象可以是具有各种值的dict列表:数字、字符串、Python容器(列表/元组/dict)和numpyhtml" target="_blank">数组。后者是我提出这个问题的原因,因为我不能只是这样做

self.assertEqual(big_struct1, big_struct2)

因为它会产生

ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()

我想我需要为这个写我自己的平等测试。它应该适用于任意结构。我目前的想法是递归函数:

  • 尝试将arg1的当前“节点”与arg2的对应节点进行直接比较
  • 如果未引发异常,则继续移动(“终端”节点/叶也在此处处理)
  • 如果捕获到ValueError,则会深入搜索,直到找到numpy.array
  • 比较阵列(例如,如下所示)

看起来有点问题的是跟踪两个结构的“对应”节点,但这里可能只需要zip

问题是:有没有好的(更简单的)替代方法?也许numpy会为此提供一些工具?如果没有其他选择,我将实施这个想法(除非我有更好的想法),并将其作为答案发布。

P. S.我有一种模糊的感觉,我可能已经看到了一个解决这个问题的问题,但是我现在找不到了。

P.P.S.另一种方法是使用函数遍历结构并将所有numpy.arrayS转换为列表,但这更容易实现吗?对我来说似乎是一样的。

编辑:子类化numpy.ndarray听起来很有希望,但是显然我没有将比较的双方硬编码到测试中。其中一个确实是硬编码的,所以我可以:

  • numpy.array的自定义子类填充它

我在这方面的问题是:

  1. 这行得通吗(我的意思是,我觉得没问题,但可能一些棘手的边缘案件不会得到正确处理)?我的自定义对象是否会像我所期望的那样,在递归等式检查中始终作为LHS结束?

编辑2:我试过了,答案中显示了(似乎)工作的实现。

共有3个答案

熊哲圣
2023-03-14

因此,jterrace说明的想法似乎对我有用,只需稍加修改:

class SaneEqualityArray(np.ndarray):
    def __eq__(self, other):
        return (isinstance(other, np.ndarray) and self.shape == other.shape and 
            np.allclose(self, other))

正如我所说,包含这些对象的容器应该位于相等检查的左侧。我从现有的numpy.ndarrays创建SaneEqualityArray对象,如下所示:

SaneEqualityArray(my_array.shape, my_array.dtype, my_array)

根据ndarray建造商签名:

ndarray(shape, dtype=float, buffer=None, offset=0,
        strides=None, order=None)

此类在测试套件中定义,仅用于测试目的。相等检查的RHS是被测试函数返回的实际对象,包含realnumpy.ndarray对象。

另外,感谢到目前为止两个答案的作者,他们都非常有帮助。如果有人发现这种方法有任何问题,我将感谢您的反馈。

万楷
2023-03-14

函数将调用对象的__eq__方法,对于复杂的数据类型应该递归。异常是Numpy,它没有sane__eq__方法。使用这个问题中的Numpy子类,您可以恢复平等行为的理智:

import copy
import numpy
import unittest

class SaneEqualityArray(numpy.ndarray):
    def __eq__(self, other):
        return (isinstance(other, SaneEqualityArray) and
                self.shape == other.shape and
                numpy.ndarray.__eq__(self, other).all())

class TestAsserts(unittest.TestCase):

    def testAssert(self):
        tests = [
            [1, 2],
            {'foo': 2},
            [2, 'foo', {'d': 4}],
            SaneEqualityArray([1, 2]),
            {'foo': {'hey': SaneEqualityArray([2, 3])}},
            [{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
             SaneEqualityArray([5, 6]), 34]
        ]
        for t in tests:
            self.assertEqual(t, copy.deepcopy(t))

if __name__ == '__main__':
    unittest.main()

这个测试通过。

姚臻
2023-03-14

会评论的,但是太长了。。。

有趣的是,你不能用==来测试数组是否相同,我建议你用np.testing.assert_array_equal

  1. 用于检查数据类型、形状等。

如果你想要这种疯狂嵌套的东西,我几乎建议尝试用泡菜序列化它,但那太严格了(第3点当然是完全破碎的),例如数组的内存布局并不重要,但对其序列化很重要。

 类似资料:
  • 比较两个NumPy数组是否相等的最简单方法是什么(其中相等定义为:A=B iff,用于所有索引i:

  • 是否有一种惯用的方法来比较两个NumPy数组,它们将NaN视为彼此相等(但不等于NaN以外的任何东西)。 例如,我希望以下两个数组比较相等: 和以下两个数组进行比较: 我正在寻找一种可以产生标量布尔结果的方法。 以下方法可以做到这一点: 但它很笨重,并且创建了所有这些中间数组。 有没有一种方法可以更容易地观察眼睛,更好地利用记忆? 另外,如果有帮助的话,已知数组具有相同的形状和数据类型。

  • 问题内容: 对于我的单元测试,我想检查两个数组是否相同。简化示例: 这是行不通的,因为。最好的进行方法是什么? 问题答案: 或者您可以使用或用: 编辑 由于您正在使用它进行单元测试,因此裸露(而不是将其包装成get )可能更自然。

  • 问题内容: 比较两个NumPy数组是否相等的最简单方法是什么(其中相等定义为:对于所有索引i:,A = B iff )? 简单地使用就会给我一个布尔数组: 我是否必须确定该数组的元素是否相等,或者是否有更简单的比较方法? 问题答案: 测试数组(A == B)的所有值是否均为True。 注意:也许您还想测试A和B形状,例如 特殊情况和替代方法 (来自dbaupp的回答和yoavram的评论) 应当指

  • 问题内容: 有人遇到过这个问题吗?假设您有两个类似以下的数组 有没有一种方法可以比较b中a中的哪些元素?例如, 我正在尝试避免循环,因为要花费数百万个元素才能解决问题。有任何想法吗? 干杯 问题答案: 实际上,有一个比以下任何一种方法更简单的解决方案: 所得的c为:

  • 我有以下数组列表 现在我需要比较这两个数组,并检查是否有id中的任何值存在于empIds中。如果是,我需要以布尔值true退出。我是这样做的。 但这需要很多时间。有人能帮我优化一下吗?