AI Research Today

Generative Recursive Reasoning

Aaron Season 1 Episode 12

Use Left/Right to seek, Home/End to jump to start or end. Hold shift to jump forward or backward.

0:00 | 37:01

Send us Fan Mail

In this episode, we explore the paper "Generative Recursive Reasoning (GRAM)," a fascinating new approach to AI reasoning co-authored by Yoshua Bengio and researchers from Mila and Samsung AI.

Most modern AI systems reason by generating more tokens. GRAM takes a different approach: instead of extending a chain of thought, it repeatedly refines an internal latent state. The key innovation is introducing probabilistic reasoning trajectories, allowing the model to explore multiple possible solutions simultaneously rather than committing to a single deterministic path.

We discuss:

  • Recursive Reasoning Models (RRMs) and why they differ from traditional transformers
  • The limitations of deterministic latent reasoning
  • How GRAM introduces stochastic latent trajectories
  • Variational inference and the roles of pθ and qϕ
  • Multi-hypothesis reasoning and inference-time scaling
  • Results on Sudoku, ARC-AGI, N-Queens, and other structured reasoning benchmarks
  • Why latent-space reasoning may become an alternative to longer chain-of-thought prompting

The paper also demonstrates unconditional generation capabilities, suggesting a path toward reasoning systems that can both solve problems and generate structured outputs through recursive latent computation.

PDF:
Generative Recursive Reasoning

Arkitekt AI:
https://arkitekt-ai.com

Contact:
support@arkitekt-ai.com

Hello, welcome to another episode of AI Research Today. I'm your host, Aaron, from Architect AI. We're an AI firm that's able to generate software using our proprietary AI system, complex enterprise grade types of things, CRMs, etc. And today I will again be covering a paper that I found particularly interesting. Um, this paper is called Generative Recursive Reasoning. This is out of Quebec AI Institute, NYU, University of Montreal, and KIST, uh, which I'm not familiar with. Um, but the paper caught my interest mainly because it's a fairly novel approach, at least from what I can tell. Uh novel to me, at least. Uh, I I haven't heard a lot about this approach. Um, there is cited uh prior research, um but not prevalent as as far as I can tell. So there's sort of like a developing research thread along this trajectory, um, but it's still uh in its infancy, I would say. Um it's uh an a new generative type of model. Um and specifically, uh this paper um adds a variant on this new type of generative model and makes it stochastic. Um and so it's becomes non-deterministic, sort of like an LLM where you're sampling logits uh versus some kind of deterministic output, um, broadly speaking. Um so I will uh talk through, like usual, I'll talk through each section of the paper. Um so uh let's just start off with the abstract. So um the abstract, I'll I'll summarize a bit. So recursive reasoning models offer promising alternative to autoregressive sequence extension by performing iterative latent state refinement with shared transition functions. So um this, like like I mentioned, is sort of an alternative to existing generative models. Um, the the applications in this paper aren't language specifically, they use things like Sudoku and chess uh and and other things like that, um, but could be presumably extended to language modeling tasks. Um, it sort of functions like a kind of diffusion unit architecture type of thing, but also has um callbacks to like an RNN recurrence mechanism. Um so uh they they mention gram models, um, generative recursive reasoning models. Um, so that's the titer tight excuse me, title of the paper. Um so gram uh models reasoning as a stochastic latent trajectory. So again, this is sort of a refined latent state that's being iterated on. Um, and the key concept here is unlike a traditional transformer architecture where we're stacking decoder layers on top of each other. Um in that case, you're increasing the amount of compute required um to run, train, or host the model. Um, and here they move the compute, they organize the compute differently. So you they're reusing the same uh decoder block um effectively, and they're just running it many times. So you could think of it as like a simple example. Like I have an input, I put it through a decoder layer, get an output, cycle that back in, and just repeat that as many times as I want until I'm ready to decode it. Um, it's uh a little different than that, uh obviously more sophisticated, but just sort of as like a high-level um visual intuition of what's going on. Um traditional recursive reasoning models uh are deterministic, meaning that there's no stochasticity introduced during the inference or training process. So if you it's based off of think of it like a simple neural network. So if I run the same neural network many times, I'm gonna get the same output. It's just a set of weights and biases, uh, and that's just gonna give me the same output vector for some input vector. So they want to make this a stochastic process. Um, and the benefit they propose in the abstract and the paper is that uh one input may have multiple valid outputs, or it may have multiple ways to think about getting to a valid output. So constraining the model to generating a single output is constraining its uh thought space, I guess you could say. Um and so they train a conditional model, um uh a conditional model where you have X uh and then you sample Y based on X. Um and this is done through an iterative uh latent process. So okay, uh let's move into the introduction. So they mention a couple of prior works on these types of models, um, HRM and TRM, um as providing evidence of this approach for structured reasoning. Um so they perform this extended computation through iterative latent state refinement, uh deep supervision across steps, and this makes them uh suit well suited to problems requiring uh constraint propagation. Um so they they have uh this pictor picture in the uh figure one of the paper, um, and they kind of compare deterministic RMMs, RRMs against Gram, their generative model. Uh and the the point again being is you know for some X there may be multiple valid um outputs, and so if we constrain ourselves to a single latent trajectory, we're gonna wind up at the same output each time, whereas we're not you know exploring these multiple valid ways to get to a a viable solution. They the the benefit of R RMs is we are organizing compute in terms of recursive depth. Um so this, like I mentioned, is different than a traditional transformer setup where transformer we're stacking things, so we're increasing compute and we're running a single forward pass. Um so that's like the the way the models have existed up until this paper. Um they also mentioned though that they add this additional knob, which is width, um, and that's number of sampled trajectories. So, like I mentioned, the trajectory um that's generated internal, the latent state trajectory, is non-deterministic. So every time you run this model, you're likely to get a different internal reasoning state, which could lead you to a different output state. And so, in addition to having the ability to continue recursive depth, um, you can also increase the number of parallel samples, um, which is what they do uh in the paper. They look at uh scaling inference wide and deep and uh sort of you know the pros and cons and and uh trade-offs of that. Obviously, you have some other mechanisms that we'll get to that you need to consider when doing that. Um, but uh that's that's sort of like what they're arguing is the benefit of this of this method here. So let's see. Um okay, let's get into the actual architecture a bit. There's a lot of nests here. Um it's like a Russian doll sort of thing. Um so the input, like we take some input, and then again I'll clarify here we're we're talking about general numeric input. So we're not talking about tokenization or anything like that, just some kind of like vector of input numbers that could be like an encoded image, or in their case, like a sudoku board. Um, so we have some input vector of numbers, and we're gonna do the standard thing of embedding. So we're gonna embed that. Um, that's just gonna blow it up to some number of hidden dimensions. So um when from from this point forward, when I refer to input, I'm not referring to the raw actual input, um, but I'm just talking about the encoded input. So every time you run this model, you you start off with some initialized latent state, we'll call Z0. Um, and uh your next loop, um, your next uh iteration of the network is going to take in the embedding plus uh z naught. And you kind of continue this um where you're every time you generate a new latent state step, um, you're taking in z of t minus one and then the embedding. So the embedding is persistent, it's always being applied. Um, this kind of helps keep the model on track as it's thinking, quote unquote. Um, so uh it you know, without the input embedding being applied in each um latent transition transition step, um, you could just sort of get way off task. Uh and so they have um a several step architecture here, um, and they refer to uh latent transitions as little t. So like we have z of little t and that's going from one to some capital T final number of transitions. Um, and we can do this some in times. Um, so uh you could think of like uh we take z naught, that's like our our raw z0 initial one, we refine that t times, and so now we have z0 of t, and then that becomes the input to the next uh supervision step, and we can do some number of supervision steps. So the total number of steps that we're taking in this model is um t times in. Um, but anyways, within a specific latent transition, little t uh there's several things going on. So they make this reference to uh hierarchical instantiation of z. So z itself is composed of two subcomponents. Keep in mind, z is this uh internal latent state. So z itself is composed of what they call H and L. Um, H is the uh high level, sort of more static across transition steps um piece of information, and then L is like the uh scratch pad within a step um transition uh component. And so when we take our embedding, uh we within a transition, there are multiple steps taking place. So step one is this L uh low-level component. So we have some function F sub L. And in this case, the paper uses a uh decoder block with uh swiglue and some some other kind of like standard transformer decoder block. So um we have some uh attention uh MLP sort of mechanism. Um we input the encoding into that that runs K times within a single transition step to give us our final um uh L value. Uh and the H value, the high level, is an input to this network as well. So for the low level encoder, um we have uh an input of the uh low level um which is iterated k times, and then a fixed input from the high level mechanism, which kind of like helps keep it on track, as well as an input from the encoding uh of the initial uh input. Um so okay, so that was step one. So we refine K times L, which is like our low-level thing in in intra step thinking, interstep thinking, uh, or no, intra-step thinking uh mechanism. Um, and uh then that goes into our um second step, which is the F sub H. Um, this is another uh uh sort of decoder block. This is a different one though that's trained for uh high-level reasoning, uh, which is the H component of the latent state. Um so that uh finally takes in this refined uh L component and the previous steps is previous steps H component and outputs uh U. Um and then this U is what is so so all up to this point this has been deterministic with the exception of H, which I'll just uh sideline to avoid confusion. But within this step, this has all been deterministic so far. Um so the U that's produced from H and L is where the non-determinism comes in. So we sample, depending on if we're training from a posterior or prior, I'll discuss that a bit more in a few minutes. Um, but let's just say for uh inference uh pass, we sample from uh our prior, uh, and the prior gives us some noise. Um, and this is like a Gaussian noise. Um again, I'll give some more details in a minute, but just thinking high-level through the architecture. So our our refined L and H um are used to sample some noise, and this noise sort of forms a perturbation around the deterministic answer. And so then our uh noise and um the U are used to generate our final H um value. So uh our output of this step is a new H and a new L conditioned on previous steps noise. I I I mentioned that H itself I'm gonna sideline. I just said that because H does include uh non-deterministic things from previous steps. So but within the specific latent state refinement step, the operations are deterministic up to the point where we're sampling noise from the prior or the posterior. Um okay, let's see here. Yeah, I'll I'll just quote this line from the paper. So um unlike prior recursive reasoning models that update latent state deterministically, uh Gram defines a stochist stochastic transition. So repeated computation induces a distribution over latent reasoning trajectories. Um latent reasoning trajectories are just our internal thought process, and uh we're just creating a distribution over those rather than a single one. Um and this is realized as a stochastic residual perturbation around a deterministic update, which is what I mentioned that that u and uh epsilon noise sample uh form. So at each transition, we first compute the deterministic update u, and then we sample a conditional perturbation from a state-dependent Gaussian. So this will change uh depending on the state that's put in. So that's where you know we can get different trajectory uh trajectories from. Um and so this uh z finally then is a combination of the sampled noise, which is state dependent, and the u value, which is deterministically calculated within a step. Again, there is um noise and non-determinism in the u itself, just based on prior updates, but uh within the step, the z is a sum of the u value, which itself is a function of h and l, um, and then the sampled noise epsilon. All right. So uh they mention again the low-level components updated k times within a single transition, and so this is kind of carrying some intermediate uh interstep uh computation. Um and we sort of have this like fast and slow thinking. This reminds me of uh the um paper that uh I covered uh from DeepMind. Um it is called uh nested learning: the illusion of deep learning architectures. Uh, you can go back and listen to that episode if you're curious. Um, but uh that episode discussed the sort of high-level and low-level models modes um of a LLM they were talking about, um, how to make learning uh dynamic or adapting uh during inference. Um it it comes from this the the idea comes from this book written a while ago uh called Thinking Fast and Thinking Slow, which categorizes human thought into two separate paradigms where you have high-level fast thinking, which is sort of like things you can do on autopilot, maybe brushing your teeth or putting your pants on, sort of things. Like the high level uses just enough energy to accomplish tasks that we are in some sense subconsciously confident we know how to do already. And then there's the uh slow thinking, which is you know, if you're gonna sit down and do a math problem, um, that's the kind of thinking that would be engaged because it requires you to uh take a step out of your autopilot mode and turn on, turn on the old brain and think through things deeply, and it requires energy and uh concentration and focus to do that. Um, and that's not something that is doing we're doing the majority of the day. And so that paper and also this paper kind of uh try to take something like that into account, um, getting back to modeling the deep learning architectures along more of a same thread that a biological organism would use to think. Um, and that's sort of like the a parallel you could think of with this high level and low-level component. This H component is sort of more slowly varying across um latent transitions, and this low-level component is like a rapid transient sort of update that's happening within um the step itself. Um they call out stochasticity, stochasticity is only introduced at the high level, like I mentioned. The low level is deterministic, um, and they do make a note. Um, they tried adding stochasticity into the low level update, um, but it only added complexity, and they didn't see that it added any performance performance gains. Um okay, so how do you train a model like this? Um that's always that's always the the kicker once you understand the architecture. I remember reading through the Mamba paper the first time, and it's it's you you think okay, I get the I get the architecture, but then and then you think through the the actual training process, and it's uh a sort of a separate beast. Um so I'll try and summarize here. Uh there are a lot of um uh math, there's a lot of complex math notation, so there's a lot of expectation values, as I can, as you can imagine. Um, so I do suggest that you go and review this yourself and kind of maybe use this recording as a guide to help understand what's going on. Um so uh the sequence of latent variables, a full time. Trajectory, they call it, is going to be Z total or T total, sorry. T total, like I mentioned earlier, is going to be T, which is the number of inter step uh interstep refinement, and then the number of supervision steps, which is like your outer sort of iteration counter. Um, so t times n is your t total. Um, and the conditional likelihood of this would just be um the in integrating over all possible trajectories, right? So like you could think of um a trajectory sampled given x, and then given x and our trajectory, what is our y? So you have two conditional probability distributions there. We're integrating over, in this case, they call it tau, which is like your full trajectory, including your um t and your n steps all together as one. Um but the direct maximum likelihood estimation of this is intractable because we have to marginalize over all of these latent trajectories, of which there are infinitely many. And so they use this uh uh variational uh posterior trick, um, sort of like the VAE uh model sort of trick, um, and use the elbow method, evidence lower bound. So basically they bound the uh maximum likelihood of the log of y given x um with some expectation values, and those expectation values are gonna become what we're concerned with um uh optimizing our objective towards. So uh there are two values in this um lower bounds that they're constructing. One of them is the expectation of the log of y given tau and x, where tau is sampled from our um posterior, they introduce this posterior q sub phi. Um so I'll I know I know it's hard to think through in your head uh just listening, but I'll just talk about two different distributions here, the prior and the posterior. Okay, so in the paper they use p sub theta and q sub phi. I I won't go uh discussing Greek letters and stuff, but I'll just talk about the uh posterior and the prior. So um we sample a trajectory from the posterior, and then given the trajectory and the input, we sampled uh y from the prior. Um and so we're sort of cheating here, right? Because we're generating a trajectory from a distribution that has seen the answer already. Um, and this is like the variational trick. So uh we generate y using some distribution that knows y in advance and just gives us the latent state that would lead to that y. And then we normalize or regularize that with a uh uh KL uh divergence term um and the the callback lever lever divergence term. Um and this callback lever sort of uh pulls the posterior and the prior together. Um, and so we generate uh good trajectories from our posterior, and that's used to generate our output, and then we're maximizing the likelihood of that good output, and then we're regularizing it, so we're kind of pulling along the prior uh with the posterior uh during training, um, and then as a result, we get a good prior. The prior itself doesn't need the answer to generate a trajectory, so it should just generate trajectories similar to the posterior. Uh and so that's sort of the trick. So we like train when we're training, we're using like a cheat code, um, and then the actual function starts to model the cheat code, and then during inference, we scrap the cheat code, and the actual function should um work uh as expected because it was trained um using the answers. Uh they they mentioned that the prior and the posterior are modeled as uh Markov processes over latent states. Um this is like a reinforcement learning throwback where you know we're just saying that the current state is a function only of the prior state and not like some big uh more complex thing going back through time, and so it becomes a product um over uh each t in the t total. Um and we can plug that back into the expectation values I mentioned above, and it simplifies from a product um that giant pi sort of notation, simplifies from that into a sum. Um, and I I won't read off uh that for you. Uh it I don't think would be helpful, so I'll just let you take a look at that um when you get a minute. But uh just just as a note, they include this um Markov assumption in the expectation values. It simplifies things into summations, um, and then practically uh it is extremely difficult to propagate gradients back through all of those um latent state updates, and so for each transition inner transition T they propagate gradients only for the final update state. So you sort of have this surrogate objective that's optimizing the last uh transition state um for each T along your in supervision steps, and that becomes your training objective. Uh a couple other things I want to mention. They they call out finally. Um so there are um multiple axes uh we can run during inference. Depth is recursion longer, so we can just do more transitions and steps before decoding. This is like the same kind of axis that all the prior research has, um, but they are deterministic. Um so we're just walking further up the fixed point the more recursions we do. Um they introduce this thing called adaptive computation time ACT, which they inherit from HRM. Um so instead of always running a fixed number of supervision steps, uh the model learns when to stop. Um, there's an auxiliary halt head they mention in the appendix that does sort of like a Q learning style update on whether decoding now would be correct versus another step is worth it. So easy problems should halt earlier than longer ones. So sort of like an adaptive or dynamic um smart depth calculation. Uh, they also mention width, uh, and so we can sample many trajectories in parallel. This is the net new axis um axis, which doesn't exist um in other research and only exists because uh gram is stochastic. So we want to run the model in independent time from the same input, uh, draw some noise, get in different trajectories, um, decode each one's terminal state, and now we have in candidate solutions. Uh, then the step becomes um, you know, how do we how do we choose which one? And they mention a couple of strategies, like a majority vote strategy, just choose the one that appears the most often. Um they also talk about a latent process reward model. So instead of voting, score each trajectory and pick the best one. Um, this model's a learned value head, so like another small head bolted on, trained to predict how good a trajectory will turn out um reading from its latent state. So uh that that's sort of like the high-level overview of the model itself. So um they they train this uh and test it on sudoku, and they also look at ARC AG1 um uh benchmarks. Uh these are like puzzle benchmarks. Um, they compare to previous research on recursive reasoning models. Um, they also compare to um like some some of the frontier models like O3 Mini GPT 5.2, low thinking, and Grok 4 thinking. Um uh it is surpassed in performance by these other ones, Grok 4 thinking and some other things on the ARC AG1 benchmark. Um, but uh the model is way smaller, like I think they have like a 10 million parameter model here. So for these puzzles, like specific puzzle tasks, um, they're able to outperform models that are way larger um simply by uh running the recursion. Now, I guess you know, thinking through usage of this, um there's always a trade-off like test train compute. So here uh inference uh is gonna take longer, right? Because we're running um a bunch of loops to generate and refine an answer versus like a traditional um autoaggressive transformer, which is just gonna generate token by token some output. Um, so uh it's it's the constant tug of war. Are we are we offloading compute to to test time or are we just gonna train a larger model? Um I really like this approach because you can get at least on these benchmarks, they show really good performance, and uh you don't need a massive amount of compute to run the model. So, like the big trade-off with some of the larger frontier models is there's no way you could run those models on like consumer grade laptop hardware. Uh and so you're stuck pushing data via API to some off-site company who's running inference on some mega cluster and returning a response to you. Um and here you could run this easily on consumer hardware. And I would go so far as to say you probably don't even need a GPU for a model of this size. Um, so quite cool if you can get um improved models um like this for generative tasks and run them on-edge devices, um, that would be a nice addition to a lot of use cases um that require uh private inference and small compute. Like if you think of IoT environments, like a manufacturing IoT environment, like I have some routers or or small uh computers sitting through factory, um, like a uh PLC sort of processor recording manufacturing data, and I need to run inference on something to you know predict if uh uh this this weld is good or or whatever. Like there might there'd be a use case to having an on-prem device. Um and uh that that that use case is probably more like a traditional machine learning application, but there are there are uh language ones as well. Like, for example, um one client we were working with uh law firm uh had large amounts of uh legal documents getting deposited, these had private client information um and so couldn't be pushed off-prem. And so then it gets into this thing of like, okay, we're fine-tuning a small language model on tasks uh that we had to construct some fine-tuning set for, and then deploy that. And even then, uh performance is a little iffy, like we were using in that specific situation a Mistral 7B. This was a couple years ago when that was current, um, and uh it still requires like a 24 gigs VRAM to run like a BF16 precision model, if I remember correctly. Uh and so that's still a pretty hefty GPU, it still requires like a monthly spend, but if you could get comparable performance with like a significantly smaller model like this, um it's gonna be quite beneficial for you from a cost perspective. Um, I I would I I I didn't see, and uh someone please correct me um if I'm wrong, um, but I didn't see inference time uh or latency uh comparisons uh in the paper. Uh it is a pretty lengthy paper, so it's possible I just missed it. Um but you know, your trade-off is like your inference is probably gonna take longer, but in tasks that are running in batches, um, in a lot of industry use cases, it's probably fine. So um, you know, a lot of work in the edge device small language modeling space. I think the uh gram uh is an interesting proposal for how to solve generative tasks on small edge devices um with minimal compute availability. So very interested to see where it goes. Um they do uh link to their GitHub, um which uh I'll include a link to the archive paper um in the episode description here. So uh take a take a minute, read through the paper. Really cool stuff. Um I like seeing kind of novel things like this versus like a tweak on another attention mechanism type of paper. Um, so yeah, uh other than that, um I think I'll go ahead and wrap this episode up. So thanks for listening. Again, I'm your host, Aaron. Um, you can email me if you want, support at architects ai.com. I'll include that email in the description. Uh, let me know if there's papers you'd like me to cover. Um, and other than that, hope you have a great rest of your week. Thanks.