test: Added integration test for using EntityExtractor in query pipeline (#4117)

* Added new test for using EntityExtractor in query node and made some fixtures to reduce code duplication.

* Reuse ner_node fixture

* Added pytest unit markings and swapped over to in memory doc store.

* Change to integration tests
This commit is contained in:
Sebastian 2023-02-28 09:20:44 +01:00 committed by GitHub
parent 5678bb6375
commit 040d806b42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,22 +4,32 @@ from haystack.nodes import TextConverter
from haystack.nodes.retriever.sparse import BM25Retriever
from haystack.nodes.reader import FARMReader
from haystack.pipelines import Pipeline
from haystack import Document
from haystack.nodes.extractor import EntityExtractor, simplify_ner_for_qa
from ..conftest import SAMPLES_PATH
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_extractor(document_store_with_docs):
@pytest.fixture
def tiny_reader():
return FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
@pytest.fixture
def ner_node():
return EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
@pytest.mark.integration
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
def test_extractor(document_store_with_docs, tiny_reader, ner_node):
es_retriever = BM25Retriever(document_store=document_store_with_docs)
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
pipeline.add_node(component=ner_node, name="NER", inputs=["ESRetriever"])
pipeline.add_node(component=tiny_reader, name="Reader", inputs=["NER"])
prediction = pipeline.run(
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}}
@ -29,16 +39,15 @@ def test_extractor(document_store_with_docs):
assert "Berlin" in entities
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_extractor_batch_single_query(document_store_with_docs):
@pytest.mark.integration
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
def test_extractor_batch_single_query(document_store_with_docs, tiny_reader, ner_node):
es_retriever = BM25Retriever(document_store=document_store_with_docs)
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
pipeline.add_node(component=ner_node, name="NER", inputs=["ESRetriever"])
pipeline.add_node(component=tiny_reader, name="Reader", inputs=["NER"])
prediction = pipeline.run_batch(
queries=["Who lives in Berlin?"], params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}}
@ -48,16 +57,15 @@ def test_extractor_batch_single_query(document_store_with_docs):
assert "Berlin" in entities
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_extractor_batch_multiple_queries(document_store_with_docs):
@pytest.mark.integration
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
def test_extractor_batch_multiple_queries(document_store_with_docs, tiny_reader, ner_node):
es_retriever = BM25Retriever(document_store=document_store_with_docs)
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
pipeline.add_node(component=ner_node, name="NER", inputs=["ESRetriever"])
pipeline.add_node(component=tiny_reader, name="Reader", inputs=["NER"])
prediction = pipeline.run_batch(
queries=["Who lives in Berlin?", "Who lives in New York?"],
@ -71,16 +79,15 @@ def test_extractor_batch_multiple_queries(document_store_with_docs):
assert "New York" in entities_paul
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_extractor_output_simplifier(document_store_with_docs):
@pytest.mark.integration
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
def test_extractor_output_simplifier(document_store_with_docs, tiny_reader, ner_node):
es_retriever = BM25Retriever(document_store=document_store_with_docs)
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
pipeline.add_node(component=ner_node, name="NER", inputs=["ESRetriever"])
pipeline.add_node(component=tiny_reader, name="Reader", inputs=["NER"])
prediction = pipeline.run(
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}}
@ -89,7 +96,8 @@ def test_extractor_output_simplifier(document_store_with_docs):
assert simplified[0] == {"answer": "Carla and I", "entities": ["Carla"]}
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
@pytest.mark.integration
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
def test_extractor_indexing(document_store):
doc_path = SAMPLES_PATH / "docs" / "doc_2.txt"
@ -109,6 +117,18 @@ def test_extractor_indexing(document_store):
assert "Haystack" in meta["entity_words"]
@pytest.mark.integration
def test_extractor_doc_query(ner_node):
pipeline = Pipeline()
pipeline.add_node(component=ner_node, name="NER", inputs=["Query"])
prediction = pipeline.run(query=None, documents=[Document(content="Carla lives in Berlin", content_type="text")])
entities = [x["word"] for x in prediction["documents"][0].meta["entities"]]
assert "Carla" in entities
assert "Berlin" in entities
@pytest.mark.integration
def test_extract_method():
ner = EntityExtractor(
model_name_or_path="Jean-Baptiste/camembert-ner", max_seq_len=12, aggregation_strategy="first"
@ -138,6 +158,7 @@ def test_extract_method():
]
@pytest.mark.integration
def test_extract_method_pre_split_text():
ner = EntityExtractor(
model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english", max_seq_len=6, pre_split_text=True
@ -167,6 +188,7 @@ def test_extract_method_pre_split_text():
]
@pytest.mark.integration
def test_extract_method_unknown_token():
ner = EntityExtractor(
model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english",
@ -196,6 +218,7 @@ def test_extract_method_unknown_token():
assert output == [{"entity_group": "O", "word": "Hi my name is JamesÐ.", "start": 0, "end": 21}]
@pytest.mark.integration
def test_extract_method_simple_aggregation():
ner = EntityExtractor(
model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english",