Design Memory Mechanism from Optimizers
From optimizer step to memory updates
Motivation
Continuing from my last blog post, we discussed linear layers can be seen as fast weights memory on top of the whole training datasets. Throughout the years, we have made substantial progress in optimizers, from momentum SGD to AdamW, later Muon. You might ask if there is such a duality, why don't we copy the optimizers' steps as the memory update mechanisms. In this blog post, we are going to explore this possibility. All codes and implementations can be found under this repo.
MSGD as Memory Updates
Momentum SGD (with optional Nesterov and weight decay). In standard momentum SGD we keep a velocity $v_t$ that exponentially averages gradients $g_t$:
With decoupled weight decay (or simple $\ell_2$ decay) the parameters are additionally shrunk:
For Nesterov momentum, the gradient is evaluated at the lookahead point $w_{t-1}-\mu v_{t-1}$.
Copying these steps into a memory update. Treat the per-timestep outer product $G_t = q_k\, (k_t v_t^\top)$ (where $q_k = 1/\sqrt{d}$ is the usual attention scale) as an analogue of a "gradient" for the memory matrix. We then filter $G_t$ with a momentum-like EMA to get a memory velocity $\tilde{G}_t$ and apply an exponential decay to the update:
Operationally, this yields a momentum-smoothed sequence of memory slices whose cumulative sum forms $M_t$; querying uses the usual $y_t = q_t^\top M_t$ contraction.
Show Me the Code
class MSGDFastWeightMemory(nn.Module):
def __init__(self, momentum=1.0, mem_lr = 1.0, mem_decay = 0.1, nesterov = False):
super().__init__()
self.momentum = momentum
self.mem_lr = mem_lr
self.mem_decay = mem_decay
self.nesterov = nesterov
def forward(self, qkv):
seqlen = qkv.shape[1]
q, k, v = qkv.unbind(dim=2)
qk_scale = 1.0 / math.sqrt(q.shape[-1])
mem_slice = qk_scale * k.unsqueeze(-1) @ v.unsqueeze(-2)
# compute the momemntum, that is the first layer of memory
mem_momentum = compute_mem_momentum(mem_slice, self.momentum, nesterov = self.nesterov)
mem_slice = self.mem_lr * mem_momentum
# compute the weight decay
mem_slice = compute_mem_momentum(mem_slice, 1 - self.mem_decay, nesterov = False)
mem = torch.cumsum(mem_slice, dim = 2)
output = torch.einsum("bhtd,bhtdv->bhtv", q, mem)
return output
Adam as Memory Updates
Adam. Adam keeps first and second moments of the gradient:
(Bias correction may be omitted in practice for simplicity.)
Copying these steps into a memory update. Replace the gradient $g_t$ by the memory "signal" $G_t = q_k\,(k_t v_t^\top)$. Maintain moment estimates $(m_t, v_t)$ over $G_t$ and normalize the update slice elementwise—this makes the memory step scale aware and robust to feature magnitude:
As before, the running cumulative memory stores temporally integrated, per-dimension normalized associations.
Show Me the Code
class AdamFastWeightMemory(nn.Module):
def __init__(self, beta1=0.9, beta2 = 0.99, mem_lr = 1.0, mem_decay = 0.1, nesterov = False):
super().__init__()
self.beta1 = beta1
self.beta2 = beta2
self.mem_lr = mem_lr
self.mem_decay = mem_decay
self.nesterov = nesterov
def forward(self, qkv):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
"""
seqlen = qkv.shape[1]
q, k, v = qkv.unbind(dim=2)
qk_scale = 1.0 / math.sqrt(q.shape[-1])
mem_slice = qk_scale * k.unsqueeze(-1) @ v.unsqueeze(-2)
# compute the momemntum, that is the first layer of memory
mem_momentum = compute_mem_momentum(mem_slice, self.beta1, nesterov = self.nesterov)
mem_momentum_sq = compute_mem_momentum(mem_slice**2, self.beta2, nesterov = self.nesterov)
mem_slice = self.mem_lr * mem_momentum/(torch.sqrt(mem_momentum_sq) + 1e-8)
# maybe add corrections here too later
# compute the weight decay
mem_slice = compute_mem_momentum(mem_slice, 1 - self.mem_decay, nesterov = False)
mem = torch.cumsum(mem_slice, dim = 2)
output = torch.einsum("bhtd,bhtdv->bhtv", q, mem)
return output
Muon as Memory Updates
Muon (layerwise, scale-normalized steps via matrix preconditioning). Muon can be viewed as applying a matrix inverse-square-root style preconditioner to a momentum-smoothed update so that the effective step is invariant to isotropic rescalings of the hidden representation. Abstractly, given a momentum-like signal $u_t$, Muon forms a preconditioned update using a function $\Phi(\cdot)$ that behaves like $(\cdot)^{-1/2}$:
In practice, $\Phi$ can be implemented with a few steps of the Newton–Schulz iteration that approximates the matrix inverse square root. This makes steps "length-normalized" and better conditioned.
Copying these steps into a memory update. We again take the association signal $G_t = q_k\,(k_t v_t^\top)$, smooth it with momentum to obtain $U_t$, then apply a zeroth-power (inverse square-root–like) normalization via Newton–Schulz—denoted $\operatorname{NS}_5(\cdot)$ below—to construct the memory slice before decay and accumulation:
Intuitively, this makes the memory updates shape aware: high-variance directions are downscaled while weak directions are amplified, leading to a more uniform and stable memory.
Show Me the Code
class MuonFastWeightMemory(nn.Module):
def __init__(self, momentum=1.0, mem_lr = 1.0, mem_decay = 0.1, nesterov = False):
super().__init__()
self.momentum = momentum
self.mem_lr = mem_lr
self.mem_decay = mem_decay
self.nesterov = nesterov
def forward(self, qkv):
seqlen = qkv.shape[1]
q, k, v = qkv.unbind(dim=2)
qk_scale = 1.0 / math.sqrt(q.shape[-1])
mem_slice = qk_scale * k.unsqueeze(-1) @ v.unsqueeze(-2)
# compute the momemntum, that is the first layer of memory
mem_momentum = compute_mem_momentum(mem_slice, self.momentum, nesterov = self.nesterov)
mem_slice = self.mem_lr * zeropower_via_newtonschulz5(mem_momentum, 5)
# compute the weight decay
mem_slice = compute_mem_momentum(mem_slice, 1 - self.mem_decay, nesterov = False)
mem = torch.cumsum(mem_slice, dim = 2)
output = torch.einsum("bhtd,bhtdv->bhtv", q, mem)
return output
Multi-Query Associative Recall (MQAR)
TBD
Main Difference from Existing Works
Many existing works (Titan, Gated DeltaNet, Longhorn, DeltaFormer and etc) are variants solving the following regression problem with SGD:$$\mathcal{L}(M_{t-1}; x_t) = \|M_{t-1}(k_t) - v_t\|_2^2$$
The algorithm it runs is always gradient descent without momentum, with exception in the paper "Test-time training done right", the authors perform Newton Schulze on the gradient. This obviously is different from how linear layers are actually updated during pretraining despite their duality.We all know momentum is such a powerful concept, showing up in neural network architecture (residual and more recently TRM) to optimizers. There are always two streams (or more), one fast and one slow. Rule of Two is not just a Code of the Sith.
Limitations
The current implementations are in simple and naive Pytorch. Even with torch.compile, it's unlikely to scale. However, with a bit of efforts, we might be able to make them much faster and memory efficient by writing out the customized CUDA kernels.References
- Arora, S., Eyuboglu, S., Timalsina, A., Johnson, I., Poli, M., Zou, J., ... & Ré, C. (2023). Zoology: Measuring and improving recall in efficient language models. arXiv preprint arXiv:2312.04927.
- Behrouz, A., Zhong, P., & Mirrokni, V. (2024). Titans: Learning to memorize at test time. arXiv preprint arXiv:2501.00663.
- Zhang, T., Bi, S., Hong, Y., Zhang, K., Luan, F., Yang, S., ... & Tan, H. (2025). Test-time training done right. arXiv preprint arXiv:2505.23884.
- Yang, S., Kautz, J., & Hatamizadeh, A. (2024). Gated delta networks: Improving mamba2 with delta rule. arXiv preprint arXiv:2412.06464.
- Liu, B., Wang, R., Wu, L., Feng, Y., Stone, P., & Liu, Q. (2024). Longhorn: State space models are amortized online learners. arXiv preprint arXiv:2407.14207.
- Jolicoeur-Martineau, A. (2025). Less is More: Recursive Reasoning with Tiny Networks. arXiv preprint arXiv:2510.04871.
- Jordan, K., Jin, Y., Boza, V., Jiacheng, Y., Cesista, F., Newhouse, L., & Bernstein, J. (2024). Muon: An optimizer for hidden layers in neural networks. Retrieved from https://kellerjordan.github.io/posts/muon/
Citation
If you would like to cite this blog post, you can use the following BibTeX entry:
@misc{liang2025fastweights,
title = {Design Memory Mechanism from Optimizers},
author = {Liang, Kaizhao},
year = {2025},
url = {https://kyleliang919.github.io/MEM-OPTIM}
}