Tensor Comprehensions 基础算子改写
亢建白
2023-12-01
def conv(self, input, outchannel, k_size, stride = 1, padding = 0):
LANG = """
def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {{
O(n, m, h, w) +=! I(n, c, {sh} * h + kh, {sw} * w + kw) * W1(m, c, kh, kw)
}}
"""
kernel = torch.randn(outchannel, input.size(1), k_size, k_size)
input = input.type(torch.FloatTensor)
if padding != 0:
pad = torch.zeros(input.size(0), input.size(1), padding, input.size(3))
input = torch.cat((input, pad), 2)
input = torch.cat((pad, input), 2)
pad = torch.zeros(input.size(0), input.size(1), input.size(2), padding)
input = torch.cat((input, pad), 3)
input = torch.cat((pad, input), 3)
sh, sw = stride, stride
convolution = tc.define(LANG, training=True, name="convolution", backward="convolution_grad",constants={"sh":sh, "sw":sw})
I = Variable(input.cuda(), requires_grad=True)
W = Parameter(kernel.cuda())
out = convolution(I, W, options=tc.CudaMappingOptions("conv"))
return out
def relu(self, input):
LANG = """
def relu(float(Q, W, B, M) I) -> (O1){
O1(q, w, b, m) = fmax(I(q, w, b, m), 0)
}
"""
relu = tc.define(LANG, name="relu")
inp = input.cuda()
out = relu(inp, options=tc.CudaMappingOptions("naive"))
return out
def maxpool(self, input, p_size, stride, padding = 0):
LANG = """
def maxpool(float(B,C,H,W) input) -> (output) {{
output(b,c,h,w) max=! input(b, c, h * {sH} + kh, w * {sW} + kw)
where kh in 0:{kH}, kw in 0:{kW}
}}
"""
input = input.type(torch.FloatTensor)
if padding != 0:
pad = torch.zeros(input.size(0), input.size(1), padding, input.size(3))
input = torch.cat((input, pad), 2)
input = torch.cat((pad, input), 2)
pad = torch.zeros(input.size(0), input.size(1), input.size(2), padding)
input = torch.cat((input, pad), 3)
input = torch.cat((pad, input), 3)
sH, sW = stride, stride
kH, kW = p_size, p_size
maxpool = tc.define(LANG, name="maxpool", constants={"sH":sH, "sW":sW, "kH":kH, "kW":kW})
inp = input.cuda()
out = maxpool(inp, options=tc.CudaMappingOptions("naive"))
out = out.type(torch.FloatTensor)
return out
def avgpool(self, input, p_size, stride, padding = 0):
LANG = """
def avgpool(float(B, C, H, W) input) -> (output) {{
output(b, c, h, w) +=! input(b, c, h * {sH} + kh, w * {sW} + kw) / ({kH} * {kW})
where kh in 0:{kH}, kw in 0:{kW}
}}
"""
input = input.type(torch.FloatTensor)
if padding != 0:
pad = torch.zeros(input.size(0), input.size(1), padding, input.size(3))
input = torch.cat((input, pad), 2)
input = torch.cat((pad, input), 2)
pad = torch.zeros(input.size(0), input.size(1), input.size(2), padding)
input = torch.cat((input, pad), 3)
input = torch.cat((pad, input), 3)
sH, sW = stride, stride
kH, kW = p_size, p_size
avgpool = tc.define(LANG, name="avgpool", constants={"sH":sH, "sW":sW, "kH":kH, "kW":kW})
inp = input.cuda()
out = avgpool(inp, options=tc.CudaMappingOptions("naive"))
out = out.type(torch.FloatTensor)
return out
def batchnorm(self, input):
LANG = """
def batchnorm(float(N,C,H,W) I, float(C) rMeanIn, float(C) rVarIn)
-> (O, rMeanOut, rVarOut, mean, centered, variance, expectedVariance, normalizedOut)
{{
mean(c) +=! I(nn, c, hh, ww)
mean(c) = mean(c) / (N * H * W)
rMeanOut(c) = (1 - {momentum}) * rMeanIn(c) + {momentum} * mean(c)
centered(n, c, h, w) = I(n, c, h, w) - rMeanOut(c)
variance(n, c, h, w) = centered(n, c, h, w) * centered(n, c, h, w)
expectedVariance(c) +=! (variance(n, c, h, w) + {eps}) / (N * H * W)
rVarOut(c) = rsqrt(
(1 - {momentum}) * rVarIn(c) + {momentum} * expectedVariance(c))
O(n, c, h, w) = centered(n, c, h, w) * rVarOut(c)
normalizedOut(n, c, h, w) = O(n, c, h, w)
}}
"""
channel = input.size(1)
batchnorm = tc.define(LANG, name="batchnorm", constants={"momentum": 0.9, "eps": 1e-5})
running_mean, running_var = torch.randn(channel).cuda(), torch.randn(channel).cuda()
out = batchnorm(input.cuda(), running_mean, running_var, options=tc.CudaMappingOptions("naive"))
return out[7]
def linear(self, input, outchannel):
LANG = """
def fc(float(B,M) I, float(N,M) W1, float(N) B1) -> (O1) {
O1(b, n) +=! I(b, m) * W1(n, m)
O1(b, n) = O1(b, n) + B1(n)
}
"""
I = input.cuda()
W = torch.randn(outchannel, input.size(1)).cuda()
B = torch.randn(outchannel).cuda()
fc = tc.define(LANG, name="fc")
out = fc(I, W, B, options=tc.CudaMappingOptions("naive"))
return out