(P0) VAE 浅析
参考资料
- https://borisburkov.net/2022-12-31-1/
- https://lilianweng.github.io/posts/2018-08-12-vae/
- https://spaces.ac.cn/archives/5253
- 原始论文 Auto-Encoding Variational Bayes https://arxiv.org/abs/1312.6114
- Tutorial on Variational Autoencoders https://arxiv.org/abs/1606.05908
数学基础
两个一元正态分布的 KL 散度计算推导: https://statproofbook.github.io/P/norm-kl.html
多元正态分布的 KL 散度可参考 wiki
概率空间
概率空间的严格定义: 概率空间是一个三元组 $(\Omega, \mathcal{F}, P)$, 其中
(1) $\Omega$ 是样本空间, 它是一个集合, 即随机事件的所有可能的取值. 例如在离散情形下, 掷骰子的样本空间是 $\Omega=\lbrace 1, 2, 3, 4, 5, 6\rbrace$; 在连续情形下, 一维正态分布的 $\Omega=\mathbb{R}$.
(2) $\mathcal{F}$ 代表事件域, 它需要是定义在 $\Omega$ 上的一个 $\sigma$ 代数: 换句话说, $\mathcal{F}$ 是一个集合, 它每个元素都是 $\Omega$ 的子集, 并且满足以下三个条件:
- 空集属于 $\mathcal{F}$: $\Phi\in\mathcal{F}$
- 如果一个集合 $A$ 属于 $\mathcal{F}$, 那么它相对于全集 $\Omega$ 的补集 $\bar{A}=\Omega\backslash A$ 也属于 $\mathcal{F}$
- 如果可数个 $A_1, A_2, …$ 都属于 $\mathcal{F}$, 那么它们的并集 $\bigcup_{i=1}A_i$ 也属于 $\mathcal{F}$.
TODO: 什么是可数个集合的并集
(3) $P$ 代表概率测度, 它是定义在事件域上的一个函数, 且需要满足:
- 对于任意 $A\in\mathcal{F}$, $P(A) \geq 0$
- $P(\Omega)=1$
- 假设可数个 $A_1,A_2, …$ 都属于 $\mathcal{F}$, 且两两不相交, 那么 $P(\bigcup_{i}A_i)=\sum_{i}{P(A_i)}$
在离散情形投骰子的情形下, 定义样本空间 $\Omega=\lbrace1, 2, 3, 4, 5, 6\rbrace$, 定义事件域 $\mathcal{F}$ 为 $\Omega$ 的所有可能的子集, 例如: 事件 $\lbrace1, 3, 6\rbrace$ 代表投出的骰子取值是 ${1, 3, 6}$ 之一. $P(\lbrace i\rbrace)=\frac{1}{6}, i\in\Omega$, 注意对于任意 $A\in\mathcal{F}$, $P(A)$ 就完全被确定了.
在连续情形的标准正态分布的情形下, 定义样本空间 $\Omega=\mathbb{R}$, 而定义事件域是 $\mathbb{R}$ 上的博雷尔集, 而概率 $P$ 的定义是:
\[P(A)=\int_{x\in A}{p(x)dx},\]其中
\[p(x)=\frac{1}{\sqrt{2\pi}}\exp^{-\frac{x^2}{2}}\]所谓博雷尔集, 有几种定义形式, 都对应着同样的东西:
(1) $\mathbb{R}$ 上可以拆解为可数个开区间的并集或者交集的集合全体构成的集合.
(2) 包含所有开区间的最小 $\sigma$ 代数
博雷尔集的一些性质是: 它包含所有的开集和闭集(因此开区间和闭区间都在里面), 注意博雷尔集不止包含开集和闭集, 例如博雷尔集包含半开半闭区间, 但它既不是开集也不是闭集
所谓开集的一般性定义是 TODO, 在 $\mathbb{R}$ 上开集的定义是 TODO
所谓闭集的一般性定义是 TODO, 在 $\mathbb{R}$ 上闭集的定义是 TODO
在连续情形下, 一般是在 $\mathbb{R}$ 上定义概率密度函数, 而概率测度里的积分需要对任意博雷尔集上有定义, 严格的说, 这个积分是勒贝格积分而不是高等代数里的黎曼积分. 而同时满足这两个条件概率密度函数可以保证概率测度是良定义的, 即:
(1) $\int_{-\infty}^{+\infty}{p(x)}{dx}=1$
(2) $p(x)\geq 0, \forall x\in\mathbb{R}$
在其他的讨论里, 例如讨论期望和方差时, 还一般会假定概率密度函数的矩的存在性, 一般还会做一些关于概率密度函数连续性之类的假定. 但这些条件不是必须的, 甚至于不一定需要定义密度函数, 只需要概率测度 $P$ 有定义即可. 这些关于 $\mathcal{F}$, 以及勒贝格积分之类的讨论属于高等概率论的范畴
以下讨论主要是对离散型和连续型概率空间:
- 离散型概率空间的概率测度可以完全由概率质量函数 (Probability Mass Function, PMF) 刻画: 一对一的关系
- 连续型概率空间的概率测度可以由概率密度函数 (Probability Density Function, PDF) 刻画: 概率密度函数 -> 概率测度应该是单射, 但可能不一定是双射, TODO, 待确认?
条件分布, 联合分布, 边缘分布
记号上来说, 条件概率的记号是令人有些混淆的: 一般来说, 我们会用 $f(x)$ 和 $g(x)$ 来表示不同的函数, 因为用了 $f$ 和 $g$ 做区分. 然而在条件概率的情形下, $p(x,z)$ 与 $p(x)$ 代表了相同概率测度 $P$ 的联合与边缘概率, 这里的两个 $p$ 可能会有一定的混淆.
边缘概率与条件概率的联系:
\[\begin{align} p(x)&:=\int_z{p(x,z)dz}\\ &=\int_z{p(z)p(x|z)dz} \end{align}\]正态分布及性质
一元正态分布
\[\mathcal{N}(\mu,\sigma^2)\sim \frac{1}{\sigma\sqrt{2\pi}}e^{-\frac{1}{2}\left(\frac{x-\mu}{\sigma}\right)^2}\]特别地, 一元标准正态分布的概率密度函数是
\[\mathcal{N}(0, 1)\sim \frac{1}{\sqrt{2\pi}}e^{-\frac{1}{2}x^2}\]多元正态分布
\[\mathcal{N}(\boldsymbol{\mu},\Sigma)\sim\frac{1}{\sqrt{2\pi}^k}\cdot\frac{1}{\sqrt{\text{det}(\Sigma)}}\exp\left(-\frac{1}{2}(\mathbf{x}-\boldsymbol{\mu})^T\Sigma^{-1}(\mathbf{x}-\boldsymbol{\mu})\right)\]特别地, 多元标准正态分布地概率密度函数
\[\mathcal{N}(\mathbf{0},I)\sim \frac{1}{\sqrt{2\pi}^k}\exp(-\frac{1}{2}\mathbf{x}^T\mathbf{x})\]两个多元正态分布的 KL 散度:
\[D_{KL}(\mathcal{N}(\boldsymbol{\mu}_1,\Sigma_1)\parallel\mathcal{N}(\boldsymbol{\mu}_2,\Sigma_2)) = \frac{1}{2}\left(\text{tr}(\Sigma_2^{-1}\Sigma_1)-k+(\boldsymbol{\mu}_2-\boldsymbol{\mu}_1)^T\Sigma_2^{-1}(\boldsymbol{\mu}_2-\boldsymbol{\mu}_1)+\ln\frac{\text{det}\Sigma_2}{\text{det}\Sigma_1}\right)\]特别地:
\[D_{KL}(\mathcal{N}(\boldsymbol{\mu},\Sigma)\parallel\mathcal{N}(\mathbf{0},I)) = \frac{1}{2}\left(\text{tr}(\Sigma)-k+\boldsymbol{\mu}^T\boldsymbol{\mu}-\ln(\text{det}\Sigma)\right)\]更特别地, 如果 $\Sigma$ 是对角阵, 且对角元素为 $\sigma_1^2,…,\sigma_k^2$, 那么:
\[D_{KL}(\mathcal{N}(\boldsymbol{\mu},\text{diag}(\sigma_1^2,...,\sigma_k^2))\parallel\mathcal{N}(\mathbf{0}, I)) = \frac{1}{2}\left(\sum_{i=1}^{k}{(\sigma_i^2-2\ln\sigma_i)}-k+\boldsymbol{\mu}^T\boldsymbol{\mu}\right)\]反过来, 特别地
\[D_{KL}(\mathcal{N}(\mathbf{0},I) \parallel \mathcal{N}(\boldsymbol{\mu},\Sigma))=\frac{1}{2}\left(\text{tr}(\Sigma^{-1})-k+\boldsymbol{\mu}\Sigma^{-1}\boldsymbol{\mu}+\ln(\text{det}\Sigma)\right)\]更特别地
\[D_{KL}(\mathcal{N}(\mathbf{0},I) \parallel \mathcal{N}(\boldsymbol{\mu},\text{diag}(\sigma_1^2,...,\sigma_k^2)))=\frac{1}{2}\left(\sum_{i=1}^{k}{(\frac{1}{\sigma_i^2}+2\ln\sigma_i)}-k+\sum_{i=1}^{k}\sigma_i^2\mu_i^2\right)\]KL 散度
KL 散度在离散型和连续型概率空间上有定义, 在部分混合型分布上也可以有定义, 我们只考虑连续型和离散型的情况. KL 散度用于刻画两个分布间的“距离”, 具有非负性, 非对称性.
离散型
\[D_{KL}(P\parallel Q)=\sum_{x\in\Omega}{P(x)\ln\frac{P(x)}{Q(x)}}\]其中 $\Omega$ 是样本空间, $P(x)$ 与 $Q(x)$ 是两个概率测度对应的概率质量函数
连续型
\[D_{KL}(P\parallel Q)=\int_{-\infty}^{+\infty}{p(x)\ln\frac{p(x)}{q(x)}}dx\]其中 $p(x)$ 与 $q(x)$ 是两个概率测度对应的概率密度函数
不等式
假设概率空间是 $(\Omega_z \times \Omega_x, \mathcal{F}, P)$, $p(x,z)$ 是对应的联合概率密度函数, 对于每个特定的 $x$ 来说, 记 $p(z | x)$ 是样本空间 $\Omega_z$ 上的边缘密度函数. |
另一方面, 假设 $q(z;x)$ 也是定义在样本空间 $\Omega_z$ 上的密度函数, 注意我们这里只假设对u任意的 $x$, $q(z;x)$ 是一个密度函数, 而没有声明 $x$ 是否是随机变量, 因此也更没有联合概率的说法.
我们做如下推导, 对于任意的 $x$:
\[\begin{align} D_{KL}(q(z;x)\parallel p(z|x))&=\int{q(z;x)\ln\frac{q(z;x)}{p(z|x)}dz}\\ &=\int{q(z;x)\ln\frac{q(z;x)p(x)}{p(x,z)}dz}\\ &=\int{q(z;x)\left[\ln p(x)+\ln\frac{q(z;x)}{p(x,z)}\right]dz}\\ &=\ln p(x)+\int{q(z;x)\ln\frac{q(z;x)}{p(z)p(x|z)}dz}\\ &=\ln p(x)+\int{q(z;x)\left[\ln\frac{q(z;x)}{p(z)}-\ln p(x|z)\right]dz}\\ &=\ln p(x)+D_{KL}(q(z;x)\parallel p(z))-\mathbb{E}_{z\sim q(z;x)}\ln p(x|z) \end{align}\]我们来解释下各项的含义:
-
等号左侧是 $q(z;x)$ 与条件概率 $p(z x)$ 的 KL 散度. - 等号右侧第二项是 $q(z;x)$ 与边缘概率 $q(x)$ 的 KL 散度.
-
等号右侧第三项是条件概率 $\ln p(x z)$ 作为一个对于 $z$ 的函数来说, 在分布 $q(z;x)$ 下的期望
以上推导是数学上的形式推导, 后续再套上实际含义进行分析.
VAE
问题定义
假设我们有一个训练数据集 $D=\lbrace\mathbf{x}i\rbrace{i=1}^{N}$, 其中 $\mathbf{x}_i\in\mathbb{R}^c$, 我们认为这些样本点是由一个分布 $p(\mathbf{x})$ 里采样而来的, 我们的目标是有个办法在这个分布中里进行采样.
给一个具体的例子: 给定 MNIST 数据集中的图像, 即 60000 张 $1 \times 28 \times 28$ 的图片, 我们认为这些图片作为 $c=784=1 \times 28 \times 28$ 维的向量, 是从一个分布中抽样而来的. 我们现在希望能随机生成一些满足这个分布的样本, 就是生成一些图像, 我们可以预期大概率是一些数字为 $0 \sim 9$ 的图片.
在我们上面的设定下, 优化的目标自然是:
\[\theta^*=\arg\max_{\theta} \prod_{i=1}^N p_\theta(\mathbf{x}_i)\]注意我们这里假设了我们所有可能的分布由一组参数 $\theta$ 所刻画, 我们的目的是在这些所有可能的分布里找一个最优的 $\theta^*$.
首先, 常见的做法会变成优化这个问题
\[\theta^*=\arg\max_{\theta} \sum_{i=1}^N \ln p_\theta(\mathbf{x}_i)\]所以我们应该怎么做呢?
思路探讨
可以直接定义一个神经网络来刻画概率密度函数, 这个神经网络的输入是 $c$ 维的, 输出是 1 维的, 并且需要满足概率密度函数的约束条件, 即非负性以及在样本空间上的概率积分为 1. 以我们举的 MNIST 的例子为例, 譬如说我们可以这样定义一个神经网络及优化策略:
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64 * 7 * 7, 512),
nn.ReLU(),
nn.Linear(512, 1),
nn.ReLU(), # 因为要保证非负性
)
# model(torch.rand(4, 1, 28, 28)) # output shape: (4, 1)
for batch_image in dataloader:
# 通常会做些预处理, 例如减均值除方差, 总的来说输入是连续型的(但严格地说在某个区间之外概率密度应该要为 0)
batch_input = transform(batch_image)
loss = -model(batch_image).sum()
loss.backward()
optimizer.step()
这里有几个问题:
(1) 首先, 最要紧的是我们无法通过这种方式约束概率密度函数积分为 1.
(2) 其次, 假设我们确实找到了这个 $\theta^*$, 而且它也满足了概率密度函数积分为 1 的条件 (几乎不可能), 我们没有办法对这个分布进行抽样
(3) 最后还有个小问题, 其实我们这样定义的这一族函数里, 且满足概率密度函数积分为 1 的条件的函数, 有很大一部分在某个区间之外概率值不为 0. 这个小问题后续基本可以忽略, 因为作为概率密度函数, 其实在某个区间之外, 取值必然是接近 0 的, 不然就没法满足积分和为 1 的条件了.
于是, 有个做法是我们假设 $p_\theta(\mathbf{x})$ 是我们已知的一族常见分布, 而看起来在概率论中我们知道符合条件的分布是多元正态分布 $\mathcal{N}(\boldsymbol{\mu},\Sigma)$, 而正态分布完全由 $\boldsymbol{\mu},\Sigma$ 确定, 因此, 是我们只需要能找到最优的 $\mathcal{N}(\boldsymbol{\mu},\Sigma)$ 即可.
但是这同样存在几个问题:
(1) 首先是假设空间不够大: 我们对 $p_\theta(\mathbf{x})\sim \mathcal{N}(\boldsymbol{\mu}\theta,\Sigma\theta)$ 的这个假设其实不一定成立. 譬如说以 MNIST 数据集为例, 我们做这个假设: 数字 0 的图像满足某个专属于数字 0 的正态分布, 针对数字 1~9 同理(注意实际上在我们的问题设定下, 知道这批图像对应的标签以及标签的含义是不可能的, 这里属于上帝视角), 这看起来似乎是合理的. 这么一来,
\[p(\mathcal{x})=p(\mathcal{x}|z)p(z)\]其中 $z$ 对应类别, 而 $p(\mathcal{x} | z=i)$ 对应于每个数字的正态分布. 由于这里比较特殊, 我们的 $z$ 实际上是一个离散型均匀分布 (注意我们这么说仍然是开上帝视角在审视 MNIST 这个特殊的数据集), 这么一来 $p(\mathcal{x})$ 实际上并不是一个正态分布. 而在我们的思路下, 假定 $p_\theta(\mathbf{x})\sim \mathcal{N}(\boldsymbol{\mu}\theta,\Sigma\theta)$, 在这个分布集合里, 离真实 $p(\mathbf{x})$ 最近的分布可能都相去甚远. |
(2) 正态分布的协方差矩阵想要用神经网络拟合是一件困难的事情, 直接的做法是我们需要一个神经网络, 输入是 $c$ 维的, 输出是 $c\times c$ 维的, 还要保证其半正定性, 是比较麻烦的. 当然, 我们可以先拟合 $\Sigma=XX^T$ 中的 $X$ 矩阵, 其中 $X$ 是无约束的 $c\times c$ 矩阵. 尽管如此, 在图像的例子里, $c$ 可能会达到上百万, 而 $c\times c$ 就更是不可想象了. 因此我们可能又不得不对协方差矩阵也做些假设, 例如他是对角矩阵. 但这么一来的话, 假设空间就变得更小了.
上面的思路虽然仍然不可行, 但我们受到 MNIST 例子的启发, 在那个特定的例子里, 如果我们确实有标签信息, 那么实际上已经有解法了, 就是假设 $p(\mathcal{x} | z=i)\sim\mathcal{N}(\boldsymbol{\mu_i},\text{diag}(\boldsymbol{\sigma}_i))$. 然后用统计的办法简单估计 $p(z)$. 这样一来, $p(\mathbf{x})$ 有显式解, 并且对 $p(\mathbf{x})$ 的抽样也是简单直接的: |
- 先用多项分布在 $p(z)$ 中抽样出类别 $i$
- 然后抽样 $\mathbf{x}\sim\mathcal{N}(\mathbf{0}, I)$
- 最后做变换: $\mathbf{x}’=\boldsymbol{\sigma}_i\odot\mathbf{x}+\boldsymbol{\mu_i}$
TODO: 这么采样为什么是对的, 严格地数学证明.
VAE
由此延申开来, 我们引入一个隐变量 $\mathbf{z}\in\mathbb{R}^d$, 假设存在一个联合概率 $p’(\mathbf{x}, \mathbf{z})$, 使得边缘分布 $p’(\mathbf{x})$ 恰好为 $p(\mathbf{x})$, 后面我们就直接不区分 $p’$ 与 $p$, 而直接认为 $p(\mathbf{x}, \mathbf{z})$ 是联合分布.
如果我们这样做假设, 后面将看到对 $p(\mathbf{x} | \mathbf{z};\theta)$ 的假设可以有其他选项: |
那么这样定义的联合概率密度函数组 $p(\mathbf{x}, \mathbf{z})$ 所对应的边缘分布 $p(\mathbf{x})$ 是否可以拟合足够多的函数呢? TODO
问题是针对这样的构想, 我们怎么求得 $p(\mathbf{x})$ 以及怎么做采样呢?
首先采样过程实际上是简单的 (TODO, 怎么证明):
- 先在 $d$ 维空间中抽样 $\mathbf{z}\sim\mathcal{N}(\mathbf{0}, I)$
- 然后在 $c$ 维空间中抽样 $\delta_x\sim\mathcal{N}(\mathbf{0}, I)$
- 然后做变换 $\mathbf{x}=\boldsymbol{\sigma}(\mathbf{z};\theta)\cdot\delta_x+\boldsymbol{\mu}(\mathbf{z};\theta)$
但现在问题是如何求解 $p(\mathbf{x})$, 按定义:
\[p(\mathbf{x})=\int p(\mathbf{x}|\mathbf{z};\theta)p(\mathbf{z}) d\mathbf{z}\]这个积分我们没法处理: 注意这里联合分布和对 $\mathbf{x}$ 的边缘分布其实都不一定是正态分布.
笔者浅见, 可能在数学上是不合理的: 但其实我们也许可以采几个样做估算, 譬如说在 $\mathbb{R}^d$ 中采样几个 $\mathbf{z}_k$, 用如下式子估算 $p(\mathbf{x})$:
\[p(\mathbf{x})=\text{mean}_{k}(p(\mathbf{x}|\mathbf{z}_k;\theta)p(\mathbf{z}_k))\]但大概是这种估算太不精确了吧, 所以不能这么做
求解 $p(\mathbf{x})$ 的另一个思路是利用贝叶斯公式
\[p(\mathbf{x})=\frac{p(\mathbf{z})p(\mathbf{x}|\mathbf{z})}{p(\mathbf{z}|\mathbf{x})}\]我们假设用一组分布 $q(\mathbf{z}; \mathbf{x},\phi)$ 来估计 $p(\mathbf{z} | \mathbf{x})$, 对此, 我们似乎又没有什么好的选项, 只好又选择正态分布族 (后面马上会解释这个选择的好处): |
利用 KL 散度的性质, 以下这行式子是关键:
\[D_{KL}\left[q(\mathbf{z};\mathbf{x},\phi)\parallel p(\mathbf{z}|\mathbf{x};\theta)\right]-\ln p(\mathbf{x};\theta)=D_{KL}\left[q(\mathbf{z};\mathbf{x},\phi)\parallel p(\mathbf{z})\right]-\mathbb{E}_{\mathbf{z}\sim q(\mathbf{z};\mathbf{x},\phi)}\ln p(\mathbf{x}|\mathbf{z};\theta)\]注意, 我们的优化目标是 $\min_{\theta}\ln p(\mathbf{x},\theta)$, 通过以上等式, 我们转而优化它的上界, 也就是左侧的第一项总是大于 0 的, 因此也等价于最小化右侧.
这里说明一下, 为什么我们就这样忽略了左侧第一项呢, 主要原因还是左侧第一项无法计算, 原因是按照定义:
\[p(\mathbf{z}|\mathbf{x};\theta)=\frac{p(\mathbf{x}|\mathbf{z};\theta)p(\mathbf{z})}{p(\mathbf{x};\theta)}=\frac{p(\mathbf{x}|\mathbf{z};\theta)p(\mathbf{z})}{\int p(\mathbf{x}|\mathbf{z};\theta)p(\mathbf{z}) d\mathbf{z}}\]其实就是分母我们无法计算, 这等于回到了我们之前遇到的难题. 因此更不必说计算左侧第一项的 KL 散度.
那么现在问题在于右侧怎么做优化呢? 首先, 右侧的第一项是两个正态分布的 KL 散度, 有显式的计算公式 (这也是为什么选择假设 $q(\mathbf{z};\mathbf{x},\phi)$ 是正态分布的原因, 因为这样可以显式计算):
\[D_{KL}\left[q(\mathbf{z};\mathbf{x},\phi)\parallel p(\mathbf{z})\right]=\frac{1}{2}\left(\sum_{i=1}^{d}{(\sigma_i^2(\mathbf{x};\phi)-\ln\sigma_i^2(\mathbf{x};\phi))}-d+ \boldsymbol{\mu}(\mathbf{x};\phi)^T\boldsymbol{\mu}(\mathbf{x};\phi)\right)\]而右侧第二项通过采样来估计, 按理来说应该采样多个 $\mathbf{z}\sim q(\mathbf{z};\mathbf{x},\phi)$, 但 VAE 的具体实现上很多时候就仅用了一个采样点, 具体的做法采用了所谓的重参数技巧, 实际上也就是:
- 从 $d$ 维标准正态分布里抽样出一个点 $\boldsymbol{\epsilon}\sim\mathcal N(\mathbf{0}, I)$
- 应用变换计算: $\mathbf{z}=\boldsymbol{\sigma}(\mathbf{x};\phi)\odot\boldsymbol{\epsilon}+\boldsymbol{\mu}(\mathbf{x};\phi)$
- 最后由此估算右侧第二项:
备注: 在很多网上的实现里, 会将右侧第一项损失乘上 $\lambda$ 作为权重系数(一般小于 1, 例如 0.00025), 而右侧第二项损失被替换为 $\sum_{i=1}^{c}{[x_i-\mu_i(\mathbf{z};\theta)]^2}$. 这其实等价于左侧第一项不变, 右侧第二项的损失被直接替换为了 $\frac{1}{\lambda}*\sum_{i=1}^{c}{[x_i-\mu_i(\mathbf{z};\theta)]^2}$, 以至于由 $\mathbf{z}$ 生成 $\mathbf{x}$ 的神经网络只计算均值 $\boldsymbol{\mu}(\mathbf{z},\theta)$, 而忽略 $\boldsymbol{\sigma}^2(\mathbf{z},\theta)$. 而这么做的本质实际上是将 $p(\mathbf{x} | \mathbf{z})$ 的分布假定为了 $\mathcal{N}(\boldsymbol{\mu}(\mathbf{z},\theta),\lambda I)$. |
算法流程整理
综合以上, 将算法流程梳理如下:
输入:
- 数据集 $\lbrace\mathbf{x}i\rbrace{i=1}^N$, 其中 $\mathbf{x}_i\in\mathbb{R}^c$
- encoder: $\mathbb{R}^c\to(\mathbb{R}^d,\mathbb{R}^d)$, 分别代表 $\boldsymbol{\mu}(\mathbf{x};\theta)$ 和 $\ln(\boldsymbol{\sigma^2}(\mathbf{x};\theta))$.
- decoder: $\mathbb{R}^d\to(\mathbb{R}^c,\mathbb{R}^c)$, 分别代表 $\boldsymbol{\mu}(\mathbf{z};\phi)$ 和 $\ln(\boldsymbol{\sigma^2}(\mathbf{z};\phi))$.
其中 $d$ 一般远小于 $c$.
训练流程:
对于单个样本 $\mathbf{x}j$, 首先从 $d$ 维 标准正态分布中随机抽样 $S$ 个点, 记为 $\lbrace\boldsymbol{\epsilon}\rbrace{s=1}^{S}$, 然后根据 encoder 计算:
\[\mathbf{z}_s=\boldsymbol{\sigma}(\mathbf{x}_j;\theta)\odot\boldsymbol{\epsilon}_s+\boldsymbol{\mu}(\mathbf{x}_j;\theta),\quad s=1,\ldots,S\]损失函数按如下定义计算, 然后使用正常基于梯度的优化算法针对 batch 数据进行优化.
\[\begin{align} L_j =& \frac{1}{2}\left(\sum_{i=1}^{d}{(\sigma_i^2(\mathbf{x}_j;\phi)-\ln\sigma_i^2(\mathbf{x}_j;\phi))}-d+ \boldsymbol{\mu}(\mathbf{x}_j;\phi)^T\boldsymbol{\mu}(\mathbf{x}_j;\phi)\right)\\ &+\frac{1}{S}\sum_{s=1}^{S}{\left[\frac{k}{2}\ln(2\pi)+\frac{1}{2}\sum_{i=1}^{c}\ln\sigma_i^2(\mathbf{z}_s;\theta)+\frac{1}{2}\sum_{i=1}^{c}\frac{[x_{ji}-\mu_i(\mathbf{z}_s; \theta)]^2}{\sigma_i^2(\mathbf{z}_s;\theta)}\right]} \end{align}\]生成
- 先在 $d$ 维空间中抽样 $\mathbf{z}\sim\mathcal{N}(\mathbf{0}, I)$
- 然后在 $c$ 维空间中抽样 $\delta_x\sim\mathcal{N}(\mathbf{0}, I)$
- 然后做变换 $\mathbf{x}=\boldsymbol{\sigma}(\mathbf{z};\theta)\cdot\delta_x+\boldsymbol{\mu}(\mathbf{z};\theta)$
代码实现 (TODO, 不 work)
备注: 这里的实现与网上的大多数实现保持一致, 也就是将 $p(\mathbf{x} | \mathbf{z})$ 的分布假定为了 $\mathcal{N}(\boldsymbol{\mu}(\mathbf{z},\theta),\lambda I)$, 其中 $\lambda$ 为超参数, 另外在训练好模型后, 生成数据是, 不加上 $\boldsymbol{\sigma}(\mathbf{z};\theta)\cdot\delta_x$ 这一项. |
import math
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
hidden_dim = 64
epoch_nums = 10
lam = 0.00025 * 28 * 28 / 2 # 0.098
device = "cuda:0"
global_iter_num = 0
print_every = 100
class VAEEncoder(nn.Module):
def __init__(self, hidden_dim=64):
super().__init__()
self.hidden_dim = hidden_dim
# self.base = nn.Sequential(
# nn.Conv2d(1, 32, kernel_size=3, padding=1),
# # nn.BatchNorm2d(32),
# nn.ReLU(),
# nn.MaxPool2d(kernel_size=2, stride=2),
# nn.Conv2d(32, 64, kernel_size=3, padding=1),
# # nn.BatchNorm2d(64),
# nn.ReLU(),
# nn.MaxPool2d(2),
# nn.Flatten(),
# nn.Linear(64 * 7 * 7, 512),
# nn.ReLU()
# )
self.base = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
)
self.mean_layer = nn.Linear(512, hidden_dim)
self.log_var_layer = nn.Linear(512, hidden_dim)
def forward(self, x):
out = self.base(x)
mean = self.mean_layer(out)
log_var = self.log_var_layer(out)
return mean, log_var
class VAEDecoder(nn.Module):
def __init__(self, hidden_dim=64):
super().__init__()
self.hidden_dim = hidden_dim
# self.base = nn.Sequential(
# nn.Linear(hidden_dim, 512),
# nn.ReLU(),
# nn.Linear(512, 64 * 7 * 7),
# nn.Unflatten(1, (64, 7, 7)),
# nn.Upsample(scale_factor=2, mode="nearest"),
# nn.ReLU(),
# # nn.BatchNorm2d(64),
# nn.Conv2d(64, 32, kernel_size=3, padding=1),
# nn.Upsample(scale_factor=2, mode="nearest"),
# nn.ReLU(),
# # nn.BatchNorm2d(32),
# nn.Conv2d(32, 1, kernel_size=3, padding=1),
# )
self.base = nn.Sequential(
nn.Linear(hidden_dim, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 28*28),
nn.Unflatten(1, (1, 28, 28))
)
def forward(self, z):
return self.base(z)
class VAE(nn.Module):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
assert self.encoder.hidden_dim == self.decoder.hidden_dim
self.hidden_dim = self.encoder.hidden_dim
self.dim = 1 * 28 * 28
def reparametrize(self, mean, log_var):
z = torch.randn_like(mean)
out = mean + torch.exp(0.5 ** log_var) * z
return out
def forward(self, x):
mean, log_var = self.encoder(x)
z = self.reparametrize(mean, log_var)
out = self.decoder(z)
return mean, log_var, out
def get_loss(self, x, mean, log_var, out, lam):
batch_size = x.shape[0]
input_dim = x.numel() / batch_size
reconstruction_loss = ((out - x) ** 2).flatten(1).sum(dim=1).mean()
# expectation_loss = reconstruction_loss + input_dim * math.log(2*math.pi*lam)
expectation_loss = 0.5 * reconstruction_loss
kld_loss = (torch.exp(log_var)-log_var-1+mean**2).flatten(1).sum(dim=1).mean()
kld_loss = 0.5 * kld_loss
weighted_kld_loss = lam * kld_loss
loss = expectation_loss + weighted_kld_loss
return {
"total_loss": loss,
"expectation_loss": expectation_loss,
"weighted_kld_loss": weighted_kld_loss,
"kld_loss": kld_loss,
"reconstruction_loss": reconstruction_loss,
}
def sample(self, n):
with torch.no_grad():
z = torch.randn((n, self.hidden_dim)).to(device) # TODO
return self.decoder(z)
encoder = VAEEncoder(hidden_dim)
decoder = VAEDecoder(hidden_dim)
model = VAE(encoder, decoder).to(device)
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[1.])
]
)
dataset = torchvision.datasets.MNIST(root="./mnist_dataset", transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(epoch_nums):
for batch_x, batch_y in dataloader:
batch_x = batch_x.to(device)
optimizer.zero_grad()
mean, log_var, out = model(batch_x)
losses = model.get_loss(batch_x, mean, log_var, out, lam)
losses["total_loss"].backward()
optimizer.step()
global_iter_num += 1
if global_iter_num % print_every == 0:
print(f"{global_iter_num} steps")
for key, value in losses.items():
print(key, value.item())
# sampling
transforms.ToPILImage()((model.sample(1) + 0.5)[9])