mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
Create EntityExtractor (#1573)
* Create extractor/entity.py * Aggregate NER words into entities * Support indexing * Add doc strings * Add utility for printing * Update signature of run() to match BaseComponent * Add test * Modify simplify_ner_for_qa to return the dictionary and add its test Co-authored-by: brandenchan <brandenchan@icloud.com>
This commit is contained in:
parent
69a0c9f2ed
commit
25d76f508d
1
haystack/extractor/__init__.py
Normal file
1
haystack/extractor/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from haystack.extractor.entity import EntityExtractor, simplify_ner_for_qa
|
75
haystack/extractor/entity.py
Normal file
75
haystack/extractor/entity.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
from typing import List, Union, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import json
|
||||||
|
from haystack import BaseComponent, Document, MultiLabel
|
||||||
|
from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
|
||||||
|
class EntityExtractor(BaseComponent):
|
||||||
|
"""
|
||||||
|
This node is used to extract entities out of documents.
|
||||||
|
The most common use case for this would be as a named entity extractor.
|
||||||
|
The default model used is dslim/bert-base-NER.
|
||||||
|
This node can be placed in a querying pipeline to perform entity extraction on retrieved documents only,
|
||||||
|
or it can be placed in an indexing pipeline so that all documents in the document store have extracted entities.
|
||||||
|
The entities extracted by this Node will populate Document.entities
|
||||||
|
"""
|
||||||
|
outgoing_edges = 1
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_name_or_path="dslim/bert-base-NER"):
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||||
|
token_classifier = AutoModelForTokenClassification.from_pretrained(model_name_or_path)
|
||||||
|
self.model = pipeline("ner", model=token_classifier, tokenizer=tokenizer, aggregation_strategy="simple")
|
||||||
|
|
||||||
|
def run(self, documents: Optional[Union[List[Document], List[dict]]] = None) -> Tuple[Dict, str]: # type: ignore
|
||||||
|
"""
|
||||||
|
This is the method called when this node is used in a pipeline
|
||||||
|
"""
|
||||||
|
if documents:
|
||||||
|
for doc in documents:
|
||||||
|
# In a querying pipeline, doc is a haystack.schema.Document object
|
||||||
|
try:
|
||||||
|
doc.meta["entities"] = self.extract(doc.text) # type: ignore
|
||||||
|
# In an indexing pipeline, doc is a dictionary
|
||||||
|
except AttributeError:
|
||||||
|
doc["meta"]["entities"] = self.extract(doc["text"]) # type: ignore
|
||||||
|
output = {"documents": documents}
|
||||||
|
return output, "output_1"
|
||||||
|
|
||||||
|
def extract(self, text):
|
||||||
|
"""
|
||||||
|
This function can be called to perform entity extraction when using the node in isolation.
|
||||||
|
"""
|
||||||
|
entities = self.model(text)
|
||||||
|
return entities
|
||||||
|
|
||||||
|
|
||||||
|
def simplify_ner_for_qa(output):
|
||||||
|
"""
|
||||||
|
Returns a simplified version of the output dictionary
|
||||||
|
with the following structure:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
answer: { ... }
|
||||||
|
entities: [ { ... }, {} ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
The entities included are only the ones that overlap with
|
||||||
|
the answer itself.
|
||||||
|
"""
|
||||||
|
compact_output = []
|
||||||
|
for answer in output["answers"]:
|
||||||
|
|
||||||
|
entities = []
|
||||||
|
for entity in answer["meta"]["entities"]:
|
||||||
|
if entity["start"] >= answer["offset_start_in_doc"] and entity["end"] <= answer["offset_end_in_doc"]:
|
||||||
|
entities.append(entity["word"])
|
||||||
|
|
||||||
|
compact_output.append({
|
||||||
|
"answer": answer["answer"],
|
||||||
|
"entities": entities
|
||||||
|
})
|
||||||
|
return compact_output
|
57
test/test_extractor.py
Normal file
57
test/test_extractor.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from haystack.retriever.sparse import ElasticsearchRetriever
|
||||||
|
from haystack.reader import FARMReader
|
||||||
|
from haystack.pipeline import Pipeline
|
||||||
|
|
||||||
|
from haystack.extractor import EntityExtractor, simplify_ner_for_qa
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||||
|
def test_extractor(document_store_with_docs):
|
||||||
|
|
||||||
|
es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
|
||||||
|
ner = EntityExtractor()
|
||||||
|
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2")
|
||||||
|
|
||||||
|
pipeline = Pipeline()
|
||||||
|
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||||
|
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
|
||||||
|
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
|
||||||
|
|
||||||
|
prediction = pipeline.run(
|
||||||
|
query="Who lives in Berlin?",
|
||||||
|
params={
|
||||||
|
"ESRetriever": {"top_k": 1},
|
||||||
|
"Reader": {"top_k": 1},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
entities = [entity["word"] for entity in prediction["answers"][0]["meta"]["entities"]]
|
||||||
|
assert "Carla" in entities
|
||||||
|
assert "Berlin" in entities
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||||
|
def test_extractor_output_simplifier(document_store_with_docs):
|
||||||
|
|
||||||
|
es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
|
||||||
|
ner = EntityExtractor()
|
||||||
|
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2")
|
||||||
|
|
||||||
|
pipeline = Pipeline()
|
||||||
|
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||||
|
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
|
||||||
|
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
|
||||||
|
|
||||||
|
prediction = pipeline.run(
|
||||||
|
query="Who lives in Berlin?",
|
||||||
|
params={
|
||||||
|
"ESRetriever": {"top_k": 1},
|
||||||
|
"Reader": {"top_k": 1},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
simplified = simplify_ner_for_qa(prediction)
|
||||||
|
assert simplified[0] == {
|
||||||
|
"answer": "Carla",
|
||||||
|
"entities": ["Carla"]
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user