Stochastic collapse AND (explainable AI/neuro OR beating power laws)
Date Posted:
September 25, 2023
Date Recorded:
September 25, 2023
Speaker(s):
Surya Ganguli, Stanford University
All Captioned Videos Brains, Minds and Machines Summer Course 2023
PRESENTER: Now, we'll have a talk by Surya Ganguli. He's a physicist, theoretical neuroscientist, and professor of Applied Physics at Stanford. And he leads the neural dynamics and computation lab. Yeah, won a lot of awards-- let him talk about collapse now.
SURYA GANGULI: OK, thanks. Yeah, it sounds like today is the day of collapse. I'm going to talk about a different type of collapse. And my title is a Boolean logic statement. We'll talk about collapse and explainable AI or beating power laws-- beating neural power laws. I know this is kind of a mixed crowd that's interested in theory and also the brain and so forth, so it'll be kind of a choose your own adventure talk.
So, basically, what's on the menu? So the first part I'll definitely talk about is some of the latest stuff I'm excited about is what kind of implicit biases in SGD in actual deep networks-- not just in toy settings-- can combat over parameterization to aid generalization. And we'll uncover a phenomenon of stochastic collapse that's slightly different from what you've been hearing about today to much simpler subnetworks with many either dead or redundant neurons.
We'll show that collapse of such networks is prevalent and can help generalization. And it provides one explanation for why training with large learning rates early in training for a long time-- even after the training loss has just plateaued-- can be beneficial for generalization. And then we won't have time to cover both of these second parts, but I'll give you a chance to vote on which part you want to hear about. I'll take the vote after we cover the first part.
So option 1 deals with the question-- it's at the intersection of AI and neuroscience-- if we have a really good predictive model of the circuit, what the heck do we do with it? For example, when we do this in neuroscience, are we replacing something we don't understand-- i.e., the brain-- with something else we don't understand, just a good predictive model of it? So we'll show how, for example, to apply explainable AI based methods to model reduce a very good model of the retina. And we'll explain how this one model reproduces decades worth of experiments. So that's option 1.
The other option-- more theoretical-- is these neural scaling laws, where error drops off as a power law with, say, the amount of compute, the amount of data, or model size, have captured the imagination of industry and academia. But they're kind of unsustainable. The exponents are very shallow. So can we beat these neural scaling laws? And we'll use statistical mechanics to develop a theory, at least for data selection, to beat neural scaling laws with respect to data. And we'll apply these algorithms to actually do that.
So you can start to think about which one you'd rather hear about, but I'll take the vote when the time comes. OK, so the first part-- this implicit bias of SGD. So this is some of our recent work to come out of our lab done with a set of really talented grad students. So a classic question that showed up even today is, why do overparametrized deep neural networks generalize? They have so many parameters. They're so flexible. Why don't they overfit the training data?
The key idea is perhaps SGD has implicit biases that don't allow these networks to explore all of function space. So they're not as flexible as you might think they are. At least SGD doesn't allow them to be as flexible as you think they might look like, depending on their parameter count. There's lots of work on this topic in toy settings and so forth. We're going to provide a new additional implicit bias driven by symmetry that applies very generally to many different deep networks that are not toy models. And we'll show evidence for it empirically.
So what's the idea? The idea is that, in weight space, there are certain invariant sets, which, by definition, are subsets of parameter space-- in our case, they'll be linear subspaces-- such that any initialization within the subspace will remain within that subspace for all future iterates of SGD, no matter what mini-batches you choose, no matter what learning rates you choose, no matter what hyperparameters you choose, as long as you're doing SGD updates.
So if such an invariant set were to exist, it would be a powerful constraint. So what are some invariant sets that do indeed exist? So one is a sign invariant set where, basically, if you have a neuron whose nonlinear activation function goes through the origin, there's a sign symmetry that allows such that if you kill the incoming weights and the outgoing weights of that neuron, then that will never recover from future iterates of SGD.
That's a linear subspace in all of weight space where the incoming and outgoing weights are 0. And once they become 0, it will remain 0 forever. Another invariant set called a permutation-invariant set is where you have one neuron or a pair of neurons, say, that have the same incoming weight vector, the same outgoing weight vector, and the same biases. It turns out once you have that condition holding, no future iterative of SGD can break that condition.
The way you prove that is super simple. You just start within the invariant set, so the invariant condition holds. You compute the gradient with respect to any example. And you show that the gradient lies within the subspace that obeys either constraint. So you can show that the gradient with respect to every single example lies within the subspace. And then, once you have that, you know that all future iterates of SGD will remain in that subspace no matter what mini-batch you choose, and no matter what learning rate you choose, because it's a linear subspace.
So, basically, yeah, this is the dynamics-- the dynamical update equation for SGD. Theta's your parameters. L is the loss with respect to one example. This is the gradient. Because every gradient with respect to every example lies within the subspace, the projection of the gradient to the orthogonal complement is 0. For example, in the case of permutation-invariant sets, it's a codimension 3 subspace of parameter space because we have three conditions-- incoming weight's equal, outgoing weight's equal, bias is equal. So it's a codimension 3 subspace.
And in any given layer-- let's say you have n neurons in the layer. You have n choose two such codimension 3 subspaces. So it's as if your parameter space is riddled with these invariant subsets. Now, the next question you're probably thinking of is, so what? If I start in the subset, which is a codimension 3 subset for any one of these pairs, I'll stay within it. And that's great. But will I be attracted to it from a random initial condition? That's what the rest of this part of the talk is going to be about.
The gradient can never break that condition. So what that means is the projection of the gradient onto the three-dimensional subspace perpendicular to this codimension 3 subspace is 0. Or a simpler way to say that is the gradient lies within the subspace. Or another way to say that is once two neurons become equivalent, SGD can't break that equivalence. Those are all three ways to say the same thing. And the math is just you compute the gradient with respect to every individual example. And you show it lies within the subspace. It's just a couple of lines to check that.
Yeah, this is without momentum. But I think you can generalize it. So then as long as the initial momentum vector lies within the subspace, then everything follows. Yeah, so you're never going to initialize on any invariant subspace. But you can get attracted to it. And that's what the rest of this-- yes, we'll talk about the attraction. That's the major question that's the subject of this talk.
OK, so how are we going to prove that, at least locally, you can get attracted to it? And then in simulations we'll show that globally you're actually attracted to it. So what we're going to do is we're going to model SGD using stochastic gradient flow. So what we're going to do is we're going to model the dynamics of SGD as a stochastic differential equation in continuous time. SGD is a discrete time update process. But to the extent that the learning rate eta is not that large, this can be well approximated by this gradient flow-- this continuous time stochastic differential equation.
And if you haven't encountered these stochastic differential equations, don't worry about it. I'll try to explain all of it intuitively. There's a drift term, which is just gradient descent on your full batch loss landscape. So this models the average of a mini-batch. So if I average over all the data points, I get the full batch gradient. And that's what the drift term is in the stochastic process. That's the first moment of the random update of diffusion. But then I have to model the second moment.
So you're going to get some variability in the mini-batch to mini-batch gradient. And that variability is nothing more than the covariance of these gradient terms with respect to all the examples in the entire data set. Yeah? You mean in this one? There's a couple of ways to take the scaling limit. And we took it in a way that we put all of the learning rate in the second term. There's a small reparametrization. But you could have done it in another way where there would be an eta here and a square root of eta here and so forth.
So, basically, what you have here is you have-- this is a random vector whose covariance matrix is just the covariance of the gradients across the entire data set. So you're going to have a non-trivial correlated noise driving the update. And that's what's showing up here. So this is a standard white noise process. So just think of it as independent noise coming into every parameter. But it's pre-multiplied by this matrix.
And this matrix times its transpose is what's called the diffusion matrix. And this diffusion matrix is nothing other than the covariance matrix of the gradients-- the covariance of the gradients. So this is now a continuous time process that matches the first and second moments of this discrete time process. And the critical issue-- so lots of people have modeled stochastic gradient descent with this stochastic differential equation.
But the key thing to think about is the non-trivial, non-white structure of the covariance. And what we've shown based on the existence of these invariant subspaces is that the gradient of every single example lies within the subspace, which means the covariance of the gradients has no power in the dimensions perpendicular to the subspace. All the gradients lie here, so there can be no covariance along this direction if you're on the subspace.
So what that means is if I take this diffusion operator and apply a projection operator perpendicular to the subspace-- a projection operator onto the subspace that's the orthogonal complement of the invariant subspace-- on each side, then I'll get 0. So this is just-- intuitively, what this is saying is the diffusion perpendicular to the subspace vanishes on the subspace. That's really important. In physical systems, diffusion is proportional to temperature, roughly.
So what that's saying is that the dynamics is infinitely cold or 0 temperature on the subspace, at least in dimensions perpendicular to it. And that will have important implications for the attractivity of stochastic gradient descent to the subspace. Yeah. It's an approximation that holds for small learning rates. That's right. So all the theorems we're going to prove will be about SGF. We'll derive predictions. And we'll test those predictions for SGD. And we'll see that the approximation is good enough in practice.
So, now, let's give you a very simple example of what can happen when you have a position-dependent diffusion. So let's think about a position-dependent diffusion in a double-well potential. I'll explain what the GIF is doing later. But you have a double-well potential. And just think of this as your loss landscape for a neural network with one parameter. And you're doing gradient descent on the loss landscape. So your drift term is just the negative gradient of the loss, as it usually is.
But now let's imagine, for whatever reason-- sorry. It's going to just go away in one sec. It'll come back. Let's say, for whatever reason, the diffusion term is growing. There should be an absolute value here. But the diffusion term is growing as you get further away from the origin. So this double-well potential has a maximum at the origin and two minima over here. But the diffusion is large here and large here. And it vanishes to exactly 0 here.
So, roughly in physical terms, this is like a landscape with a mountain here. And the top of the mountain is really cold. But it's really hot here. So then there are certain phase transitions in the dynamics of this position-dependent diffusion problem. And it has to do with this parameter-- sorry, this parameter zeta. So let's imagine that zeta is small. So the diffusion term is small. So everything's dominated by the drift.
Then what happens is the steady-state distribution of this process is it looks like this distribution. Let's make zeta larger. Then the steady-state probability distribution spreads out. Now, the really interesting thing is if zeta is above a certain threshold, this is what happens. Everything concentrates at the origin. This is when this position-dependent diffusion is dominating. The entire process gets stuck at the maximum. It doesn't go to the minimum. And that's what you see here.
Now, what's the intuition for that? The intuition for that is if zeta is large, this position-dependent diffusion is dominating everything. It's really hot here. So if the particle ends up here, it's going to be buffeted around by the high-temperature diffusion. But it's really cold here. If it gets to here, it's not going to be buffeted around. So even though the top of the mountain is giving you drift away from the maximum, the coldness at the top of the mountain freezes you there. And it pins you there because it's too hot elsewhere.
So if you enter a hotter region, you'll get buffeted back to the top. OK, that's intuition. I'll give you a mathematical formula about that. Yeah. So for stochastic gradient flow there will not be, because this is the steady-state distribution. For stochastic gradient descent, because of the finite update, it'll bounce around near the top. And you'll see simulations where that happens. In fact, the position of these points-- these specific points-- corresponds to a stochastic gradient descent, actually.
Oh, so the simulations of these dots are finite-- are SGD at some finite eta. Right. So we know that is SGF a good approximation to SGD at some small eta. OK. OK, so we can prove theorems about this rather than just wave our hands. So we can define-- we borrow a concept from stochastic control theory. We have to define whether or not a fixed point is stable or not, or whether it's locally attractive.
But this is a stochastic process. We need a stochastic notion of attractivity. So there is such a notion from stochastic control theory called stochastic attractivity. It has some epsilons and deltas and so forth. But, roughly, all you need to worry about is this informal statement. The basic condition for stochastic attractivity is if you want all future iterates to remain close to a set with high probability, there always exists an initial condition that achieves that.
So, roughly, it's saying if you want to stay close to it, as long as you're this close to begin with, you will stay as close as you want for all future iterates. So that's a stochastic attractivity condition. And we can prove a theorem for when you satisfy this stochastic attractivity condition in the double-well potential. And this will formalize some of the intuition that I was giving you So because it's a local condition, we only have to worry about the shape of the loss landscape in the vicinity of the fixed point-- this local maximum.
So let's zoom in here and Taylor expand. The Taylor expansion is all we need to prove the local condition. So here's our loss landscape locally near theta. And it's a negative curvature maximum. So we just have a negative-- we have a negative quadratic here. So this quadratic has a second order-- its first non-zero coefficient is second order. So it has a second derivative here that's just a number. And that number will play a key role. The more negative it is, the more repulsive the drift term is.
Similarly, the diffusion-- we're assuming that the diffusion vanishes at this point. And then it's going to grow away from the point. So the first non-zero term in the Taylor expansion is going to be quadratic. So this is the growth, and this is the second order term in the Taylor expansion. The condition for stochasticity is that the curvature of the loss function plus the curvature of diffusion should be bigger than 0.
And this is a necessary and sufficient condition in one dimension. So if it's bigger than 0, and only if it's bigger than 0, you'll stochastic collapse to the maximum. So, basically, what is this saying? If the second derivative of the loss function is negative, then the deterministic drift is pushing you away. But as long as the diffusion is growing faster than the drift pushes you away, the diffusion will dominate, and it'll push you to the origin.
So that's a really funny non-intuitive fact about stochastic processes with position-dependent diffusion. Or another way to think about it is spatially dependent temperature. We're so used to thinking about temperature being constant in a room. But just imagine if that part of the room were, like, I don't know, really hot-- 130 degrees. And this part of the room were, like, I don't know, 60 degrees. We would all be down here, but for different reasons, because we prefer this location. But if we were particles, if this was 0 temperature and that was really high, eventually we'd all get stuck here.
We can generalize this to a sufficient condition in higher dimensions-- a sufficient local condition in higher dimensions. So, roughly, this theorem-- we don't need to digest it carefully, but what this theorem is saying is a generalization of what we saw in the previous thing in one dimension. Again, the second derivative in any direction normal to the subspace has a term, as long as this term coming from the loss function and terms related to the strength with which diffusion grows perpendicular to the subspace.
As long as the sum of these terms is bigger than 0, then you'll be locally attracted to the subspace. So, roughly, the intuition is in higher dimensions where you have a subspace and a higher dimension orthogonal complement, as long as the diffusion tensor in every direction perpendicular to the invariant subspace-- as long as it grows sufficiently quickly to counteract the drift terms perpendicular to the subspace, you'll be attracted to the subspace.
OK, now, you might be asking me, great, is this stochastic collapse still just a theorist fantasy? You've shown they exist. You've shown that under certain conditions you can be locally attracted to it. Do those conditions hold in practice? And B, even if you can be locally attracted to it, if you start from a completely random initial condition, will you actually get attracted to it in practice? Can you see it? So we did some simulations to check that. And surprisingly, even to our surprise, we saw strong stochastic collapse zeta permutation-invariant subsets.
So what we did is we trained a VGG-16 on CIFAR-10, a ResNet-18 on CIFAR-100. So these two rows are two different experiments. And we checked for stochastic collapse to permutation-invariant subsets. So what we did was we took all the incoming weights to each neuron, and we hierarchically clustered them according to their cosine similarity. And this is the outcome of the similarity matrix of hierarchical clustering.
And you can see there are lots of blocks of neurons that all have very similar incoming weight vectors. Then we took the outgoing weight vectors, and they plotted their similarity matrix and the same order that we got from the hierarchical clustering of the incoming weight vectors. And we saw a very similar pattern here, which tells us that pairs of neurons that have similar incoming weight vectors have very similar outgoing weight vectors. And it turns out they have very similar biases as well.
And there's many, many blocks where you have blocks of similar neurons. So this is collapsed to close to a permutation-invariant subset. Now, each row corresponds to a different layer. This is conv11 of the VGG-16, conv12, and so forth. Early layers do not collapse. Now, why is that? It turns out that there's an expansion in the VGG-16 architecture where at some point you go from 256 channels to 512 channels. And it turns out all the layers that do collapse are after this expansion. The ones before don't collapse. Oh, yea, have that? OK.
So to come up with a convenient scalar measure of the stochastic collapse, what we did is we computed the fraction of independent neurons in a layer. So what we mean by that is we just set a very small tolerance-- say, a 1% tolerance-- and we asked if there are any pair of neurons whose incoming and outgoing weight vectors are equal to each other within 1%. Then we just call them the same neuron. And so all of these neurons would be considered the same neuron then under this measure. And they'd just be considered one independent neuron.
And so we computed the number of independent neurons and divided by the total number of neurons. And so that's a measure that measures just the fraction of independent neurons. So if you have a small fraction, then you have a strong stochastic collapse. OK, so we plotted the fraction of independent neurons as a function of layer for the VGG-16. And you see exactly that right between conv7 and conv8, where the number of channels goes from 256 to 512, is when the stochastic collapse starts happening.
For the ResNet, there's interesting oscillations in which layers you get more stochastic collapse. And that has to do with the residual connections, which we have a theory for as well. But I won't go into the details of that. So at least what this is showing is that stochastic collapse is not a theorist fantasy. It actually happens in practice.
All right. So, now, so what is it good for? Yeah, did you have a question? Yeah, so what happens is the permutation-invariant condition becomes more stringent if you have the residual connections because you have to take those into account as well. But there still are certain layers in which you get it. Yeah, so if you want details about when that happens and why, it's in the supplemental material of the paper. It's a bit technical, so I don't want to spend too much time on it.
So the question is, what is stochastic collapse good for? It turns out it can help with generalization. And, basically, what it means is that these permutation-invariant sets are actually saddle points in the weight space of a neural network. They're higher training error saddle points because you can always get to lower training error by having the neurons specialize. But it turns out that these higher training error saddle points have better generalization error, it turns out. And we can show that in a linear student-teacher setting. And then we can also test predictions of that for full neural networks.
Yeah, so let me discuss the case of a linear student-teacher-- the linear student-teacher setting. So I'm going to review some old work by Andrew Saxe, Andrew Lampinen, myself, and Jay on understanding generalization in the linear student-teacher setting for full batch gradient descent. So that work was all for full batch gradient descent. We're going to generalize it to stochastic gradient descent. And then we'll see that there's very different dynamics there between stochastic and full batch gradient descent. And there'll be some generalizable lessons for more general networks that we can extract from that analysis.
OK, so let's review some old work. So imagine that we just have a two-weight layer neural network linear. It can be generalized to multi-layer linear, but the essential features show up in the two-weight layer linear. So the learning dynamics is nonlinear because the error function just was squared loss is quardic in the weights. So the gradient is cubic in the weights. If you write down the gradient descent equations-- again, full batch gradient descent equations-- you get these nonlinear ODEs that, as promised, are cubic in the weights.
They're driven by the second-order statistics of your data-- in particular, sigma 11, which is the correlation matrix of the inputs, and sigma 31, which is the correlation matrix between inputs and outputs. So sigma 31 is a matrix that tells you how correlated is one input neuron with one output neuron across the training set. So it's the expected value, literally, of yx transpose.
We're going to work in a simplified setting where we assume that our inputs are white, so sigma 11 is 1. So then everything that drives the dynamics of learning is sigma 31-- the input-output correlation matrix. For this setting, we can derive exact solutions to the learning dynamics of full batch gradient descent, but for a very special class of initial conditions. But it turns out that class of initial conditions is attracting for the full set of initial conditions.
But, basically, what we can show is that for this special class of initial conditions where the singular vectors of the student are aligned with the singular vectors of sigma 31-- the training data-- all you have to worry about is the singular values of the students and how they evolve in time. So this is how the singular values evolve. They're this particular function that rises like a sigmoid. And so, basically, what happens is the learning dynamics learns the singular modes of the training data of sigma 31 mode by mode until, eventually, the product of weight matrices of the teacher equals sigma 31.
So let's imagine that sigma 31-- this matrix-- has a certain singular value decomposition with singular values given by s-- s1, s2, s3, and so forth. What do the singular modes of the student look like? If you start them from small initial conditions, the singular values of the student will look small. And they'll rise like a sigmoid. And they'll eventually equal the corresponding singular value of the training data, sigma 31. The time of this rise occurs at 1 over the singular value of the training data.
So that's just what falls out of the solutions to the learning dynamic. So sigma 31, which drives the learning, has a certain singular value decomposition. The product of the weight matrices-- so W32 times W21-- that input-output map has a certain singular value decomposition that's time dependent because it evolves over training time. What I'm describing for you is, how do the singular values of this product matrix evolve over time?
The answer is they start small because your initialization is small. And they rise to the corresponding data singular value of sigma 31. But they rise in a sigmoidal fashion. And this is what that rise looks like. So for a particular singular mode of the training data that has a singular value s, it'll eventually rise to s. And this transition will happen at time 1 over s. So a kind of a poetic way to state it is for this setting, stronger statistical structure in the data is learned earlier.
That's probably also true in nonlinear deep learning. The problem is we don't know how to quantify the strength of statistical structure in data. And we don't know how to predict the learning time as a function of the strength of that statistical structure. But in this linear setting, we know how to answer both of those questions. The statistical structure in question are the singular values of the input-output correlation matrix of the training data. And the time of learning is 1 over the singular value.
Another way to think about it is the learning mode. The modes of the student learn the modes of the training data independently. And so you could think of this as a single mode or just a single network. And you could think of the loss landscape as just the singular value of the training data minus a times b-- the weight of the first layer and the weight of the second layer. This is your loss landscape. If you just think about a single mode, or, equivalently, a two-layer network with one neuron in each layer-- two-weight layer network with one neuron in each layer-- and this is your loss landscape, it's easy to see what the global minima of these loss landscape are.
They satisfy the equation ab equals s. So that's a hyperbola with these two branches. These are the minima. It turns out ab equals 0 is a saddle point-- actually, a maximum. And the learning dynamics in ab space look like this. If you start here, you'll evolve to here. If you start with a equals b, you'll have a equals b forever and so forth. So this is the loss landscape for a single mode. Yeah? Why do the smaller singular values take more time to learn?
What happens is if you start from small initial conditions, you'll be stuck at the saddle point. You'll be close to the saddle point. And the escape time of that saddle point corresponds to 1 over the singular value. So if a singular value is large, the curvature around the saddle point is large. And so the escape rate is fast. If the singular value is small, the curvature around the saddle point is small of the loss landscape. So, therefore, the escape rate is large. That's literally it. Of course, you can just can also just analytically solve the nonlinear ODE and get this formula and see that that's the case.
OK, this intuition about this interesting loss landscape for a very simple network will become more important-- even more important when we discuss stochastic collapse in the setting when we go beyond full batch gradient descent to stochastic gradient descent. So are there any questions about this? At least if you forget everything I said, if you think about just this, like, three neurons, two weights, and this loss landscape, this should be completely clear because a lot of the intuition for what I'm talking about can just be obtained from this left column.
In fact, this solution is a solution to the dynamics of this trajectory when a equals b, it turns out. All right, so another way to think about this is that here we were talking about each mode. When does it get learned? It gets learned at some time given by 1 over the singular value associated with that mode in the data. But you could also ask at any instant of time, which modes in the data have I learned? And which ones have I not yet learned?
So if I fix an instant of time and ask, how much has a mode learned? So what I can do is I can think about the singular value of the teacher-- sorry, of the student, which I'll call s-- divided by the singular value of the data, which I'm now going to call s hat here, and I look at that fraction, this is a measure of what fraction of the data you've learned already. And you can think of this learning process as a singular mode detection wave that sweeps in from large singular values to small singular values.
And as this wave sweeps in, you've learned all of the singular mode structure in your training data where the singular values are larger than capital T-- the time of learning. And those that are smaller than little t are not yet learned. So that kind of makes sense. If you train for a certain amount of time, you've learned some large modes, and you haven't learned the small modes yet. So that's what's happening as for the learning dynamics.
The singular mode detection wave will become important for trying to understand generalization. Yeah? OK, we haven't discussed the teacher yet. We've only discussed the dynamics of the student for an arbitrary training data set. And the sufficient statistics of that training data set that drive learning are sigma 31. And sigma 31 is just this correlation matrix between inputs and outputs, i.e., the expected add value of yx transpose.
So we haven't discussed a teacher. The teacher is now going to generate training data that will eventually determine a sigma 31. We're going to shift to that now. So, literally, all I've told you is that if I have an arbitrary training set, this is what the dynamics of the student will be. Yeah? So I've kind of divided the problem up into, what is the dynamics of a student learning for any data set? And then I'm going to now create the data set from a teacher. And then I can discuss generalization error. So far, I can't discuss generalization error, because I haven't discussed the model for the data.
So now, this is a student-teacher setting. We're going to imagine that we have a low-rank teacher that has some ground truth weight matrices-- these bars. So the overbars represent the teacher always. All that really matters is the composite weight matrix of the teacher. It's a low-rank matrix W-bar. I'm going to assume that I'm going to generate labels that are the inputs. We have random inputs that pass through the teacher. And then the output is corrupted by noise-- by IID noise.
So that noise will be important because what it does is it buries the teacher in the training data. Now, remember, we know that the dynamics of the student is governed by the input-output correlation matrix of the training data, sigma 31. So what is sigma 31 in a situation where you have a teacher where you feed in random inputs. You get the outputs, and you corrupt it by noise. It turns out a simple calculation shows you that sigma 31 is nothing other than the teacher plus a noise matrix.
So this is exactly how the teacher is buried within the training data. The statistics of training data that drives learning contains within it the teacher. But it has a perturbation by a noise matrix. So this is an interesting random matrix. And now random matrix theory starts to play a role in the analysis. It plays an important role because what we have is we have a low-rank matrix perturbed by random noise.
And this is a famous problem in random matrix theory. And the spectrum-- the singular value spectrum and structure of sigma 31 and this relationship with the teacher is well understood from existing random matrix theory. So, now, let me review that. OK, so here's the basic idea. Remember, sigma 31-- the correlation matrix of the training data-- is what drives learning. But the teacher is what determines generalization. So what is the relationship between the singular value structure of sigma 31 and the singular value structure of the teacher?
Well, the teacher is going to have some singular value decomposition. That's a specification of the problem. That has a real singular value as s bar. The training data has some singular-- the training data correlation matrix has some singular value decomposition given by the s hats and the u hats. So the training data has the hats on it. The teacher has the bars on it. The question is, what is the relationship between the singular vector of the training data and the singular vector of the teacher?
It turns out it depends on the strength of the teacher, i.e., the singular values of the teacher, and in particular their relationship to the variance of the noise, which we've taken to be 1 here. So what I'm plotting here is the overlap between u hat alpha and u bar alpha-- the training data dotted with the teacher-- as a function of the strength of the teacher. What's interesting is if the singular value of the teacher is small, then what happens is the training data can't detect that mode of the teacher.
And so the overlap between the training data singular vector and the teacher singular vector will be 0, actually. But then there's a sudden phase transition where if the singular value of the teacher is large under a certain threshold, there will be a non-zero overlap between the singular vector of the training data and the singular vector of the teacher. And then it'll become large and become larger and larger as the singular value of the teacher becomes larger.
Yeah? No, the singular value of the teacher is the signal because, remember, the training data is W bar plus noise. So the singular value of the teacher is a signal. I want the teacher to have large singular values so I can see it buried within the additive noise. What really matters is the ratio of the singular value of the teacher to the variance of the noise. And I'm working in units where the variance of the noise is 1. So I don't have to worry about that.
When I analytically compute the generalization error, I can compute it for any choice of s bars you give me. What I'm showing you is I'm showing you the dependence of u hat alpha-- the dot product of u hat alpha with u bar alpha as a function of s bar alpha. So this formula holds for every single singular independently. Now, if you have multiple singular modes, they'll have different singular values. And so then there will be an interesting learning dynamics that happens, which we're going to talk about next.
But, by the way, what about the singular values? It turns out if you're beyond this phase transition, then the singular value spectrum of your data matrix sigma 31 will have an outlier singular value corresponding to the signal. And the associated singular vector will have some nontrivial overlap. But if you're in this regime, then the training data singular values will be buried within this noise spectrum. So, for example, if the teacher was 0, then you just have the singular value spectrum of a noise matrix.
And that has a well-known Marchenko-Pastur distribution. And that's this bulk. So this is all you would see if you had a 0 teacher. If you have a teacher with a small singular value, that's, again, all you'd see. If you have a teacher with one large singular value, suddenly you see an outlier singular value in the training data. Yeah? So because these are unit vectors, the dot product is between negative 1 and 1. But you always resolve the sign ambiguity so it's positive.
OK, so this is how the teacher is buried in the data. Now, with these formulas we can analytically write down the training error as a function of time and the test error as a function of time for an arbitrary teacher. Yeah? OK, so here's an example of some learning curves. Here's the training error. Here's the test error. The solid curves are theory. The triangles are experiments. There's a nice match between theory and experiment. And this is for a situation where you have a rank 1 teacher.
And that rank 1 teacher has a large enough singular value that it pops out of the noise, and there's a non-trivial overlap between the training data singular vectors and the teacher singular vectors. So ignore, for example, these singular modes. So how does the training process look like? Can you get a gut, intuitive, visceral sense of what's happening during training? You can through the singular mode detection wave picture. So now you have a singular mode detection wave that's sweeping in from a small times to large times.
Every time it sweeps over a singular value of the training data, it detects it. And it incorporates it into the student. Now, this first singular mode, when it detects it, that singular mode, singular vector structure is correlated with the teacher. So you learn something both about the teacher and about your training data. So when the singular mode detection wave sweeps over this thing, the training error drops, and the test error drops.
Then nothing happens for a while. Nothing happens for a while because the singular mode detection wave is still traveling. Then it starts to enter the noise associated with the details of the training data. So now, as it penetrates the noise, it continuously starts picking up information about the training data. But it picks up nothing about the teacher because this is all noise. So then what happens is the training error drops continuously, but the test error rises. And that's the nature of the overfitting.
You can combat this overfitting in full batch gradient descent only by early stopping. So you can stop early here. And that corresponds to terminating the singular mode detection wave such that you've detected the large singular values, but not detected the small singular values. And that's the right thing to do because the large singular values contain information about teacher, and the small singular values only contain information about the noise. Yeah?
In the full learning dynamics, the singular vectors of the teacher start to align with the singular vectors of the training data. They never see the teacher. The teacher is just implicit in there, right? So you align with the singular vectors of the training data. And then the singular value of the student aligns with the singular value of the training data. It's a cooperative effect, where they both go up together.
OK, so that was a lot to talk about. You can generalize this to five layers. It all matches. You can even put ReLUs in the simulations, and the global structure of these learning curves don't change. This all survives changing linear to ReLU. You get the same structure. All right, so now-- So that's old work, by the way. So now, what happens when you go from full batch gradient descent to stochastic gradient descent?
So now we have a saddle point here. And just remember that double-well potential I showed you. If the diffusion is strong, you can collapse to the maximum. How do you make the diffusion strong? You make the learning rate large, and you make the mini-batch size small. That's how the diffusion becomes strong. So then eta is your learning rate. And beta is your mini-batch size.
So now, if you work out the curvature conditions that I talked about and turn those conditions into the parameters of this problem-- this should be an s hat here. So s hat is the true singular value of the training data. We get a condition that as long as s hat is less than the learning rate over beta, then this maximum here is actually stable. So full batch gradient descent, if you start close to here-- i.e., with small initial weights-- you will go always to the global minimum.
But if at high learning rate and small mini-batch sizes, you will instead stochastically collapse to 0. Now, this is actually fantastic. This will help with generalization because of the following. Remember, what happens with full batch gradient descent? At any given time t, the large singular values of the training data are learned, and the small singular values are not learned. So if you want to control where you terminate the learning, your only thing you can play with is the optimal early stopping time.
But for stochastic gradient descent, it doesn't matter how long you train for. All the modes that are small-- where small is set by learning rate over batch size-- never get learned. So you can control how large or small the modes are by tuning the learning rate of the mini-batch size in a controllable way. OK, now, this can be a good thing for generalization because, remember, the small singular values are the noise singular values.
You don't want to learn the small ones. They're just noise. You want to learn the big ones. They're the signal. So it's actually beneficial to set the learning rate on the mini-batch size such that all the singular modes of the training data that are small are set to 0 in the student. It's like a low-rank regularization. Now, this has interesting implications and, I think, a more generalizable take-home message in terms of the loss landscape geometry.
The loss landscape geometry of this two-weight layer problem has been well understood for many years. It turns out that this is kind of the picture. There are many manifolds of critical points-- both saddle points and minima. In this two-weight layer setting, you have a manifold of minima-- these two hyperbolas. And you have one saddle. For the full n-dimensional problem, you have many, many saddle points-- manifolds of many different indices, where index is the number of negative curvature directions-- but you only have one manifold of minima.
Why is that? It turns out that every assignment of the student's singular values to some number of data singular values, and the rest of the student singular values are 0. So imagine you have a full-rank student. Let's say k of the singular values equals k of the data singular values, but the rest are 0. It turns out that's a saddle point. And it's a high training error saddle point because you can always do better by turning on the non-zero singular values of the student.
So almost every single extremum point is a saddle point in this problem. It turns out there are no local minima that are not also global minima. What is the global minimum? The assignment in which the top singular values of the student are equal to the top singular values of the data-- so, i.e., all the student singular values are turned on. That's the global minimum. So what this is saying is that the best generalizer for this setting of low-rank teacher plus noise is actually a saddle.
You don't want to get to the global minimum. You want to get stuck at a saddle. And stochastic gradient descent will exactly help you get stuck there. So this leads to a very interesting non-intuitive statement. Higher training error saddle points can achieve lower test error than lower training error minima. And stochastic gradient descent will get attracted to this high training error saddle points. And, therefore, you'll have good generalization there.
Now, we've looked at Hessians of various neural networks after you train them for a long time. And you always get a small number of negative eigenvalues or Hessian, even after the training is done. I was always puzzled by that. I was like, did the training not work? I never found a local minimum? But, actually, this kind of explains why. If you're doing stochastic gradient descent, you're much more likely to be attracted to the saddle. Yeah?
Yeah, so I'm going to get to that. Can we now exploit this to our benefit? The key idea is changing the learning rate changes which set of saddles are attracted or repulsive. And we can engineer that. And that's actually what we have been doing all along as a community-- and I'll explain that-- by training at high learning rates early in training.
So we quantitatively tested our theory in deep linear networks. And we got a really nice quantitative match between theory and experiment. The solid curves are theory. The dots are experiment. Everything matches up. But let me actually move to the result. So this made a prediction. So what prediction does it make? This says that if I work at high eta, then very negative curvature directions will be attractive. But then if I drop eta, the same negative curvature directions will no longer be attractive, but the shallower ones will be.
So I can control which saddles I get attracted to by changing eta. So it turns out a very, very common practice in deep learning-- sorry, and we verified that in the linear networks. And this made a prediction. If you train at high learning rates initially, you'll stochastically collapse to the saddle with many permutation-invariant neurons. And then if you drop the learning rate, it'll be hard to escape from that saddle. So you'll be stuck with those permutation-invariant neurons. You control the capacity of the network, and you get better generalization.
And that's actually what happens in practice. So here's an example. So now these are real experiments. You don't have to trust me as a theorist anymore. Now you have to trust me as an experimentalist. Or more precisely, you have to trust my students as experimentalists. And just so I tell you, I trust them. They're really good. Of course, I debug their experiments and stuff. But yeah, they all make sense.
So, again, we're back to training a VGG-16 on CIFAR-10, or ResNet-18 on CIFAR-100. OK, so this blue curve is the test accuracy. This gray curve is the loss-- the training loss. We're training at a high learning rate. And we're training for at a high learning rate for a very long period of time. So the training loss has plateaued. By the way, if nothing I've said makes sense so far, you can reset and hopefully understand the slide.
So the training loss has plateaued. Now, oftentimes we train at a large learning rate for a long period of time, even after the training loss is plateaued. And then we drop the learning rate. Here are three possible times at which we could drop the learning rate-- this orange time, this green time, and this red time. What happens to the learning curves after the learning rate drop? This learning curve is all before any learning rate drops.
These curves are what happens around the time of the learning rate drops. So now this is the origin of time. And this is what happens after the learning rate drop for the test accuracy. And what you can see is that if you drop it early-- i.e., this orange time-- you're stuck at a certain test accuracy before the learning rate drop, but you get higher test accuracy after the learning rate drop. But if you drop it even later at the green time, you get even higher test accuracy. And if you drop it even later, you get even higher test accuracy.
This is part of the reason why people do this early training at large learning rates. Now, what could be a potential mechanism for it? Well, our theory from the linear network suggests that if you train at a high learning rate, the diffusion is strong. You'll have a lot of stochastic collapse. So maybe what's happening is you get more and more stochastic collapse as you train for longer. And that's exactly what we see. So on these same exact runs, we again computed the fraction of independent neurons.
If the fraction is small, you have a lot of stochastic collapse where you have many, many neurons that have the same incoming and outgoing weights. And you see that at the early learning rate drop, there's not much stochastic collapse. At the green learning rate drop, there's more. And at the red learning rate drop that's latest, you have even more. So you see-- you see a correlation, at least, between more stochastic collapse and more improvement in test error.
A similar thing happens for the ResNet. But, again, the ResNet is more complicated because you get these oscillations in which layer gives you more collapse. But you do see that at longer learning rates-- sorry, for a longer period of early large learning rate training, you get both better generalization after the learning rate drop and more stochastic collapse. Yeah? I think of this more as a scientific explanation of why the annealing works.
But now you can connect two very different quantities to a third quantity, which is curvature of the Hessian. So I suspect that you can relate the best learning rate and batch size ratio to the curvature of the Hessian around permutation-invariant sets. We haven't done that. But that's something you could relate. So the curvature of the loss landscape will depend on both the architecture and the data set.
If you have some predictions about the curvature of the loss landscape as a function of architecture and data set for some future data set or future architecture, then you could use this. But in practice, you'd have to measure it first. But more generally, I just think of it as a mystery. You train longer, you generalize better. Training error has plateaued. What the heck is going on? And then our theory, we never would have done this measurement-- by the way, we saw this same phenomenon in the linear network. And we never would have made this measurement had we not seen it in the theory. And then you see that there's this dramatic difference in the degree of stochastic collapse. And so that might be a mechanism.
This degree of stochastic collapse is all measured before we drop the gradient-- before we drop the learning rate. So these are properties of the network at, say, this time right before this, this time right before this, and this time right before this. Again, when the learning rate is high, that will help with stochastic collapse. Why you need the warm-up I'm not sure. We do know a lot about the very early phase of training.
It's very chaotic. So we have other work that shows that where you end up in the loss landscape is very, very sensitive to your mini-batch selection for the first couple of epochs. So there's something very chaotic going on in the first couple of epochs. My mental picture for that is-- oh, by the way, this picture of loss landscapes. This is from a review article I wrote with my colleagues at Google. We have an annual review of condensed matter physics called "Statistical Mechanics of Deep Learning." I think that's a good way to start to get an overview of physicists' approach to understanding deep learning.
And what happens is high in the loss landscape you're riddled with saddles. So if you're riddled with saddles, what happens is you can get to the saddle and you have to decide, do I go to the left or the right? So, roughly, the Lyapunov exponent or degree of divergence of the training trajectories should be really large, high in the loss landscape where you're dominated by saddles of high index. That's my mental picture for what's going on there.
But when we're talking about stochastic collapse, the ones that-- the saddle points that help generalization are usually low-index saddle points that have a small number of negative curvature directions. So you don't want to collapse to the saddles up here. Actually, I just made up a theory for why that would help. If you start with a random initialization, you're likely to start off high here. But you don't want to get stochastically attracted to these high-index saddle points. You keep the learning rate low.
Then, as you go lower and lower in the loss landscape, you would more and more like to get attracted to these saddle points. You increase the learning rate. It turns out there's a strong correlation between the index of the saddle point and its error. That was another paper that I wrote with Yoshua Bengio many years ago, identifying and attacking the saddle point problem in high-dimensional non-convex optimization. Back then, people were wondering, why don't you get stuck in local minima? And our answer was local minima at high error don't exist in high-dimensional spaces, typically.
Early in training, the learning rate is high. So that promotes a stochastic collapse. You do shift it down. So then the question is, the initial saddle point that you get stuck at, is that the best one? Or is it better to get stuck at a saddle point slightly lower? So that's why the learning rate drop might help. But yeah, anyways. I mean, the only thing you can really take home with you is that the stochastic collapse phenomenon is correlated with the generalization improvement after learning rate drop. That's the only experimental statement I can tell you. And then we understand why for deep linear networks. And we just have a story for why.
So let me just summarize. So here are the take-home messages. Because position-- so that's the key take-home message. Position-dependent diffusion is an exotic phenomenon in physics. But it turns out to be a ubiquitous phenomenon in deep learning because the covariance of gradients with respect to examples depends on where you are in weight space. And, in particular, these position-dependent diffusion constants vanish on invariant sets.
So these are very, very cold sets that can trap SGD dynamics but not gradient descent dynamics. And these invariant sets correspond. For example, two classes are assigned permutation-invariant networks or subnetworks with dead neurons. SGD will be attracted to these invariant sets. We found local conditions for this attractivity. The basic idea is the diffusion must grow faster perpendicular to these invariant subsets, and the drift away from them grows.
These invariant sets correspond to saddle points, not minima. In two-weight layer networks, there are many saddles and only one minimum. All of them generalize differently. And typically, the saddles-- the higher training error saddles-- will generalize better than the lower training error minima. Yeah, you can tune the learning rate and batch size to select which saddles you find attractive.
And in the wild-- in real networks, ResNets, VGGs, whatever-- stochastic collapse is prevalent. And prolonged early learning with high learning rates enhances both stochastic collapse and subsequent generalization. So I was even too ambitious. I had promised you a vote, but I think we're done, right?
[LAUGHTER]
I'm actually just curious to get a pulse on the audience. If we were to go forward with the second part of the talk, there's two options. One is neural AI using explainable AI to really understand conceptually how complex models of a retina work and extract testable predictions. Or just how do we beat power law neural scaling using stat-mech to bend the power laws down-- sometimes even make them exponentials.
I'm just curious. If we were to continue, how many would go for option 1? Interesting. OK, and then how many would go for option 2? It was roughly split but slightly more on the neuro side. So people still care about neuro.
[LAUGHTER]
That's good. Good. OK, all right, but anyways, it's a moot point. But I'm actually glad we spent a lot of time in the first part because, hopefully, I went slowly enough that something-- that I could convey some of the ideas. All right, thanks. I'll stop here.
[APPLAUSE]