Create EntityExtractor (#1573)

* Create extractor/entity.py

* Aggregate NER words into entities

* Support indexing

* Add doc strings

* Add utility for printing

* Update signature of run() to match BaseComponent

* Add test

* Modify simplify_ner_for_qa to return the dictionary and add its test

Co-authored-by: brandenchan <brandenchan@icloud.com>
This commit is contained in:
Sara Zan 2021-10-11 11:04:11 +02:00 committed by GitHub
parent 69a0c9f2ed
commit 25d76f508d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 133 additions and 0 deletions

View File

@ -0,0 +1 @@
from haystack.extractor.entity import EntityExtractor, simplify_ner_for_qa

View File

@ -0,0 +1,75 @@
from typing import List, Union, Dict, Optional, Tuple
import json
from haystack import BaseComponent, Document, MultiLabel
from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline
from transformers import pipeline
class EntityExtractor(BaseComponent):
"""
This node is used to extract entities out of documents.
The most common use case for this would be as a named entity extractor.
The default model used is dslim/bert-base-NER.
This node can be placed in a querying pipeline to perform entity extraction on retrieved documents only,
or it can be placed in an indexing pipeline so that all documents in the document store have extracted entities.
The entities extracted by this Node will populate Document.entities
"""
outgoing_edges = 1
def __init__(self,
model_name_or_path="dslim/bert-base-NER"):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
token_classifier = AutoModelForTokenClassification.from_pretrained(model_name_or_path)
self.model = pipeline("ner", model=token_classifier, tokenizer=tokenizer, aggregation_strategy="simple")
def run(self, documents: Optional[Union[List[Document], List[dict]]] = None) -> Tuple[Dict, str]: # type: ignore
"""
This is the method called when this node is used in a pipeline
"""
if documents:
for doc in documents:
# In a querying pipeline, doc is a haystack.schema.Document object
try:
doc.meta["entities"] = self.extract(doc.text) # type: ignore
# In an indexing pipeline, doc is a dictionary
except AttributeError:
doc["meta"]["entities"] = self.extract(doc["text"]) # type: ignore
output = {"documents": documents}
return output, "output_1"
def extract(self, text):
"""
This function can be called to perform entity extraction when using the node in isolation.
"""
entities = self.model(text)
return entities
def simplify_ner_for_qa(output):
"""
Returns a simplified version of the output dictionary
with the following structure:
[
{
answer: { ... }
entities: [ { ... }, {} ]
}
]
The entities included are only the ones that overlap with
the answer itself.
"""
compact_output = []
for answer in output["answers"]:
entities = []
for entity in answer["meta"]["entities"]:
if entity["start"] >= answer["offset_start_in_doc"] and entity["end"] <= answer["offset_end_in_doc"]:
entities.append(entity["word"])
compact_output.append({
"answer": answer["answer"],
"entities": entities
})
return compact_output

57
test/test_extractor.py Normal file
View File

@ -0,0 +1,57 @@
import pytest
from haystack.retriever.sparse import ElasticsearchRetriever
from haystack.reader import FARMReader
from haystack.pipeline import Pipeline
from haystack.extractor import EntityExtractor, simplify_ner_for_qa
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_extractor(document_store_with_docs):
es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
ner = EntityExtractor()
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2")
pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
prediction = pipeline.run(
query="Who lives in Berlin?",
params={
"ESRetriever": {"top_k": 1},
"Reader": {"top_k": 1},
}
)
entities = [entity["word"] for entity in prediction["answers"][0]["meta"]["entities"]]
assert "Carla" in entities
assert "Berlin" in entities
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_extractor_output_simplifier(document_store_with_docs):
es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
ner = EntityExtractor()
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2")
pipeline = Pipeline()
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
prediction = pipeline.run(
query="Who lives in Berlin?",
params={
"ESRetriever": {"top_k": 1},
"Reader": {"top_k": 1},
}
)
simplified = simplify_ner_for_qa(prediction)
assert simplified[0] == {
"answer": "Carla",
"entities": ["Carla"]
}