TV loss全称Total Variation Loss,其作用主要是降噪,图像中相邻像素值的差异可以通过降低TV Loss来一定程度上进行解决 ,从而保持图像的光滑性。
连续TV Loss的定义为:
J
T
0
(
u
)
=
∫
D
u
u
x
2
+
u
y
2
d
x
d
y
J_{T_0}(u)=\int_{D_u}\sqrt{u^2_x+u^2_y}dxdy
JT0(u)=∫Duux2+uy2dxdy其中
u
x
=
∂
u
∂
x
u_x=\frac{\partial u}{\partial x}
ux=∂x∂u,
u
y
=
∂
u
∂
y
u_y=\frac{\partial u}{\partial y}
uy=∂y∂u,
D
u
D_u
Du是定义域。
带阶数的TV Loss的定义为:
J
T
0
(
u
)
=
∫
D
u
(
u
x
2
+
u
y
2
)
β
2
d
x
d
y
J_{T_0}(u)=\int_{D_u}(u^2_x+u^2_y)^{\frac{\beta}{2}}dxdy
JT0(u)=∫Du(ux2+uy2)2βdxdy
离散TV Loss的定义为:
J
T
0
(
u
)
=
∑
i
,
j
(
(
x
i
,
j
−
1
−
x
i
,
j
)
2
+
(
x
i
+
1
,
j
−
x
i
,
j
)
2
)
β
2
J_{T_0}(u)=\sum\limits_{i,j}((x_{i,j-1}-x_{i,j})^2+(x_{i+1,j}-x_{i,j})^2)^{\frac{\beta}{2}}
JT0(u)=i,j∑((xi,j−1−xi,j)2+(xi+1,j−xi,j)2)2β
import torch
def tv_loss(input_t):
temp1 = torch.cat((input_t[:, :, 1:, :], input_t[:, :, -1, :].unsqueeze(2)), 2)
temp2 = torch.cat((input_t[:, :, :, 1:], input_t[:, :, :, -1].unsqueeze(3)), 3)
temp = (input_t - temp1)**2 + (input_t - temp2)**2
return temp.sum()
if __name__ == '__main__':
input = torch.rand(4,3,32,32)
print(tv_loss(input))