Dense Associative Memory in Machine Learning
Date Posted:
February 5, 2024
Date Recorded:
August 12, 2023
Speaker(s):
Dmitry Krotov, MIT / IBM Research
All Captioned Videos Brains, Minds and Machines Summer Course 2023
AKSHAY: Now we'll hear from Dmitri Krotov. He's a physicist working on neural networks and machine learning currently at the MIT-IBM Watson AI Lab. And before that, he worked with John Hopfield at Princeton at the Institute for Advanced Study. And, yeah, he has a lot of interesting work on associative memories in learning. And, well, it'll be-- I think it'll be good to hear from him.
DMITRI KROTOV: Thanks so much for the invitation. It's honestly my first time in Woods Hole, and nature is gorgeous. Scientific discussions are fantastic. So thanks a lot for the invitation. And the thing that I want to say is that I really enjoyed like how Tommy Poggio put in his introductory lecture of this call for physics-inspired understanding of what is happening in machine learning.
And I'm a theoretical physicist myself, as Akshay just said. And as you will see, the models that I'm going to tell you about today, they are heavily inspired by many of the ideas from statistical physics, such as spin glasses, Ising models, et cetera, et cetera. But there is a little bit of a difference between these kinds of models and a-- more conventional feed-forward neural networks. And the difference is somewhat historic.
So typically when we talk about typically-used models in AI, like transformers, they first gained empirical success. And only after that they attracted attention of theoreticians, who studied them in-depth, tried to understand, et cetera, et cetera. When it comes to associative memories, and even more so to Hopfield Network, the situation is a little bit of the opposite.
There has been a lot of theoretical research that has been put into these ideas in the '80s and '90s. But, somehow, for a long time, there was this perception that these are nice toy models that we can write some interesting formulas about. We can maybe solve some simple tasks. But people thought that they are not really particularly useful for solving large-scale machine learning problems.
And in some sense, what is interesting about the time that we are living through today is that this situation is beginning to change. And we are seeing more and more examples of meaningful and successful use cases of these ideas for solving some interesting machine learning problems. So let me start with a definition. So what is a Hopfield Network of associative memory?
And the right way to think about it is a recurrent neural network. So something like a LSTM is a good example. And what it means is that, at any given moment of time, there is a state vector. And that state vector evolves in time. And the state vector can be continuous. It can be binary. It doesn't really matter. These are details. The time can also be either discrete or continuous. These are also somewhat of a technical detail. Both options are possible.
But what is important is that there is this state vector that changes in time according to some highly nonlinear dynamics. And these nonlinear dynamics becomes extremely important in what I'm about to tell you. So what is cool and distinct about Hopfield Networks-- as opposed to LSTMs, or vanilla RNNs, or gated recurrent units-- is that Hopfield Network have a single scalar value, which is called the energy function.
And whatever this nonlinear dynamics is doing, the energy function is only allowed to decrease in time. So it never is allowed to go up. So essentially, you can think about the operation of this network as the state of this Red Bull that is located somewhere high up on this energy landscape. And I sort of projected this complicated multi-dimensional energy landscape on a two-dimensional plane.
And the energy landscape has a bunch of local minima. So the minima here are denoted by this letters psi 1, psi 2, psi 3, and psi 4. And in the language of Hopfield Networks, these vectors-- and each of them is a high-dimensional vector, by the way. So these vectors are called memories.
Now, the reason why they are called memories is because you can think about this whole temporal dynamics-- when you start from some high-energy state and you go down along the hill of the energy landscape-- as a process of recovering the memory. So imagine that I give you some hint, like I-- for example, I take ChatGPT. I put a little prompt, which is like the first few sentences from, I don't know, Crime and Punishment, or something like that.
And then I let ChatGPT run. And then ChatGPT recalls like the whole few paragraphs after that. So in some sense, the initial state of this network is a little bit like a prompt. And then this whole sort of dynamical output that the model produces is a little bit like this process of memorization-- what the model has been exposed in the past.
And for memories, there is, again, many choices here. So you can treat memories as individual instances of events from the past. For example, you know, I see this image of this room. That's one option. But you can also think about them as more consolidated representations.
So I have seen all kinds of things in the past, all kinds of rooms. And the memory could be some kind of consolidated version of a general concept of a room. So there is a little bit of versatility here, in terms of how we use these models.
So what is important, however, that there are many of those memories. And typically, one of the objective of improving these models is to be able to store as many high-dimensional vectors in the energy landscape as possible. And essentially, by giving a sufficiently rich prompt, we expect this network to recover one of those memories. So-- and by the way, one more thing.
The reason why this idea is called associative memory is because you can think about the initial state where this red ball starts and the final point where this red ball ends up being in the one-- in the limit of t goes to infinity as a kind of association between initial state and the final state. So that's the rationale behind calling this idea associative memory.
So let me show you a couple of conventional use cases of this idea. So the first use case is completion of incomplete patterns. And this notion has just been discussed in the previous talk in a very nice and concise mathematical model. And what it means is the following, that if you start from some initial state, which, say, represents 3/4 of the pattern, and then if I let this red ball roll down the hill along the energy, the dynamics essentially recovers the missing parts, right?
That's the pattern completion task. So that's one conventional use case of associative memory models. The second equally conventional use case is denoising. And in the denoising setting, we again start with some prompt, but we add a huge amount of noise to it. And, again, we run temporal dynamics. And you can see that associative memory manages to recover the correct pattern from the noise.
You always set like where you initially position. The point is the initial state. After you release the point, it only goes down. So here, for example, you're asking what happened between this example and that one. Yeah, it recovers the whole image. And then you have to reset it.
If you want to recover the next image, then you have to reset it. Again, you give it a new prompt. You reinitialize it, maybe, if you need to, but usually it's not necessary. And then it goes down again. And the problem of local minima, I'm going to discuss that a lot. And in fact, the local minima is the thing that prevented Hopfield Networks from being successful in the past.
But now I believe we have a solution to the local minima. So in the modern reincarnation of Hopfield Networks, essentially there are no local minima. There are only good minima. But let me get to that. Is there any randomness?
AUDIENCE: [INAUDIBLE]
DMITRI KROTOV: Right. So you have a right to inject randomness pretty much at any step. You can start with some prompt, and you can add noise to that prompt. You can also add longitudinal noise at every step of the dynamics because, eventually, the system is going to be described by some differential equation. And you have full right to add random noise to it.
And maybe it's not even Gaussian noise. It's completely within your hands. And of course, if you add a random noise, then, with small probability, you can actually go up, right? So again, at this level, I understand these models in a pretty broad sense. So generally, it goes down. But occasionally, if you introduce enough noise, it can jump up a little bit. OK.
And so here, kind of to conclude with this slide, hopefully the second example immediately reminds you of diffusion models again, right? Because in the diffusion models, you do a little bit-- you solve a little bit of a similar task. You start with a noisy image, and you and you then progress to the end. So what I'm going to try to do-- I don't know. I guess I cannot do this.
But anyway-- so what I usually try to emphasize in this slide also, that if you pay attention carefully to the second simulation on the right, you might be able to see the same bird that you can see on the left somewhere in the middle of the dynamics. Yeah. And and that's kind of, indeed, how this network works, because it's not just the thing selects the best match.
It's actually a quite sophisticated dynamical trajectory that goes around the whole space and explores multiple memories. So it's a highly collective property of many memories and many neurons. And that's what, essentially, associative memories are. So now, let's make this idea little bit more mathematically defined. And this idea has been famously formalized by John Hopfield in the '80s.
And what he proposed is that you can describe the-- first of all, let's talk about the state space. So what he proposed is that you can treat the activities of neurons as binary vectors, and that's pretty much in accord with how this model-- a similar kind of model was introduced in the previous talk. So essentially, the vector sigma i. So index i correspond to the number of neurons. And in total, I have N neurons-- capital N neurons.
And essentially, every neuron-- every value of sigma i is equal to either minus 1 or plus 1. Minus 1, in neuroscience terms, corresponds to a silent neuron, and plus 1 corresponds to the neuron that fires action potential. So it's completely like discrete system so far.
So what we can do is the following. We can take memories-- remember, these psis are memories. In the upper index, mu, that runs from 1 to K indicates which patterns we write into the memory. So essentially, mu equals 1 is the bird mu plus two is the boat et cetera. Index i here is the index that runs from 1 to N. It's this guy.
And essentially, this index i indicates the projection of the memory vector into the internal space in which this model operates. And essentially, what Hopfield said is the following. Let's construct this symmetric matrix. So essentially, I'm going to take each of those vectors psi, and I'm going to compute an outer product of those vectors and sum over mu, right?
And if we take that matrix and we plug it into the energy function, which is defined this way-- so it's a quadratic form in terms of sigma-- then, essentially, by flipping those spins from minus 1's and plus 1s at random in such a way that every spin flip is only approved if the energy goes down, you can recover the local minimum. And that's what sort of the standard Hopfield Network of associative memory is.
So it turns out that it works as I have described, more or less. But only if the number of patterns-- the number K that we are trying to put into the network-- is sufficiently small. So specifically-- and this has been shown by numerous people, including Haim Sompolinsky, who is in this-- I don't know if he's right here. But he's somewhere around in a parallel course.
It has been shown that it only works if the number of patterns that we encode this way is smaller than, roughly speaking, the linear number of neurons. And there is a precise coefficient that people have derived using multiple techniques over 0.14. So essentially, there is no way that you can store more than linear number of memories if you are given a certain budget of neurons.
And if you think about it, it's a little bit problematic from the perspective of practical machine learning applications for the following reason. Because imagine, for example, that you are trying to store, say, Japanese kanji characters, and you have somehow digitized them on a small grid of 28-by-28 pixels, right-- like black and white images.
So what this model will tell you, that then-- if you treat every pixel as a neuron, then we have 784 neurons, right? And what this bound will tell you, that you can store at most, say, hundreds of patterns. And that's clearly an-- absurd. Because if you talk to, say, a Japanese person, in order to meaningfully read the newspaper, they need to know a couple of thousands of those characters. And they can meaningfully distinguish them.
So there is definitely like a big discrepancy between this old Hopfield Network and psychological expectations of what this model-- what a good model should be doing. And a couple of years ago, we started to think how to improve this bound. But in contrast to what other people have tried before, we're not going after this coefficient. But we are going to go after this scaling relationship. So we want to really dramatically scale up the memory storage capacity of this system.
Yeah. So this bound specifically allows you a small number of errors-- like fixed percent of errors in the final recovery. If you want a perfect recovery, you need to divide that by log N. Yeah. So it's going to be N divided by log N, and it will be slightly different coefficient-- not 0.14, but slightly different coefficient in front. Yeah. Yeah, right. Correct.
There are lots of neurons. But if you take this model in a naive way-- so each pixel correspond to a neuron, right? You would like to be able to store more patterns than you have neurons, right? So it would be a nice property to have. Of course, I'm not saying that this is the model of a Japanese person's brain, right? That's not what I'm saying.
But what I'm saying, that there are many machine learning applications and there are many examples in biology where a simple network should be able to store more patterns than the number of neurons. So it's quite nice property to have. And there is a mathematical question. Can we do it or not? And what I'm about to tell you that, in fact, there is an easy way to do that.
Because of course-- like, look at this formula, right? You give me any number K of vector psis right? I can write the top formula. There is no problem, right? The question is, can I extract it? So if I start with a random state in space of sigmas, and I start flip and spins one at a time with always trying to decrease the energy, the question is, eventually, will I end up being in one of the psis or not?
And the answer is that if K is small, then you will. And if K is bigger than this bound, then you will end up-- I'm not afraid to say this word to you-- in a spin glass state. [CHUCKLING] Yes. So, indeed, you will end up being in some kind of state which is-- looks very irregular and that have nothing to do with the patterns that you are trying to store. Like the overlap with the pattern that you're trying to store will be of the order square root of N, while, for perfect recovery, you need of order N.
Right. So moving on. So how can we improve on this results? And the answer is the following, that it's one of the answers. There could be-- probably there could be other answers. But one of the answers is that we need to introduce more nonlinear interactions here. We want to extend these energy functions to nonquadratic forms.
So essentially, we want to introduce cubic interactions between sigma quartic interactions, et cetera, et cetera. Yes. So let me get to F first, but you're right. F can be different to every mu. Although for normal applications, you would not want to do that. But let me get to F, OK?
So let's talk about F. So essentially-- let's actually go back for a second to this model first. So let's take this expression, 4Tij, plug it into this formula, and rearrange the sums over i, j, and mu. So if you do that, then you can hopefully see that you can easily rewrite this energy in this form. So it's essentially a dot product between sigma and each individual memory.
You sum over the neurons, and then you square that overlap. And you sum those overlaps over all the memories. So that's-- these three formulas are equivalent, right? So now, let's imagine that instead of putting a quadratic function here, we put something more aggressively growing, like exponential function or power function. And let's start with the power function.
Let's put a function, F of x, that scales like x to the power n. And the little parameter n, let's say that it's an integer. And this integer is bigger or equal than 2. So if little n is equal to 2, then we recover the old Hopfield Network. However, if n is bigger than 2, then you can see that-- for example, if n is equal to 3, then we get a cubic interaction in the energy function. And that's, of course, an entirely different world compared to the quadratic system.
So it turns out that if you take this idea and you repeat the calculation that leads to this result, essentially you can prove that the maximum number of memory now scales like a power law. So essentially, this power is equal to this parameter, n minus 1. So if little n is equal to 2, then you recover the old linear scaling relationship. But if n is bigger than 2, then the memory storage capacity grows much more rapidly as a function of n. And that's what we want.
So moreover, you can go even further. You can go to an exponential function. And for an exponential function, you can even get an exponential scaling. So essentially, you can get a lot of memories into-- and you can pack them into the same amount of configuration space, which is 2 to the power n. It's like-- as you can see, all the operations here are pretty straightforward.
It's the same thing. You take the dot product between the two, and then you just pass it through some non-linearity or the quadratic. So to the extent that you need to-- it takes a little more flops to compute a cubic function than quadratic, of course, right-- if you go to a hardware level. But it's not like a huge amount.
Because the initial system-- the main problem with the initial system was that it got stuck in local minima. And what is cool about the system on the right, that it doesn't. So it's actually-- the convergence speed is actually faster.
AUDIENCE: [INAUDIBLE]
DMITRI KROTOV: Yes. Once you are in the local-- inside the basin of attraction, you will go down. So it's hard to explore, but let's get to that. We'll get around that point. In fact, there are many examples where you can construct, say, i's in a certain way. And then you can show that you can store many more memories than linear, even in the quadratic case. But this, what I'm telling you about right now, is valid for arbitrary psis, right?
So-- with a caveat. So like this precise scaling relationship is derived for random psis. But even if the psis are extremely heavily correlated-- take like, I don't know, binarized MNIST, or something like that-- you would still have a superlinear behavior. So it would not probably be like as rapid as predicted by this formula, because this formula is derived for random patterns. But it would be still noticeably superlinear.
So it depends on what you call strong and weak memory. Because notice that because because x psis minus 1's and plus 1, their L2 length is equal, right? So there is no-- they're not like one is longer than the other in the absolute value. But I guess you might mean that one of the memories is closer than everything else, right-- to the initial state. Yes.
Yes, and that's exactly the thing that we are using. We are trying to penalize the contribution of memories that are far away and select the memory that is the closest one. It depends on the task that you're trying to solve. It's possible to design a machine learning task where you would not want to go to extremely crazily-growing activation functions.
But what I'm saying is that there is an interesting regime between the quadratic function and the most rapid-growing function that you can imagine that is kind of sufficiently broad. And within that regime, you can find a function that is likely to suit the needs of your specific application that you're having in mind.
Yeah. So in some sense, think about this function as a hyperparameter, if you wish. OK, fine. So let's move on. So sort of the-- and I guess the last thing that I want to say here, maybe you have heard about words like dense associative memories. Sometimes people call them modern Hopfield Networks.
So this is all the same thing. This is the idea that is described on the right, and it has multiple names. Right-- and, by the way, the rationale why we call it dense associative memory-- the rationale is the following. Because, essentially, the network operates in the same amount of configuration space. For N neurons, it's 2 to the power n possible configurations.
But in the model on the left, the memories are sparse in that space. And in the model on the right, you pack a lot more memories in the same amount of configuration space. So they are much more dense. That's the rationale for the word dense. OK. Let me move on.
So of course, when we do deep learning, it would be terrible thing to do deep learning with binary patterns. We like everything to be fully differentiable. We like to do back propagation. So somehow, things need to be more continuous. And it turns out that, essentially, it is possible to take this general rationale about making certain activation functions more steep. And it is possible to convert-- to rewrite the system as a set of coupled nonlinear differential equations.
So essentially, here the i's are internal state of the neurons. And f mu's are activations. So for people who do computational neuroscience, you can think about vi, for example, as a current that runs through the neuron, and about f mu as the firing rate of that neuron. So there is some transfer function, and that transfer function can be highly nonlinear.
But essentially, without going too much into detail, let me state the following, that, by and large, the phenomenon of increasing the memory storage capacity of this network can be translated from the binary variables to continuous variables. There is a lot of work that goes into that statement, and there are many papers about that. But essentially, that's a, roughly speaking hole.
So essentially, whatever improvements in capacity you can imagine in the binary world, you can improve-- you can imagine, roughly speaking, equal improvement in the capacity in the continuous variables. So now, let me fast forward a few years. And what I want to say, that, of course, everyone heard about transformers, right? Like, transformers are everywhere-- ChatGPT, et cetera. Very important model.
So around 2020, a group led by Sepp Hochreiter, who is a co-inventor of LSTM, noticed that if you take the dense associative memory model and you replace this activation function, which is like a power function, with a softmax, then, essentially, dense associative memory turns into the self-attention operation in transformers.
And so essentially, it's like-- it's almost like a tautological statement. Because essentially, a softmax is an example of a very rapidly-growing activation function that is very sharply peaked around memories. And the memories in this correspondence would correspond to keys in the transformer. And the vector sigma i, which is the state vector, which corre-- would correspond to the queries.
And the value matrix would also correspond to memories, as well, in some-- a little bit convoluted way. But essentially, the main theory that people have developed for dense associative memories can be to some extent translated to the world of self-attention. And there is a couple of reasons why this is an interesting statement.
The reason number one is that you can think about pre-trained models. So you can take, I don't know, BERT or some other, say, vision transformer. And you can map it back onto the language of Hopfield Networks. And you can ask, where are memories there? Where are basins of attraction? How can we think about this, right?
So it sort of presents a certain theoretical mindset for thinking about transformers more in the language of theoretical concepts that existed in a different field. So that's one reason why this is interesting. So the second reason why this is interesting, that it sort of opens up a little bit of an area of search for new architectures.
So you can think the following way. What if, indeed, transformers are just sophisticated models of associative memory? Then can we build better transformers? Can we somehow improve self-attention? Can we pack more memories in them, et cetera, et cetera?
And that is-- the second question is something that I'm very interested in. And essentially, on the next few slides, I'm going to tell you about our efforts on that front-- how can we build better transformers from the theoretical perspective of dense associative memories? So you see this function F. So this function F needs to be the integral of the softmax. So it needs to be log of the sum of the exponents.
So if you take this function to be log of the sum of the exponents, and then you differentiate that function with respect to sigma-- and close your eyes, and assume that sigma is a continuous variable for a second-- then you will get exactly the self-attention operation. Indeed, for correlated patterns, the memory retrieval is always worse than for uncorrelated. Because the patterns confuse the network.
If they're too correlated, then it's hard to distinguish between them. And that's bad for retrieval. However, although the scaling might be different from the theoretically-derived one for the random patterns, for correlated patterns, it is still approximately the same, right? Like maybe the power is not quite n, but a little bit lower.
But what is important, it's completely superlinear-- noticeably superlinear. Because essentially, the function F is somewhat set over there. It sort of picks the best match. So think about it like this. Let's say that you have one of the psis, and I give you sigma. So essentially, when you compute the overlaps with all the memories, there will be one which is the largest.
And then, if I pick this activation function to be very rapidly growing, then, essentially, that largest one will win. And everyone else will be surprised. And that's essentially the idea here. So we want to do something of that sort. And in some sense, self-attention does exactly that. Yeah.
OK. So let me move on. So let's go to transformers. I'm going to introduce a new architecture that we called energy transformer. And as you will see, this energy transformer is fundamentally designed around the ideas that I have just described. So essentially, I'm going to model the output as a gradient descent dynamics.
So I'm going to define the energy function, and I'm going to update my tokens. So x will now be tokens, like in the conventional language transformer or vision transformer, that are updated in time. So what is not conventional about this transformer is that it is a recurrent neural network. So essentially, unlike having multiple layers of transformers stacked on top of each other, I'm going to take the output after each update, and I'm going to send it back as the input to the same block, right?
And essentially, I'm going to try to design the energy function in the following way. So first of all, my energy function should decrease as the dynamics of this token update progresses in time. That's one desideratum. The second desideratum would be that the update of this axis should look more or less like transformer. So it needs to have self-attention. It needs to have MLP. It needs to have skip connections. It needs to have layer nominalizations-- all the stuff that we like in transformers.
So the question is, can I design such an energy function or not? And it turns out that I can. And, by the way, for the use case of this model, I'm going to focus on images. So I'm going to use the typical VIT settings. So essentially, you give me an image. I split the image into patches in a completely like obvious way, and then I tokenize every patch by passing it through some simple linear transformation.
And then I add to those tokens positional embeddings. Then some of the tokens are replaced with the mask tokens, and that would be the initial state to my energy transformer. Then I follow the energy descent dynamics. And this is the energy, and on the x-axis is time. And you can see that, as dynamics progresses, the energy transformer is expected to recover the missing patches. So that's the desired behavior that we want to have.
Again, I should probably make it stronger. This model not only reminds us about transformers, but also about diffusion. You're 100% correct, yes. Yes, everything here is continuous. So, essentially, you can write this update as dxdT is equal to minus gradient of E with respect to a not quite x, but another variable that I'll explain in a moment. It's precisely mask autoencoder.
But the only difference is that in the mask autoencoders, they use conventional transformer block. And my block will be slightly different because I want the energy. And actually, let me emphasize this very strongly. Because whenever I show to the machine learning audience the energy, the associative memory-- the biological associative memory of the listener is immediately associated with the loss.
And energy and loss have nothing to do with each other. Loss is defined in the space of weights. We optimize the loss with respect to the weights. We optimize energy with respect to the states of the neural network, right? So everything that I have described so far is inference time.
And there is a huge question. How do we train these systems? Now, I'm answering Boris's question. How do we train this question. And essentially, we train in the following way. We start with some image from the ImageNet data set, right? We run these dynamics. We get something.
Then we take that something, we compare it with the ground truth. We compute the L2 loss, and then we back propagate that L2 loss with respect to memories now, right? So the forward pass in this network is described by the energy descent dynamics. And the training pass is described by the loss descent dynamics, right?
So we now have two functions. One is being optimized in the inference time, and the other one is being optimized in the training time. Then the number of parameters is exactly the same. And to understand that-- to understand that, let's go back to this slide. Sorry.
So, essentially, think about this formula versus this formula. So my parameters are psi. Sometimes people write Hopfield Networks in terms of [? Tij, ?] and then it naively looks like it has more parameters or less parameters, depending on how you count. But if you write it in terms of memories, then the number of parameters is exactly the same. And in some sense, Ts are not fundamentals. Fundamental is like x psis
OK. So let me move on here. So now I'm going to explain to you how I constructed energy transformer block. And to explain that, I'm going to contrast it with the conventional transformer block that you can see on the left. So the conventional one, probably everyone knows, right?
There is a multi-head attention operation. There is a multilayer perceptron. Multilayer perceptron operates on token-wise, so it does not mix tokens. And the attention operation mixes the tokens. In addition to that, we have a couple of skip connections and a couple of layer normalization steps. So this is the conventional transformer block.
So now, we are going to move to the energy transformer block. And as you can see, it is somewhat similar, but meaningfully different. So, first of all, the analog of the multi-head at energy attention is the multi-head attention. And, as you can see, in the analog of the Hopfield Network is the MLP, right?
So one difference that you immediately can see here is that they operate in parallel as opposed to consecutively. And that is needed for the validity of this equation. So essentially, I want my dynamics of the tokens. And, by the way, here x is the token.
I want tokens to evolve on the energy landscape in an energy-descending way. And it turns out that you need to switch them in parallel in order to do that. But computationally, it's not really such a big deal because you stack these layers on top of each other anyway many times, right? And there are skip connections. So probably, it doesn't really matter that much if you put them in parallel or consecutively.
So another important thing is that, instead of feedforward MLP, we are going to have the Hopfield Network here. And the Hopfield Network here is the old Hopfield Network. It's not dense associative memory. The dense associative memory is going to live right here.
So what is the difference between old Hopfield Networks and an MLP in transformers? The only difference is that a Hopfield Network has two sets of weights that go from the tokens into the internal space, and then the other matrix that goes from the internal space to the token space. They need to be related by the transposition operation.
So if these two matrices are not arbitrary, but are related by the transposition operation, then you can show that you can write down an energy function for the MLP. So far, what I have explained, that you can easily convert an MLP with just tightening the weights a little bit inside that MLP into an energy-based Hopfield Network. All right?
Now we're going to do the same thing with the multi-head attention operation. And, probably, it would be easiest if I just show you the equation of how this is done. So essentially, these are the two energies that we are trying to minimize. So the energy number one is energy attention, which corresponds to this block. And the energy number two is the energy that corresponds to this block.
And essentially, the whole token dynamics is doing gradient descent in the inference time on the sum of these two energies. So let's take a look at the energy attention operation. So index H here corresponds to different heads of attention. So let's ignore it for a second. Let's imagine that we have only one head.
Indices B, C correspond to tokens. So essentially, I tokenized my image in a standard way by splitting it into patches. And I enumerated different tokens. But what is most important here is that you can see this kind of general function log of the sum of the exponents that I alluded to earlier in response to Magnus's question. And essentially, if you take the derivative of that function, you will get a softmax, right? You can see that.
So underneath that energy function you, have two tensors-- K and Q, which are pretty much defined in the conventional transformer-like way. So you take the tokens x. You pass them through a layer normalization step, which is denoted here by g. So g is a layer normalized version of x. And then you multiply them by some matrices. And the matrices for keys and queries, WK and WQ, can be different in this model, right?
And that's it. So it turns out that with these very few modifications-- and hopefully you can see that it's almost looking like a proper transformer that we all love-- we can transform the transformer into the energy-based model. So suddenly, this new transformer can now do gradient descent on the energy in the inference time.
So r is the rectified linear unit. Hopfield would be linear. Hopfield would be linear. You can probably do linear, as well. But it's better to do rectified linear here. Yeah.
So Qs I should have put in the formula. So again, let's ignore index H, which corresponds to heads, right? So let's take tokens. Tokens are xiA. So A enumerates different patches. And i enumerates the vector dimension of the token, right?
You take that matrix, and you multiply it according to the vector index by a matrix of parameters-- WQ. And that's what gives you Q alpha C. Linear transformation. Correct. OK. So the question is-- Let's put all of these ingredients together. Let's train it on ImageNet. Does it work or not?
And let me show you what it does. So here, you can see several columns. So essentially, in the third column right here, you can see the ground truth images from the ImageNet. In the first column, you can see the initial states that we give to the energy transformer. And in the middle columns, we are going to see a movie. I'm going to play it for you in a second. And you will see how the initially occluded patches evolve in time, right? So let's take a look.
And the movie is going to loop. So once it gets to the end-- like, we essentially train it for 12 steps when folded-- the dynamics in time for 12 steps. And then we backprop through it, right? So you can only see 12 frames here, and then it goes back to the initial frame.
So hopefully you will see-- you would agree with me-- and, by the way, these are images from the validation set, right? So they have not been used in training. These are completely new images from the same distribution, but the model has never seen them before.
And you can see that, on the one hand, the auto-completions are somewhat meaningful, right? But you can see that sometimes the model makes interesting errors. Let's take a look, for example, at this plot.
Like here, the ground truth is the very irregular pattern of bricks. And please focus at this part of the image, where you can see that, because the initial boundaries were completely occluded, the model essentially extrapolates one gigantic blue brick in a very long stripe, right? And there are these kinds of mistakes. But at the same time, if you look at this middle image, it's a reasonable autocompletion of the pattern that you could imagine by masking some of the tokens here.
Another interesting example is in this one. So the ground truth here corresponds to two different papers. And you can see that there is an appreciable gap right here. But in the initial masked image, that gap is completely occluded. Yet you can see that right here, the model actually does reconstruct a gap.
So it's not reconstruct a blob of green color. It does have some kind of notion of proportions, a little bit. It would not work as a generative model. So it would extrapolate things inside the blob in a more naive way than you would do. So, indeed, it only works nicely if you don't occlude too much. The model does need like some kind of proximal cues to autocomplete those images.
So if you just show a blue sky, it would not like draw a whale in between. Honestly, all these models operate in the embedding space. And I don't really think that there is anything special about bright pixels as opposed to extremely dark pixels. So yeah.
But indeed, you're right that it seems to extrapolate certain rules, like straight lines tend to continue through space, right? Like if there is a left portion of the face, there's got to be symmetric right portion of the face, et cetera, et cetera. So it learns these kinds of memories. And in some sense, it is these kinds of rules that are embedded in the attention operation-- the dense associative memories.
Because the memory here is how tokens are put together in the image plane. It's not what individual token looks like, right? OK. So now let me tell you a story that is related to what I was talking about before, but slightly different.
So we really are interested in scaling these sort of ideas up to more interesting applications, like maybe language, maybe bigger image data sets-- something like really ambitious. And when we started doing this, we realized one important aspect of our models that is not typically present in most of feedforward neural networks in deep learning. And that aspect is the following.
Think about how we used to train neural networks, say, 15 years ago. What we would do is we would define a feedforward neural network, then we would write some kind of loss function. Then we would take a piece of paper and a pencil and calculate derivatives manually, right-- and then put them back into MATLAB. And MATLAB on GPU could do like a little bit of speed up so we could train those neural networks.
Of course, today nobody does that because we have autograd, right? And autograd speeds things up dramatically. So essentially, you never really write the backward pass. So the cool thing about our networks is that even the forward pass, in principle, could be computed through autograd, right? Because if we know the energy, we don't even need to call the forward pass-- because it also can be done through autograd.
So we thought, wouldn't it be cool to somehow design a proper software engineering library and framework that does that? And that's how we came up with this idea of HAMUX, which stands for Hierarchical Associative Memory User Experience. So essentially, HAMUX does two kinds of autograds for you-- the autograd on the energy function in the forward pass and the autograd on the loss function in the training pass.
There is one extra twist that is present in HAMUX, and I didn't have time to explain it properly today because I would need a little bit more time. And that aspect is that it turns out that even the energy function is not the most fundamental object for this model. Rather, the most fundamental objects for these models is what we call a Lagrangian function.
So essentially, you can start the whole definition of these models with especially auxiliary function. And then the energy function would be computed from that function. And essentially, that's what HAMUX does. So you start with the Lagrangian function for every layer. And then you press a button, and it automatically generates a forward pass for you.
It generates the energy function. It generates the backward pass-- anything that you need without you ever taking a piece of paper and computing derivatives manually. So it's open source. If you are at all in software engineering, please take a look at this GitHub report.
Yes. OK. And now-- so 15 minutes. Perfect. Then now, since I have another 15 minutes, maybe I could actually try to explain to you the Lagrangian. Because I really think that's a neat part of the story. And I sort of didn't want to start with it because I wanted to explain bigger ideas first. But now, since we have a little bit of time left, maybe I can explain to you some underlying math.
Yes. So let's take a look at this system of equations again. And you have seen them in one of the previous slides, but I didn't really explain them in any detail. So imagine a very simple Hopfield Network which consists of two layers of neurons, the blue guys and the green guys. Each of those neurons would be described by two sets of variables, the internal state of the neuron and the output of the neuron.
Again, for neuroscientists, think about internal state as a current inside the neuron, and the output as the axonal output-- like firing rate. And then you can clearly recognize firing rate models here, right? So essentially, the blue guys would have internal variables that are denoted by letters vi. And they would have the axonal outputs denoted by gi.
And for the green guys, the internal variables would be denoted by h mu, and the external outputs would be denoted by f mu. So now imagine that I'm going to design a very special pattern of connectivity here. So it's a little bit like a restricted Boltzmann machine.
So there are feedforward weights, and there are feedback weights. But there are no weights in between the green guys or the blue guys. And also, let me assume that the forward weights and the backward weights are equal to each other. Because that's what you need to do to get the energy function.
So now, you can write the system of equations. And hopefully you can see that this coupled nonlinear differential equations, they pretty much describe this picture. So if I take a blue neuron, vi, it receives inputs from the green neurons that are enumerated by the index mu, right? And the functions f mu and gi here are some nonlinear transfer functions.
So now, where do Lagrangians come into the game? So it turns out that the right way to think about the system is to define an auxiliary function, L, which is the integral of the activation function. And let's call it Lagrangian for a moment, and you will see why in a second.
So if we lift our understanding of these coupled nonlinear differential equations from the perspective of activation functions to the perspective of these auxiliary functions, L, It turns out that you can write down a global Lyapunov function for this system using Legendre transform of the Lagrangian function.
So essentially, the energy function turns out to be Legendre transform of these auxiliary functions, L. And for those of you who are familiar with the classical mechanics, you immediately understand why we have called them Lagrangians, right? Because in classical mechanics, if you have a system and you have a Lagrangian for that system in order to compute the Hamiltonian or energy, you need to do Legendre transform.
So let me state here that the analogy stops here. So I'm not claiming that this is an Euler-Lagrange system because these equations are clearly not derived from the action principle. There is no any other part of Lagrangian dynamics that we like. But this is the rationale for calling these auxiliary functions the Lagrangians.
So the cool part is that if we want to prove that this energy function is decreasing, essentially the only thing that we need to do-- and that's a simple exercise. So essentially, I take this energy function, and then I differentiate. I use a chain rule, and I differentiate pretty much every letter that you see in the right-hand side.
So if I do that and I use the equations of motion-- these nonlinear equations-- I can show that the time derivative of the energy is equal to this quadratic form defined on the velocities, dvi's, dt's. And the matrix that appears here is nothing but the matrix of second derivatives of the Lagrangian function.
So in order to prove that the energy function is guaranteed to go down, the only thing that you need to check is that the Lagrangians are convex. So essentially, the right thing to think about this-- and I'm coming back to the previous slide on HAMUX-- that the total energy of this whole system comes in three pieces. Piece number one is the energy of the blue layer. Piece number two is the energy of the green layer. And the piece number three is the interaction energy between the layers.
And essentially, this is the energy of the blue layer. This is the energy of the green layer. And this is the energy of the interactions, right? So the total energy, it's not only just simple dynamics of individual layers. But it also-- we need to include some kind of coupling between the layers.
What I'm saying is that each energy of the individual layer is Legendre transform. And hopefully you would agree with that. And then you need to add an interaction energy on top of that. Yeah. Yeah, yeah. Yeah. So that one does not come through Legendre transform. Correct.
Yeah. And essentially, the last thing that I want to say here is that you can generalize this simple two neural network to crazily complicated recurrent neural networks. So I'm going to give you a couple of examples here. So this is the simple two-layer neural network, right? You can add more layers. You can add 10 layers, and it's going to be fully recurrent.
And you can use exactly the same theory with Lagrangians to derive the global energy functions for this multi-layer hierarchical Hopfield Networks. You can make it convolutional, if you want. You can say, I'm not going to use fully-connected layers. I'm going to use convolutions. And you can also write down an energy function using Lagrangians and Legendre transform for deconvolutional operation.
You can also include attention. So essentially, all the conventional inductive biases that we like so much in machine learning, they can be formulated in these terms. If you pick the Lagrangians properly-- so Lagrangians are convex and the energy function is bounded from below-- then, because of the convexity of Lagrangians, you are guaranteed that the energy function always decreases.
Because the energy is bounded from below, you cannot shoot through the floor, right? You have to stop somewhere. And when you stop, the network is guaranteed to arrive at a fixed point. And computationally, this is kind of a cool property of these systems, that they never really run away.
Like, for instance, if you've ever tried to train deep equilibrium models, for example, that sort of operate in the same paradigm-- continuous time differential equations, right? We run until the fixed point, and then we do something. Those models typically do not have those properties. They do not have energy functions. And if you play with them, that would be the main problem that you would encounter, that they tend to run away.
So they're not really guaranteed to go to fixed points, and they don't. These models are guaranteed to go to fixed points because of the properties of the Lagrangian functions that define them. Let me slightly make your question even stronger.
If I write down an arbitrary system of nonlinear coupled differential equations, I can have all kinds of crazy behaviors. I can have fixed points, limit cycles, strange attractors like chaotic behavior-- anything you want. But it is precisely because this is not an arbitrary system of equations-- it was carefully engineered so that it does gradient descent on a bounded energy function.
These are those properties that guarantee that no matter what psis you pick here-- and psis are picked by gradient descent, you have no control over them, right? The learning algorithm can change them in an arbitrary direction.
Despite that, no matter what psis are, it will always go to a fixed point. You see this in the following way. You derive a global Lyapunov function, which is given by this formula. And then you take the derivative of this Lyapunov function. Yes, you need to do a little bit of computation.
And I kind of-- if you're interested in doing this computation, please take a look at this paper. Literally, you need to take this E of T, you need to use the chain rule, and you need to use these nonlinear equations. Let me give you one example of how you can break that property.
Replace this matrix psi by different matrix A. The theorem falls apart. You do not require psis to be symmetric, but you require the same psi appear in the top equation and in the bottom equation. Correct. They cannot be symmetric because they are not even squared, right? They are rectangular.
Because the number of green neurons can be different from the number of blue neurons. That's why it has two different indices, i and mu. So, indeed, if we go back in history to the conventional use of Hopfield Networks, the typical framework is you give me the patterns. I encode them in the network, and then I extract exactly the same patterns.
So you can do it that way, but you can also work with consolidated patterns. You can take the same architecture, but now you can play a different game. You can say, instead of giving the patterns, I'm going to give a big data set. And I'm going to make the network to extract the meaningful patterns from the data set using some kind learning algorithm. And that's the setting that we are in.
So think about it like this. Like, the memories would not be individual instances. It would be some kind of clusters of instances, right? And if you design the network-- the front end of this network in a good way, then the inductive biases of this front end network would kind of consolidate the clusters in a way that you want them to be consolidated, right-- not based on which data point has more bright pixels, but rather than whether or not it's a cat or a dog.
Yeah. So in some sense, 90% of what I've been talking about today is about consolidated patterns, not about individual instances. Yeah. So in some sense, memories here are a little bit like knowledge. They are not instances of data. That's like what the network has learned by looking at a lot of different kinds of instances of data.
OK. Fine. So I think I'm going to wrap up here. And let me just quickly jump to conclusions. So at this point, the problem of the amount of information that we can write down in the Hopfield Network is by and large solved. So for whatever practical application that you may have in mind, if you think that Hopfield Network is the right tool, the right data science tool for you, you can find one of the models from the Hopfield family that would quite likely satisfy your needs.
So, indeed, there are no really practical limits on the amount of information that you can memorize in these networks. So they can be continuous. They can be binary. That's, again, your choice. There are models of both sorts. They can have all kinds of inductive biases, including convolutions, attention.
There is a little bit of subtlety with pooling. You can do average pooling, but not max pooling. But this is like a technical detail. And I really believe that these ideas of energy-based dynamics that is sort of formulated around memories and basins of attraction is a very sort of interesting, novel direction that we may want to explore.
And what we're trying to do with some folks who are interested in these idea is we are trying to build some kind of meaningful AI applications of this class of models. So, yeah, thanks for your attention.
[APPLAUSE]