AI Research Today
AI Research Today unpacks the latest advancements in artificial intelligence, one paper at a time. We go beyond abstracts and headlines, walking through architectures, experiments, training details, ablations, failure modes, and the implications for future work. Each episode will choose between one and three new, impactful research papers and go through them in depth. We will discuss the papers at the level of an industry practitioner or AI researcher. If you want to understand the newest topics in AI research but don't have the time to dig through the papers yourself, this is your solution.
AI Research Today
GradMem: Teaching LLMs to Remember (Without Retraining)
Use Left/Right to seek, Home/End to jump to start or end. Hold shift to jump forward or backward.
In this episode, we break down GradMem, a new approach to memory in large language models:
https://arxiv.org/pdf/2603.13875v1
Instead of relying on the transformer KV cache or repeatedly reprocessing documents (like in RAG), GradMem introduces a different idea—learn a compact memory representation at inference time. Using a few steps of gradient descent, the model “writes” important information from a context into a small set of memory tokens, allowing it to answer future queries without needing the original context.
We cover:
- Why KV cache is a brute-force solution to long context
- How test-time optimization turns memory into something learnable
- The difference between storing text vs. storing information
- What this means for agents, RAG systems, and long-horizon tasks
Big takeaway:
Instead of reading context over and over, models can learn to compress and reuse it intelligently.
Learn more / build with AI
Hello, welcome to another episode of AI Research Today. I'm your host, Aaron McClendon from Architects AI, company specializing in automatic implementation of enterprise grade software across a pretty wide range of use cases, CRMs, machine learning pipelines, etc. So I won't uh turn this into too much of a pitch, but that's a little bit about my background where I'm coming from. So today the paper we wanted to discuss is called Grad Meme: Learning to Write Context into Memory with Test Time Gradient Descent. The paper stood out to me mainly because it was an approach that I hadn't seen before, so kind of caught my attention, no pun intended. It is leveraging a few different techniques that also I've found to be interesting to investigate, especially MetaLearning, which was a technique that came back came out back in 2017. I've mentioned in a couple of different episodes in the past. If you go back and listen to some of the episodes, um, so I'll walk through primarily the implementation and theory behind the paper, what the technique is, how it works. Um, and then I'll briefly mention some of the experimental procedures they went through, although it's kind of difficult just because then I'm just quoting numbers to you, which isn't too intuitive. So I just recommend you go review the experimental results when you get a chance. I'll, as usual, link the archive paper in the episode description. So when you get a minute free, just go download the paper and give it a read. It's not too long, it's uh about 11 pages before the appendix and references. Um, that being said, let me start off with the abstract. Um, so I'll just read a few excerpts from it and then kind of I'll provide some commentary as we go through it. So many large language model applications require conditioning on long context. This is like you would do see in a typical RAG application. You might have some million token documents uh that you prepend to a query, and then that prompt gets sent into the language model. So you're conditioning on a long context in that situation. Transformers typically support this by storing a large per layer KV cache of past activations. So this is just in your standard auto-regressive uh recurrence in a transformer, you're storing for each layer the KV encoding for all the tokens that have been generated to that point, which increases your inference speed, reduces your latency, but it also requires some memory overhead to keep that. It's also a transient storage. So once your query's done, the KV cache is typically discarded. Um, they propose an alternative called compressive memory. So read a context once, store it in a compact state, and then answer many queries from that state. They study this in a context removal setting where the model must generate an answer without access to the original context. So, what does that mean? Uh, you can again I'll I'll draw this back to like a rag sort of um setup where in typical RAG, a query comes in, we embed or vectorize the query, and then we run some kind of distance metric to identify relevant documents that have been encoded in vectors as well. And then we take the actual text representation of those documents, append it to the query, and then pass that through our transformer architecture. So this paper is not optimizing the retrieval portion, like what are the top five documents most related to my query? What it's optimizing is once I have that document, is there a way to more efficiently pass that document through my transformer in the form of a compressed memory state? So I'm not taking, I'm no longer taking the text representation of the document that was pulled using this cosine similarity. I'm now taking the compressed memory representation. I skip the tokenization and embedding step and just pass that directly through the transformer architecture. So given a context, um, grad mem performs a few steps of gradient descent on a small set of prefix memory tokens. So I'll call out uh it is a bit confusing. They refer to memory tokens here. Memory token is not some special character, an actual token. It's just referring to an embedding vector uh that they use to represent this document in its compressed state. So um while keeping the models frozen. So importantly, the base transformer architecture itself is never edited, the weights of that are never edited. The only thing we're editing is the compressed memory representation. So they treat the memory representation as a learnable parameter itself, and they perform this test time gradient descent to figure out what that representation should be. Um and then they compare this to uh a baseline method, which is effectively our RM RMT, is it's effectively just using the same memory tokens um but without doing any kind of optimization. Uh I'll describe the baseline a bit more when we get into the discussion here. But that's kind of the that's the abstract, uh, that's the high-level contribution. So to summarize, they use a meta-learning setup to optimize test time compute in order to effectively store a compressed representation of a reference document that can be reused for future queries. Okay. So um I'll read a couple snippets from the introduction here. Large language models are increasingly deployed in settings where task-relevant information resides in long context. Um, in these regimes, the challenge is not to support long context, but to do this reusably. The dominant approach being retaining intermediate activations via the KV cache, which reduces recomputation but also increases uh memory overhead. And a comp complementary approach is to provide a compact memory state compressed constructed from a context. This reduces the uh burden of the overhead from the KV cache because you're now working with a compressed representation. Um, and the question then becomes how effectively can we compress the representation such that it preserves the original meaning of the reference document? They split this into two stages. So they split this into a read stage and a write stage. The write stage is independent of downstream supervision, they state, meaning that the typical next token loss objective for an ideal output is not affecting this write optimization. The write optimization is only looking at reconstruction error from the original context document. Um they so they form a two-stage process here, and that's where their meta-learning comes in. So um standard training writes data into parameters of a model with gradient updates. So they treat the memory as a parameter-like state to store the current context, as opposed to a one-shot forward write process, like their baseline is that they're comparing their method to. Um and they also do this with a very limited number of gradient update steps. So they limit this to five um update steps, and they discuss different uh numbers of update step settings uh in the experimental results section later in the paper. They evaluate this on several tasks. One of them is a synthetic data task, so they just generate random characters that represent keys and values, and then they want to understand how well the model can reconstruct a key or value from a compressed state. Um, and then they also compare this to some standard benchmarks that are publicly available for information retrieval tasks. So the question is you know, does this compression it's inherently lossy to some degree, but how lossy is it? So does this compression allow us to preserve information that would be needed to make to cons contextual uh responses from the input um documents? So let's move on to the actual problem setup here. Um this the problem setup that so that so they split the task into three sequences. So the context, which again tying this back to my rag analogy, this could be like a reference document, like a PDF document or something. The query, which would be the question to be asked over the PDF document, and then the target is the desired response from uh the query and the context. So what we want to do is enable the prediction of Y from the query without direct access to the context C. So where the test becomes, we we have the PDF, which is our ideal state. That's all the information possible. So if we have full access to the PDF, then we should have some response Y. Uh but if we can't access the PDF and can instead only access some compressed state, what's the error for predicting the target y? And that becomes the objective used to optimize the memory representation. So if we let f of theta be a standard causal language model, so this is just a model parameterized by some weight set theta, and then they introduce the standard notation of you know conditioned output. So f of theta given x outputs y, where x is the input sequence, y is our output sequence. And in this case, the input sequence x is just composed of the compressed memory representation and the query. Um and this is where they break this into two steps. So the first step is the right state where you and they represent this as a as a function, um, but it's really an algorithm. Uh, I guess it is parameterized by the weights of the base transformer, but it seemed a little misleading when I was first reading the paper, at least. But anyway, so they call this uh epsilon theta, but basically um we take the input context, we pass it through epsilon theta, which is not a model, it's just an optimization algorithm. But either way, it outputs a compressed representation of our context we dub as m. And then the read stage is given m and the query, can we produce the ideal output, or what is the error for that? That's the read phase. So they have this two-stage setup. Um so let's see here. Oh, yeah, okay. So so we'll talk a little bit more about the memory state, which I think is the more complex um of the uh procedure. So we have this memory state um that we're gonna represent by let's say m vectors. M is really the variable parameter. Uh the dimension of the vectors is d, which is the representation of the base transformer's hidden state. So think about your transformer. You have some input query, it gets tokenized. That tokenized that token is tokenized query then gets embedded into d dimensions. That's like your hidden dimension. Um, this is done by your embedding layer. Uh, but either way, if we want to add things to this embedded query, they also need to be of dimension D. Uh, but we can add as many or as few things to this prompt as we like. So the memory representation then becomes a variable m, which is the number of vectors we're using to represent the hidden state, um, and each vector is of dimension d, which is dictated by the underlying uh transformer itself. So the question is, how do we come up with this memory representation? So the naive approach, I suppose, would be you randomly initialize some representation and then you optimize on the reconstruction loss to come up with your ideal vector. But this could take many, many, many iterations. Um, and so this is where meta-learning comes in. So basically, meta-learning meta-learning is learning to learn. So we want to have a model that's not necessarily let's say we have three different tasks. So one of them is writing documents, one of them is um analyzing documents, and one of them is writing emails. So it's three kind of separate tasks. So you could generate fine-tuning data sets for each of those tasks, take a base transformer, and then fine-tune it three times for each of those tasks. Now, the number of steps it takes to fine-tune to each of those tasks, or if if it's even able to approach the optimal loss for each of those tasks depends on where the base model's uh performance sits in the overall loss basin, you could think of it. And so what MetaLearning does is it finds the parameters of the base model that are sort of most equidistant to all the other tasks. And so when you come in to do five to fine-tune your model to some new unseen tasks, uh task, the the number of updates required to fine-tune to that task is smaller. So your your your model is quicker to learn things. Um, and so that's the trick they use here is rather than starting from scratch for each context that we're trying to compress, what if we instead start from a pre-optimized vector that we know is going to be generally close to optimal compression for any given context? And so that's what we optimize in the outer loop of their meta-learning algorithm. Um, and the inner loop of the meta-learning algorithm is optimizing on the reconstruction loss itself. So, okay, so um I'll talk through this again in a bit more depth. So let's say we have some context sequence C and it's in tokens long. So we have tokens T1 through TN. The right objective is task agnostic, and it only depends on the ability of the model to reconstruct the context when conditioned on the memory. So let's say we have some initialized representation of memory M, and we're going to construct an objective which is the autoregressive cross-entropy loss. So this is just the negative sum over the n tokens of the log of the loss of the ideal token, um, which is the correct token versus the probability of the predictive token given the compressed representation M and all the tokens leading up to the current index. So minimizing this obviously is going to reduce the distance between those probability differences, and so is going to get the model closer and closer to producing the correct token, optimal token, um, or not even optimal, correct token given the input. So let's say the reference doc document is just a sentence, like I took my dog to the park. So we're gonna represent that sentence with some numeric compressed representation. Um and then we're going to look at, okay, let's give the model this compressed representation. The first word of I took my dog to the park is I. So what's the probability given if we give this model the representation, what's the probability it um reproduces I? Uh what's the probability now for step two, given the memory and the the first token I, it predicts the token took. Obviously, the word to token wouldn't directly map like that, but just as an illustrative example. Um so uh that's our inner loop, that's like our reconstruction loss. Can we can we exactly reconstruct the reference document if we only provide the model with the compressed representation? So we do this k times, k steps. So we just do the gradient of this loss times some learning rate alpha, and then we update our MK vector. So we use our compressed memory representation, compute our loss, grad backprop update. Now we have a new memory representation m, and we do this four or five times until we arrive at what we say is our ideal uh compressed memory representation m hat, they're gonna call it. Now, again, this is where they denote m hat by epsilon sub theta of c, which really epsilon is just this gradient descent process over the uh reconstruction cross entropy loss. Uh okay, so now we get into the read phase. So the model receives m hat and the query, and there's some ideal target y. So now we have the uh log loss of can we predict y given this compressed representation and the actual query? Um and doing that becomes a second order derivative because the compressed memory representation itself is uh a dependent variable. Um and so we have to pass that derivative through to the inner loop, um, which is where this uh second order derivative comes in from a typical mammal process. Um so uh the write phase performs the number of per sample optimization steps on M, the model parameters, um, and and this is initialized by the model parameters and the shared initialization M naught. Um so we update um the equation using our standard gradient update, and then the task loss, which is the loss of the the ideal output given the compressed representation in the query, um, is the outer objective used to learn theta. Um note not theta, we're not up, we're not changing anything about the base model weights itself, but the theta meaning the um representation of the compressed initialization vector m naught. So uh what we're meta learning in this case is the initialization of the memory vector itself. So, and that's kind of the key of why they only need to do five steps during an actual inference process is because our starting point for the compressed representation is presumably so good that we only need to provide uh pro go through five gradient update steps on the reconstruction to get a good compressed representation of this context. Um, so I think that's uh that's the main theoretical idea of the paper. So after this, they get into the experimental results section. Um, they talk about the different data sets they test. One of them, of course, is this what I mentioned earlier, synthetic data set, which is just uh synthetically generated key value pairs, and they want to generate a compressed representation of this context and then try and reconstruct it exactly from the compressed representation. The benefit of this approach being that there's no way that the um underlying knowledge that the pre-trained transformer has natively could affect its ability for this because this is just random noise we're predicting effectively. So this kind of becomes their um their baseline test. They also test a couple of other data sets. Um, squad, bobby, uh there it's mainly they're mainly data sets that are testing um information retrieval. So I took my dog to the park, my dog chased a squirrel, and then the question would be where did I take my dog? The answer being the park. Um, so those sorts of things. Some of them are constructed from Wikipedia entries. Um, and uh one of them, the Bobby uh set, has increasing level of task difficulties. So not now level one is like where did I take my dog? Level two is where did I take my dog and what did he do? So you could just kind of increase the difficulty of the questions you're asking up to some uh they call QA5, which is like the hardest uh question retrieval difficulty. Um and they baseline this on several different architectures, so one being the full attention transformer, which is just your standard decoder-only transformer architecture. Um, they also try this with uh the Mamba model, um, notably Mamba 2, which I guess is the default um now. Uh but Mamba 2, Mamba 2 is is uh different architecture than a standard transformer architecture. Caveat, different in most senses, but in the Mamba 2 paper that itself, actually, I remember they showed that there is equivalence between linear attention and um uh some constrained implementation of Mamba 2. So there is a um relationship between Mamba transformer, but either way, the the Mamba they're testing here is not um that uh specific equivalence relation. Um they trust recurrent memory transformer um as as well as their baseline. So their baseline method they're comparing against RMT is a forward-only write procedure. So this is another uh procedure proposal where um to gain the compressed representation of your context, you take your in um vectors, which are gonna become your compressed representation, you take the context itself, and then you pass that through your transformer, and then each layer performs the attention procedure looking at um you know the context attending to the compression and the compression attending to the context. And then at the last layer, you have this hidden state representation, and you take that uh representation um in the indices of the compressed uh vectors, um, and that becomes your compressed representation. That's kind of like the baseline that compare against, so that would be like no um live test time optimization, just you basically using the pre-trained transformer directly to generate the compressed representation. Um they try this for a few different uh tests for these different model architectures, and they try this for increasing number of recall requests, so like just compressing a document of um four KV pairs, eight kv pairs up to 96, and they look at how each of them performs as they scale up the number of uh requirements and change the architecture. Uh their method um performs well across the board. Um they note that the baseline procedure has like a catastrophic forgetting sort of thing happen when they get to a certain size of the compressed representation, um, whereas their method stays more or less um 100% performance across um scale. Uh and they um also look at you know the rate of exact matches and things like that. Um so I think that's where I'll cut off the discussion at that point. We're pretty much at the end of the paper. Again, I would uh encourage you to go look at the actual results themselves. The paper, the the plots are interesting and they do have some useful diagrams showing the process, the meta-learning process. Can be a little difficult to visualize. Hopefully, I did a decent job of explaining what it is. Um, but yeah, I recommend you take a look through the paper yourself. Uh, take a look through the MAML paper as well. It came out around 2017. Uh I don't remember who wrote it, but uh yeah, should be easy to find if you just kind of Google around for the archive release of the paper. So that being said, I'll thank you for tuning in. Uh again, I'm Aaron McClendon from Architect AI. If you have papers you would suggest I cover or any comments on the episodes, etc., feel free to email me at aaron.mcclendon at architect architect dashai.com. Uh other than that, hope you have a great rest of your week. Thanks.