mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-31 11:56:35 +00:00
test: Added integration test for using EntityExtractor in query pipeline (#4117)
* Added new test for using EntityExtractor in query node and made some fixtures to reduce code duplication. * Reuse ner_node fixture * Added pytest unit markings and swapped over to in memory doc store. * Change to integration tests
This commit is contained in:
parent
5678bb6375
commit
040d806b42
@ -4,22 +4,32 @@ from haystack.nodes import TextConverter
|
||||
from haystack.nodes.retriever.sparse import BM25Retriever
|
||||
from haystack.nodes.reader import FARMReader
|
||||
from haystack.pipelines import Pipeline
|
||||
from haystack import Document
|
||||
|
||||
from haystack.nodes.extractor import EntityExtractor, simplify_ner_for_qa
|
||||
|
||||
from ..conftest import SAMPLES_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_extractor(document_store_with_docs):
|
||||
@pytest.fixture
|
||||
def tiny_reader():
|
||||
return FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ner_node():
|
||||
return EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_extractor(document_store_with_docs, tiny_reader, ner_node):
|
||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
|
||||
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
|
||||
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
|
||||
pipeline.add_node(component=ner_node, name="NER", inputs=["ESRetriever"])
|
||||
pipeline.add_node(component=tiny_reader, name="Reader", inputs=["NER"])
|
||||
|
||||
prediction = pipeline.run(
|
||||
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}}
|
||||
@ -29,16 +39,15 @@ def test_extractor(document_store_with_docs):
|
||||
assert "Berlin" in entities
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_extractor_batch_single_query(document_store_with_docs):
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_extractor_batch_single_query(document_store_with_docs, tiny_reader, ner_node):
|
||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
|
||||
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
|
||||
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
|
||||
pipeline.add_node(component=ner_node, name="NER", inputs=["ESRetriever"])
|
||||
pipeline.add_node(component=tiny_reader, name="Reader", inputs=["NER"])
|
||||
|
||||
prediction = pipeline.run_batch(
|
||||
queries=["Who lives in Berlin?"], params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}}
|
||||
@ -48,16 +57,15 @@ def test_extractor_batch_single_query(document_store_with_docs):
|
||||
assert "Berlin" in entities
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_extractor_batch_multiple_queries(document_store_with_docs):
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_extractor_batch_multiple_queries(document_store_with_docs, tiny_reader, ner_node):
|
||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
|
||||
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
|
||||
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
|
||||
pipeline.add_node(component=ner_node, name="NER", inputs=["ESRetriever"])
|
||||
pipeline.add_node(component=tiny_reader, name="Reader", inputs=["NER"])
|
||||
|
||||
prediction = pipeline.run_batch(
|
||||
queries=["Who lives in Berlin?", "Who lives in New York?"],
|
||||
@ -71,16 +79,15 @@ def test_extractor_batch_multiple_queries(document_store_with_docs):
|
||||
assert "New York" in entities_paul
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_extractor_output_simplifier(document_store_with_docs):
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
def test_extractor_output_simplifier(document_store_with_docs, tiny_reader, ner_node):
|
||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
|
||||
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
|
||||
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
|
||||
pipeline.add_node(component=ner_node, name="NER", inputs=["ESRetriever"])
|
||||
pipeline.add_node(component=tiny_reader, name="Reader", inputs=["NER"])
|
||||
|
||||
prediction = pipeline.run(
|
||||
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}}
|
||||
@ -89,7 +96,8 @@ def test_extractor_output_simplifier(document_store_with_docs):
|
||||
assert simplified[0] == {"answer": "Carla and I", "entities": ["Carla"]}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
def test_extractor_indexing(document_store):
|
||||
doc_path = SAMPLES_PATH / "docs" / "doc_2.txt"
|
||||
|
||||
@ -109,6 +117,18 @@ def test_extractor_indexing(document_store):
|
||||
assert "Haystack" in meta["entity_words"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_extractor_doc_query(ner_node):
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=ner_node, name="NER", inputs=["Query"])
|
||||
|
||||
prediction = pipeline.run(query=None, documents=[Document(content="Carla lives in Berlin", content_type="text")])
|
||||
entities = [x["word"] for x in prediction["documents"][0].meta["entities"]]
|
||||
assert "Carla" in entities
|
||||
assert "Berlin" in entities
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_extract_method():
|
||||
ner = EntityExtractor(
|
||||
model_name_or_path="Jean-Baptiste/camembert-ner", max_seq_len=12, aggregation_strategy="first"
|
||||
@ -138,6 +158,7 @@ def test_extract_method():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_extract_method_pre_split_text():
|
||||
ner = EntityExtractor(
|
||||
model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english", max_seq_len=6, pre_split_text=True
|
||||
@ -167,6 +188,7 @@ def test_extract_method_pre_split_text():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_extract_method_unknown_token():
|
||||
ner = EntityExtractor(
|
||||
model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english",
|
||||
@ -196,6 +218,7 @@ def test_extract_method_unknown_token():
|
||||
assert output == [{"entity_group": "O", "word": "Hi my name is JamesÐ.", "start": 0, "end": 21}]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_extract_method_simple_aggregation():
|
||||
ner = EntityExtractor(
|
||||
model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english",
|
||||
|
Loading…
x
Reference in New Issue
Block a user