One of the great attractions of large language models (LLMs) is that they encode information about the real world. But the world is constantly changing, and an LLM’s information is only as fresh as the data it was trained on.
Training an LLM can take months, even when the task is parallelized across 1,000 servers, so AI researchers have sought alternate ways to update LLMs’ knowledge. One of these is to directly edit targeted layers of an LLM, to improve its performance on a particular knowledge-based task. This is a task-specific solution, not a general solution, but it takes hours to implement rather than months.
Existing techniques for direct layer editing generally require either manual selection of layers to be edited or a time-consuming procedure to determine the layers where editing will do the most good. Last week, at the 2024 meeting of the European Chapter of the Association for Computational Linguistics, we presented a new method for automatically selecting layers to be edited, which yields more-accurate updates than previous automated methods.
Compared to the prior method for manual layer selection, it also limits regression, or post-update backsliding on data that the model previously handled correctly. On some datasets, our method, which we call SaLEM (for salient-layers editing model), reduced regression by an order of magnitude, while offering equivalent accuracy on new data.
Identifying layers
We consider the case in which an LLM has been fine-tuned on a specific task, such as determining whether one input sentence logically entails or counts as evidence for or against another. In such cases, the model input is typically a pair of texts, and the output is a decision such as “entailed” or “supported”.
In the prior approach to manual layer selection, known as causal tracing, the first token of each training example is fed to the model, then the first and second, then the first, second, and third, and so on. Then the process is repeated with one of the model layers masked. This two-step analysis, in turn, must be repeated for each layer of the network, a time-consuming procedure.
In our case, we instead prepare an “edit dataset”, consisting of input-output pairs drawn from three groups: (1) the pass samples, for which the existing model outputs the correct answers; (2) the fail samples, for which the existing model outputs the wrong answers; and (3) the adapt samples, which are semantically equivalent to the fail samples but differently phrased.
For each sample, we compute the loss between the existing model’s output and the target output and the corresponding gradients — the modifications of model weights that make correct outputs more likely. Then we average the gradients across each layer of the model and across all training samples. The layer with highest average gradient — the layer that requires the largest modification to accommodate new facts about the world — is the one we edit.
Layer editing
To edit the selected layer, we use the MEND method proposed by Stanford University researchers in 2022. With MEND, a second machine learning model, the editor model, is trained to, essentially, take gradients as inputs and output parameter edits.
But rather than the raw gradients, the model’s inputs are a low-rank approximation of the gradients, which reduces the dimension of the data by identifying the axes along which most of the variance occurs. This is something like teasing out the underlying causes of the larger gradients, which helps the model generalize better. We also guard against overfitting by aggregating gradients in batches of 10 before computing their low-rank approximation.
We use two training objectives to train the editor, one that maximizes the likelihood of correct answers on inputs from the fail and adapt sets and one that minimizes output divergence on inputs from the pass set. This helps prevent regression.
In the original MEND paper, the Stanford researchers used this approach to edit the top three layers of a fine-tuned LLM, a reasonable heuristic for trading off editing efficiency, correction of outputs, and prevention of regression. Because SaLEM identifies the one layer most implicated in the new model update, it can match MEND’s performance on new data. But because it modifies parameters in one layer rather than three, it reduces regression.
Experiments
We evaluated SaLEM on six datasets used to fine-tune LLMs on natural-language-processing tasks. Four of the datasets had to do with natural-language inference, one was a question-answering dataset, and one was a dataset for the standard LLM task of next-token prediction. For the question-answering and generation tasks, we compared SaLEM and the baselines on four different LLM architectures. We measured performance using both edit accuracy, or post-editing accuracy on the new data, and drawdown, which measures regression on the old data.
On the inference tasks, SaLEM matched the edit accuracy of the top performers but had significantly better drawdown — four and ten times better than the second-best performer on two of the datasets. On the other two tasks, SaLEM finished second on both measures to an approach called editable neural networks (ENN). But ENN requires two copies of an LLM to run simultaneously, which is resource intensive. Indeed, for two of the four LLM architectures we tested, we were unable to run ENN because of its computational demands.
In ongoing work, we are investigating (1) enriching the editing dataset with better failed samples and their semantic and counterfactual equivalents, (2) a better weight update mechanism to inform the editor about the extent of updates for borderline instances, and (3) a method of performing edits without loading the full model into memory, as we currently do.