当前位置: 首页 > 工具软件 > d2l-torch > 使用案例 >

../d2l/torch.py中的lambda表达式

解浩渺
2023-12-01

astype

astype = lambda x, *args, **kwargs: x.type(*args, **kwargs)
cmp = d2l.astype(y_hat, y.dtype) == y

解释:x接受第0个参数y_hat,args接收其它后面的参数y.dtype(这里是torch.int64),x.type是将x的元素强制转换成某个属性。综合起来这个lambda的意思是将y_hat的元素类型设置为和y.dtype一样的类型。

reduce_sum

reduce_sum = lambda x, *args, **kwargs: x.sum(*args, **kwargs)
d2l.reduce_sum(d2l.astype(cmp, y.dtype))

解释:等价于d2l.astype(cmp, y.dtype).sum();

另外,bool类型是直接可以加和的,例子如下:

a = torch.tensor([[False, True], [True, True]])
a.sum()
output: tensor(3)

 类似资料: