mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-25 14:59:01 +00:00
test: Added integration test for using EntityExtractor in query pipeline (#4117)
* Added new test for using EntityExtractor in query node and made some fixtures to reduce code duplication. * Reuse ner_node fixture * Added pytest unit markings and swapped over to in memory doc store. * Change to integration tests
This commit is contained in:
parent
5678bb6375
commit
040d806b42
@ -4,22 +4,32 @@ from haystack.nodes import TextConverter
|
|||||||
from haystack.nodes.retriever.sparse import BM25Retriever
|
from haystack.nodes.retriever.sparse import BM25Retriever
|
||||||
from haystack.nodes.reader import FARMReader
|
from haystack.nodes.reader import FARMReader
|
||||||
from haystack.pipelines import Pipeline
|
from haystack.pipelines import Pipeline
|
||||||
|
from haystack import Document
|
||||||
|
|
||||||
from haystack.nodes.extractor import EntityExtractor, simplify_ner_for_qa
|
from haystack.nodes.extractor import EntityExtractor, simplify_ner_for_qa
|
||||||
|
|
||||||
from ..conftest import SAMPLES_PATH
|
from ..conftest import SAMPLES_PATH
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
@pytest.fixture
|
||||||
def test_extractor(document_store_with_docs):
|
def tiny_reader():
|
||||||
|
return FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ner_node():
|
||||||
|
return EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||||
|
def test_extractor(document_store_with_docs, tiny_reader, ner_node):
|
||||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||||
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
|
|
||||||
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
|
|
||||||
|
|
||||||
pipeline = Pipeline()
|
pipeline = Pipeline()
|
||||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||||
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
|
pipeline.add_node(component=ner_node, name="NER", inputs=["ESRetriever"])
|
||||||
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
|
pipeline.add_node(component=tiny_reader, name="Reader", inputs=["NER"])
|
||||||
|
|
||||||
prediction = pipeline.run(
|
prediction = pipeline.run(
|
||||||
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}}
|
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}}
|
||||||
@ -29,16 +39,15 @@ def test_extractor(document_store_with_docs):
|
|||||||
assert "Berlin" in entities
|
assert "Berlin" in entities
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
@pytest.mark.integration
|
||||||
def test_extractor_batch_single_query(document_store_with_docs):
|
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||||
|
def test_extractor_batch_single_query(document_store_with_docs, tiny_reader, ner_node):
|
||||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||||
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
|
|
||||||
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
|
|
||||||
|
|
||||||
pipeline = Pipeline()
|
pipeline = Pipeline()
|
||||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||||
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
|
pipeline.add_node(component=ner_node, name="NER", inputs=["ESRetriever"])
|
||||||
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
|
pipeline.add_node(component=tiny_reader, name="Reader", inputs=["NER"])
|
||||||
|
|
||||||
prediction = pipeline.run_batch(
|
prediction = pipeline.run_batch(
|
||||||
queries=["Who lives in Berlin?"], params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}}
|
queries=["Who lives in Berlin?"], params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}}
|
||||||
@ -48,16 +57,15 @@ def test_extractor_batch_single_query(document_store_with_docs):
|
|||||||
assert "Berlin" in entities
|
assert "Berlin" in entities
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
@pytest.mark.integration
|
||||||
def test_extractor_batch_multiple_queries(document_store_with_docs):
|
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||||
|
def test_extractor_batch_multiple_queries(document_store_with_docs, tiny_reader, ner_node):
|
||||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||||
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
|
|
||||||
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
|
|
||||||
|
|
||||||
pipeline = Pipeline()
|
pipeline = Pipeline()
|
||||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||||
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
|
pipeline.add_node(component=ner_node, name="NER", inputs=["ESRetriever"])
|
||||||
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
|
pipeline.add_node(component=tiny_reader, name="Reader", inputs=["NER"])
|
||||||
|
|
||||||
prediction = pipeline.run_batch(
|
prediction = pipeline.run_batch(
|
||||||
queries=["Who lives in Berlin?", "Who lives in New York?"],
|
queries=["Who lives in Berlin?", "Who lives in New York?"],
|
||||||
@ -71,16 +79,15 @@ def test_extractor_batch_multiple_queries(document_store_with_docs):
|
|||||||
assert "New York" in entities_paul
|
assert "New York" in entities_paul
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
@pytest.mark.integration
|
||||||
def test_extractor_output_simplifier(document_store_with_docs):
|
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||||
|
def test_extractor_output_simplifier(document_store_with_docs, tiny_reader, ner_node):
|
||||||
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
es_retriever = BM25Retriever(document_store=document_store_with_docs)
|
||||||
ner = EntityExtractor(model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english")
|
|
||||||
reader = FARMReader(model_name_or_path="deepset/tinyroberta-squad2", num_processes=0)
|
|
||||||
|
|
||||||
pipeline = Pipeline()
|
pipeline = Pipeline()
|
||||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||||
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
|
pipeline.add_node(component=ner_node, name="NER", inputs=["ESRetriever"])
|
||||||
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
|
pipeline.add_node(component=tiny_reader, name="Reader", inputs=["NER"])
|
||||||
|
|
||||||
prediction = pipeline.run(
|
prediction = pipeline.run(
|
||||||
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}}
|
query="Who lives in Berlin?", params={"ESRetriever": {"top_k": 1}, "Reader": {"top_k": 1}}
|
||||||
@ -89,7 +96,8 @@ def test_extractor_output_simplifier(document_store_with_docs):
|
|||||||
assert simplified[0] == {"answer": "Carla and I", "entities": ["Carla"]}
|
assert simplified[0] == {"answer": "Carla and I", "entities": ["Carla"]}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||||
def test_extractor_indexing(document_store):
|
def test_extractor_indexing(document_store):
|
||||||
doc_path = SAMPLES_PATH / "docs" / "doc_2.txt"
|
doc_path = SAMPLES_PATH / "docs" / "doc_2.txt"
|
||||||
|
|
||||||
@ -109,6 +117,18 @@ def test_extractor_indexing(document_store):
|
|||||||
assert "Haystack" in meta["entity_words"]
|
assert "Haystack" in meta["entity_words"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_extractor_doc_query(ner_node):
|
||||||
|
pipeline = Pipeline()
|
||||||
|
pipeline.add_node(component=ner_node, name="NER", inputs=["Query"])
|
||||||
|
|
||||||
|
prediction = pipeline.run(query=None, documents=[Document(content="Carla lives in Berlin", content_type="text")])
|
||||||
|
entities = [x["word"] for x in prediction["documents"][0].meta["entities"]]
|
||||||
|
assert "Carla" in entities
|
||||||
|
assert "Berlin" in entities
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
def test_extract_method():
|
def test_extract_method():
|
||||||
ner = EntityExtractor(
|
ner = EntityExtractor(
|
||||||
model_name_or_path="Jean-Baptiste/camembert-ner", max_seq_len=12, aggregation_strategy="first"
|
model_name_or_path="Jean-Baptiste/camembert-ner", max_seq_len=12, aggregation_strategy="first"
|
||||||
@ -138,6 +158,7 @@ def test_extract_method():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
def test_extract_method_pre_split_text():
|
def test_extract_method_pre_split_text():
|
||||||
ner = EntityExtractor(
|
ner = EntityExtractor(
|
||||||
model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english", max_seq_len=6, pre_split_text=True
|
model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english", max_seq_len=6, pre_split_text=True
|
||||||
@ -167,6 +188,7 @@ def test_extract_method_pre_split_text():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
def test_extract_method_unknown_token():
|
def test_extract_method_unknown_token():
|
||||||
ner = EntityExtractor(
|
ner = EntityExtractor(
|
||||||
model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english",
|
model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english",
|
||||||
@ -196,6 +218,7 @@ def test_extract_method_unknown_token():
|
|||||||
assert output == [{"entity_group": "O", "word": "Hi my name is JamesÐ.", "start": 0, "end": 21}]
|
assert output == [{"entity_group": "O", "word": "Hi my name is JamesÐ.", "start": 0, "end": 21}]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
def test_extract_method_simple_aggregation():
|
def test_extract_method_simple_aggregation():
|
||||||
ner = EntityExtractor(
|
ner = EntityExtractor(
|
||||||
model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english",
|
model_name_or_path="elastic/distilbert-base-cased-finetuned-conll03-english",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user