Re-understanding KL Approximation from an RL-for-LLM Lens: Notes on “Approximating KL Divergence”

Community Article Published August 11, 2025

What's the difference between KL-divergence estimation methods used in PPO and GRPO?

John Schulman’s blog post “Approximating KL Divergence” talks about how to approximate KL divergence via sampling (Monte Carlo), and it introduces three estimators (\(k_1\), k2k_2, k3k_3) along with their bias–variance behaviors. But the original post is framed in the context of general probability distributions; it doesn’t touch the reinforcement-learning-for-LLM training setting. This write-up records the questions I had while reading, the thoughts I formed after mapping things to RL for LLMs, and a few places where I felt the original explanations could be pushed on a bit.

What “Approximating KL Divergence” Says (in my own words)

In this section, I’m imagining readers who haven’t read the original post yet, so let’s quickly pass through the most important bits. Put simply, the post is about how we can build reasonable Monte Carlo–style estimators for it, when we can’t directly compute KL divergence.

KL(q,p)=xq(x)logq(x)p(x)=Exq ⁣[logq(x)p(x)]. \mathrm{KL}(q, p) = \sum_x q(x)\,\log\frac{q(x)}{p(x)} = \mathbb{E}_{x\sim q}\!\left[\log\frac{q(x)}{p(x)}\right].

As the formula shows: when estimating the KL between two (complicated) distributions, there’s a coding trick people often use: just approximate KL by the sample mean of log ⁣(q(x)p(x))\log\!\big(\frac{q(x)}{p(x)}\big) with samples drawn from qq (as opposed to trying to evaluate the full expectation exactly). The post then points out another move: use the sample average of 12(logr)2\tfrac{1}{2}(\log r)^2 to replace the more “standard” logr\log r form, where r=q(x)p(x)r=\frac{q(x)}{p(x)}. The write-up explains why this expression can be a good (albeit biased) estimator for KL, and how to make it unbiased while keeping its low variance.

How we compute KL depends on how we can access pp and qq. Here we assume we can evaluate p(x)p(x) and q(x)q(x) (probabilities or densities) for any xx, but we can’t analytically sum/integrate over xx. Why might we fail to do the analytic sum/integral? Maybe the exact computation is too expensive in compute or memory, maybe there’s no closed form, or maybe we only store log-probs instead of full distributions to keep code simpler, especially fine when KL is just a diagnostic (as is often the case in RL). The most common strategy to approximate sums or integrals is Monte Carlo. Given samples x1,x2,,xnqx_1, x_2, \dots, x_n \sim q, how do we build a good estimator?

A good estimator should be unbiased (right mean) and low-variance. We know one unbiased estimator:

k1=logq(x)p(x). k_1 = \log\frac{q(x)}{p(x)}.

But it has high variance: by definition KL is a nonnegative quantity, yet for the estimator above, roughly “half” the sample values can be negative (if we assume no prior structure on pp and qq), and this swings the average around a lot, hence high variance. For notational convenience, set r=q(x)p(x)r = \frac{q(x)}{p(x)}. Then the original KL can be written as

KL[q,p]  =  Exq[logr]. \mathrm{KL}[q, p] \;=\; \mathbb{E}_{x\sim q}\,[\log r].

To reduce variance, we can design an alternative estimator: k2=12(logr)2. k_2 = \frac{1}{2}(\log r)^2. It has lower variance, but it’s biased. Intuitively, k2k_2 feels nicer because each sample gives a nonnegative “distance” between pp and qq, so it stays positive. Empirically, k2k_2 really does have much lower variance than k1k_1, and the bias can be quite small. As for why k2k_2 enjoys such a variance drop compared with k1k_1, the original post uses an ff-divergence view to give an analytic explanation, and I won’t repeat that here.

Now, can we get an estimator that is both unbiased and low-variance? A general trick is to use a control variate: start from the unbiased k1k_1 and add something whose expectation is zero and is negatively correlated with it to reduce variance. A very convenient zero-mean quantity here is r1r-1. Thus, for any λ\lambda, k  =  logr+λ(r1) k \;=\; -\log r + \lambda\,(r-1) is still an unbiased KL estimator. In theory we could minimize the variance over λ\lambda, but the closed-form depends on pp and qq and isn’t easy to get. Notice, though, that since log(x)\log(x) is concave, log(x)    x1, \log(x) \;\le\; x-1, so if we pick λ=1\lambda=1, the expression is guaranteed nonnegative. Here, r1r-1 is the tangent of logr\log r at r=1r=1. So with λ=1\lambda=1 we’re really measuring the vertical gap between log(x)\log(x) and its tangent. This leads to the estimator k3  =  (r1)    logr, k_3 \;=\; (r - 1) \;-\; \log r, which is always nonnegative. And k3k_3 is exactly the piece where, in practice, GRPO differs from PPO in how KL is estimated (PPO uses k1k_1).

Discussing KL Estimation from an RL-for-LLM Perspective

In RL (think PPO, GRPO, etc.), we often tack a KL divergence term onto the loss to keep the new policy from drifting too far from the old one. Here, qq is the old policy distribution ( πold\pi_{\text{old}} ), pp is the new policy distribution ( πnew\pi_{\text{new}} ), and xx is a complete action sample (in an LLM this means a token or a token sequence). We usually use ss to denote the state (in an LLM, that’s the prompt or context), and xx is a specific token generated in that context. When we compute KL, what we’re really doing is taking the KL over the action distribution given a state, and then averaging over states:

KL[p,q]=Es[xp(xs)logp(xs)q(xs)]. \mathrm{KL}[p, q] = \mathbb{E}_{s} \left[ \sum_x p(x|s) \log \frac{p(x|s)}{q(x|s)} \right].

At sampling time, we typically fix a prompt (state) and then estimate this KL for that prompt.

So why can’t we just compute KL exactly instead of estimating it? The reasons are exactly those listed in the original blog post; in RL for LLMs, the main culprit is Reason #1: the action space (token space) is too large to sum/integrate over all possible xx. For example, if a tokenizer has 50,000 vocabulary entries, even computing the KL for a single token means summing over 50,000 actions; and in RL we’re usually doing multi-step (sequence) generation, so the space blows up exponentially, which is completely impractical. There’s also a pragmatic reason: during training we generally don’t store the full distribution (all token probabilities); we only keep the log-probs of the tokens actually generated along the trajectory, to save GPU memory and I/O. So we have to use Monte Carlo sampling: draw xx from some distribution (usually qq, the old policy), and use those samples to approximate KL. And that drops us squarely into the territory the blog post is about.

In that post, the estimator we keep talking about is really just a function of a sample: it takes p(x)p(x) and q(x)q(x) for some sampled xx (or their ratio r=q(x)p(x)r = \frac{q(x)}{p(x)}) and spits out a number. We then take the average of those numbers over our samples to approximate KL. For example:

  • k1(x)=logrk_1(x) = -\log r
  • k2(x)=12(logr)2k_2(x) = \frac12 (\log r)^2
  • k3(x)=(r1)logrk_3(x) = (r - 1) - \log r

These kik_i are just different KL-estimator formulas. They all approximate KL by averaging over samples, but differ in bias and variance. Once we pick an estimator, we’re really just committing to a specific formula for approximating KL. The process looks like this:

  1. Sampling
    Sample a batch of tokens (or sequences) x1,x2,,xNx_1, x_2, \dots, x_N from the old policy qq.
  2. Compute log-probs
    For each sample, compute the log-probabilities under both new and old policies:

logp(xi), logq(xi) \log p(x_i),\ \log q(x_i)

and get ri=q(xi)p(xi)r_i = \frac{q(x_i)}{p(x_i)} or logri\log r_i. 3. Plug into the estimator formula
For example, if we choose k3k_3:

k3(xi)=(ri1)logri k_3(x_i) = (r_i - 1) - \log r_i

  1. Average

KL^1Ni=1Nk3(xi) \widehat{\mathrm{KL}} \approx \frac1N \sum_{i=1}^N k_3(x_i)

That’s the approximate KL value, standing in for the true KL.

If we compare this to computing the true KL (no estimation) for a discrete probability distribution (LLM single-token step): we’d need to iterate over every possible token xx: KL(pq)=xp(x)logp(x)q(x) \mathrm{KL}(p\|q) = \sum_x p(x) \log \frac{p(x)}{q(x)} You can see immediately that with an estimator, the computational load is much smaller than doing the full sum, especially in high-dimensional action spaces.

Talking About Variance in Different KL Estimators

Important to note: the “variance” we’re talking about here is the variance of the values the estimator outputs over samples: Varxq[k(x)] \mathrm{Var}_{x \sim q}[k(x)] That is, how much k(x)k(x) fluctuates across the sample space. An unbiased estimator means that with infinitely many samples, its mean equals the true KL. But a high-variance estimator means that even if the mean is right (unbiased), with a small number of samples the average can be way off. In RL for LLMs, the KL term is often a regularization factor in the loss (e.g., βKL\beta \cdot \mathrm{KL}). If the KL estimator’s variance is large, it makes the loss noisy, which in turn makes gradients noisy and training unstable.

In the original post, to give readers an intuition for why k1k_1 is not low-variance, the author writes:

However, it ( k1k_1 ) has high-variance, as it’s negative for half of the samples, whereas KL is always positive.

The author points out that although k1k_1 is unbiased, without prior constraints on pp and qq, half the samples will have one bigger than the other, so half of the k1k_1 values are positive and half negative. Up to here I’m fine. But then the author says: because KL is always greater than 0 (a basic inequality), k1k_1 must therefore have high variance. And here I think the causal link doesn’t actually hold: you can’t use the sign of the expectation to dictate the sign of individual samples. A quick counterexample: in computing the expectation, p(x)logp(x)q(x)p(x) \log \frac{p(x)}{q(x)} is also sometimes positive and sometimes negative; that fact by itself tells you nothing about variance. In reality, a single-sample log ratio (whether logq(x)p(x)\log \frac{q(x)}{p(x)} or logp(x)q(x)\log \frac{p(x)}{q(x)}) can be positive or negative, just like k1k_1, so sign-flipping alone is not the sole reason for high variance.

From the KL definition: KL(qp)=Exq[logq(x)p(x)] \mathrm{KL}(q \| p) = \mathbb{E}_{x\sim q}\left[ \log \frac{q(x)}{p(x)} \right] The expectation is guaranteed nonnegative, but the integrand logq(x)p(x)\log\frac{q(x)}{p(x)} can be positive or negative for individual samples. And k1k_1 is exactly that integrand: k1(x)=logq(x)p(x) k_1(x) = \log \frac{q(x)}{p(x)} So each sample value can indeed be positive or negative, same as the integrand in the KL definition.

So why does k1k_1 have high variance?

It’s not the mere “sign flipping.” The real reason is that k1k_1’s value distribution is often wide (heavy-tailed). For example, if p(x)p(x) is tiny for some sample, then logqp\log\frac{q}{p} can be huge (positive or negative). These extreme values dominate the finite-sample average, pushing variance up. In other words, it’s the combination of extreme values + positive/negative cancellation: cancellation means you need more samples to converge to the true mean, and extreme values make the sample variance itself larger. So the “half negative” comment in the blog is more of an intuition hook than a complete explanation.

From this perspective, if we look at the other estimators k2k_2 and k3k_3, we see: k2=12(logr)2k_2 = \frac12 (\log r)^2 is always positive, so there’s no cancellation, but this introduces bias; squaring also smooths the magnitude, reducing variance. k3k_3 uses a control variate to knock out part of the fluctuation source, lowering variance while keeping unbiasedness (details next).

In PPO/GRPO, if you use k1k_1 and the batch is small or the distributions are far apart, the KL estimate will jump around (because a few extreme samples can swing the mean hard). That makes the KL-penalty coefficient unstable: it might suddenly be way too strong or too weak. Switching to a lower-variance estimator ( k2k_2 or k3k_3 ) makes each sample’s KL contribution steadier, less likely to be dominated by a handful of extreme samples.

Why can k3k_3 be unbiased and low-variance?

At first glance, k3k_3 is always positive, so you might think its mean must be larger than k1k_1’s.
But remember: k3k_3 is derived from k1k_1 via a control variate. The blog’s reasoning goes like this: k~(x)=k1(x)+λh(x) \tilde{k}(x) = k_1(x) + \lambda \cdot h(x) where h(x)=r1h(x) = r - 1, and under xqx\sim q its expectation is: Exq[h(x)]=Eq[p(x)q(x)1]=xp(x)1=11=0. \mathbb{E}_{x\sim q}[h(x)] = \mathbb{E}_q\left[\frac{p(x)}{q(x)} - 1\right] = \sum_x p(x) - 1 = 1 - 1 = 0. So adding any multiple of h(x)h(x) doesn’t change the expectation. When λ=1\lambda = 1: k~(x)=logr+(r1)=(r1)logr=k3(x). \tilde{k}(x) = -\log r + (r - 1) = (r - 1) - \log r = k_3(x). This explains why k3k_3’s expectation equals k1k_1’s expectation, and equals the KL, making it an unbiased estimator.

The reason k3k_3 has lower variance than k1k_1 is: k1k_1 only has logr-\log r, which can swing wildly (both positive and negative, with occasional huge values). But r1r - 1 and logr-\log r are numerically highly correlated (when one grows, the other grows/shrinks), and that correlation is negative. Adding (r1)(r - 1) is like injecting a negatively correlated term to cancel fluctuations. After cancellation, what’s left in k3k_3 is tighter in range, always positive, and therefore lower in sample variance.

Community

Sign up or log in to comment