mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
Create EntityExtractor (#1573)
* Create extractor/entity.py * Aggregate NER words into entities * Support indexing * Add doc strings * Add utility for printing * Update signature of run() to match BaseComponent * Add test * Modify simplify_ner_for_qa to return the dictionary and add its test Co-authored-by: brandenchan <brandenchan@icloud.com>
This commit is contained in:
parent
69a0c9f2ed
commit
25d76f508d
1
haystack/extractor/__init__.py
Normal file
1
haystack/extractor/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from haystack.extractor.entity import EntityExtractor, simplify_ner_for_qa
|
75
haystack/extractor/entity.py
Normal file
75
haystack/extractor/entity.py
Normal file
@ -0,0 +1,75 @@
|
||||
from typing import List, Union, Dict, Optional, Tuple
|
||||
|
||||
import json
|
||||
from haystack import BaseComponent, Document, MultiLabel
|
||||
from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline
|
||||
from transformers import pipeline
|
||||
|
||||
|
||||
class EntityExtractor(BaseComponent):
|
||||
"""
|
||||
This node is used to extract entities out of documents.
|
||||
The most common use case for this would be as a named entity extractor.
|
||||
The default model used is dslim/bert-base-NER.
|
||||
This node can be placed in a querying pipeline to perform entity extraction on retrieved documents only,
|
||||
or it can be placed in an indexing pipeline so that all documents in the document store have extracted entities.
|
||||
The entities extracted by this Node will populate Document.entities
|
||||
"""
|
||||
outgoing_edges = 1
|
||||
|
||||
def __init__(self,
|
||||
model_name_or_path="dslim/bert-base-NER"):
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
token_classifier = AutoModelForTokenClassification.from_pretrained(model_name_or_path)
|
||||
self.model = pipeline("ner", model=token_classifier, tokenizer=tokenizer, aggregation_strategy="simple")
|
||||
|
||||
def run(self, documents: Optional[Union[List[Document], List[dict]]] = None) -> Tuple[Dict, str]: # type: ignore
|
||||
"""
|
||||
This is the method called when this node is used in a pipeline
|
||||
"""
|
||||
if documents:
|
||||
for doc in documents:
|
||||
# In a querying pipeline, doc is a haystack.schema.Document object
|
||||
try:
|
||||
doc.meta["entities"] = self.extract(doc.text) # type: ignore
|
||||
# In an indexing pipeline, doc is a dictionary
|
||||
except AttributeError:
|
||||
doc["meta"]["entities"] = self.extract(doc["text"]) # type: ignore
|
||||
output = {"documents": documents}
|
||||
return output, "output_1"
|
||||
|
||||
def extract(self, text):
|
||||
"""
|
||||
This function can be called to perform entity extraction when using the node in isolation.
|
||||
"""
|
||||
entities = self.model(text)
|
||||
return entities
|
||||
|
||||
|
||||
def simplify_ner_for_qa(output):
|
||||
"""
|
||||
Returns a simplified version of the output dictionary
|
||||
with the following structure:
|
||||
[
|
||||
{
|
||||
answer: { ... }
|
||||
entities: [ { ... }, {} ]
|
||||
}
|
||||
]
|
||||
The entities included are only the ones that overlap with
|
||||
the answer itself.
|
||||
"""
|
||||
compact_output = []
|
||||
for answer in output["answers"]:
|
||||
|
||||
entities = []
|
||||
for entity in answer["meta"]["entities"]:
|
||||
if entity["start"] >= answer["offset_start_in_doc"] and entity["end"] <= answer["offset_end_in_doc"]:
|
||||
entities.append(entity["word"])
|
||||
|
||||
compact_output.append({
|
||||
"answer": answer["answer"],
|
||||
"entities": entities
|
||||
})
|
||||
return compact_output
|
57
test/test_extractor.py
Normal file
57
test/test_extractor.py
Normal file
@ -0,0 +1,57 @@
|
||||
import pytest
|
||||
|
||||
from haystack.retriever.sparse import ElasticsearchRetriever
|
||||
from haystack.reader import FARMReader
|
||||
from haystack.pipeline import Pipeline
|
||||
|
||||
from haystack.extractor import EntityExtractor, simplify_ner_for_qa
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_extractor(document_store_with_docs):
|
||||
|
||||
es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
|
||||
ner = EntityExtractor()
|
||||
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2")
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
|
||||
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
|
||||
|
||||
prediction = pipeline.run(
|
||||
query="Who lives in Berlin?",
|
||||
params={
|
||||
"ESRetriever": {"top_k": 1},
|
||||
"Reader": {"top_k": 1},
|
||||
}
|
||||
)
|
||||
entities = [entity["word"] for entity in prediction["answers"][0]["meta"]["entities"]]
|
||||
assert "Carla" in entities
|
||||
assert "Berlin" in entities
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_extractor_output_simplifier(document_store_with_docs):
|
||||
|
||||
es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
|
||||
ner = EntityExtractor()
|
||||
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2")
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||
pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"])
|
||||
pipeline.add_node(component=reader, name="Reader", inputs=["NER"])
|
||||
|
||||
prediction = pipeline.run(
|
||||
query="Who lives in Berlin?",
|
||||
params={
|
||||
"ESRetriever": {"top_k": 1},
|
||||
"Reader": {"top_k": 1},
|
||||
}
|
||||
)
|
||||
simplified = simplify_ner_for_qa(prediction)
|
||||
assert simplified[0] == {
|
||||
"answer": "Carla",
|
||||
"entities": ["Carla"]
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user