mx.symbol.slice_axis可以直接在某一维上切割,选择整行或整列。
pick 是精准挑选指定位置的值,
精准筛选指定位置的值。
代码格式,x,y都需要转成mxnet nd.array形式
import mxnet as mx
import mxnet.gluon.loss as gloss
#
from mxnet import nd
x =np.array([[ 1., 2.],
[ 3., 4.],
[ 5., 6.]])
x = mx.nd.array(x)
y=np.array([ 0, 1,1])
y=mx.nd.array(y)
result= nd.pick(x, y,1)
print(result)
结果:
[1. 4. 6.]
<NDArray 3 @cpu(0)>
官方举例:
Examples::
x = [[ 1., 2.],
[ 3., 4.],
[ 5., 6.]]
// picks elements with specified indices along axis 0
pick(x, y=[0,1], 0) = [ 1., 4.]
// picks elements with specified indices along axis 1
pick(x, y=[0,1,0], 1) = [ 1., 4., 5.]
y = [[ 1.],
[ 0.],
[ 2.]]
// picks ele