有关完整的ND4J API索引,请参考javadoc。
ND4J中使用了三种类型的操作:标量、转换和累加。我们将使用“op”这个词作为“operation”的同义词。你可以在这里的目录下看到这三种nd4j操作的列表。每个列表中的每个Java文件都是一个op。
大多数操作只需要枚举,或者一个可以自动完成的离散值列表。激活函数是例外,因为它们采用诸如“relu”或“tanh”之类的字符串。
标量、转换和累加都有自己的模式。转换是最简单的,因为接受单个参数并对其执行操作。绝对值是一种转换,它采用参数x如abs(IComplexNDArray ndarray)
并生成X的绝对值的结果。同样,可能应用于sigmoid转换 sigmoid()
以生成“x的sigmoid”。
标量操作只接受两个参数:输入和要应用于该输入的标量。例如,ScalarAdd()
有两个参数:input INDArray
x和标量数字大小;即ScalarAdd(INDArray x, Number num)
。同样的格式适用于每个标量运算。
最后,我们有积加,也被称为GPU处 缩减。累加将数组和向量互相添加,并可以通过按行操作添加元素来减少结果中这些数组的维数。例如,我们可以在数组上运行累加。
[1 2
3 4]
它会给我们向量
[3
7]
将列(即维度)从两个减少到一个。
累加可以是成对的,也可以是标量的。在成对缩减中,我们可能要处理两个形状相同的数组x和y。在这种情况下,我们可以用x和y的元素2乘2来计算它们的余弦相似度。
cosineSim(x[i], y[i])
或采用 EuclideanDistance(arr, arr2)
, 一个数组arr
和另一个数组arr2
之间的缩减。
许多nd4j操作被重载,这意味着共享一个公共名称的方法具有不同的参数列表。下面我们将只解释最简单的配置。
如你所见,ND4J操作有三种可能的参数类型:输入、可选参数和输出。输出在ops的构造函数中指定。输入在方法名称后面的括号中指定,始终在第一个位置,并且可选参数用于转换输入;例如要添加的标量;要乘以的系数,始终在第二个位置。
Method | What it does |
---|---|
Transforms | |
ACos(INDArray x) | 三角反余弦,元素方向。cos的倒数 ,如果 y = cos(x) , 那么 x = ACos(y) . |
ASin(INDArray x) | 也称为Arcsin。反正弦,元素方向。 |
ATan(INDArray x) | 三角反切线,元素方向。tan的倒数,这样,如果y = tan(x) 那么 x = ATan(y) . |
Transforms.tanh(myArray) | 双曲正切:一个sigmoidal函数。这适用于就地元素tanh运算。 |
Nd4j.getExecutioner().exec(Nd4j.getOpFactory() .createTransform(“tanh”, myArray)) | equivalent to the above |
对于其它转换, 查看此页 。
这里有两个执行z=tanh(x)的示例,其中原始数组x未修改。
INDArray x = Nd4j.rand(3,2); //输入
INDArray z = Nd4j.create(3,2); //输出
Nd4j.getExecutioner().exec(new Tanh(x,z));
Nd4j.getExecutioner().exec(Nd4j.getOpFactory().createTransform("tanh",x,z));
后两个例子使用了nd4j对所有操作的基本约定,其中我们有3个NDArrays,x、y和z。
x 是输入,必要的
y 是可选输入,只在一些操作中使用(余弦相似度,加法操作等)
z 是输出
通常,z=x(如果只使用一个参数的构造函数,则这是默认值)。但也有例外情况,如x=x+y。另一种可能性是z=x+y等。
大多数累加可以直接通过INDArray接口访问。
例如,要累加一个INDArray的所有元素:
double sum = myArray.sumNumber().doubleValue();
沿维累加即每行的总和值:
INDArray tenBy3 = Nd4j.ones(10,3); //10 行, 3列
INDArray sumRows = tenBy3.sum(0);
System.out.println(sumRows); //输出: [ 10.00, 10.00, 10.00]
沿维数的累加可推广,因此可以沿任何具有两个或多个维数的数组的两个维数求和。
一个简单的例子:
INDArray random = Nd4j.rand(3, 3);
System.out.println(random);
[[0.93,0.32,0.18]
[0.20,0.57,0.60]
[0.96,0.65,0.75]]
INDArray lastTwoRows = random.get(NDArrayIndex.interval(1,3),NDArrayIndex.all());
间隔是从fromInclusive到toExclusive,注意到可以等同于使用inclusive版本:NDArrayIndex.interval(1,2,true);
System.out.println(lastTwoRows);
[[0.20,0.57,0.60]
[0.96,0.65,0.75]]
INDArray twoValues = random.get(NDArrayIndex.point(1),NDArrayIndex.interval(0, 2));
System.out.println(twoValues);
[ 0.20, 0.57]
这些是底层数组的视图,而不是复制操作(它提供了更大的灵活性,并且没有复制成本)。
twoValues.addi(5.0);
System.out.println(twoValues);
[ 5.20, 5.57]
System.out.println(random);
[[0.93,0.32,0.18]
[5.20,5.57,0.60]
[0.96,0.65,0.75]]
为了避免就地行为,可以使用random.get(…).dup()进行复制。
Scalar | |
INDArray.add(number) | 返回将数字添加到 |
INDArray.addi(number) | 返回将数字添加到 |
ScalarAdd(INDArray x, Number num) | 返回将数字添加到 |
ScalarDivision(INDArray x, Number num) | 返回将 |
ScalarMax(INDArray x, Number num) | 将 |
ScalarMultiplication(INDArray x, Number num) | 返回将 |
ScalarReverseDivision(INDArray x, Number num) | 返回将num除以 |
ScalarReverseSubtraction(INDArray x, Number num) | 返回从num中减去 |
ScalarSet(INDArray x, Number num) | 这会将 |
ScalarSubtraction(INDArray x, Number num) | 返回从 |
如果您不理解nd4j语法的解释,找不到方法的定义,或者希望请求添加函数,请在Gitter live chat上通知我们。