mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-12 08:28:25 +00:00
Update LFQA with the latest LFQA seq2seq and retriever models (#2210)
* Register BartEli5Converter for vblagoje/bart_lfqa model * Update LFQA unit tests * Update LFQA tutorials
This commit is contained in:
parent
255226f9fa
commit
6c0094b5ad
@ -32,7 +32,7 @@ Make sure you enable the GPU runtime to experience decent speed in this tutorial
|
|||||||
|
|
||||||
# Install the latest master of Haystack
|
# Install the latest master of Haystack
|
||||||
!pip install --upgrade pip
|
!pip install --upgrade pip
|
||||||
!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]
|
!pip install -q git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@ -80,22 +80,24 @@ document_store.write_documents(dicts)
|
|||||||
|
|
||||||
#### Retriever
|
#### Retriever
|
||||||
|
|
||||||
**Here:** We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
|
We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from haystack.nodes import EmbeddingRetriever
|
from haystack.nodes import DensePassageRetriever
|
||||||
|
|
||||||
retriever = EmbeddingRetriever(
|
retriever = DensePassageRetriever(
|
||||||
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", model_format="retribert"
|
document_store=document_store,
|
||||||
|
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
|
||||||
|
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
|
||||||
)
|
)
|
||||||
|
|
||||||
document_store.update_embeddings(retriever)
|
document_store.update_embeddings(retriever)
|
||||||
```
|
```
|
||||||
|
|
||||||
Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents.
|
Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents.
|
||||||
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@ -111,13 +113,13 @@ print_documents(res, max_text_len=512)
|
|||||||
|
|
||||||
Similar to previous Tutorials we now initalize our reader/generator.
|
Similar to previous Tutorials we now initalize our reader/generator.
|
||||||
|
|
||||||
Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5)
|
Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
generator = Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5")
|
generator = Seq2SeqGenerator(model_name_or_path="vblagoje/bart_lfqa")
|
||||||
```
|
```
|
||||||
|
|
||||||
### Pipeline
|
### Pipeline
|
||||||
@ -139,13 +141,13 @@ pipe = GenerativeQAPipeline(generator, retriever)
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
pipe.run(
|
pipe.run(
|
||||||
query="Why did Arya Stark's character get portrayed in a television adaptation?", params={"Retriever": {"top_k": 1}}
|
query="How did Arya Stark's character get portrayed in a television adaptation?", params={"Retriever": {"top_k": 3}}
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
pipe.run(query="What kind of character does Arya Stark play?", params={"Retriever": {"top_k": 1}})
|
pipe.run(query="Why is Arya Stark an unusual character?", params={"Retriever": {"top_k": 3}})
|
||||||
```
|
```
|
||||||
|
|
||||||
## About us
|
## About us
|
||||||
|
@ -386,7 +386,8 @@ class Seq2SeqGenerator(BaseGenerator):
|
|||||||
def _register_converters(cls, model_name_or_path: str, custom_converter: Optional[Callable]):
|
def _register_converters(cls, model_name_or_path: str, custom_converter: Optional[Callable]):
|
||||||
# init if empty
|
# init if empty
|
||||||
if not cls._model_input_converters:
|
if not cls._model_input_converters:
|
||||||
cls._model_input_converters["yjernite/bart_eli5"] = _BartEli5Converter()
|
for c in ["yjernite/bart_eli5", "vblagoje/bart_lfqa"]:
|
||||||
|
cls._model_input_converters[c] = _BartEli5Converter()
|
||||||
|
|
||||||
# register user provided custom converter
|
# register user provided custom converter
|
||||||
if model_name_or_path and custom_converter:
|
if model_name_or_path and custom_converter:
|
||||||
|
@ -1307,12 +1307,12 @@
|
|||||||
},
|
},
|
||||||
"preceding_context_len": {
|
"preceding_context_len": {
|
||||||
"title": "Preceding Context Len",
|
"title": "Preceding Context Len",
|
||||||
"default": 3,
|
"default": 1,
|
||||||
"type": "integer"
|
"type": "integer"
|
||||||
},
|
},
|
||||||
"following_context_len": {
|
"following_context_len": {
|
||||||
"title": "Following Context Len",
|
"title": "Following Context Len",
|
||||||
"default": 3,
|
"default": 1,
|
||||||
"type": "integer"
|
"type": "integer"
|
||||||
},
|
},
|
||||||
"remove_page_headers": {
|
"remove_page_headers": {
|
||||||
|
@ -304,8 +304,8 @@ def question_generator():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def eli5_generator():
|
def lfqa_generator(request):
|
||||||
return Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5", max_length=20)
|
return Seq2SeqGenerator(model_name_or_path=request.param, min_length=100, max_length=200)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
@ -509,6 +509,14 @@ def get_retriever(retriever_type, document_store):
|
|||||||
model_format="retribert",
|
model_format="retribert",
|
||||||
use_gpu=False,
|
use_gpu=False,
|
||||||
)
|
)
|
||||||
|
elif retriever_type == "dpr_lfqa":
|
||||||
|
retriever = DensePassageRetriever(
|
||||||
|
document_store=document_store,
|
||||||
|
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
|
||||||
|
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
|
||||||
|
use_gpu=False,
|
||||||
|
embed_title=True,
|
||||||
|
)
|
||||||
elif retriever_type == "elasticsearch":
|
elif retriever_type == "elasticsearch":
|
||||||
retriever = ElasticsearchRetriever(document_store=document_store)
|
retriever = ElasticsearchRetriever(document_store=document_store)
|
||||||
elif retriever_type == "es_filter_only":
|
elif retriever_type == "es_filter_only":
|
||||||
|
@ -60,9 +60,10 @@ def test_generator_pipeline(document_store, retriever, rag_generator):
|
|||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.generator
|
@pytest.mark.generator
|
||||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||||
@pytest.mark.parametrize("retriever", ["retribert"], indirect=True)
|
@pytest.mark.parametrize("retriever", ["retribert", "dpr_lfqa"], indirect=True)
|
||||||
|
@pytest.mark.parametrize("lfqa_generator", ["yjernite/bart_eli5", "vblagoje/bart_lfqa"], indirect=True)
|
||||||
@pytest.mark.embedding_dim(128)
|
@pytest.mark.embedding_dim(128)
|
||||||
def test_lfqa_pipeline(document_store, retriever, eli5_generator):
|
def test_lfqa_pipeline(document_store, retriever, lfqa_generator):
|
||||||
# reuse existing DOCS but regenerate embeddings with retribert
|
# reuse existing DOCS but regenerate embeddings with retribert
|
||||||
docs: List[Document] = []
|
docs: List[Document] = []
|
||||||
for idx, d in enumerate(DOCS_WITH_EMBEDDINGS):
|
for idx, d in enumerate(DOCS_WITH_EMBEDDINGS):
|
||||||
@ -70,11 +71,11 @@ def test_lfqa_pipeline(document_store, retriever, eli5_generator):
|
|||||||
document_store.write_documents(docs)
|
document_store.write_documents(docs)
|
||||||
document_store.update_embeddings(retriever)
|
document_store.update_embeddings(retriever)
|
||||||
query = "Tell me about Berlin?"
|
query = "Tell me about Berlin?"
|
||||||
pipeline = GenerativeQAPipeline(retriever=retriever, generator=eli5_generator)
|
pipeline = GenerativeQAPipeline(generator=lfqa_generator, retriever=retriever)
|
||||||
output = pipeline.run(query=query, params={"top_k": 1})
|
output = pipeline.run(query=query, params={"top_k": 1})
|
||||||
answers = output["answers"]
|
answers = output["answers"]
|
||||||
assert len(answers) == 1
|
assert len(answers) == 1, answers
|
||||||
assert "Germany" in answers[0].answer
|
assert "Germany" in answers[0].answer, answers[0].answer
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
@ -51,7 +51,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"# Install the latest master of Haystack\n",
|
"# Install the latest master of Haystack\n",
|
||||||
"!pip install --upgrade pip\n",
|
"!pip install --upgrade pip\n",
|
||||||
"!pip install git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]"
|
"!pip install -q git+https://github.com/deepset-ai/haystack.git#egg=farm-haystack[colab,faiss]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -146,7 +146,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"#### Retriever\n",
|
"#### Retriever\n",
|
||||||
"\n",
|
"\n",
|
||||||
"**Here:** We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`\n",
|
"We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`\n",
|
||||||
"\n"
|
"\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -161,10 +161,12 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from haystack.nodes import EmbeddingRetriever\n",
|
"from haystack.nodes import DensePassageRetriever\n",
|
||||||
"\n",
|
"\n",
|
||||||
"retriever = EmbeddingRetriever(\n",
|
"retriever = DensePassageRetriever(\n",
|
||||||
" document_store=document_store, embedding_model=\"yjernite/retribert-base-uncased\", model_format=\"retribert\"\n",
|
" document_store=document_store,\n",
|
||||||
|
" query_embedding_model=\"vblagoje/dpr-question_encoder-single-lfqa-wiki\",\n",
|
||||||
|
" passage_embedding_model=\"vblagoje/dpr-ctx_encoder-single-lfqa-wiki\",\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"document_store.update_embeddings(retriever)"
|
"document_store.update_embeddings(retriever)"
|
||||||
@ -176,25 +178,16 @@
|
|||||||
"id": "sMlVEnJ2NkZZ"
|
"id": "sMlVEnJ2NkZZ"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."
|
"Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "qpu-t9rndgpe"
|
"id": "qpu-t9rndgpe"
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"ename": "SyntaxError",
|
|
||||||
"evalue": "EOL while scanning string literal (<ipython-input-1-cc681f017dc5>, line 7)",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[0;36m File \u001b[0;32m\"<ipython-input-1-cc681f017dc5>\"\u001b[0;36m, line \u001b[0;32m7\u001b[0m\n\u001b[0;31m params={\"top_k_retriever=5\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m EOL while scanning string literal\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from haystack.utils import print_documents\n",
|
"from haystack.utils import print_documents\n",
|
||||||
"from haystack.pipelines import DocumentSearchPipeline\n",
|
"from haystack.pipelines import DocumentSearchPipeline\n",
|
||||||
@ -214,7 +207,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"Similar to previous Tutorials we now initalize our reader/generator.\n",
|
"Similar to previous Tutorials we now initalize our reader/generator.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5)\n",
|
"Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa)\n",
|
||||||
"\n"
|
"\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -226,7 +219,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"generator = Seq2SeqGenerator(model_name_or_path=\"yjernite/bart_eli5\")"
|
"generator = Seq2SeqGenerator(model_name_or_path=\"vblagoje/bart_lfqa\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -274,25 +267,26 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"pipe.run(\n",
|
"pipe.run(\n",
|
||||||
" query=\"Why did Arya Stark's character get portrayed in a television adaptation?\", params={\"Retriever\": {\"top_k\": 1}}\n",
|
" query=\"How did Arya Stark's character get portrayed in a television adaptation?\", params={\"Retriever\": {\"top_k\": 3}}\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "zvHb8SvMblw9"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
"source": [
|
||||||
"pipe.run(query=\"What kind of character does Arya Stark play?\", params={\"Retriever\": {\"top_k\": 1}})"
|
"pipe.run(query=\"Why is Arya Stark an unusual character?\", params={\"Retriever\": {\"top_k\": 3}})"
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "IfTP9BfFGOo6"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false
|
"collapsed": false,
|
||||||
|
"id": "i88KdOc2wUXQ"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"## About us\n",
|
"## About us\n",
|
||||||
@ -340,5 +334,5 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 2
|
"nbformat_minor": 0
|
||||||
}
|
}
|
||||||
|
@ -37,18 +37,20 @@ def tutorial12_lfqa():
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Initalize Retriever and Reader/Generator:
|
Initalize Retriever and Reader/Generator:
|
||||||
We use a `RetribertRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
|
We use a `DensePassageRetriever` and we invoke `update_embeddings` to index the embeddings of documents in the `FAISSDocumentStore`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from haystack.nodes import EmbeddingRetriever
|
from haystack.nodes import DensePassageRetriever
|
||||||
|
|
||||||
retriever = EmbeddingRetriever(
|
retriever = DensePassageRetriever(
|
||||||
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", model_format="retribert"
|
document_store=document_store,
|
||||||
|
query_embedding_model="vblagoje/dpr-question_encoder-single-lfqa-wiki",
|
||||||
|
passage_embedding_model="vblagoje/dpr-ctx_encoder-single-lfqa-wiki",
|
||||||
)
|
)
|
||||||
|
|
||||||
document_store.update_embeddings(retriever)
|
document_store.update_embeddings(retriever)
|
||||||
|
|
||||||
"""Before we blindly use the `RetribertRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."""
|
"""Before we blindly use the `DensePassageRetriever` let's empirically test it to make sure a simple search indeed finds the relevant documents."""
|
||||||
|
|
||||||
from haystack.utils import print_documents
|
from haystack.utils import print_documents
|
||||||
from haystack.pipelines import DocumentSearchPipeline
|
from haystack.pipelines import DocumentSearchPipeline
|
||||||
@ -59,10 +61,10 @@ def tutorial12_lfqa():
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Similar to previous Tutorials we now initalize our reader/generator.
|
Similar to previous Tutorials we now initalize our reader/generator.
|
||||||
Here we use a `Seq2SeqGenerator` with the *yjernite/bart_eli5* model (see: https://huggingface.co/yjernite/bart_eli5)
|
Here we use a `Seq2SeqGenerator` with the *vblagoje/bart_lfqa* model (see: https://huggingface.co/vblagoje/bart_lfqa)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
generator = Seq2SeqGenerator(model_name_or_path="yjernite/bart_eli5")
|
generator = Seq2SeqGenerator(model_name_or_path="vblagoje/bart_lfqa")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Pipeline:
|
Pipeline:
|
||||||
@ -78,14 +80,14 @@ def tutorial12_lfqa():
|
|||||||
|
|
||||||
"""Voilà! Ask a question!"""
|
"""Voilà! Ask a question!"""
|
||||||
|
|
||||||
query_1 = "Why did Arya Stark's character get portrayed in a television adaptation?"
|
query_1 = "How did Arya Stark's character get portrayed in a television adaptation?"
|
||||||
result_1 = pipe.run(query=query_1, params={"Retriever": {"top_k": 1}})
|
result_1 = pipe.run(query=query_1, params={"Retriever": {"top_k": 3}})
|
||||||
print(f"Query: {query_1}")
|
print(f"Query: {query_1}")
|
||||||
print(f"Answer: {result_1['answers'][0]}")
|
print(f"Answer: {result_1['answers'][0]}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
query_2 = "What kind of character does Arya Stark play?"
|
query_2 = "Why is Arya Stark an unusual character?"
|
||||||
result_2 = pipe.run(query=query_2, params={"Retriever": {"top_k": 1}})
|
result_2 = pipe.run(query=query_2, params={"Retriever": {"top_k": 3}})
|
||||||
print(f"Query: {query_2}")
|
print(f"Query: {query_2}")
|
||||||
print(f"Answer: {result_2['answers'][0]}")
|
print(f"Answer: {result_2['answers'][0]}")
|
||||||
print()
|
print()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user