Re-understanding KL Approximation from an RL-for-LLM Lens: Notes on “Approximating KL Divergence”
What's the difference between KL-divergence estimation methods used in PPO and GRPO?
John Schulman’s blog post “Approximating KL Divergence” talks about how to approximate KL divergence via sampling (Monte Carlo), and it introduces three estimators (\(k_1\), , ) along with their bias–variance behaviors. But the original post is framed in the context of general probability distributions; it doesn’t touch the reinforcement-learning-for-LLM training setting. This write-up records the questions I had while reading, the thoughts I formed after mapping things to RL for LLMs, and a few places where I felt the original explanations could be pushed on a bit.
What “Approximating KL Divergence” Says (in my own words)
In this section, I’m imagining readers who haven’t read the original post yet, so let’s quickly pass through the most important bits. Put simply, the post is about how we can build reasonable Monte Carlo–style estimators for it, when we can’t directly compute KL divergence.
As the formula shows: when estimating the KL between two (complicated) distributions, there’s a coding trick people often use: just approximate KL by the sample mean of with samples drawn from (as opposed to trying to evaluate the full expectation exactly). The post then points out another move: use the sample average of to replace the more “standard” form, where . The write-up explains why this expression can be a good (albeit biased) estimator for KL, and how to make it unbiased while keeping its low variance.
How we compute KL depends on how we can access and . Here we assume we can evaluate and (probabilities or densities) for any , but we can’t analytically sum/integrate over . Why might we fail to do the analytic sum/integral? Maybe the exact computation is too expensive in compute or memory, maybe there’s no closed form, or maybe we only store log-probs instead of full distributions to keep code simpler, especially fine when KL is just a diagnostic (as is often the case in RL). The most common strategy to approximate sums or integrals is Monte Carlo. Given samples , how do we build a good estimator?
A good estimator should be unbiased (right mean) and low-variance. We know one unbiased estimator:
But it has high variance: by definition KL is a nonnegative quantity, yet for the estimator above, roughly “half” the sample values can be negative (if we assume no prior structure on and ), and this swings the average around a lot, hence high variance. For notational convenience, set . Then the original KL can be written as
To reduce variance, we can design an alternative estimator: It has lower variance, but it’s biased. Intuitively, feels nicer because each sample gives a nonnegative “distance” between and , so it stays positive. Empirically, really does have much lower variance than , and the bias can be quite small. As for why enjoys such a variance drop compared with , the original post uses an -divergence view to give an analytic explanation, and I won’t repeat that here.
Now, can we get an estimator that is both unbiased and low-variance? A general trick is to use a control variate: start from the unbiased and add something whose expectation is zero and is negatively correlated with it to reduce variance. A very convenient zero-mean quantity here is . Thus, for any , is still an unbiased KL estimator. In theory we could minimize the variance over , but the closed-form depends on and and isn’t easy to get. Notice, though, that since is concave, so if we pick , the expression is guaranteed nonnegative. Here, is the tangent of at . So with we’re really measuring the vertical gap between and its tangent. This leads to the estimator which is always nonnegative. And is exactly the piece where, in practice, GRPO differs from PPO in how KL is estimated (PPO uses ).
Discussing KL Estimation from an RL-for-LLM Perspective
In RL (think PPO, GRPO, etc.), we often tack a KL divergence term onto the loss to keep the new policy from drifting too far from the old one. Here, is the old policy distribution ( ), is the new policy distribution ( ), and is a complete action sample (in an LLM this means a token or a token sequence). We usually use to denote the state (in an LLM, that’s the prompt or context), and is a specific token generated in that context. When we compute KL, what we’re really doing is taking the KL over the action distribution given a state, and then averaging over states:
At sampling time, we typically fix a prompt (state) and then estimate this KL for that prompt.
So why can’t we just compute KL exactly instead of estimating it? The reasons are exactly those listed in the original blog post; in RL for LLMs, the main culprit is Reason #1: the action space (token space) is too large to sum/integrate over all possible . For example, if a tokenizer has 50,000 vocabulary entries, even computing the KL for a single token means summing over 50,000 actions; and in RL we’re usually doing multi-step (sequence) generation, so the space blows up exponentially, which is completely impractical. There’s also a pragmatic reason: during training we generally don’t store the full distribution (all token probabilities); we only keep the log-probs of the tokens actually generated along the trajectory, to save GPU memory and I/O. So we have to use Monte Carlo sampling: draw from some distribution (usually , the old policy), and use those samples to approximate KL. And that drops us squarely into the territory the blog post is about.
In that post, the estimator we keep talking about is really just a function of a sample: it takes and for some sampled (or their ratio ) and spits out a number. We then take the average of those numbers over our samples to approximate KL. For example:
These are just different KL-estimator formulas. They all approximate KL by averaging over samples, but differ in bias and variance. Once we pick an estimator, we’re really just committing to a specific formula for approximating KL. The process looks like this:
- Sampling
Sample a batch of tokens (or sequences) from the old policy . - Compute log-probs
For each sample, compute the log-probabilities under both new and old policies:
and get or .
3. Plug into the estimator formula
For example, if we choose :
- Average
That’s the approximate KL value, standing in for the true KL.
If we compare this to computing the true KL (no estimation) for a discrete probability distribution (LLM single-token step): we’d need to iterate over every possible token : You can see immediately that with an estimator, the computational load is much smaller than doing the full sum, especially in high-dimensional action spaces.
Talking About Variance in Different KL Estimators
Important to note: the “variance” we’re talking about here is the variance of the values the estimator outputs over samples: That is, how much fluctuates across the sample space. An unbiased estimator means that with infinitely many samples, its mean equals the true KL. But a high-variance estimator means that even if the mean is right (unbiased), with a small number of samples the average can be way off. In RL for LLMs, the KL term is often a regularization factor in the loss (e.g., ). If the KL estimator’s variance is large, it makes the loss noisy, which in turn makes gradients noisy and training unstable.
In the original post, to give readers an intuition for why is not low-variance, the author writes:
However, it ( ) has high-variance, as it’s negative for half of the samples, whereas KL is always positive.
The author points out that although is unbiased, without prior constraints on and , half the samples will have one bigger than the other, so half of the values are positive and half negative. Up to here I’m fine. But then the author says: because KL is always greater than 0 (a basic inequality), must therefore have high variance. And here I think the causal link doesn’t actually hold: you can’t use the sign of the expectation to dictate the sign of individual samples. A quick counterexample: in computing the expectation, is also sometimes positive and sometimes negative; that fact by itself tells you nothing about variance. In reality, a single-sample log ratio (whether or ) can be positive or negative, just like , so sign-flipping alone is not the sole reason for high variance.
From the KL definition: The expectation is guaranteed nonnegative, but the integrand can be positive or negative for individual samples. And is exactly that integrand: So each sample value can indeed be positive or negative, same as the integrand in the KL definition.
So why does have high variance?
It’s not the mere “sign flipping.” The real reason is that ’s value distribution is often wide (heavy-tailed). For example, if is tiny for some sample, then can be huge (positive or negative). These extreme values dominate the finite-sample average, pushing variance up. In other words, it’s the combination of extreme values + positive/negative cancellation: cancellation means you need more samples to converge to the true mean, and extreme values make the sample variance itself larger. So the “half negative” comment in the blog is more of an intuition hook than a complete explanation.
From this perspective, if we look at the other estimators and , we see: is always positive, so there’s no cancellation, but this introduces bias; squaring also smooths the magnitude, reducing variance. uses a control variate to knock out part of the fluctuation source, lowering variance while keeping unbiasedness (details next).
In PPO/GRPO, if you use and the batch is small or the distributions are far apart, the KL estimate will jump around (because a few extreme samples can swing the mean hard). That makes the KL-penalty coefficient unstable: it might suddenly be way too strong or too weak. Switching to a lower-variance estimator ( or ) makes each sample’s KL contribution steadier, less likely to be dominated by a handful of extreme samples.
Why can be unbiased and low-variance?
At first glance, is always positive, so you might think its mean must be larger than ’s.
But remember: is derived from via a control variate. The blog’s reasoning goes like this:
where , and under its expectation is:
So adding any multiple of doesn’t change the expectation. When :
This explains why ’s expectation equals ’s expectation, and equals the KL, making it an unbiased estimator.
The reason has lower variance than is: only has , which can swing wildly (both positive and negative, with occasional huge values). But and are numerically highly correlated (when one grows, the other grows/shrinks), and that correlation is negative. Adding is like injecting a negatively correlated term to cancel fluctuations. After cancellation, what’s left in is tighter in range, always positive, and therefore lower in sample variance.