http://joschu.net/blog/kl-approx.html
$$ K L[q, p]=\sum_x q(x) \log \frac{q(x)}{p(x)}=E_{x \sim q}\left[\log \frac{q(x)}{p(x)}\right] $$它解释了一个我在各种代码中使用过的技巧,我将 $K L[q, p]$ 近似为 $\frac{1}{2} (\log p(x) - \log q(x))^2$ 的样本平均值,对于来自 $q$ 的样本 $x$,而不是更标准的 $\log \frac{q(x)}{p(x)}$。 这篇文章将解释为什么这个表达式是 KL 的一个好的(虽然有偏差的)估计器,以及如何在保持其低方差的同时使其无偏差。
我们计算 $KL$ 的选项取决于我们对 $p$ 和 $q$ 有什么样的访问权限。 在这里,我们将假设我们可以计算任何 $x$ 的概率(或概率密度)$p(x)$ 和 $q(x)$,但我们无法解析地计算 $x$ 上的总和。 为什么我们不能解析地计算它呢?
精确计算它需要太多的计算或内存。
没有闭合形式的表达式。
我们可以通过仅存储对数概率(log-prob)来简化代码,而无需存储整个分布。如果KL散度仅用作诊断工具,这会是一个合理的选择,就像在强化学习中经常出现的情况一样。
估计总和或积分的最常见策略是使用蒙特卡洛估计。给定样本 $x_1, x_2, \dots \sim q$,我们如何构建一个好的估计?
一个好的估计量是无偏的(它具有正确的均值)并且具有低方差。我们知道一个无偏估计量(在来自 $q$ 的样本下)是 $\log \frac{q(x)}{p(x)}$。然而,它具有高方差,因为它对于一半的样本是负的,而KL散度始终是正的。让我们将这个朴素估计量称为 $k_1 = \log \frac{q(x)}{p(x)} = - \log r$,其中我们定义了比率 $r=\log \frac{p(x)}{q(x)}$,它将在后续计算中频繁出现。
另一种估计量,它具有较低的方差但有偏差,是 $\frac{1}{2}(\log \log \frac{p(x)}{q(x)})^2 = \frac{1}{2}(\log r)^2$。让我们将这个估计量称为 $k_2$。直观地说,$k_2$ 似乎更好,因为每个样本都会告诉你 $p$ 和 $q$ 相差多远,并且它始终是正的。经验表明,$k_2$ 的方差确实比 $k_1$ 低得多,并且偏差也出奇地低。(我们将在下面的实验中展示这一点。)
$$ D_f(p_0, p_{\theta}) = \tfrac{f''(1)}{2} \theta^T F \theta + O(\theta^3) $$其中 $F$ 是在 $p_{\theta}=p_0$ 处计算的 $p_{\theta}$ 的 Fisher 信息矩阵。
$E_q[k_2]=E_q[\frac{1}{2}(\log r)^2]$ 是一个 f-散度,其对应的函数为 $f(x)=\frac{1}{2} (\log x)^2$,而 $K L[q, p]$ 对应的函数为 $f(x)= - \log x$。容易验证,两者都满足 $f''(1)=1$,因此当 $p\approx q$ 时,它们都表现出相同的二次距离函数形式。
$$ k_3 = (r - 1) - \log r $$通过观察凸函数及其切平面之间的差异来测量距离,这种思想在许多领域都有应用。它被称为Bregman散度,并具有许多优良的性质。
我们可以将上述思想推广,从而为任何f-散度获得一个良好且始终为正的估计量,其中最值得注意的是另一种KL散度 $K L[q, p]$(请注意,这里的p和q互换了位置)。由于f是凸函数,且$E_q[r]=1$,因此,f-散度的估计量可以表示为:$f(r) - f'(1)(r-1)$。这个值始终为正,因为它代表了f在r=1处与其切线之间的距离,而凸函数总是位于其切线的上方。现在,$K L[p, q]$ 对应于 $f(x)=x \log x$,其 $f'(1)=1$,因此我们得到的估计量是 $r \log r - (r - 1)$。
总而言之,对于样本 $x \sim q$,以及 $r = \frac{p(x)}{q(x)}$,我们有以下估计量:
$K L[p,q]$: $r \log r - (r - 1)$
$K L[q, p]$: $(r - 1) - \log r$
现在,让我们比较一下 $K L[q, p]$的三种估计器的偏差和方差。假设 $q=N(0,1)$,$p=N(0.1,1)$。在这种情况下,真实的KL散度为$0.005$。

请注意,$k_2$的偏差在这里非常低,仅为$0.2%$。现在,我们尝试一个更大的真实KL散度。如果 $p=N(1,1)$,则真实的KL散度为$0.5$。

在这种情况下,$k_2$的偏差明显增大。$k_3$的标准差比$k_2$更低,同时又是无偏的,因此它似乎是一个更好的估计器。
以下是我用来获得这些结果的代码:
import torch.distributions as dis
p = dis.Normal(loc=0, scale=1)
q = dis.Normal(loc=0.1, scale=1)
x = q.sample(sample_shape=(10_000_000,))
truekl = dis.kl_divergence(p, q)
print("true", truekl)
logr = p.log_prob(x) - q.log_prob(x)
k1 = -logr
k2 = logr ** 2 / 2
k3 = (logr.exp() - 1) - logr
for k in (k1, k2, k3):
print((k.mean() - truekl) / truekl, k.std() / truekl)