Paper Summary: An Introduction To Transformers - Turner (2023)
Summary of “An Introduction to Transformers” by Richard E. Turner (2023).
Input to the Transformer
The input to the transformer is a sequence \(X^{(0)}\in\mathbb{R}^{D\times N}\) where \(N\) is the length of the sequence and \(D\) is the dimensionality of each item in the sequence, which are known as tokens and denoted as \(\mathbf{x}_n^{(0)}\in\mathbb{R}^{D\times 1}\). \[ X = \left[\mathbf{x}_0^{(0)}, \ldots, \mathbf{x}_N^{(0)}\right] \] The items in the sequence are representations of objects of interest. For instance, in language tasks, a token is usually a unique vector representation of a word, whereas for an image it would be a vector representation of a patch. The representation can be given/fixed (e.g. Word2Vec) or learned during training.
Transformer Block
The output of the transformer is denoted \(X^{(M)}\) where \(M\) is the number of transformer blocks \[ X^{(m+1)} = \text{transformer-block}(X^{(m)}). \] Each transformer block consists of two stages, each with three steps, where the middle step is sandwiched between a standardization and a residual connection.
- Stage 1:
- Standardisation: Each token is standardized separately before being fed into the tranformer to stabilise learning: to each token \(\mathbf{x}_n^{(0)}\) we subtract the mean and divide by the standard deviation \[ \bar{\mathbf{x}}_n^{(0)} = \frac{\mathbf{x}_n^{(0)} - \text{mean}(\mathbf{x}_n^{(0)})}{\text{std}(\mathbf{x}_n^{(0)})} \qquad \qquad \forall n = 1, \ldots, N. \] so we form \(\bar{X}^{(0)} = [\bar{\mathbf{x}}_0^{(0)}, \ldots, \bar{\mathbf{x}}_N^{(0)}]\). Typically this is called LayerNorm, but the author calls it TokenNorm.
- MHSA Layer: Consists of \(H\in\mathbb{Z}_+\) heads, meaning that we perform \(H\) operations in parallel. Each of the \(H\) operations, shown below, consists of multiplying the standardised sequence \(\bar{X}^{(m-1)}\) by the attention matrix and the projecting, thus obtaining a \(D\times N\) matrix and then “linearly projecting” this by \(V_h\). The attention matrix is constructed from the input. \[ \begin{align} Y^{(m)} &= \text{MHSA}(\bar{X}^{(m-1)}) = \sum_{h=1}^H V_h \bar{X}^{(m-1)} A_h^{(m)} && \text{where }\, V_h\in\mathbb{R}^{D\times D} \text{ and } A_h^{(m)}\in\mathbb{R}^{N\times N}\\ (A_h)_{n, n'} &= \frac{\displaystyle \exp\left\{\frac{\left\langle\mathbf{k}_{h, n}^{(m)}, \mathbf{q}_{h, n'}^{(m)} \right\rangle}{\sqrt{D}}\right\}}{\displaystyle \sum_{n''=1}^N \exp\left\{\frac{\left\langle\mathbf{k}_{h, n''}^{(m)}, \mathbf{q}_{h, n'}^{(m)} \right\rangle}{\sqrt{D}}\right\}} && \text{where } \, \mathbf{q}_{h, n}^{(m)} = U_{\mathbf{q}, h}^{(m)} \bar{\mathbf{x}}_n^{(m-1)} \text{ and } \mathbf{k}_{h, n}^{(m)} = U_{\mathbf{k}, h}^{(m)}\bar{\mathbf{x}}_n^{(m-1)}. \end{align} \] The vectors \(\mathbf{q}\) and \(\mathbf{k}\) are known as queries and keys respectively and we use the matrices \(U_{\mathbf{q}}\) and \(U_{\mathbf{k}}\) to introduce a notion of asymmetry of the relationship between the various tokens. These matrices have shape \(K\times D\) where typically \(K < D\). Each element of \(A_h^{(m)}\) is then the softmax of this relationship and we use multiple heads because each head is capturing a different type of relationship. The parameters for this layer are \(U_{\mathbf{q}, h}, U_{\mathbf{k}, h}\) and \(V_h\) for every \(h=1, \ldots, H\).
- Residual Connection: We use residual connections so that each preceding transformation is not too dissimilar to a identity operation: it only introduces a mild non-linearity. This stabilizes training, and if we have enough layers, it will still be able to learn complex representations. To implement the residual connection we simply do \[ Z^{(m)} = X^{(m-1)} + Y^{(m)}, \] and we point out that we sum the un-standadized \(X^{(m-1)}\). This ends the first stage.
- Stage 2
- Standardization: We do the same as before and standardize each token in \(Z^{(m)}\) to obtain \(\bar{Z}^{(m)}\).
- MLP Layer: A simple (usually quite shallow) MLP is applied to each transformed token \(\bar{\mathbf{z}}_n^{(m)}\) to obtain the corresponding representation \[ \mathbf{\varkappa}_n^{(m)} = \text{MLP}(\bar{\mathbf{z}}_n^{(m)}) \qquad n=1, \ldots, N \text{ and } m=0, \ldots, M. \] Notice that we have a single MLP and this is used for all the transformed tokens.
- Residual Connection: Finally, we apply one last residual connection to obtain \(X^{(m)}\) \[ \mathbf{x}_n^{(m)} = Z^{(m)} + \mathbf{\varkappa}_n^{(m)}, \] where we notice, one again, that we sum the output of the first stage, before being normalised by the second stage.
Transformers for Sequential Tasks
To use Transformers for sequential tasks we use masking which means that the matrices \(A^{(m)}_h\) are all upper-triangular (notice that other references use \(N\times D\) inputs and so require lower-triangular matrices), meaning that we enforce an auto-regressive behavior, which makes the model more computationally tractable for sequential problems, at the cost of losing some representation ability. To perform Language Modelling (i.e. predicting the next word), we want to use \(\mathbf{x}_{n-1}^{(M)}\) to predict \(\mathbf{x}_{n}^{(M)}\). We do this using the softmax function on the dictionary \[ p(\text{word}_n = w \mid \mathbf{x}_{n-1}^{(M)}) = \frac{\displaystyle \exp\left(\mathbf{g}_w^\top \mathbf{x}_{n-1}^{(M)}\right)}{\displaystyle \sum_{w=1}^W \exp\left(\mathbf{g}_w^\top \mathbf{x}_{n-1}^{(M)}\right)}, \] where \(W\) is the size of the vocabulary and \(w\) is the \(w^{\text{th}}\) word in this vocabulary, and \(\left\{\mathbf{g}_w\right\}_{w=1}^W\) are weights of the softmax function.