当前位置: 首页 > 工具软件 > tv > 使用案例 >

TV Loss详解

微生俊捷
2023-12-01

TV Loss介绍

 TV loss全称Total Variation Loss,其作用主要是降噪,图像中相邻像素值的差异可以通过降低TV Loss来一定程度上进行解决 ,从而保持图像的光滑性。

TV Loss定义

 连续TV Loss的定义为:
J T 0 ( u ) = ∫ D u u x 2 + u y 2 d x d y J_{T_0}(u)=\int_{D_u}\sqrt{u^2_x+u^2_y}dxdy JT0(u)=Duux2+uy2 dxdy其中 u x = ∂ u ∂ x u_x=\frac{\partial u}{\partial x} ux=xu u y = ∂ u ∂ y u_y=\frac{\partial u}{\partial y} uy=yu D u D_u Du是定义域。

 带阶数的TV Loss的定义为:
J T 0 ( u ) = ∫ D u ( u x 2 + u y 2 ) β 2 d x d y J_{T_0}(u)=\int_{D_u}(u^2_x+u^2_y)^{\frac{\beta}{2}}dxdy JT0(u)=Du(ux2+uy2)2βdxdy

 离散TV Loss的定义为:
J T 0 ( u ) = ∑ i , j ( ( x i , j − 1 − x i , j ) 2 + ( x i + 1 , j − x i , j ) 2 ) β 2 J_{T_0}(u)=\sum\limits_{i,j}((x_{i,j-1}-x_{i,j})^2+(x_{i+1,j}-x_{i,j})^2)^{\frac{\beta}{2}} JT0(u)=i,j((xi,j1xi,j)2+(xi+1,jxi,j)2)2β

TV Loss运行代码

import torch 

def tv_loss(input_t):
	temp1 = torch.cat((input_t[:, :, 1:, :], input_t[:, :, -1, :].unsqueeze(2)), 2)
	temp2 = torch.cat((input_t[:, :, :, 1:], input_t[:, :, :, -1].unsqueeze(3)), 3)
	temp = (input_t - temp1)**2 + (input_t - temp2)**2
	return temp.sum()

if __name__ == '__main__':
	input = torch.rand(4,3,32,32)
	print(tv_loss(input))
 类似资料: