Explainability for Large Language Models: A Survey
They divide training of LLMs into two categories:
- Fine-tuning: LM first trained on large corpus of unlabeled text and then fine=tuned on a set of labeled data from a specific domain. Typically one adds FFNN above the final layer of the encoder which can learn this additional task. Explainability research for these models:
- Local-exploration: aims to understand how the model generated a prediction from a given input. Examples are:
- Feature Attribution: assignn a relevance score to each input feature to reflect its contribution to the model prediction
- Perturbation-based: perturb inputs (remove/mask/alter) and evaluate change in model output. Issues: ignoring correlations between input features, high-confidence predictions with non-sensical inputs, evaluation on out-of-distribution data.
- Gradient-based: Importance of an input feature determined by the magnitude of the partial derivative of the output with respect to that dimension. A big problem is due to gradient saturation, masking out smaller gradients.
- Surrogate Models: Train a white-box model to explain model outputs. E.g. TransSHAP uses Shapley values from game theory: each feature is a player in a cooperative prediction game and assigns subsets of features a value corresponding to their contribution to the model prediction.
- Attention-based Explanations: Use attention weight to explain model output. It is unclear whether it is appropriate at all to use attention weights, some studies show that explanations given by attention-based methods don’t correlate with those given by other methods. Two main local attention techniques:
- Visualizations: E.g. visualize attention heads for a single input using bipartite graphs or heatmaps. An interesting approach is to track the attention flow, to trace the evolution of attention, se Attention Flows: Analyzing and Comparing Attention Mechanisms in Language Models
- Function-based: Use partial derivatives of model outputs with respect to attention weights, or integrated versions of partial gradients, or a mixture of this and raw attention.
- Example-based Explanations: Study how a model’s output change based on different inputs.
- Adversarial Examples: An interesting one is SemAttack which first transforms input tokens to embeddings, then perturbs these in an iterative manner to maximise the goal of the adversarial attack.
- Counterfactual Explanation: It is a type of causal explanation, containing two stages: firt important tokens are selected, then they are edited/infilled.
- Data Influence: Measure the influence of training examples on the loss on test points. An important example of this branch of research was done by Anthropic where they aim to answer the question “how would the model’s behavior change if a given sequence were added to the training set?” This can be tackled with influence functions, a technique from statistics. This techniques has computational bottlenecks: computing inverse-Hessian-vector product (IHVP), and gradients for each training example, separately for each query. IHVP is approximated with EK-FAC. Second problem tackled with query batching.
- Natural Language Explanation: Typically done by training a separate model with data containing both text and human-annotated explanations. There are several approaches such as explain-then-predict, predict-then-explain or joint predict-explain.
- Feature Attribution: assignn a relevance score to each input feature to reflect its contribution to the model prediction
- Global-exploration: Aim to offer insight into inner workings of language models, e.g. explain the knowledge and linguistic properties learned by individual components of the model.
- Probing-based Explanations: Methods to understand the knowledge that LLMs have learned.
- Classifier-based Probing: Train a classifier to identify either certain linguistic properties or reasoning abilities from representations (generated by the model) for input words. E.g. some probing methods look at vector representations to measure the knowledge embedded in the model.
- Parameter-free Probing: Performance of a language model is assessed based on how well it performs on a dataset with phrases having either correct or incorrect grammar.
- Neuron Activation Explanation:
- Ranking and Relating: Important neurons are identified by studying their activations using unsupervised learning. This is done to reduce the computational burden and only focus on a subset of all neurons. The importance of neurons can be studied with ablation studies. Then, the relationship between individual neurons and linguistic properties is learned.
- Using GPT: Language models can explain neurons in language models showed how one can use GPT-4 to summarise the pattern in the text that trigger high activation values. The process is simple: ask language model to explain the neuron’s activations. Then, conditional on the explanations, simulate the neuronal activations. Then these explanations are scored by comparing them with the true ones. A challenge is that we don’t really have available the ground truth explanations.
- Concept-based Explanations: Map input to a set of concepts and measure the importance score of each pre-defined concept to model predictions. E.g. TCAV uses directional derivatives to determine how much the model predictions depend on the concepts. Drawback is that it requires additional data (probing dataset) and it can be tricky to determine which ones are useful concepts for explainability. Mechanistic Interpretability: Mostly studies the connections between neurons, especially when these connections for circuits or linear combinations of neuron activations. For instance Decomposing Language Models Into Understandable Components studies “patterns of neural activation” and find them to be more consistent and interpretable than individual neurons (based on blinded human feedback).
- Probing-based Explanations: Methods to understand the knowledge that LLMs have learned.
- Local-exploration: aims to understand how the model generated a prediction from a given input. Examples are:
- Prompting: Model trained on sentences with blanks, that the model has to fill in to enable zero or few-shot learning. This gives a foundational model. To turn it into an assistant model it has to be trained on instruction-response examples. This is done via supervised fine-tuning to align the model’s responses to human feedback, and the typical approach is Reinforcement Learning from Human Feedback (RLHF).
- Base Model Explanation:
- Understanding Base Models via Prompting: Work mostly centers around explaining In-Context Learning (ICL) and Chain-of-Thought (CoT) prompting. For the latter, one study focused on perturbing CoT to determine which aspects are important for generating high-performing explanations.
- Understanding Base Models via their Representation space: Typically done in two parts:
- Representation Reading: Identifies representations for high-level concepts () and functions withing a network.
- Representation Control: Manipulate representations of concepts/functions to meet some requirements (typically safety).
- Assistant Model Explantion: Base models are undergo an alignment fine-tuning via RLHF. Explainability research in this area focuses on:
- Pre-training vs Fine-tuning: Understanding whether knowledge comes from the initial pre-training or from the fine-tunining stage. For instance, LIMA: Less Is More for Alignment achieves GPT-4-like performance by fine-tuning with only 1000 well-crafted examples and no reinforcement learning. They conclude that almost all knowledge comes from pre-training and fine-tuning is somehow an easier task. Another conclusion is that data-quality is more important than data quantity in fine-tuning.
- Understanding Hallucinations: Hallucinations may be the product of a lack of data or repeated data.
- Base Model Explanation:
Finally, discusses the following open challenges in Explainablity research: - Souces of Emergent Abilities - Model Perspective: what architectural choices lead to these emergent abilities? Minimum complexity to achive this observed strong performance? - Data Perspective: Which subsets of the training data are responsible for particular model predictions? How does data quality/quantity affect pre-training and fine-tuning? - Differences in reasoning between promted and fine-tuned models. - LLMs often predict via shortcuts rather than reasoning: understanding what causes this, and improving OOD performance is an important task. - It has been found (see On Attention Redundancy: A Comprehensive Study) that often different heads are redundant and could be pruned without a massive impact on model performace. This could lead to model-compression techniques. - Exploring temporal analysis: i.e. understanding how the training dynamics evolves and the phase transitions. For instance, Sudden Drops in the Loss: Syntax Acquisition, Phase Transitions, and Simplicity Bias in MLMs found a phase transition during pre-training whereby the model gains Syntactic Attention Structure (SAS) which leads to a big drop in loss.