Towards Understanding the Implicit Regularization Effect of SGD
Date Posted:
September 6, 2024
Date Recorded:
August 11, 2024
Speaker(s):
Pierfrancesco Beneventano, Princeton University
All Captioned Videos Brains, Minds and Machines Summer Course 2024
AUDIENCE: This morning, [INAUDIBLE] in the department of applied mathematics
PIERFRANCESCO BENEVENTANO: Operations research nomenclature, you know?
AUDIENCE: And he'll speak about an old puzzle in deep learning, which is applied implicit biases in SGD [INAUDIBLE].
PIERFRANCESCO BENEVENTANO: Thank you. I think I have it. OK, so thank you so much, Tomi, for inviting me here. On one hand, I feel very connected with this group of people because you also show up late to the first events in the morning. On the other hand, I think it's going to be tough because, when I give this talk in Princeton or in mathematics department, people still remember that, four years ago, they all worked on this kind of problem.
But you people at MIT and co, I feel like you are-- I don't know. Now, you realize you don't care a bit. Like, you don't care so much about this anymore, because you were hearing yesterday, now, SGD is essentially online. And on my first slide of notation, there's going to be-- we have a finite data set, you know? But I'll try to convince you that this stuff is still somewhat relevant. And worst case scenario, it's going to be relevant when you have a small data set.
So you can call me Pier. Feel free to interrupt at any time. I hope I'm going to be-- I'm going to finish five minutes early. But we'll see. I'm going to speak about a paper I did alone and a paper I did with Tomi and Andrea, which is a collaborator of us. But actually, when I planned this talk, I thought it was a lecture, not a talk. So it's going to be half lecture, half talk. And so, it's going to start with checking in on my audience.
So I assume you all know what we're speaking about. That is training, neural networks. But I want to know from you, or at least, before coming yesterday and seeing that everyone was assuming, you all know everything. And you probably all know everything. I want you to know, for example, who does not know what gradient descent is? Cool, cool.
Don't be shy, don't be shy. Yesterday, I've been very shy when Aaron was speaking. And I was like, who does not think-- I was actually on the yes side. So I thought that you can get there with scale. But then, everyone was saying, no, no, no. And so, I didn't feel like saying it. But, so, good. You're brave, braver than me, for sure. Who did not take a course that dealt with optimization? OK, cool, cool.
And who doesn't know that-- you have many optimization things. But when stuff are not convex and when your landscape you're going down is not nice, then you may have problems with convergence. Cool, cool, OK. And who does not know how to prove that ReLU networks converge to a global minimum? I hoped you knew. This question is-- I hope someone knew on the audience, but, OK, too bad.
OK, cool, so, I had some hopes for the audience. And I think you met my hope. So I hope the talk, or the lecture, it's good. So let's start very briefly with gradient descent. What is gradient descent? So essentially, we have a landscape. And we want to find the minimum. We have a function. We want to find the minimum of the function. And throughout the talk, you can think of-- I don't know. You're on the mountains. You see a river down there. And it's, like, the Wilson algorithm to get down to the river.
And gradient descent is the most basic way to get down the river. So you have a function from Rd to R. In the landscape metaphor, it's R2. And R is the elevation. And you want to get to the lowest elevation. And you want to get there. And you want to get there fast enough. So we want to find one of the minima. Maybe there's a river, and the river or a lake. And the lake is-- the points on the lake are all at the same elevation.
So it's like, you want to get to the lake. Sometimes, you don't really care where you get to the lake. But sometimes, you care. And this is going to be the case later. And so, how can you do that? You're somewhere. And you say, oh, in what direction am I going down? And you do a step in that direction. This is gradient descent.
So it's, like, you take the gradient. The gradient is points towards going up. And so what you do is, I'll do a step behind, OK? And the size of the step is going to be called step size, or learning rate. And it's going to be my eta. And so, again, my position at time i plus 1 is going to be my position at i, minus a step in the direction in which it grows the most, so plus a step in the direction in which I go the down-est, let's say.
And so, as an example, let's take the function that takes a vector. I'll call my vector theta, OK? So there is theta 1, theta 2, theta 3, theta d in dimension, d. And let's take the function that is the first component square. So you know what the minimum. Or a minimum here is, every point, I don't care about, from the second to the d component. But I care that the first component is 0. This is the minimum. These are the minimums of this function.
And so, let's do gradient descent. And let's see how gradient descent work here. And so, we want to find the minimum. We start from theta 0, somewhere, wherever I am during my hike. And I say, OK, I want to go down. So we said, gradient descent is-- you take the place where you were. You compute the gradient. And you do one step back. And that's going to be the place where you're going to be.
So it's like, OK, what's the gradient of this function in theta 1? Well, in theta 1, we have a quadratic. So we can compute this gradient. What's the gradient of this function in theta 2? Well, this function does not depend on theta 2. And we'll call later, during the talk, essentially, theta 2 is not in the support of this function. And so, actually, the gradient in theta 2 of this function is 0. And so, essentially, gradient descent on this function is keeping you where you are on the second, the third, the diff component, and going down on the first component.
Does it make sense? So essentially, when you take the derivative there and you do this step, you realize that every component, j, is getting shrunk by 1 minus 2 step. And this delta is some mathematician kind of thing. It means, if j is 1, this is one. If j is different than 1, this is 0. So essentially, on the first component, you're going down on-- you're shrinking of 1 minus 2 over eta. And all the other ones, you're multiplying by 1. It's the same.
And the minimum we find that we're saying is exactly where we started on some components, and then 0 on the first one. Is everyone clear with this thing? Do you have a grasp of what gradient descent is now? Cool. And so, in a sense, we characterize the speed of convergence because I told you, it's shrunken by 1 minus 2 eta, multiplying by that. And so, now, we have a speed. If you want to get to epsilon in the first component, you know that you have the iterative formula to get how many steps you need to get to epsilon. And we also found the location of convergence. And this location of convergence, what that is, that is essentially the closest point to your initialization that was lying on the beach over the lake.
So it's like, if you look at this, this is essentially the landscape, your landscape for this function. And so, on the first dimension, we're going down. And on all the other ones, our landscape is flat, OK? Because we're not getting anything by moving in that direction. And essentially, what gradient descent is doing here is going down and touching the river. Now it's a river because it's two-dimensional-- touching the river in the place that it was the closest to where we started, OK? Cool.
This is not the case for neural networks. So first of all, we don't know, if we converge, we don't know to what speed we converge, of course. But we also don't know to which minima we go-- to which minimum we go. We don't know where we touch in the river. And we'll see later that, actually, if I massage a bit my gradient descent and I take different versions of gradient descent, actually, those different versions go in very different places along the river. And this is going to be essentially the pitch of my talk, trying to come up with how the different versions of gradient descent pick the place on the beach where they go to the lake, OK?
So, now, the talk starts. And I'll give you some motivation, some motivation to study this problem. You've heard from many people now that want to understand, why neural networks generalize, or how neural networks generalize, or when neural networks generalize. And generalize means that it works-- the neural network, when you put it in production, works, essentially.
And people realize that, for example, this is a matrix completion problem. It's not really relevant, what problem this is. And this is a graph that I copy-pasted from a talk of Nathan Srebro. But I asked permission. And here, they show that, on the same problem, when you use different algorithms-- so you use either GD, or you use a regularized GD, or you rewrite the problem in some different way, then, actually, you consistently go to very different minima with very different algorithms.
And from an optimization perspective, so getting down to the river, you may think that all the points on the river are the same kind of minimum. But in practice, this is not the case. So you have that-- some minimum are generalized well. And some minimum are generalized badly.
So the question about generalization, the way I see, and I'm sure people disagree with me in this audience, but the way I see generalization, at least when you have a finite data set and you have a model in which could check this kind of thing, was about the location of convergence. So I'm an applied math person. I do optimization. The only thing I look at is the optimization problem by AI. And, essentially, can we train, it means, do I converge? How fast? But does it generalize the minimum that I find is a question about location of convergence for me, OK?
And how are the lost landscape of neural networks? How is the river made? So there are many different minima. And this has been known for long. And I was saying, different algorithms find different minima consistently. And so, in a way, the way I see this problem and the way the community was seeing this problem in 2020, maybe not anymore, it's about how the algorithm travels those landscapes, OK?
So I'll show you an example real quick. And it's like, OK, let me take this amazing data set that is five data points in a W shape in one dimension. Let me initialize and take a certain learning rate and do gradient descent there. Well, very often, gradient descent really quickly finds a minimum. The problem is that this is a minimum. And this is actually a minimum that gradient descent finds. This minimum here is perfectly balanced.
The distance from here to here and from here to here is exactly one half of the distance from here to here. This is a perfect minimum from the optimization landscape perspective. But actually, from a machine learning perspective, this is a badly-generalizing minimum, because when I start sampling real life data points, then I'll sample data points here, data points here, data points here, and actually, the algorithm, my AI, doesn't properly predict those minima. But right there, I'm at this local minimum.
The GD found, OK, let me start running stochastic gradient descent, whatever that is. Then what happens is that, actually, stochastic gradient descent starts moving around, starts moving around, moving around. And at some point, you see this function that oscillates, oscillates around this function. So you have some perturbation of this first function. At some point, you escape. And you escape real quick to this one, which is kind of an amazing global minimum.
Some people may tell me, oh, but actually, the data points on this line are not well-fitted. Yes, but your training data is well-fitted. And for sure, this works much better than this one here. So the takeaways-- yes, sorry, I didn't see questions. Please speak. I know it's usually not considered polite. But I don't care.
AUDIENCE: But if you just have the information of five data points, you don't know the underlying data. You don't know it's a W that you're trying to fit.
PIERFRANCESCO BENEVENTANO: Absolutely.
AUDIENCE: So then, there's little difference between.
PIERFRANCESCO BENEVENTANO: Absolutely. But still, this one's wrong. And this one's correct. From your point of view, this one's wrong. And this one's correct. Does it make sense? Yeah, of course, I could have any data point in here, and also any data point in here. But this one fits in the training set and this one not. So before putting assumptions of this data set that actually comes-- essentially, when I was saying, telling to you, a W, I was saying, I'm actually assuming that this is the easiest possible data-generating process that generates this W, OK?
And this is what she's calling me out. Yes, I actually don't know what's going on here in the real distribution. It may be there is a sign, the sign, and goes like this. But still, here, my training data is fitted, and here not. So at least you can trust me on that, yes? Cool. So the takeaway, the first takeaway is, in general, we have many minima. And which one you find actually matters, OK? Yes.
AUDIENCE: Maybe it's related to the previous slide. But you differentiate between global minima and local minima? Or d you care about not finding a local minima, but finding a global minima on the training set, or finding the best global minima that will generalize better? Because this is only local versus global minimum.
PIERFRANCESCO BENEVENTANO: Right. Later, I'll care about, within the global minima, finding one that is better represented in the parameter space. And now, this was a lecture-like example, yeah, sorry. He's the person I'm fearing the most, oops. OK, cool. So there are many minima. Which one you find matters. And also, different algorithms find different minima. I showed you with a very poor example.
And we're interested in-- so in this talk, I'll be interested in the implicit regularization of the optimizer. This meant, in the literature, many different things. But the thing I'm meaning by that is, essentially, can we try to characterize what minima algorithm find? Well, actually, we won't characterize that. But I'll tell you, when you're in these places, minima SGD goes away. This is going to be essentially what I'm going to tell you in the first 2/3 of this talk.
And so, how is the landscape? OK, wait, wait, wait. And so, essentially, we're not interested now in, how cheap and how quick is the convergence? But in a chiefly non-convex way, landscape. So you don't have just-- like a valley that's going down. But stuff are happening-- essentially, where we're going. And again, I won't answer this question. I liked a feature in the first part of mini batch SGD that may actually lead us to answer this question. Cool.
So the minima in the landscape of neural network, actually, they lie on a high-dimensional manifold. What does this mathematical jargon mean? It means that there are some rivers and lakes in which all the points are minima. It's not that it's like a minimum here and a minimum there. And so it depends on in which valley I get, OK? It's like, within every valley-- so, yes, from a certain point of view. But within every valley, there are, I don't know.
For example, in our previous case, there was a D minus 1 dimensional manifold of minima, OK? In that case, it was a linear space, since any of the second D minus 1 component could be any value. And that was still a minimum, OK? And so, essentially, the questions of where the algorithm goes can be refined if we care about the setting, or at least if we care about the second part of the training where we already picked our valley, to, how do the algorithms pick a minimum once they're close to the coast?
So it's like, you're close to the coast. There are many places where you can put your umbrella, OK? And now, you're trying to pick the best on the beach. You're looking around. You're like, where on the beach should I go? This is essentially what the question somehow boils down to, once I massaged this way. Cool. So actually, this question, when you try to mathematize, it becomes-- imagine I'm in a stationary manifold. So everything is pretty flat around me.
Yeah, I mean, I still can see the direction in which I go down. But all my landscape around is kind of flat, the same, like here. How do the different algorithms approach this problem of finding a minimum? And empirically, we realize, even in a toy model, they do it differently. And now, we'll get through it.
But let's speak about the algorithm first. So people work with mini-batch SGD. What does this mean? So I have a set of parametric functions, OK? And those are my loss on every data point. And the goal is finding the minimum of my average over the data points of the loss. So here, theta are the parameters on which I'm optimizing. The z's are indices, but are actually my data points.
And my functions are kind of smooth everywhere, for what this means. And essentially, what happens in real world neural networks is that my model-- so the number of thetas parameter that I have is too big. So, usually, I want gradient-based methods for whoever understands this thing. I cannot do a second order method because computing the action is too much. And also, the loss is an average.
I have many data points. So I'm like, you know what? Instead of computing the gradient of a sum of functions, you have to compute the gradient of all those functions, and then sum that up. Actually, I'm taking just 10 data points at a time. I compute the gradient there. I do a step, then other data, 10 data points, then other 10 data points. So this is called mini-batching, OK? So we're not doing gradient descent. We're doing the most computationally feasible, if you want version, of gradient descent.
And so, to set notation again, parameters of the i plus 1 step, it's parameters at the i step minus my step size and the average, 1 over b, when I take b data points in my batch, i plus 1, the gradient on those data points in that batch. Questions? Cool. And now, you may ask me, how do I sample those batches here? Did anyone think about how to find the subset of the data set to look at each step? Well, there are essentially two ways to do it, two common ways to do it. There are more.
But the important ones are two. And one is with replacement. This means that I'm sampling with replacement. As I take the bowl out, I put it back in. So I sample a data point. I sample a batch. And then I put all those data points back in when I have to sample the next batch. That's very nice mathematically, because now, those batches are independent. Those steps are independent. I can use a lot of very nice mathematics that people developed for us. But then there's another way. And I sampled this batch.
So I took out those b data points. And those, I don't put them in again. I sample a batch between the data points that are still in so that I'm sure that I see all the data points by some time. And then, when I saw them all, I put everything back in. And I restart with this thing. And those two algorithms may seem very similar. But the mathematical nature of how they behave is going to be different. And this is essentially what the first part of the talk is going to be about. Cool.
So the thing is, SGD is without replacement in practice when we have a fixed data set. And the reason are multiple. So first of all, it's because that's the standard of TensorFlow and PyTorch. And most of you never touch those standards, OK? But now, you should ask me, why is that the standard? And it's the standard for two different reasons. One, we had some work of 10 to 15 years ago that was telling us, it generalizes better.
And two, we have some other work that tells us, it's faster in any way, so faster computationally, because, essentially, a way you can do this without replacement is by reshuffling your data set and partitioning. You don't have to sample at every step. And sampling is-- I mean, I'm an numerical analysis person. And when we study in school, numerical analysis, we know that the products between matrices and sampling are those two things that, we have to do it all the time. So they're actually computationally expensive.
But also, it's faster from an optimization perspective. It converges faster. So essentially, we want to analyze this one. But now, the steps are not independent. So we cannot analyze and use all these magic tools that people produce for us. So if it's faster in every way, so what we found out in 2009 in a paper of four pages, in a paper of four pages, that, actually, the slope is twice the slope of SGD with replacement. And then there has been a lot of work towards proving this for convex functions, which is a nice function in which we know we can go down.
And in the end, I think the best result we have is by Mishchenko et al, 2020. But now, I care about this, in practice, generalizes better. And so, this is a question. And it's also taking its time. And this one generalizes better. Essentially, to me, this is a question about the location of convergence, I was saying, OK? So it means it may go to different minima. Or the mathematical nature of the effect that made it pick a minimum within the river may be different from the other one. Is that clear? Cool.
OK, and so, actually, in this experiment here, let's try to look at some graph to see what's going on here, OK? So we went from this local minima of SGD. We run either SGD with and without replacement, actually, consistently, different stuff happened. Consistently, the same thing happened, but in a different way. So essentially, what happens is that those algorithms start oscillating, oscillating around. This is my loss curve.
So you say, at least the lower bound of it is staying flat. And they oscillate a lot. And then, at some point, they escape down. And that is the global minima they found. But there are consistently, in the hyper-parameter, different behaviors. And SGD with replacement oscillates a lot and escapes later. Well, SGD without replacement, then it oscillates less and escapes faster. And so, essentially, I want to understand this phenomenon. What's going on here? And so, yes.
AUDIENCE: This is the loss of the train?
PIERFRANCESCO BENEVENTANO: Sorry?
AUDIENCE: This is the loss of the train?
PIERFRANCESCO BENEVENTANO: The training loss, yes. And so, let's take the easiest possible neural network that is the product between two numbers, OK? So it's like, I have one number. I multiply it by one other number. And this is going to be my minima classifier that now is, it's still linear in my data. But it's not linear anymore in my parameters. And let's try to see, why is this nice? Well, for no reason but there is a two-dimensional landscape. And so I can plot it on my slide.
But let's see what's going on here. And what's going on here is that they actually travel, those manifold of minima, differently. So here, essentially, the manifold of minima is an i parable, because theta 1 times theta 2 has to be 1. And so, this can be 10 and 1/10, or 1 and 1. And essentially, when I do GD plus Gaussian noise, or when I do SGD with replacement or SGD without replacement, they all somehow touch the manifold of minima in the same way when I initialize them close enough. But actually, then, some of them, even without replacement, they start oscillating around very often.
And so, now, I'm sure that Aaron is going to ask me, wait, there's no mini-batch in this problem. Well, actually I sampled some x's and some y's, in a way that, when you write down the problem, it actually centers around the air. So now, I have some data points that align along the same function, linear function, that is y equal x. And now, I'm doing stochastic gradient descent on this problem.
And actually, what we see is that SGD with replacement oscillates around the manifold the minimum moves towards this amazing pink spot, which is our flattest solution, or lowest norm solution, which is theta 1 equal to theta 2 equals 1. And the same happens for without replacement. But without replacement moves faster and has less oscillations. And this is what we saw already in that extremely [INAUDIBLE] example. Did I convince you guys? OK. Go first.
AUDIENCE: It looks like the SGD without replacement, even though it has less oscillations, but it also looks like, in the last part, it travels along a distance.
PIERFRANCESCO BENEVENTANO: Sorry?
AUDIENCE: It looks like it covers a longer distance over time in the end.
PIERFRANCESCO BENEVENTANO: Yeah, it shows-- you mean, it's getting closer and closer to the manifold? Yes, yes.
AUDIENCE: What is the red line?
PIERFRANCESCO BENEVENTANO: The red is a different learning rate. So red is very small learning rate. Yellow is very big learning rate. Yes.
AUDIENCE: [INAUDIBLE] must find the global minimum.
PIERFRANCESCO BENEVENTANO: All the points on this river are global minima, because theta 1 times theta 2, the product of them doesn't change.
AUDIENCE: So why did this decide that the pink one is the best?
PIERFRANCESCO BENEVENTANO: The pink one is, I was saying, the best aka, the lowest norm solution, or the flattest solution. So when you compute the highest eigenvalue of the Hessian, along the old manifold in the pink one is the lowest. And the pink place is the closest to the origin. The origin lies exactly here. So I mean, it's not the best solution from a generalization perspective in this problem. But some mathematical features have been good. Yes.
AUDIENCE: What-- does it have anything to do with the initial condition?
PIERFRANCESCO BENEVENTANO: No, actually. You see, I'm starting at 6, 1. And whenever I start, if the learning rate is small enough to actually converge, because, otherwise, it shoots me up, then as it reaches the manifold of minima, it starts going around. Cool. I think I'm going kind of slow. But it's good. So what's this going on here? Actually, when you try to plot-- and here is going to be the magical thing.
So essentially, when we find the minimum for a function, there is an average of stuff. We want to find the 0 of the gradient. And so, we want to find the place where the average gradient is 0. Let me now plot the trace of the covariances of the gradient. So, essentially, it's some form of second moment of that quantity that I'm zeroing out. Well, actually-- and this is, of course, the case for the linear network I showed you. But this plot is in the ReLU case.
Actually, here, this is the training, the loss dynamic on the W data set for SGD without replacement. Actually, here, this trace of the covariance of the gradient, it's doing weird things in the phases in which I'm converging because there is-- I care about getting to 0. But then, in the phases in which I'm kind of around 0, then what's going on is that those traces of the covariance is going down, OK?
OK, and now we get back to the theorem, not get back. We get to a theorem. There is, by far, not strong enough to explain this phenomenon. But at least I think gives a flavor of what the phenomenon is. There is a point-wise bias. So the question is, imagine I'm somewhere during my hike. And I get hit by a mosquito. I want to go down. And I'm like-- I decide if doing mini-batch SGD was a very bad metaphor for doing mini-batch SGD. But I decide between which mini-batch, which version of GD to do, OK?
So where can we expect to be after k steps? So I'm like, OK, I use algorithm number two. What is going to be the expectation over the way I sample the batches of my theta k, given theta 0, so where I am at the K step, given where I am at the 0 step. And so, this is like-- and actually, I lied, also, for a second reason in the metaphor because I'm doing this close by a manifold of a stationary area, OK? Because, actually, what I'm interested in and trying to convince you about is what I do in the stationary areas.
And what you will find is that SGD does steps at random that are at least centered around GD steps. But SGD without replacement, actually, that's an additional step and expectation toward. some certain direction. And so, how can I analyze mathematically-- I added these slides for the people that were craving about mathematics. But there's not going to be much mathematics. But anyways, so let me, in Taylor, expand at initialization, my k step, OK?
And then, I have that I'm doing k step of the gradient in the batch 1, batch 2, batch 3, batch i, batch k. But actually, I'm expanding also those gradient steps at initialization. So I have a first order part of those gradient steps if I was at theta 0. And then I have some additional terms that tell me how, actually, that gradient step changed. And the first part of those additional steps is, essentially, the action in the batches before that is how much the landscape changed by going-- the steps of the batches before, times the gradient step in the batch afterward.
And I'm just staring at these horrible equations. We can see that, when you do SGD with replacement, then all those gradients are centered around the gradient descent step. But they're noisy, OK? Well, actually, when I'm doing SGD without replacement, now allow me to say, k is the epoch. So it's the number of batches in the data set. Well, actually, this thing sums up to k gradient steps because I have exactly once every data point. And this is exactly k steps of gradient descent. And so, what I have is that this part is centered with noise for SGD with replacement. But this part is deterministic for SGD without replacement.
And now, let's look at the second part. This is also noisy and centered around that same part of the Taylor expansion for gradient descent for SGD with replacement. But it's different for SGD without replacement. So, here, what's the point of SGD without replacement? I know that a data point, if it's in the batch 1, it's not going to be in the batch 2. And this means that those batches are dependent, statistically.
And so, now, I have the expectation of a product of things. And this is not anymore the product of expectation. But it's like, OK, the part that is the product of expectation and the part that is that thing without the product of expectation. And that thing hopefully is small, because they're dependent but not that dependent. But actually, here, I get something different in expectation.
So essentially, what I'm saying here is that, in SGD with replacement, I have a noisy, centered steps here. And SGD without replacement, this is deterministic when I'm close to a manifold of minima. And here, not only do I have the oscillation coming from this part. But actually, I have also a bias. So in expectation, I'm moving due to that part. Is this clear? Yes.
AUDIENCE: Could you say a little bit more on the noisy, centered SGD with replacement?
PIERFRANCESCO BENEVENTANO: Yeah, it's centered in the sense that all those random variables are independent on each other. And their expectation is, I'm sub-sampling from the gradient, the full batch gradient. Their expectation, if I'm sub-sampling uniformly, their expectation is exactly the full batch gradient. So they're centered around the full batch gradient. They're independent. And there's some noise given by my sub-sampling, Cool?
OK, and so, what we found, let me skip on analyzing this bias. But essentially, when I analyze this bias, I have an action times a gradient. And this is, let me tell you, the derivative of the norm of the gradient squared. And so, essentially, I'm getting that there is the SGD part, and then another part that is essentially the trace of the covariance of the gradients.
And so, essentially, the behavior I'm saying is, in expectation, SGD is different, because I'm doing the GD steps. And then I'm adding up a small step in another direction. But I'm doing it in expectation. And essentially, what I want to convey as a message is that SGD with replacement very often goes in the same place where SGD without replacement goes, bad news, but at least bad news for me because my paper is less relevant. But SGD with replacement goes there by-- you oscillate a lot. And you end up in the place that is the widest because when you get there, then you oscillate less, OK? Mathematicians would kill me for this kind of reasoning. But this is essentially what it boils down to.
But SGD without replacement, it's oscillating less because it comes from the second order, the oscillation. And there, you have a learning rate squared. And the learning rate may be small. So it's oscillating less. So it has that effect, but less. But also, it has a bias, OK? So I'm moving where the SGD goes, and then to a second order in another direction. And actually, that direction, magical enough, aligns along the flat directions. So you compute the math. You try to compute that math for two years.
And then, at some point, I realized, I told you, a trace of the covariance. Actually, that's a weighted trace of the covariance. And those weights are along the eigenvectors of the small eigenvalues. And so, essentially, what's going on is that the GD part is taking me down. But this other thing is moving me across. And so, in expectation, I'm doing a small step, a step at every epoch. So maybe that's not relevant for online SGD. For online SGD, I believe it's relevant with replacement kind of thing. But I'm moving in this direction.
And what does it mean, I'm moving in this direction? It means that I'm doing a step there until I'm at the minimum of my trace of the covariance within this manifold of minimum, because if I move away, then the GD part is much bigger and drags me down again. So essentially, what I'm saying is this one, is this thing. So on the sharp direction, I follow SGD with replacement, or GD in expectation. On the flat directions, I have this regularizer.
And then, actually, I think I'm not saying for time is that, if the directions are too sharp, actually, this regularizer shoots you away. And I thought-- I was very excited. I thought it could be an explanation for why SGD doesn't train at the edge of stability. Sadly, I actually don't think that's the explanation, but yes. And again, I was saying, it's a weighted trace of the covariance. So this S is essentially a diagonal matrix. So I'm taking the trace, so the sum of the eigenvalues. And I'm putting some weights in front of the eigenvalues. Yes.
AUDIENCE: Do you have a sense of what is the magnitude of the trace of the covariance compared to [INAUDIBLE]?
PIERFRANCESCO BENEVENTANO: Yeah, let's do it offline. But essentially, the multiplier around this thing-- so, essentially, for problems that are linear in the parameters, this trace of the covariance is the same everywhere. So in linear regression, you don't see this effect. But when it starts being like-- I don't know. The way I see it is from a very optimization-- bad optimization perspective.
When the manifold tilts, then this starts happening and moves in the direction in which the manifold opens up. And to come back to the size of this thing, well, it depends on the size of the opening up, of course, but also the-- essentially, here, I have eta squared. But also, that sum is of k squared of terms, because I had every action and gradients and every action before.
So essentially, I have a divided by data set size, because this was coming up of the expectation. And then you multiply by the number of steps squared. And one of them kills it with the divided by data set size. And then, the other one, that would be too big anyways, impractical. It gets killed by learning rate squared. But still, you have an 0 of 1 step, in general. Did I answer your question? And also, I'm very proud to say that this analysis is very general.
So people, to try to get results like that, like where do I go to the river, usually, you used to put more assumptions on the model, on the injecting noise. And here, I have no assumptions like that. And I'm not limited to toy models, every function that is smooth enough. And also, previous analyses were saying, I'm analyzing k steps of time. So I have eta k to be very small and eta k times the action to be very small. And I kind of deleted this assumption from the thing.
So the implication number one is that, actually, this is the reason why I'm traveling the manifold. And I'm doing it faster and with less oscillation. And I hope I convinced you about it. And also, an implication that I like always to say is, well, escaping saddle. So people know that a lot of saddles of every degree in the landscape of neural networks. And people worked a lot to prove that SGD with replacement is escaping those saddles.
And there are some super technical results, or at least super technical to me, that say, if lambda is the smallest eigenvalue, the speed at which SGD with replacement escapes, it is lambda to the -3.5. And here, I'm saying, listen, I see the GD part doesn't move. But my bias, my small bias actually keeps moving in most of those saddles.
And so, under similar assumptions to the 1s, aka that are escaping directions, in which I can say that SGD with replacement escapes with the speed, actually, I can say, it escapes with this speed because, simply, the regularizer step doesn't stop moving. And so, the picture is, well, SGD with replacement is like, I shoot around, I shoot around. At some point, I'm far enough from the saddle that the GD part kicks me out. Here, the picture is, I keep moving because of my regularizer. And then, at some point, the GD will kick me down. Yes.
AUDIENCE: Is this for functions with strict saddles?
PIERFRANCESCO BENEVENTANO: This is for strict saddles. I also analyze the non-strict saddles. But this speed is for strict saddles. But for-- sorry?
AUDIENCE: [INAUDIBLE] is the smallest?
PIERFRANCESCO BENEVENTANO: The smallest eigenvalue of the action.
AUDIENCE: In terms of magnitude, or?
PIERFRANCESCO BENEVENTANO: In terms of the biggest of the negative ones, yeah, the biggest magnitude of the negative ones. And so, as I hopefully convinced you, there is a strong-- maybe they go to the same minima. But from a mathematical perspective, SGD even without replacement, and I put the [INAUDIBLE] line in the wave, they actually act differently.
And SGD wave, it moves around the manifold of loss equals 0 by oscillating. It's a completely stochastic kind of thing. But SGD without, it jumps around. But the regularizer points you in one direction. And the GD, it just touches it and stops there. Cool. So essentially, we understood what was going on here, I believe. Or this is kind of good enough for me. Yes?
AUDIENCE: How does this depend on the size of the set?
PIERFRANCESCO BENEVENTANO: The step, yes. So this question is a good question, in the sense that the way that step depends on the size of the data set is the way I told you. But actually, if the data set is bigger, this is essentially one step of regularization at every epoch. So if the data set is too big, essentially, you have to compute one step. And you're moving on that step.
But that's it, because you see only one epoch. So I can also tell you, maybe we converge if I allow to have infinite epochs. And I can prove some stuff like that. But actually, if you tell me, oh, my data set is one billion. And I see it once, or at most twice, then-- so, yeah.
But now, I want to get practical, because I was assuming not to have convinced you with this kind of thing because you know that there are infinite data. You don't care anymore. The first assumption on my first slide, there was like, we have a data set of size, n. And so, I want to show you that, actually you can prove interesting stuff with it. And this is a joint work with Tomi and with Andrea, which was a student that was at Tomi's lab for the past six months.
And the problem is the following. OK, vision data sets are usually very huge in dimensionality. And also, language data are very huge in dimensionality because, essentially, there is a dimension for every token, or a dimension for every pixel, or three dimensions for every pixel. But actually, all those pixels are relevant. And it's like, when we look at this, it's easy to notice that those pixels in the corner are not relevant. OK, so, probably, a good generalizing model will not depend on those pixels on the corner.
And when I look at this more difficult data set, is this still the case? Well, actually, it is. It's not any more pixels. But it's like, there are some linear combinations of pixels that don't tell us anything about the class in which we are. So, for example, if you take, I don't know, the upper corner minus the upper left, minus the upper right, plus the down right, minus the down left, this is zero to no information about the class you are in.
And so, we can see that, essentially, this doesn't tell us that there's a linear combination of that. That doesn't matter. But it's telling us, there is a non-linear combination of that that doesn't matter. But when you train a ResNet on CIFAR 10, and then you look at, what's the shape of the input of the head MLP? So you're doing many convolutions, many residual things. And then, at the end, you're doing linear regression on top of that.
And what's the shape of the stuff you get in this final linear regression? Actually, you see that the spectrum-- and then, you fit these linear regressions onto your problem or classifying. Then you see there are, essentially, just a few directions that matter. And most of them do not. And so, it's like, OK, maybe you can argue that, in general, my target function depends on many, many directions.
But I'm doing a theory of MLP. And those blocks within the models sometimes do not. And so the question is, how do we treat mathematically? How do I understand what's going on? Well, we can say how the support is the lowest dimensional manifold on which I can find a model that generalizes 95% of the time.
And you can, say, identify the support. I say the model identifies the support if, when I move in the direction that are orthogonal to this manifold, then my output of the model doesn't change. And then, to tell you more, I can say, the first layer identifies the support if, when I take two points in this manifold that is, like, I don't know, nine-dimensional, and then I add five in an orthogonal direction to this manifold and five in another orthogonal direction to the other data point, actually, the first layer considers those two data points the same.
So now, we're trying to understand, actually, our data sometimes is very high-dimensional. But our target function depends on just a few dimensions. And how does the neural network find those few dimensions? And empirically, Andrea found that, actually, most of the cases, the first layer is lower dimensional, and lower dimensional in a way-- lower dimensional in a way that puts all the weights along this, for example, linear space.
And this is, for example, when you train NIST with-- I know it's very toy. But when you train NIST with an MLP, what you find, actually, is that, at initialization, this is the distribution of the singular values of your first layer. Well, if you train it with SGD, with a big batch, then you have a convergence. You have a certain distribution. But if you train it with SGD with a very small batch, you have a very different distribution.
And already, where do I want to go? I want to go that maybe the size of the batch and the size of the step size actually bias the way in which neural networks find the support of the function. And the way Tomi would phrase this thing is a bigger batch size or a bigger-- sorry, a smaller batch size or a bigger learning rate actually biases towards low rank in the first layer.
And so, essentially, we run some experiments. This is still the empirical part. And unlike models that we know that are low dimensional, because we crafted them this way, we actually noticed that, OK, I'm plotting the matrix of the first weights, the weights of the first layer. And initialization is like this. Well, at GD, I'm finding a model that fits the data exactly as good as GD [INAUDIBLE] actually better, or exactly as good as SGD. But it's very close to the initialization one.
And then, as I drop the batch, the second part that is the one that interacts with the features of the input that are not relevant actually goes down and down and down. And then, when I do weight decay, of course, they converge to zero. And actually, my mental picture for this is, you remember those landscapes with that manifold and GD that was going down and stopping there? And SGD that was going down and then started moving along, well, this is essentially what happens here, or at least this is what I can prove that happens when you do linear networks because I'm not that good.
I cannot prove it for many, many models. But essentially, all these algorithms, they touch the manifold in a point that is the closest to initialization. And then, by moving around along the manifold of minima, what am I doing? I'm actually regularizing this way. The effect of that moving around that manifold of minimum toward the place on the first layer in linear networks, and also empirically on every network-- I can prove it on linear network, actually has this effect, zeroing out the stuff that we don't care about.
And so, essentially, I'm saying, smaller batches does it better. And we can see it in practice. Now, I'm training an MLP on MNIST. And we spoke about the fact that those inputs are not relevant. Well, when I train with a big batch SGD, this is the size of the way that you interact with that pixel there, OK?
Actually, my model is still interacting, or at least the first layer is still interacting. Then the model, by switching off many ReLUs, is still cutting out the information that comes from here most of the time. But the model is still interacting with those pixels. And as I go down, down, the batch actually stops interacting with them. Yes.
AUDIENCE: There must be like an optimal test site because, I guess if you take the batch size to 1. You will start having--
PIERFRANCESCO BENEVENTANO: If you take the batch size to 1, you don't converge anymore. If you converge, you do better. You do to a better minima. But you don't converge. Does it make sense?
AUDIENCE: Yeah, but which batch size is--
PIERFRANCESCO BENEVENTANO: It depends on the problem. It depends on how hard is your landscape to go down, essentially. And we can see it here. When you train, now, I'm looking at the size of those weights that interact with those pixels we don't care about. OK, I'm looking exactly at that. And I'm training with GD and SGD and anything. The trajectory is kind of the same until some point, some point in which they diverge.
And those pixels go a bit down. But then they stabilize for GD. And they go down there much faster for either the smaller batch or for the same batch size for bigger learning rate. And then there's weight decay that does better than everyone. But this is another story. It's like, we're playing basketball. And then LeBron comes to the game. And actually, so what we can see from this picture and from my theorems on linear networks is that, weight decay is the best. GD does it a bit, gets down those weights. But at some point, it stops and stops always too early, or at least very often. When you take the probability on the initialization, it stops too early.
And the speed of shrinking depends on learning rate and batch size. And there are two phases for SGD. And what are the two phases? I'm going down to the manifold. And I'm oscillating and moving around parallely to the manifold. And now, for the last five minutes, this is not going to be a talk of zoology. But I feel like there is a big elephant in the room. And this is my elephant in the room.
And if you like, the big elephant in the room is that, yesterday, during the panel, I got stabbed multiple times. And I got stabbed by Aaron and by Jacob. I don't know if he's here, where. And, OK, the strongest of the stabbings were, like, we should move forward from gradient-based methods. And it's, like, when that sentence arrived, me and Tom kind of melted on the thing. But, actually, I agree. It's just that I worked on that for four years. So now, I'm like, what next? I'll do my startup.
It's maybe a stabbing that we can try to address. And, fun fact, I will not. I will not really. It's like, oh, now we should move forward from implicit regularization. They don't say-- they didn't say exactly this. But all those theories for GD or SGD, we should move forward from that. And it's like, blah, blah, blah, blah, blah, blah, blah. And so, now, I was planning to try to convince people, by working some math on the iPad, that, actually, maybe doing this implicit regularization thing is still relevant.
And this is the part that I added yesterday from 1:00 AM to 3:00 AM, because of how the panel went. And this is also the reason why I switched from Camel to Marlborough. This is usually a very bad sign. And, essentially, I want you to think. So let's take this model again. And I'm not sure even I agree on the argument I'm trying to make for the-- I think I'll go two minutes over time-- the argument I'll try to make.
But I hope there's food for thought. And we may discuss it, and then write the paper. We still care about implicit regularization, or we don't care anymore about implicit regularization. It's like, when you do online SGD, in this model, there is one dimension. This may be the reason why this is still the case, maybe that everything breaks when it's multi-dimensional. Some question? No, no questions? OK.
Well, here, if I sample new x's and new y's, but that aligns around 1, 1, then I still have on expectation this manifold and still also a loss in expectation. All those points here are minima. And what is online SGD doing is essentially sampling IID. So from a mathematical perspective, the nature of this effect is exactly the one of SGD with replacement. So maybe those things that I'm saying and the thing that I'll prove in two minutes about the support still matters, now that we're doing online. We have infinite data, et cetera, et cetera, et cetera.
Then, when we move on from gradient-based methods, well, I'll try to sell myself to mathematics department, trying to convince them that I did some relevant mathematics instead of some relevant machine learning theory. But still, maybe, still, those stuff matters for online SGD. And, yes?
AUDIENCE: So the first one, and even the previous, yeah, the first one, each point here is one epoch? Or it's just, for example, for SGD, it's a batch?
PIERFRANCESCO BENEVENTANO: It is one epoch, I believe.
AUDIENCE: The second question is, we are trying-- we're training everything with the same learning rate?
PIERFRANCESCO BENEVENTANO: Yes, I'm keeping the learning rate constant, yes. And so, let me restart with, I have a data set. But now, I'm not saying the size of the data set, OK? I removed this yesterday at 2:46 AM. And let me say, oh, I have this following linear model in the input but non-linear in the parameters. And I want to see-- I do one step of GD. What is one step of GD? Well I think, either, I differentiate this thing. And I get, these two goes away. It appears that, of course, I forgot about it.
And then, I'm doing the derivative in theta. So I get, essentially, all of them but theta that get multiplied. And so, now, I'm saying, wait, wait, wait, wait, wait. I'm speaking of irrelevant components. So what I want to prove is that the weight that interacts with the components that are not relevant for the target function go to 0. So I'm saying, OK, actually, consider my y to be 0. Or the conditional expectation of my y given x along this direction being 0, this is then the way you frame it for multidimensional linear networks.
And now, I can write this thing. But instead of y, I put 0. And I forgot an x here. Otherwise, there should be an x squared that multiplies. But essentially, now, I can rewrite this. This is the product of all the thetas except the i-th layer. So this is the product of the stuff, all the other layers but the one in which I'm looking at, because every parameter is a layer now. And here, well, here I have the layer I'm looking at, times all the other ones.
And so, I can say, hey, wait, wait, wait, wait, wait, wait a second. This gradient descent step is actually multiplying 1 minus eta to the product of all the other layers squared. This gets very ugly when width is not 1. But anyways, this is the point-- times theta i. And so, I'm actually like, GD is shrinking, the parts of the layers that interact with the relevant components.
And if you look at this quantity, now, let's say. OK, but I randomly initialize those theta. And theta 3 is the smallest of all of them. So when you see theta 3 is the smallest of all of the ones, it means that the product of the j different and 3 of theta j squared is the biggest of the products of this kind. And theta j is the smallest of the products of this kind. And what does this mean?
It means that the one that is the smallest at initialization, the layer that is the smallest at initialization, actually converges to 0 faster than the others. And what does this mean? This means that GD identifies the support in any of the layers. I can do this thing for real, too. And for real, what happens is that, if there's a switch of sign, at some point, it's the real way to identify the support, by switching off. It's not even an intermediate layer.
But when I do SGD, when I do SGD and I have some assumption on the data or whatever that tells me that I'm oscillating, those are called edge of stability. Those are called, the linear model is mis-specified. Those are called, we have catapult effect, a.k.a. The mini-batch matrix, expectation of x-x transpose [INAUDIBLE]. Then what's happening is, I'm oscillating. Yes, yes, it's the last slide.
And what's going on in this case for SGD is that, now, you have to imagine this blue line to be the axis. So instead of having 10 and 1/10, I have 10, 0, 0, 0, 0, 10. And so, what happens also here is that-- well, I send to 0-- the one that initializes the smallest between the layers. But then, oscillating, I move towards the origin.
And so, after a long time, unfortunately, a long time, I sent a 0 to the second smallest. And after an even longer time, I sent a 0 to the third smallest, et cetera, et cetera, et cetera, until all of the layers are sent to 0. So if I have an assumption that tells me, the SGD dynamics oscillates, then SGD is strongly regularizing in infinite time because it's sending all of them to 0, not just the one that was initialized smallest. End of the story.
[APPLAUSE]