当前位置: 首页 > 面试题库 >

有效地将每一行的元素相乘

龙霖
2023-03-14
问题内容

定大小的ndarray(n,3)n1000左右,如何乘在一起的每一行,快的所有元素?下面的(不明确的)第二个解决方案运行大约0.3毫秒,是否可以改进?

# dummy data
n = 999
a = np.random.uniform(low=0, high=10, size=n).reshape(n/3,3)

# two solutions
def prod1(array):
    return [np.prod(row) for row in array]

def prod2(array):
    return [row[0]*row[1]*row[2] for row in array]

# benchmark
start = time.time()
prod1(a)
print time.time() - start
# 0.0015

start = time.time()
prod2(a)
print time.time() - start
# 0.0003

问题答案:

进一步提高性能

起初,一般的经验法则。您正在使用数值数组,因此请使用数组而不是列表。列表看起来可能有点像一个通用数组,但是在后端却完全不同,并且对于大多数数值计算来说绝对是不可行的。

如果您使用Numpy-Arrays编写简单的代码,则可以通过简单地按如下所示的方式添加代码来获得性能。如果使用列表,则可以或多或少地重写代码。

import numpy as np
import numba as nb

@nb.njit(fastmath=True)
def prod(array):
  assert array.shape[1]==3 #Enable SIMD-Vectorization (adding some performance)
  res=np.empty(array.shape[0],dtype=array.dtype)
  for i in range(array.shape[0]):
    res[i]=array[i,0]*array[i,1]*array[i,2]

  return res

使用np.prod(a, axis=1)并不是一个坏主意,但是性能并不是很好。对于只有1000x3的数组,函数调用开销非常大。在另一个jitted函数中使用jitted
prod函数时,可以完全避免这种情况。

基准测试

# The first call to the jitted function takes about 200ms compilation overhead. 
#If you use @nb.njit(fastmath=True,cache=True) you can cache the compilation result for every successive call.
n=999
prod1   = 795  µs
prod2   = 187  µs
np.prod = 7.42 µs
prod      0.85 µs

n=9990
prod1   = 7863 µs
prod2   = 1810 µs
np.prod = 50.5 µs
prod      2.96 µs


 类似资料:
  • 问题内容: 小写列表或集合的每个元素的最有效方法是什么? 我对清单的想法: 有没有更好,更快的方法?对于Set,此示例的外观如何?由于当前没有方法可以将操作应用于集合(或列表)的每个元素而无需创建其他临时集合? 这样的事情会很好: 谢谢, 问题答案: 对于列表,这似乎是一个非常干净的解决方案。它应该允许使用特定的List实现,以提供对于线性遍历列表和在恒定时间内替换字符串都是最佳的实现。 这是我可

  • 问题内容: 假设我有一个numpy数组: 我有一个对应的“向量”: 我如何沿每一行进行减法或除法运算,所以结果是: 长话短说:如何使用对应于每一行的1D标量数组在2D数组的每一行上执行操作? 问题答案: 干得好。您只需要与广播结合使用(或):

  • 我有一个数组[25,-6,14,7,100]。预期输出为 基本上,循环时下一个元素被减/加到当前元素中。求和和乘积很容易,因为我只需要做 它在线程“main”java.lang.ArrayIndexOutOfBoundsException中给出了

  • 我有一个数据帧,其中前3列是最后一列的标记名称,这是一个矩阵列表。每个矩阵都是二进制的(只有1和0)。我想从每一行中提取一个标记名称,并用该标记名称替换行矩阵中1的所有值,但是我想避免使用循环,以便我可以并行地对所有行执行此操作。 示例数据框: 目标格式如下所示。。。 我尝试过使用方法,但没有成功,因为结果输出是列表格式的,而不是获得df1的原始结构。有什么想法吗?

  • 问题内容: 给定 我可以复制所有项目到做 有没有更惯用的方法来做到这一点? 仅适用于切片(并作为源)。 问题答案: 这对我来说似乎是一种完美的方法。我认为将一张地图复制到另一张地图并不足以提供单线解决方案。

  • 例如,对于 我想得到 有没有办法不用for循环或使用? 编辑:实际数据由1000行组成,每行100个元素,每个元素的范围从1到365。最终目标是确定有重复的行的百分比。这是一个作业问题,我已经解决了(用for循环),但我只是想知道是否有更好的方法来做它与Numpy。