Diffusion Models and Ornstein-Uhlenbeck Processes

In this post, we will give a brief introduction to diffusion models and talk about their connections to Ornstein-Uhlenbeck processes.

What are diffusion models

Diffusion models are a family of generative models. Broadly, these are the models that learn the data distribution by injecting noise via a diffusion process and then generate new samples by denoising.

Diffusion models currently achieve the state-of-the-art performance in image generation and have also seen applications in other tasks.

Ornstein-Uhlenbeck processes

The diffusion processes the models usually use are Ornstein–Uhlenbeck processes, which are defined using SDE

dXt=g(t)22σ2(μXt)dt+g(t)dBt, dX_t = \frac{g(t)^2}{2\sigma^2}(\mu - X_t)dt + g(t)dB_t,

where μRd\mu\in \mathbb R^d, BtB_t is the dd-dimensional Brownian motion and g(t)g(t) is a strictly positive function.

Unlike Brownian motion, the process XtX_t has the stationary distribution N(μ,σ2)N(\mu,\sigma^2) and an explicit transition probabiliy: for sts\leq t, conditioning on Xs=xsX_s = x_s, XtX_t is normally distributed with mean ms,tm_{s,t} and variance σs,t2\sigma^2_{s,t}, where

ms,t=μ+(xsμ)exp(12σ2stg(t)2dt),σs,t2=σ2(1exp(1σ2stg(t)2dt)).\begin{aligned} m_{s,t} &= \mu + (x_s-\mu)\exp\left(-\frac{1}{2\sigma^2}\int_s^t g(t)^2 dt\right),\\ \sigma^2_{s,t} &= \sigma^2\left(1-\exp\left(-\frac{1}{\sigma^2}\int_s^t g(t)^2 dt\right)\right). \end{aligned}

Fix an end time T>0T>0 and set X~t=XTt\widetilde X_t = X_{T-t} be the time-reversal process for XtX_t. Then X~t\widetilde X_t satisfies the SDE

dX~t=[g(Tt)22σ2(μX~t)+g(Tt)2xlogpTt(X~t)]dt+g(Tt)dBt, d\widetilde{X}_t = \left[-\frac{g(T-t)^2}{2\sigma^2}(\mu-\widetilde{X}_t) + g(T-t)^2\nabla_x\log p_{T-t}(\widetilde{X}_t)\right]dt + g(T-t)dB'_t,

where ptp_t is the density of XtX_t and BtB'_t is a different Brownian motion.

Noise injection & denoising

The framework of a diffusion model generally goes as follows: we first fix a sequence of time 0=t0<t1<<tn=T0=t_0< t_1 <\dots< t_n = T, g(t)g(t), μ\mu and σ\sigma. Then given each point x0Rdx_0\in\mathbb R^d sampled from the data distribution pp, we can then add Gaussian noise to the data iteratively by sampling xi:=Xtix_i := X_{t_i}, where XtX_t follows SDE (1) starting at x0x_0. Since the transition probability is just Gaussian, we can sample (Xn)(X_n) efficiently.

Assuming TT is large enough that xnx_n is approximately distributed according to the stationary Gaussian distribtion. Then the objective will be learning the the information of ptp_t so that we can approximate the time-reversal process (3) with initial data X~0N(μ,σ2)\widetilde{X}_0 \sim N(\mu,\sigma^2).

Score matching

Since g(t)g(t), μ\mu and σ\sigma are known, we only need an estimate of xlogpt\nabla_x\log p_t to approximate the time-reserval process. The gradient of log density logp:RdRd\nabla \log p:\mathbb R^d\to \mathbb R^d is known as the score. The idea of learning the score of a data distribution rather than the distribution itself is usually referred to as Score matching.

One major advantage of this idea is that we do not need to normalize our estimate: If we use a neural network f(θ;x)f(\theta;x) to approximate some density p(x)p(x), we need to make sure that f(θ;x)dx=1\int f(\theta;x)dx=1. Computing such an integral can be challenging or even intractable.

For diffusion models, our training objective will be given by

i=1nλiE[f(θ;xi,ti)xlogpti(xi)2], \sum_{i=1}^n\lambda_i\mathbb{E}\left[\left\Vert f(\theta;x_i,t_i) - \nabla_x\log p_{t_i}(x_i)\right\Vert^2\right],

where (λi)(\lambda_i) are positive weights summed up to 11 and determined by the choices of σ\sigma and g(t)g(t).

Trainable objective

Since xpt\nabla_xp_t is not tractable, we need to convert (4) to a trainable objective. To do this, we note that

E[f(θ;xi,ti)xlogpti(xi)2]= E[f(θ;xi,ti)2]E[f(θ;xi,ti)xlogpti(xi)]+Const= E[f(θ;xi,ti)2](f(θ;xi,ti)xlogpti(xi))pti(xi)dxi+Const= E[f(θ;xi,ti)2](f(θ;xi,ti)1pti(xi)xpti(xi))pti(xi)dxi+Const= E[f(θ;xi,ti)2](f(θ;xi,ti)xpti(xi))dxi+Const,\begin{aligned} &\mathbb{E}\left[\left\Vert f(\theta;x_i,t_i) - \nabla_x\log p_{t_i}(x_i)\right\Vert^2\right]\\ =~&\mathbb{E}\left[\left\Vert f(\theta;x_i,t_i)\right\Vert^2\right] - \mathbb{E}\left[f(\theta;x_i,t_i)\cdot\nabla_x\log p_{t_i}(x_i)\right] + \text{Const}\\ =~&\mathbb{E}\left[\left\Vert f(\theta;x_i,t_i)\right\Vert^2\right] - \int \left(f(\theta;x_i,t_i)\cdot\nabla_x\log p_{t_i}(x_i) \right)p_{t_i}(x_i)dx_i+ \text{Const}\\ =~&\mathbb{E}\left[\left\Vert f(\theta;x_i,t_i)\right\Vert^2\right] - \int \left(f(\theta;x_i,t_i)\cdot \frac{1}{p_{t_i}(x_i)}\nabla_x p_{t_i}(x_i) \right)p_{t_i}(x_i)dx_i+ \text{Const}\\ =~&\mathbb{E}\left[\left\Vert f(\theta;x_i,t_i)\right\Vert^2\right] - \int \left(f(\theta;x_i,t_i)\cdot \nabla_x p_{t_i}(x_i) \right)dx_i+ \text{Const}, \end{aligned}

where we can use the fact that

pti(xi)=p0,ti(xix0)p0(x0)dx0\begin{aligned} p_{t_i}(x_i) = \int p_{0,t_i}(x_i|x_0)p_0(x_0)dx_0 \end{aligned}

to conclude that

E[f(θ;xi,ti)xlogpti(xi)2]=E[f(θ;xi,ti)xlogp0,ti(xix0)2]+Const. \mathbb{E}\left[\left\Vert f(\theta;x_i,t_i) - \nabla_x\log p_{t_i}(x_i)\right\Vert^2\right] = \mathbb{E}\left[\left\Vert f(\theta;x_i,t_i) - \nabla_x\log p_{0,t_i}(x_i|x_0)\right\Vert^2\right] + \text{Const}.

The expection on the right-hand side of (5) is particularly nice as we know xix_i conditioned on x0x_0 is normally distributed with mean and variances given explicitly by (2). Therefore, we can set our loss function to be

i=1nλiE[f(θ;xi,ti)xlogp0,ti(xix0)2]= i=1nλiE[f(θ;xi,ti)+xim0,ti(x0)σ0,ti2]\begin{aligned} &\sum_{i=1}^n\lambda_i\mathbb{E}\left[\left\Vert f(\theta;x_i,t_i) - \nabla_x\log p_{0,t_i}(x_i|x_0)\right\Vert^2\right]\\ =~&\sum_{i=1}^n\lambda_i\mathbb{E}\left[\left\Vert f(\theta;x_i,t_i) + \frac{x_i-m_{0,t_i}(x_0)}{\sigma_{0,t_i}}\right\Vert^2\right] \end{aligned}

Training

To sample training data given the loss function (6), we can first sample an initial data and a random index

x0p,Ii=1nλiδi. x_0\sim p,\quad I \sim \sum_{i=1}^n \lambda_i \delta_i.

We can then sample xIx_I using (2) and compute the gradient for the loss

f(θ;xI,tI)+xIm0,tI(x0)σ0,tI2. \left\Vert f(\theta;x_I,t_I) + \frac{x_I-m_{0,t_I}(x_0)}{\sigma_{0,t_I}}\right\Vert^2.

Conditioning

In many practical applications of generative models, instead of simply sampling from the data distribution pp, we want to generate data given certain properties. For example, in image generation, we want an image that fits a certain text discription.

Mathematically, this means that we want to learn the conditional distribution p(L)p(\cdot|L) where \ell is an additional input. In this case, we can still use our noise injection and denoising but the score xlogpt(x)\nabla_x \log p_t(x) needs to be replaced with conditional score xlogpt(x)\nabla_x \log p_t(x|\ell).

Bayes' rule implies that

xlogpt(x)=xlogpt(x)+xlogpt(x). \nabla_x\log p_t(x|\ell) = \nabla_x\log p_t(x) + \nabla_x \log p_t(\ell|x).

Hence, suppose we have a good estimate for the score. It is suffice to learn the information of xlogpt(x)\nabla_x \log p_t(\ell|x). We refer to this method of learning condition score gradient guidance.

Sampling

Once we have trained the network f(θ;x,t)xpt(x)f(\theta;x,t)\approx \nabla_xp_t(x). We can generate new data by sampling xnx_n from N(μ,σ2)N(\mu,\sigma^2) and then simulate either the SDE (3) or the ODE

dx~t=[g(Tt)22σ2(μx~t)+12g(Tt)2xlogpTt(x~t)]dt. d\widetilde{x}_t = \left[-\frac{g(T-t)^2}{2\sigma^2}(\mu-\widetilde{x}_t) + \frac{1}{2}g(T-t)^2\nabla_x\log p_{T-t}(\widetilde{x}_t)\right]dt.

The two processes give the same marginal distributions for xtx_t. Emperically, the SDE gives samples of better quality while the ODE benefits in faster convergence.