diff --git a/haystack/extractor/entity.py b/haystack/extractor/entity.py index 7d908e75b..7cbeee937 100644 --- a/haystack/extractor/entity.py +++ b/haystack/extractor/entity.py @@ -19,7 +19,8 @@ class EntityExtractor(BaseComponent): def __init__(self, model_name_or_path="dslim/bert-base-NER"): - + + self.set_config(model_name_or_path=model_name_or_path) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) token_classifier = AutoModelForTokenClassification.from_pretrained(model_name_or_path) self.model = pipeline("ner", model=token_classifier, tokenizer=tokenizer, aggregation_strategy="simple") diff --git a/haystack/pipeline.py b/haystack/pipeline.py index b9d247f50..5aad42945 100644 --- a/haystack/pipeline.py +++ b/haystack/pipeline.py @@ -35,6 +35,8 @@ from haystack.summarizer.base import BaseSummarizer from haystack.translator.base import BaseTranslator from haystack.document_store.base import BaseDocumentStore from haystack.question_generator import QuestionGenerator +from haystack.extractor import EntityExtractor + logger = logging.getLogger(__name__)