test: Added integration test for using EntityExtractor in query pipeline (#4117)

* Added new test for using EntityExtractor in query node and made some fixtures to reduce code duplication.

* Reuse ner_node fixture

* Added pytest unit markings and swapped over to in memory doc store.

* Change to integration tests
This commit is contained in:
Sebastian 2023-02-28 09:20:44 +01:00 committed by GitHub
parent 5678bb6375
commit 040d806b42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,22 +4,32 @@ from haystack.nodes import TextConverter
from haystack.nodes.retriever.sparse import BM25Retriever from haystack.nodes.retriever.sparse import BM25Retriever
from haystack.nodes.reader import FARMReader from haystack.nodes.reader import FARMReader
from haystack.pipelines import Pipeline from haystack.pipelines import Pipeline
from haystack import Document
from haystack.nodes.extractor import EntityExtractor, simplify_ner_for_qa from haystack.nodes.extractor import EntityExtractor, simplify_ner_for_qa
from ..conftest import SAMPLES_PATH from ..conftest import SAMPLES_PATH
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) @pytest.fixture
def test_extractor(document_store_with_docs): def tiny_reader():
return FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
@pytest.fixture
def ner_node():
return EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
@pytest.mark.integration
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
def test_extractor(document_store_with_docs, tiny_reader, ner_node):
es_retriever = BM25Retriever(document_store=document_store_with_docs) es_retriever = BM25Retriever(document_store=document_store_with_docs)
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"]) pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"]) pipeline.add_node(component=ner_node, name="NER", inputs=["ESRetriever"])
pipeline.add_node(component=reader, name="Reader", inputs=["NER"]) pipeline.add_node(component=tiny_reader, name="Reader", inputs=["NER"])
prediction = pipeline.run( prediction = pipeline.run(
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}} query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}}
@ -29,16 +39,15 @@ def test_extractor(document_store_with_docs):
assert "Berlin" in entities assert "Berlin" in entities
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) @pytest.mark.integration
def test_extractor_batch_single_query(document_store_with_docs): @pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
def test_extractor_batch_single_query(document_store_with_docs, tiny_reader, ner_node):
es_retriever = BM25Retriever(document_store=document_store_with_docs) es_retriever = BM25Retriever(document_store=document_store_with_docs)
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"]) pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"]) pipeline.add_node(component=ner_node, name="NER", inputs=["ESRetriever"])
pipeline.add_node(component=reader, name="Reader", inputs=["NER"]) pipeline.add_node(component=tiny_reader, name="Reader", inputs=["NER"])
prediction = pipeline.run_batch( prediction = pipeline.run_batch(
queries=["Who lives in Berlin?"], params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}} queries=["Who lives in Berlin?"], params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}}
@ -48,16 +57,15 @@ def test_extractor_batch_single_query(document_store_with_docs):
assert "Berlin" in entities assert "Berlin" in entities
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) @pytest.mark.integration
def test_extractor_batch_multiple_queries(document_store_with_docs): @pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
def test_extractor_batch_multiple_queries(document_store_with_docs, tiny_reader, ner_node):
es_retriever = BM25Retriever(document_store=document_store_with_docs) es_retriever = BM25Retriever(document_store=document_store_with_docs)
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"]) pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"]) pipeline.add_node(component=ner_node, name="NER", inputs=["ESRetriever"])
pipeline.add_node(component=reader, name="Reader", inputs=["NER"]) pipeline.add_node(component=tiny_reader, name="Reader", inputs=["NER"])
prediction = pipeline.run_batch( prediction = pipeline.run_batch(
queries=["Who lives in Berlin?", "Who lives in New York?"], queries=["Who lives in Berlin?", "Who lives in New York?"],
@ -71,16 +79,15 @@ def test_extractor_batch_multiple_queries(document_store_with_docs):
assert "New York" in entities_paul assert "New York" in entities_paul
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) @pytest.mark.integration
def test_extractor_output_simplifier(document_store_with_docs): @pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
def test_extractor_output_simplifier(document_store_with_docs, tiny_reader, ner_node):
es_retriever = BM25Retriever(document_store=document_store_with_docs) es_retriever = BM25Retriever(document_store=document_store_with_docs)
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
pipeline = Pipeline() pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"]) pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"]) pipeline.add_node(component=ner_node, name="NER", inputs=["ESRetriever"])
pipeline.add_node(component=reader, name="Reader", inputs=["NER"]) pipeline.add_node(component=tiny_reader, name="Reader", inputs=["NER"])
prediction = pipeline.run( prediction = pipeline.run(
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}} query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}}
@ -89,7 +96,8 @@ def test_extractor_output_simplifier(document_store_with_docs):
assert simplified[0] == {"answer": "Carla and I", "entities": ["Carla"]} assert simplified[0] == {"answer": "Carla and I", "entities": ["Carla"]}
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True) @pytest.mark.integration
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
def test_extractor_indexing(document_store): def test_extractor_indexing(document_store):
doc_path = SAMPLES_PATH / "docs" / "doc_2.txt" doc_path = SAMPLES_PATH / "docs" / "doc_2.txt"
@ -109,6 +117,18 @@ def test_extractor_indexing(document_store):
assert "Haystack" in meta["entity_words"] assert "Haystack" in meta["entity_words"]
@pytest.mark.integration
def test_extractor_doc_query(ner_node):
pipeline = Pipeline()
pipeline.add_node(component=ner_node, name="NER", inputs=["Query"])
prediction = pipeline.run(query=None, documents=[Document(content="Carla lives in Berlin", content_type="text")])
entities = [x["word"] for x in prediction["documents"][0].meta["entities"]]
assert "Carla" in entities
assert "Berlin" in entities
@pytest.mark.integration
def test_extract_method(): def test_extract_method():
ner = EntityExtractor( ner = EntityExtractor(
model_name_or_path="Jean-Baptiste/camembert-ner", max_seq_len=12, aggregation_strategy="first" model_name_or_path="Jean-Baptiste/camembert-ner", max_seq_len=12, aggregation_strategy="first"
@ -138,6 +158,7 @@ def test_extract_method():
] ]
@pytest.mark.integration
def test_extract_method_pre_split_text(): def test_extract_method_pre_split_text():
ner = EntityExtractor( ner = EntityExtractor(
model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english", max_seq_len=6, pre_split_text=True model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english", max_seq_len=6, pre_split_text=True
@ -167,6 +188,7 @@ def test_extract_method_pre_split_text():
] ]
@pytest.mark.integration
def test_extract_method_unknown_token(): def test_extract_method_unknown_token():
ner = EntityExtractor( ner = EntityExtractor(
model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english", model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english",
@ -196,6 +218,7 @@ def test_extract_method_unknown_token():
assert output == [{"entity_group": "O", "word": "Hi my name is JamesÐ.", "start": 0, "end": 21}] assert output == [{"entity_group": "O", "word": "Hi my name is JamesÐ.", "start": 0, "end": 21}]
@pytest.mark.integration
def test_extract_method_simple_aggregation(): def test_extract_method_simple_aggregation():
ner = EntityExtractor( ner = EntityExtractor(
model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english", model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english",