Tensor Comprehensions 基础算子改写

亢建白
2023-12-01
def conv(self, input, outchannel, k_size, stride = 1, padding = 0):
        LANG = """
        def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {{
            O(n, m, h, w) +=! I(n, c, {sh} * h + kh, {sw} * w + kw) * W1(m, c, kh, kw)
        }}
        """
        kernel = torch.randn(outchannel, input.size(1), k_size, k_size)
        input = input.type(torch.FloatTensor)
        if padding != 0:
            pad = torch.zeros(input.size(0), input.size(1), padding, input.size(3))
            input = torch.cat((input, pad), 2)
            input = torch.cat((pad, input), 2)
            pad = torch.zeros(input.size(0), input.size(1), input.size(2), padding)
            input = torch.cat((input, pad), 3)
            input = torch.cat((pad, input), 3)


        sh, sw = stride, stride
        convolution = tc.define(LANG, training=True, name="convolution", backward="convolution_grad",constants={"sh":sh, "sw":sw})
        I = Variable(input.cuda(), requires_grad=True)
        W = Parameter(kernel.cuda())
        out = convolution(I, W, options=tc.CudaMappingOptions("conv"))
        return out


    def relu(self, input):
        LANG = """
        def relu(float(Q, W, B, M) I) -> (O1){
            O1(q, w, b, m) = fmax(I(q, w, b, m), 0)
        }
        """
        relu = tc.define(LANG, name="relu")
        inp = input.cuda()
        out = relu(inp, options=tc.CudaMappingOptions("naive"))
        return out


    def maxpool(self, input, p_size, stride, padding = 0):
        LANG = """
        def maxpool(float(B,C,H,W) input) -> (output) {{
            output(b,c,h,w) max=! input(b, c, h * {sH} + kh, w * {sW} + kw)  
            where kh in 0:{kH}, kw in 0:{kW}
        }}
        """
        input = input.type(torch.FloatTensor)
        if padding != 0:
            pad = torch.zeros(input.size(0), input.size(1), padding, input.size(3))
            input = torch.cat((input, pad), 2)
            input = torch.cat((pad, input), 2)
            pad = torch.zeros(input.size(0), input.size(1), input.size(2), padding)
            input = torch.cat((input, pad), 3)
            input = torch.cat((pad, input), 3)
        
        sH, sW = stride, stride
        kH, kW = p_size, p_size
        maxpool = tc.define(LANG, name="maxpool", constants={"sH":sH, "sW":sW, "kH":kH, "kW":kW})
        inp = input.cuda()
        out = maxpool(inp, options=tc.CudaMappingOptions("naive"))
        out = out.type(torch.FloatTensor)
        return out
        
    def avgpool(self, input, p_size, stride, padding = 0):
        LANG = """
        def avgpool(float(B, C, H, W) input) -> (output) {{
            output(b, c, h, w) +=! input(b, c, h * {sH} + kh, w * {sW} + kw) / ({kH} * {kW})
            where kh in 0:{kH}, kw in 0:{kW}
        }}
        """
        input = input.type(torch.FloatTensor)
        if padding != 0:
            pad = torch.zeros(input.size(0), input.size(1), padding, input.size(3))
            input = torch.cat((input, pad), 2)
            input = torch.cat((pad, input), 2)
            pad = torch.zeros(input.size(0), input.size(1), input.size(2), padding)
            input = torch.cat((input, pad), 3)
            input = torch.cat((pad, input), 3)
            
        sH, sW = stride, stride
        kH, kW = p_size, p_size
        avgpool = tc.define(LANG, name="avgpool", constants={"sH":sH, "sW":sW, "kH":kH, "kW":kW})
        inp = input.cuda()
        out = avgpool(inp, options=tc.CudaMappingOptions("naive"))
        out = out.type(torch.FloatTensor)
        return out


    def batchnorm(self, input):
        LANG = """
        def batchnorm(float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn)
        -> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut)
        {{
           mean(c) +=! I(nn, c, hh, ww)
           mean(c)  = mean(c) / (N * H * W)
           rMeanOut(c) = (1 - {momentum}) * rMeanIn(c) + {momentum} * mean(c)
           centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c)
           variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w)
           expectedVariance(c) +=! (variance(n, c, h, w) + {eps}) / (N * H * W)
           rVarOut(c) = rsqrt(
             (1 - {momentum}) * rVarIn(c) + {momentum} * expectedVariance(c))
           O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c)
           normalizedOut(n, c, h, w) = O(n, c, h, w)
        }}
        """
        channel = input.size(1)
        batchnorm = tc.define(LANG, name="batchnorm", constants={"momentum": 0.9, "eps": 1e-5})
        running_mean, running_var = torch.randn(channel).cuda(), torch.randn(channel).cuda()
        out = batchnorm(input.cuda(), running_mean, running_var, options=tc.CudaMappingOptions("naive"))
        return out[7] 


    def linear(self, input, outchannel):
        LANG = """
        def fc(float(B,M) I, float(N,M) W1, float(N) B1) -> (O1) {
          O1(b, n) +=! I(b, m) * W1(n, m)
          O1(b, n) = O1(b, n) + B1(n)
        }
        """
        I = input.cuda()
        W = torch.randn(outchannel, input.size(1)).cuda()
        B = torch.randn(outchannel).cuda()
        fc = tc.define(LANG, name="fc")
        out = fc(I, W, B, options=tc.CudaMappingOptions("naive"))
        return out
 类似资料: