Attention Approximates Sparse Distributed Memory
Date Posted:
October 20, 2021
Date Recorded:
October 19, 2021
Speaker(s):
Trenton Bricken, Harvard University
All Captioned Videos CBMM Research
Description:
Abstract: While Attention has come to be an important mechanism in deep learning, it emerged out of a heuristic process of trial and error, providing limited intuition for why it works so well. Here, we show that Transformer Attention closely approximates Sparse Distributed Memory (SDM), a biologically plausible associative memory model, under certain data conditions. We confirm that these conditions are satisfied in pre-trained GPT2 Transformer models. We discuss the implications of the Attention-SDM map and provide new computational and biological interpretations of Attention.
TRENTON BRICKEN: A lot of this work was pursued during COVID isolation alone in my bedroom, and it's really exciting to now be sharing it in front of a lot of people in person. So the title of this work is "Attention Approximates Sparse Distributed Memory." It was done by myself in collaboration with Cengiz Pehlevan, and I'm now advised by Dr. Kreiman. And let's get into it.
So first of all, why should you care? We show that the heuristic Attention operation can be implemented with simple properties of high dimensional vectors in a biologically plausible fashion. The Transformer and its Attention operation are incredibly powerful but were heuristically developed and the softmax operation in Attention is particularly important but also heuristic. Now the intersection of hyperspheres that's used in Sparse Distributed Memory closely approximates the softmax and Attention operation both in theory and trained Transformer models that we investigate. So SDM thus pre-empted Attention by approximately 30 years having been made back in 1988 and meets a high bar for biological plausibility. Particularly it maps compellingly to the unique wiring of the cerebellum.
So as an overview of this presentation, I'm just going to give a summary of the Sparse Distributed Memory, and I'm going to give a summary of Transformer Attention. I'll then show the relationship and how SDM can also interpret the Transformer more broadly. And then if there's time, which there probably won't be, I want to give a review of SDM's biological plausibility in the cerebellum. So we'll see how we do. I'm going to prioritize visual intuition and then try and get into some of the math and please ask questions whenever.
OK. First an overview of Sparse Distributed Memory. So the fundamental question or problem it's trying to solve is how can the brain write and read memories in order to retrieve the correct one later. And so there are a few considerations around this. First, we want high memory capacity. We also want robustness to query noise when we're trying to retrieve a memory later. We also in this case want our system to be biologically plausible. And we also want fault tolerance in that we're resistant to cell death-- neuron death.
What makes SDM unique versus other memory models? And this comes from its name. So it's sparse in that it works in a high dimensional vector space and neurons exist in only a fraction of possible locations in that space. Secondly, it's distributed in that the read and write operations you're doing apply to all nearby neurons in that vector space.
So first getting into the SDM write operation, we're storing patterns in nearby neurons. So in green here we have our first pattern. It's appearing in our high dimensional binary vector space. We'll move to continuous later, but for now on a binary space where we're using Hamming distance. And the green pattern is going to activate neurons in this right radius around it. So the neurons are these hollow black circles. And write itself in. And just as a side note, this pattern it can either write itself in an auto-associative setting or it can write in a different pattern-- point to a different pattern-- in a hetero-associative setting. For example that would be if I'm remembering the alphabet, then my first pattern here would be the letter A. It would write in a pointer to the letter B. And then if I query for B, it would point to C, et cetera.
I clicked too far. OK. So now we've written that green pattern into nearby neurons, and so you can see that the green pattern is inside the hollow black circles, and we're also keeping track of the original pattern location. And that will be important in a second. Yep.
AUDIENCE: What space are these neurons, what physical space in the brain?
TRENTON BRICKEN: So I don't want to get too much into biological plausibility right now. But you could think the dendrites that the neuron has correspond to a vector, which represents a particular location in this high dimensional space. So the neurons all have an address in this space. They exist in a particular location. And then a pattern as it's processed by sensory stimuli kind of up the layers of the brain will have some vector representation that will also define a location in this space.
AUDIENCE: I had the same question. Are they physically near each other?
TRENTON BRICKEN: No. The neural population code for the two vectors would be similar. Yeah. Thanks for asking.
AUDIENCE: If their addresses are similar that means the neurons that they're synapsing from are similar? They have a high input similarity?
TRENTON BRICKEN: The dendrites of a neuron-- so what it is sensitive to, what it's activated by, will be close in space. So the pattern corresponds to a population code or a particular vector, and neurons that have a high similarity to or looking for patterns like that will be close in space, and therefore they'll be activated.
AUDIENCE: Mathematically, it's a binary effect?
TRENTON BRICKEN: Yes. High dimensional binary vectors.
AUDIENCE: When you say that they store--
TRENTON BRICKEN: I'll get more into the math in a second. Yeah. For now, you can just think of these neurons as if they have a storage vector, and they're storing this pattern inside of it. So here we're writing in a second pattern. It has a different location in the space, and it's also activating nearby neurons. And then it's writing itself in. And now note that some of the neurons are storing both the green and orange patterns. And you can think of this as it's storing a set of patterns, but in reality each neuron has one storage vector. And we're doing a summation of the patterns that we're storing. And because we're in a high dimensional space you can think of that as a superposition of the two patterns being stored with minimal crosstalk between.
So finally we're going to write in a third pattern. And you can see that it's stored again. And something that's important for later is, note that although these patterns are ephemeral-- so they've disappeared-- their original locations can be triangulated based upon the nearby neurons that are storing that pattern.
So now we're getting into the read operation. And so in this case, I have my pink query xi. And it is also going to be activating nearby neurons. Those neurons are going to output the patterns that they've stored. All of those patterns are going to be aggregated, and then we're doing a majority operation. So it will converge towards whatever the dominant pattern was. And so in this case, the query is getting one green, two orange, and four blue patterns, and the blue pattern dominates such that it will converge towards blue. Also that's fitting because if you look at the location of the query, it's closest to where the blue pattern was originally written in.
Another way of looking at this, where we abstract away the actual neurons is just consider the original pattern location. And then we can just look at the intersection between the original write circle and the read circle for our query. And so in the bottom right here I've replaced the number of patterns with the size of these circle intersections. And it's this circle intersection relationship here that will be crucial to mapping onto Attention later.
So now getting into more of the mathematical formalism, here I have my blue pattern and the intersection with the query. And I'm defining that intersection and the number of neurons in it with a cardinality operator. And then the neurons that are in the pattern intersected with the neurons that are in the query. And now I'm going to break down this equation for the read operation one step at a time. So first of all, for each pattern, which is denoted by the bold p at the far right, we're weighting it by the size of its circle intersection. We're then summing over all of the patterns. We're then normalizing by the circle intersection weights, because we ultimately want to be able to compute a majority. So we need that normalizing constant. And in this case because we're in a binary space, we need to map back to zero or 1. And so we have this majority rule function g.
OK, I'm actually not sure if I addressed your question with this, because I'm focusing on the high level relationship to Attention. Sparse Distributor Memory talks more about the way the patterns are stored, but it is the superposition. So I'm doing a superposition of all patterns that activate that neuron. Yeah.
AUDIENCE: So does the storage move where the neuron is in space?
TRENTON BRICKEN: So it's a key value pairing. So the neuron has an address where it exists in the space. And then it's storage vector that it will return. And you could think of that biologically as the dendrites that activate that neuron. And then the synapses it has with its efferent connections downstream.
AUDIENCE: [INAUDIBLE]
TRENTON BRICKEN: In order to show the relationship to Attention, I'm abstracting away the actual neurons. So this slide doesn't show the synaptic weights. But there's a one-to-one mapping between the neuron perspective and the pattern perspective. I just can't get into the SDM biological plausibility. I have slides on it at the very end that hopefully we can get to. But I want to focus on the relationship to a Transformer Attention here.
AUDIENCE: What is the background of this? Past memory?
TRENTON BRICKEN: Yeah. So it was invented by Kanerva. Yeah. So he published his book on it in 1988. There were a couple of extensions that were developed in a few years after, and then it kind of disappeared with the sense of time. And I'm still not sure why. Maybe you have a better idea for why that was.
AUDIENCE: There were earlier associative memory [INAUDIBLE].
TRENTON BRICKEN: Yes. Yeah, absolutely.
AUDIENCE: So this is kind of connecting them.
TRENTON BRICKEN: Yes. I don't talk about related work here, but Sparse Distributed memory can actually be written as a generalization of [INAUDIBLE] networks. So they're a special case.
AUDIENCE: [INAUDIBLE] dynamics [INAUDIBLE].
TRENTON BRICKEN: Yes there are still some differences, but they're very closely related. So now giving a short overview of Transformer Attention for those who are unfamiliar with it. And just quickly, Transformers are one of the state of the art deep learning models across many modalities right now. And so on the left here the Transformers being used to predict the next word and generating text. In this case, it has a wonderful story about unicorns that you can read on OpenAI's website. The Attention operation is also applied in three different locations in AlphaFold two which was recently used for protein folding prediction.
It's also used for almost every Google English search query now, and it doesn't just do language processing. It's also moved into image classification and generation tasks. And so as a kind of fun one here, you have this model input, which is half of an image and then it needs to predict what the other half of the image would look like. And you can see that here it's getting the shadows correct in some different examples. And the original's on the far right.
So the core thing that makes the Transformer unique is the Attention operation. And I have this slide here just to show where it's doing next word prediction. That's the example, we're going to be working with. And in this case, the animal didn't cross the street because it blank. And our query to the system is the word it, and we're then looking back at the rest of the system, deciding how to pay attention to it to predict what word comes next. And so here our word it has connections with previous words that it's using to then predict what comes next.
And so diving more into how that works exactly, I'm going to work with a simpler phrase-- the cat sat on the blank. And there are four things that we do in order to predict the next word. And so first we take each of our input words, and we create what are called keys, values, and queries. So each word aside from the last one turns into both a key and a value. And then the last word becomes a query. We then use the query and compare it to each of our keys. And the way we do that is we use product operation. And we then take the size of those dot products with each input, and we normalize. And we're using the softmax function for normalization, which is crucial to the relationship that [? all ?] letter shows. We'll get into that. And then finally, based upon the weights, so how much attention we decide to pay to each of our inputs, which is based on how similar the query was to the keys. We're then going to take the corresponding value that's paired to each key and do a summation operation of each of those values. And then from that we'll do some projections and then predict the next word.
So as just like a fictional example, we're really going to work through this. The cat sat on the blank. So our query word here is the. And you could hypothetically think that query is a vector that's going to have a high dot product or be similar to keys that are either nouns or their associated verbs. And so in this case, you'd think it might have a high similarity with the words cat and sat-- their keys. So it gives a large weight to the cat and sat value vectors. And the cat value vector you might think of as containing some superposition of other animals that are related to cats, and maybe words that rhyme. For example, the word mat. This is totally fictional what I'm trying to give intuition, right?
Then the sat value vector that corresponds to its key, we're paying attention to it, contains things that are sat on. So including mat and so what you could have is when you're paying attention to both cat and sat value vectors, you're doing the summation of them, and then you have all of the different possible vectors and superposition. And in this case, I have a wait of three on mat and then a weight of 1 on mouse and sofa. It goes on and on. But you can think that the mat not would dominate in the summation operation allowing us to then correctly predict the next word that comes.
And I guess one piece of intuition here is what you pay attention to and what you should extract from it are different. And so that's why we have the key value pairing, and we're paying attention to the keys, then we're extracting information that are in the values. So quickly just to show you some of the notation for this, because we're going to map it onto SDM. Here I have my query that's being updated. And it's being updated using this equation. And the Ws you can kind of ignore the projection matrices.
Yeah I break this down more clearly here. So first we're doing a dot product between our keys and our queries. So that's shown by this operation 2 here. And so the Y's are the actual inputs. We then do this projection of them and then we do a dot product.
Then we have the softmax operation. And the way this is actually defined is an exponential over a sum of exponentials. And to give you some intuition for it, softmax normalizes the weights but it makes large values larger. And this relationship is crucial to SDM so that's why I'm spending a lot of time on it. So just as this demo here, I have these inputs on my x-axis, and they each have some value to them, ranging from 0 to 5. And on the second plot here, I just do a normal normalization, in which case that largest value, which has an index of 4 normally has a value of 5. If it's normalized regularly, it'll just have a value of 0.3 but if I'm using a softmax operation depending on my beta coefficient in the softmax it'll have a value of 0.6, so it becomes much pointier, peakier than it would if I was just doing an ordinary normalization. And so just relating all this back to the equations again I do my softmax operation, and then I use these normalized weights to weight the summation of the value vectors we talked about before. And that gives you that full equation and hopefully some intuition for it.
OK so how does Transformer Attention approximate SDM? Well it turns out that in a high dimensional space if I have two hyperspheres that as I pull apart those hyperspheres, the size of their circle intersection the number of neurons that they share will decay approximately exponentially. So in this figure on the right here on the x-axis I'm showing the Hamming distance between those two circles as I pull them apart. And on the y-axis I'm showing in log space the number of neurons that exist in that circle intersection, and because in log space this plot is approximately linear, it means that the number of neurons in the circle intersection is approximately exponential.
This is just one set of SDM parameters, but I'm using n=64 dimensions, which is the normal dimensionality used in Transformer Attention. And do note here that the exponential approximation doesn't hold for all hamming distances. It works best for patterns and queries that are close to each other, when the circle intersection is large. But that's the regime that we care about because when we do the softmax operation and then have our normalizing constant, anything that's far away basically drops to 0.
So here we have our equation for the circle intersection. We can write it as approximately exponential with a coefficient outside of the exponent and then the beta coefficient inside of it. And there are two things that we need to do to make this relationship a good one. First we need SDM to be continuous and so what we need there is we need a mapping from our hamming distance into cosine similarity where we're L2 normalizing our vectors and then taking their dot product. And this equation here is just the linear mapping between hamming distance and cosine similarity.
And then we also need the beta coefficient inside our exponential to be a correct value-- that it can fit our exponential decay. And the way we can do that is in a closed form with just log linear regression on our circle intersection. And so I'm going to show you I'm just redefining the Attention operation we went through before into the SDM notation. No other tricks. And this is the real money slide where it's the relationship between SDM and Attention. So I've expanded out the softmax operation on the right here into the exponential over the sum of exponentials. I have the SDM equation presented before and the extent to which the circle intersection in SDM is approximates an exponential is the extent to which SDM and Attention converge.
So now I just have some results. First in theory. And so I have two plots with different SDM settings for small and large hamming distance that we're using the size of the circle radiis and in blue I have the actual circle intersection, which I'm normalizing. Just a basic normalization, right? And then in orange I'm fitting an exponential with the beta coefficient using my log linear regression to the circle intersection and then I'm using the softmax equation. l so you can see the quality of these approximations in these two different settings. And then the subplots I have there are log plots. And so you can see that in this case with a larger hamming distance that exponential-- the approximation only holds for closer points. But by the time I'm at a hamming distance of 20 here where it drops off, you can see that my normalized weight is basically zero /
So we've talked about the relationship between SDM and Attention. But how does SDM relate to the Transformer more broadly? And so one way that we can look at this is, to what extent do Transformers use beta coefficients in their Attention operation that are similar to those for optimal versions of SDM? So depending on what I want my SDM system to do, if I want to store the maximum number of memories possible versus if I want to have my system very robust to query noise, I will use different hamming radii. And in order to compute this I need to assume that my patterns are random and so this won't apply to the real world where of course data is on some lower dimensional correlated manifold. But I can still get these values with random patterns, and see how they map on for the beta coefficients that Transformers decide to use.
And so in this case, I'm using the key query normalized Attention variant of the Transformer, which actually learns its beta coefficient, so it makes it very easy to look at this. Because normal Transformers don't learn their beta coefficient. You have to kind of infer what it would be from the size of its dot products between queries and patterns. And so this histogram shows the learned beta coefficients across Attention heads, across layers, for this Transformer model. And the vertical red lines are three different definitions of optimal SDM. And so on the far left we're maximizing for query noise. We want our queries to be as noisy as possible but still work. In the middle we're maximizing signal to noise ratio. And on the far right we're maximizing memory capacity. And you can see that the learned better coefficients fall within these bounds, and also it skews towards max query noise, which I think makes sense, because if you're maximizing memory capacity, you're assuming your queries are noise free, and if you're training a model in a deep learning environment without a distribution training data always appearing, of course, that's not going to be the case.
So beyond Attention, how can we interpret other parts of the Transformer? Well there's some interesting work showing-- yeah?
AUDIENCE: These values are the translation of the terms.
TRENTON BRICKEN: OK. Yeah. Thank you. SDM and Attention use different notation. So I just needed to--
AUDIENCE: Where do keys and queries come from? You're just using the information [INAUDIBLE].
TRENTON BRICKEN: Yeah.
AUDIENCE: Something historic.
TRENTON BRICKEN: Yeah.
AUDIENCE: So values are the patterns.
TRENTON BRICKEN: And again those patterned pointers you can be in auto or hetero-associative setting. So the pattern pointer can either equal the address. It can point to itself. Or exactly. Yeah. And so in a Transformer setting, of course, where you're trying to predict the next thing, it would be hetero-associative.
AUDIENCE: Associate A to B.
TRENTON BRICKEN: Yes. Exactly. Yeah. And this work has been accepted to NeurIPS, and the paper will be out a week from today. We don't have a preprint yet, but it'll be the camera ready version is next week today. So there's some interesting work that we cite in the paper that other people have done showing that the feedforward part of the Transformer as a whole. And I should have said before, this is the whole Transformer architecture with each of the operations laid out. And so we can actually interpret that feedforward layer as doing a long-term version of Attention. And so then we can interpret it as doing a long-term version of SDM. And by long term here I mean when I'm doing normal Attention, my keys and values are a function of my receptive field, the current inputs that I'm looking at. And this longer term Attention is actually independent of my particular inputs that I'm working with. It'll store longer term memories across the whole training. Multiple epochs.
AUDIENCE: [INAUDIBLE]
TRENTON BRICKEN: Exactly. Yeah. And that actually relates to the neuron versus pattern perspectives of SDM that I talked about a while back. We can also interpret layer norm, which has been shown to be really important when people have tried to get rid of it. And this is in the sense that in order to do SDM, I need cosine similarity. So I need to L2 norm my vectors. And the key query normalized variant of Attention that leads to some small improvements actually uses L2 norm instead of layer norm. So you can kind of think of this work as retroactively predicting this improvement and layer norm approximating.
AUDIENCE: [INAUDIBLE]
TRENTON BRICKEN: Yeah so basically I look at my vector, I compute for all the vectors a mean and standard deviation, and then I normalize by that. If you're familiar with the batch norm operation, it's kind of similar to that.
AUDIENCE: [INAUDIBLE]
TRENTON BRICKEN: But it's a function of-- I think you have a running average of all the things you've seen. It's not just within the batch. Like when I do batch norm, I compute the mean of everything in that training batch.
AUDIENCE: Normalized over [INAUDIBLE].
TRENTON BRICKEN: Yeah. But they're quite similar. And I think that they're putting everything on a similar scale the same way that L2 norm would. So beyond these connections SDM has a number of extensions that we think could be useful in further improving the Transformer. So one, SDM has some close relationships to vector symbolic architectures. There's also some work showing that you could have multiple value vectors corresponding to each key. There are variants of self-attention where you're not having every single input be its own query. And there are other forms of external memory storage techniques.
And so in summary the intersection between two hyperspheres approximates an exponential and this allows SDM's read and write operations to approximate Attention in theory and in the tests that we run. And so as sort of big picture future research questions that we certainly don't have answers to yet, but I'm interested in exploring-- are is the Transformer so successful because it's performing a key cognitive operation? And it's worth me pointing out that the cerebellum or cerebellar like architectures are ubiquitous across a large number of organisms. And so softmax there's some key operation that it's doing. And given how successful the Transformer has been empirically across multiple modalities, is SDM the correct theory for how the cerebellum is functioning?
And I think I'm pretty much out of time and want to leave time for questions, so I won't get into biological plausibility. It's in the appendix of the paper that will be out soon. But it's quite exciting in how it maps to each of the cell types. And I'll just go to the thank you slide.
Associated Research Module: