Retrieval Augmented Language Model Pre-Training (REALM)

Kartik Perisetla
11 min readFeb 15, 2021

--

keywords: language modeling, question answering, passage retrieval, interpretable model, interpretable knowledge, T5, neural knowledge retriever

In this post, we will walk through paper REALM: Retrieval-Augmented Language Model Pre-Training by Google Research

TL;DR

  • Language Model pre-training captures good amount of world knowledge for NLP tasks such as Question-Answering. However this knowledge is stored in parameters of neural network. In order to store more knowledge, one has to go for even larger network i.e. even more parameters.
  • Key contribution of this paper: A solution to above mentioned problem. The approach outlined in this paper allows us to build neural models with relatively fewer parameters that perform better than SOTA on downstream tasks such as Question-Answering.
  • In order to capture the knowledge in a modular and interpretable way, Language model pre-training is augmented with a knowledge retriever that allows model to retrieve and attend over the documents from a large corpus such as Wikipedia, that is used during pre-training, fine-tuning and inference.
  • Key idea of REALM is to train the retriever using a performance based signal from unsupervised text: a retrieval that improves the language model’s perplexity is helpful and should be rewarded, while an uninformative retrieval should be penalized.

We all have seen large language models with billions of parameters trained on huge corpus of data to achieve SOTA results. But this paper tackles the problem of how we can build a lightweight neural model that achieve equal or better accuracy on downstream task. For that, authors have presented an approach wherein they use a latent model that is responsible for deciding what knowledge the model will learn.

Background & Experiment Setup

Language Model pre-training has been used to learn useful representations of language from unlabeled text. pre-trained model is then fine-tuned to perform the downstream task. Model parameters are essentially updated in this fine-tuning stage on top of learned representations at pre-training stage. Masked-Language-Model(MLM) is used as pre-training variant in this paper. One key difference or extension in this paper is that — Authors use different variant of masking — like Salient span masking which we will discuss further.
I believe the readers are familiar with task of Open-domain question answering. Authors have chosen this task in order to see what knowledge has been incorporated in model parameters. The typical architecture of Question-Answering systems utilize a two staged approach: retrieve relevant documents and extract an answer from the document. They key idea in this paper extends this two staged approach with language model pre-training.

Key Contributions

  • A new approach which augments language model pre-training with a textual knowledge retriever. Also, on how to train such a knowledge retriever in an unsupervised manner using Masked-Language-Model as a learning signal and back-propogating through a retrieval step that considers millions of documents. Essentially, a retrieval that improves the language model perplexity should be rewarded and the uninformative retrieval should be penalized.
  • Effectiveness of REALM pre-training is demonstrated by fine-tuning on Open-domain question answering and comparing against state-of-the-art models on 3 question-answering benchmarks. REALM approach outperforms all methods by 4–16% on absolute accuracy.

Main Challenge

  • Incorporating a large scale neural retrieval module during pre-training poses a significant computational challenge, since the retriever must consider millions of documents for each pre-training step and it has to learn through back-propagation.
  • Addressing this: structured the retriever such that the computation performed for each can be cached and asynchronously updated and selection of best documents can be formulated as MIPS.

How different is this approach from previous work

  • Prior work has used discrete retrieval step to neural networks(Danqi Chen’s DrQA), but did not apply to LM pre-training and used non-learned retrievers
  • kNN-LMs (Khandelwal et al.) uses only examples labeled for the target task, not fine tuned for downstream tasks. Also, this approach doesn’t use any pre-training, it does single pass over data to build the Key-Value store training context and target.

Evaluation Strategy

  • REALM approach is pre-trained and then fine-tuned on Open-domain Question-Answering and is evaluated on 3 benchmark datasets: NaturalQuestion-open, WebQuestions, CuratedTrec.
  • Compared with SOTA Open-domain Question-Answering models such as T5
  • Exact Match metric is used in evaluation.

Approach

  • In both pre-training and fine-tuning, REALM is learning a probability distribution P(y|x) for input x over possible outputs y. For pre-training, x is sentence from pre-training corpus X with masked tokens or masked salient spans. For fine-tuning task, x is a question and y is the answer.
  • REALM decomposes p(y|x) in two steps: retrieve and predict. For input x, it retrieves relevant documents z from a knowledge corpus Z. Then conditioning on input as well as retrieved document to generate output y. Here, z is treated as latent variable and overall likelihood of generating y is computed by marginalized over all possible documents z:
Overall likelihood of generating y is computed by marginalizing over all possible documents z.

Two staged approach

Model architecture is presented in the form of two components: a Neural Knowledge Retriever, which models p(z|x) and the Knowledge Augmented Encoder which models p(y|z,x).

Neural Knowledge Retriever

The retriever is defined using a dense inner product model:

The relevance score f(x,z) between x and z is defined as inner product of the vector embeddings. The retrieval distribution is the softmax over all relevance scores. The detailed diagram of how Knowledge retriever works is shown below:

Neural Knowledge Retriever

Knowledge Augmented Encoder

Given an input x and a retrieved document z, Knowledge Augmented Encoder defines p(y|z,x). input x and retrieved document z are joined into a single sequence and fed into a different BERT model and [CLS] token representation is used as a pooled representation of the sequence. They key idea is to allow cross attention between input x and document z before predicting y.

Just to refresh, below figure shows what cross attention does- In encoder-decoder setting, on decoder side for each timestep decoded so far, the representation for each token is recomputed using the cross attention. That is, using each token decoded so far as query and using representation from last layer from encoder as key-value, attention is computed and each token representation on decoder side is recomputed.

For Masked-Language-Model pre-training task, model has to predict the original value of masked token in input x. Same MLM objective is used as presented in BERT paper.

For Open-domain question answering fine tuning task, we want model to produce answer y. The assumption that answer y can be found as a contiguous sequence of tokens in some document z. Let S(z, y) be the set of spans matching y in z. Then p(y|z,x) can be defined as:

Start and end token representation for spans is used in Feed-Forward-Neural-Network to compute likelihood
Knowledge Augmented Encoder

Injecting inductive biases into pre-training

Authors have presented a few strategies that further guided the model towards more meaningful retrievals:

Salient span masking

In order to make the model learn about world knowledge when predicting the missing token during MLM, authors masked out salient spans pertaining to named entities. They used a BERT based named-entity-tagger and masked out spans tagged as entities and asked model in REALM pre-training to predict masked entities.

Null document

Consider the case when no document needs to be retrieved to predict the masked tokens — this is modeled by adding an empty null document to the top-k retrieved documents.

Prohibiting trivial retrievals

Here the authors have tried to address the issue when pre-training corpus and knowledge corpus are the same. If masked sentence x comes from document z, the knowledge encoder can trivially predict y by looking at the unmasked version of x in document z. This results in large positive gradient — if this occurs too often, the knowledge retriever ends up learning to look for exact string matches between input x and document z. Thus such candidates are excluded during pre-training.

Initialization — Warm start Embeddings

At the beginning of training if the retriever does not have good embeddings for input x and documents z, the retrieved documents z will likely be unrelated to input x. This causes knowledge encoder to learn to ignore the retrieved documents. Once this occurs the knowledge retriever never receives a meaningful gradient and thus cannot improve, creating a vicious cycle. In order to avoid this cold-start problem, authors do a warm start for these embeddings by leveraging BERT trained with simple training objective — Inverse Cloze Task(ICT) where given a sentence, model is trained to predict context/document it came from.

Training

The training objective for pre-training and fine-tuning is to maximize the log-likelihood log p(y|x) of the correct output y. Since Neural Knowledge Retriever(θ) and Knowledge Augmented Encoder(ϕ) are differentiable neural networks, thus allowing us to compute gradients, back-propagate the errors and update the model parameters using stochastic gradient descent.

The key challenge is that marginal probability computation p(y|x) involves summation over all the documents z in the knowledge corpus Z. Authors have approximated this by instead summing over top-k documents with highest probability. Authors are leveraging Maximum Inner Product Search(MIPS) algorithms to find the approximate top-k documents using relevance score f(x,z) — inner product between query and document embeddings.

In order to employ MIPS, an search index is built using the document embeddings as shown in figure above for Neural Knowledge Retriever(θ). One issue here is that the search index will go stale every time the model parameters are updated after each step.

Addressing the stale MIPS search index issue with Asynchronous refresh

Asynchronous re-embedding and re-indexing

One solutions authors employ is to refresh the search index by asynchronously re-embedding and re-indexing all the documents with latest set of model parameters, after few hundred steps. Even with this solution, the index is slightly stale between refreshes. But authors show empirically that this procedure results in stable optimization, provided the index refresh happens at a sufficiently frequent rate.

Two jobs: trainer and index builder

Figure below shoes the REALM pre-training with asynchronous MIPS refreshes. Two jobs are running at any given point of time: primary trainer job — that performs gradient updates on the parameters and secondary index builder job- that embeds and indexes the documents. As it can be seen from the figure, trainer sends the index builder a snapshot of its parameters, the trainer then continues to train while index builder uses latest parameters snapshot to construct a new index in background. As soon as new index is built, it is sent to the trainer.

Choice on refresh

Authors have used asynchronous refresh only for pre-training while it could have been used for pre-training as well as fine-tuning tasks. Authors used the MIPS index built once and used for fine-tuning and do not update document embeddings.

One interesting experiment to carry out would be to see — how MIPS index refresh rate helps with model performance. Also, the impact of using multiple sources for Knowledge Corpus.

What is Neural Knowledge Retriever Learning?

Authors have clearly explained how the training objective encourages meaningful retrievals — by rewarding for relevant retrievals and penalizing for irrelevant retrievals.

For a given query x and document z, the relevance score f(x,z) is assigned by retriever to the document z. It is demonstrated how a single step of gradient descent during REALM pre-training alters this score by looking at gradient with respect to the parameters of Neural Knowledge Retriever(θ):

  • For each document z, the gradient encourages the retriever to change the score f(x,z) by r(z) -> increasing if r(z) is positive and decreasing if r(z) is negative.
  • r(z) is positive iff p(y|z,x) > p(y|x) -> probability of predicting the correct output when document z is greater than probability of correct output when randomly sampling a document from p(z|x). Thus, document z receives a positive update when it performs better than expected. The detailed derivation of the gradient is provided in appendix of the paper

Experiments & Results

Authors present comparison of REALM approach with Retrieval-based open-QA and Generation-based open-QA systems. Retrieval-based open-QA systems first retrieve relevant documents and a reading comprehension system to extract answer from the documents. Generative-based open-QA systems model this as a sequence prediction task — encode the question and then decode the answer token-by-token based on the encoding.

Authors have reused the all hyperparameters from paper : Lee et al.(2019). You may refer to paper for actual details on infra level details on training like how many TPUs used, batch size, etc.

Test results on Open-QA benchmarks
  • Table 1 shows the accuracy of different approaches on three open-QA datasets. Table also shows the number of parameters for each model.
  • As it can be seen from the table, Generative open-QA systems based on T5 are powerful and their performance improves with model size. In contrast REALM(39.2, 40.4) outperforms T5–11B(34.5) model while being 30 times smaller.
  • Most direct comparison of REALM is with ORQA where fine-tuning setup, hyper-parameters and training data are identical. The improvement seen in REALM over ORQA is due to better pre-training methods. Table also shows that REALM approach can be applied both on — single corpus setting and separate corpus setting.

Ablation Study

  • Authors ablated critical components of REALM and presented the impacted. In order to understand whether REALM pre-training MLM task improves retriever or encoder, authors reset the parameters of either retriever or encoder to their baseline settings( as presented in ORQA paper) before pre-training and fed that into fine-tuning.
  • Resetting both retriever and encoder reduces the system to baseline ORQA. Conclusion from ablation study is that both components benefit from REALM approach but best performance is achieved when both are pre-trained with REALM and both used.

Adapting to new Knowledge

An explicit retrieval system allows authors to adapt to new world knowledge simply by modifying the corpus documents. To demonstrate this authors replaced the knowledge corpus with a more recent version of Wikipedia corpus after pre-training is done. When the input query is about a fact where the two corpora disagree, REALM can change the prediction to reflect the updated information. However, even with explicit knowledge retrieval mechanism, the knowledge augmented encoder ends up remembering some world knowledge, making the prediction of some input sentences not updated with the new corpus.

--

--

No responses yet