Retrieval Augmented Language Model Pre-Training (REALM)
keywords: language modeling, question answering, passage retrieval, interpretable model, interpretable knowledge, T5, neural knowledge retriever
In this post, we will walk through paper REALM: Retrieval-Augmented Language Model Pre-Training by Google Research
TL;DR
- Language Model pre-training captures good amount of world knowledge for NLP tasks such as Question-Answering. However this knowledge is stored in parameters of neural network. In order to store more knowledge, one has to go for even larger network i.e. even more parameters.
- Key contribution of this paper: A solution to above mentioned problem. The approach outlined in this paper allows us to build neural models with relatively fewer parameters that perform better than SOTA on downstream tasks such as Question-Answering.
- In order to capture the knowledge in a modular and interpretable way, Language model pre-training is augmented with a knowledge retriever that allows model to retrieve and attend over the documents from a large corpus such as Wikipedia, that is used during pre-training, fine-tuning and inference.
- Key idea of REALM is to train the retriever using a performance based signal from unsupervised text: a retrieval that improves the language model’s perplexity is helpful and should be rewarded, while an uninformative retrieval should be penalized.
We all have seen large language models with billions of parameters trained on huge corpus of data to achieve SOTA results. But this paper tackles the problem of how we can build a lightweight neural model that achieve equal or better accuracy on downstream task. For that, authors have presented an approach wherein they use a latent model that is responsible for deciding what knowledge the model will learn.
Background & Experiment Setup
Language Model pre-training has been used to learn useful representations of language from unlabeled text. pre-trained model is then fine-tuned to perform the downstream task. Model parameters are essentially updated in this fine-tuning stage on top of learned representations at pre-training stage. Masked-Language-Model(MLM) is used as pre-training variant in this paper. One key difference or extension in this paper is that — Authors use different variant of masking — like Salient span masking which we will discuss further.
I believe the readers are familiar with task of Open-domain question answering. Authors have chosen this task in order to see what knowledge has been incorporated in model parameters. The typical architecture of Question-Answering systems utilize a two staged approach: retrieve relevant documents and extract an answer from the document. They key idea in this paper extends this two staged approach with language model pre-training.
Key Contributions
- A new approach which augments language model pre-training with a textual knowledge retriever. Also, on how to train such a knowledge retriever in an unsupervised manner using Masked-Language-Model as a learning signal and back-propogating through a retrieval step that considers millions of documents. Essentially, a retrieval that improves the language model perplexity should be rewarded and the uninformative retrieval should be penalized.
- Effectiveness of REALM pre-training is demonstrated by fine-tuning on Open-domain question answering and comparing against state-of-the-art models on 3 question-answering benchmarks. REALM approach outperforms all methods by 4–16% on absolute accuracy.
Main Challenge
- Incorporating a large scale neural retrieval module during pre-training poses a significant computational challenge, since the retriever must consider millions of documents for each pre-training step and it has to learn through back-propagation.
- Addressing this: structured the retriever such that the computation performed for each can be cached and asynchronously updated and selection of best documents can be formulated as MIPS.
How different is this approach from previous work
- Prior work has used discrete retrieval step to neural networks(Danqi Chen’s DrQA), but did not apply to LM pre-training and used non-learned retrievers
- kNN-LMs (Khandelwal et al.) uses only examples labeled for the target task, not fine tuned for downstream tasks. Also, this approach doesn’t use any pre-training, it does single pass over data to build the Key-Value store training context and target.
Evaluation Strategy
- REALM approach is pre-trained and then fine-tuned on Open-domain Question-Answering and is evaluated on 3 benchmark datasets: NaturalQuestion-open, WebQuestions, CuratedTrec.
- Compared with SOTA Open-domain Question-Answering models such as T5
- Exact Match metric is used in evaluation.
Approach
- In both pre-training and fine-tuning, REALM is learning a probability distribution P(y|x) for input x over possible outputs y. For pre-training, x is sentence from pre-training corpus X with masked tokens or masked salient spans. For fine-tuning task, x is a question and y is the answer.
- REALM decomposes p(y|x) in two steps: retrieve and predict. For input x, it retrieves relevant documents z from a knowledge corpus Z. Then conditioning on input as well as retrieved document to generate output y. Here, z is treated as latent variable and overall likelihood of generating y is computed by marginalized over all possible documents z:
Two staged approach
Model architecture is presented in the form of two components: a Neural Knowledge Retriever, which models p(z|x) and the Knowledge Augmented Encoder which models p(y|z,x).
Neural Knowledge Retriever
The retriever is defined using a dense inner product model:
The relevance score f(x,z) between x and z is defined as inner product of the vector embeddings. The retrieval distribution is the softmax over all relevance scores. The detailed diagram of how Knowledge retriever works is shown below:
Knowledge Augmented Encoder
Given an input x and a retrieved document z, Knowledge Augmented Encoder defines p(y|z,x). input x and retrieved document z are joined into a single sequence and fed into a different BERT model and [CLS] token representation is used as a pooled representation of the sequence. They key idea is to allow cross attention between input x and document z before predicting y.
Just to refresh, below figure shows what cross attention does- In encoder-decoder setting, on decoder side for each timestep decoded so far, the representation for each token is recomputed using the cross attention. That is, using each token decoded so far as query and using representation from last layer from encoder as key-value, attention is computed and each token representation on decoder side is recomputed.
For Masked-Language-Model pre-training task, model has to predict the original value of masked token in input x. Same MLM objective is used as presented in BERT paper.
For Open-domain question answering fine tuning task, we want model to produce answer y. The assumption that answer y can be found as a contiguous sequence of tokens in some document z. Let S(z, y) be the set of spans matching y in z. Then p(y|z,x) can be defined as:
Injecting inductive biases into pre-training
Authors have presented a few strategies that further guided the model towards more meaningful retrievals:
Salient span masking
In order to make the model learn about world knowledge when predicting the missing token during MLM, authors masked out salient spans pertaining to named entities. They used a BERT based named-entity-tagger and masked out spans tagged as entities and asked model in REALM pre-training to predict masked entities.
Null document
Consider the case when no document needs to be retrieved to predict the masked tokens — this is modeled by adding an empty null document to the top-k retrieved documents.
Prohibiting trivial retrievals
Here the authors have tried to address the issue when pre-training corpus and knowledge corpus are the same. If masked sentence x comes from document z, the knowledge encoder can trivially predict y by looking at the unmasked version of x in document z. This results in large positive gradient — if this occurs too often, the knowledge retriever ends up learning to look for exact string matches between input x and document z. Thus such candidates are excluded during pre-training.
Initialization — Warm start Embeddings
At the beginning of training if the retriever does not have good embeddings for input x and documents z, the retrieved documents z will likely be unrelated to input x. This causes knowledge encoder to learn to ignore the retrieved documents. Once this occurs the knowledge retriever never receives a meaningful gradient and thus cannot improve, creating a vicious cycle. In order to avoid this cold-start problem, authors do a warm start for these embeddings by leveraging BERT trained with simple training objective — Inverse Cloze Task(ICT) where given a sentence, model is trained to predict context/document it came from.
Training
The training objective for pre-training and fine-tuning is to maximize the log-likelihood log p(y|x) of the correct output y. Since Neural Knowledge Retriever(θ) and Knowledge Augmented Encoder(ϕ) are differentiable neural networks, thus allowing us to compute gradients, back-propagate the errors and update the model parameters using stochastic gradient descent.
The key challenge is that marginal probability computation p(y|x) involves summation over all the documents z in the knowledge corpus Z. Authors have approximated this by instead summing over top-k documents with highest probability. Authors are leveraging Maximum Inner Product Search(MIPS) algorithms to find the approximate top-k documents using relevance score f(x,z) — inner product between query and document embeddings.
In order to employ MIPS, an search index is built using the document embeddings as shown in figure above for Neural Knowledge Retriever(θ). One issue here is that the search index will go stale every time the model parameters are updated after each step.
Addressing the stale MIPS search index issue with Asynchronous refresh
Asynchronous re-embedding and re-indexing
One solutions authors employ is to refresh the search index by asynchronously re-embedding and re-indexing all the documents with latest set of model parameters, after few hundred steps. Even with this solution, the index is slightly stale between refreshes. But authors show empirically that this procedure results in stable optimization, provided the index refresh happens at a sufficiently frequent rate.
Two jobs: trainer and index builder
Figure below shoes the REALM pre-training with asynchronous MIPS refreshes. Two jobs are running at any given point of time: primary trainer job — that performs gradient updates on the parameters and secondary index builder job- that embeds and indexes the documents. As it can be seen from the figure, trainer sends the index builder a snapshot of its parameters, the trainer then continues to train while index builder uses latest parameters snapshot to construct a new index in background. As soon as new index is built, it is sent to the trainer.
Choice on refresh
Authors have used asynchronous refresh only for pre-training while it could have been used for pre-training as well as fine-tuning tasks. Authors used the MIPS index built once and used for fine-tuning and do not update document embeddings.
One interesting experiment to carry out would be to see — how MIPS index refresh rate helps with model performance. Also, the impact of using multiple sources for Knowledge Corpus.
What is Neural Knowledge Retriever Learning?
Authors have clearly explained how the training objective encourages meaningful retrievals — by rewarding for relevant retrievals and penalizing for irrelevant retrievals.
For a given query x and document z, the relevance score f(x,z) is assigned by retriever to the document z. It is demonstrated how a single step of gradient descent during REALM pre-training alters this score by looking at gradient with respect to the parameters of Neural Knowledge Retriever(θ):
- For each document z, the gradient encourages the retriever to change the score f(x,z) by r(z) -> increasing if r(z) is positive and decreasing if r(z) is negative.
- r(z) is positive iff p(y|z,x) > p(y|x) -> probability of predicting the correct output when document z is greater than probability of correct output when randomly sampling a document from p(z|x). Thus, document z receives a positive update when it performs better than expected. The detailed derivation of the gradient is provided in appendix of the paper
Experiments & Results
Authors present comparison of REALM approach with Retrieval-based open-QA and Generation-based open-QA systems. Retrieval-based open-QA systems first retrieve relevant documents and a reading comprehension system to extract answer from the documents. Generative-based open-QA systems model this as a sequence prediction task — encode the question and then decode the answer token-by-token based on the encoding.
Authors have reused the all hyperparameters from paper : Lee et al.(2019). You may refer to paper for actual details on infra level details on training like how many TPUs used, batch size, etc.
- Table 1 shows the accuracy of different approaches on three open-QA datasets. Table also shows the number of parameters for each model.
- As it can be seen from the table, Generative open-QA systems based on T5 are powerful and their performance improves with model size. In contrast REALM(39.2, 40.4) outperforms T5–11B(34.5) model while being 30 times smaller.
- Most direct comparison of REALM is with ORQA where fine-tuning setup, hyper-parameters and training data are identical. The improvement seen in REALM over ORQA is due to better pre-training methods. Table also shows that REALM approach can be applied both on — single corpus setting and separate corpus setting.
Ablation Study
- Authors ablated critical components of REALM and presented the impacted. In order to understand whether REALM pre-training MLM task improves retriever or encoder, authors reset the parameters of either retriever or encoder to their baseline settings( as presented in ORQA paper) before pre-training and fed that into fine-tuning.
- Resetting both retriever and encoder reduces the system to baseline ORQA. Conclusion from ablation study is that both components benefit from REALM approach but best performance is achieved when both are pre-trained with REALM and both used.
Adapting to new Knowledge
An explicit retrieval system allows authors to adapt to new world knowledge simply by modifying the corpus documents. To demonstrate this authors replaced the knowledge corpus with a more recent version of Wikipedia corpus after pre-training is done. When the input query is about a fact where the two corpora disagree, REALM can change the prediction to reflect the updated information. However, even with explicit knowledge retrieval mechanism, the knowledge augmented encoder ends up remembering some world knowledge, making the prediction of some input sentences not updated with the new corpus.