SGD and Weight Decay Secretly Compress Your Neural Network
- All Captioned Videos
- Brains, Minds and Machines Summer Course 2024
TOMER GALANTI: Hello, everyone. So my name is Tomer Galanti, and today, I will talk about some inductive biases in deep learning. So before I start, I just want to thank my colleagues that worked with me on this work.
And, yeah, so in deep learning, typically, we have some neural network, some model that we want to train to fit some data set. So we have the model, like this architecture. We have some training data. Let's say we have ImageNet classification problem. And we want to train the neural network, let's say, optimize it on the data to minimize some objective function, some loss function, right?
And the most simplistic way of viewing this problem is empirical risk minimization. We pick this architecture, and we try to minimize it. And then the output of the algorithm is basically like a minimizer of this objective function-- so some kind of a function that is able to correctly, let's say, classify all of the training samples in the training data.
But in reality, this is not exactly what we do. We have a very specific learning algorithm for optimizing our neural networks. We typically use gradient-based optimization techniques, like stochastic gradient descent plus regularization and so on. And we have different variations of these algorithms, like stochastic gradient descent and Adam and so on.
So there might be a difference between what we would theoretically learn with empirical risk minimization and stochastic gradient optimization-based methods. We could find many different solutions to this problem.
And this leads to a variety of questions about optimization, about what we learn with neural networks, and so on. And the type of questions that I'm curious about in this work, and other works as well, is what kind of structure emerges during training.
So we learned some set of weights. We have these parameters here that we call weights. And we learn representations of data. So this is kind of the functions that we compute in intermediate layers of the neural network.
And the question is, what is the structure here? Can we say something interesting about the weights that we learn? Can we say something interesting about the representations of the data in the intermediate layers?
And beyond that, maybe understanding these properties could lead to a better understanding of how to properly optimize neural networks and design architectures and things like that so we can have better compression or adaptation, transfer learning, and so on-- so all of the efficiency type of questions related to deep learning.
OK. So, for example, there is this question of compression in deep learning. Basically, when we think of neural networks, we try to train a very large neural network. In practice, we know that training large neural networks is a good idea. We're able to fit the data better. And there is also some kind of work related to generalization when it comes to large neural networks.
But they're very large. Large neural networks are large, which means that they're heavy computationally. And at inference time, maybe we want to apply the neural network many times. So if we think of something like an application like ChatGPT, where people use it a lot, we want a lightweight model instead. So at inference time, we might prefer to have a much smaller model.
And fortuantely, in practice, we see that neural networks are highly compressible. So there are many different ways in which neural networks tend to be compressible. For example, we know that when training neural networks, in many cases, we see that many of the weights are very small, close to zero. So essentially, we can prune them and still get a very good neural network that approximates the solution that we found.
Another approach is knowledge distillation, where we take this large neural network, and we fit a much smaller neural network to mimic the large neural network. So we get compression this way, and that works as well.
And another approach is to take advantage of low-rank structures in the weight matrices. So we know from empirical observations that the weights, the weight matrices in the different layers of the neural network are not necessarily full-rank. So essentially, we can take these matrices and represent them in a more succinct way using UV decomposition. So we can save parameters this way.
So these are all very interesting observations. And the first question I want to ask is, why should we care? Does anyone have any ideas? Not everyone at the same time.
OK, so I kind of mentioned it already. But basically, there is this efficiency question. We want to use as few memory as possible. Large models require lots of memory and so on.
The second thing is faster inference. So when we have a pruned network, a much smaller network, we can run it much faster. And another possibility is faster training. If we're able to, for example, prune the network after 10 epochs of training, very quickly, then we might continue training from that point. So we have a much smaller model to train from, from epoch 10 instead of epoch 300.
So this is a very important question. And the question is, why should we have a theory to understand this phenomenon? Why should we care about understanding why that happens and when and all of these questions?
And the reason why I think it's important is because, OK, we have these observations. We know that neural networks are generally compressible, but can we control that? Can we design the optimization method, the architecture, the data, whatever, such that we can ensure that the model that we end up with will be more compressible or less compressible and so on?
So we want to have some better insight of what's happening there. And the questions that come up out of it is why something like that happens. Is it because of the optimizer? Is it the regularization? Would it change when we change the regularization? Maybe both, the interaction between the optimization process and the regularization and so on?
Maybe it's the actual structure of the objective function, the critical points of the problem that we try to solve. And also, to what extent does it depend on the architecture and the data and so on? So this is the type of questions that I think are important to have a better understanding of how to control these kind of properties.
And that was not the first work around that area in this field. There are several papers that try to have a better theoretical understanding of these kind of properties.
One, for example, from 2019 is the Ji and Telgarsky paper that tried to prove some kind of a low-rank bias in deep learning. So suppose we have this linear model. So we have a deep linear model. And the objective is to minimize some logistic loss or exponential loss with gradient flow, which is an algorithm that is similar to stochastic gradient descent.
Then each one of the weight matrices converges to a rank-1 matrix. But the problem is that we have this result under the assumption that the data is linearly separable. So if the data is linearly separable, we get that each one of the layers is of rank 1.
So we train a model with large rank-- large dimensional weight matrixes, but we end up with very low dimensional matrices. So this is a very nice result, right? It's a very strong result. We get rank 1. We couldn't ask for more.
But the problem is that we make very strong assumptions. We assume that the weight matrices-- that the neural network is linear. We don't have activation functions. So it's very inexpressive. But also, we make a strong assumption that the data is linearly separable.
Another paper in this line is the Timor et al. paper from 2023. So suppose we have a ReLU neural network. So now it's a much more interesting model. And the objective now is to minimize the squared norm of the weights subject to fitting the data. So we want to find the solution to minimizing the squared norm of the weights subject to the model fitting each one of the training samples.
And we also make an assumption that with k layers, we can fit the data perfectly. Then, at the global optimum, all of the top layers above the kth layer are of rank at most 2 OK, so we get this nice result saying that the top layers are going to be low rank as long as the bottom layers are sufficiently expressive.
And, again, this is also kind of an unrealistic setting. First of all, this is not something that we typically do in practice. It's not clear how to even optimize that objective.
AUDIENCE: Can you explain that? Because, for me, it seems like regularization plus MSE or something like cross-entropy.
TOMER GALANTI: Yeah, so if you train, let's say, I don't know, MSE plus regularization, you don't necessarily get that.
AUDIENCE: Yeah, OK. But the objective is very common. [INAUDIBLE].
TOMER GALANTI: So the objective here is to minimize this quantity.
AUDIENCE: Oh, OK, so it's the opposite of the normal objective.
TOMER GALANTI: Yeah. It's like the dual thing. Yeah. So, yeah, it's kind of like in correlation with minimizing, let's say, MSE loss plus regularization, but it's not exactly the same thing. In practice, we do that and not the former.
And also, we know nothing about the bottom layers here, right? It only says something interesting about the top layers.
And another reason why I think this is kind of an interesting question is that in the community, there is some discussion about the role of this behavior. We have this low-rank bias in neural networks. I'm not arguing whether I agree with this statement or not, but some people believe that it might be related to why neural networks generalize well or not and so on.
So I'm giving this citation just to give the impression that this is something that people care about. All right. So the type of questions that I want to talk about today are the following.
So what is the source of the low-rank bias in deep learning? Can we identify an inherent behavior towards low rank in neural networks? Is it a property induced by the optimization, by stochastic gradient descent or gradient descent in general?
To what extent does it depend on the training hyperparameters? OK, so if we know that it depends on the, let's say, batch size or the learning rate and so on, we might be able to control the degree of this bias, right? We can ensure that the model has a lower rank and so on.
How much does it depend on the structure of the data? Is it data-dependent or independent of the data? And a little bit what is the relationship between low-rank bias and generalization, which I think Akshay will talk a little bit more about in his presentation. So I'll give some experiments to discuss.
OK, yeah, so because I'm curious about how optimization influences this bias, this behavior, I want to start with a simple experiment to see how, potentially, the hyperparameters could influence the rank of the weight matrices. So can we identify some kind of empirical observation here?
So we trained lots of ResNets for classification for a long time, and we varied the hyperparameters. So we trained with stochastic gradient descent, and we take the batch size, we vary the batch size or the learning rate and so on. We want to see if the rank changes.
And to calculate the rank, we basically count how many singular values of the weight matrices, after normalizing them, cross a certain threshold. So we fix this threshold, and we count how many singular values are larger than that threshold for the normalized weight matrix.
And we averaged the rank across all of the layers just to get one number. But we could do the analysis per layer individually. It's also fine.
And this is the type of experiments-- this is the type of results that we got. So let's say we fix the batch size to be 8, and we vary the learning rate and the degree of regularization. So if we fix the regularization here to be this value, and we increase the learning rate, then we see that the rank is shrinking. So the rank is getting smaller as we increase the learning rate.
AUDIENCE: What is n, just to give it like-- 200, is it almost full rank or it's also low rank, but relatively?
TOMER GALANTI: Yeah, so the rank is, I think, 256 here, the dimensions. Yeah. So--
AUDIENCE: Can you just remind which is which, the mu and the lambda?
TOMER GALANTI: Lambda is the regularization. Mu is the learning rate. And B is the batch size.
So let's say we fix the learning rate here and we increase lambda. Then the rank is also going down. So if we increase the regularization parameter, or the learning rate, then we see that the rank is going down. And a similar behavior also happens with the batch size.
So if we fix the learning rate here to be 0.6, and we increase the batch size, then the rank is actually increasing. So if we decrease the batch size, the rank is expected to go down.
AUDIENCE: And wouldn't you say that if you increase the weight decay, it makes a lot of sense that the rank would go down, regularization-wise?
TOMER GALANTI: Of the rank?
AUDIENCE: So maybe I'm wrong, but it makes a lot of sense that if the weight decay goes up, you have more regularization. So it makes sense that the rank of the matrices will go down just because you have more-- something pushes the whole matrix down.
TOMER GALANTI: Yeah, I mean--
AUDIENCE: In general.
TOMER GALANTI: I agree with you. It's like a form of regularization. And rank is kind of like a measure of complexity. But it's not necessarily obvious why the two would be specifically connected, right?
AUDIENCE: Like how? What for?
TOMER GALANTI: Yeah, because you could think of measures of complexity that L2 regularization would not be able to correlate with, right? Yeah, or another type of regularization that does correlate with rank and L2 won't. And all of these possibilities are on the table. Yeah.
AUDIENCE: Do these all have the same training? I don't really see references [INAUDIBLE].
TOMER GALANTI: Yeah, so in these experiments, I'm not showing the train and test performance. But of course, if you increase, let's say, the learning rate too much or the weight decay too much, then the training performance would deteriorate as well as the test. So I don't care about the train and test performance at this point. I only care about how the hyperparameters are going to influence the complexity. But absolutely, yeah, maybe you want to control both, the performance and the complexity.
AUDIENCE: Is there some intuition to why this range of variables, or what happens out of this range?
TOMER GALANTI: So it's a good question. We kind of picked these ranges to ensure that this is the range where things are happening. So if you go out of that range-- let's say the weight decay is too large-- you shrink to rank 0. It's going to be trivial in that sense.
But if you take it, let's say, too small, then you're in the range where you don't see any effect. So there is this transition period. But, yeah, it's not obvious where it's going to happen, depending on-- I think that depends on the data and things like that, the architecture. Yeah?
AUDIENCE: Yeah. I was wondering if you're going to talk more-- maybe you're going to talk about it. But what the intuition would be of, like, why are they having a small batch size plus a high learning rate, how that would be related to [INAUDIBLE], is that something you're going to talk about?
TOMER GALANTI: Yeah, I'll show a theorem with the proof and all that. So, yeah, I hope that it will be self-explanatory, but I'm happy to reiterate that question if it needs clarification.
Yeah, so this is the premise of this work, actually. So we try to give a very simple proof for this behavior, not just an empirical investigation, but also to be able to mathematically show that there is this tendency. And we argue that it applies to a very broad set of neural networks.
We consider very standard setting, where we train the neural network with stochastic gradient descent and L2 regularization, and we don't make any assumptions about the structure of the data, the number of samples, and so on. And we don't make any strong assumptions about the convergence of this process. So we try to keep it as standard as possible, as faithful to practice as we can.
So just to give you some intuition of what kind of models this result is applied to, so the neural network can include linear layers, residual connections, convolutional layers, and like any type of reasonable activation function. So it can be ReLU, sigmoid, softmax, whatever, pooling, even self-attention layers-- so a very broad set of architectures.
And this is basically the objective. So this is what we want to minimize. So when I talked earlier about the problem that we consider, we have some neural network hW. Let's think of a ResNet or a transformer. And the goal is to minimize the average error across training samples plus L2 regularization.
And we train with stochastic gradient descent. So each iteration is equal to the previous one minus the gradient times the learning rate of the regularized objective over some random batch of samples. So S tilde t is a randomly selected batch of size B from the set S.
OK, and when we want to write it more explicitly, it looks like this. We have Wt. We have the gradient across the training samples. And we have the gradient of the regularization. OK?
OK, so this is basically the theoretical result here. So what do we have? Let's say we have some differentiable loss function-- so let's say MSE loss or cross-entropy or something like that.
We have some neural network h, which is some architecture from the family I described earlier. We assume that the weights are being updated using stochastic gradient descent. So we have a learning rate, we have some regularization coefficient lambda, and we have the batch size B.
So for a second, let's ignore this assumption. But intuitively, after enough training-- so after a long time of training, we can ensure that we have the following inequality. So basically, we can approximate the weights at layer i-- so i is the layer-- at iteration t with some low-rank matrix up to an approximation error epsilon. And the rank of the approximator is bounded by this quantity, some function of the batch size and the learning rate and the weight decay.
So I'll go over it again. We have the weights at layer i. After t iterations, we normalize it, and we look at if we can approximate it with a low-rank matrix. And we can get this kind of approximation with a low-rank matrix with rank at most B times log of 2 over epsilon over 2 times lambda mu.
And basically, this kind of predicts these observations that we saw earlier, right? If the batch size is smaller, then the rank is smaller. If the learning rate or the weight decay are increased, then the rank is going to shrink, and so on.
AUDIENCE: In the experiment before, it seems like there was some cases that the rank goes to 0, right?
TOMER GALANTI: Right, yeah.
AUDIENCE: Here, 2 mu lambda is bigger than 1, right? Is it smaller than 1 because of your assumption?
TOMER GALANTI: Right.
AUDIENCE: Maybe it's not the same assumption, but I'm not sure. And B is bigger than 1, so I guess the minimum one should be 1, or maybe log 2?
TOMER GALANTI: Yeah. It's a bound. It's not necessarily tight. But, yeah, I mean, on the region where it applies, it's correct.
AUDIENCE: So here the minimal rank should be something like B log of 2? So it cannot go beyond below it, right?
TOMER GALANTI: Right, yeah, yeah. So this result cannot say, for example, that if you increase the learning rate or the weight decay too much, you'll go to 0. It just tells you it's going to be at most something, something, which might be actually large. But, yeah, with this result, we cannot say everything.
I just want to emphasize that this is not a trivial thing. If I have a neural network of width, like, let's say, a million, or a trillion, then because this quantity is independent of the dimensions, then it's going to be not trivial in these cases. But generally, it's a loose bound.
AUDIENCE: Are we not going to see the condition?
TOMER GALANTI: What?
AUDIENCE: The "if" that you said-- the limit with time infinite.
TOMER GALANTI: Oh, that?
AUDIENCE: Yeah.
TOMER GALANTI: Yeah, so I don't know how to prove this thing.
AUDIENCE: Oh, OK.
TOMER GALANTI: But empirically, I'll show some experiments validating this equation.
AUDIENCE: OK.
TOMER GALANTI: The idea is quite simple. So basically, what it says is that the norms of the weight matrices kind of converge.
AUDIENCE: OK.
TOMER GALANTI: That's it. So it's kind of some weak notion of convergence.
AUDIENCE: OK.
TOMER GALANTI: Yeah. But you could potentially extend it to, let's say, bounded by some alpha, and then you'll have some dependence on that alpha as well.
AUDIENCE: OK.
TOMER GALANTI: Yeah, if you want to further relax it. Yeah, so let's try to prove this theorem. Let's see why something like that would be possible.
So we know empirically that weight decay is necessary for this low-rank bias, right? Empirically, we saw that when we increase the weight decay, then the rank is going down. So why something like that would be possible-- and I think that the traditional intuition about weight decay is that it penalizes the norm of the weights. We just put a term to minimize that is proportional to the norm of the weights.
So obviously, we try to minimize that. But an alternative way to think about weight decay is that it changes the training dynamics of neural networks. We change the way we optimize.
And what I want to argue is that it controls how quickly we forget past iterations. So if we think about the optimization, then let's say we have iteration t. I want to argue that iteration t depends on a suffix of the past iterations. But it very weakly depends on what happened at the very beginning of training. So we quickly forget past iterations.
So let's try to see how that works out mathematically. So we have stochastic gradient descent for this objective function. And we can basically start by unrolling the iterations one by one. OK, so iteration t plus 1 is equal to that. We can take Wt and the gradient of the regularization and put them together.
So we have these two parts, right? And we can already start seeing what's happening here. Here the coefficient is 1, and here it's slightly smaller than 1, right?
So what I want to do is to do it again and again and so on and see what happens. So we'll unroll the optimization process and see if we can get something interesting out of it.
OK, so that was the first iteration. Now we'll apply the same process within this quantity. So we'll rewrite that using this equation when applied to t.
So this is exactly this thing. This 1 minus 2 times blah, blah, blah is this quantity. And here we have this gradient.
So when we put everything together, we get this equation. So we have a dependence on Wt minus 1. Now we have this coefficient, but it's squared instead of to the 1. And we have this kind of sum over many different gradients-- weighted gradients.
And if we do it k times-- so in general, when we unroll k iterations backwards, we get the following equation. So we can represent t plus 1 as a function of t minus k times some term that exponentially decays with the number of iterations we went backwards minus some mixture of gradients, the sum over many different gradients with different coefficients.
So if, for example, we assume that the norms of the weight matrices converge, which is this assumption that I mentioned-- so let's say it converges-- then from some iteration t onwards, if I fix k-- let's say k is 100-- then the norms of that and that should be very close to one another.
The ratio between them goes to 1. So this is much smaller than that in terms of the norm. So basically, if this is significantly smaller than that, then this thing should be approximately that. So we basically go k iterations backward.
And then we can say, well, this Wt plus 1 is approximately this mixture of many gradients. And this is exactly what we do here. And now the question is, OK, how should I choose k such that the approximation will be at most epsilon?
I want to get approximation error epsilon for that weight matrix. Then I want to pick k such that we will get epsilon distance. So that will be significantly smaller than that, smaller enough to ensure that the approximation is epsilon. And for that purpose, we can just pick k to be this kind of quantity.
And that gives us this first step, which basically says, well, the weights at iteration t can be approximated by the suffix of the iterations. So all of the training up to t minus k is essentially irrelevant. And we are left with the suffix of training.
And now the question is, OK, what is the rank of this quantity? We know that our weights can be approximated by the suffix, the suffix of the training. What is the rank of this mixture of many gradients?
And I just want to make a remark. So far, we didn't make any assumptions about the data or the architecture or anything like that. So I just want to say we didn't make any assumptions about the architecture so far or about the data.
Yeah, and now the question is, what is the rank of this sum of gradients? And for that, basically, we proved that the rank of each iteration is at most 1. So if I take the gradient of the loss function with respect to any weight matrix in the neural network, then the rank of that is bounded by 1.
And that applies for a very wide range of neural networks-- so any neural network that can be represented this way. So basically, let's say we have the l-th layer. So it's a matrix Wl times everything that comes before that layer. And then we have something on top of that.
So that's g. So we can rewrite h as kind of like a multiplication of the matrix by the layers below, and we apply some layers on top of it.
So for example, with fully connected neural networks, this could be the bottom l minus 1 fully connected layers, all of the layers below. Then we have the l-th health matrix. That's Wl. And then we have, let's say, sigma, Wl plus 1, and so on and so on. And that is captured by this gw.
So that's the type of architectures that we have. But it can also be applied for residual neural networks and transformers and so on.
OK, and the proof of this lemma is also super simple. So it's basically an application of the chain rule. So we know that h can be written as this gw over z, where z is whatever is inside-- so w times the previous layers.
So by the chain rule, we can take the gradient with respect to z and then z with respect to Wl. And this is exactly what we get here. So we have derivative with respect to z times the derivative of that with respect to Wl, which is U transposed.
And this is just the product of two vectors, which is a matrix of rank 1. So this way we get a rank 1 matrix for each one of these gradients.
I'll go back. Basically, what it says is that each one of these gradients is of rank 1. And we picked k to be this quantity. So we have k times B gradients. So k times B times 1 is the bound on the rank.
So that's why we got B times this quantity. And obviously, the rank there is 1. So we get the bound that I showed earlier, this thing. It's B times this k.
AUDIENCE: What happens if you have weight sharing over there?
TOMER GALANTI: So we avoid that for convolutional networks. You also have a factor of how many times you have the-- how many patches within that layer, how many shares, which is also some kind of a constant there.
AUDIENCE: I'm wondering if this general theorem gives some intuition why explicitly low-ranked fine-tuning approaches tend to work just as well as fully general fine-tuning.
TOMER GALANTI: It's a good question. I think the intuition for using LoRA and things like that is that the weight matrices that we learn are low-rank. So it makes sense that we do adaptation within a low-rank subspace because it's already low-rank. But direct connection I'm not seeing right now. But, yeah, it's a good question.
I think maybe there is this question of how the hyperparameters and the rank that they control, essentially, are going to influence the ability to do LoRA or not.
AUDIENCE: How does the norm of the weights come into play here if the norm would be [INAUDIBLE]?
TOMER GALANTI: You mean this thing?
AUDIENCE: Yes. The limit converges-- the difference-- the ratio between them converges to 1. But you have this alpha factor [INAUDIBLE].
TOMER GALANTI: You mean this?
AUDIENCE: Yeah, the norm-- you assume that the norm compared to this [INAUDIBLE].
TOMER GALANTI: Yeah, so basically, in the theorem presentation, I divide by the norm of that, right? And that is very close to alpha. So, yeah, if you formalize it, you'll get the theorem. Yeah, here. OK.
Yeah, so what are the implications from this result? What can we learn about deep learning about neural networks from this observation? I think one interesting thing is to see that there is this inherent tendency of stochastic gradient descent together with regularization, which is this very standard learning algorithm to compress the neural networks that we train.
So even if we don't necessarily get a highly compressible neural network during training, there is this inherent behavior that happens behind the scenes that has some influence. And the question is how, basically, the hyperparameters are influencing that, which is what we are showing here.
And another interesting thing is that it seems that stochastic gradient descent and stochastic gradient descent plus weight decay learn very different types of representations. So in one case, we know that the rank of the weight matrices is not necessarily full rank, right? The matrices are not full rank. So the dimensionality of the representations might be low.
And in the other case, we don't have this bias. So the learned representations are not necessarily biased to be low-dimensional.
So that is, I think, an important thing to note, because in the literature, we see that people sometimes analyze neural networks in the lens of gradient flow. And in gradient flow, it's much more similar to stochastic gradient descent, or actually gradient descent with no batches. And we see that the batches, the batch size, and the weight decay have a lot of influence on what we actually learn. So we have to be very careful when we do analysis of the representations and what we learn and so on.
Another thing is that the ranks are controlled by the hyperparameters. We can actually tune the hyperparameters to ensure that the neural network is going to be more or less compressible. And empirically, we see that this effect seems to disappear when we don't use regularization. So it depends on regularization.
And that also relates somehow to scaling laws for neural networks. We know that these kind of parameters are going to control the scale of the complexity of the network. We can tune the parameters to ensure that the complexity is bounded in a certain way.
So this kind of recaps. As we can see, this aligns with the theory that we developed. And maybe a different visualization is looking at how the ranks are evolving during training.
So we start at a certain point. At initialization, the matrices are full. We have training with different values for the weight decay. So we vary lambda. And we can see that when we train with zero weight decay, then the rank is not changing.
And the higher the weight decay, then the smaller the rank is going to be. And that is also reflected with the batch size. The smaller the batch size, the smaller the rank is going to be.
AUDIENCE: So because all of them seem to converge early during training, what's the trade-off between the rank and the performance? Let's say you stop at 450. Do you have a prediction that sometimes stopping early is kind of a low rank?
TOMER GALANTI: So this is a very good question. The theorem I showed is essentially oblivious to the performance. It just tells you how rank behaves as a function of hyperparameters.
And it's a very good question I don't have a good answer to. Empirically, what we see, for example, here is that if I pick hyperparameter-- so I fix everything. Let's say I fix the weight decay and the learning rate, and I only vary the batch size.
But I ensure that all of the models are going to perform the same. Let's say they all perfectly fit the training data. So this is what we see here.
All of the models are achieving 100% performance. But they have different batch sizes. So the one with the smallest batch size is the one that achieves the highest test performance.
So it's actually somewhat correlated to the test performance, but you have to ensure that all of the models are fitting the training data perfectly. If that wouldn't be the case-- let's say I would take-- I don't know, let's say the smallest one here would be 100. And I then train with batch size 1, and I get training performance because of that of 85%. Then you cannot expect the test to be very good, right?
So it's kind of like a balance. But the interesting observation is that it tells you that rank is important for understanding the performance, and it might be helpful to monitor that also with respect to the hyperparameters to get better performance, to tune the parameters correctly and so on.
Yeah, so these are more experiments of the same kind-- not so insightful at this point. Yeah, and now I wanted to also show an experiment of this assumption I made with the limit.
So basically, we plot the ratio between the norm of the weights at iteration t and t plus 1. And we wanted to see if it converges to 1, as we assume in the statement of the theorem. So we look at each one of the layers of the neural network, and we plot this ratio across epochs.
So we actually don't even care about subsequent iterations. We took subsequent epochs, which gives many iterations in between them. So there could be even more variation.
And we can see that it converges to something very close to 1. And we tried doing that across many architectures and settings, and it seems to happen very generically.
So this is the plot about the test performance. Yeah, so as I mentioned, if we ensure that the models fit the training data, then, picking the batch size or learning rate or weight decay that gives you the lowest rank is going to give you the best performance empirically.
So basically, we showed-- we justified empirical observations that there is a correlation between the batch size and learning rate and weight decay and performance. And I would say that rank minimization cannot exclusively explain generalization in deep learning as well.
And the reason for that is because we empirically observed that this is something that depends on having, let's say, regularization, right. But we can still generalize quite well without L2 regularization. So I personally don't think that rank minimization is the reason why neural networks are generalizing well, even though that was a hypothesis that was raised in the community.
And to conclude, yeah, so we provided some theoretical framework for identifying inductive biases in neural networks, like this low-rank bias. But it could be interesting to see if there are other types of properties that could be proven from this kind of analysis.
We showed that low-rank bias is somewhat of a universal property of stochastic gradient descent and regularization. It's independent of the architecture, for the most part, and of the structure of the data.
And the ranks are essentially independent of the width. So we have this bound that is independent of the width of the neural network. And it's also related to implicit regularization and compression and so on.
And I think there are some interesting open questions that follow that. So first of all, I mentioned maybe other properties, but maybe other optimizers have different properties that emerge during training. That would be interesting to explore, or they have different scales for the minimization of rank and so on.
And another question is how fast this low-rank bias happens. So when can we say that the rank is convergent? So is it after 100 epochs or 500 epochs?
That could be crucial for, let's say, compressing the neural network at the right time and continuing training from that point onward. And that is also related to translating this kind of behavior into regularization and compression techniques.
So we saw that the test performance is somewhat influenced by this behavior. Can we take advantage of that to do hyperparameter selection in a smarter way? Yeah, so I'm done. Thank you.
[APPLAUSE]