Learning to Reason, Insights from Language Modeling
Date Posted:
October 21, 2024
Date Recorded:
October 1, 2024
CBMM Speaker(s):
Noah Goodman All Captioned Videos Brains, Minds and Machines Seminar Series
Description:
Noah Goodman, Stanford University
JOSHUA TENENBAUM: Welcome, everybody. Today, we are extremely fortunate to have Noah Goodman as our guest speaker. Some of you who've been around here as long as I have remember when Noah was here. He was here as a postdoc in our group and worked with a number of other groups many years ago. And before moving out to Stanford, where he's been on the faculty since, he's also founded a couple of companies, done some time in industry, and really-- I think that's an OK way to put it-- [LAUGHS]
--and has been one of the most really interesting and important figures in the space between cognitive science and AI. Never more so than now, I think. But over the years, he has been one of the pioneers in probabilistic programming, helping to create some of the first Turing-complete probabilistic programming languages and use those for modeling cognition.
For everybody who's ever taken any class from me here and use prob mods, of course, you know that all came from Noah and his students also. He used these and other tools to really introduce pioneering ideas in computational understanding of language for pragmatics and related inferences about minds from language like the RSA theory.
He also has been, for a very long time, interested in the interaction between neural networks and learning approaches with reasoning approaches. You'll see some of that here, but he was one of the pioneers of what-- different people mean different things by this, but Noah was one of the originators of the term 'amortized inference' and ways you can use a machine learning system to learn to do various kinds of probabilistic inferences very efficiently. He's done all of these kinds of things.
Most recently, he and his group have been doing a lot of work with language models, both in AI and as models of human-like or maybe human-related inference. And I think this talk will cross all those lines. I think his work has been having quite a lot of impact in the broader world, in industry. Probably most of you know that-- was it, like, what, a year ago or something? Sam Altman was briefly fired from OpenAI.
There were lots of rumors swirling-- why was he fired? What was going on? There was-- one of the many rumors around there was that some top secret thing was going on that maybe he didn't tell the board about that had to do with ways in which they were building some new superhuman kind of reasoning.
No one knows exactly what that was about, but you also probably might have noticed-- was it a couple of weeks ago? When they released some new-- totally new-- kind of reasoning system in GPT that they had to start the numbering system over from scratch.
Anyway. Nobody quite knows what was going on then, nobody quite knows what's going on now. Definitely some interesting things have been going on, but I think there's good reason to think that one of the basic things that OpenAI has been scaling up in their latest systems for learning to reason are closely related to some of the ideas that I think Noah has developed in his group and that he'll tell us about here. So I'll just-- with that, I'll turn it over to Noah to talk to us about learning to reason. OK.
[APPLAUSE]
NOAH GOODMAN: Thanks, Josh. That was kind of epic. I don't know that I can live up to that. So yeah. So today, I want to talk about learning to reason. And I feel obligated to warn you that this is an AI talk, even though I'm a cognitive scientist, because it's harder to get behavioral data that actually matches the rapid progress of AI. So in another 10 years, I'll come back and I'll tell you all about the psychological side of this.
But for now, I'm going to warm up by reimagining my motivations when I got here almost 20 years ago as a postdoc. I was really interested in high-level cognition, the kinds of stuff we do in language, that are reasoning and math and stuff. And there was really only one thing in the universe that you could study if you wanted to understand these, which was the human brain or mind.
For people who are lucky enough to be interested in vision, you could do these comparative studies. And this was really awesome. But since I was interested in language and high-level cognition, there were no comparative studies that were really open to me. I guess maybe young children versus adults was an example of this since I'm looking at Laura here, who I worked with on that back then. But that's not the point of this.
If the goal was to understand not just the human mind, but the space of all intelligent things that had this kind of high level intelligence that I was interested in-- back then, we could only study one corner.
I think we are currently living in a really exciting moment because there's this new set of systems in the world, and we can have lots of fun arguments about whether these are human-like or human-related or intelligent or not, but it's definitely the case that they are capable of a bunch of behaviors that, until recently, I would have said are the things that define this realm of stuff I'm interested in. Being able to fluently use language in interesting ways and so on.
And that's exciting because I think now, we can start to triangulate a little bit more and try to get at what's going on in the space of intelligence that is the-- well, wait, I have a slide on this. It's a slide that goes back to the cringey and trite metaphor that I still really like of understanding the space of flying things.
If you lived before airplanes and you were trying to understand the space of flying things, you would study birds. It would be great. You'd learn all sorts of things about birds. You would obsess a lot about details like feathers and how great they are. And you'd learn a lot. And then all of a sudden, somebody makes a jet airplane, and you're looking for the feathers and you're not finding them.
And the point of this analogy, of course, is that when you think about flying things, if you're interested in the space of flying things, you need to try to distinguish the things that are more universal principles that will apply to birds and airplanes, things like principles of lift and aerodynamics. At the other end, the things that are contingent-- they're just things that happen to be there because birds have a particular evolutionary history-- feathers are made of keratin.
And then there's an interesting space in the middle, which I'm calling-- although I'm not sure it's the right term-- the functional demands. They're not fully universal to the space of all the flying things, but they might characterize a broad subspace of flying things. So for me, a really good case of this is the fact that birds flap their wings.
And you might think, oh, this is a weird, like, just thing about birds. It's not, actually. It's a response to the particular functional demand of needing to take off at low airspeed. Airplanes don't, and so they don't have to flap their wings. And so if you're interested in taking off at low airspeeds, you might realize there's this principle that characterizes not all the flying things, but some of them in an important way.
So the project that I, I guess, have in the back of my head is, let's think about this kind of taxonomy with respect to the space of intelligent things now that we are starting to have a few more points in that space to compare to each other. Now, I probably won't make good on that in this talk, but keep this in mind as the program that's behind this, the set of questions.
Now specifically in this talk, I want to think about reasoning. And what I mean by reasoning, at least for this talk, is things like logical reasoning, like syllogistic reasoning that goes back to ancient Greece. Is Socrates mortal? Well, Socrates is a man and men are mortal, so Socrates is mortal. Or things like solving simple equations by working through these steps.
And for reasoning, there's a lot of questions you can have, but I think the most basic is, what is it? How can we think about it? Where does it come from? And are there sort of reasons, more general reasons that we have it in our heads?
So-- now, I didn't have great ways of thinking about this for a long time-- and some people will tell me that I still don't. But a thing happened-- what was this, maybe five or six years ago, that I found very striking. Large language models started happening, and that was kind of cool for all sorts of reasons.
But what really got me excited about them was when it was observed in a couple of papers that there was this chain of thought thing you can do. So you take a language model that's trained to predict the next token on a big corpus. And if you just immediately ask it to immediately answer a word problem, a math word problem, it'll give you an answer, and it'll do OK. Not terrible, not great.
And the interesting observation is that if you give it room and you ask it, for instance, think step by step or something like that in between, and you give it room to generate some token, some words before it answers, all of a sudden, if the model is big enough and trained on enough data, it starts to do better due to that, what's called-- it got called many things originally. These days, mostly, it's called chain of thought.
So this is a paper from one of the original chain of thought-- a picture from one of the original chain of thought papers showing that for big enough models, there's a benefit to a chain of thought over direct answering.
So I thought this was really cool because it's suggestive of the kind of extended reflective process that we go through when we're reasoning that makes our beliefs at the end of that process better than our beliefs were at the beginning of that process.
So my students and I naturally were asking, like, OK, but why do language models do this? It's not obvious to me at all that they should do this or that this should be a property of a language model. So that's this question, why do they do that?
And so my student, Ben Prystawski, had the idea that we can boil this down to a very, very simple toy setting in order to study it. And the toy setting is we set up a Bayes net, and this is a Bayes net with 100 Boolean variables and some causal structure that we just come up with, more or less, randomly. And we think of this Bayes net as defining the world.
Now it defines a world, which means that you can take a sample from the joint distribution specified by the Bayes' net and say, OK, that's a particular state the world could be in. And then you can write that down as a sequence by just taking a random ordering of the variables and writing X12 equals 0, X4 equals 1, X11 equals 1, So this is just a string of characters now, or tokens.
And then you can ask yourself, OK, so what happens if you train a little toy brain-- but actually a transformer-- on these sequences of values in this Baye net? Now, I'll come back to this in a minute, but we're using a transformer here, but I want you to think of it as just a little density-- sequential density estimator. So the important property it has is it's doing its best to be able to approximate the distribution over these sequences of variables.
JOSHUA TENENBAUM: Do you know, is the sequence based on the order of the Bayes net? Or what's the--
NOAH GOODMAN: So, so far, no, the sequence is random. So to generate the data, we sample from the Bayes net to fill in all the values of the variables, we choose a random permutation, and we show them in that order. And by the way, yes, you should follow Josh's lead and interrupt me if you have questions as I'm going along.
OK. Now, here's what's interesting about having this little transformer, a density estimator. There's two very natural ways you can use it to estimate a given conditional probability in that Bayes net. Oh, and I should point out, the nice thing about using a Bayes net for all of this is that we know the actual conditional probability distributions in the world, whereas like for all the text on the internet, we don't really know those probabilities.
So we know what the conditional probability of C given A is. And we can use our little transformer as an estimator for that conditional probability in two ways. The first is we can just directly use the transformer. We give it the string A equals-- let's say that's 0-- A equals 0, C equals, and we see what value it fills in next. And that gives us a probability that C is 0 or 1.
Or we can use the same thing as a Monte Carlo estimator where what we do is we start out the transformer with A equals 0, and then we let it just freely generate whatever sequence of tokens it wants to generate until it generates C equals in some value. And that gives us another estimator of this conditional probability where we've sort of simulated-- we've forward-sampled whatever the transformer wants to, but variables in between A and-- what did I say? A and C here.
So now we have two estimators, and we can just ask a simple statistical question, which is, when, if ever, is it the case that this Monte Carlo estimator has lower bias in estimating the true probability than the direct estimator? Now here, I'm calling it a Monte Carlo estimator, but it should be pretty obvious to you that this is chain of thought. This is where you let the model just freely generate until it decides to give you the answer. And so we're asking, when is the chain-of-thought-based way of getting an answer lower bias than the direct way of getting an answer?
OK, now before I give away the punchline and tell you something about this, let me introduce one other ingredient here, which is pseudo-motivated by psychological considerations. So humans have not- perfect knowledge of the full state of the world at any given moment. The world is big and we don't see most of it. What we have is an egocentric viewpoint on the world. And so we open our eyes and we only see little regions of all of the stuff that's true about the world at any moment.
And so this motivates adjusting that training information that we give the transformer. So one condition is the global observation. We give it everything that's true on a given instance. And the other is a local observation where we'll choose a variable in the Bayes net, and then we'll choose a local neighborhood of that variable, and we'll only show the transformer the value of that local neighborhood on each go. OK. So we have the global and the local training conditions, and then the direct and the Monte Carlo estimators for the probabilities.
So here are the results. The blue lines are the fully observed training condition and the red lines are the local observation conditions. And the first thing to notice is that you learn much faster in the local observation cases. This is not surprising because we're asking the transformer to capture much less information. It only has to master these little local neighborhoods, not like all of the statistical relations in the full Bayes net. OK, so that's good to know.
The other thing is that in the local but not global observation conditions, there's a gap where the chain-of-thought estimator has lower bias than the direct estimator. So this shows that in this case, there's a benefit to reasoning, to chain of thought, if the data has this local structure as this local observation structure.
AUDIENCE: A system where you have this Monte Carlo sampling, that's the-- the result will be the same? Is it like a consistent result or--
NOAH GOODMAN: You mean across different instantiations of, like-- yes, OK, good. That's the next question. Thank you. Perfect segue. So, I showed you, basically, an existential proof that there exists some Bayes net and training setup where you get lower bias from the chain-of-thought than the direct estimator. A natural question is whether this is universal. Like, is this generally true?
So we have a little math in this paper. We didn't do it in the-- well, first of all, I'll say empirically, we did it across, I don't know, five or 10 different Bayes nets and varied the size of the transformer that we trained and a few other kind of parameters like that that reviewers asked for. And it is consistent across all of those instantiations.
But I think the more compelling thing-- you don't actually have to read the details of this theorem. The more compelling thing is that, at least in the idealized case of a chain Bayes net where the proof is a little bit easier, if you imagine a chain Bayes net where you're training on adjacent pairs of variables that's the local observation, then you can prove that the Monte Carlo estimator will always have lower bias than the direct estimator.
So this is for that class of Bayes nets. And, I guess, under a key assumption here, which is that the density estimator is a risk-minimizing-- OK, for those who like math, it's a risk-minimizing, entropy-regularized density estimator, then you will always have this phenomenon.
I do want to point out the interesting thing, that this theorem never says transformer anywhere, and it doesn't depend on any inductive biases of the transformer or something about attention mechanisms or something. This is actually something that is true of density estimators. So the important property here is that you're trying to do sequential density estimation, and you're doing it well enough, you're risk-minimizing. So you do it-- you capture the distribution well enough.
So that suggests that the phenomenon, that chain of thought, is helpful for getting better answers is actually very general. It has to do with not any specifics of the model, but the fact that it's a good enough density estimator and it has this local observation property of the data. So it's something about the data and the goodness of the estimator. Also, look, there's math. That means it's real, right?
OK. So even really big language models are not going to be able to learn all the relevant conditional distributions for the world, everything they might need. And equally true of even really smart people. Instead, what they can do is they can reason.
And what this toy model suggests is that reasoning is really about bridging together the local expertise that the model has by, in metaphorical sense, wandering through the connected regions of conceptual space to get from to what you're trying to figure out. And that this is a very general property when the experience has this particular structure, this local structure.
However, you might have noticed that chaining world states together doesn't really exactly have this flavor of figuring it out, of thinking hard and realizing that you're not quite doing it right and this sort of stuff. And relatedly, the chains that come out of this --just like the transformer forward-sample-- they're pretty inefficient. I told you nothing about, oh yeah, it's going on the optimal path between A and C or anything like that. I just said you let it wander around and freely generate. And so it generates a pretty meandering conceptual path toward the answer.
So, the next parts are about-- more or less thinking about those properties. AI-generated art. Feast your eyes. OK. So I don't really need to introduce the game of 24 at MIT, I think, because Josh talks about it a lot, at least I think I first heard about it from Josh a long time ago. It's a fun game where you're given some input numbers, you have to combine them with arithmetic operations and get the output number.
And it's not particularly easy. It takes a lot of-- even if you're a sort of expert, it takes a lot of thinking. In my lab recently, we've been collecting verbal protocol data for this game and games like it, which-- that's a different talk, which we don't have all the data yet, so this is maybe the one I'll give in a couple of years.
But verbal protocol is awesome. Newell and Simon and Eriksson knew this, but we sort of forgot it for a long time. But now we have all this new technology. So we can collect a lot of verbal protocol data from the internet, we can transcribe it really easily with modern tools, we can analyze it nicely. But that's not the point.
The point here is that when people go about solving this kind of problem, they go through this interesting, meandering path that has things like, oh, we're going to do plus 2 by order of operations, so we have doubles, three doubles. Oops, no, no, no, no. And then he backs up. So we have 6 minus 5, and then keeps going.
And this is kind of interesting because it's quite characteristic, I think, of us figuring out hard problems. But language models don't actually do this when they do chain of thought. Well, OK, I have to amend this. Language models, until like two weeks ago, didn't do this. o1, it turns out, does this a lot.
But in general, language models are really not-- they don't go through this kind of backtracking, realizing that they made a mistake, and going back and trying something else process. And when we were thinking about this, it seemed like there were three logically possible reasons that this could be. The first was that search can't be represented in a sequence, and language models can only-- like, chain of thought is a sequence, so that could have been.
The second is that it might be something about the transformer architecture itself that prevents it from being able to easily backtrack-- maybe it forgets where it is or something. And the third is that there are just not very many examples of this kind of thing in all the data that transformers are trained on.
So the first hypothesis, I think, is sort of obviously false because if you have a search algorithm, you can just put print statements in it and you'll get-- printed out a sequence which describes the search process. So it is at least possible to describe search in a sequence. The second could be the case.
The third seemed really promising because if you think about the kind of reasoning I showed you before, nobody's going to write that down and put it on the internet. They're going to figure out the answer and then just give you the answer. They're going to say, of course, the beautiful answer is, whatever it was, 3 plus 1 times 12.
So it seems entirely possible that transformers didn't exhibit this kind of backtracking reasoning property, because that's not what's on the internet. It's not the kind of thing we give to each other. And that's as simple as it's not in the data. So we decided to test that.
This is a paper led by my graduate student, Kanishk Gandhi. Gabe over there was one of the coauthors on it. Yay, Gabe. And what we did is we said, OK, let's take the game of 24, or countdown, and let's start with classical search strategies. Because a finite and enumerable search problem, we can just enumerate the search tree in depth-first, breadth-first, so on, different ordering strategies.
And then let's come up with a little language where we can write down the steps that you go through as you're traversing the tree. And that language should have things like states, describing states are at, explicitly describing moments when you backtrack, some stuff to talk about your search strategy, and your goals, and so on.
So having made a little language for expressing search in these combinatorial problems, we then have the ability to take a particular search in a particular problem and represent it as training data for a transformer in one of two different ways. We can either represent the full search, including all of the messy backtracking and parts that are not the things that are the actual solution. So the red bits, we can describe the full traversal that found the goal state. Or we can just describe the actual solution from the start state to the actual goal state. So this is the optimal path and that's the search path.
And so what we did is we took a big pile of data and we trained two conditions, one of them on the search trajectories and another one on the optimal paths. We matched the number of tokens. So there's actually more examples in the optimal path case because the full trajectories are, of course, much longer. And then we trained the transformer. And we looked at how good it was at solving held-out versions of these problems. We ignore the matrices for a second.
So this set of bars is showing you if the transformer is trained on the optimal path, and then you ask it to solve either new inputs for seen outputs-- like 24 is the output. We've seen 24 before. Or new outputs-- so further generalization. And it achieves whatever it achieves, 28% success rate or something.
And these bars are the exact same conditions when the transformer is trained with the stream of search. And it's doing vastly, vastly better at generalizing to both new inputs and new states.
Now so far-- well, let me actually say one or two qualitative things. So the first thing is just, like, OK, so it seems like the transformer is able to do something much better when you show it that it can think by backtracking and do this thing. Qualitatively, when you look at the traces, even though the transformer is not, after training, structured to do any kind of search, it is constructing a stream of search which corresponds to a valid search-- search through a valid tree in our language.
And so one thing we can do is say, OK, if you take the stream of search, parse it into the search tree, and ask for how close that search is to the search from the classical reasoning strategies-- you can, say, compare it to various depth-first search strategies, various breadth-first search strategies that use different heuristics, what you find is that the stream of search is a valid tree, but it's not particularly similar to any one of those classical strategies. So it's not like it's latching onto a single strategy and using it.
What it has is like some kind of-- honestly, sort of bad mixture of heuristics that it's using, which you see by the fact that the stream of search, while it does way better than training on the optimal path, doesn't actually do better than the classical search strategies that it was trained on. By the way, you might wonder why the classical strategies are not at 100% since, of course, you can enumerate the entire tree. It's because we cut it off after a finite amount of search in both the training and the test.
Keep this in mind. We're going to come back to it in a minute. And the leading question here is, well, it's learned some kind of mixture of the heuristics and the training thing. And also, it's learned to search at all, which is the interesting-- the remarkable part. But couldn't it do better-- like, couldn't it learn good heuristics?
So jumping back now a couple of years before we started doing this, when language models were first birthed and doing chain of thought and everybody was kind of agog, including us, we started wondering whether they could get better at chain of thought, become more directed in their chain of thought by teaching, and in particular, by self-teaching.
So my awesome graduate student, Eric Zelikman-- that's Eric up there-- he tried a very simple kind of self-teaching strategy. This is the simplest kind of reinforcement learning to think with chain of thought that you can imagine. It goes like this. You take a question and a correct answer from a data set, but it does not have the chain of thought, the rationale in it. It just has the question and the answer.
You give the language model the question. You ask it to think step by step with a few shots to show it what you mean by that. It generates a rationale, and then it generates an answer, and you check the answer. If it gets the answer right, you take that and you throw it into a data set that you use to fine-tune the same language model. So you're basically training it on its own reasoning if it got the answer right for that reasoning. And if it gets the answer wrong, you throw it out. We'll come back to that in a sec, but you don't train on the cases where it got the answer wrong.
And remarkably, this actually leads to a pretty substantial bootstrapping. So this is showing a teaching-- or letting the model learn how to do multi-digit addition. So the input is like sum up those two digits. The, I think, three examples, I don't know-- few examples we give of reasoning are these kind of structured scratchpad things inspired by a paper by Nye. And-- from MIT, but not an MIT anymore, I think. Is that right, Max Nye, who's not here? Yeah. Good. OK, fine.
And yeah. So what I'm showing you on the x-axis is the number of iterations of learning, the number of times that you try to solve problems and train yourself on your own successes. And this very satisfying thing happens where it very quickly gets good at doing one-digit arithmetic because it basically knew how to do that already. And then it gets good at two-digit arithmetic, and then three-digit, four-digit, five-digit arithmetic. And then if you look at transfer, it's like somewhat able to do six-digit arithmetic even though you never gave it those, so it generalizes a little bit.
One thing I wanted to just mention, because doing multi-digit arithmetic does not seem particularly striking in, whatever, October of 2024, but this was like two and a half, three years ago, we were using GPT-J, which is like abysmal at pretty much all the things, but especially it's abysmal at arithmetic, so it starts out not able to do this at all.
We also tried this on common sense reasoning with the CommonsenseQA data set. This is like a kind of strange common sense reasoning task, honestly, but it's what's there. It's these things like, where do you put grapes just before checking out? And models will do things like, oh, the answer should be the place where grocery items are placed before checking out, and et cetera, et cetera.
And ignoring the quirks of this particular task, you can compare different models trained in different ways. If you take GPT-J and just do few-shot chain-of-thought reasoning, it gets about 36%. If you take a much bigger model, GPT-3, and fine-tune it on the correct answers, but not any chain of thought, it does well. it gets, like-- well, sort of well, 73%. And when we do star in the simple way on GPT-J, the smaller model, it gets up to 68.8%, which is pretty good.
Now Eric was not satisfied with this. He said, but wait, we are not yet matching the state of the art, which was 73%. He's a computer scientist, in case you wondered. But also, we only ever actually used about 70% of the data for training the model because we never got a correct answer for the other 30% of the data. And that seemed very wasteful and unfortunate.
So Eric thought, can't we do something with the data that we don't get a correct answer for? And so he tried this kind of interesting, although slightly kludgy, thing where he said, OK, let's give a hint. Let's actually just tell the model which is the correct answer. So literally, the input has "correct" next to the correct answer. And then we ask it to think step by step and give an answer.
Now, you might be surprised or not surprised that it does not always give the correct answer even so. But what we did is we said, OK, when it does give the correct answer, then let's just take the rationale that it uses to get the correct answer and throw that into the training set as a sort of good reasoning trace for this problem. So we call this rationalization.
And for those who are a little bit more probabilistically inclined, I think you should think of this as an importance sampling way of estimating what the right thinking traces should be. So there's this importance distribution, which is off the target policy that we use. In principle, there should be a correction, an importance sampling correction, but that never actually helped when we tried it, so we didn't do it. And then we throw those back into the training data set. And we got our little boost. We got our boost up to almost 73% from doing this. Any questions so far?
AUDIENCE: Oh yeah. So what are the [INAUDIBLE] so far? Is that no true [INAUDIBLE] actually unknown or ambiguous. How do you ask this?
NOAH GOODMAN: Oh, like if you don't have answers? Yeah. This is a good question. There are several ways you can go about this. So you have to have some source of supervision. Having a correct answer is a very nice, strong way to have supervision. You could have weaker supervision, like having, I don't know, like a long gold response without having a single answer. That sort of is fine.
You can also do self-consistency. So we did this, actually, but didn't put it in the paper, and I there's a follow-up-- somebody else has a follow-up paper that does this where if you let the model, you take n chain-of-thought reasoning examples from the model, look at the majority-vote answer, and just assume that's correct, and then do a star on that, that doesn't work as well, but it does work, you do get some bootstrapping.
So I would say any source you have for mining some additional information about what the answer should can be used, but of course, you need something.
OK, so this is a kind of interesting philosophical choice about what the rationales represent. So if you want the chain of thought, the reasoning, to be interpretable to humans and faithfully convey the reasoning to humans, then you should check whether they're interpretable. On the other hand, if you think of them as internal computational steps that the model uses to improve its own belief distribution, there's no particular reason that those should be necessarily interpretable to humans.
So in this paper, we were definitely exploring the latter. We weren't all that concerned if the chains of thought drifted off human-interpretable rationales. There's interesting follow-up work. There's a recent paper from OpenAI where they're interested in legibility. And so if you're interested in the human interpretability, take a look at this. It's called the "Prover Verifier Games" paper.
AUDIENCE: To sort of backtrack, going back to the--
NOAH GOODMAN: Backtracking is good. Helps you think.
AUDIENCE: Yeah. Going back to the multi-digit arithmetic, [INAUDIBLE] I noticed there was a linear-- like, just almost like a linear change in the number of iterations as the number of digits went up. And I was wondering if it's doing the exact same task, just with more digits and more data. Is it always just going to have that same linear trajectory of more iterations as more digits go on? Or how would you--
NOAH GOODMAN: I don't know. My guess would be that's somewhat accidental to the details of the task setup, but you're right that it is pretty strikingly regular. Good thing to explore.
AUDIENCE: Going back to the analogy between giving it the correct answer and viewing it as a form of importance sampling, you said incorporating the importance weight didn't help-- or that's what I understood you said.
NOAH GOODMAN: That's what I said.
AUDIENCE: Could you elaborate a little bit more? How did you incorporate it and why did you think it didn't help?
NOAH GOODMAN: OK. Let me give you a short answer, but this would be a good thing to talk about offline because there's a lot more to say about this. So we naively just-- you can write down the expectation. It doesn't matter if you're only taking one sample, but it matters if you're taking multi samples-- so this was when we were exploring multi-sample estimators. And we added it, and it literally didn't do anything. It turns out that the importance weights were very extreme, and so it was equivalent to just taking the best trace, which is what one sample more or less does.
So I think-- it didn't help there, but it was not for reasons that I consider super fundamental. There's an interesting paper with a technique that's inspired by this, but much fancier called "Trace" where they treat they treat the same thing as a proper latent variable estimation problem.
And they-- so interestingly, they do use a much better multi-sample estimator. They get slightly better results from doing that. They talk about the importance-based estimators, but they don't use them because they also tried it and didn't get it to really work. So I feel like there might be something kind of deep in there, but I'm not sure. Yeah. Cool.
OK. Coming back to stream of search for a second. So now you can do the same thing. You can say, great. So we trained up our initial policy on stream of search. I'm zooming in now on the action. And then we can do STaR on top of that where we say, OK, great, now go solve game of 24 problems and reinforce your own reasoning, your own traces when you get the right answer.
And we get a kind of small, but satisfying boost from doing this. So the model is able to refine its search heuristics a little bit to get better. We did the same thing also with another reinforcement learning algorithm. So STaR is kind of a batch algorithm. It does some solving and then it does some training.
APA is an online algorithm that tries to do better RL, update the policy as you go. It turns out, it does not actually help much, so you get about the same result. It also works, but you get about the same results if you do the fancy online RL thing. And you get improvement from both of them, but in both cases, it tops out at some point.
I think this is a kind of open, interesting problem in self-taught reasoning, that there tends to be a limit to the exploration that models do intrinsically. And I think to go beyond this, you need to do more aggressive forms of exploration as you're teaching it to reason. But that's speculation, I don't know that that's true for sure. OK. So, I'm going to switch to something else-- well, related, but a little bit different. Yeah.
AUDIENCE: Just wondering does the model generalize across different tasks. So let's say if they were trained to do the CommonsenseQA, can they generate the reasoning strategy to 24?
NOAH GOODMAN: OK, so, I have eight GPUs in my lab, roughly speaking, but I have friends up in San Francisco who have approximately 6 million GPUs or something like that. And they tried this, and it does seem to generalize. And if you talk to o1 you will have stronger evidence of this claim. So, you know--
AUDIENCE: --about that now or later?
NOAH GOODMAN: How about later? Yeah.
AUDIENCE: Because I think that is kind of a key question.
NOAH GOODMAN: I mean-- so we have done this and found weak evidence of generalization across tasks, but also, I think circumstantially, pretty strong evidence that with more scale, we get more generalization, but we don't have the compute or the patience to test that. Let's come back and-- yeah.
But I want to talk about something else, which I find fascinating and, I don't know, it feels a little bit more MIT-flavored to me, which has to do with this other feature of human thinking, which is that we benefit a lot from having externalized ways of reasoning or formalized systems like logic.
So I considered giving the entire talk all about Gabe Poesia's work on math, but then it was pointed out to me that he's on the job market this year and I maybe shouldn't steal his thunder completely. So I'm just going to give a little bit of-- I'm just going to steal a little bit of his thunder.
So Gabriel-- we've been thinking about this problem of doing formal math and learning to do formal math. And in some sense, treating it as a game that you can play without a lot of human input a lot like AlphaZero did for Go. And there were a couple of obstacles.
The first and biggest obstacle was that if you look at formalized math systems like Lean and Koch, these theorem-proving environments, their action space, you can type an infinite number of things at the prompt at any moment when you're trying to do a proof. And so it didn't match very well the reinforcement learning algorithms that people use for games.
And so the first thing that we did, which was in an earlier paper, is we said, if you take the basis of these systems, which is this dependently-typed lambda calculus, which somewhat magically, to me, is pretty good at representing all the math that people care about, you represent theorems as types and proofs as programs in this thing, and it is good at representing lots of interesting math.
So the first step was to say, oh, there's this cool trick from programming languages that we can use. In programming languages, for compilers, you will often do a first step where you take your program, at least functional program, and you put it into anormal form. Now anormal form just means that you break down a really large function call into a series of steps where you assign each input to the function its own name and then just apply the function to each of those names. Details of this don't matter.
The important thing is that we can do the same transformation on the calculus of constructions, this dependently-typed lambda calculus, and it yields a finite action space because at every step, the only thing we can do is either-- the only thing we can do is apply a function to some already-created inputs, and since the set of already-created inputs is finite, there's only finitely many things you can do.
That's really cool because it turns math into a finite action game. It's a game where at every step, there's a finite set of possible actions. And so now there's a game tree. Now it's, of course, an infinite game tree, but at every node, it's just like, here are the things you could try.
OK. Now once you have this finite game tree, you can try all the tricks that people have been enjoying for learning to play games. Before I go on to the next slide-- so those tricks are like-- I'm not going to go into it in detail, but its like you can use Monte Carlo tree search to explore how to try proving different theorems, and when you find a proof, you can update your policy to do more of that, and so on.
The one thing that makes this strikingly different from learning to play Go or chess or whatever is that Go always starts from the same board state. So does chess. And so there's really only one game in Go. It's big, but there's one game.
In math, every theorem is its own game. Every theorem is its own board that you have to start from. And this means that there's this problem in learning to play the game of math, of getting dense enough data so that when you see a new theorem, the things you've learned from previous theorems are relevant, not completely new.
And so one approach to this is, great, we're going to go get humans to write down a lot of theorems and try to learn how to solve them-- this is kind of the standard thing. We tried that and found that it was just extremely hard to get enough theorems, and we were lazy, so we decided that we would make the AI find its own theorems to prove.
And so it gave rise to this really nice setup where we are trying to co-train two things. The input is just the axioms. So we say, hey, you want to learn how to do arithmetic, here are the axioms of Peano arithmetic. We're not telling you anything else. Humans put in nothing else. And then we have two agents that are playing with each other-- not really against each other, but sort of.
So we have the prover agent that's doing the usual game thing. It takes a conjecture, tries to prove it, and if it wins, it learns some training data. And we have the conjecturer agent. Now, the conjecturer agent's job is just to sample a well-formed-- syntactically well-formed-- conjecture and hand it over to the prover.
Now this would not be an interesting game if the conjecture distribution was fixed because then the prover would just learn to prove those things and no more learning would happen. And so what we do is we co-train them with an objective-- the objective for the prover is just get a good proof. The objective for the conjecturer is this zone of proximal development objective where we reward it if it generates a conjecture which is provable by the prover agent, but not if it's too easy, measured by the probability of getting a proof.
So if it's too hard and the conjecturer can't prove it, no reward. If it's too easy, if it's in the easiest 20 or 50 percentile, no reward. But if it's a conjecture that's provable but hard, then the conjecturer gets rewarded.
And so this sets up this really cool dynamic where the two agents are helping each other learn. And-- briefly, this is showing arithmetic, for instance, the conjecturer from the zeroth, first, second, third, fourth iteration of training. And then it's showing it against the prover's policy at the same 0, 1, 2, 3, 4.
And so if you look at the fourth conjecture and cut a slice through here, you see that those theorems are much harder for the initial policy than the final policy. So it is learning to prove theorems better. And if you cut a slice through in the other direction, fixing, say, the first policy, you see that the conjectures are getting gradually harder. And so they are both improving at their task, proving better, conjecturing harder conjectures, and forcing each other to learn.
There's some analysis against human theorems in the paper. We took all the theorems from a first chapter of a textbook, and the prover is also gets better at proving those theorems.
OK. I really just-- this is not somehow on the main path of the talk, but I like it so much, I couldn't not put it in. But let me come back to the main part of the talk and then I'll wrap up. I can-- can I go for a few more minutes? OK, cool. So something that you guys might know but I think is worth knowing about is content effects in human syllogistic reasoning.
All students read. Some people who read are professors, therefore, some students are professors. Does that sound like a good conclusion? No. It is logically valid, but it sounds terrible because it violates the content. The content violates our world knowledge. Whereas the other one, all students read, some people who read also write essays, therefore, some students write essays, that sounds great because it's consistent with our world knowledge.
So this belief bias effect, or content effect has been known since-- in reasoning, in psychology-- since the 80s. And a few years back, Ishita Dasgupta, Andrew Lampinen, and colleagues tried this with language models and found that language models very strongly have the same effect. So if you give language models either valid or invalid syllogisms, if they agree with world knowledge, they are strongly endorsed, and if they disagree with world knowledge, they are not strongly endorsed, independent of whether they're actually valid or not. Not totally independent, but approximately.
So we were basically wondering whether we could-- crucial thing, obvious thing, but crucial thing-- formal logic does not have that property. Formal logic just follows the rules of formal logic. And so there's a question about whether you can support language models to better follow the rules of formal logic by giving them sort of externalization.
So we built a little system. And what this system does, very briefly, is it lets a language model go ahead and sample sequences just like it wants to, but we give it a special kind of bracket. And when it chooses to generate that special kind of bracket, we go into basically a product of experts mode where we start running the Peano logic system on the side, and constraining what you can generate to be something that is valid according to first-order logic in this case.
And so you might decide you're making an axiom, and then you can make an axiom like Jane is a friend, friend Jane, and that's fine. You might decide that you're doing an inference, and then you're constrained to only sample a sequence that is logically consistent. It's a logically consistent inference according to the rules of first-order logic.
Then you generate the closing bracket-- or you, the language model, generate the closing bracket. And then you're free. You're back, you're free to generate whatever you want until you decide to generate more brackets. So this basically gives a kind of hybrid system where in the certain mode of thinking in between the brackets you constrain it to, you language model is still responsible for choosing which logically valid thing is generated, but you force it to only generate a logically valid thing.
OK. It looks like this. There's some few-shot examples. And then we-- based on the few-shot examples, the language model will first choose to formalize the context where it intermixes some informal statements with these more formal statements. It formalizes the goal, and then it goes ahead and does some reasoning where it, again, mixes together the formal and informal.
And then at the end, it generates an answer, which is not constrained to follow from the logic, but all of the logical inference it's done is in context, so it will depend on that.
And this is the result on these belief bias-type tasks. So if you ignore LLaMA for a second, unconstrained GPT-3 and-- GPT-3 and 3.5 have belief bias, they don't do great. If you give it access to this thinking tool, now it's at ceiling and it has eliminated the belief bias effect for the most part.
This is LLaMA 1, and it turns out to be a much, much weaker model, so it doesn't get to the ceiling. Most of the mistakes here are formalization mistakes, so it's not good at translating the English into the formal language. LLaMA 2 came out right as this paper was accepted, so we didn't actually run it, but I suspect LLaMA 3 is vastly better at this.
You can do it for multi-step reasoning, and it helps even more for multi-step reasoning, but I'm going to skip that. And then you can do something interesting. You can say, OK, let's go back to self-taught reasoning. What if we decide we're going to do self-taught reasoning for these logical inference problems?
So if you do this naively, you do STaR without anything fancy, it collapses. And it collapses for a very simple reason, which is that these logical tasks are right-- there are two answers. It's Boolean. And so they're right half the time, even if it's for the wrong reason. And if you're right half the time for the wrong reason and you reinforce that, you would just end up collapsing and you get gibberish.
On the other hand, if we give the model the logic guide, even though it's binary, it's right, it has useful thoughts more often, and it starts to bootstrap. And the gold one is interesting because we give it the logic guide and we also check not that it got the right answer, but that it formulated for itself a goal, and that by the end of inference, it had proved the goal it formulated for itself. And that was strong enough for it to now start bootstrapping, and LLaMA learned how to think better.
And this turns out to transfer to legal reasoning as well. We do the same bootstrapping thing for GPT 3.5, train it on synthetic reasoning tasks, and then test it on legal reasoning benchmarks, and we get a boost from doing the same process.
OK. I'm over time. Reasoning is pretty cool. We should think about the mind in terms of this hierarchy of generality to specificity of principles. I won't claim that I have identified the principles at any level here, but I do think that this kind of comparative enterprise, while it looks like I was just doing AI, is a pretty fruitful way to generate ideas about what goes at these different levels. Big thanks to my students, most importantly. Thank you all.
[APPLAUSE]