Pytorch学习(十一)--- Tensor Comprehensions 初探

应俊爽
2023-12-01

初衷

其实, 有时候你会很难受. 因为你有一些想法, 在python端写会很慢, 写成cuda代码又有难度, 还要想着各种优化. 这时候你就不爽了,好不容易想到骚操作, 竟然因为写不来代码, 就泯灭这个想法吗. 这时候, 你可以看看Tensor Comprehesions(TC)这个包.

先吹一波: TC是一个让你不用写高性能代码的包, 它会直接根据简单的语法来生成GPU代码.
如果你还在为一下事情烦恼:
- 当你的pytorch 层很慢, 然后你想写CUDA代码, 这时候你没必要真的写CUDA代码.
- 你有一个CUDA层, 你花了一周写啊,调试啊,优化啊.但是你现在想一小时搞定.
- 其他..

使用

安装

这个没啥好说的:

conda install -c pytorch -c tensorcomp tensor_comprehensions

大概会安装这些包:

   package                    |            build
    ---------------------------|-----------------
    protobuf-3.4.1             |       h21cfbc1_2         5.6 MB  tensorcomp
    llvm-tapir50-0.1.0         |       h186cc49_2       331.5 MB  tensorcomp
    cudatoolkit-8.0            |                3       322.4 MB
    rhash-1.3.5                |       hbf7ad62_1         178 KB
    unzip-6.0                  |       h611a1e1_0          91 KB
    halide-0.1.0               |       h9df8326_2        37.9 MB  tensorcomp
    isl-tc-0.1.0               |       h9c8d533_2         1.3 MB  tensorcomp
    libuv-1.19.2               |       h14c3975_0         713 KB
    pytorch-0.3.1              |py36_cuda8.0.61_cudnn7.0.5_2       205.1 MB  pytorch
    gflags-2.4.4               |       h4126541_2         113 KB  tensorcomp
    cudnn-7.0.5                |        cuda8.0_0       249.3 MB
    ------------------------------------------------------------
                                           Total:        1.13 GB

开始用:

import torch
import tensor_comprehensions as tc
lang = """
def fcrelu(float(B,M) I, float(N,M) W1, float(N) B1) -> (O1) {
    O1(b, n) +=! I(b, m) * W1(n, m)
    O1(b, n) = O1(b, n) + B1(n)
    O1(b, n) = fmax(O1(b, n), 0)
}
"""
fcrelu = tc.define(lang, name="fcrelu")
B, M, N = 100, 128, 100
I, W1, B1 = torch.randn(B, M).cuda(), torch.randn(N, M).cuda(), torch.randn(N).cuda()
fcrelu.autotune(I, W1, B1, cache="fcrelu_100_128_100.tc")

首先我们要开始写一种TC语法lang, 这个等会儿再说. 然后通过 tc.define定义操作, 在autotune自动调参.
最终的结果会存在 ‘fcrelu_100_128_100.tc’, 下一次执行同一条语句是,就会直接加载结果.

out = fcrelu(I, W1, B1)

简要介绍TC语法

其实刚才看到了上面的东西,主要的疑惑是,咋写lang啊.

lang = """
def matmul(float(M,N) A, float(N,K) B) -> (output) {
  output(i, j) +=! A(i, kk) * B(kk, j)
}
"""
float(M,N) A, float(N,K) B) -> (output)

这里表示A是M*N的tensor, B是N*K的tensor.输入输出都是tensor类型.关键是output怎么得到?

  output(i, j) +=! A(i, kk) * B(kk, j)
  1. 如何确定范围:表达式的范围根据相应的输入的坐标即可.比如 output(i,j)这个i的值就是[0,M-1].
    如果我们需要进行类似stride的操作呢?这个以后再说.
  2. 右边有的索引而坐标没了, 那么该索引就要进行Reduction 就是说,这个索引要遍历所有能遍历的值.比如
    这里kk, 我们看到 kk在A中的范围是[0, N-1], 在B中的范围也是[0, N-1],所以A的范围就是取交集,同时kk只在右边有,左边没有.我们对kk进行reduction. 
  3. 最后! 表示每次都要进行初始化.这个初始化不一定是初始化成0的,不同操作都有一个初始化的值.有些操作的初始化值可能不是0.这个暂时不说.
output(i, j) = 0   # 初始化
output(i, j) += A(i, kk) * B(kk, j)  # for kk in [0, N-1]

再看一个例子:
实现AvgPool2d

LANG="""

def avgpool(float(B, C, H, W) input) -> (output) {{
    output(b, c, h, w) += input(b, c, h * {sH} + kh, w * {sW} + kw) where kh in 0:{kH}, kw in 0:{kW}
}}
"""
avgpool = tc.define(LANG, name="avgpool", constants={"sH":1, "sW":1, "kH":2, "kW":2})

我们可以用 where 搭配索引的范围0:{kH},等价于 python中的range(kH)
下面是几个例子:
简单卷积

def convolution(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(M) B) -> (O) {
    O(n, m, h, w) +=! I(n, c, h + kh, w + kw) * W1(m, c, kh, kw)
    O(n, m, h, w) = O(n, m, h, w) + B(m)
}

strided卷积

def convolution_strided(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(M) B) -> (O) {{
    O(n, m, h, w) +=! I(n, c, {sh} * h + kh, {sw} * w + kw) * W1(m, c, kh, kw)
    O(n, m, h, w) = O(n, m, h, w) + B(m)
}}

strided 卷积梯度

def convolution_grad(float(N, C, H, W) I, float(M, C, KH, KW) W1, float(N, M, H, W) O_grad) -> (I_grad, W1_grad) {{
    I_grad(n, c, h, w) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * W1(m, c, kh, kw)
    W1_grad(m, c, kh, kw) +=! O_grad(n, m, {sh} * h - kh, {sw} * w - kw) * I(n, c, h, w)
}}

全连接层

def fully_connected(float(B, M) I, float(N, M) W1, float(N) B1) -> (O1) {
    O1(b, n) +=! I(b, m) * W1(n, m)
    O1(b, n) = O1(b, n) + B1(n)
}

Softmax层



def softmax(float(N, D) I) -> (O, maxVal, expDistance, expSum) {
    maxVal(n) max= I(n, d)
    expDistance(n, d) = exp(I(n, d) - maxVal(n))
    expSum(n) +=! expDistance(n, d)
    O(n, d) = expDistance(n, d) / expSum(n)
}

 类似资料: