[Paper] Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks (NIPS 2020)
05 Nov 2023 #llm #transformer #nlp
Lewis, Patrick, et al. “Retrieval-augmented generation for knowledge-intensive nlp tasks.” Advances in Neural Information Processing Systems 33 (2020): 9459-9474.
Points
- Retrieval Augmented Generation (RAG) model combines a retriever and a generator for enhanced knowledge-intense tasks.
- RAG Variants: RAG-Sequence uses a single document for output; RAG-Token integrates multiple documents per token.
- RAG models outperform baselines in open-domain QA, abstractive QA, Jeopardy question generation, and fact verification.
- RAG models demonstrate practical benefits with easy updates to the non-parametric memory.
Background
- Large pre-trained Language models (LLMs) store factual knowledge in their parameters, functioning as implicit knowledge base.
- LLMs, however, have limitations: they cannot expand their memory, provide insight into their predictions, and may produce ‘hallucinations’.
- Recently, hybrid models, such as REALM and ORQA, address these issues by using a differentiable retriever to revised and expanded knowledge, showing promising results, primarily in open-domain question answering (QA).
Method
Retrieval-augmented generation (RAG) fine-tunes pre-trained generation models with a non-parametric memory for a general-purpose tasks.
- Parametric memory: a pre-trained seq2seq transformer
- Non-parametric memory: a dense vector index of Wikipedia, accessed with a pre-trained neural retriever.
- Dense passage retriever (DPR): retrieves latent documents conditioned on the input.
- BART: the generator conditions on the latent documents together with the input to generate the output. Other seq2seq models like T5 can also be used and fine-tuned with the retriever.
- Latent documents: marginalized using a top-K approximation, either on a per-output basis or a per-token basis.
- RAG-Sequence Model: assumes the same document is responsible for all tokens.
- RAG-Token Model: considers different documents for different tokens.
Models
RAG models use the input sequence $x$ to retrieve text documents $z$ and use them as additional context when generating the target sequence $y$. RAG has two components:
- Retriever $p_\eta(z\mid x)$: returns distributions over text passages given a query $x$ with parameters $\eta$.
- Truncated as top-K assumtion.
- Generator $p_\theta(y_i\mid x,z,y_{1:i-1})$: generates a current token based on the previous $i-1$ tokens $y_{1:i-1}$, the input $x$, and a retrieved passage $z$ with parameters $\theta$.
The retriever and the generator are trained end-to-end, treating the retrieved document as a latent variable. To marginalize over the latent documents, two methods are proposed, RAG-Sequence and RAG-Token.
RAG-Sequence and RAG-Token
RAG-Sequence Model uses the same retrieved document to generate the complete sequence.
- The retrieved document is a single latent variable to get the seq2seq probability $p(y\mid x)$ via a top-K approximation.
- The top-K documents are retrieved using the retriever, and generator produces the output sequence probability for each document.
- Use cases: Better suited for tasks where the context of entire documents is crucial, like summarization tasks.
RAG-Token Model uses different latent documents for each target token.
- The generator chooses content from several documents for the answer.
- The top-K documents are retrieved using the retriever, and the generator produces a distribution for the next output token for each document before marginalizing.
- Use cases: More suitable for tasks that benefit from integrating detailed information from multiple sources, like open-domain QA.
Retriever and Generator
Retriever $p_\mu(z\mid x)$ is based on DPR, which follows a bi-encoder architecture:
\[p_\mu(z|x)\propto \exp(\bf d \rm (z)^\top \bf q \rm (x)) \\ \bf d \rm (z)=\rm BERT_d(z), \ \bf q \rm (x)=\rm BERT_q(x)\]- $\bf d \rm (z)$: a dense representation of a document produced by a document encoder based on $\rm BERT_{BASE}$.
- $\bf q \rm (x)$: a query representation produced by a query encoder based on $\rm BERT_{BASE}$.
- Maximum inner product search (MIPS): caculates top-k $p_\eta(\cdot\mid x)$ approximately in sub-linear time.
- Non-parametric memory: the index of the document. The retriever is trained to retrieve documents containing answers to TriviaQA questions and Natural Questions.
Generator $p_\theta(y_i\mid x,z,y_{1:i-1})$ can be any encoder-decoder model, based on BART in the paper.
- $\rm BART_{large}$ is used: a pre-trained seq2seq transformer with 400M parameters, pre-trained using a denoising objective with various noising functions.
- The input $x$ and the retrieved document $z$ are concatenated and then inputted into $\rm BART$ model to generate the output.
- Parametric memory: $\rm BART$ generator parameters $\theta$.
Training
The retriever and generator are trained jointly without direct supervision on which document should be retrieved.
- Objective: Minimize the negative marginal log-likelihood of each target with a corpus of input/output pairs $(x_j, y_j)$, $\sum_j-\log(p(y_j\mid x_j))$.
- Adam optimizer.
- Fine-tuning only the query encoder $\rm BERT_q$ and the generator $\rm BART$ during training.
- Updating the document encoder $\rm BERT_d$ is costly and ineffective
- Requires periodic updating of the document index (as REALM).
- Not necessary for strong performance.
- Updating the document encoder $\rm BERT_d$ is costly and ineffective
Decoding
For testing, RAG-Sequence and RAG-Token require different methos to approximate $\arg \max_y{p(y\mid x)}$.
RAG-Sequence model utilizes beam search for each document $z$. It can’t be solved with a single beam search, as the likelihood $p(y\mid x)$ does not break into a conventional per-token likelihood.
- Each hypothesis of $z$ is scored by $p_\theta(y_i\mid x,z,y_{1:i-1})$.
- Some hypothesis $y$ included in the set of hypothesis $Y$ may not have appeared in the beams of all documents.
- Thorough Decoding: To estimate the probability of $y$, (1) Run an additional forward pass for each $z$ where $y$ doesn’t appear in the beam, (2) multiply the generator probability with $p_\eta(z\mid x)$, and (3) sum the probabilities across beams.
- Fast Decoding: For efficient decoding, Approximate $p_\theta(y\mid x,z_i) \approx 0$ where $y$ wasn’t generated during beam search from $x, z_i$, avoiding additional forward passes once the candidate set $Y$ is generated.
- For longer output sequences, $\left\vert Y \right\vert$ can be large with many forward passes.
RAG-Token model is a basic autoregressive seq2seq generator with transition probability:
\[p'_\theta(y_i\mid x,y_{1,i-1})=\sum_{z\in top-k(p(\cdot \mid x))}p_\eta(z_i \mid x)p_\theta(y_i\mid x,z_i,y_{1:i-1})\]Experiments
The experiments were conducted on several datasets to evaluate the model’s performance in knowledge-intensive NLP tasks.
- Wikipedia December 2018 dump was used as the non-parametric knowledge source.
- Wikipedia articles were split into 100-word chunks, totaling 21M documents.
- An embedding for each document was calculated by the document encoder $\rm BERT_d$, and a single MIPS index was built with Hierarchical Navigable Small World approximation for fast retrieval.
- When retrieving the top $k$ documents for each query, $k\in {5,10}$ was considered for training, and set using dev data for test time.
Tasks
- Open-domain Question Answering (QA): an important real-world application and common testbed for knowledge-intensive tasks.
- Text pairs $(x,y)$ are matched as questions and answers.
- RAG is trained to minimize the negative log-likelihood of answers.
- Close-book QA is also a compared task: generating answers without retrieving but purely with parametric knowledge.
- Datasets: Natural Questions, TriviaQA, WebQuestions, CuratedTREC
- Abstractrive Question Answering: tests natural language generation (NLG) ability with free-form and abstractive cases.
- Use MSMARCO NLG Task v2.1: only the question and answers, not existing gold passages in the dataset, treated as an open-domain abstractive QA task.
- Jeopardy Question Generation: evaluates the generation ability in a non-QA setting.
- Jeopardy: guessing an entity from a fact about that entity.
- e.g., “In 1986 Mexico scored as the first contry to host this international sport competition twice.” where the answer is “The World Cup”.
- Jeopardy questions are precise and factual, making it a challenging, knowledge-intensive task to generate them conditioned on the anser entities.
- Jeopardy: guessing an entity from a fact about that entity.
- Fact Verification (FEVER): a retrieval problem coupled with an challenging entailment reasoning task.
- Requires classifying whether a text is supported or refuted by Wikipedia or whether there’s not enough information to decide.
- Provides an appropriate testved for exploring a model’s ability to handle classification rather than generation.
- Two varients: the 3-way classification (supports/refutes/not enough) and the 2-way (support/refutes).
Results
The results demonstrated that both RAG-Sequence and RAG-Token models outperformed baseline models across various datasets and tasks.
Open-Domain QA
- RAG models significantly outperformed the baselines, showing higher EM and F1 scores.
- The RAG-Token model, in particular, performed well due to its ability to integrate detailed information from multiple documents.
Abstractive Question Answering
- RAG models achieved SOTA performance, even though many questions are unanswerable without the gold passages.
- RAG models hallucinated less and generated more factually correct and diverse text compared to BART (Table 3).
Jeopardy Question Generation
- Both of RAG models outperformed BART on Q-BLEU-1 (Table 2).
- Human evaluators indicate that RAG-generated content was more factual in 42.7% of cases, demostrating the effectiveness of RAG over the SOTA generation model (Table 4).
- RAG-Token model performed better than RAG-Sequence, combining content from several documents effectively (Fig 2).
- The generator’s the parametric knowledge sufficed to complete the generation after initially referencing the document (Fig 2).
Fact Verification
- For 3-way classification, RAG achieved scores within 4.3% of SOTA models trained with intermediate retrieval supervision for a specific domain.
- For 2-way classification, RAG achieved performance within 2.7% of the base model, SotA, which were trained to classify true of false given the gold evidences.
- The documents retrieved by RAG are overlapped significantly with FEVER’s gold evidence.
Additional Results
-
Generation Diversity: When investigating generation diversity by calculating the ratio of distinct ngrams to total ngrams generated by different models, RAG models generated more diverse outputs compared to BART. RAG-Sequence produced slightly more diverse outputs than RAG-Token (Table 5).
- Retrieval Ablations: Freezing the retriever during training resulted in lower performance compared to the original RAG models. Replacing the retriever with a BM25 system showed that learned retrieval improved performance for all task (table 6).
-
Index hot-swapping: Demonstrated the advantage of non-parametric memory by using an index from Wikipedia dump from December 2016. RAG models still answered 70% of questions correctly, showing that knowledge can be updated simply by replacing the non-parametric memory.
- Effect of Retrieving more documents: Adjusting the number of retrieved documents at test time showed improved performance up to a certain point, demonstrating the benefits of retrieveing more relevant documents (fig 3).