chore: normalize more optional imports (#5251)

* docstore filters

* modeling metrics

* doc language classifier

* file converter

* docx converter

* tika

* preprocessor

* context matcher

* pylint
This commit is contained in:
ZanSara 2023-08-09 09:27:53 +02:00 committed by GitHub
parent 30e6c7ac43
commit c27622e1bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 59 additions and 175 deletions

View File

@ -5,17 +5,13 @@ from collections import defaultdict
from haystack.document_stores.utils import convert_date_to_rfc3339
from haystack.errors import FilterError
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__file__)
try:
with LazyImport("Run 'pip install farm-haystack[sql]'") as sql_import:
from sqlalchemy.sql import select
from sqlalchemy import and_, or_
except ImportError as exc:
logger.debug("sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue.")
select = None
and_ = None
or_ = None
def nested_defaultdict() -> defaultdict:
@ -319,10 +315,7 @@ class NotOperation(LogicalFilterClause):
return {"bool": {"must_not": conditions}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
sql_import.check()
conditions = [
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
for condition in self.conditions
@ -370,10 +363,7 @@ class AndOperation(LogicalFilterClause):
return {"bool": {"must": conditions}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
sql_import.check()
conditions = [
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
for condition in self.conditions
@ -406,10 +396,7 @@ class OrOperation(LogicalFilterClause):
return {"bool": {"should": conditions}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
sql_import.check()
conditions = [
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
for condition in self.conditions
@ -455,10 +442,7 @@ class EqOperation(ComparisonOperation):
return {"term": {self.field_name: self.comparison_value}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
sql_import.check()
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value == self.comparison_value
)
@ -498,10 +482,7 @@ class InOperation(ComparisonOperation):
return {"terms": {self.field_name: self.comparison_value}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
sql_import.check()
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value.in_(self.comparison_value)
)
@ -544,10 +525,7 @@ class NeOperation(ComparisonOperation):
return {"bool": {"must_not": {"term": {self.field_name: self.comparison_value}}}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
sql_import.check()
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value != self.comparison_value
)
@ -587,10 +565,7 @@ class NinOperation(ComparisonOperation):
return {"bool": {"must_not": {"terms": {self.field_name: self.comparison_value}}}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
sql_import.check()
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value.notin_(self.comparison_value)
)
@ -638,10 +613,7 @@ class GtOperation(ComparisonOperation):
return {"range": {self.field_name: {"gt": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
sql_import.check()
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value > self.comparison_value
)
@ -682,10 +654,7 @@ class GteOperation(ComparisonOperation):
return {"range": {self.field_name: {"gte": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
sql_import.check()
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value >= self.comparison_value
)
@ -726,10 +695,7 @@ class LtOperation(ComparisonOperation):
return {"range": {self.field_name: {"lt": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
sql_import.check()
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value < self.comparison_value
)
@ -770,10 +736,7 @@ class LteOperation(ComparisonOperation):
return {"range": {self.field_name: {"lte": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
sql_import.check()
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value <= self.comparison_value
)

View File

@ -8,30 +8,15 @@ from transformers import AutoConfig
from haystack.modeling.model.prediction_head import PredictionHead
from haystack.modeling.utils import flatten_list
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
try:
with LazyImport("Run 'pip install farm-haystack[metrics]'") as metrics_import:
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import classification_report, f1_score, matthews_corrcoef, mean_squared_error, r2_score
from sklearn.metrics.pairwise import cosine_similarity
except ImportError as exc:
logger.debug("scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue.")
pearsonr = None
spearmanr = None
classification_report = None
f1_score = None
matthews_corrcoef = None
mean_squared_error = None
r2_score = None
cosine_similarity = None
try:
from seqeval.metrics import classification_report as token_classification_report
except ImportError as exc:
logger.debug("seqeval could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue.")
token_classification_report = None
registered_metrics = {}
@ -71,28 +56,19 @@ def simple_accuracy(preds, labels):
def acc_and_f1(preds, labels):
if not f1_score:
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
metrics_import.check()
acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds)
return {"acc": acc["acc"], "f1": f1, "acc_and_f1": (acc["acc"] + f1) / 2}
def f1_macro(preds, labels):
if not f1_score:
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
metrics_import.check()
return {"f1_macro": f1_score(y_true=labels, y_pred=preds, average="macro")}
def pearson_and_spearman(preds, labels):
if not pearsonr or not spearmanr:
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
metrics_import.check()
pearson_corr = pearsonr(preds, labels)[0]
spearman_corr = spearmanr(preds, labels)[0]
return {"pearson": pearson_corr, "spearman": spearman_corr, "corr": (pearson_corr + spearman_corr) / 2}
@ -124,14 +100,8 @@ def compute_metrics(metric: str, preds, labels):
}
assert len(preds) == len(labels)
if metric in FUNCTION_FOR_METRIC.keys():
if ( # pylint: disable=too-many-boolean-expressions
(metric == "mcc" and not matthews_corrcoef)
or (metric == "mse" and not mean_squared_error)
or (metric == "r2" and not r2_score)
):
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
if metric in ["mcc", "mse", "r2"]:
metrics_import.check()
return FUNCTION_FOR_METRIC[metric](preds, labels)
elif isinstance(metric, list):
ret = {}
@ -149,16 +119,10 @@ def compute_report_metrics(head: PredictionHead, preds, labels):
if head.ph_output_type in registered_reports:
report_fn = registered_reports[head.ph_output_type] # type: ignore [index]
elif head.ph_output_type == "per_token":
if not token_classification_report:
raise ImportError(
"seqeval could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
metrics_import.check()
report_fn = token_classification_report
elif head.ph_output_type == "per_sequence":
if not classification_report:
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
metrics_import.check()
report_fn = classification_report
elif head.ph_output_type == "per_token_squad":
report_fn = lambda *args, **kwargs: "Not Implemented" # pylint: disable=unnecessary-lambda-assignment
@ -453,11 +417,7 @@ def semantic_answer_similarity(
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
:return: top_1_sas, top_k_sas, pred_label_matrix
"""
if not cosine_similarity:
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
metrics_import.check()
assert len(predictions) == len(gold_labels)
config = AutoConfig.from_pretrained(sas_model_name_or_path, use_auth_token=use_auth_token)

View File

@ -1,21 +1,16 @@
import logging
from typing import List, Optional
from haystack.nodes.base import Document
from haystack.nodes.doc_language_classifier.base import BaseDocumentLanguageClassifier
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
try:
with LazyImport("Run 'pip install farm-haystack[preprocessing]' or 'pip install langdetect'") as langdetect_import:
import langdetect
except (ImportError, ModuleNotFoundError) as exc:
logger.debug(
"langdetect could not be imported. "
"Run 'pip install farm-haystack[preprocessing]' or 'pip install langdetect' to fix this issue."
)
langdetect = None
class LangdetectDocumentLanguageClassifier(BaseDocumentLanguageClassifier):
@ -59,11 +54,7 @@ class LangdetectDocumentLanguageClassifier(BaseDocumentLanguageClassifier):
:param languages_to_route: A list of languages in ISO code, each corresponding to a different output edge (see
[langdetect` documentation](https://github.com/Mimino666/langdetect#languages)).
"""
if not langdetect:
raise ImportError(
"langdetect could not be imported. "
"Run 'pip install farm-haystack[file-conversion]' or 'pip install langdetect' to fix this issue."
)
langdetect_import.check()
super().__init__(route_by_language=route_by_language, languages_to_route=languages_to_route)
def predict(self, documents: List[Document], batch_size: Optional[int] = None) -> List[Document]:

View File

@ -7,17 +7,14 @@ from pathlib import Path
from tqdm import tqdm
from haystack.nodes.base import BaseComponent
from haystack.schema import Document
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
try:
with LazyImport("Run 'pip install farm-haystack[preprocessing]' or 'pip install langdetect'") as langdetect_import:
import langdetect
except (ImportError, ModuleNotFoundError) as exc:
logger.debug(
"langdetect could not be imported. "
"Run 'pip install farm-haystack[preprocessing]' or 'pip install langdetect' to fix this issue."
)
langdetect = None
# https://en.wikipedia.org/wiki/Ligature_(writing)
@ -138,13 +135,17 @@ class BaseConverter(BaseComponent):
return True
lang = None
if not langdetect:
logger.debug("langdetect could not be imported. Haystack won't try to guess the document language.")
else:
try:
lang = langdetect.detect(text)
except langdetect.lang_detect_exception.LangDetectException:
pass
try:
langdetect_import.check()
lang = langdetect.detect(text)
except langdetect.lang_detect_exception.LangDetectException:
pass
except ImportError as exc:
logger.debug(
"langdetect could not be imported. Haystack won't try to guess the document language. "
"Original error: %s",
exc,
)
return lang in valid_languages

View File

@ -5,19 +5,14 @@ from pathlib import Path
from haystack.nodes.file_converter.base import BaseConverter
from haystack.schema import Document
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
try:
with LazyImport("Run 'pip install farm-haystack[file-conversion]' or 'pip install docx'") as docx_import:
import docx
except ImportError as exc:
logger.debug(
"docx could not be imported. "
"Run 'pip install farm-haystack[file-conversion]' or 'pip install python-docx' to fix this issue."
)
docx = None
class DocxToTextConverter(BaseConverter):
@ -28,11 +23,7 @@ class DocxToTextConverter(BaseConverter):
id_hash_keys: Optional[List[str]] = None,
progress_bar: bool = True,
):
if not docx:
raise ImportError(
"docx could not be imported. "
"Run 'pip install farm-haystack[file-conversion]' or 'pip install python-docx' to fix this issue."
)
docx_import.check()
super().__init__(
remove_numeric_tables=remove_numeric_tables,
valid_languages=valid_languages,

View File

@ -10,25 +10,21 @@ import requests
from haystack.nodes.file_converter.base import BaseConverter
from haystack.schema import Document
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
try:
with LazyImport("Run 'pip install farm-haystack[file-conversion]' or 'pip install tika'") as tika_import:
from tika import parser as tika_parser
except ImportError as exc:
logger.debug(
"tika could not be imported. "
"Run 'pip install farm-haystack[file-conversion]' or 'pip install tika' to fix this issue."
)
tika_parser = None
TIKA_CONTAINER_NAME = "tika"
def launch_tika(sleep=15, delete_existing=False):
tika_import.check()
# Start a Tika server via Docker
logger.debug("Starting Tika ...")
@ -54,6 +50,7 @@ def launch_tika(sleep=15, delete_existing=False):
class TikaXHTMLParser(HTMLParser):
# Use the built-in HTML parser with minimum dependencies
def __init__(self):
tika_import.check()
self.ingest = True
self.page = ""
self.pages: List[str] = []
@ -107,11 +104,7 @@ class TikaConverter(BaseConverter):
as a float, or a :ref:`(connect timeout, read timeout) <timeouts>` tuple.
Defaults to 10 seconds.
"""
if not tika_parser:
raise ImportError(
"tika could not be imported. "
"Run 'pip install farm-haystack[file-conversion]' or 'pip install tika' to fix this issue."
)
tika_import.check()
super().__init__(
remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages, id_hash_keys=id_hash_keys
)

View File

@ -15,19 +15,14 @@ from more_itertools import windowed
from haystack.nodes.preprocessor.base import BasePreProcessor
from haystack.errors import HaystackError
from haystack.schema import Document
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
try:
with LazyImport("Run 'pip install farm-haystack[preprocessing]' or 'pip install nltk'") as nltk_import:
import nltk
except ImportError as exc:
logger.debug(
"nltk could not be imported. "
"Run 'pip install farm-haystack[preprocessing]' or 'pip install nltk' to fix this issue."
)
nltk = None
iso639_to_nltk = {
@ -108,17 +103,16 @@ class PreProcessor(BasePreProcessor):
`max_char_check`-th char, regardless of any other constraint. If the resulting documents are still too long,
they'll be cut again until all fragments are below the maximum allowed length.
"""
nltk_import.check()
if remove_substrings is None:
remove_substrings = []
super().__init__()
try:
if nltk:
nltk.data.find("tokenizers/punkt")
nltk.data.find("tokenizers/punkt")
except LookupError:
try:
if nltk:
nltk.download("punkt")
nltk.download("punkt")
except FileExistsError as error:
logger.debug("NLTK punkt tokenizer seems to be already downloaded. Error message: %s", error)
pass
@ -830,11 +824,6 @@ class PreProcessor(BasePreProcessor):
return sentences
def _load_sentence_tokenizer(self, language_name: Optional[str]) -> "nltk.tokenize.punkt.PunktSentenceTokenizer":
if not nltk:
raise ImportError(
"nltk could not be imported. "
"Run 'pip install farm-haystack[preprocessing]' or 'pip install nltk' to fix this issue."
)
# Try to load a custom model from 'tokenizer_model_path'
if self.tokenizer_model_folder is not None:
tokenizer_model_path = Path(self.tokenizer_model_folder).absolute() / f"{self.language}.pickle"

View File

@ -8,14 +8,13 @@ from collections import namedtuple
from tqdm import tqdm
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__file__)
try:
with LazyImport("Run 'pip install farm-haystack[metrics]' or 'pip install rapidfuzz'") as rapidfuzz_import:
from rapidfuzz import fuzz
except ImportError as exc:
logger.debug("rapidfuzz could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue.")
fuzz = None # type: ignore
_CandidateScore = namedtuple("_CandidateScore", ["context_id", "candidate_id", "score"])
@ -55,10 +54,7 @@ def calculate_context_similarity(
we cut the context on the same side, recalculate the score and take the mean of both.
Thus [AB] <-> [BC] (score ~50) gets recalculated with B <-> B (score ~100) scoring ~75 in total.
"""
if not fuzz:
raise ImportError(
"rapidfuzz could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
rapidfuzz_import.check()
# we need to handle short contexts/contents (e.g single word)
# as they produce high scores by matching if the chars of the word are contained in the other one
# this has to be done after normalizing