mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-09 23:17:21 +00:00
Update LFQA with the latest LFQA seq2seq and retriever models (#2210)
* Register BartEli5Converter for vblagoje/bart_lfqa model * Update LFQA unit tests * Update LFQA tutorials
This commit is contained in:
parent
255226f9fa
commit
6c0094b5ad
@ -32,7 +32,7 @@ Make sure you enable the GPU runtime to experience decent speed in this tutorial
|
||||
|
||||
# Install the latest master of Haystack
|
||||
!pip install --upgrade pip
|
||||
!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]
|
||||
!pip install -q git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]
|
||||
```
|
||||
|
||||
|
||||
@ -80,22 +80,24 @@ document_store.write_documents(dicts)
|
||||
|
||||
#### Retriever
|
||||
|
||||
**Here:** We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
|
||||
We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
|
||||
|
||||
|
||||
|
||||
|
||||
```python
|
||||
from haystack.nodes import EmbeddingRetriever
|
||||
from haystack.nodes import DensePassageRetriever
|
||||
|
||||
retriever = EmbeddingRetriever(
|
||||
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", model_format="retribert"
|
||||
retriever = DensePassageRetriever(
|
||||
document_store=document_store,
|
||||
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
|
||||
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
|
||||
)
|
||||
|
||||
document_store.update_embeddings(retriever)
|
||||
```
|
||||
|
||||
Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents.
|
||||
Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents.
|
||||
|
||||
|
||||
```python
|
||||
@ -111,13 +113,13 @@ print_documents(res, max_text_len=512)
|
||||
|
||||
Similar to previous Tutorials we now initalize our reader/generator.
|
||||
|
||||
Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5)
|
||||
Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa)
|
||||
|
||||
|
||||
|
||||
|
||||
```python
|
||||
generator = Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5")
|
||||
generator = Seq2SeqGenerator(model_name_or_path="vblagoje/bart_lfqa")
|
||||
```
|
||||
|
||||
### Pipeline
|
||||
@ -139,13 +141,13 @@ pipe = GenerativeQAPipeline(generator, retriever)
|
||||
|
||||
```python
|
||||
pipe.run(
|
||||
query="Why did Arya Stark's character get portrayed in a television adaptation?", params={"Retriever": {"top_k": 1}}
|
||||
query="How did Arya Stark's character get portrayed in a television adaptation?", params={"Retriever": {"top_k": 3}}
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
pipe.run(query="What kind of character does Arya Stark play?", params={"Retriever": {"top_k": 1}})
|
||||
pipe.run(query="Why is Arya Stark an unusual character?", params={"Retriever": {"top_k": 3}})
|
||||
```
|
||||
|
||||
## About us
|
||||
|
@ -386,7 +386,8 @@ class Seq2SeqGenerator(BaseGenerator):
|
||||
def _register_converters(cls, model_name_or_path: str, custom_converter: Optional[Callable]):
|
||||
# init if empty
|
||||
if not cls._model_input_converters:
|
||||
cls._model_input_converters["yjernite/bart_eli5"] = _BartEli5Converter()
|
||||
for c in ["yjernite/bart_eli5", "vblagoje/bart_lfqa"]:
|
||||
cls._model_input_converters[c] = _BartEli5Converter()
|
||||
|
||||
# register user provided custom converter
|
||||
if model_name_or_path and custom_converter:
|
||||
|
@ -1307,12 +1307,12 @@
|
||||
},
|
||||
"preceding_context_len": {
|
||||
"title": "Preceding Context Len",
|
||||
"default": 3,
|
||||
"default": 1,
|
||||
"type": "integer"
|
||||
},
|
||||
"following_context_len": {
|
||||
"title": "Following Context Len",
|
||||
"default": 3,
|
||||
"default": 1,
|
||||
"type": "integer"
|
||||
},
|
||||
"remove_page_headers": {
|
||||
|
@ -304,8 +304,8 @@ def question_generator():
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def eli5_generator():
|
||||
return Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5", max_length=20)
|
||||
def lfqa_generator(request):
|
||||
return Seq2SeqGenerator(model_name_or_path=request.param, min_length=100, max_length=200)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@ -509,6 +509,14 @@ def get_retriever(retriever_type, document_store):
|
||||
model_format="retribert",
|
||||
use_gpu=False,
|
||||
)
|
||||
elif retriever_type == "dpr_lfqa":
|
||||
retriever = DensePassageRetriever(
|
||||
document_store=document_store,
|
||||
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
|
||||
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
|
||||
use_gpu=False,
|
||||
embed_title=True,
|
||||
)
|
||||
elif retriever_type == "elasticsearch":
|
||||
retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
elif retriever_type == "es_filter_only":
|
||||
|
@ -60,9 +60,10 @@ def test_generator_pipeline(document_store, retriever, rag_generator):
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.generator
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["retribert"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["retribert", "dpr_lfqa"], indirect=True)
|
||||
@pytest.mark.parametrize("lfqa_generator", ["yjernite/bart_eli5", "vblagoje/bart_lfqa"], indirect=True)
|
||||
@pytest.mark.embedding_dim(128)
|
||||
def test_lfqa_pipeline(document_store, retriever, eli5_generator):
|
||||
def test_lfqa_pipeline(document_store, retriever, lfqa_generator):
|
||||
# reuse existing DOCS but regenerate embeddings with retribert
|
||||
docs: List[Document] = []
|
||||
for idx, d in enumerate(DOCS_WITH_EMBEDDINGS):
|
||||
@ -70,11 +71,11 @@ def test_lfqa_pipeline(document_store, retriever, eli5_generator):
|
||||
document_store.write_documents(docs)
|
||||
document_store.update_embeddings(retriever)
|
||||
query = "Tell me about Berlin?"
|
||||
pipeline = GenerativeQAPipeline(retriever=retriever, generator=eli5_generator)
|
||||
pipeline = GenerativeQAPipeline(generator=lfqa_generator, retriever=retriever)
|
||||
output = pipeline.run(query=query, params={"top_k": 1})
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert "Germany" in answers[0].answer
|
||||
assert len(answers) == 1, answers
|
||||
assert "Germany" in answers[0].answer, answers[0].answer
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
|
@ -51,7 +51,7 @@
|
||||
"\n",
|
||||
"# Install the latest master of Haystack\n",
|
||||
"!pip install --upgrade pip\n",
|
||||
"!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]"
|
||||
"!pip install -q git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -146,7 +146,7 @@
|
||||
"\n",
|
||||
"#### Retriever\n",
|
||||
"\n",
|
||||
"**Here:** We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`\n",
|
||||
"We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
@ -161,10 +161,12 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from haystack.nodes import EmbeddingRetriever\n",
|
||||
"from haystack.nodes import DensePassageRetriever\n",
|
||||
"\n",
|
||||
"retriever = EmbeddingRetriever(\n",
|
||||
" document_store=document_store, embedding_model=\"yjernite/retribert-base-uncased\", model_format=\"retribert\"\n",
|
||||
"retriever = DensePassageRetriever(\n",
|
||||
" document_store=document_store,\n",
|
||||
" query_embedding_model=\"vblagoje/dpr-question_encoder-single-lfqa-wiki\",\n",
|
||||
" passage_embedding_model=\"vblagoje/dpr-ctx_encoder-single-lfqa-wiki\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"document_store.update_embeddings(retriever)"
|
||||
@ -176,25 +178,16 @@
|
||||
"id": "sMlVEnJ2NkZZ"
|
||||
},
|
||||
"source": [
|
||||
"Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."
|
||||
"Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "qpu-t9rndgpe"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "SyntaxError",
|
||||
"evalue": "EOL while scanning string literal (<ipython-input-1-cc681f017dc5>, line 7)",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;36m File \u001b[0;32m\"<ipython-input-1-cc681f017dc5>\"\u001b[0;36m, line \u001b[0;32m7\u001b[0m\n\u001b[0;31m params={\"top_k_retriever=5\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m EOL while scanning string literal\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from haystack.utils import print_documents\n",
|
||||
"from haystack.pipelines import DocumentSearchPipeline\n",
|
||||
@ -214,7 +207,7 @@
|
||||
"\n",
|
||||
"Similar to previous Tutorials we now initalize our reader/generator.\n",
|
||||
"\n",
|
||||
"Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5)\n",
|
||||
"Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
@ -226,7 +219,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"generator = Seq2SeqGenerator(model_name_or_path=\"yjernite/bart_eli5\")"
|
||||
"generator = Seq2SeqGenerator(model_name_or_path=\"vblagoje/bart_lfqa\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -274,25 +267,26 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pipe.run(\n",
|
||||
" query=\"Why did Arya Stark's character get portrayed in a television adaptation?\", params={\"Retriever\": {\"top_k\": 1}}\n",
|
||||
" query=\"How did Arya Stark's character get portrayed in a television adaptation?\", params={\"Retriever\": {\"top_k\": 3}}\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "zvHb8SvMblw9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pipe.run(query=\"What kind of character does Arya Stark play?\", params={\"Retriever\": {\"top_k\": 1}})"
|
||||
]
|
||||
"pipe.run(query=\"Why is Arya Stark an unusual character?\", params={\"Retriever\": {\"top_k\": 3}})"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "IfTP9BfFGOo6"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
"collapsed": false,
|
||||
"id": "i88KdOc2wUXQ"
|
||||
},
|
||||
"source": [
|
||||
"## About us\n",
|
||||
@ -340,5 +334,5 @@
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
|
@ -37,18 +37,20 @@ def tutorial12_lfqa():
|
||||
|
||||
"""
|
||||
Initalize Retriever and Reader/Generator:
|
||||
We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
|
||||
We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
|
||||
"""
|
||||
|
||||
from haystack.nodes import EmbeddingRetriever
|
||||
from haystack.nodes import DensePassageRetriever
|
||||
|
||||
retriever = EmbeddingRetriever(
|
||||
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", model_format="retribert"
|
||||
retriever = DensePassageRetriever(
|
||||
document_store=document_store,
|
||||
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
|
||||
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
|
||||
)
|
||||
|
||||
document_store.update_embeddings(retriever)
|
||||
|
||||
"""Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."""
|
||||
"""Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."""
|
||||
|
||||
from haystack.utils import print_documents
|
||||
from haystack.pipelines import DocumentSearchPipeline
|
||||
@ -59,10 +61,10 @@ def tutorial12_lfqa():
|
||||
|
||||
"""
|
||||
Similar to previous Tutorials we now initalize our reader/generator.
|
||||
Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5)
|
||||
Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa)
|
||||
"""
|
||||
|
||||
generator = Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5")
|
||||
generator = Seq2SeqGenerator(model_name_or_path="vblagoje/bart_lfqa")
|
||||
|
||||
"""
|
||||
Pipeline:
|
||||
@ -78,14 +80,14 @@ def tutorial12_lfqa():
|
||||
|
||||
"""Voilà! Ask a question!"""
|
||||
|
||||
query_1 = "Why did Arya Stark's character get portrayed in a television adaptation?"
|
||||
result_1 = pipe.run(query=query_1, params={"Retriever": {"top_k": 1}})
|
||||
query_1 = "How did Arya Stark's character get portrayed in a television adaptation?"
|
||||
result_1 = pipe.run(query=query_1, params={"Retriever": {"top_k": 3}})
|
||||
print(f"Query: {query_1}")
|
||||
print(f"Answer: {result_1['answers'][0]}")
|
||||
print()
|
||||
|
||||
query_2 = "What kind of character does Arya Stark play?"
|
||||
result_2 = pipe.run(query=query_2, params={"Retriever": {"top_k": 1}})
|
||||
query_2 = "Why is Arya Stark an unusual character?"
|
||||
result_2 = pipe.run(query=query_2, params={"Retriever": {"top_k": 3}})
|
||||
print(f"Query: {query_2}")
|
||||
print(f"Answer: {result_2['answers'][0]}")
|
||||
print()
|
||||
|
Loading…
x
Reference in New Issue
Block a user