基于分数的生成方法
本文主要介绍score-based generative model的方法流程 · 6 min read
相比于DDPM等离散格式方法基于概率分布的方法(DDPM),该方法提出了连续意义下的生成过程,为后续DPS等方法的发展奠定了基础。该方法考虑连续情形的SDE,正向演变(顺着时间)会将初始数据分布演化为高斯分布;依照相应反向SDE,反向演变会将高斯分布演化为数据分布,从而实现在数据分布中采样。
我们的目标是建立一个连续扩散序列{x(t)}t=0T,使得x(0)服从数据分布,且x(T)服从某个容易采样的分布,比如标准高斯分布。于是我们考虑建立以下伊藤随机微分方程:
dx=f(x,t)dt+g(t)dw,
这里dw是标准布朗运动。有许多f,g的选择可以满足我们的上述要求,比如将DDPM连续化后得到如下随机微分方程(连续化过程见s0lu5lblzl4.feishu.cn):
dx=−2β(t)dt+β(t)dw,
就可以实现从数据分布到标准正态分布的演化,其中T=1,β(0)=0.2,β(1)=10.
以下我们记pt(x)为x(t)所服从的概率分布函数,pst(x(t)∣x(s))为x(s)到x(t)的转移概率核。
对每个SDE(1),都存在相应反SDE方程:
dx=[f(x,t)−g(t)2∇xlogpt(x)]dt+g(t)dwˉ,
其中,dwˉ为反向过程的标准布朗运动。如果我们能得到分数函数∇xlogpt(x),就可以沿着反SDE(3),从标准正态分布的采样出发,演化到数据分布中的采样。遗憾的是,解析意义下我们难以得到上述分数函数,不过,我们可以建立一个神经网络sθ(x(t),t)去拟合该分数函数。
在训练过程中,我们有一个自然的损失函数:
J1=Et{Ex(t)[∥sθ(x(t),t)−∇x(t)logp(x(t))∥22]}
但是,上述损失函数中的分数函数我们没法提前计算得知。幸运的是,上述损失函数可以进行转化,从而得到以下等价的损失函数(转化过程见附录A):
J2=Et{Ex(0)Ex(t)∣x(0)[∥sθ(x(t),t)−∇x(t)logp0t(x(t)∣x(0))∥22]}
至于训练样本生成,我们可以数据集中任意x(0)出发,正向演化到x(t),然后代入计算上述损失函数。这里的∇x(t)logp0t(x(t)∣x(0))可以是易于计算的。
至此,我们可以完成训练,并通过反向SDE方程进行采样。
注意到,上述采样方式依照一个SDE方程,相同的采样起点(标准正态分布采样)并不能得到相同的采样终点(数据分布采样)。这在生成中没有什么影响,但在条件生成(即生成满足某种条件的图片)时,会引入一些问题。于是,本文作者提出了一个ODE采样的方法。我们可以将上述反向SDE改写为如下ODE:
dx=[f(x,t)−2g(t)2∇xlogpt(x)]dt,
可以证明,在相同初始分布下(x(0)服从同一分布),任意时刻t,由(3)和(6)得到的x(t)服从同一分布(详细证明见附录B)。因此,(3)和(6)拥有相同的分数函数。我们可以通过SDE进行训练,然后将训练得到的分数函数代入(6)式中,通过上述ODE方程进行数据分布的采样,从而保证采样的稳定性。
J1J2=Et∼U(0,1), x(t)∼p(x(t))[∥sθ(x(t),t)−∇x(t)logp(x(t))∥22]=Et∼U(0,1), x(t)∼p(x(t)∣x(0)), x(0)∼p(x(0))[∥sθ(x(t),t)−∇x(t)logp(x(t)∣x(0))∥22].
由于t服从相同分布,因此我们考虑:
J^1J^2=Ex(t)∼p(x(t))[∥sθ(x(t),t)−∇x(t)logp(x(t))∥22]=Ex(t)∼p(x(t)∣x(0)), x(0)∼p(x(0))[∥sθ(x(t),t)−∇x(t)logp(x(t)∣x(0))∥22].
J^1J^2=Ex(t)∼p(x(t))[∥sθ(x(t))∥22]−S1(θ)+C1=Ex(t)∼p(x(t)∣x(0)), x(0)∼p(x(0))[∥sθ(x(t))∥22]−S2(θ)+C2,
S1(θ)S2(θ)=2Ex(t)∼p(x(t))[⟨sθ(x(t),t),∇x(t)logp(x(t))⟩]=2Ex(t)∼p(x(t)∣x(0)), x(0)∼p(x(0))[⟨sθ(x(t),t),∇x(t)logp(x(t)∣x(0))⟩].
21S1(θ)=∫x(t)p(x(t))⟨sθ(x(t),t),∇x(t)logp(x(t))⟩dx(t)=∫x(t)⟨sθ(x(t),t),∇x(t)p(x(t))⟩dx(t)=∫x(t)⟨sθ(x(t),t),∫x(0)p(x(0))∇x(t)p(x(t)∣x(0))dx(0)⟩dx(t)=∫x(t)⟨sθ(x(t),t),∫x(0)p(x(0))p(x(t)∣x(0))∇x(t)logp(x(t)∣x(0))dx(0)⟩dx(t)=∫x(t)∫x(0)p(x(0))p(x(t)∣x(0))⟨sθ(x(t),t),∇x(t)logp(x(t)∣x(0))⟩dx(0)dx(t)=Ex(t)∼p(x(t)∣x(0)), x(0)∼p(x(0))[⟨sθ(x(t),t),∇x(t)logp(x(t)∣x(0))⟩]=21S2(θ).
因此有J2=J1−C1+C2,两者具有相同的最优点,证毕。
我们考虑SDE(3),对应的Fokker-Planch方程为:
∂t∂p(x(t))=−∇x(t)⋅[f(x(t),t)p(x(t))]+21g2(t)∇x(t)⋅∇x(t)p(x(t)).
由于事实∇x(t)(logp(x(t)))p(x(t))=∇x(t)p(x(t)),我们有以下结果:
∂t∂p(x(t))=−∇x(t)⋅[f(x(t),t)p(x(t))−21g2(t)∇x(t)p(x(t))]=−∇x(t)⋅{[(f(x(t),t)−21g2(t)∇x(t)logp(x(t)))]p(x(t))}
对比两式,我们得到(8)对应方程:
dx=[f(x(t),t)−21g2(t)∇x(t)logp(x(t))]dt.
由于(8)是ODE,其对应反向方程仍为:
dx=[f(x(t),t)−21g2(t)∇x(t)logp(x(t))]dt.
证毕。