Diffusion Posterior Sampling (DPS)
本文档主要介绍DPS的方法流程 · 5 min read
利用训练好的无条件生成模型,通过某种数学上的技巧,将无条件分数函数转化为条件分数函数,从而实现条件生成,在求解反问题中(通俗的来说,解方程)有应用潜力。
在很多科学问题中,我们通过观测到的结果数据,反向求解方程中的系数等问题,称为反问题。反问题的常见形式是:
yδ=Ax+n,
其中n为观测噪声(通常假设为高斯噪声),A为正向算子,yδ为观测数据,x为待求解系数。比如在图像领域,去噪问题中,A为恒等算子;增加分辨率或补全图像问题中,A为某个投影算子。
见基于分数的生成方法(score-based)
在无条件生成中,我们要学习的数据初始分布为x(0)∼p0(x(0))=pdata(x(0)),通过如下方程:
dx=−2β(t)dt+β(t)dw,
实现从数据分布到标准正态分布的演化,其中T=1,β(0)=0.2,β(1)=10。并训练好分数函数∇xlogpt(x)后,通过逆时演化以下方程:
dx=[−2β(t)−∇x(t)logpt(x(t))]dt+β(t)dwˉ
实现从数据分布中采样。
在反问题中,我们通常会有观测数据yδ,以及相应的正向算子A。因此我们通常想要学习条件数据分布,即x(0)∼pdata(x(0)∣yδ)。其对应正向SDE仍为方程(1),也就是说,我们仍然可以从该条件数据分布,顺时演化到标准正态分布。然而,其对应反向SDE为:
dx=[−2β(t)−∇x(t)logpt(x(t)∣yδ)]dt+β(t)dwˉ
通过前一步训练,我们有分数函数∇xlogpt(x),但并没有上式中的条件分布函数。
于是我们将上述(3)式,通过贝叶斯公式转化为:
dx=[−2β(t)−(∇x(t)logpt(x(t))+∇x(t)logpt(yδ∣x(t)))]dt+β(t)dwˉ
其中,分数函数已有训练近似,我们只需计算得出∇x(t)logpt(yδ∣x(t))即可进行逆时演化。
注意到:
p(yδ∣x(t))=∫p(yδ∣x(0),x(t))p(x(0)∣x(t))dx(0)=∫p(yδ∣x(0))p(x(0)∣x(t))dx(0)=Ex(0)∼p(x(0)∣x(t))[p(yδ∣x(0))]
由于(详见 DPS 附录A):
x^(0):=E[x(0)∣x(t)]=αˉ(t)1(x(t)+(1−αˉ(t))∇x(t)logpt(x(t))≃αˉ(t)1(x(t)+(1−αˉ(t))sθ(x(t),t))
于是有:
p(yδ∣x(t))=Ex(0)∼p(x(0)∣x(t))[p(yδ∣x(0))]≃p(yδ∣Ex(0)∼p(x(0)∣x(t))[x(0)])=p(yδ∣x^(0))
(上式利用E[f(x)]≃f(Ex),这里有一个Jessen Gap,详见 DPS 附录A)
由于假设n为高斯噪声,于是(12)式可由正态分布概率函数计算,得到:
∇x(t)logp(y∣x(t))≃−σ21∇x(t)∥y−A(x^(0))∥22
其中,σ为n服从的正态分布的方差。由此,我们可以近似计算反向SDE(3),并从任一标准正态分布的采样出发,逆时演化得到条件数据分布的采样。
- 只需要训练一次无条件分数函数,理论上即可对任一算子进行上述操作进行反演,是一个无监督的训练模型
- 由于基于SDE,采样过程不稳定,结果时好时坏
- 需要进行大量的迭代过程,相比于有监督的端到端模型,速度较慢
- 由于没有数据的监督,相比于有监督的端到端模型,效果较差
- 可以考虑基于ODE的采样过程,增强采样稳定性
- 借助一些加速生成的模型(例如Consistency Model),提高采样速度