LLM Training as Large-scale Optimization

4. LLM Training as Large-scale Optimization#

In this chapter we will discuss the training of LLMs as a large-scale optimization problem.

To begin with, let us consider that you have already fixed all the model architecture (MLP, RNN, MHA, etc.) so that we could denote all the trainable parameters as one single vector \(\vw\) (note that if we have multiple layers of matrix weights, we can always vectorize them and concatenate them together), and fixed the dataset which is denoted as \(\mathcal{D}=\{\xi\}\) where each data \(\xi=[\xi_1,...,\xi_{T}]\) is a sequence of tokens (tokenization of a sentence or paragraph). Here each \(\xi_i\) is a token and \(T\) is usually called as the sequence length.

Now we want the model to learn to remember all the text, that is, given a unfinished sequence of token \([\xi_1,...,\xi_t]\), we want to let the model to predict the next token. Denote \(\pi(\xi|[\xi_1,...,\xi_t];\vw)\) the probability of model (with weight \(\vw\)) outputing \(\xi\) when the input is \([\xi_1,...,\xi_t]\), which is exactly the model output (remember the last layer of the model is always a softmax layer, so that the output would exactly be the probability for predicting the next token). Then we want to teach the model by maximizing the probability of outputing the correct token \(\xi_{t+1}\), mathematically, it would be

\[\min_{\vw}\ -\log (\pi(\xi_{t+1}|[\xi_1,...,\xi_t];\vw)).\]

Here we do minimize the negative log likelihood, which is equivalent to maximizing for \(\pi(\xi|[\xi_1,...,\xi_t];\vw)\) directly. Now imagine that for the entire dataset, we want to do this next-token prediction for all tokens, mathematically:

\[\min_{\vw}\ \frac{1}{|\mathcal{D}|}\sum_{\xi\in\mathcal{D}} \frac{1}{T}\sum_{t=1}^{T}-\log (\pi(\xi_{t+1}|[\xi_1,...,\xi_t];\vw)).\]

We can simply denote the one sample loss as \(\ell(\vw,\xi):=\frac{1}{T}\sum_{t=1}^{T}-\log (\pi(\xi_{t+1}|[\xi_1,...,\xi_t];\vw))\). Imagine that we have countless number of data (\(|\mathcal{D}|\rightarrow\infty\)), the above loss boils down to the following stochastic optimization problem:

(4.1)#\[\min_{\vw}\ f(\vw) := \EE_{\xi\sim\mathcal{D}}[\ell(\vw,\xi)]\]

where \(\mathcal{D}\) is again the training dataset. Essentially we want to minimize the prediction loss for a huge dataset \(\mathcal{D}\).

Note that for neural network training, we certainly still cannot collect all possible text data (well… it is almost achieved by companies such as OpenAI, as they collect all data on the internet), we can still collect a set of training data \(\mathcal{D}=\{\xi_i\}_{i=1,2,...,|\mathcal{D}|=:n}\) and optimize for the training loss

(4.2)#\[\min_{\vw}\ f(\vw) := \frac{1}{n}\sum_{i=1}^{n}\ell(\vw, \xi_i)\]

which is also know as the finite sum setting of the stochastic optimization.