Skip to content(if available)orjump to list(if available)

From Memorization to Reasoning in the Spectrum of Loss Curvature

kingstnap

A very similar idea is presented here in the first 5 minutes of this recent talk. But more from observing a kink in loss curves.

https://youtu.be/UyK3DgWY7yw?si=NN3f9Erik8o_Nfbs

andy12_

Very concise summary of the procedure described in this paper:

1. Run the model once across a dataset to estimate loss curvature per MLP weight matrix via K-FAC (activation/gradient covariances).

2. Decompose each weight matrix into curvature-ordered components; low-curvature directions correspond most to verbatim memorization, higher curvature to shared/general mechanisms.

3. Edit by dropping the low-curvature subspace and keep only the top directions.

getnormality

Thank you!

I think you may have accidentally switched low and high in #2, no? The abstract speaks of high curvature as associated with memorization:

> curvature for memorized training points is much sharper than non memorized

andy12_

Actually, no! Look at this in the paper

> In extending from studying per-example to bulk memorization, we propose a novel inversion of the previous interpretation of loss curvature: while individual memorized points are associated with high curvature, the direction of curvature varies across examples, meaning that, averaged across multiple examples, memorization directions are actually flatter than generalizing directions, which maintain a consistent moderate curvature across points

radarsat1

This sounds more correct to me. I've read previously somewhere that better generalization is usually associated with wider, smoother minima, and this is why regularization is important, because it has a smoothing function on the loss landscape.

getnormality

Yes. This is also not hard to see intuitively from scratch.

Say you have a smooth but highly flexible model y = f(x) and some data points you are fitting with a machine learning algorithm. For whatever reason, the algorithm decides it wants to reduce training error by interpolating some specific point, (x0,y0), without negatively affecting training error on nearby points. The direct, guaranteed successful way to do this is to adjust the model to y0 = f(x0) exactly on x0 by adding a Dirac delta there, leaving the rest of f exactly as-is. But this cannot be done on a differentiable model, as it would create a discontinuity. The next best thing that such a model can actually do is replace the Dirac delta with a smooth but very narrow bump (e.g. Gaussian). But this narrow bump will inevitably have extremely high curvature at x0, since the bump is flat at x0 and it has to merge with the neighborhood around x0 in a very short distance.

Think of driving: if you have to change lanes in a very short distance, you're going to have to steer hard. Steering is curvature.

woadwarrior01

That's very reminiscent of the idea behind the SAM (Sharpness Aware Minimization) family of optimizers.

vessenes

Thank you for this huge time saver.

Now, about the paper-that’s super interesting. I imagine the dream here is to distil down into a “reasoning” core. Or maybe reclaim space for more generalization. Lots of interesting use cases.