Linear Analysis of RNN Dynamics
Date Posted:
November 23, 2020
Date Recorded:
November 19, 2020
Speaker(s):
Eli Pollock
All Captioned Videos Computational Tutorials
Description:
Recurrent neural networks (RNNs) are a powerful model for neural and cognitive phenomena. However, interpreting these models can be a challenge. In this tutorial, we will discuss how dynamical systems theory provides some tools for understanding RNNs. In particular, we will focus on the theory and application of linearizing RNN dynamics around fixed points. We will then look at some computational tools that have been developed around this framework.
Speaker Bio: Eli Pollock is a fifth year BCS PhD student in the Jazayeri lab. His research focuses on using RNNs to model cognitive processes.
NHAT LE: Today, we're very delighted to have Eli Pollock to tell us more about analysis of RNN dynamics. Eli is a fifth-year graduate student, currently in the Jazayeri Lab. And he's done some very amazing work on developing toolboxes and analysis techniques to help us understand how RNNs solve cognitive tasks.
And being a fellow graduate student, I can tell you that it's always a great pleasure to hear Eli explain these concepts in a very accessible and easy to understand way. And I hope that you can also find this session helpful and informative. So Eli, thanks so much for joining us today.
ELI POLLOCK: Yeah, thank you, Nhat, for that very kind introduction. Hello, everyone. As Nhat said, I'm Eli. So the topic of today's tutorial is going to be linear analysis of RNN dynamics.
So this is going to be broken up into a few parts. And for the sake of accessibility, I want to assume as little mathematical background as possible. And so there might be some people who already know a lot of the things in this talk. Feel free to tune out for those parts.
I'm going to be going through some fairly basic dynamical systems theory in the beginning, which relies a little bit on knowing somewhat basic calculus, taking derivatives of things like that. But these tools are really powerful, and they will help us a lot later on in the talk. I'm then going to get kind of into the meat of the talk, talking about recurrent neural networks, which are a pretty popular tool for modeling neural circuits these days. And I'm going to be talking about how we can apply the tools that we've talked about up here to RNNs.
I'm then going to introduce and go through two techniques that I think are pretty interesting that use these ideas. One I cannot take credit for. It's called FixedPointFinder.
It is by Matt Golub, who is at Stanford. And so we're going to just walk through the GitHub README of that, just to get a sense of what you could do with this tool. But since I didn't make the code, I don't really feel all that comfortable giving a very detailed tutorial of it.
I'm then going to end the talk, depending on how much time we have, with an approach that I have worked on and some code that I did write. And there will be a Google Colab notebook for people to follow along at the end. And I'll go through how we can flip a lot of the ideas that we've explored on their head, and go from instead of RNNs to linearized dynamics, go directly from linearized dynamics to RNNs. OK, so that's the overall summary of the talk.
So for the first part, our goal is going to be understanding the following sentence. Linearize around the fixed point by taking the eigenvalues of the Jacobian. Now, if you have no idea about dynamical systems theory, this might be a very confusing sentence. I'm going to be mostly borrowing directly from this book Nonlinear Dynamics and Chaos by Steven Strogatz, which I highly recommend to anybody who wants to get a more thorough introduction to the ideas that I'm going to be talking about today.
All right, so first of all, starting from the very basics, what is dynamical systems theory? The big idea here is that the change in state of a behavior, or I should say system, depends on its current state and inputs. Put mathematically, that means that we're dealing with equations that take this general form, where we have a change in state. It's the state x dx/dt-- it's change over time, is going to depend on some function of its current state x at some given point in time, and some inputs, which can be time-varying.
As a simple example, we can look at some one-dimensional non-linear dynamics. Let's say that we have a system x dot, which is just shorthand for dx/dt equals sine of x. So what is the solution to a system like this.
Well, this is a nonlinear function of x And so the solution is it's analytic. You can figure it out, but it's pretty complicated. It ends up looking like this, where you get this pretty complicated relationship between the state of this system, its initial conditions, and time.
A much easier way to look at it is to look at this graph right here, where we basically just have a plot of this sine function. We have x on the x-axis. And then we have x dot up here.
And this reveals some really interesting features of the dynamics that we might expect, where at basically all intervals of high on this x-axis, we have what are called fixed points, where we might have stable fixed points here, where if the x dot is negative, we're going to be moving left. And if it's positive, we're moving right. So these fixed points are basically values of x where, if you let the system evolve from initial conditions, it will fall into these points.
Likewise, we have unstable fixed points, where if you're perturbed a little bit away from it, you're going to fall away from it and into stable fixed points. But if you're exactly on that stable point, you're not going to move at all, because the value of x dot, the change in state, is 0. And this overall plot is called a phase portrait. It's going to be a pretty useful tool for understanding pretty much any dynamical system that has few enough dimensions that we can visualize it.
All right, so now we're going to get to the first part of that sentence that I mentioned, linearizing around the fixed point. So we've introduced what a fixed point is. Now we're going to talk about what it means to linearize around it.
So let's say we're going to define this term eta, which is equal to a perturbation around a fixed point x star. So if x star represents, like, one of these stable fixed points, we're just saying that this perturbation is the difference between the state of the system and the state at that stable fixed point. So we can define that. We can then take the time derivative of it, and find that it's just equal to x dot.
And then we can do a pretty simple calculus trick, called the Taylor expansion, where if we're interested in discovering what's the behavior of that dynamical function f around the fixed point, around just basically transforming this first part, so that we have a formula for x, we get that we have the value at the fixed point plus eta times the derivative of, x plus some higher order terms which we're going to ignore. So the value of our dynamical function at the fixed point is, by definition, 0. And we can ignore these higher order terms.
So what we end up with is that this eta dot is about equal to eta times f prime evaluated at the fixed point. So what that means basically is that we're taking a small perturbation, and then we're pretending-- we're saying that it's close enough, we can ignore a lot of the terms in this equation. We can pretend that we've created a linear system where x dot, or in this case eta dot, depends on eta scaled by some other factor, where now this term f prime evaluated at the fixed point, that's just a constant number.
So what we're left with is something of the form, just taking a different variable, dy/dt equals a times y, where a is a constant. And if you've done any differential equations stuff whatsoever, you probably know that this equation has the solution of an exponential function, where the behavior of y is going to be exponential. So what it means is that if this constant a, which is basically just this term right here, if it's greater than 0, we're going to see an exponential increase in the value of the perturbation, meaning that we have an unstable fixed point. And if it is negative, less than 0, we're going to have a stable fixed point, because the perturbation is going to decrease exponentially. So we've covered that.
And going back to that 1D example, we can understand this, basically, just in terms of what we've already talked about. The slope that f prime of x evaluated at the fixed point, if it's negative, there's a negative slope here. We're going to always fall into that stable fixed point. Perturbations will decrease exponentially. And if there's a positive slope, then that means that we have divergence from this fixed point, where any perturbation, any tiny little poke that you make to the state, is going to, at least right around the fixed point, increase exponentially.
Once you move away from it, then the Taylor expansion stops working. So it's not going to increase exponentially forever. Eventually, it's just going to fall into this other fixed point. But the perturbation immediately around this will increase exponentially. OK, any questions about this pretty basic idea about just what it means to linearize around a fixed point?
AUDIENCE: I'm curious, like, when you do this procedure, what happens to the rest of the space? So you can think about it as being, like, linear around the fixed points. But as you move further out, before you get to another fixed point, is it also linearized, just not like [INAUDIBLE]?
ELI POLLOCK: Yeah, so this only holds up right around the fixed point, as I mentioned. And that's because of the nature of this Taylor expansion. Taylor expansions really only work for very small perturbations.
If you start talking about Taylor expansions at other values, this isn't going to work as well, mostly because this term, we can't just set it equal to 0 anymore. It's going to be equal to something else. So you end up with very different behavior.
I guess you could technically-- as long as the dynamics aren't too fast, you could look at [? the ?] instantaneous slope at other points and linearize there, and say like, all right, a decent approximation for this function right here is that it's behaving exponentially with this slope. But it works best around fixed points. Does that answer the question?
AUDIENCE: Yeah, I think so. I guess just maybe more directly, does this procedure change, like, the dynamics of the system away from the fixed point?
ELI POLLOCK: So if you're just looking at the linearized dynamics that you have, then yeah, it's absolutely not going to work, because you're just describing an exponential function. So if you have an unstable fixed point, what you're predicting is exponential growth that will go on forever. You'll just keep accelerating along that state. So yeah, It's not going to capture, like what happens over here.
AUDIENCE: OK, thank you.
ELI POLLOCK: Yep. So now I want to talk about-- we're going to move up in dimensionality. We've mastered one-dimensional nonlinear flows. There's not really anything else you can do in one dimension. So you're either exponentially falling into fixed points or being perturbed away from them.
For two-dimensional flows, you can get some more interesting behavior. This is going to describe a-- word going to start with linear systems, meaning that we have something of the form x dot equal ax plus by, and y dot equals cx plus dy, where a, b, c, and d are just constants. And these are now functions of each other.
So the nice thing about linear systems is that they can be described using matrix multiplication and linear algebra more generally, where now if we take a to be a matrix a, b, c, d, and x to be a vector with x and y, we now get this very nice expression x dot equals ax, where if you multiply out the matrices, you end up exactly with these equations. So it's just a very nice, simple way to put things. And the behavior of these kinds of systems is also pretty well-behaved. And you can describe it using some really interesting tools, which we're going to get into in a second.
So for now, let's just think about an example, where this a matrix only has values on the diagonal. There's no crosstalk terms here. X is solely a function of x, y is solely a function of y.
So if we do matrix multiplication, we get this. They're totally uncoupled. And once again, we have these very nice exponential equations, where x of t is going to be an initial value times this exponential function, where a is an exponent. And y is going to have a negative 1 in the exponent.
So the types of behavior that you can have here are limited to these different portraits, which still is a pretty interesting range of behavior. If a is less than 1, that means that you're having some kind of decay along both axes, but it's stronger along the x-axis. So everything is going to be pulled all the states are going to converge first onto this the x equals 0 line. And then you're going to fall into the fixed point in the middle.
If a equals 1, that means that basically-- or if a equals negative 1, then the amount of decay is equal in all directions. So no matter where you start you're just going to move along this radial axis into the fixed point. And conversely, if a is less than 0, but greater than negative 1, the decay is stronger in this direction.
Now, we start to have very different behavior from decaying into a fixed point, once a is equal to 0 or greater, where if a 0, then nothing is happening in the x direction. All of the movement is decay in the y direction. So we have, basically, a line of fixed points on the x-axis.
And finally, when a is greater than 0, we get this interesting behavior here, where this is now what we call a saddle point, where the trajectory is up [INAUDIBLE] on the y-axis are going to decay into this point. But any perturbations are going to cause it to exponentially move into the x direction. And so [INAUDIBLE] eventually, the long-term behavior of this is that you're going to end up on the x-axis zooming into infinity.
So you can get an interesting range of dynamics here. But note the fact that, in all of these cases, if you're on the x-axis or on the y-axis, you're always going to be moving in a straight line either towards that fixed point or away from it. But either way, you have straight line trajectories.
Now, this is really important to what's about to come. We might want to ask, like, OK, so in this system where we only have this very nice diagonal structure of the dynamical matrix, the straight line trajectories are on the x and y-axis because these are decoupled. But what if we now have different values? What if it's a, b, c, d? Can we figure out what the straight line trajectories are going to be, and use that to translate these behaviors into any two-dimensional linear dynamical system?
So for general cases, we want solutions of the form xt. There's some vector here, v-- some direction that isn't on the x-axis or on the y-axis, but as a linear combination of them. And we want there to be exponential behavior on it.
Now, if we plug this xt into the equation x dot equals ax, we end up with this very nice equation right here, where now we're saying that we basically want something where applying the matrix A onto one of these special straight align vector axes is basically just going to scale our position along that axis. And if you have ever taken a linear algebra course, you might recognize that this is exactly-- the words I'm about to use are eigenvalues and eigenvectors, where these special axes are eigenvectors and these scaling values are eigenvalues. So basically, we're saying that if we take the eigenvectors of A, meaning vectors for which applying a matrix just scales your position along that vector, the eigenvectors are going to give the directions of straight line trajectories. And the eigenvalues tell us how we're moving along those.
So remember, if this is negative, we have decay. And if it's positive, we have exponential growth. And eventually, you're going to end up with solutions of this form, where we basically have combinations of either exponential growth or decay along these two those eigenvector axes.
Now, if that was at all confusing, I think that this figure really shows what I mean by talking about all that, where now we can describe-- let's say we have some matrix A, and we take the eigenvectors of it, we might get that this is one eigenvector and this is another eigenvector. And they're pointing in two different directions. And maybe the corresponding eigenvalue for this direction is positive, meaning that we see this exponential growth. And the eigenvalue corresponding to this eigenvector is going to be negative.
So we're going to decay, and we're going to have this saddle point right here. So basically, under this framework, we can capture all of those different kinds of behavior. But instead of having straight line-- but instead of having everything centered around the x and y axes, everything is now going to be centered around the eigenvectors of the dynamical matrix.
So now we've filled in this second part. We can linearize around the fixed plane. And now we understand what I mean when I say eigenvalues.
Oh yeah, and one other thing. So this can be tricky to think about. But if you end up taking-- so one thing about eigenvalues and eigenvectors is that they can be complex, meaning that they have imaginary components-- things times the square root of negative 1.
If you end up with complex eigenvalues, that means that things are rotating, basically. If you just have complex eigenvalues, that means that you might have rotations around the fixed points. Or let's say that you have one eigenvalue that's complex, meaning that you have rotation, but you also have one that's a negative, meaning that there's also an exponential decay into the fixed point, you can have behavior like that.
All right, so that covers most of the things for linear dynamical systems after two dimensions. You can go up to three, four, or five dimensions. But the same results pretty much hold, where you can always describe the system's behavior as the decomposition of exponential growth and decay along eigenvectors. So let's now talk--
AUDIENCE: Can I ask a quick question?
ELI POLLOCK: Yeah, sure.
AUDIENCE: Sorry, Eli. I was just wondering, maybe I missed it, but how many-- could you remind me how many eigenvalues and eigenvectors you should expect for these things? Like, do you ever have situations where things are singular?
ELI POLLOCK: Yeah, so in cases like this, there would. I think there is just-- trying to think. You do end up, I think, with sometimes singular values. And those will describe systems where you just have rotation without anything happening in other dimensions.
And I think you can also end up with behavior like this, where you can describe this with just one value. Yeah, generally you would expect the number of eigenvalues eigenvector pairs to be equal to the number of dimensions of the system. So for a two-dimensional system, you're in most cases going to have two.
AUDIENCE: I think I want to just add that, even if you have a degenerate eigenvalue, you're going to potentially have multiple eigenvectors. And you can use those-- there's going to be some sort of special behavior, where it loops around. It's weird-looking.
ELI POLLOCK: Exactly, yeah. Yeah, feel free to also correct me if I make any mistakes here. OK.
AUDIENCE: Thanks.
ELI POLLOCK: Yeah, so now we've covered what you can do with these tools of linear analysis. Let's now talk about nonlinear systems, and how we can maybe use the tools of linear analysis in them. So a nonlinear system might look like this, where we have now a bunch of really crazy stuff going on.
In some places, we have something that looks like a saddle point, maybe another one over here. And here we have, like, an orbit. A lot of crazy stuff going on.
So let's talk again about this idea of linearization that we introduced earlier and see if we can extend it to higher dimensions, and see what happens. So let's say now that we have some nonlinear system dx/dt, where dx/dt, x dot, is equal to some function f of both x and y, and dy/dt is a function g of x and y. For a fixed point, what happens if we perturb it? And instead of the eta that we had earlier we're now going to have a u perturbation around x and a v perturbation around y
So again, we can run through very similar stuff as we did last time, where if we take u dot, we're going to find that it's equal to x dot. And we can then plug in that value into f. And then we can do, once again, that Taylor extension, basically, that we had earlier.
Except now it's a little bit more complicated, because we have to account for both stuff happening with u and stuff happening with phi. So we end up with something that now-- an equation that looks like this, where that u dot is going to depend on u scaled by this partial derivative of f with respect to x evaluated at the fixed point, as well as v scaled by a partial derivative of f with respect to y evaluated at the fixed point. And similarly, if we do the same thing on v dot, we get a very similar equation.
So this looks complicated because we now have all these partial derivatives everywhere. It looks confusing. But fortunately, once again, we can take this linear algebra approach to things, and we can rearrange those equations and ignore the quadratic terms. And we get this very nice form of u and v are equal to this matrix of partial derivatives times u and v. We're now describing the behavior of these perturbations around [? the ?] fixed point in terms of these partial derivatives.
Now, what this means is that we basically have a linear system, and this is our linear dynamical matrix. And this matrix, which we'll call it a, it's called a Jacobian matrix. And that might sound a little bit familiar because that's the last part of that sentence that we had earlier.
So if we want to evaluate this linear system, we can just do the tools that we did before by taking the eigenvalues and eigenvalues of it. And then that gives us the dimensions, where we can expect exponential decay and growth. So one more time, if we want evaluate the linear system, take the eigenvalues. We've now completed this idea of linearizing around the fixed point by taking the eigenvalues of the Jacobian matrix. So even if it's impossible to get analytical equations for solutions to the whole system, we now can understand what's happening locally around fixed points.
So going back to this crazy example, what might we expect the eigenvalues and eigenvectors of the Jacobian to be at points A, B, and C? So for example, at point A, we would expect there to be maybe one eigenvalue in this direction along the x-axis with an eigenvector in this direction with a positive eigenvalue, and another one here with a negative eigenvalue. Similar for C, these are going to be the two eigenvectors. And at B, we might expect there to be one component that is complex that produces these rotations, and another one that's real, but positive. So it exponentially-- this spiral grows away from the fixed point.
OK, once again pausing for questions. This concludes our brief introduction to dynamical systems. So if there's any questions, feel free to ask now. Once again, I highly recommend that book by Steven Strogatz if you're interested in reading more about this stuff.
AUDIENCE: I have a quick question. You were saying that I linearize the system. I compute the eigenvalues, which, I mean, in the complex case then-- so Eli has to now switch to some complex space. It doesn't make sense to consider the diagonalized matrix with complex values.
But I assume you're just looking at the eigenvalues, and then you want to infer some information about the system from that, right? So you're not really using the linear map now to reconstruct your system? Or how would that work? Because otherwise, now you're switching to t complex vector space, not a real vector space anymore.
ELI POLLOCK: Right, so then it's going to be treated differently. Instead of having like growth or decay, complex eigenvalues are going to describe the frequency of oscillations around a fixed point, for example. So the information contained in the eigenvalues is still useful for describing the behavior of [? the ?] system.
AUDIENCE: Yeah, I mean, what I'm saying is, like if you apply a matrix with complex values, then you end up in a different space, unless you do something to project back to real space, right? So now you switch from rn. If I have complex values, then I switch to cn. But I assume you mean that there's some information you can extract from that now, right?
ELI POLLOCK: Yeah. I mean, I guess like this stuff is really useful, I think, for mostly building a qualitative description of what's going on here. Remember, we're not we're not at this point talking specifically about analytic solutions. We're talking about approximate descriptions of the system locally. So yeah, I guess if you were to do the math, you would end up with complicated things happening in the complex space. But I don't know. If we're just going for like a qualitative description, the eigenvalues give us, like I said, the frequency.
AUDIENCE: I don't think you'd ever get complex valued matrices, right? Because the functions themselves would be [? r to r, ?] and the matrix would be a real matrix. You could just have--
AUDIENCE: The matrix is real, sure. That's just the Jacobian. But the question is, if you compute the eigenvalues, they might be complex.
You're basically solving-- I mean, the eigenvalues [INAUDIBLE] basically the zeros of a polynomial. So you get complex values there. So all I'm saying is you're looking at those values, and then that tells you something about the nature of your fixed points. But anyway, yeah.
ELI POLLOCK: Yeah, I mean--
AUDIENCE: [INAUDIBLE] whatever. It's fine.
ELI POLLOCK: OK, yeah, I mean the complex values in the eigenvectors ensure that everything complex cancels [INAUDIBLE], and you end up with stuff happening in real space. But yeah, I'm going to [INAUDIBLE]. Yeah, I'll move on, unless there are any other questions about this stuff.
NHAT LE: There is a question in the chat, Eli, about how many fixed points does a linear system have. I assume--
ELI POLLOCK: Good question. So generally speaking, a two-dimensional linear system can only be described in terms of things that look like this. So it will always either only have one in the middle, or it will have infinitely many all existing on one of the eigenvectors.
I'm trying to think if you have higher-dimensional systems, can you have multiple fixed points. I don't think so if we're just talking about linear dynamical systems. I think you generally have to go to nonlinear systems if you want to have more of them.
All right, so now we can get to the fun part. I mean, that first part was fun, too. But we're going to talk about recurrent neural networks viewed as dynamical systems.
So basically, the important idea here is that recurrent neural networks are just high-dimensional dynamical systems. That's just generally what they are. They are systems that have a state vector.
They might have some number of hidden units inside them. And those are going to change as a function of their current state and inputs in some way. I'm going to mostly be talking about ones that have this equation right here, where that f of x dynamical function is given by this, where the change in state depends on a membrane current leak, let's say, as well as a recurrent weight matrix. That max [? sum ?] nonlinear function of all of the neural activity to each other, plus an input.
Now, you could also have things that account for different kinds of ion channels and different excitatory or inhibitory neurons. You could go more in a machine learning direction and have LSTMs that have different terms here that introduce more complicated things. But no matter what you do, it's still going to be a dynamical system, and can thus be studied as a nonlinear dynamical system.
Another important thing about RNNs that I just wanted to bring up here is that, with enough units, an RNN, depending on the exact kind of nonlinearities that you're using, can approximate any dynamical system. So what that means is that you can have a linear mapping between the neural activity, and you can map that to an arbitrary dynamical system-- even a very non-linear one. So that's like an extension of the universal approximation theorem for deep networks, where a sufficiently large network can approximate any function. All right--
AUDIENCE: [INAUDIBLE] [? here. ?] You're assuming just in terms of the universal approximation, do you [? need ?] [INAUDIBLE]? I mean, that's just a-- do you get what I'm saying?
ELI POLLOCK: No, could you be a-- I'm not [INAUDIBLE].
AUDIENCE: I mean, I get [? that ?] you can approximate any function from, let's say, x to hidden layer to x again, right? So that's basically what you would like to approximate. [? But ?] here, you're going from x to x without a hidden layer.
ELI POLLOCK: So there is-- I mean, the hidden layer is like the recurrent [INAUDIBLE]. I'm saying that you can accommodate any number of, like, xt to xt plus 1 mappings with enough units. All right, so how is it useful that we can describe them as a [INAUDIBLE]?
So they are very nonlinear. These dynamics are very non-linear. But if we can find fixed points, we can then linearize dynamics around them to understand the computational role of those fixed points if they're doing something to help an RNN solve a task.
So for example, one thing that we might look for is the idea of a line attractor. So this looks a lot like the stuff that we saw earlier with one type of linearized dynamical regime. There is evidence that in the brain, sometimes you have representations that exist on a line-- let's say like decision variables.
And you're integrating evidence along this line. And some evidence will push you towards one decision, and other evidence will push you towards another decision. But generally, you want to be confined to this line.
So again, we can ask, like, how would we find a structure like this? And we can linearize around these fixed points, and find that there is a very small, maybe 0, eigenvalue in this direction, meaning that there's not really any intrinsic dynamics. And so inputs will push you along it. And then maybe there are negative eigenvalues lying in all other directions, so that any perturbations away from the line will immediately decay, and you'll fall back to it.
All right, so the first part of this, if we want to do some sort of analysis of analyzing RNN fixed points is that we need to find the fixed points. And this is going to be the-- I'm borrowing from this paper here, "Opening the Black Box, Low-Dimensional Dynamics in High-Dimensional Recurrent Neural Networks." It's a great paper describing how you can find fixed points, and then linearize around them.
So in this particular paper, and in the tool that I'm going to be showing you guys in just a minute, the way that they find these fixed points is they first initialize at a bunch of points in network state space. They then define this function q of x, which is analogous to a kinetic energy. Like, how fast is the network moving at that state, where f is again just the function describing the dx/dt.
So if we take an absolute value of this and square it, we can then follow the gradient descent. Just taking the partial derivatives of this q function with respect to x, we can just follow gradient descent. And descending down those gradients will lead to fixed points eventually, or at least places where we have minima in the speed of trajectories through this state space. And to get this gradient descent, all we really need is, at each point, we need to evaluate what is x dot at that point, and what is the Jacobian at that point. So once again, the Jacobian comes in handy.
All right, so that's how we're finding the fixed points. We're just taking this function that describes how quickly is this system moving at various points. And we're following gradient descent to minimize that, finding places where hopefully f of x equals 0. And then we went to linearize the RNN around the fixed points. And this, as we've discussed earlier, requires you to take the Jacobian matrix.
So given the equation for x dot, this is just a-- instead of the matrix multiplication view from earlier, this is now like a sum of multiplications. It's the same thing. If we take the Jacobian, we end up with this fairly simple equation. And we can evaluate this at fixed points, and find the eigenvectors and eigenvalues. And that tells us what happens to perturbations along different dimensions.
So I went through that pretty quickly. So I'll describe an example that will hopefully clarify what I mean by that. So in this paper that I'm talking about this technique from, one experiment that they did was they trained an RNN to remember the sign of the last pulse it received in each of three different input channels.
So like if it received a positive pulse in channel 1, it would stay at 1 until it receives a negative pulse. And then it goes back down. So they trained an RNN to do this. This is basically what they refer to as a 3-bit flip-flop task.
And what they did was they visualized the dynamics, the first three principal components. And they looked at what are all the different fixed points that we can find, and then what do the eigendecompositions at each of those fixed points say about them? And they discovered this really interesting geometry of the representation that the network is using to solve the task, where basically, each of these corners is a stable fixed point.
The dark x's are stable fixed points that represent these areas where the network is remembering the state of the three bits. And then in between each one is an unstable fixed point. And every time a pulse is received, it knocks you in that direction.
And once you get over that, you go from the attractive dynamics around here to attractive dynamics towards this one. So again, they show that pulses result in movement along some direction. And then there's relaxation dynamics towards another fixed point that updates the state of the system in the way that you want.
Another example that they show is this input-dependent sine generator, where they have another network where they give it different input values. And depending on the value of the input, it's supposed to generate sine waves of different frequency. And what they found was that there was basically a bunch of fixed points. And each one had with it an associated complex eigenvalue. And the frequency of that eigenvalue corresponded almost perfectly with the frequency of the sine that the network was supposed to generate at that point. So once again, they're discovering a computational role of the fixed points in this representation.
I just want to point this out to people as something that they can use on recurrent neural networks to find and analyze fixed points. It's not a tool that I have extensively used myself, and it has a bunch of library dependencies. So we're not going to be doing a deep dive into it right now. But I want to just briefly go through the documentation, just to give everybody a sense of how they might use this.
So it can be used on pretty much any RNN using TensorFlow. So if you're interested in experimenting with RNNs, TensorFlow is a pretty useful Python library for doing that. And it can work on different kinds of RNNs.
So I mentioned earlier, like, you can have a vanilla recurrent neural network that just has very simple recurrent weights. Or you could have something more complicated, like an LSTM, which has a memory component. You can still find the Jacobian of an LSTM, and that's something that this package can do for you.
So if you have LSTMs, this will work on it. So you build an RNN, you train it on a task-- whatever. And then you want to analyze it.
So you create a FixedPointFinder object by plugging in the RNN cell. And you can give it a bunch of initial states that you want to start from to find the fixed points. And once you do that, it does that gradient descent procedure that we talked about, and it will find the fixed points. And exactly like we've talked about, it'll do the linearization, and give you what the Jacobian eigenvalues and eigenvectors are at different points.
So for example, they use this to replicate the results from the 3-bit flip-flop task. Here's an example of the inputs and the network outputs. And they get this figure right here, which looks a lot like the one that we encountered right here, where basically, they're reproducing those results.
AUDIENCE: I have a quick question.
ELI POLLOCK: Yeah, sure.
AUDIENCE: Is there a smart way to choose the initial conditions so that you know that you're sampling the fixed points in a non-biased way.
ELI POLLOCK: Yeah, so you definitely want to-- like, I think that the best way is to create a grid that tiles the overall state space that you're using. It would probably be good to run-- get a sense of what's the general area of your network state space that it is in. Because it is very easy for this system to find fixed points that just exist way off somewhere in state space, but aren't computationally relevant to the task. So I think that the best way to use it is, again, you want to keep the samples relatively close to where it traverse during the [INAUDIBLE].
AUDIENCE: [? Cool. ?] Thanks.
ELI POLLOCK: And I also want to just bring up these-- like this question right here. This is a really cool tool, I think. But it's also not a hammer that can be applied to any problem. So if you have a network that you want to analyze, you might be tempted to use this.
But I would definitely caution anybody to think about, like, is this tool going to be useful in this case? This is most useful when you suspect that there are fixed points that are helping with the computation. If your network isn't really making use of fixed points at all, or if you wouldn't expect there to be fixed points that are creating attractor states to help with memory or complex values that are helping with rotations-- if you don't suspect that those are going to be present, if you apply this tool, you might find fixed points. But they won't necessarily be doing anything that's super-relevant to your task.
It's likely that they will be there, but they just might not be computationally relevant. So I would definitely recommend using it, but also thinking really hard about a hypothesis beforehand. Like, what kind of fixed points would produce the behavior that you're seeing? And once you have that hypothesis, you can use this tool to see if your hypothesis is correct.
AUDIENCE: Are there any other tools or ideas to find other kinds of stable structures, like limit cycles or some other interesting things? I feel like this method wouldn't work here because it's trying to find when the trajectory slows down [? into ?] limit cycles and other structures. That wouldn't be a good proxy for reaching the stable thing.
ELI POLLOCK: Yeah, that's a really good question. I don't know of any. But that is, I think, very relevant. Because yeah, fixed points are certainly not the only computationally useful structure. So yeah, if you're talking about, I don't know, oscillations, limit cycles would probably be a more interesting thing to look for.
Yeah, I don't think this tool would necessarily help with that. And I don't know of other ones that would. I don't know. It's worth looking into, and might be an interesting project.
AUDIENCE: Yeah. That does sound interesting.
ELI POLLOCK: I mean, you would look for just, like, if you-- yeah. I don't know how you would exactly find a limit cycle. I guess it would just be, like, you let it run. And if you keep moving through the same loop, you would call that a limit cycle. But yeah, I'm not sure.
AUDIENCE: Then you got chaos. [INAUDIBLE].
ELI POLLOCK: Well, if it never repeats, then you have chaos, yeah. But I don't know. I don't know what tasks that would be useful for exactly, but I'm sure there are some. All right, so to recap so far, we looked over the theory behind linearizing non-linear dynamic systems. And we looked at how you can apply this analysis to RNNs to get a sense of how fixed points in a system can be computationally interesting.
I'm going to spend the next half hour or so talking about some work that I have done on an inverse approach, where if we have a hypothesis or knowledge about x, the states that a network should take, and f of x its movement through those states, and possibly Jacobians, can we then synthesize networks? So why might you want to do this? Well, I think that there is a case to be made for creating models that implement very specific dynamical hypotheses about how you can solve a task.
So we could, for example, start from this idea of breaking tasks into latent variables, where we might have some task. In this case, this is like a ready, set, go task, let's say, where you're trying to create a time interval of this certain length after measuring it. And this can be solved through a latent dynamical model, where you move along this axis while you're measuring an interval. And then you get kicked onto this rectangular manifold. And depending on where you get kicked onto it, you move at different speeds until you hit a threshold, at which point you create a movement.
And this is a hypothesis that has come from studying neural data. But here, it's clearly shown. What if we want a network model that's able to do this, or some transformation of this?
So we might imagine that we might want a model that is able to take this, embed it into some neural space, [? where ?] maybe it's a linear embedding. Maybe it's a nonlinear embedding. But we're basically taking this model, embedding it into a neural space, and then we want to be able to have a network that solves it in that exact way.
So we're basically saying we want this model. And then we can study other things about this model, like-- I don't know, how big does the network have to be, what sort of weights does it have to produce those dynamics, things like that. But basically, putting really heavy constraints on the dynamics of a network.
All right, so once again, here's the whole thing. We're going from task to latent task dynamics. Embedding those into neural dynamics to get an overall model.
So we can do this in three steps. We want to first describe the task, like that first image. Then we want to sample those latent task dates and dynamics.
And then finally, we're going to set constraints for the RNN connectivity, and solve for the connectivity of a network that will reproduce those. And the way to do this is that we're going to do exactly that inverse approach that I mentioned a few slides ago, where we're going to say, instead of finding the Jacobian at different points, we're going to set what we want the Jacobian to be at a bunch of different set points, and produce constraints that allow us to solve for a network. And this method is described in this paper that recently came out. I called the method embedding manifolds with population-level Jacobians, or EMPJ for short.
And in the interest of time, and not confusing people with a lot of math, I'm just going to really quickly go through the equations here. Basically, we're starting from the network equation. We take the Jacobian, which is equal to this mess right here. We come up with an intuitive decomposition of the Jacobian into different eigenvectors and eigenvalues. And we can rearrange the terms to end up with an equation that looks like this, where we can have a matrix of eigenvectors in different directions.
We have a diagonal matrix with a derivative evaluated at a given point, a matrix of eigenvalues, and again a matrix of eigenvectors. And we have this matrix that we want to solve for. Now importantly, this equation has the form A W equals B. So it's a linear equation.
We can get an equation like this at a bunch of different set points in our state space, and create just one big linear equation by stacking those constraints on top of each other and just solving for W. And I'll maybe show you what I mean in this notebook, which I invite everybody to open right now. I'm going of go through and describe how I've used this method to create-- or how you can use this network to create networks that implement known dynamical solutions.
NHAT LE: Yeah, I have a quick question, actually, Eli. Can you go to your diagrams? No, the ones with the different colors, where you--
ELI POLLOCK: Oh, the equations?
NHAT LE: Yeah. So then you said you constructed this linear system. And so what points are you sampling from to solve this system?
ELI POLLOCK: Right, so I'm assuming that there is a manifold that should contain all of the states during the task. And I'm sampling points from that manifold. And then at each point, I'm specifying what are the relevant eigenvectors and eigenvalues, and just [INAUDIBLE].
NHAT LE: [? I see. ?] OK.
ELI POLLOCK: Yeah. What are the dynamics at that point?
NHAT LE: Got you.
AUDIENCE: I have also, I guess, a naive question, as somebody who doesn't necessarily work in this field, which is, in what situation do you have-- I guess I was curious about a situation where you have the Jacobian and the x points and the eigenvectors, but you want to construct this network.
ELI POLLOCK: I mean, yeah, it is-- I don't know. I don't know how broadly applicable this technique is. It might be a case where, like I mentioned, if you have a hypothesis about the way that a network is doing something, and you want to create a network that does it in that way if you just want a network model that implements a known dynamical solution.
I don't know exactly how specific I can get with situations where that would be the case. But usually, it's going to be in cases where you have a relatively simple task. It might be something that you're training a monkey to do, where generally those are going to only involve a few different variables. And if you have an idea about how those variables are interacting in the brain, and you want a network model that implements that, this would be a way to do that.
AUDIENCE: Cool. Yeah, so it's like a strong way to test hypotheses.
ELI POLLOCK: Exactly. So you can always train a network to do a task, but you're never going to be sure like exactly how it's going to solve it. In this case, you know. You're going from the solution to having a model. And then you can ask questions about that model. OK--
[INTERPOSING VOICES]
AUDIENCE: I don't know if this makes any sense. But at a very abstract level, can you think about this as like learning dynamics in a linear state space, and then lifting those into a state space of an RNN?
ELI POLLOCK: Exactly. Yeah, that's pretty much what this is. It's just taking a bunch of transitions between variable states for a task, and just embedding that in some desired way into an RNN model. So I'm going to go through some code that I have made for people to experiment with for doing different things with this technique.
There's this Colab notebook. It uses pretty much standard Python libraries. There shouldn't be any issues with just running the notebook. So I'll just go through-- there's a few sections, and I'll go through them. And we can see a quick example at the end, and maybe people can play around with that for about 10 minutes and ask questions.
OK, so I'm going to be using a vanilla RNN class. This is just some custom code that I wrote to implement a recurrent neural network. It has a bunch of parameters.
And it also just needs to be able to, basically, take inputs and run. So there's this run function. This isn't really that important, so I'm going to hide that section for the most part. But this is just the-- when one we want to run the network and visualize as activity, this is what it's going to be using.
Then this is the more interesting part to the method, where we want to have this manifold where we're sampling points, and then plugging those points into something that is going to spit out a recurrent weight matrix for an answer. So the way that I'm doing this is, I'm creating a bunch of low-dimensional x values, low-dimensional f values, defining the flow at all those points. I'm also creating different Jacobian eigenvectors and Jacobian eigenvalues at each point. So when you initialize this manifold class, you end up with an empty list of all of these things.
And then you can gradually add components of this manifold. So you might want to have a line. Like, you might want to construct a plane using different lines, or you might want to construct different ring attractors interacting in some way.
But basically, you can add components using this method of the manifold class by giving it a list of points, and a corresponding list of vectors describing what the flow over the manifold should be at all those different points. And then once you do that, you can manually add in, if you want, what are the Jacobian eigenvectors and eigenvalues going to be at each of those points. But that's optional, since obviously, you might not want to specify those in a whole bunch of different dimensions.
So then once you have that manifold class, you can use this make RNN function, which is going to take in the manifold-- take those things, create a bunch of linear constraints. Like I mentioned there's those A and B matrices, which are going to be stacks of different linear constraints. And then we just plug that into just a linear algebra least squares problem-solver, which is going to spit out a weight matrix. And then we can create an RNN with that weight matrix, initialize it at different points, and check whether it carries out the computations that we're interested in. So I'm going to then hide this.
So here's an example where let's say that, in this case, I'm just going to do a really simple example, where we have a ring-shaped manifold. And let's say that we want some one-dimensional dynamics around it, but we also want the ring to be stable. We want it so that, if you have perturbations off of the ring, they decay back onto it.
So here, we're defining this manifold using the class. We can set a bunch of parameters for the drift function and for how much time we want to simulate for. And then we add a component, where we're just taking points, going around a ring, using, like, a sine and a cosine parameterization.
And then we define some sort of sinusoidal drift function over that ring. And then if we want, we can define different eigenvalues and eigenvectors, and then make an RNN that uses this manifold. And then we can use this view manifold function to just simulate the RNN, starting from a bunch of different points around the ring.
So if we run everything-- how did we run [INAUDIBLE] all. If we run everything, we get this. But then, we can experiment with different initializations or different parameters.
So here, we have a case where there are, I think, six fixed points around the ring. Yeah, that's this omega parameter. We might want to set it to 4. And I'm also going to initialize at this IC scale [INAUDIBLE] sets where our initial points are, whether it's on the ring that we're specifying, or past it. So I'm going to set this to 1, and I'm going to rerun this cell.
Yeah, so now there's basically-- like, these are unstable fixed points. And then they're stable fixed points in the middle here, so that there's four unstable and four stable, where all the trajectories starting from these points are just falling into the [? middle ?] here. So again, we can also try something where we're starting from inside the ring, where now the points get sucked onto the ring, and then move around in the predicted way.
So something to try is, like, if we want to, for example, have eigenvalues in the radial direction that are unstable, we could make this positive. So instead of, like, negative 20, we could set it to something just like 2. And what that means is that there's very weird stuff that ends up happening here.
So yeah, I feel like I might be losing people, though. So I want to address any questions that are in the chat. It looks like the chat is-- OK, there are questions about other stuff. I don't know. Any other questions about this particular way of doing things?
AUDIENCE: Just a quick question about the stability. In your 3D picture, you had, I guess, four unstable equilibria. You call them fixed points, and then the rest of the equilibrium between the fixed points are stable. I'm just curious if the stability of these equilibria is fixed. Are there any parameters that you can vary to change the stability of the equilibrium, or that's not really what you want to do?
ELI POLLOCK: So you can change the stability of the ring itself. So if you want any perturbations off of the ring to fall into the center, or zoom off towards other fixed points, for example, you can do-- so what's happening here is that I've set it so that you're initializing close to the ring. But because we're starting not quite on the ring, and the ring isn't stable, the trajectories are all falling towards the center. And there's going to be one global fixed point, where now if we if we boost the simulation time to, like, 10 seconds instead of one second, we have something like this, where everything goes around there.
But you could like different-- so right now, there is a drift function going around the outside, going up on the ring. And so you can change the number of fixed points on that. But with this particular implementation, it's just like a sinusoidal drift function. So you're going to end up with, I don't know, an equal number of stable and unstable fixed points.
AUDIENCE: OK. Thank you. What is the delay thing? I see that you have something you call delay. What is that?
ELI POLLOCK: Yeah, that's just the simulation time. That's just the way that this code got written. That's just the name that the variable happened to take.
AUDIENCE: I see. And where do you go from here now, because this is just linearization around a fixed point?
ELI POLLOCK: Yeah, so this is the conclusion of the talk. So if people have questions here, you can-- I mean, like where do you go from here? You can you do a lot--
AUDIENCE: --picture.
ELI POLLOCK: Huh?
AUDIENCE: Global picture.
ELI POLLOCK: Right, so by setting linear constraints in a lot of different places, you can get some pretty complex global behavior. So if you're setting a bunch of linear constraints over manifolds with dynamics that are doing some more complex tasks, you can create, I don't know, a lot more interesting models. And that's explored a lot more.
I'm running out of time. But if you look up this paper, some of those applications are described in there, where instead of just having a ring that's sitting in two dimensions, you can think about, well, what happens if you want to bend the ring into higher dimensions? You can get more interesting representations that way.
Or let's say that you want to build in a way for controlling the speed that you're moving around the ring with contextual inputs. That's something else I explored in the paper. So basically what this means is that the limits of your hypotheses about how you can create a network to solve any given task is now only limited by your creativity for thinking about different kinds of dynamical solutions. Yeah, thank you, guys, for the great questions. And thanks for having me today.