参考:
Diffusion模型能做什么?
DALL·E介绍
与GAN对比
- 孰是孰非还得等等,目前来看主流的生成模型已经不怎么用GAN了,最新的都是Diffusion Model
- GAN要训练俩网络,感觉难度 较大,容易不收敛,而且多样性较差,只关注能骗过判别器就行(容易走捷径、或走歪路)
- Diffusion MOdel用一种更简单的方法来诠释了生成模型该如何学习已经生成,其实感觉更简单
相关论文
- 主要基于这个来讲:DDPM论文连接:[Ho et al.,2020]
(DDPM的全程是 Denoising Diffusion Probabilistic Model,最早的引用是2015年了)
- 前向过程,就是不断往输入数据加噪声,最快快变成了个纯噪声
- 每一个时刻都要添加高斯噪声,后一时刻都是由前一刻增加噪声得到
- 其实这个过程可以看成我们不断构建标签(噪声)的过程,后续会用到

前向过程的公式:(杂点的增加应该越来越多)
at=1−βtxt=αtxt−1+1−αtz1β要越来越大(论文中0.0001到0.002,从而α会越来越小)一开始加点噪就有效果,越往后得加噪越多才行递归公式xt−1是前一刻的状态,z1表示噪音,后者权重越来越大
递归有点像RNN,有一些缺点,有没有方法不递归直接计算?
简化递归公式(推导):
∵xt=∵xt−1= ∴xt==== 其中,αˉt=αtxt−1+1−αtz1αt−1xt−2+1−αt−1z2αt(αt−1xt−2+1−αt−1z2)+1−αtz1αtαt−1xt−2+(αt(1−αt−1)z2)+1−αtz1)αtαt−1xt−2+1−αtαt−1zˉ2αˉtx0+1−αˉtzˉtαt⋅αt−1⋯α0Z1和Z2都浮肿高斯,然后可以进行化简Z1:N (0,1−αt)Z2:N (0,at(1−αt−1))N(0,σ12Ⅰ)+N(0,σ22Ⅰ)∼N(0,(σ12+σ22)Ⅰ)(累乘)

- Q:x0能不通过递归直接变成xt,那xt能直接变成x0吗?
- A:不行,公式上推导不出来,需要迭代
(化简过程需要 贝叶斯公式。另外说一下,原始论文的推导不是用贝叶斯公式开始的,是从概率论角度说的,难度可能会更大一点)
q(xt−1∣xt,x0)=q(xt∣xt−1,x0)q(xt∣x0)q(xt−1∣x0) 化简思路:使用贝叶斯公式q(xt−1∣x0)=q(xt∣x0)=q(xt−1∣xt,x0)=aˉt−1x0+1−aˉt−1zaˉtx0+1−aˉtzaˉtxt−1+1−aˉtz∼N (均值μ, 方差σ2)∼N (aˉt−1x0, 1−aˉt−1)∼N (aˉtx0, 1−aˉt)∼N (aˉtxt−1, 1−aˉt)
化简过程
这三项都通过前向过程能算出来,分布也列出来了∝exp(−21(βt(xt−atxt−1)2+1−aˉt−1(xt−1−aˉt−1x0)2+1−aˉt(xt−aˉtx0)2))因为exp运算。把标准正态分布展开后,乘法就相当于加,除法就相当于减,把他们汇总 展开∝ = = exp(−2σ2(x−μ)2)= 所以μ~t(xt,x0)= 又xo= exp(−21(βt(xt−atxt−1)2+1−aˉt−1(xt−1−aˉt−1x0)2+1−aˉt(xt−aˉtx0)2))exp(−21(……)),拆开平方exp(−21((βtat+1−aˉtt−11)xt−12−(βt2atxt+1−aˉt−12aˉt−1x0)xt−1+C(xt,x0)))C是常数项。这个任务中,核心就是求跟Xt−1有关的,其他的都不关心。展开是为了给xt−1配方exp(−21(σ21x2−σ22μx+σ2μ2)),这样能得到均值和方差1−aˉtat(1−aˉt−1)xt+1−aˉtaˉt−1βtx0,配完化简的结果aˉt1(xt−1−aˉtzt),之前是已知x0求xt,这条公式逆了一下
得到最终结果
μ~t=at1(xt−1−aˉtβtzˉt)
其中,Zt是什么?Zt是我们估计的每个时刻的噪声
- 这东西看起来无法直接求解,只能训练一个模型来计算
- 这些相关论文里竟然都用Unet这种结构来玩的,可能是编码和解码看着比较舒服
- 模型的输入参数有两个,分别是当前时刻的分布和时刻t
把两个阶段汇总到一起,就是终极流程图了

代码
import math
from inspect import isfunction
from funtools import partial
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange
import torch
from torch import nn, einsum
import torch.nn.functional as F
辅助函数
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
class Residual(nn.Module):
……
……