首先定义Coral损失函数
import torch
def CORAL(source, target, **kwargs):
d = source.data.shape[1]
ns, nt = source.data.shape[0], target.data.shape[0]
# source covariance
xm = torch.mean(source, 0, keepdim=True) - source
xc = xm.t() @ xm / (ns - 1)
# target covariance
xmt = torch.mean(target, 0, keepdim=True) - target
xct = xmt.t() @ xmt / (nt - 1)
# frobenius norm between source and target
loss = torch.mul((xc - xct), (xc - xct))
loss = torch.sum(loss) / (4*d*d)
return loss
用模拟数据进行验证
# 通过随机数模拟产生经过模型输出的结果source和target
# batch可以不一样,但分类类数要一样
source = torch.rand(64,4) # 源域输出结果为batch=64, 4分类
target = torch.rand(64,4) # 目标域域输出结果为batch=64, 4分类
Coral_loss = CORAL(source=source, target=target)
print(Coral_loss)
>>>output
tensor(3.3486e-05)
参考资料
链接: https://zhuanlan.zhihu.com/p/108778552.