astype = lambda x, *args, **kwargs: x.type(*args, **kwargs)
cmp = d2l.astype(y_hat, y.dtype) == y
解释:x接受第0个参数y_hat,args接收其它后面的参数y.dtype(这里是torch.int64),x.type是将x的元素强制转换成某个属性。综合起来这个lambda的意思是将y_hat的元素类型设置为和y.dtype一样的类型。
reduce_sum = lambda x, *args, **kwargs: x.sum(*args, **kwargs)
d2l.reduce_sum(d2l.astype(cmp, y.dtype))
解释:等价于d2l.astype(cmp, y.dtype).sum();
另外,bool类型是直接可以加和的,例子如下:
a = torch.tensor([[False, True], [True, True]])
a.sum()
output: tensor(3)