Tensor Comprehensions(TC)语言语法简述

皇甫繁
2023-12-01

背景

Tensor Comprehensions 是一种可以构建 just in time(JIT)系统的语言,程序员可通过该语言用高级编程语言去高效的实现 GPU 等底层代码。该项目由face book开源发布。
我的理解是我们通过该语言可以更加有效的理解张量、运算图等。虽然tensorflow已经有了tensorboard这样的可视化工具,但是在运算图的理解方面还是有些困难。

语法示例

下面给出一个TC定义下的简单小函数。该函数表示输入张量为A,形状为[M,K]、x,形状为[K],输出张量为C,形状为[M]。

def mv(float(M,K) A, float(K) x) -> (C) {
	C(i) = 0
	C(i) += A(i,k) * x(k)
}

该函数的功能可以用伪码描述如下,可以看出TC的定义是一种极其简洁的循环(loop)定义。

tensor C({M}).zero(); // 0-filled single-dim tensor
	parallel for (int i = 0; i < M; i++)
		reduction for (int k = 0; k < K; k++)
			C(i) += A(i,k) * x(k);

个人认为这里存在一个降维(reduce)的操作。在该函数中并不显示地定义变量,变量隐含在索引i、k中。i、k的取值范围根据操作来确定。并且运算与运算之间不会因顺序不同而产生运算不一致的结果,例如:
m i n ( m a x ( f ( . ) ) ) ≠ m a x ( m i n ( f ( . ) ) ) ) min(max(f(.))) \neq max(min(f(.)))) min(max(f(.)))̸=max(min(f(.))))

Because k only appears on the right-hand side, stores into C will reduce over k with the reduction operator +.

要注意:a(i,j) = a(j,i)这样的赋值方式是不允许的,除非张量a为空或者只有一个元素。

实例1:Tensor Comprehension for the sgemm BLAS(sgemm矩阵乘法)。在这个例子中,我们可以看出来TC对遍历张量的一个规则,TC将对于张量将提出一中方正(rectangular)的范围。比如 i 将被赋予张量 A 的行的最大值,j 将被赋予张量 B 的最大列值,k被使用两次,所以会给 k 尽可能大的数(该范围是A行B列中的较小值)。

def sgemm(float a, float b, float(N,M) A, float(M,K) B) -> (C) {
		C(i,j) = b * C(i,j) # initialization
		C(i,j) += a * A(i,k) * B(k,j) # accumulation
}

在TC中,定义 +=! 来替代 += ,这样就可以避免初始化操作,例如:

def mv(float(M,K) A, float(K) x) -> (C) {
	C(i) +=! A(i,k) * x(k)
}

实例2:卷积运算的实现。我们可以发现,通过TC可以将传统深度学习框架中需要很多行定义的相关操作,只用几行就搞定(函数 fcrelu 应该是 relu 激活函数,conv2d 估计是卷积操作)。

def fcrelu(float(B,I) in, float(O,I) weight, float(I) bias) -> (out) {
	out(i,j) = bias(j)
	out(b,o) += in(b,i) * weight(o,i)
	out(i,j) = fmaxf(out(i,j), 0)
}

def conv2d(float(B,IP,H,W) in, float(OP,IP,KH,KW) weight) -> (out) {
	out(b,op,h,w) +=! in(b,ip, h + kh, w + kw) * weight(op,ip,kh,kw)
	# out的简约初始化为0
}

紧接着,最大池化操作的定义如下,where 语句用来确定变量的变化范围,从而程序不会出越界错。
:如果没有where语句,那么循环将默认从0开始。

def maxpool2x2(float(B,C,H,W) in) -> (out) {
	out(b,c,i,j) max=! in(b, c, 2 * i + kw, 2 * j + kh)
	where kw in 0:2, kh in 0:2 # 步长为2
}

在TC中我们还可以对两个张量建立联系,这是一种远距离的迭代函数定义(affine function of iterators)。

def gather(float(N) X, int(A,B) I) -> (Z) {
	Z(i,j) = X(I(i,j))
}

实例3:TC还可以很简洁的实现 strided convolution(跳格卷积,即步长不为1),sh代表竖直步长,sw代表水平步长,H、W分别代表图像数据(I)的长、宽。

def sconv2d(int sh, int sw, float(N, C, H, W) I, float(F, C, KH, KW) W, float(F) B) -> (O) {
	O(n,f,h,w) = B(f)
	O(n,f,h,w) += I(n, c, sh * h + kh, sw * w + kw) * W(f, c, kh, kw)
}

再看另外一个例子:

def conv1d(float(M) I, float(N) K) ! (O) {
	O(i) = K(x) * I(i + x)
}

在这个例子中,我们似乎有不同的最大化索引(index)的方法。可以先最大化 i ,令 i 在 I 的整体范围之内,这将强制 x 只接受一个值,所以忽略了大部分内核,从而不会产生期望的卷积输出。
所以 TC 的方法是最大化 x ,令 x 在 K 的整体之内,那么为了不产生越界, i 势必会偏小一些,所以我们最终得出的 O 维度将会偏小。
即得出:
i ⊂ [ 0 , M − N ] i\subset [0, M-N] i[0,MN]

后期学到了新知识会及时补充。

 类似资料: