4.3. Variants and discussions of modern optimizers#
In this section we discuss some modern optimizers that have strong potentials of replacing Adam/AdamW as the future default optimizer.
4.3.1. Optimizer design as steepest descent with different norms#
Recall that SGD can be formulated as minimizing local quadratic approximations of
It is easy to verify that this is exactly equivalent to \(\vw^{t+1}\leftarrow \vw^t - \eta_t \nabla\ell(\vw^t,\xi_t)\). Now what if we are considering different norms in the right hand side of (4.13)?
As we discussed in the previous section, SGD suffer from the imbalance of the magnitude of different dimensions/entries, and Adam is essentially derived from sign SGD. In fact, one just need to change the vector \(\ell_2\) norm in (4.13) into \(\ell_\infty\) norm and recover sign SGD:
Now let’s specify the variables into matrices, since in attention based neural networks are consists mostly of matrix variables. Denote the matrix weight matrices as \(\vW\), then the update in (4.13) can be re-written
where the inner product is just common matrix multiplication and \(\|\cdot\|_{F}\) is the Frobenius norm of a matrix (\(\|A\|_{F}^2:=\sum_{i,j}A_{i,j}^2\)). It is also easy to verify that this is exactly equivalent to \(\vW^{t+1}\leftarrow \vW^t - \eta_t \nabla\ell(\vW^t,\xi_t)\).
Now you might ask what would happen if we have multiple layers of weight matrices, rather than just one matrix above. Suppose we have \(L\) layers in total and the weight matrices are \(\vW_{1}, \vW_{2},..., \vW_{L}\). Then (4.15) can be updated to (for all \(l=1,2,...,L\))
Now we do a quick recap of possible choices of matrix norms.
Definition 4.1
For a matrix \(A\in\RR^{m\times n}\), also two norms \(\|\cdot\|_{\alpha}\) in \(\RR^{n}\) and \(\|\cdot\|_{\beta}\) in \(\RR^{m}\), the induced \(\alpha\)-to-\(\beta\) norm is given by (where \(\alpha\) and \(\beta\) could be \([1,\infty]\).)
Matrix induced norm allows use to do similar things as (4.14), i.e. replacing different norms and derive different form of descent methods, but for attention-based LLMs (which consists mostly of matrix attention weight matrices). The question is: How do we determine which norm is good for LLM training?
First, let’s see how we can reconstruct sign SGD, which is the fundamental building block of Adam. It turned out that the induced \(1\)-to-\(\infty\) norm is the correct choice to recover sign SGD for matrix variable, namely we have
where \(\dagger\) represents the dual norm.
Next, the question is: What if we just use the simplest matrix induced norm, namely the \(\ell_2\rightarrow\ell_2\) norm? Note that \(\ell_2\rightarrow\ell_2\) is also called the spectral norm since it can be shown that its value equals to the largest singular value of a matrix. It is one of the most commonly used matrix norm, largely since it is easy to compute (just need a singular value decomposition or power iteration).
It turned out that normalizing by the spectral norm will result in the orthogonalization of the matrix
where \(U\Sigma V^\top = \nabla\ell(\vW^t,\xi_t)\) is the singular value decoposition, and \(*\) stands for the dual norm of spectral norm, namely the nuclear norm (sum of all the singular values, or \(\mathrm{tr}(\Sigma)\)).
This motivates us to develop a matrix version Adam by just considering the EMA for (4.19). The resulting algorithm appears to be one of the strongest candidate to replace Adam for modern LLM training, name the Muon algorithm [Jordan et al., 2024].
Algorithm 4.6 (Muon for matrix variable optimization)
Input: initial point \(\vW_{l}^0\), stepsize sequence \(\{ \eta_{t} \}\), max. no. of iterations \(T\).
For \(k=0,1,2,..., T-1\),
note that here the SVD is also conducted to \(\vM_{l}^{t}\).
Output: last iterate \(\vW^T\), or the solution sequence (checkpoints) \(\{ \vW^t \}_{t=1}^T\).
Muon performs surprisingly well for pre-training of LLMs with faster speed. For details see this blog [Jordan et al., 2024] or this technical report [Liu et al., 2025].
4.3.2. Approximate Hessian point of view for Shampoo and Muon#
Yet another line of work inspect the choice of optimizers in a different way, via approximating second order information. Next, we first discuss how Adam could be derived in a different perspective, then discuss how this new Hessian-approximation perspective can derive a brand new algorithm called Shampoo [Anil et al., 2020, Gupta et al., 2018, Shi et al., 2023].
Consider the situation where we are trying to solve a binary classification problem using cross entropy loss (i.e. a logistic regression. In fact, the following derivation also works for multi-class classification such as LLM pre-training):
where \(\mathcal{D}\) is the data distribution. We know that it would benefit the convergence if we can access the Hessian information. The hessian of (4.20) is
Note that (verify this simply by property of \(\log\) function)
therefore we get
Note that the first term \(I(\vw):=\EE_{(x,y)\sim\mathcal{D}} [\nabla\log\pi(y|x;\vw) (\nabla\log\pi(y|x;\vw))^\top]\) is also called the Fisher information matrix. Also, if the data distribution is identical to the model output distribution, i.e. \((x,y)\sim\mathcal{D}\) is the same distribution as \((x,y)\sim\mu\times\pi(\cdot|\cdot;\vw)\), then the second term will vanish (left as a homework).
Therefore we can legitimately approximate the second-order information via the Fisher information matrix \(I(\vw)\). We can even apply the Newton’s method approximately as follows!
Algorithm 4.7 (Prototype of a second order method via Fisher information)
Input: initial point \(\vw^0\), stepsize sequence \(\{ \gamma_t \}_{t \geq 0}\), max. no. of iterations \(T\).
For \(T=0,1,2,..., T-1\),
Output: last iterate \(\vw^T\), or the solution sequence \(\{ \vw^t \}_{t=1}^T\).
Note that we use the \(I(\vw)^{-1/2}\) instead of the inverse \(I(\vw)^{-1}\), which is common in practice. This operation can be think as a “whitening” process (i.e. each eigenvalue will be dragged to be closer to \(1\)) so that we make sure that this inverse can be more heterogeneous and stable. Another practice people adopt is to add an \(\epsilon I\) (\(\epsilon\) is a small constant and \(I\) is the identity matrix) to \(I(\vw)\) before taking the inverse, also to maintain nuemrical stability.
Now, there are two main issues for Algorithm 4.7: the first is that it is not clear how \(I(\vw)\) can be estimated, and the second is how to compute its inverse efficiently.
There is one solution that can solve both problem once: using a diagonal approximation! If we only keep the diagonal element of \(I(\vw)\), the computation should be simplified greatly via the current iteration’s stochastic gradient, and the inverse is simply taking the inverse of each of the diagonal numbers! Now that we only keep the diagonal element, we will arrive at this update rule
and it is very straightforward to verify that
where the right-hand-side uses element-wise multiplication and division.
If you still remember the update form of Adam (Algorithm 4.4), also how we derived it from sign SGD, this is actually almost identical! This is the reason why Adam/AdamW is sometimes called a “second-order” method although it’s not actually calculating Hessian directly.
Now it’s move to the matrix variables and discuss how Shampoo is proposed. The idea is the same as Algorithm 4.7, but this time we have matrix variables. For a matrix weight \(W\in\mathbb{R}^{m\times n}\) and gradient \(\nabla \ell(W)=:G\in\mathbb{R}^{m\times n}\), to compute the Fisher information matrix, we need to flatten the matrix to \(\vg=\mathrm{vec}(G)\in\mathbb{R}^{m n}\) and compute
However this is too tedious! Imagine that we have a matrix of dimension \((m,n)=(4096,4096)\) (this is common for attention matrices), then the Fisher information matrix will be \(\mathbb{R}^{16777216\times 16777216}\) (\(4096^2=16777216\)) which is too huge to be computed! In practice, such huge memory overhead is usually not tolerable.
However we have an efficient way to approximate the Fisher Information without vectorizing the gradient \(G\). [Anil et al., 2020] provides the followign useful lemma.
Lemma
Let \(G_1,\ldots,G_t \in \mathbb{R}^{m\times n}\) be matrices of rank at most \(r\). Let \(g_s=\operatorname{vec}(G_s)\) and define
Let \(L_t,R_t\) be defined as
Then for any \(p,q>0\) such that \(1/p+1/q=1\), we have
(see Lemma 1 in [Anil et al., 2020])
In above lemma, \(I_m\) refers to identity matrix of size \(\mathbb{R}^{m\times m}\), similar for \(I_n\) and \(I_{mn}\). Also \(\epsilon\) is a small constant just to keep \(\widehat H_t\) invertible (note that for Algorithm 4.7, we also need such a safeguard to make the Fisher information invertible, but I omit it for simplicity.)
Also here \(\otimes\) stands for the Kronecker product. This Lemma indicates that instead of calculating the full Fisher Information, we can calcualte two small matrices \(L\in\mathbb{R}^{m\times m}\) and \(R\in\mathbb{R}^{n\times n}\) to efficiently calculate the update in Algorithm 4.7. Another useful property of Kronecker product is that, \((A\otimes B)^{-1/2} \mathrm{vec}(G) = A^{-1/2} G B^{-1/2}\). Therefore combining all these approximation and identities, we can design an algorithm that effectively approximate second-order information of a matrix variable, without significantly increase the memory overhead
The resulting algorithm is so called Shampoo algorithm (Algorithm 4.8)
Algorithm 4.8 (Shampoo for matrix variable optimization)
Input: initial point \(W^0\), stepsize/lr sequence \(\{ \gamma_t \}_{t \geq 0}\), momentum parameter \(\beta\), max. no. of iterations \(T\).
Initialize \(L_t=\epsilon I_m\in\mathbb{R}^{m\times m}\) and \(R_t=\epsilon I_n\in\mathbb{R}^{n\times n}\)
For \(T=0,1,2,..., T-1\),
Output: last iterate \(\vw^T\), or the solution sequence \(\{ \vw^t \}_{t=1}^T\).
The reader may verify that Algorithm 4.6 and Algorithm 4.8 are identitcal if we take \(\beta=0\) for both of the algorithms. This indicate that by carefully choosing the matrix norm in the steepest descent (4.13), the second-order information can be partially recovered and our optimization can be even more efficient.
4.3.3. Possible homeworks:#
implement muon and Shampoo
discuss the efficient implementation of muon: in particular, NS step and pther approximations