当前位置: 首页 > 面试题库 >

numpy.where的更快替代品?

蒲昊苍
2023-03-14
问题内容

我有一个3d数组,其中填充了从0到N的整数。我需要一个与该数组等于1、2、3,… N的位置对应的索引列表。我可以使用np.where进行如下操作:

N = 300
shape = (1000,1000,10)
data = np.random.randint(0,N+1,shape)
indx = [np.where(data == i_id) for i_id in range(1,data.max()+1)]

但这很慢。根据这个问题 快速python
numpy在哪里功能?

应该可以大大加快索引搜索的速度,但是我无法将那里提出的方法转移到我获取实际索引的问题上。加快上述代码的最佳方法是什么?

作为附加组件:我想稍后存储索引,使用np.ravel_multi_index可以将大小从保存3个索引减少为仅1个,即使用:

indx = [np.ravel_multi_index(np.where(data == i_id), data.shape) for i_id in range(1, data.max()+1)]

这更接近于Matlab的find函数。可以直接将其合并到不使用np.where的解决方案中吗?


问题答案:

我认为解决此问题的标准向量化方法最终将占用大量内存,对于int64数据,将需要O(8 * N * data.size)字节,对于上面给出的示例,则需要约22
gigs的内存。我假设这不是一个选择。

您可以通过使用稀疏矩阵存储唯一值的位置来取得一些进展。例如:

import numpy as np
from scipy.sparse import csr_matrix

def compute_M(data):
    cols = np.arange(data.size)
    return csr_matrix((cols, (data.ravel(), cols)),
                      shape=(data.max() + 1, data.size))

def get_indices_sparse(data):
    M = compute_M(data)
    return [np.unravel_index(row.data, data.shape) for row in M]

这利用了稀疏矩阵构造函数中的快速代码来以一种有用的方式来组织数据,从而构造了一个稀疏矩阵,其中rowi仅包含扁平化数据等于的索引i

为了进行测试,我还将定义一个执行简单方法的函数:

def get_indices_simple(data):
    return [np.where(data == i) for i in range(0, data.max() + 1)]

对于相同的输入,这两个函数给出相同的结果:

data_small = np.random.randint(0, 100, size=(100, 100, 10))
all(np.allclose(i1, i2)
    for i1, i2 in zip(get_indices_simple(data_small),
                      get_indices_sparse(data_small)))
# True

稀疏方法比数据集的简单方法快一个数量级:

data = np.random.randint(0, 301, size=(1000, 1000, 10))

%time ind = get_indices_simple(data)
# CPU times: user 14.1 s, sys: 638 ms, total: 14.7 s
# Wall time: 14.8 s

%time ind = get_indices_sparse(data)
# CPU times: user 881 ms, sys: 301 ms, total: 1.18 s
# Wall time: 1.18 s

%time M = compute_M(data)
# CPU times: user 216 ms, sys: 148 ms, total: 365 ms
# Wall time: 363 ms

稀疏方法的另一个好处是,矩阵M最终成为存储所有相关信息以供以后使用的非常紧凑而有效的方式,如问题的附加部分所述。希望这是有用的!

编辑:我意识到初始版本中存在一个错误:如果范围中的任何值未出现在数据中,它将失败:现在已在上面修复。



 类似资料:
  • 问题内容: 我知道这个话题已经解决了上千次。但是我找不到解决办法。 我正在尝试计算列表(df2.list2)的列中出现列表(df1.list1的每一行)的频率。所有列表仅包含唯一值。List1包含约300.000行,list2包含30.000行。 我有一个有效的代码,但是它的运行速度非常慢(因为我使用的是迭代程序)。我也尝试过itertuples(),但它给了我一个错误(“要解压缩的值太多(预期2

  • 问题内容: 为了提高其性能,我一直在使用VisualVM采样器对我的一个应用程序进行性能分析,最小采样周期为20ms。根据探查器,主线程在该方法中花费了将近四分之一的CPU时间。 我正在与该模式一起使用,以将数字“转换” 为正好有六个十进制数字的字符串表示形式。我知道这种方法相对昂贵并且 被 多次调用,但是我对这些结果感到有些惊讶。 这种采样分析器的结果在多大程度上准确?我将如何验证它们-最好不借

  • 问题内容: replace方法返回一个字符串对象而不是替换给定字符串的内容这一事实有点让人费解(但是,当您知道字符串在Java中是不可变的时,这是可以理解的)。通过在某些代码中使用深度嵌套的替换,我的性能受到了重大影响。有什么我可以替换的东西可以使它更快吗? 问题答案: 这就是StringBuilder的目的。如果要进行很多操作,请在上进行操作,然后在需要时将其转换为。 因此描述: “可变的字符序

  • 问题内容: 我正在制作一个程序,要求至少每秒捕获24个屏幕截图。目前,使用下面的代码,我每94毫秒仅获得1个,因此大约为10毫秒。 我不想使用任何第三方库,因为我试图将其保持尽可能小,但是如果我希望获得显着的性能提升,我会愿意的。我也试图保持该平台独立,但是,如果确实能够显着提高性能,我愿意将其限于Windows。 编辑:我现在也尝试了两种不同的方法;使用在oracles网站上找到的代码段,并在下

  • 问题内容: 我想用这种方法在一个块内重新加载表数据: 但是else块将不会执行。我得到的错误是: 那么,快速替代的替代方案是什么? 更新: 我现在收到中止错误。 问题答案: 这个简单的C函数: 如何使用以下命令启动功能: 在viewDidLoad()中?

  • Python的http.server(或Python 2的SimpleHTTPServer)是从命令行提供当前目录内容的一种很好的方式: