本节所使用的代码已经上传GIthub:
https://github.com/l743396657/Pyro_test/tree/master/1 A Simple Example
假设我们试图计算某物的重量,但我们使用的秤不可靠,每次我们称同一物体时给出的答案略有不同。我们可以尝试通过将噪声测量信息与基于物体的一些先验知识(如密度或物质属性)的猜测相结合,来弥补这种变化。
weight ∣ guess ∼ Normal(guess, 1 ) \text { weight } \mid \text { guess } \sim \text { Normal(guess, } 1) weight ∣ guess ∼ Normal(guess, 1)
measurement ∣ guess, weight ∼ Normal(weight, 0.75 ) \text { measurement } \mid \text { guess, weight } \sim \text { Normal(weight, } 0.75 \text { ) } measurement ∣ guess, weight ∼ Normal(weight, 0.75 )
下面我们使用Pyro来声明这个分布:
def scale(guess):
weight = pyro.sample("weight", dist.Normal(guess, 1.0))
return pyro.sample("measurement", dist.Normal(weight, 0.75))
条件分布是概率中非常重要的分布之一,使用Pyro表示条件分布是我们之后会经常用到的功能之一。在后面的概率网络等工作中,‘训练数据’大多数使用条件概率的形式表示的。
那么在我们已知上面的分布之后,我们希望知道,在measurement = 9.5的时候,物体的weight是多少。既:
(weight ∣ guess, measurement = 9.5 ) ∼ ? \text { (weight } \mid \text { guess, measurement }=9.5) \sim ? (weight ∣ guess, measurement =9.5)∼?
在Pyro中,有多种方式可以表示条件分布,这里先说明第一种:
conditioned_scale = pyro.condition(scale, data={"measurement": torch.tensor(9.5)})
结合Python的函数,可以写成:
def deferred_conditioned_scale(measurement, guess):
return pyro.condition(scale, data={"measurement": measurement})(guess)
上面的写法并不唯一,我们可以只用pyro.sample来实现条件分布:
def scale_obs(guess): # equivalent to conditioned_scale above
weight = pyro.sample("weight", dist.Normal(guess, 1.))
# here we condition on measurement == 9.5
return pyro.sample("measurement", dist.Normal(weight, 0.75), obs=torch.tensor(9.5))
在sample中加入obs选项,可以实现条件概率的表示(这种用法比较常见)。