Update LFQA with the latest LFQA seq2seq and retriever models (#2210)

* Register BartEli5Converter for vblagoje/bart_lfqa model

* Update LFQA unit tests

* Update LFQA tutorials
This commit is contained in:
Vladimir Blagojevic 2022-03-08 15:11:41 +01:00 committed by GitHub
parent 255226f9fa
commit 6c0094b5ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 68 additions and 60 deletions

View File

@ -32,7 +32,7 @@ Make sure you enable the GPU runtime to experience decent speed in this tutorial
# Install the latest master of Haystack
!pip install --upgrade pip
!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]
!pip install -q git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]
```
@ -80,22 +80,24 @@ document_store.write_documents(dicts)
#### Retriever
**Here:** We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
```python
from haystack.nodes import EmbeddingRetriever
from haystack.nodes import DensePassageRetriever
retriever = EmbeddingRetriever(
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", model_format="retribert"
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
)
document_store.update_embeddings(retriever)
```
Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents.
Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents.
```python
@ -111,13 +113,13 @@ print_documents(res, max_text_len=512)
Similar to previous Tutorials we now initalize our reader/generator.
Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5)
Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa)
```python
generator = Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5")
generator = Seq2SeqGenerator(model_name_or_path="vblagoje/bart_lfqa")
```
### Pipeline
@ -139,13 +141,13 @@ pipe = GenerativeQAPipeline(generator, retriever)
```python
pipe.run(
query="Why did Arya Stark's character get portrayed in a television adaptation?", params={"Retriever": {"top_k": 1}}
query="How did Arya Stark's character get portrayed in a television adaptation?", params={"Retriever": {"top_k": 3}}
)
```
```python
pipe.run(query="What kind of character does Arya Stark play?", params={"Retriever": {"top_k": 1}})
pipe.run(query="Why is Arya Stark an unusual character?", params={"Retriever": {"top_k": 3}})
```
## About us

View File

@ -386,7 +386,8 @@ class Seq2SeqGenerator(BaseGenerator):
def _register_converters(cls, model_name_or_path: str, custom_converter: Optional[Callable]):
# init if empty
if not cls._model_input_converters:
cls._model_input_converters["yjernite/bart_eli5"] = _BartEli5Converter()
for c in ["yjernite/bart_eli5", "vblagoje/bart_lfqa"]:
cls._model_input_converters[c] = _BartEli5Converter()
# register user provided custom converter
if model_name_or_path and custom_converter:

View File

@ -1307,12 +1307,12 @@
},
"preceding_context_len": {
"title": "Preceding Context Len",
"default": 3,
"default": 1,
"type": "integer"
},
"following_context_len": {
"title": "Following Context Len",
"default": 3,
"default": 1,
"type": "integer"
},
"remove_page_headers": {

View File

@ -304,8 +304,8 @@ def question_generator():
@pytest.fixture(scope="function")
def eli5_generator():
return Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5", max_length=20)
def lfqa_generator(request):
return Seq2SeqGenerator(model_name_or_path=request.param, min_length=100, max_length=200)
@pytest.fixture(scope="function")
@ -509,6 +509,14 @@ def get_retriever(retriever_type, document_store):
model_format="retribert",
use_gpu=False,
)
elif retriever_type == "dpr_lfqa":
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
use_gpu=False,
embed_title=True,
)
elif retriever_type == "elasticsearch":
retriever = ElasticsearchRetriever(document_store=document_store)
elif retriever_type == "es_filter_only":

View File

@ -60,9 +60,10 @@ def test_generator_pipeline(document_store, retriever, rag_generator):
@pytest.mark.slow
@pytest.mark.generator
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
@pytest.mark.parametrize("retriever", ["retribert"], indirect=True)
@pytest.mark.parametrize("retriever", ["retribert", "dpr_lfqa"], indirect=True)
@pytest.mark.parametrize("lfqa_generator", ["yjernite/bart_eli5", "vblagoje/bart_lfqa"], indirect=True)
@pytest.mark.embedding_dim(128)
def test_lfqa_pipeline(document_store, retriever, eli5_generator):
def test_lfqa_pipeline(document_store, retriever, lfqa_generator):
# reuse existing DOCS but regenerate embeddings with retribert
docs: List[Document] = []
for idx, d in enumerate(DOCS_WITH_EMBEDDINGS):
@ -70,11 +71,11 @@ def test_lfqa_pipeline(document_store, retriever, eli5_generator):
document_store.write_documents(docs)
document_store.update_embeddings(retriever)
query = "Tell me about Berlin?"
pipeline = GenerativeQAPipeline(retriever=retriever, generator=eli5_generator)
pipeline = GenerativeQAPipeline(generator=lfqa_generator, retriever=retriever)
output = pipeline.run(query=query, params={"top_k": 1})
answers = output["answers"]
assert len(answers) == 1
assert "Germany" in answers[0].answer
assert len(answers) == 1, answers
assert "Germany" in answers[0].answer, answers[0].answer
@pytest.mark.slow

View File

@ -51,7 +51,7 @@
"\n",
"# Install the latest master of Haystack\n",
"!pip install --upgrade pip\n",
"!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]"
"!pip install -q git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]"
]
},
{
@ -146,7 +146,7 @@
"\n",
"#### Retriever\n",
"\n",
"**Here:** We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`\n",
"We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`\n",
"\n"
]
},
@ -161,10 +161,12 @@
},
"outputs": [],
"source": [
"from haystack.nodes import EmbeddingRetriever\n",
"from haystack.nodes import DensePassageRetriever\n",
"\n",
"retriever = EmbeddingRetriever(\n",
" document_store=document_store, embedding_model=\"yjernite/retribert-base-uncased\", model_format=\"retribert\"\n",
"retriever = DensePassageRetriever(\n",
" document_store=document_store,\n",
" query_embedding_model=\"vblagoje/dpr-question_encoder-single-lfqa-wiki\",\n",
" passage_embedding_model=\"vblagoje/dpr-ctx_encoder-single-lfqa-wiki\",\n",
")\n",
"\n",
"document_store.update_embeddings(retriever)"
@ -176,25 +178,16 @@
"id": "sMlVEnJ2NkZZ"
},
"source": [
"Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."
"Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"id": "qpu-t9rndgpe"
},
"outputs": [
{
"ename": "SyntaxError",
"evalue": "EOL while scanning string literal (<ipython-input-1-cc681f017dc5>, line 7)",
"output_type": "error",
"traceback": [
"\u001b[0;36m File \u001b[0;32m\"<ipython-input-1-cc681f017dc5>\"\u001b[0;36m, line \u001b[0;32m7\u001b[0m\n\u001b[0;31m params={\"top_k_retriever=5\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m EOL while scanning string literal\n"
]
}
],
"outputs": [],
"source": [
"from haystack.utils import print_documents\n",
"from haystack.pipelines import DocumentSearchPipeline\n",
@ -214,7 +207,7 @@
"\n",
"Similar to previous Tutorials we now initalize our reader/generator.\n",
"\n",
"Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5)\n",
"Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa)\n",
"\n"
]
},
@ -226,7 +219,7 @@
},
"outputs": [],
"source": [
"generator = Seq2SeqGenerator(model_name_or_path=\"yjernite/bart_eli5\")"
"generator = Seq2SeqGenerator(model_name_or_path=\"vblagoje/bart_lfqa\")"
]
},
{
@ -274,25 +267,26 @@
"outputs": [],
"source": [
"pipe.run(\n",
" query=\"Why did Arya Stark's character get portrayed in a television adaptation?\", params={\"Retriever\": {\"top_k\": 1}}\n",
" query=\"How did Arya Stark's character get portrayed in a television adaptation?\", params={\"Retriever\": {\"top_k\": 3}}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zvHb8SvMblw9"
},
"outputs": [],
"source": [
"pipe.run(query=\"What kind of character does Arya Stark play?\", params={\"Retriever\": {\"top_k\": 1}})"
]
"pipe.run(query=\"Why is Arya Stark an unusual character?\", params={\"Retriever\": {\"top_k\": 3}})"
],
"metadata": {
"id": "IfTP9BfFGOo6"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
"collapsed": false,
"id": "i88KdOc2wUXQ"
},
"source": [
"## About us\n",
@ -340,5 +334,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 0
}

View File

@ -37,18 +37,20 @@ def tutorial12_lfqa():
"""
Initalize Retriever and Reader/Generator:
We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
"""
from haystack.nodes import EmbeddingRetriever
from haystack.nodes import DensePassageRetriever
retriever = EmbeddingRetriever(
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", model_format="retribert"
retriever = DensePassageRetriever(
document_store=document_store,
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
)
document_store.update_embeddings(retriever)
"""Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."""
"""Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."""
from haystack.utils import print_documents
from haystack.pipelines import DocumentSearchPipeline
@ -59,10 +61,10 @@ def tutorial12_lfqa():
"""
Similar to previous Tutorials we now initalize our reader/generator.
Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5)
Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa)
"""
generator = Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5")
generator = Seq2SeqGenerator(model_name_or_path="vblagoje/bart_lfqa")
"""
Pipeline:
@ -78,14 +80,14 @@ def tutorial12_lfqa():
"""Voilà! Ask a question!"""
query_1 = "Why did Arya Stark's character get portrayed in a television adaptation?"
result_1 = pipe.run(query=query_1, params={"Retriever": {"top_k": 1}})
query_1 = "How did Arya Stark's character get portrayed in a television adaptation?"
result_1 = pipe.run(query=query_1, params={"Retriever": {"top_k": 3}})
print(f"Query: {query_1}")
print(f"Answer: {result_1['answers'][0]}")
print()
query_2 = "What kind of character does Arya Stark play?"
result_2 = pipe.run(query=query_2, params={"Retriever": {"top_k": 1}})
query_2 = "Why is Arya Stark an unusual character?"
result_2 = pipe.run(query=query_2, params={"Retriever": {"top_k": 3}})
print(f"Query: {query_2}")
print(f"Answer: {result_2['answers'][0]}")
print()