参考资料

数学基础

两个一元正态分布的 KL 散度计算推导: https://statproofbook.github.io/P/norm-kl.html

多元正态分布的 KL 散度可参考 wiki

概率空间

概率空间的严格定义: 概率空间是一个三元组 $(\Omega, \mathcal{F}, P)$, 其中

(1) $\Omega$ 是样本空间, 它是一个集合, 即随机事件的所有可能的取值. 例如在离散情形下, 掷骰子的样本空间是 $\Omega=\lbrace 1, 2, 3, 4, 5, 6\rbrace$; 在连续情形下, 一维正态分布的 $\Omega=\mathbb{R}$.

(2) $\mathcal{F}$ 代表事件域, 它需要是定义在 $\Omega$ 上的一个 $\sigma$ 代数: 换句话说, $\mathcal{F}$ 是一个集合, 它每个元素都是 $\Omega$ 的子集, 并且满足以下三个条件:

  • 空集属于 $\mathcal{F}$: $\Phi\in\mathcal{F}$
  • 如果一个集合 $A$ 属于 $\mathcal{F}$, 那么它相对于全集 $\Omega$ 的补集 $\bar{A}=\Omega\backslash A$ 也属于 $\mathcal{F}$
  • 如果可数个 $A_1, A_2, …$ 都属于 $\mathcal{F}$, 那么它们的并集 $\bigcup_{i=1}A_i$ 也属于 $\mathcal{F}$.

TODO: 什么是可数个集合的并集

(3) $P$ 代表概率测度, 它是定义在事件域上的一个函数, 且需要满足:

  • 对于任意 $A\in\mathcal{F}$, $P(A) \geq 0$
  • $P(\Omega)=1$
  • 假设可数个 $A_1,A_2, …$ 都属于 $\mathcal{F}$, 且两两不相交, 那么 $P(\bigcup_{i}A_i)=\sum_{i}{P(A_i)}$

在离散情形投骰子的情形下, 定义样本空间 $\Omega=\lbrace1, 2, 3, 4, 5, 6\rbrace$, 定义事件域 $\mathcal{F}$ 为 $\Omega$ 的所有可能的子集, 例如: 事件 $\lbrace1, 3, 6\rbrace$ 代表投出的骰子取值是 ${1, 3, 6}$ 之一. $P(\lbrace i\rbrace)=\frac{1}{6}, i\in\Omega$, 注意对于任意 $A\in\mathcal{F}$, $P(A)$ 就完全被确定了.

在连续情形的标准正态分布的情形下, 定义样本空间 $\Omega=\mathbb{R}$, 而定义事件域是 $\mathbb{R}$ 上的博雷尔集, 而概率 $P$ 的定义是:

\[P(A)=\int_{x\in A}{p(x)dx},\]

其中

\[p(x)=\frac{1}{\sqrt{2\pi}}\exp^{-\frac{x^2}{2}}\]

所谓博雷尔集, 有几种定义形式, 都对应着同样的东西:

(1) $\mathbb{R}$ 上可以拆解为可数个开区间的并集或者交集的集合全体构成的集合.

(2) 包含所有开区间的最小 $\sigma$ 代数

博雷尔集的一些性质是: 它包含所有的开集闭集(因此开区间和闭区间都在里面), 注意博雷尔集不止包含开集和闭集, 例如博雷尔集包含半开半闭区间, 但它既不是开集也不是闭集

所谓开集的一般性定义是 TODO, 在 $\mathbb{R}$ 上开集的定义是 TODO

所谓闭集的一般性定义是 TODO, 在 $\mathbb{R}$ 上闭集的定义是 TODO

在连续情形下, 一般是在 $\mathbb{R}$ 上定义概率密度函数, 而概率测度里的积分需要对任意博雷尔集上有定义, 严格的说, 这个积分是勒贝格积分而不是高等代数里的黎曼积分. 而同时满足这两个条件概率密度函数可以保证概率测度是良定义的, 即:

(1) $\int_{-\infty}^{+\infty}{p(x)}{dx}=1$

(2) $p(x)\geq 0, \forall x\in\mathbb{R}$

在其他的讨论里, 例如讨论期望和方差时, 还一般会假定概率密度函数的矩的存在性, 一般还会做一些关于概率密度函数连续性之类的假定. 但这些条件不是必须的, 甚至于不一定需要定义密度函数, 只需要概率测度 $P$ 有定义即可. 这些关于 $\mathcal{F}$, 以及勒贝格积分之类的讨论属于高等概率论的范畴

以下讨论主要是对离散型和连续型概率空间:

  • 离散型概率空间的概率测度可以完全由概率质量函数 (Probability Mass Function, PMF) 刻画: 一对一的关系
  • 连续型概率空间的概率测度可以由概率密度函数 (Probability Density Function, PDF) 刻画: 概率密度函数 -> 概率测度应该是单射, 但可能不一定是双射, TODO, 待确认?

条件分布, 联合分布, 边缘分布

记号上来说, 条件概率的记号是令人有些混淆的: 一般来说, 我们会用 $f(x)$ 和 $g(x)$ 来表示不同的函数, 因为用了 $f$ 和 $g$ 做区分. 然而在条件概率的情形下, $p(x,z)$ 与 $p(x)$ 代表了相同概率测度 $P$ 的联合与边缘概率, 这里的两个 $p$ 可能会有一定的混淆.

边缘概率与条件概率的联系:

\[\begin{align} p(x)&:=\int_z{p(x,z)dz}\\ &=\int_z{p(z)p(x|z)dz} \end{align}\]

正态分布及性质

一元正态分布

\[\mathcal{N}(\mu,\sigma^2)\sim \frac{1}{\sigma\sqrt{2\pi}}e^{-\frac{1}{2}\left(\frac{x-\mu}{\sigma}\right)^2}\]

特别地, 一元标准正态分布的概率密度函数是

\[\mathcal{N}(0, 1)\sim \frac{1}{\sqrt{2\pi}}e^{-\frac{1}{2}x^2}\]

多元正态分布

\[\mathcal{N}(\boldsymbol{\mu},\Sigma)\sim\frac{1}{\sqrt{2\pi}^k}\cdot\frac{1}{\sqrt{\text{det}(\Sigma)}}\exp\left(-\frac{1}{2}(\mathbf{x}-\boldsymbol{\mu})^T\Sigma^{-1}(\mathbf{x}-\boldsymbol{\mu})\right)\]

特别地, 多元标准正态分布地概率密度函数

\[\mathcal{N}(\mathbf{0},I)\sim \frac{1}{\sqrt{2\pi}^k}\exp(-\frac{1}{2}\mathbf{x}^T\mathbf{x})\]

两个多元正态分布的 KL 散度:

\[D_{KL}(\mathcal{N}(\boldsymbol{\mu}_1,\Sigma_1)\parallel\mathcal{N}(\boldsymbol{\mu}_2,\Sigma_2)) = \frac{1}{2}\left(\text{tr}(\Sigma_2^{-1}\Sigma_1)-k+(\boldsymbol{\mu}_2-\boldsymbol{\mu}_1)^T\Sigma_2^{-1}(\boldsymbol{\mu}_2-\boldsymbol{\mu}_1)+\ln\frac{\text{det}\Sigma_2}{\text{det}\Sigma_1}\right)\]

特别地:

\[D_{KL}(\mathcal{N}(\boldsymbol{\mu},\Sigma)\parallel\mathcal{N}(\mathbf{0},I)) = \frac{1}{2}\left(\text{tr}(\Sigma)-k+\boldsymbol{\mu}^T\boldsymbol{\mu}-\ln(\text{det}\Sigma)\right)\]

更特别地, 如果 $\Sigma$ 是对角阵, 且对角元素为 $\sigma_1^2,…,\sigma_k^2$, 那么:

\[D_{KL}(\mathcal{N}(\boldsymbol{\mu},\text{diag}(\sigma_1^2,...,\sigma_k^2))\parallel\mathcal{N}(\mathbf{0}, I)) = \frac{1}{2}\left(\sum_{i=1}^{k}{(\sigma_i^2-2\ln\sigma_i)}-k+\boldsymbol{\mu}^T\boldsymbol{\mu}\right)\]

反过来, 特别地

\[D_{KL}(\mathcal{N}(\mathbf{0},I) \parallel \mathcal{N}(\boldsymbol{\mu},\Sigma))=\frac{1}{2}\left(\text{tr}(\Sigma^{-1})-k+\boldsymbol{\mu}\Sigma^{-1}\boldsymbol{\mu}+\ln(\text{det}\Sigma)\right)\]

更特别地

\[D_{KL}(\mathcal{N}(\mathbf{0},I) \parallel \mathcal{N}(\boldsymbol{\mu},\text{diag}(\sigma_1^2,...,\sigma_k^2)))=\frac{1}{2}\left(\sum_{i=1}^{k}{(\frac{1}{\sigma_i^2}+2\ln\sigma_i)}-k+\sum_{i=1}^{k}\sigma_i^2\mu_i^2\right)\]

KL 散度

KL 散度在离散型和连续型概率空间上有定义, 在部分混合型分布上也可以有定义, 我们只考虑连续型和离散型的情况. KL 散度用于刻画两个分布间的“距离”, 具有非负性, 非对称性.

离散型

\[D_{KL}(P\parallel Q)=\sum_{x\in\Omega}{P(x)\ln\frac{P(x)}{Q(x)}}\]

其中 $\Omega$ 是样本空间, $P(x)$ 与 $Q(x)$ 是两个概率测度对应的概率质量函数

连续型

\[D_{KL}(P\parallel Q)=\int_{-\infty}^{+\infty}{p(x)\ln\frac{p(x)}{q(x)}}dx\]

其中 $p(x)$ 与 $q(x)$ 是两个概率测度对应的概率密度函数

不等式

假设概率空间是 $(\Omega_z \times \Omega_x, \mathcal{F}, P)$, $p(x,z)$ 是对应的联合概率密度函数, 对于每个特定的 $x$ 来说, 记 $p(z x)$ 是样本空间 $\Omega_z$ 上的边缘密度函数.

另一方面, 假设 $q(z;x)$ 也是定义在样本空间 $\Omega_z$ 上的密度函数, 注意我们这里只假设对u任意的 $x$, $q(z;x)$ 是一个密度函数, 而没有声明 $x$ 是否是随机变量, 因此也更没有联合概率的说法.

我们做如下推导, 对于任意的 $x$:

\[\begin{align} D_{KL}(q(z;x)\parallel p(z|x))&=\int{q(z;x)\ln\frac{q(z;x)}{p(z|x)}dz}\\ &=\int{q(z;x)\ln\frac{q(z;x)p(x)}{p(x,z)}dz}\\ &=\int{q(z;x)\left[\ln p(x)+\ln\frac{q(z;x)}{p(x,z)}\right]dz}\\ &=\ln p(x)+\int{q(z;x)\ln\frac{q(z;x)}{p(z)p(x|z)}dz}\\ &=\ln p(x)+\int{q(z;x)\left[\ln\frac{q(z;x)}{p(z)}-\ln p(x|z)\right]dz}\\ &=\ln p(x)+D_{KL}(q(z;x)\parallel p(z))-\mathbb{E}_{z\sim q(z;x)}\ln p(x|z) \end{align}\]

我们来解释下各项的含义:

  • 等号左侧是 $q(z;x)$ 与条件概率 $p(z x)$ 的 KL 散度.
  • 等号右侧第二项是 $q(z;x)$ 与边缘概率 $q(x)$ 的 KL 散度.
  • 等号右侧第三项是条件概率 $\ln p(x z)$ 作为一个对于 $z$ 的函数来说, 在分布 $q(z;x)$ 下的期望

以上推导是数学上的形式推导, 后续再套上实际含义进行分析.

VAE

问题定义

假设我们有一个训练数据集 $D=\lbrace\mathbf{x}i\rbrace{i=1}^{N}$, 其中 $\mathbf{x}_i\in\mathbb{R}^c$, 我们认为这些样本点是由一个分布 $p(\mathbf{x})$ 里采样而来的, 我们的目标是有个办法在这个分布中里进行采样.

给一个具体的例子: 给定 MNIST 数据集中的图像, 即 60000 张 $1 \times 28 \times 28$ 的图片, 我们认为这些图片作为 $c=784=1 \times 28 \times 28$ 维的向量, 是从一个分布中抽样而来的. 我们现在希望能随机生成一些满足这个分布的样本, 就是生成一些图像, 我们可以预期大概率是一些数字为 $0 \sim 9$ 的图片.

在我们上面的设定下, 优化的目标自然是:

\[\theta^*=\arg\max_{\theta} \prod_{i=1}^N p_\theta(\mathbf{x}_i)\]

注意我们这里假设了我们所有可能的分布由一组参数 $\theta$ 所刻画, 我们的目的是在这些所有可能的分布里找一个最优的 $\theta^*$.

首先, 常见的做法会变成优化这个问题

\[\theta^*=\arg\max_{\theta} \sum_{i=1}^N \ln p_\theta(\mathbf{x}_i)\]

所以我们应该怎么做呢?

思路探讨

可以直接定义一个神经网络来刻画概率密度函数, 这个神经网络的输入是 $c$ 维的, 输出是 1 维的, 并且需要满足概率密度函数的约束条件, 即非负性以及在样本空间上的概率积分为 1. 以我们举的 MNIST 的例子为例, 譬如说我们可以这样定义一个神经网络及优化策略:

import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=3, padding=1),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    
    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(2),

    nn.Flatten(),
    nn.Linear(64 * 7 * 7, 512),
    nn.ReLU(),
    nn.Linear(512, 1),
    nn.ReLU(),  # 因为要保证非负性
)

# model(torch.rand(4, 1, 28, 28))  # output shape: (4, 1)

for batch_image in dataloader:
    # 通常会做些预处理, 例如减均值除方差, 总的来说输入是连续型的(但严格地说在某个区间之外概率密度应该要为 0)
    batch_input = transform(batch_image)
    loss = -model(batch_image).sum()
    loss.backward()
    optimizer.step()

这里有几个问题:

(1) 首先, 最要紧的是我们无法通过这种方式约束概率密度函数积分为 1.

(2) 其次, 假设我们确实找到了这个 $\theta^*$, 而且它也满足了概率密度函数积分为 1 的条件 (几乎不可能), 我们没有办法对这个分布进行抽样

(3) 最后还有个小问题, 其实我们这样定义的这一族函数里, 且满足概率密度函数积分为 1 的条件的函数, 有很大一部分在某个区间之外概率值不为 0. 这个小问题后续基本可以忽略, 因为作为概率密度函数, 其实在某个区间之外, 取值必然是接近 0 的, 不然就没法满足积分和为 1 的条件了.

于是, 有个做法是我们假设 $p_\theta(\mathbf{x})$ 是我们已知的一族常见分布, 而看起来在概率论中我们知道符合条件的分布是多元正态分布 $\mathcal{N}(\boldsymbol{\mu},\Sigma)$, 而正态分布完全由 $\boldsymbol{\mu},\Sigma$ 确定, 因此, 是我们只需要能找到最优的 $\mathcal{N}(\boldsymbol{\mu},\Sigma)$ 即可.

但是这同样存在几个问题:

(1) 首先是假设空间不够大: 我们对 $p_\theta(\mathbf{x})\sim \mathcal{N}(\boldsymbol{\mu}\theta,\Sigma\theta)$ 的这个假设其实不一定成立. 譬如说以 MNIST 数据集为例, 我们做这个假设: 数字 0 的图像满足某个专属于数字 0 的正态分布, 针对数字 1~9 同理(注意实际上在我们的问题设定下, 知道这批图像对应的标签以及标签的含义是不可能的, 这里属于上帝视角), 这看起来似乎是合理的. 这么一来,

\[p(\mathcal{x})=p(\mathcal{x}|z)p(z)\]
其中 $z$ 对应类别, 而 $p(\mathcal{x} z=i)$ 对应于每个数字的正态分布. 由于这里比较特殊, 我们的 $z$ 实际上是一个离散型均匀分布 (注意我们这么说仍然是开上帝视角在审视 MNIST 这个特殊的数据集), 这么一来 $p(\mathcal{x})$ 实际上并不是一个正态分布. 而在我们的思路下, 假定 $p_\theta(\mathbf{x})\sim \mathcal{N}(\boldsymbol{\mu}\theta,\Sigma\theta)$, 在这个分布集合里, 离真实 $p(\mathbf{x})$ 最近的分布可能都相去甚远.

(2) 正态分布的协方差矩阵想要用神经网络拟合是一件困难的事情, 直接的做法是我们需要一个神经网络, 输入是 $c$ 维的, 输出是 $c\times c$ 维的, 还要保证其半正定性, 是比较麻烦的. 当然, 我们可以先拟合 $\Sigma=XX^T$ 中的 $X$ 矩阵, 其中 $X$ 是无约束的 $c\times c$ 矩阵. 尽管如此, 在图像的例子里, $c$ 可能会达到上百万, 而 $c\times c$ 就更是不可想象了. 因此我们可能又不得不对协方差矩阵也做些假设, 例如他是对角矩阵. 但这么一来的话, 假设空间就变得更小了.

上面的思路虽然仍然不可行, 但我们受到 MNIST 例子的启发, 在那个特定的例子里, 如果我们确实有标签信息, 那么实际上已经有解法了, 就是假设 $p(\mathcal{x} z=i)\sim\mathcal{N}(\boldsymbol{\mu_i},\text{diag}(\boldsymbol{\sigma}_i))$. 然后用统计的办法简单估计 $p(z)$. 这样一来, $p(\mathbf{x})$ 有显式解, 并且对 $p(\mathbf{x})$ 的抽样也是简单直接的:
  • 先用多项分布在 $p(z)$ 中抽样出类别 $i$
  • 然后抽样 $\mathbf{x}\sim\mathcal{N}(\mathbf{0}, I)$
  • 最后做变换: $\mathbf{x}’=\boldsymbol{\sigma}_i\odot\mathbf{x}+\boldsymbol{\mu_i}$

TODO: 这么采样为什么是对的, 严格地数学证明.

VAE

由此延申开来, 我们引入一个隐变量 $\mathbf{z}\in\mathbb{R}^d$, 假设存在一个联合概率 $p’(\mathbf{x}, \mathbf{z})$, 使得边缘分布 $p’(\mathbf{x})$ 恰好为 $p(\mathbf{x})$, 后面我们就直接不区分 $p’$ 与 $p$, 而直接认为 $p(\mathbf{x}, \mathbf{z})$ 是联合分布.

如果我们这样做假设, 后面将看到对 $p(\mathbf{x} \mathbf{z};\theta)$ 的假设可以有其他选项:
\[\begin{align} p(\mathbf{z}) &\sim \mathcal{N}(\mathbf{0}, I)\\ p(\mathbf{x}|\mathbf{z};\theta) &\sim \mathcal{N}(\boldsymbol{\mu}(\mathbf{z};\theta),\text{diag}(\boldsymbol{\sigma}^2(\mathbf{z};\theta))) \end{align}\]

那么这样定义的联合概率密度函数组 $p(\mathbf{x}, \mathbf{z})$ 所对应的边缘分布 $p(\mathbf{x})$ 是否可以拟合足够多的函数呢? TODO

问题是针对这样的构想, 我们怎么求得 $p(\mathbf{x})$ 以及怎么做采样呢?

首先采样过程实际上是简单的 (TODO, 怎么证明):

  • 先在 $d$ 维空间中抽样 $\mathbf{z}\sim\mathcal{N}(\mathbf{0}, I)$
  • 然后在 $c$ 维空间中抽样 $\delta_x\sim\mathcal{N}(\mathbf{0}, I)$
  • 然后做变换 $\mathbf{x}=\boldsymbol{\sigma}(\mathbf{z};\theta)\cdot\delta_x+\boldsymbol{\mu}(\mathbf{z};\theta)$

但现在问题是如何求解 $p(\mathbf{x})$, 按定义:

\[p(\mathbf{x})=\int p(\mathbf{x}|\mathbf{z};\theta)p(\mathbf{z}) d\mathbf{z}\]

这个积分我们没法处理: 注意这里联合分布和对 $\mathbf{x}$ 的边缘分布其实都不一定是正态分布.

笔者浅见, 可能在数学上是不合理的: 但其实我们也许可以采几个样做估算, 譬如说在 $\mathbb{R}^d$ 中采样几个 $\mathbf{z}_k$, 用如下式子估算 $p(\mathbf{x})$:

\[p(\mathbf{x})=\text{mean}_{k}(p(\mathbf{x}|\mathbf{z}_k;\theta)p(\mathbf{z}_k))\]

但大概是这种估算太不精确了吧, 所以不能这么做

求解 $p(\mathbf{x})$ 的另一个思路是利用贝叶斯公式

\[p(\mathbf{x})=\frac{p(\mathbf{z})p(\mathbf{x}|\mathbf{z})}{p(\mathbf{z}|\mathbf{x})}\]
我们假设用一组分布 $q(\mathbf{z}; \mathbf{x},\phi)$ 来估计 $p(\mathbf{z} \mathbf{x})$, 对此, 我们似乎又没有什么好的选项, 只好又选择正态分布族 (后面马上会解释这个选择的好处):
\[q(\mathbf{z}; \mathbf{x}, \phi)\sim \mathcal{N}(\boldsymbol{\mu}(\mathbf{x};\phi),\text{diag}(\boldsymbol{\sigma}^2(\mathbf{x};\phi)))\]

利用 KL 散度的性质, 以下这行式子是关键:

\[D_{KL}\left[q(\mathbf{z};\mathbf{x},\phi)\parallel p(\mathbf{z}|\mathbf{x};\theta)\right]-\ln p(\mathbf{x};\theta)=D_{KL}\left[q(\mathbf{z};\mathbf{x},\phi)\parallel p(\mathbf{z})\right]-\mathbb{E}_{\mathbf{z}\sim q(\mathbf{z};\mathbf{x},\phi)}\ln p(\mathbf{x}|\mathbf{z};\theta)\]

注意, 我们的优化目标是 $\min_{\theta}\ln p(\mathbf{x},\theta)$, 通过以上等式, 我们转而优化它的上界, 也就是左侧的第一项总是大于 0 的, 因此也等价于最小化右侧.

这里说明一下, 为什么我们就这样忽略了左侧第一项呢, 主要原因还是左侧第一项无法计算, 原因是按照定义:

\[p(\mathbf{z}|\mathbf{x};\theta)=\frac{p(\mathbf{x}|\mathbf{z};\theta)p(\mathbf{z})}{p(\mathbf{x};\theta)}=\frac{p(\mathbf{x}|\mathbf{z};\theta)p(\mathbf{z})}{\int p(\mathbf{x}|\mathbf{z};\theta)p(\mathbf{z}) d\mathbf{z}}\]

其实就是分母我们无法计算, 这等于回到了我们之前遇到的难题. 因此更不必说计算左侧第一项的 KL 散度.

那么现在问题在于右侧怎么做优化呢? 首先, 右侧的第一项是两个正态分布的 KL 散度, 有显式的计算公式 (这也是为什么选择假设 $q(\mathbf{z};\mathbf{x},\phi)$ 是正态分布的原因, 因为这样可以显式计算):

\[D_{KL}\left[q(\mathbf{z};\mathbf{x},\phi)\parallel p(\mathbf{z})\right]=\frac{1}{2}\left(\sum_{i=1}^{d}{(\sigma_i^2(\mathbf{x};\phi)-\ln\sigma_i^2(\mathbf{x};\phi))}-d+ \boldsymbol{\mu}(\mathbf{x};\phi)^T\boldsymbol{\mu}(\mathbf{x};\phi)\right)\]

而右侧第二项通过采样来估计, 按理来说应该采样多个 $\mathbf{z}\sim q(\mathbf{z};\mathbf{x},\phi)$, 但 VAE 的具体实现上很多时候就仅用了一个采样点, 具体的做法采用了所谓的重参数技巧, 实际上也就是:

  • 从 $d$ 维标准正态分布里抽样出一个点 $\boldsymbol{\epsilon}\sim\mathcal N(\mathbf{0}, I)$
  • 应用变换计算: $\mathbf{z}=\boldsymbol{\sigma}(\mathbf{x};\phi)\odot\boldsymbol{\epsilon}+\boldsymbol{\mu}(\mathbf{x};\phi)$
  • 最后由此估算右侧第二项:
\[\begin{align} \ln p(\mathbf{x}|\mathbf{z};\theta)&=-\frac{c}{2}\ln(2\pi)-\sum_{i=1}^{c}\ln\sigma_i(\mathbf{z};\theta)-\frac{1}{2}[\mathbf{x}-\boldsymbol{\mu}(\mathbf{z};\theta)]^T\Sigma^{-1}(\mathbf{z};\theta)[\mathbf{x}-\boldsymbol{\mu}(\mathbf{z};\theta)] \\ &=-\frac{c}{2}\ln(2\pi)-\frac{1}{2}\sum_{i=1}^{c}\ln\sigma_i^2(\mathbf{z};\theta)-\frac{1}{2}\sum_{i=1}^{c}\frac{[x_i-\mu_i(\mathbf{z}; \theta)]^2}{\sigma_i^2(\mathbf{z};\theta)} \end{align}\]
备注: 在很多网上的实现里, 会将右侧第一项损失乘上 $\lambda$ 作为权重系数(一般小于 1, 例如 0.00025), 而右侧第二项损失被替换为 $\sum_{i=1}^{c}{[x_i-\mu_i(\mathbf{z};\theta)]^2}$. 这其实等价于左侧第一项不变, 右侧第二项的损失被直接替换为了 $\frac{1}{\lambda}*\sum_{i=1}^{c}{[x_i-\mu_i(\mathbf{z};\theta)]^2}$, 以至于由 $\mathbf{z}$ 生成 $\mathbf{x}$ 的神经网络只计算均值 $\boldsymbol{\mu}(\mathbf{z},\theta)$, 而忽略 $\boldsymbol{\sigma}^2(\mathbf{z},\theta)$. 而这么做的本质实际上是将 $p(\mathbf{x} \mathbf{z})$ 的分布假定为了 $\mathcal{N}(\boldsymbol{\mu}(\mathbf{z},\theta),\lambda I)$.

算法流程整理

综合以上, 将算法流程梳理如下:

输入:

  • 数据集 $\lbrace\mathbf{x}i\rbrace{i=1}^N$, 其中 $\mathbf{x}_i\in\mathbb{R}^c$
  • encoder: $\mathbb{R}^c\to(\mathbb{R}^d,\mathbb{R}^d)$, 分别代表 $\boldsymbol{\mu}(\mathbf{x};\theta)$ 和 $\ln(\boldsymbol{\sigma^2}(\mathbf{x};\theta))$.
  • decoder: $\mathbb{R}^d\to(\mathbb{R}^c,\mathbb{R}^c)$, 分别代表 $\boldsymbol{\mu}(\mathbf{z};\phi)$ 和 $\ln(\boldsymbol{\sigma^2}(\mathbf{z};\phi))$.

其中 $d$ 一般远小于 $c$.

训练流程:

对于单个样本 $\mathbf{x}j$, 首先从 $d$ 维 标准正态分布中随机抽样 $S$ 个点, 记为 $\lbrace\boldsymbol{\epsilon}\rbrace{s=1}^{S}$, 然后根据 encoder 计算:

\[\mathbf{z}_s=\boldsymbol{\sigma}(\mathbf{x}_j;\theta)\odot\boldsymbol{\epsilon}_s+\boldsymbol{\mu}(\mathbf{x}_j;\theta),\quad s=1,\ldots,S\]

损失函数按如下定义计算, 然后使用正常基于梯度的优化算法针对 batch 数据进行优化.

\[\begin{align} L_j =& \frac{1}{2}\left(\sum_{i=1}^{d}{(\sigma_i^2(\mathbf{x}_j;\phi)-\ln\sigma_i^2(\mathbf{x}_j;\phi))}-d+ \boldsymbol{\mu}(\mathbf{x}_j;\phi)^T\boldsymbol{\mu}(\mathbf{x}_j;\phi)\right)\\ &+\frac{1}{S}\sum_{s=1}^{S}{\left[\frac{k}{2}\ln(2\pi)+\frac{1}{2}\sum_{i=1}^{c}\ln\sigma_i^2(\mathbf{z}_s;\theta)+\frac{1}{2}\sum_{i=1}^{c}\frac{[x_{ji}-\mu_i(\mathbf{z}_s; \theta)]^2}{\sigma_i^2(\mathbf{z}_s;\theta)}\right]} \end{align}\]

生成

  • 先在 $d$ 维空间中抽样 $\mathbf{z}\sim\mathcal{N}(\mathbf{0}, I)$
  • 然后在 $c$ 维空间中抽样 $\delta_x\sim\mathcal{N}(\mathbf{0}, I)$
  • 然后做变换 $\mathbf{x}=\boldsymbol{\sigma}(\mathbf{z};\theta)\cdot\delta_x+\boldsymbol{\mu}(\mathbf{z};\theta)$

代码实现 (TODO, 不 work)

备注: 这里的实现与网上的大多数实现保持一致, 也就是将 $p(\mathbf{x} \mathbf{z})$ 的分布假定为了 $\mathcal{N}(\boldsymbol{\mu}(\mathbf{z},\theta),\lambda I)$, 其中 $\lambda$ 为超参数, 另外在训练好模型后, 生成数据是, 不加上 $\boldsymbol{\sigma}(\mathbf{z};\theta)\cdot\delta_x$ 这一项.
import math
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms


hidden_dim = 64
epoch_nums = 10
lam = 0.00025 * 28 * 28 / 2  # 0.098
device = "cuda:0"
global_iter_num = 0
print_every = 100

class VAEEncoder(nn.Module):
    def __init__(self, hidden_dim=64):
        super().__init__()
        self.hidden_dim = hidden_dim
        # self.base = nn.Sequential(
        #     nn.Conv2d(1, 32, kernel_size=3, padding=1),
        #     # nn.BatchNorm2d(32),
        #     nn.ReLU(),
        #     nn.MaxPool2d(kernel_size=2, stride=2),
            
        #     nn.Conv2d(32, 64, kernel_size=3, padding=1),
        #     # nn.BatchNorm2d(64),
        #     nn.ReLU(),
        #     nn.MaxPool2d(2),

        #     nn.Flatten(),
        #     nn.Linear(64 * 7 * 7, 512),
        #     nn.ReLU()
        # )
        self.base = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
        )

        self.mean_layer = nn.Linear(512, hidden_dim)
        self.log_var_layer = nn.Linear(512, hidden_dim)
    
    def forward(self, x):
        out = self.base(x)
        mean = self.mean_layer(out)
        log_var = self.log_var_layer(out)
        return mean, log_var
    
class VAEDecoder(nn.Module):
    def __init__(self, hidden_dim=64):
        super().__init__()
        self.hidden_dim = hidden_dim
        # self.base = nn.Sequential(
        #     nn.Linear(hidden_dim, 512),
        #     nn.ReLU(),
        #     nn.Linear(512, 64 * 7 * 7),
        #     nn.Unflatten(1, (64, 7, 7)),

        #     nn.Upsample(scale_factor=2, mode="nearest"),
        #     nn.ReLU(),
        #     # nn.BatchNorm2d(64),
        #     nn.Conv2d(64, 32, kernel_size=3, padding=1),
            
        #     nn.Upsample(scale_factor=2, mode="nearest"),
        #     nn.ReLU(),
        #     # nn.BatchNorm2d(32),
        #     nn.Conv2d(32, 1, kernel_size=3, padding=1),
        # )

        self.base = nn.Sequential(
            nn.Linear(hidden_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Unflatten(1, (1, 28, 28))
        )

    def forward(self, z):
        return self.base(z)


class VAE(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        assert self.encoder.hidden_dim == self.decoder.hidden_dim
        self.hidden_dim = self.encoder.hidden_dim
        self.dim = 1 * 28 * 28

    def reparametrize(self, mean, log_var):
        z = torch.randn_like(mean)
        out = mean + torch.exp(0.5 ** log_var) * z
        return out

    def forward(self, x):
        mean, log_var = self.encoder(x)
        z = self.reparametrize(mean, log_var)
        out = self.decoder(z)
        return mean, log_var, out
    
    def get_loss(self, x, mean, log_var, out, lam):
        batch_size = x.shape[0]
        input_dim = x.numel() / batch_size
        
        reconstruction_loss = ((out - x) ** 2).flatten(1).sum(dim=1).mean() 
        # expectation_loss = reconstruction_loss + input_dim * math.log(2*math.pi*lam)
        expectation_loss = 0.5 * reconstruction_loss

        kld_loss = (torch.exp(log_var)-log_var-1+mean**2).flatten(1).sum(dim=1).mean()
        kld_loss = 0.5 * kld_loss
        weighted_kld_loss = lam * kld_loss
        
        loss = expectation_loss + weighted_kld_loss
        return {
            "total_loss": loss,
            "expectation_loss": expectation_loss,
            "weighted_kld_loss": weighted_kld_loss,
            "kld_loss": kld_loss,
            "reconstruction_loss": reconstruction_loss,
        }
    
    def sample(self, n):
        with torch.no_grad():
            z = torch.randn((n, self.hidden_dim)).to(device)  # TODO
            return self.decoder(z)

encoder = VAEEncoder(hidden_dim)
decoder = VAEDecoder(hidden_dim)
model = VAE(encoder, decoder).to(device)
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[1.])
    ]
)
dataset = torchvision.datasets.MNIST(root="./mnist_dataset", transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
optimizer = torch.optim.Adam(model.parameters())

for epoch in range(epoch_nums):
    for batch_x, batch_y in dataloader:
        batch_x = batch_x.to(device)
        optimizer.zero_grad()
        mean, log_var, out = model(batch_x)
        losses = model.get_loss(batch_x, mean, log_var, out, lam)

        losses["total_loss"].backward()
        optimizer.step()
        global_iter_num += 1
        if global_iter_num % print_every == 0:
            print(f"{global_iter_num} steps")
            for key, value in losses.items():
                print(key, value.item())

# sampling
transforms.ToPILImage()((model.sample(1) + 0.5)[9])