Design Memory Mechanism from Optimizers

From optimizer step to memory updates

Motivation

Continuing from my last blog post, we discussed linear layers can be seen as fast weights memory on top of the whole training datasets. Throughout the years, we have made substantial progress in optimizers, from momentum SGD to AdamW, later Muon. You might ask if there is such a duality, why don't we copy the optimizers' steps as the memory update mechanisms. In this blog post, we are going to explore this possibility. All codes and implementations can be found under this repo.

MSGD as Memory Updates

Momentum SGD (with optional Nesterov and weight decay). In standard momentum SGD we keep a velocity $v_t$ that exponentially averages gradients $g_t$:

$$\begin{aligned} v_t &= \mu\, v_{t-1} + g_t,\\ w_t &= w_{t-1} - \eta\, v_t.\\ \end{aligned}$$

With decoupled weight decay (or simple $\ell_2$ decay) the parameters are additionally shrunk:

$$w_t \leftarrow (1-\eta\,\lambda)\, w_{t-1} - \eta\, v_t.$$

For Nesterov momentum, the gradient is evaluated at the lookahead point $w_{t-1}-\mu v_{t-1}$.

Copying these steps into a memory update. Treat the per-timestep outer product $G_t = q_k\, (k_t v_t^\top)$ (where $q_k = 1/\sqrt{d}$ is the usual attention scale) as an analogue of a "gradient" for the memory matrix. We then filter $G_t$ with a momentum-like EMA to get a memory velocity $\tilde{G}_t$ and apply an exponential decay to the update:

$$\begin{aligned} \tilde{G}_t &= \mu\, \tilde{G}_{t-1} + G_t,\\ \Delta M_t &= \alpha\, \tilde{G}_t,\\ \Delta M_t &\leftarrow (1-\gamma)\, \Delta M_t,\\ M_t &= M_{t-1} + \Delta M_t. \end{aligned}$$

Operationally, this yields a momentum-smoothed sequence of memory slices whose cumulative sum forms $M_t$; querying uses the usual $y_t = q_t^\top M_t$ contraction.

Show Me the Code
class MSGDFastWeightMemory(nn.Module):
    def __init__(self, momentum=1.0, mem_lr = 1.0, mem_decay = 0.1, nesterov = False):
        super().__init__()
        self.momentum = momentum
        self.mem_lr = mem_lr
        self.mem_decay = mem_decay
        self.nesterov = nesterov

    def forward(self, qkv):
        seqlen = qkv.shape[1]
        q, k, v = qkv.unbind(dim=2)
        qk_scale = 1.0 / math.sqrt(q.shape[-1])
        mem_slice = qk_scale * k.unsqueeze(-1) @ v.unsqueeze(-2)
        # compute the momemntum, that is the first layer of memory
        mem_momentum = compute_mem_momentum(mem_slice, self.momentum, nesterov = self.nesterov)
        mem_slice = self.mem_lr * mem_momentum
        # compute the weight decay
        mem_slice = compute_mem_momentum(mem_slice, 1 - self.mem_decay, nesterov = False)
        mem = torch.cumsum(mem_slice, dim = 2)
        
        output = torch.einsum("bhtd,bhtdv->bhtv", q, mem)
        return output

Adam as Memory Updates

Adam. Adam keeps first and second moments of the gradient:

$$\begin{aligned} m_t &= \beta_1 m_{t-1} + (1-\beta_1) g_t,\\ v_t &= \beta_2 v_{t-1} + (1-\beta_2) g_t^{\odot 2},\\ \hat m_t &= m_t/(1-\beta_1^t),\quad \hat v_t = v_t/(1-\beta_2^t),\\ w_t &= w_{t-1} - \eta\, \hat m_t/(\sqrt{\hat v_t}+\varepsilon). \end{aligned}$$

(Bias correction may be omitted in practice for simplicity.)

Copying these steps into a memory update. Replace the gradient $g_t$ by the memory "signal" $G_t = q_k\,(k_t v_t^\top)$. Maintain moment estimates $(m_t, v_t)$ over $G_t$ and normalize the update slice elementwise—this makes the memory step scale aware and robust to feature magnitude:

$$\begin{aligned} m_t &= \beta_1 m_{t-1} + (1-\beta_1) G_t,\\ s_t &= \beta_2 s_{t-1} + (1-\beta_2) G_t^{\odot 2},\\ \Delta M_t &= \alpha\, m_t/(\sqrt{s_t}+\varepsilon),\\ \Delta M_t &\leftarrow (1-\gamma)\, \Delta M_t,\qquad M_t = M_{t-1}+\Delta M_t. \end{aligned}$$

As before, the running cumulative memory stores temporally integrated, per-dimension normalized associations.

Show Me the Code
class AdamFastWeightMemory(nn.Module):
    def __init__(self, beta1=0.9, beta2 = 0.99, mem_lr = 1.0, mem_decay = 0.1, nesterov = False):
        super().__init__()
        self.beta1 = beta1
        self.beta2 = beta2
        self.mem_lr = mem_lr
        self.mem_decay = mem_decay
        self.nesterov = nesterov

    def forward(self, qkv):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
        """
        seqlen = qkv.shape[1]
        q, k, v = qkv.unbind(dim=2)
        qk_scale = 1.0 / math.sqrt(q.shape[-1])
        mem_slice = qk_scale * k.unsqueeze(-1) @ v.unsqueeze(-2)
        # compute the momemntum, that is the first layer of memory
        mem_momentum = compute_mem_momentum(mem_slice, self.beta1, nesterov = self.nesterov)
        mem_momentum_sq = compute_mem_momentum(mem_slice**2, self.beta2, nesterov = self.nesterov)
        mem_slice = self.mem_lr * mem_momentum/(torch.sqrt(mem_momentum_sq) + 1e-8)
        # maybe add corrections here too later   
        # compute the weight decay
        mem_slice = compute_mem_momentum(mem_slice, 1 - self.mem_decay, nesterov = False)
        mem = torch.cumsum(mem_slice, dim = 2)
        
        output = torch.einsum("bhtd,bhtdv->bhtv", q, mem)
        return output

Muon as Memory Updates

Muon (layerwise, scale-normalized steps via matrix preconditioning). Muon can be viewed as applying a matrix inverse-square-root style preconditioner to a momentum-smoothed update so that the effective step is invariant to isotropic rescalings of the hidden representation. Abstractly, given a momentum-like signal $u_t$, Muon forms a preconditioned update using a function $\Phi(\cdot)$ that behaves like $(\cdot)^{-1/2}$:

$$\begin{aligned} u_t &= \mu\, u_{t-1} + g_t,\\ \Delta w_t &= \alpha\, \Phi(u_t) \;\approx\; \alpha\, u_t\,(u_t^\top u_t)^{-1/2}. \end{aligned}$$

In practice, $\Phi$ can be implemented with a few steps of the Newton–Schulz iteration that approximates the matrix inverse square root. This makes steps "length-normalized" and better conditioned.

Copying these steps into a memory update. We again take the association signal $G_t = q_k\,(k_t v_t^\top)$, smooth it with momentum to obtain $U_t$, then apply a zeroth-power (inverse square-root–like) normalization via Newton–Schulz—denoted $\operatorname{NS}_5(\cdot)$ below—to construct the memory slice before decay and accumulation:

$$\begin{aligned} U_t &= \mu\, U_{t-1} + G_t,\\ \Delta M_t &= \alpha\, \operatorname{NS}_5(U_t),\\ \Delta M_t &\leftarrow (1-\gamma)\, \Delta M_t,\qquad M_t = M_{t-1}+\Delta M_t. \end{aligned}$$

Intuitively, this makes the memory updates shape aware: high-variance directions are downscaled while weak directions are amplified, leading to a more uniform and stable memory.

Show Me the Code
class MuonFastWeightMemory(nn.Module):
    def __init__(self, momentum=1.0, mem_lr = 1.0, mem_decay = 0.1, nesterov = False):
        super().__init__()
        self.momentum = momentum
        self.mem_lr = mem_lr
        self.mem_decay = mem_decay
        self.nesterov = nesterov

    def forward(self, qkv):
        seqlen = qkv.shape[1]
        q, k, v = qkv.unbind(dim=2)
        qk_scale = 1.0 / math.sqrt(q.shape[-1])
        mem_slice = qk_scale * k.unsqueeze(-1) @ v.unsqueeze(-2)
        # compute the momemntum, that is the first layer of memory
        mem_momentum = compute_mem_momentum(mem_slice, self.momentum, nesterov = self.nesterov)
        mem_slice = self.mem_lr * zeropower_via_newtonschulz5(mem_momentum, 5)
        # compute the weight decay
        mem_slice = compute_mem_momentum(mem_slice, 1 - self.mem_decay, nesterov = False)
        mem = torch.cumsum(mem_slice, dim = 2)
        
        output = torch.einsum("bhtd,bhtdv->bhtv", q, mem)
        return output

Multi-Query Associative Recall (MQAR)

TBD

Main Difference from Existing Works

Many existing works (Titan, Gated DeltaNet, Longhorn, DeltaFormer and etc) are variants solving the following regression problem with SGD:

$$\mathcal{L}(M_{t-1}; x_t) = \|M_{t-1}(k_t) - v_t\|_2^2$$

The algorithm it runs is always gradient descent without momentum, with exception in the paper "Test-time training done right", the authors perform Newton Schulze on the gradient. This obviously is different from how linear layers are actually updated during pretraining despite their duality.
We all know momentum is such a powerful concept, showing up in neural network architecture (residual and more recently TRM) to optimizers. There are always two streams (or more), one fast and one slow. Rule of Two is not just a Code of the Sith.

Limitations

The current implementations are in simple and naive Pytorch. Even with torch.compile, it's unlikely to scale. However, with a bit of efforts, we might be able to make them much faster and memory efficient by writing out the customized CUDA kernels.

References

  1. Arora, S., Eyuboglu, S., Timalsina, A., Johnson, I., Poli, M., Zou, J., ... & Ré, C. (2023). Zoology: Measuring and improving recall in efficient language models. arXiv preprint arXiv:2312.04927.
  2. Behrouz, A., Zhong, P., & Mirrokni, V. (2024). Titans: Learning to memorize at test time. arXiv preprint arXiv:2501.00663.
  3. Zhang, T., Bi, S., Hong, Y., Zhang, K., Luan, F., Yang, S., ... & Tan, H. (2025). Test-time training done right. arXiv preprint arXiv:2505.23884.
  4. Yang, S., Kautz, J., & Hatamizadeh, A. (2024). Gated delta networks: Improving mamba2 with delta rule. arXiv preprint arXiv:2412.06464.
  5. Liu, B., Wang, R., Wu, L., Feng, Y., Stone, P., & Liu, Q. (2024). Longhorn: State space models are amortized online learners. arXiv preprint arXiv:2407.14207.
  6. Jolicoeur-Martineau, A. (2025). Less is More: Recursive Reasoning with Tiny Networks. arXiv preprint arXiv:2510.04871.
  7. Jordan, K., Jin, Y., Boza, V., Jiacheng, Y., Cesista, F., Newhouse, L., & Bernstein, J. (2024). Muon: An optimizer for hidden layers in neural networks. Retrieved from https://kellerjordan.github.io/posts/muon/

Citation

If you would like to cite this blog post, you can use the following BibTeX entry:

@misc{liang2025fastweights,
  title = {Design Memory Mechanism from Optimizers},
  author = {Liang, Kaizhao},
  year = {2025},
  url = {https://kyleliang919.github.io/MEM-OPTIM}
}