TurboQuant, or: why are random compression schemes often nearly as good as carefully-designed ones?

Google’s TurboQuant paper introduces a new algorithm for compressing large language models, based on an old-but-still-cool “trick” of higher-dimensional probabilistic geometry, the Johnson-Lindenstrauss (JL) lemma. Normally, this blog is about text editors and stuff, but the trick is cool enough that I could not help myself writing about it.

Roughly, the JL lemma states that if you have points in large d-dimensional space, a completely random projection to much smaller k-dimensional subspace will be (in some sense) “nearly optimal” “in the general case.” Or, more specifically: with high probability, the pairwise distances between points are preserved, given a couple other requirements around d and k.

But, wait. Completely random?? Nearly optimal??? Really???? Yes. But it’s not as useful as it seems, and you can very often do much better, e.g., through Principal Component Analysis (PCA).

The goal of this article is to explain why, using interactive examples to help build intuition for where it breaks down.

Disclaimers: There are probably some errors in the code here, since (1) I do a bunch of tricks to make this fast and (2) it’s been a looooong time since I’ve hand-coded Jacobi eigendecomposition. Should be enough to help you build intuition though.

Histogram showing which projections have many “badly distorted” points

The JL lemma says that with high probability, all ratios stay within the (1 ± ε) band, assuming you your choice of k satisfies k≥O(log(n)/ε^2).

But, this depends on n, d, k, and ε. Play around with each to see what happens! Notice that lowering k precipitiously produces badly-distorted points.

Heatmap of “badly distorted” points

Same thing as above, basically, but as heatmaps. Here you can see which points get distorted.

Left heatmap is the original pairwise distances, middle is the projected pairwise distances, and the right is the empirical distortion ratio. If JL is working, the first two should look nearly identical and the third should be almost uniformly green.


PCA vs Random Projection: FIGHT

Ok. So now the bad news. We said before that random projection performs nearly as well as a carefully-selected projection “in the general case.” But we also said that PCA can perform better in other, specific cases. Which cases? Well, that’s the rub. When the data exists in a subspace (as it does with nearly all datasets meaningful to humans) PCA will perform much better.

Thought his visualization is not perfect (e.g., JL lemma is not meant for 2/3 dimension projections, and also I made it as a student in 2011, and the projections are not orthogonal to the plan, lol) you can kind of get a sense for how it works.

Now, let’s see how it performs in practice by doing the projections and comparing it to PCA.

5a. Random Data: Random Projection ≈ PCA

When the data is “isotropic Gaussian” (i.e., basically random points), random projection is about as good as PCA. So, both clouds are kinda spread out in similar ways.

5b. Low-Rank Subspace: PCA ≫ RP

When data lives on a low-rank subspace, PCA finds it perfectly and preserves distances more-or-less-exactly, while random projection spreads out. RP still satisfies the JL bound, but PCA dominates. Notice the orange points are now on the y=x line, which indicates essentially no distortion.

Empirical Failure Probability

So, where does it break down in practice? Let’s take a look. For each target dimension k, we run 50 random projections and count how often any pair violates the (1±ε) bound. The JL lemma predicts this probability drops exponentially in k. The curve below confirms it.

Conclusions

And that’s about it. If you notice bugs let me know, I did this whole thing in about 30 minutes by hand. I’m sure there are problems in here.

Appendix

The jlCore cell provides all shared utilities: seeded RNG (mulberry32), Gaussian point generation, random projection, distance computation, and a global state store.