diff --git a/haystack/extractor/__init__.py b/haystack/extractor/__init__.py new file mode 100644 index 000000000..0c6d3aa1d --- /dev/null +++ b/haystack/extractor/__init__.py @@ -0,0 +1 @@ +from haystack.extractor.entity import EntityExtractor, simplify_ner_for_qa \ No newline at end of file diff --git a/haystack/extractor/entity.py b/haystack/extractor/entity.py new file mode 100644 index 000000000..1128fc5cd --- /dev/null +++ b/haystack/extractor/entity.py @@ -0,0 +1,75 @@ +from typing import List, Union, Dict, Optional, Tuple + +import json +from haystack import BaseComponent, Document, MultiLabel +from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline +from transformers import pipeline + + +class EntityExtractor(BaseComponent): + """ + This node is used to extract entities out of documents. + The most common use case for this would be as a named entity extractor. + The default model used is dslim/bert-base-NER. + This node can be placed in a querying pipeline to perform entity extraction on retrieved documents only, + or it can be placed in an indexing pipeline so that all documents in the document store have extracted entities. + The entities extracted by this Node will populate Document.entities + """ + outgoing_edges = 1 + + def __init__(self, + model_name_or_path="dslim/bert-base-NER"): + + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + token_classifier = AutoModelForTokenClassification.from_pretrained(model_name_or_path) + self.model = pipeline("ner", model=token_classifier, tokenizer=tokenizer, aggregation_strategy="simple") + + def run(self, documents: Optional[Union[List[Document], List[dict]]] = None) -> Tuple[Dict, str]: # type: ignore + """ + This is the method called when this node is used in a pipeline + """ + if documents: + for doc in documents: + # In a querying pipeline, doc is a haystack.schema.Document object + try: + doc.meta["entities"] = self.extract(doc.text) # type: ignore + # In an indexing pipeline, doc is a dictionary + except AttributeError: + doc["meta"]["entities"] = self.extract(doc["text"]) # type: ignore + output = {"documents": documents} + return output, "output_1" + + def extract(self, text): + """ + This function can be called to perform entity extraction when using the node in isolation. + """ + entities = self.model(text) + return entities + + +def simplify_ner_for_qa(output): + """ + Returns a simplified version of the output dictionary + with the following structure: + [ + { + answer: { ... } + entities: [ { ... }, {} ] + } + ] + The entities included are only the ones that overlap with + the answer itself. + """ + compact_output = [] + for answer in output["answers"]: + + entities = [] + for entity in answer["meta"]["entities"]: + if entity["start"] >= answer["offset_start_in_doc"] and entity["end"] <= answer["offset_end_in_doc"]: + entities.append(entity["word"]) + + compact_output.append({ + "answer": answer["answer"], + "entities": entities + }) + return compact_output diff --git a/test/test_extractor.py b/test/test_extractor.py new file mode 100644 index 000000000..6c13157a4 --- /dev/null +++ b/test/test_extractor.py @@ -0,0 +1,57 @@ +import pytest + +from haystack.retriever.sparse import ElasticsearchRetriever +from haystack.reader import FARMReader +from haystack.pipeline import Pipeline + +from haystack.extractor import EntityExtractor, simplify_ner_for_qa + + +@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) +def test_extractor(document_store_with_docs): + + es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs) + ner = EntityExtractor() + reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2") + + pipeline = Pipeline() + pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"]) + pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"]) + pipeline.add_node(component=reader, name="Reader", inputs=["NER"]) + + prediction = pipeline.run( + query="Who lives in Berlin?", + params={ + "ESRetriever": {"top_k": 1}, + "Reader": {"top_k": 1}, + } + ) + entities = [entity["word"] for entity in prediction["answers"][0]["meta"]["entities"]] + assert "Carla" in entities + assert "Berlin" in entities + + +@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) +def test_extractor_output_simplifier(document_store_with_docs): + + es_retriever = ElasticsearchRetriever(document_store=document_store_with_docs) + ner = EntityExtractor() + reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2") + + pipeline = Pipeline() + pipeline.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"]) + pipeline.add_node(component=ner, name="NER", inputs=["ESRetriever"]) + pipeline.add_node(component=reader, name="Reader", inputs=["NER"]) + + prediction = pipeline.run( + query="Who lives in Berlin?", + params={ + "ESRetriever": {"top_k": 1}, + "Reader": {"top_k": 1}, + } + ) + simplified = simplify_ner_for_qa(prediction) + assert simplified[0] == { + "answer": "Carla", + "entities": ["Carla"] + } \ No newline at end of file