4.3. Variants and discussions of modern optimizers#
In this section we discuss some modern optimizers that have strong potentials of replacing Adam/AdamW as the future default optimizer.
4.3.1. Optimizer design as steepest descent with different norms#
Recall that SGD can be formulated as minimizing local quadratic approximations of
It is easy to verify that this is exactly equivalent to \(\vw^{t+1}\leftarrow \vw^t - \eta_t \nabla\ell(\vw^t,\xi_t)\). Now what if we are considering different norms in the right hand side of (4.13)?
As we discussed in the previous section, SGD suffer from the imbalance of the magnitude of different dimensions/entries, and Adam is essentially derived from sign SGD. In fact, one just need to change the vector \(\ell_2\) norm in (4.13) into \(\ell_\infty\) norm and recover sign SGD:
Now let’s specify the variables into matrices, since in attention based neural networks are consists mostly of matrix variables. Denote the matrix weight matrices as \(\vW\), then the update in (4.13) can be re-written
where the inner product is just common matrix multiplication and \(\|\cdot\|_{F}\) is the Frobenius norm of a matrix (\(\|A\|_{F}^2:=\sum_{i,j}A_{i,j}^2\)). It is also easy to verify that this is exactly equivalent to \(\vW^{t+1}\leftarrow \vW^t - \eta_t \nabla\ell(\vW^t,\xi_t)\).
Now you might ask what would happen if we have multiple layers of weight matrices, rather than just one matrix above. Suppose we have \(L\) layers in total and the weight matrices are \(\vW_{1}, \vW_{2},..., \vW_{L}\). Then (4.15) can be updated to (for all \(l=1,2,...,L\))
Now we do a quick recap of possible choices of matrix norms.
Definition 4.1
For a matrix \(A\in\RR^{m\times n}\), also two norms \(\|\cdot\|_{\alpha}\) in \(\RR^{n}\) and \(\|\cdot\|_{\beta}\) in \(\RR^{m}\), the induced \(\alpha\)-to-\(\beta\) norm is given by (where \(\alpha\) and \(\beta\) could be \([1,\infty]\).)
Matrix induced norm allows use to do similar things as (4.14), i.e. replacing different norms and derive different form of descent methods, but for attention-based LLMs (which consists mostly of matrix attention weight matrices). The question is: How do we determine which norm is good for LLM training?
First, let’s see how we can reconstruct sign SGD, which is the fundamental building block of Adam. It turned out that the induced \(1\)-to-\(\infty\) norm is the correct choice to recover sign SGD for matrix variable, namely we have
where \(\dagger\) represents the dual norm.
Next, the question is: What if we just use the simplest matrix induced norm, namely the \(\ell_2\rightarrow\ell_2\) norm? Note that \(\ell_2\rightarrow\ell_2\) is also called the spectral norm since it can be shown that its value equals to the largest singular value of a matrix. It is one of the most commonly used matrix norm, largely since it is easy to compute (just need a singular value decomposition or power iteration).
It turned out that normalizing by the spectral norm will result in the orthogonalization of the matrix
where \(U\Sigma V^\top = \nabla\ell(\vW^t,\xi_t)\) is the singular value decoposition, and \(*\) stands for the dual norm of spectral norm, namely the nuclear norm (sum of all the singular values, or \(\mathrm{tr}(\Sigma)\)).
This motivates us to develop a matrix version Adam by just considering the EMA for (4.19). The resulting algorithm appears to be one of the strongest candidate to replace Adam for modern LLM training, name the Muon algorithm [cite muon here]
Algorithm 4.6 (Muon for matrix variable optimization)
Input: initial point \(\vW_{l}^0\), stepsize sequence \(\{ \eta_{t} \}\), max. no. of iterations \(T\).
For \(k=0,1,2,..., T-1\),
note that here the SVD is also conducted to \(\vM_{l}^{t}\).
Output: last iterate \(\vW^T\), or the solution sequence (checkpoints) \(\{ \vW^t \}_{t=1}^T\).
Muon performs surprisingly well for pre-training of LLMs with faster speed. For details see this blog [cite muon blog here] or this technical report [cite kimi muon paper here].
4.3.2. Approximate Hessian point of view for Shampoo and Muon#
Yet another line of work inspect the choice of optimizers in a different way, via approximating second order information. Next, we discuss how Adam could be derived in a different perspective, also discuss how this new Hessian-approximation perspective
Consider the situation where we are trying to solve a binary classification problem using cross entropy loss (i.e. a logistic regression. In fact, the following derivation also works for multi-class classification such as LLM pre-training):
where \(\mathcal{D}\) is the data distribution. We know that it would benefit the convergence if we can access the Hessian information. The hessian of (4.20) is
Note that (verify this simply by property of \(\log\) function)
therefore we get
Note that the first term \(I(\theta):=\EE_{(x,y)\sim\mathcal{D}} [\nabla\log\pi(y|x;\theta) (\nabla\log\pi(y|x;\theta))^\top]\) is also called the Fisher information matrix. Also, if the data distribution is identical to the model output distribution, i.e. \((x,y)\sim\mathcal{D}\) is the same distribution as \((x,y)\sim\mu\times\pi(\cdot|\cdot;\theta)\), then the second term will vanish ([should be a simple homework]).
Therefore we can legitimately approximate the second-order information via the Fisher information matrix \(I(\theta)\). We can even apply the Newton’s method approximately as follows!
Algorithm 4.7 (Prototype of a second order method via Fisher information)
Input: initial point \(\vw^0\), stepsize sequence \(\{ \gamma_k \}_{k \geq 0}\), max. no. of iterations \(K\).
For \(k=0,1,2,..., K-1\),
Output: last iterate \(\vw^K\), or the solution sequence \(\{ \vw^k \}_{k=1}^K\).
Now, there are two main issues for Algorithm 4.7: the first is how to estimate \(I(\theta)\), and the second is how to compute its inverse efficiently.
There is one solution that can solve both problem once: using a diagonal approximation: If we only keep the diagonal element of \(I(\theta)\), the computation should be simplified greatly, and the inverse is simply take the inverse of each of the diagonal numbers!
(TODO: continue the discussion)