Apical dendrites as a site for gradient calculations
Date Posted:
April 26, 2019
Date Recorded:
April 26, 2019
Speaker(s):
Blake Richards
All Captioned Videos Brains, Minds and Machines Seminar Series
Description:
Blake Richards, Assistant Professor, Associate Fellow of the Canadian Institute for Advanced Research (CIFAR)
Abstract:
Theoretical and empirical results in the neural networks literature demonstrate that effective learning at a real-world scale requires changes to synaptic weights that approximate the gradient of a global loss function. For neuroscientists, this means that the brain must have mechanisms for communicating loss gradients between regions, either explicitly or implicitly. Here, I describe our research into potential means of communicating loss gradients using the unique properties of apical dendrites in pyramidal neurons. I will present modelling work showing that, in principle, ensembles of pyramidal neurons could using the temporal derivative of their activity to estimate cost gradients. I will also show how this can be learned using the discontinuities that spikes induce. Finally, I will discuss specific experimental predictions that arise from these theories.
MARTIN SCHRIMPF: My name is Martin Schrimpf. I'm a grad student with Jim DiCarlo and part of the Center for Brains, Minds, and Machines. And it's my great pleasure to welcome Blake Richards today.
Blake received his bachelor's degree in cognitive science and artificial intelligence from the University of Toronto in 2004, working, among others, with Geoffrey Hinton. After completing his PhD at Oxford with Colin Akerman, he went to SickKids Hospital for his post doc with Paul Frankland. And he is now an assistant professor at the University of Toronto and will relocate to McGill in the summer.
On top of numerous awards, Blake has recently been named one of 29 CIFAR AI chairs in the Learning in Machines and Brains Program. Please join me in welcoming Blake Richards.
[APPLAUSE]
BLAKE RICHARDS: Thank you very much, Martin. And could we just get the lights down a bit here? Thank you. Perfect. So thanks for having me and for being here, everyone. It's my first time both at MIT and in Boston and the surrounding area altogether, which has been a real gap in my travels. So glad to have made it here.
So I'm going to tell you a little bit today about some of the modeling work that my group has been doing in collaboration with Richard Naud to look at how you could potentially use the physiology of apical dendrites to engage in gradient calculations for learning.
Now, let me unpack this a little bit. Let's start with the question of, why gradients? Come on, you can do it. There we go.
To get the question of why we would use gradients, we have to start with a very high level question of, what does it mean to learn? So learning implies something normative, necessarily. When you learn, you get better at something. If you're not getting better at something, you're not learning. You're just changing.
So when we say you're getting better at something, ideally, as scientists, we would have some way of quantifying that. And certainly, as computational modelers, we're particularly interested in providing some quantification for that learning.
Now, within the computational neuroscience perspective, but particularly even more so within machine learning, the way that we engage with this normative perspective is through the use of loss functions. So for perhaps some of the physiologists in the crowd, a loss function is just a function that tells you, for each phenotype of the system-- and to clarify, that could be any component of the phenotype-- but let's say for sake of simplicity that we're talking about synaptic connections-- so for each setting of your synaptic connections-- so here we have a two synapse network and different weights on those synapses.
You're going to be better or worse at some task. And the loss function measures how bad you are. So your goal in learning, then, is to find a setting of your synaptic weights that minimizes your loss function. And that is just how we define learning-- how we get at it in a quantitative manner.
Now, if loss goes down over time-- so if your loss function is reducing over time, then you know that at least some component of the direction that you're going in your phenotype is space-- and let's, just again, say this is synaptic weight space-- some component of the direction you're going in synaptic weight space is in the direction of the gradient.
So if I've got, say, three synapses here and I look at the setting of the synapses at some time step, and then some other time step T seconds later, they've changed over the course of that T to arrive at that point. And if I'm actually learning-- even if the path is all twisty and weird, this vector of my change, on some level, has to not be orthogonal to the gradient of my loss function. If it was, I'm not getting better. So learning necessarily means that I'm at least going, to some extent, in the same direction as my loss gradient.
Now, the other thing that we can say is this can be both shown analytically within certain, restricted cases, but also, empirically, we have found in machine learning that, generally, algorithms which adhere more closely to the loss gradient will do better and learn faster than algorithms that don't. So here you're looking at plots from a paper by Xie and Seung that was analyzing gradient descent.
And two other methods-- node perturbation and weight perturbation-- both of these methods actually provide an estimate of the gradient of the loss function. But they do so with higher variance than true gradient descent. And what you find, lo and behold, is that with the higher variance of your gradient estimate, it takes you longer to learn.
So any algorithm, we can say broadly, as I said, empirically, this is the experience-- any algorithm that is better at following the loss gradient is going to learn more quickly. And that's why modern machine learning is so dependent on gradient descent and related methods. Pretty much all machine learning nowadays-- well, not all-- but all of the kind of high impact machine learning relies on gradient descent techniques.
So my bold or ridiculous claim, depending on your perspective, is the following. I would suggest that, given the fact that humans and animals actually seem to be quite effective learners and able to learn fairly quickly, I think it is more parsimonious an explanation to assume, at first, that we might have some mechanism in our brains for calculating loss gradients and for using those loss gradients to update our synaptic weights.
If someone comes along later and shows how you can learn better without following loss gradients, great. Then I'm happy. That's how science progresses. But until I see a demonstration of that, I think it's a perfectly reasonable question to ask, how might the brain calculate gradients of loss functions that it might be interested in? And so that's been a goal of my research in my lab for the last few years.
So my talk today is going to describe to you our work on exploring potentially the apical dendrites of pyramidal neurons as a site for these gradient calculations. And I'll show you modeling work that I think at least makes a plausible case that it's worth doing some experiments to test these ideas.
OK, so the problem with calculating gradients. So the reason this isn't an easy problem and something that actually requires a lab to focus on-- or many labs, as it were-- is that the most direct way to calculate the gradient of a loss function if you have a multilayer network-- so here, I'm illustrating a multi-layer neural network. Each of these little circles is supposed to represent a neuron. And each of the lines is synaptic connection.
If you have a multilayered network like this-- and for the record, in this talk, I'm going to ignore the question of gradient descent through time. That's a whole other kettle of fish that I don't have time or intelligence to deal with. But if we have a hierarchical network nonetheless, so even ignoring the difficulties of time, the way that you're going to calculate the gradient of a loss function is ultimately by using the chain rule.
So you to calculate the gradient in one layer, and then you propagate it backwards. And you use that calculation to calculate the gradient at another layer. And you propagate that backwards. And that's just the back propagation algorithm.
So when we actually lay this out mathematically in the kind of simplest case-- let's say we have a neural network as structured here with input x, output y, and a hidden layer h. We've got standard linear, nonlinear unit activations and a standard squared error loss function. So there's our loss function. L is equal to 1/2 the error squared.
When we calculate the partial derivatives of this loss function with respect to our synaptic weights in the hidden layer, which, of course, we're going to use to calculate our gradient, we end up getting this, courtesy of the chain rule. And that gives us this equation here. So this is just all the partial derivatives chained out from the loss function, all the way down to these synaptic weights here.
Now, there are a few difficulties with this-- from a physiological perspective, that is. There's nothing difficult about this from a computer science perspective. You can do this on a computer, no problem. We do it all the time.
But from a physiology perspective, we have the following problems. The first is that we need separate forward and backward passes. So you need to send activity through the network. And then you need to do your gradient propagations backwards. So you need to calculate the gradients at the top, then calculate the gradients of the next layer, then so on and so forth, and go backwards like this.
And even though there's maybe some evidence for waves of activity, there is no clear evidence of completely separate forward and backward passes in the brain-- certainly not throughout the entire hierarchy of the cortex.
Another issue is that you need to control your plasticity with error-- this error term that's been propagated from the top or the gradient of the error, anyway. And that is also potentially problematic, because we don't see any evidence for non-local error terms impacting synaptic plasticity in the brain.
Generally, synaptic plasticity is a local phenomenon. And we don't really have any reason to believe that something like this error term in this equation is ultimately determining the synaptic changes in a particular layer of cortex.
And another key problem here is that, in working out these gradients, we're left with this term here. To go back to the diagram for a second, so these are the synapses from my input to my hidden layer. And W1 here is the synapses from my hidden layer to my output layer. And what that term tells us is that to update my upstream synapses, I need to know my downstream synapses. And that itself is also problematic, biologically. There's no reason to believe that a neuron in V1, say, knows something about the downstream synapses to V2 or V4.
So because of all these reasons, the idea that the brain does back propagation is a non-starter. We know the brain doesn't do back propagation. And I want to say that really clearly, because I get a lot of people asking me, do you really think the brain does back propagation?
No, I don't think the brain does back propagation. But estimating a gradient of a loss function is not equal to back propagation. Back propagation is a particular way of calculating the gradient of a loss function using the chain rule. And so what I'm arguing is that, though the brain can't do back propagation, it might have other ways of estimating these gradients and getting fairly tight estimates on them.
OK. So if my computer will listen to me-- all right. So here are the problems that I've identified for you for doing a gradient calculation if you do it in the straightforward way. And what I'm going to try to do, is I'm going to try to sequentially knock out each of these problems and show how the physiology of pyramidal neurons can actually permit you to learn without having to deal with each of these issues.
OK. So to start with, I'm going to describe a phenomenon that my collaborator, Richard Naud, describes as burst multiplexing, which can knock out that first issue for us. So just to recognize everyone who's worked on this, the main student behind all this in my lab and all the work I'm going to show you here has been Jordan Guerguiev, a fantastic student who will be finishing up soon and looking for a post doc. And I really can't recommend him highly enough.
And we've been doing this in collaboration with Alexandre Payeur, Friedemann Zenke, and this is Richard. And this team has done all the work that you'll see in the following slides.
OK, so let's talk briefly about the apical dendrites of pyramidal neurons. So here we've got a sort of cartoon of a pyramidal neuron with basal dendrites that are coming out of the perisomatic region. And then this long, apical shaft that reaches up to the superficial layers of the neocortex, or similar things happen in the hippocampus with, now, a tuft of dendrites here. Here's our apical tuft.
And this is an illustration of data from an old paper by Matthew Larkum, which has been very influential in my thinking. And he followed it up with a series of papers.
I think there's some interesting discussion about whether this happens in all pyramidal neurons. I saw a presentation at-- or sorry, no, I didn't see it at Cosine, but it was sent to me by someone from Cosine-- suggesting that maybe this doesn't happen in V2.
But I've seen this in [? my hands ?] and somatosensory cortex and in V1. And Matt did it in somatosensory cortex. And people have described it in hippocampus, as well.
Basically what happens is, this shaft of the apical dendrite has a region that's very rich in voltage-gated calcium channels. And as a result, there is this interesting non-linear behavior in the apical dendrite. So here in the experiment that Matt did-- Matt Larkum-- he had three patch electrodes attached to different parts of the same pyramidal neuron.
And he would inject current into this red electrode here, attached to the apical tuft. And if you just injected a little bit of current, he would see the response. And it would very quickly dissipate in its propagation down to the soma.
But if he generated enough current into the apical dendrite, then he would get this massive non-linear response in the apical dendrite, courtesy of these voltage-gated calcium channels. And this would generate a prolonged depolarization in the cell, which is why we call it a "plateau potential." And it can drive burst firing in the pyramidal neuron.
Now, this burst firing driven by the plateau potential, then, is basically a way of signaling that you've got simultaneous basal and apical input. This is something that has been shown experimentally, but also a very nice demonstration of it in an in silico biophysical model by Adam Shai when he was in Christof Koch's lab.
Basically, here they've got a multi-compartment simulation of a pyramidal neuron receiving basal synaptic inputs and apical tuft synaptic inputs. And if you send in the basal synaptic inputs by themselves, you can drive a single spike from the neuron. If you send in just apical inputs, you don't get much of anything because of the dissipation due to the electrotonic distance of the apical dendrite.
But if you give basal and tuft input, you get a plateau potential and burst firing. So the burst firing is basically a way of saying, in addition to the basal input, I got some apical input as well. It's effectively a coincidence detector for the cell.
Now, this is a fairly long story. This just shows that if you remove the calcium channels, you don't get that behavior. What's interesting, though-- so Richard, when he was in Henning Sprekeler's lab a few years ago, ran some biophysical simulations where he took ensembles of simulated pyramidal neurons that he had fit to Matthew Larkum's data. And he provided them with different inputs to the apical dendrite compartment or to a perisomatic compartment that's also intended to capture the basal dendrites.
And what he did then in his analysis-- which is a very clever analysis I think, that really got me thinking-- is, rather than just examining the spiking behavior of the cell, he analyzed what he called the event rate and the burst probability. So the event rate is basically just if you look across the ensemble-- so here we've got three pyramidal neurons we're plotting out.
So here's their voltages with spiking. If you look across the ensemble, you can ask how many of these pyramidal neurons engaged in either a spike or a burst. We call that an "event." It can be either a spike or a burst. The event rate is just the percentage of neurons in the ensemble that are engaged in either a spike or a burst.
You can then also ask, of the events that occurred, how many of them were bursts? And that, he called the "burst probability." And these two things can be decoupled from one another. So here, in this first time bin, we've got a very low event rate, because only one of the neurons was actually engaged in either a burst or a spike. But we've got a very high burst probability, because the only event was a burst.
In contrast, in this third bin, we've got a very high event rate, because all three neurons engaged in an event. But we've got a lower burst probability, because only one of those events was a burst.
So we've got these two different signals in the system. And what Richard showed, which was cool, is that if you look at just the firing rate of the neuron when you're injecting current into these two different compartments, it's hard to interpret what's going on. You kind of see some responses to changes in the currents, whatever.
But if you, instead, analyze the event rate and the burst probability across the population, the event rate fairly closely matches the current that is being received by the perisomatic compartments. And the burst probability fairly closely matches the current that is being received by the apical dendritic compartment.
And this is happening simultaneously. So in other words, you have a multiplex signal. This ensemble of neurons is simultaneously communicating two different streams of information-- one about what it's receiving at its perisomatic compartments, and one about what it's receiving at its distal apical compartments.
So Richard and I got talking shortly after he published this paper. Oh, this is just to show that you get this same effect even when you've got other inputs.
So if you've got a whole other set of Poisson inputs coming into your ensemble and you look at what happens when you inject-- here, blue is the perisomatic compartment, and red is the apical compartment again-- again, you can see the event rate tracking the perisomatic current, and the burst probability tracking, fairly closely, the apical dendrite input, even though there's all these other Poisson inputs.
OK. So what Richard and I got talking about was that, potentially what you could do, is you could use this to simultaneously broadcast a bottom-up and top-down stream of information. Because one thing that I failed to mention but was on my slide a few slides ago, is that it's biology, so there are no hard rules here.
But often what you see in terms of the micro-circuit structure of the neocortex is that a lot of the top-down feedback from higher order regions of the cortex impinges upon the apical dendrites of pyramidal neurons in lower regions of cortex. And a lot of the bottom-up information, either from primary thalamic afferents passing through layer 4 or from lower cortical regions, arrives at the basal compartments.
So theoretically what you could have is a sort of bottom-up signal via the basal compartments and a top-down signal via the apical compartments. And courtesy of this burst multiplexing, you could communicate these two signals-- the bottom-up and top-down signals-- simultaneously.
So what we're showing here. So here, we simulated an ensemble of pyramidal neurons that was receiving an input to their apical dendrites, which we're then projecting to the apical dendrites of another pyramidal ensemble.
That pyramidal ensemble was then projecting to the perisomatic compartment of these dendrites and receiving, in its perisomatic compartment another current. And then in order to get this to work, we find we have to introduce some interneurons that also target the perisomatic and distal apical compartments differently, and which also use short-term depressing versus short-term facilitating synapses distinctly.
But if we set up this micro-circuit in this way, which is actually not a horrible approximation of the micro-circuits that we see in the neocortex, what we find is that this ensemble-- so here we're injecting input to its apical dendrite-- it's burst probability is going to match that current. But what's interesting is that, then, also the apical dendrites of these guys are also not doing a bad job of tracking that signal.
Meanwhile, their event rate is encoding the signal that's coming into their perisomatic region. And we see that same signal get propagated up to this next ensemble here.
So theoretically, what you could have is layers of pyramidal ensembles that are simultaneously broadcasting these two signals to each other at the same time. And so this has been where we've started to go with these models, is to think about how we could use this to eliminate the need for separate forward and backwards passes.
Because if you could encode some information about the error gradients of the loss gradients in that top-down signal coming into the apical dendrites, then they could be communicating your gradients for you simultaneously, as the rest of the system is communicating whatever sensory information it has to up through the hierarchy via the event rates. So this is ultimately where we're going to go.
Now, one additional, kind of interesting component to this, which I'll touch on again later, is that you also can control the non-linearity in the apical dendrites using apical dendritic targeting inhibition. And you can do that in a variety of different ways.
So for example, if you have some kind of simulated somatostatin interneurons-- so this is the burst probability as a function of the amount of dendritic input. You can alter the slope of that non-linearity. It has to be non-linear, because it has to range between 0 and 1, or 0% and 100%.
And you can alter that slope and make it more or less linear, depending upon how much inhibition you apply. And you can do that either directly with somatostatin interneurons, or indirectly via VIP-mediated disinhibition, or even just altering the release probability onto the somatostatin interneurons, or even more directly, I suppose, by altering the dendritic exciteability. Though the dendritic exciteability, it should be said, doesn't actually alter the slope so much as just shift your function for you.
So really, the take-home message here I want you to get is that you can control the slope of this non-linearity with somatostatin inhibition or, alternatively, with VIP-mediated disinhibition. And that's going to come up again in another few slides.
But the take-home message from this section is, theoretically, with ensembles of pyramidal neurons and incorporating the non-linear properties of atypical dendrites that we see, you should be able to simultaneously communicate bottom-up and top-down signals in the neocortex. So on some level, we can maybe cross this one of our list. And we don't have to worry about having to do separate forward and backward passes so much.
So what about controlling plasticity with error? OK. So I'm going to look at supervised learning here. And for the record, that's just for convenience, because it's easier. Because no one has necessarily demonstrated really good unsupervised learning.
Or alternatively, another thing I'll say about that, is a lot of the unsupervised learning that has turned out to be particularly effective is something we call "natural supervision," where basically, you're still doing supervised learning. It's just that the targets are no longer labels or anything like that, but are something like predictions of the sensory inputs you're going to get.
So I think the question of supervised versus unsupervised learning is an interesting one. But for this talk, I'm going to stick to showing you how we can do supervised learning with this burst multiplexing system.
OK. So one of the things that is also interesting about the physiology of apical dendrites this way is that, when we look at the impact of bursts on plasticity, they seem to be able to regulate the sign of synaptic changes. So this was a set of experiments from Letzkus back in 2006. But there have been a series of other experiments that have demonstrated more recently in Jeff Magee's lab, and also from Anthony Holtmaat's lab, that non-linear plateau potentials really are good at regulating plasticity and pyramidal neurons.
Here are the experiments from the Letzkus lab. He was doing, again, multiple patches on a layer 5 pyramidal neuron. And then he was looking at-- in this case, it was actually a layer 2/3 input to a part of the apical dendrite that was not quite in the tuft region. So this doesn't quite fit 100% with what I'm about to show you. But the broad idea is there.
So basically what he found was, that if he triggered the presynaptic input and then triggered a single spike in the layer 5 cell, he would get long-term depression at his synapse. Whereas if he triggered a presynaptic input and then got a non-linear response, courtesy of some additional dendritic input, and triggered a burst, he would get long-term potentiation. So he could control the sign of plasticity of the synapses by whether or not there was a burst present.
So Richard and I have been working with this idea. So the fundamental idea is that the learning rule at the biophysical level-- maybe "biophysical" is the wrong word-- at the cellular level would look something like this.
So you have an ensemble of pyramidal neurons here. And this is a presynaptic ensemble, a postsynaptic ensemble. They've all got multiple synapses to each other.
And we have the following weight update rule. We say that the weight change for a particular weight between neuron i and j is going to be a function of an eligibility trace, which is this here. And that's basically just a convolved spike train, so indicating when spikes have occurred in the past.
And these two terms here-- so we've got a term indicating whether or not the postsynaptic neuron has engaged in an event, and a term indicating whether or not that postsynaptic event was a burst. The event term induces depression. So if you just get an event, you will get depression. Whereas the burst term switches that to long-term potentiation.
Now, so basically, bursts are going to tell the neuron, you should be potentiating rather than depressing.
AUDIENCE: [INAUDIBLE]
BLAKE RICHARDS: Yes. Because what you can show is that, across the ensemble, what this turns into is the following. So this is the average across the ensemble. And you end up getting basically a Hebbian term, which is the correlation in the event rates across the ensemble.
And then the difference between the burst probability in the ensemble and some baseline burst probability. So what you end up having is a Hebbian term with an additional error term that's regulating the direction of that Hebbian plasticity. So it's akin to a three-factor learning rule.
Now, what that gives you, effectively, is a frequency-dependent STDP, as well. So here, so if we just do like an STDP protocol, depending upon the frequency with which we stimulate the cells, we will either get depression or potentiation. Both the simulations and analytical theory agree as to what's going to happen there. And we see similar things with a Poisson protocol or a burst Poisson protocol, where we intentionally induce bursts sometimes.
In either case, what you get is a situation where, at low frequencies, you get depression, and at high frequencies, you get potentiation. And this is the kind of thing that experimentalists have known about for a long time. So this is all just to say that this matches a lot of what we know about synaptic plasticity from the experimental literature.
But what it gives you is the ability to control plasticity with this error term-- so effectively, a supervised learning role. So what you can do is, basically-- so here in the red, we're plotting the postsynaptic burst probability in these guys. So this is the burst probability in this ensemble. That burst probability is ultimately what determines the sign of plasticity at these synapses.
So when the burst probability rises really high, the synapses, on to the perisomatic region here, increase. If we increase that burst probability again, it increases further. If we then decrease the birth probability down to 0, we start to get depression. And if we leave it at the baseline burst probability, we have stable weights.
So you can basically regulate the plasticity in the network by altering the burst probability in your ensemble. And that doesn't depend on the burst probability of these guys. So here, we raise the burst probability here. They're not determining these guys' synaptic weight changes. It's the birth probability in this ensemble that determines the direction of change.
OK. Now, the other thing we can do then is, how do we calculate the difference from a baseline? One of the ways we can do that is with a temporal difference. So if you have some initial baseline activity and then you push it in the direction of a target, that difference over time is going to be your sort of difference from the baseline.
So here's some time activity, which I'm just denoting x, because this is very abstract here. So this is some activity vector x at time t star minus 1. And then at time t star we get pushed towards this target.
And here's the difference between these two. You can show that if you have a squared error loss function, basically this temporal difference is proportional to the gradient of your loss function at time t star minus 1.
So if you have some kind of baseline activity and then you push it in that direction of the target, you're getting out of that difference your gradient. So we're going to use that gradient, then-- this temporal difference gradient-- along with that burst probability mechanism that I just showed you, to actually propagate gradients through these ensembles.
So here, we're moving to an ensemble-level model now. Now, because we don't yet have the capability of running a million biophysical neurons in a deep network, we're moving to an ensemble-level model here.
So just to be clear, we're now approximating the activity of each ensemble with some relatively simple equations. We say that the event rates at layer i are a ReLU function of a linear transformation of the event rates at the layer below. The burst probabilities are a sigmoid function applied to the burst rates from the layer above, where the burst rates are just the product of the burst probabilities with the event rates.
In addition to taking into account the top-down input in the burst probabilities, we also have a recurrent inhibition term that's introduced by our kind of very loose approximation of somatostatin inhibition, which we're going to use to regulate the linearity-- the non-linearity, I should say-- of our burst probability.
So here is the equations that are governing this. And then, we're going to use the following synaptic weight update, which is a high level approximation of the synaptic weight update I just showed you.
So we're going to say that the weight changes in a particular layer are determined by-- so if we get a target at time t star, the extent to which the burst probabilities diverge from the baseline probability that they were at prior to the target being received is our error term. And then we have a Hebbian term that is just the product of the event rates across the two layers.
So there's our synaptic weight update. It's based loosely, but nonetheless, on biophysical principles. And what we can show is that it actually gives us a decent approximation of the mean squared error gradient.
So here what I'm plotting is the angle between the weight update that back propagation would tell us to do, which is technically the true gradient and the weight update that our algorithm tells us to do. And you can see that over time, so we get this drop in that gradient.
That's because one thing I didn't mention here, is our feedback weights right now in this model are fixed. And we're relying on a phenomenon called feedback alignment that I'll touch on in a moment.
But basically, the network has to learn how to deal with its feedback weights appropriately. But if it does so, it can actually get down to a pretty reasonable approximation to the gradient. It's only about 40 to 50 degrees off the gradient. Sorry, it's about-- yeah, a 40 to 50 degree difference, though that's a function of the non-linearity.
So here, these betas are the slope of the non-linearity in the burst probability. So if we make that slope less-- we make it a more linear function-- we can get a better approximation of the gradient.
So one of the other things we do in the model is, we introduce learning in the somatostatin interneurons to encourage linearity in the apical compartment. And if we do that, we can then, without having to fuss around with the slope, just encourage a better approximation of the gradient. So red here is where we're updating our somatostatin connections to encourage linearity in the apical compartment.
AUDIENCE: [INAUDIBLE]
BLAKE RICHARDS: Hm?
AUDIENCE: Can you just very quickly elaborate-- so it's not clear to me what each layer, what kind of information it has. So initially, you said with this burst rate and [? end ?] rate, it's possible that each unit in each layer kind of knows what the target [INAUDIBLE]
BLAKE RICHARDS: So what's happening is-- sorry if it wasn't clear-- in each layer, the apical compartments are receiving information about the burst rate in the layer above.
AUDIENCE: Yeah.
BLAKE RICHARDS: And the perisomatic compartments are receiving information about the event rate in the layer below. OK?
AUDIENCE: [INAUDIBLE]
BLAKE RICHARDS: So that's the information that gets propagated through the network.
AUDIENCE: [INAUDIBLE] different from layer to layer, you're saying? Or is the burst rate [INAUDIBLE]
BLAKE RICHARDS: It's all happening simultaneously. So we run this in time. So you'll notice there are t's here. So we assume a synaptic delay of one time-step. And so we just run it in time-- chunk, chunk, chunk, chunk, chunk, every time step.
And at each time-step, the layer below communicates its event rate up to the next layer, the layer above communicates its burst rate down to the lower layer. And this all just runs through time.
And then at some point during that simulation, we give a target that just nudges the output layer towards the right answer. It doesn't even set it towards the right answer. It just pushes it in the direction of the right answer.
That, then, alters the burst probabilities in that output layer. And that, then, percolates through the network. And it's that percolation through the network that moves all of the burst probabilities off of their baseline burst probability and induces the error signal that drives learning. Yeah?
AUDIENCE: At the [INAUDIBLE] also change every time-step?
BLAKE RICHARDS: No. So that's a good question, and no. For this to work the way it is right now, we need to assume that the input layer is constant throughout that time. Yes. Yeah.
AUDIENCE: [INAUDIBLE] alignment?
BLAKE RICHARDS: Yes.
AUDIENCE: We set the y to be doubly transposed.
BLAKE RICHARDS: Yeah.
AUDIENCE: Then how close to the [INAUDIBLE]?
BLAKE RICHARDS: Pretty close. Like, it's not quite perfect. But it's just a few degrees off, basically. Yeah. Yeah?
AUDIENCE: Is it possible that you have a loop or anything? Like, say if your [INAUDIBLE] model, all the connections are not in a loop, but if you add a loop, will there be any [INAUDIBLE]
BLAKE RICHARDS: I mean, to be honest with you, I don't know. My guess would be that, from enough experience of having played with recurrent networks, that if all I did was add a loop and I didn't do any training on it or anything, (LAUGHING) yes, I'd get a problem.
So what we'd have to do is also consider appropriate loss functions for the recurrent connections there. I mean, we've got some recurrent connections, but they're just inhibitory, so. Yeah.
OK, so with this, we can do pretty well on a variety of test sets. So one of the ways we test it, in fact, is we set up this structure. So when you move to more difficult data sets-- not MNIST, but the CIFAR-10 data set, for example, you really need convolution layers in the early stages. But convolution layers can't be learned via this burst multiplexing.
But in order to test whether or not we're communicating the gradients appropriately, what we do is-- so we have this structure of neural network. We've got an input, a set of convolution and pooling layers, and then some of our burstprop layers.
And what we do is, we calculate the gradients. We estimate the gradients using our burstprop algorithm. And then, the difference in the burst probabilities that we get at the bottom layer of this component of the network are what we feed in to the chain rule calculations for the convolutional layers in the bottom.
So if we're actually successfully doing a decent job of estimating the gradients, which is what our comparison to burstprop suggests we are doing, then the convolution layers should be receiving the information they need to actually do their gradient updates. And, in fact, we find that they do seem to be, because learning works well.
So let's maybe just jump to the larger networks. With four hidden layers, we can crush MNIST down to around 0.9 on the test set. With five hidden layers, we can get down to around 33%, 34% with CIFAR-10. And that's about what you get with back propagation with a similar network architecture as well.
So now we would need to push this to even harder data sets, ideally to image net, or even more ideally to non-image categorization data sets, god forbid. We'll get there, but there's difficult compute issues going on when you start pushing into more interesting stuff. This will be helped by my move to [INAUDIBLE].
AUDIENCE: [INAUDIBLE]
BLAKE RICHARDS: Yeah.
AUDIENCE: [INAUDIBLE] I guess, what is called cyan color, right? Five hidden layers--
BLAKE RICHARDS: Yes.
AUDIENCE: [INAUDIBLE]
BLAKE RICHARDS: Yes.
AUDIENCE: [INAUDIBLE] backprop at all.
BLAKE RICHARDS: Well, no, it sort of does in the convolution layers, because we pass our estimate of the gradient to those layers. And then they do backprop.
AUDIENCE: [INAUDIBLE]
This one.
BLAKE RICHARDS: The red one? The red one is where it's pure backprop.
AUDIENCE: [INAUDIBLE]
BLAKE RICHARDS: Through all the layers, yeah. You got it. OK.
AUDIENCE: I'm sorry. Why are these numbers so high? [INAUDIBLE] so many convolutional layers. And they do backprop through all the layers. [INAUDIBLE] you can get like [? 16%. ?]
BLAKE RICHARDS: Well, it's also the number of neurons that we're using. And it's also, you would ideally go even farther. Like if we did on CIFAR-10 seven hidden layers and we had more neurons than we're using, we could probably get this down further. But that's about what you get.
AUDIENCE: [INAUDIBLE]
BLAKE RICHARDS: And we're not doing any data augmentation.
AUDIENCE: [INAUDIBLE] baseline that they were trying to fight with dropout was below 20%, as far as I remember.
BLAKE RICHARDS: Yeah. But they're optimizing a whole set of their architecture that we're not. We've done almost no architectural optimization here, to be clear.
AUDIENCE: But how many [INAUDIBLE] per layer do you have?
BLAKE RICHARDS: If I recall correctly, I think once we get up to here, it's like 100, 200. And the number of filters here, [SIGHS] I'd have to ask my student again. But it's not a lot, which is part of the reason we're only getting down to 33%.
AUDIENCE: [INAUDIBLE] like 64 or [INAUDIBLE]
BLAKE RICHARDS: Something on the order of a couple a dozen, a few dozen, yeah.
AUDIENCE: [INAUDIBLE]
BLAKE RICHARDS: I mean, I don't know what to tell you. I have trained convolutional neural networks on CIFAR-10 with these types of architectures. And this is what you get out a backprop if you don't do much work to get it down.
OK. So I don't have very much time here. But I'm just going to briefly touch on the question of learning the feedback weights. So we've been relying on this feedback alignment effect.
So one of the things that my friend, Tim [INAUDIBLE] showed a few years ago-- and I take it he came and presented here in Boston recently on some of the updates to this-- is that if you have a neural network where you replace the transpose of your feedforward weights with a random fixed feedback matrix, then learning actually seems to proceed well enough. And it's because the feedforward weights align themselves to the feedback weights such that you end up actually approximating the gradient in the way that I showed you earlier.
The difficulty is that we've found that this doesn't work well when you scale up. So even on CIFAR-10, you can kind of get it. But once you push it to like [? ImageNet, ?] it just doesn't work.
So we find that with feedback alignment, or even alternatives like difference target propagation that try to learn something on the backwards weights, we get nowhere near to what we get with back propagation on a ConvNet. So this feedback alignment effect seems to not really solve it altogether for you. But there have been a few papers that have come out recently that have shown that even if the weights aren't perfectly symmetric, you can actually do fairly well if the sign of the weights are aligned.
So really, all you need is you need a system that's going to learn to approximate the sign of the weights for you, ideally. And if you can do that, you should be good in your gradient estimates.
So we've been thinking about how to try to potentially learn these. And what I've been thinking about, kind of inspired by some of the older STDP work, is thinking about this in a causal perspective.
So let's think about it this way. Let's say I'm neuron A here. And I need to somehow have feedback synapses that match my feedforward synapses
Another way of phrasing that is that, if I have a positive causal impact on B, then I want B to have a positive causal impact on me. And if I have a negative causal impact on C, I want C to have a negative causal impact on me.
So we can view this as a sort of causal inference problem. And a paper came out last year from my friend Konrad Kording's lab, that showed that one of the things you can do with spikes is you can learn them to use causal relationships. And he did this with a trick that he borrowed from economics called regression discontinuity design.
So let me give you an example of that. The example I'm going to give is the following. So let's imagine that students in a first-year neuroscience exam, they take the exam, and they pass if they get a score of 50%. Now, the question is, what is the causal effect of passing that exam on the student's final GPA? Like, how does it somehow impact their later studies?
Now, if we just look at the relationship between exam score and final GPA, presumably we're going to see a positive correlation between those two things. And as a result, if we just do the kind of obvious thing to try to estimate the causal impact of taking the average of those who passed the exam and the average of those who didn't pass the exam, we would incorrectly infer a causal impact there, when, in fact, there was none. It's just that there is a correlation between these two variables.
And economists are constantly faced with this kind of thing, right. If you give someone a particular loan or something like that, how does it impact their economic performance? And you've got all these other variables you can't control and all sorts of correlations you can't control.
So they came up with this method called regression discontinuity design, where what you do is you fit linear regressions across either side of this discontinuity that you're interested in. And then you look at the difference between those linear regressions. And if there's actually a difference between those, then you've got some evidence of a causal impact.
So inspired by their proposal-- because they basically said, look, you can do the same thing with spikes. Spikes are a discontinuity. So neurons could learn about their causal impact on things by using information about what happens when they did or didn't spike. So we took Konrad and Ben's idea. And we've been applying it to learning the feedback pathways in our networks.
I've got to finish up. So I'm going to go a little bit quickly through this. Basically, in the first pass, we've got two neurons. These are leaky integrate-and-fire neurons. If they pass threshold, they fire a spike. And we've got a feedforward synapse that determines neuron 1's impact on neuron 2, and then a feedbacks synapse.
And what we're going to do now, is we're going to look at what happens-- so neuron 1, to go back here, is going to try to estimate its causal impact on neuron 2 by doing a regression discontinuity estimate on the postsynaptic potentials in its apical dendrite. So it's basically going to say, what do the postsynaptic potentials from this higher order neuron look like when I spike and when I don't spike?
So here we've got the voltage of neuron 2, the voltage of neuron 1. And basically, whenever the neuron gets close to its threshold or just passes it, then we've got a chance to look at the correlation between the postsynaptic potential in the apical dendrite and the voltage in neuron 1.
And that's going to be how we do our regression discontinuity design. And we're going to use that to estimate the causal effect beta of the input neuron on the output neuron.
I don't have time to go through the math. But basically, we're just fitting a piecewise linear regression. And we're updating the synapse using the information from those piecewise linear regressions. And we're fitting those piecewise linear regressions with gradient descent.
Anyway, long story short-- here is the feedforward synapse. Here is what is effectively the feedback synapse, the estimate of the causal impact. And it does a pretty good job of estimating its causal impact on the downstream neuron.
And it does so all in real time, just using spike information and just using local information. And this is after about 100 seconds of simulated running of the neurons. And this even works when you've got correlated inputs to the neurons up to a certain degree of correlation. If they're correlated 100%, it doesn't work. But even up to like 0.9 correlation, it works. [CHUCKLES]
Now, the other question we had is, can this scale to thousands of neurons? And the answer is, so here we run a simulation where we've got 1,000 input neurons and then one output neuron. And we're looking at learning these feedback synapses that way.
And it doesn't do as well. But it still manages to kind of capture it. This is with average firing rates of around 10 hertz. And that's part of the reason, probably, as you know, we're running it for 100 seconds. So it's getting some data points, but not a huge amount.
But one of the interesting things is, obviously, in our previous model, we were particularly interested in the ensemble-level models. And there, when we look at sort of ensembles of the input neurons, the actual estimate of the sign of the weights is quite good.
So that's potentially one way of getting around this noise problem. And basically, our next step is to try to implement this in our burstprop model and see if we can reduce the bias in our gradient estimates using this mechanism.
Another thing that we'd like to look at is using this for direct feedback, where the feedback pathways skip multiple layers. Because we think it might even work in that situation. But that's another conversation.
So let me finish up. Roughly, I've shown you this third one. Maybe that should be a sketchy, not quite crossed out. But to conclude, I think that it's safe to say that many species, especially humans, can learn many things surprisingly efficiently.
And I always find it a kind of funny discussion where people say to me, well, but why would you think the brain does gradient descent? Because neural networks are really poor at learning. Well, you know what's even worse at learning? Things that don't have an estimate of the gradient.
So I think it's reasonable to assume, until proven otherwise, that our brains might be doing something like a gradient calculation, especially because any species that evolved the ability to estimate loss gradients would probably outcompete other species, because they'd be just that much better at learning. And if it's a biologically plausible thing to do, then at least it's worth exploring experimentally.
So the next steps are to explore this experimentally. I actually think some of the work that's been done on comparing deep neural networks to the real brain, including by people here at MIT, is most interesting because it suggests that the brain is, indeed, doing end-to-end optimization.
That, for me, is actually the take-home message-- is that, if you do a better job of approximating the representational geometry in the brain with a neural network trained on a particular loss function in an end-to-end fashion, it suggests that the brain at least has a loss function that, in some way, cares about similar things as the loss function you've trained that neural network on, and that the network in the brain has been trained in an end-to-end fashion as well, using gradient information.
So our work provides a sort of proof of principle as to how loss gradients may be estimated using the unique properties of pyramidal neurons. And now we want to start generating some predictions with this model.
I'll give you just one taste of a prediction. We're working on a whole host of them. But one of our predictions that come out of the burstprop model is that the model predicts a relationship between burst variance and error.
So what we should see is that, as the error increases, there should be greater variance in the burst probabilities. And across training, the variance in the burst probabilities should decrease across all layers of a network.
So if you look in the brain, what we would predict is that when an animal is doing some task and it makes an error, you would see increased variance in the burstiness of the neurons across different regions of the brain. Maybe it's bullshit. But at least it's a concrete prediction.
And we're going to try to generate some more concrete predictions like this going forward with this model, because I think the nice thing is, unlike standard deep neural networks, we have a very clear biophysical mapping to what's going on in the brain. We're making specific postulates about physiological sites for these calculations. And that's why we can generate clear predictions with it.
OK. So in other words, I think we should kind of try to do the top-down and bottom-up thing in science simultaneously, as it were. Thanks very much for listening.
[APPLAUSE]