Clean NER Scanner imports (#11653)

This commit is contained in:
Pere Miquel Brull 2023-05-18 12:53:22 +02:00 committed by GitHub
parent 667706d09b
commit 8795337f88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 20 deletions

View File

@ -48,20 +48,23 @@ class NERScanner:
Based on https://microsoft.github.io/presidio/
"""
import spacy
from presidio_analyzer import AnalyzerEngine
from presidio_analyzer.nlp_engine.spacy_nlp_engine import SpacyNlpEngine
def __init__(self):
import spacy
from presidio_analyzer import AnalyzerEngine
from presidio_analyzer.nlp_engine.spacy_nlp_engine import SpacyNlpEngine
try:
spacy.load(SPACY_EN_MODEL)
except OSError:
logger.warning("Downloading en_core_web_md language model for the spaCy")
from spacy.cli import download
try:
spacy.load(SPACY_EN_MODEL)
except OSError:
logger.warning("Downloading en_core_web_md language model for the spaCy")
from spacy.cli import download
download(SPACY_EN_MODEL)
spacy.load(SPACY_EN_MODEL)
download(SPACY_EN_MODEL)
spacy.load(SPACY_EN_MODEL)
analyzer = AnalyzerEngine(nlp_engine=SpacyNlpEngine(models={"en": SPACY_EN_MODEL}))
self.analyzer = AnalyzerEngine(
nlp_engine=SpacyNlpEngine(models={"en": SPACY_EN_MODEL})
)
@staticmethod
def get_highest_score_label(
@ -77,8 +80,7 @@ class NERScanner:
most_used_label_occurrence = score[1]
return label_score or (None, None)
@classmethod
def scan(cls, sample_data_rows: List[Any]) -> Optional[TagAndConfidence]:
def scan(self, sample_data_rows: List[Any]) -> Optional[TagAndConfidence]:
"""
Scan the column's sample data rows and look for PII
"""
@ -87,7 +89,7 @@ class NERScanner:
str_sample_data_rows = [str(row) for row in sample_data_rows if row is not None]
for row in str_sample_data_rows:
try:
results = cls.analyzer.analyze(row, language="en")
results = self.analyzer.analyze(row, language="en")
for result in results:
logger.debug("Found %s", result.entity_type)
tag = result.entity_type
@ -104,7 +106,7 @@ class NERScanner:
logger.warning(f"Unknown error while processing {row} - {exc}")
logger.debug(traceback.format_exc())
label, score = cls.get_highest_score_label(labels_score, str_sample_data_rows)
label, score = self.get_highest_score_label(labels_score, str_sample_data_rows)
if label and score:
tag_type = NEREntity.__members__.get(label, TagType.NONSENSITIVE).value
return TagAndConfidence(tag=tag_type, confidence=score)

View File

@ -34,6 +34,7 @@ class PIIProcessor:
def __init__(self, metadata: OpenMetadata):
self.metadata = metadata
self.ner_scanner = NERScanner()
def patch_column_tag(
self, tag_type: str, table_entity: Table, column_name: str
@ -81,7 +82,7 @@ class PIIProcessor:
# Scan by column name. If no results there, check the sample data, if any
tag_and_confidence = ColumnNameScanner.scan(column.name.__root__) or (
NERScanner.scan([row[idx] for row in table_data.rows])
self.ner_scanner.scan([row[idx] for row in table_data.rows])
if table_data
else None
)

View File

@ -22,10 +22,12 @@ class NERScannerTest(TestCase):
Validate various typical column names
"""
ner_scanner = NERScanner()
def test_scanner_none(self):
self.assertIsNone(NERScanner.scan(list(range(100))))
self.assertIsNone(self.ner_scanner.scan(list(range(100))))
self.assertIsNone(
NERScanner.scan(
self.ner_scanner.scan(
" ".split(
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nam consequat quam sagittis convallis cursus."
)
@ -34,7 +36,7 @@ class NERScannerTest(TestCase):
def test_scanner_sensitive(self):
self.assertEqual(
NERScanner.scan(
self.ner_scanner.scan(
[
"geraldc@gmail.com",
"saratimithi@godesign.com",
@ -44,6 +46,8 @@ class NERScannerTest(TestCase):
TagType.SENSITIVE,
)
self.assertEqual(
NERScanner.scan(["im ok", "saratimithi@godesign.com", "not sensitive"]).tag,
self.ner_scanner.scan(
["im ok", "saratimithi@godesign.com", "not sensitive"]
).tag,
TagType.SENSITIVE,
)