当前位置: 首页 > 工具软件 > Pyro > 使用案例 >

概率模型编程工具 Pyro(三)条件分布

姜飞飙
2023-12-01

概率模型编程工具 Pyro(三)条件分布

本节所使用的代码已经上传GIthub:

https://github.com/l743396657/Pyro_test/tree/master/1 A Simple Example

例子说明:

假设我们试图计算某物的重量,但我们使用的秤不可靠,每次我们称同一物体时给出的答案略有不同。我们可以尝试通过将噪声测量信息与基于物体的一些先验知识(如密度或物质属性)的猜测相结合,来弥补这种变化。

 weight  ∣  guess  ∼  Normal(guess,  1 ) \text { weight } \mid \text { guess } \sim \text { Normal(guess, } 1)  weight  guess  Normal(guess, 1)

 measurement  ∣  guess, weight  ∼  Normal(weight,  0.75  )  \text { measurement } \mid \text { guess, weight } \sim \text { Normal(weight, } 0.75 \text { ) }  measurement  guess, weight  Normal(weight, 0.75 ) 

下面我们使用Pyro来声明这个分布:

def scale(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))
    return pyro.sample("measurement", dist.Normal(weight, 0.75))

条件分布:

条件分布是概率中非常重要的分布之一,使用Pyro表示条件分布是我们之后会经常用到的功能之一。在后面的概率网络等工作中,‘训练数据’大多数使用条件概率的形式表示的。

那么在我们已知上面的分布之后,我们希望知道,在measurement = 9.5的时候,物体的weight是多少。既:

 (weight  ∣  guess, measurement  = 9.5 ) ∼ ? \text { (weight } \mid \text { guess, measurement }=9.5) \sim ?  (weight  guess, measurement =9.5)?

在Pyro中,有多种方式可以表示条件分布,这里先说明第一种:

conditioned_scale = pyro.condition(scale, data={"measurement": torch.tensor(9.5)})

结合Python的函数,可以写成:

def deferred_conditioned_scale(measurement, guess):
    return pyro.condition(scale, data={"measurement": measurement})(guess)

上面的写法并不唯一,我们可以只用pyro.sample来实现条件分布:

def scale_obs(guess):  # equivalent to conditioned_scale above
    weight = pyro.sample("weight", dist.Normal(guess, 1.))
    # here we condition on measurement == 9.5
    return pyro.sample("measurement", dist.Normal(weight, 0.75), obs=torch.tensor(9.5))

在sample中加入obs选项,可以实现条件概率的表示(这种用法比较常见)。

 类似资料: