chore: normalize more optional imports (#5251)

* docstore filters

* modeling metrics

* doc language classifier

* file converter

* docx converter

* tika

* preprocessor

* context matcher

* pylint
This commit is contained in:
ZanSara 2023-08-09 09:27:53 +02:00 committed by GitHub
parent 30e6c7ac43
commit c27622e1bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 59 additions and 175 deletions

View File

@ -5,17 +5,13 @@ from collections import defaultdict
from haystack.document_stores.utils import convert_date_to_rfc3339 from haystack.document_stores.utils import convert_date_to_rfc3339
from haystack.errors import FilterError from haystack.errors import FilterError
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
try: with LazyImport("Run 'pip install farm-haystack[sql]'") as sql_import:
from sqlalchemy.sql import select from sqlalchemy.sql import select
from sqlalchemy import and_, or_ from sqlalchemy import and_, or_
except ImportError as exc:
logger.debug("sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue.")
select = None
and_ = None
or_ = None
def nested_defaultdict() -> defaultdict: def nested_defaultdict() -> defaultdict:
@ -319,10 +315,7 @@ class NotOperation(LogicalFilterClause):
return {"bool": {"must_not": conditions}} return {"bool": {"must_not": conditions}}
def convert_to_sql(self, meta_document_orm): def convert_to_sql(self, meta_document_orm):
if not select: sql_import.check()
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
conditions = [ conditions = [
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm)) meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
for condition in self.conditions for condition in self.conditions
@ -370,10 +363,7 @@ class AndOperation(LogicalFilterClause):
return {"bool": {"must": conditions}} return {"bool": {"must": conditions}}
def convert_to_sql(self, meta_document_orm): def convert_to_sql(self, meta_document_orm):
if not select: sql_import.check()
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
conditions = [ conditions = [
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm)) meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
for condition in self.conditions for condition in self.conditions
@ -406,10 +396,7 @@ class OrOperation(LogicalFilterClause):
return {"bool": {"should": conditions}} return {"bool": {"should": conditions}}
def convert_to_sql(self, meta_document_orm): def convert_to_sql(self, meta_document_orm):
if not select: sql_import.check()
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
conditions = [ conditions = [
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm)) meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
for condition in self.conditions for condition in self.conditions
@ -455,10 +442,7 @@ class EqOperation(ComparisonOperation):
return {"term": {self.field_name: self.comparison_value}} return {"term": {self.field_name: self.comparison_value}}
def convert_to_sql(self, meta_document_orm): def convert_to_sql(self, meta_document_orm):
if not select: sql_import.check()
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
return select([meta_document_orm.document_id]).where( return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value == self.comparison_value meta_document_orm.name == self.field_name, meta_document_orm.value == self.comparison_value
) )
@ -498,10 +482,7 @@ class InOperation(ComparisonOperation):
return {"terms": {self.field_name: self.comparison_value}} return {"terms": {self.field_name: self.comparison_value}}
def convert_to_sql(self, meta_document_orm): def convert_to_sql(self, meta_document_orm):
if not select: sql_import.check()
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
return select([meta_document_orm.document_id]).where( return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value.in_(self.comparison_value) meta_document_orm.name == self.field_name, meta_document_orm.value.in_(self.comparison_value)
) )
@ -544,10 +525,7 @@ class NeOperation(ComparisonOperation):
return {"bool": {"must_not": {"term": {self.field_name: self.comparison_value}}}} return {"bool": {"must_not": {"term": {self.field_name: self.comparison_value}}}}
def convert_to_sql(self, meta_document_orm): def convert_to_sql(self, meta_document_orm):
if not select: sql_import.check()
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
return select([meta_document_orm.document_id]).where( return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value != self.comparison_value meta_document_orm.name == self.field_name, meta_document_orm.value != self.comparison_value
) )
@ -587,10 +565,7 @@ class NinOperation(ComparisonOperation):
return {"bool": {"must_not": {"terms": {self.field_name: self.comparison_value}}}} return {"bool": {"must_not": {"terms": {self.field_name: self.comparison_value}}}}
def convert_to_sql(self, meta_document_orm): def convert_to_sql(self, meta_document_orm):
if not select: sql_import.check()
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
return select([meta_document_orm.document_id]).where( return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value.notin_(self.comparison_value) meta_document_orm.name == self.field_name, meta_document_orm.value.notin_(self.comparison_value)
) )
@ -638,10 +613,7 @@ class GtOperation(ComparisonOperation):
return {"range": {self.field_name: {"gt": self.comparison_value}}} return {"range": {self.field_name: {"gt": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm): def convert_to_sql(self, meta_document_orm):
if not select: sql_import.check()
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
return select([meta_document_orm.document_id]).where( return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value > self.comparison_value meta_document_orm.name == self.field_name, meta_document_orm.value > self.comparison_value
) )
@ -682,10 +654,7 @@ class GteOperation(ComparisonOperation):
return {"range": {self.field_name: {"gte": self.comparison_value}}} return {"range": {self.field_name: {"gte": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm): def convert_to_sql(self, meta_document_orm):
if not select: sql_import.check()
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
return select([meta_document_orm.document_id]).where( return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value >= self.comparison_value meta_document_orm.name == self.field_name, meta_document_orm.value >= self.comparison_value
) )
@ -726,10 +695,7 @@ class LtOperation(ComparisonOperation):
return {"range": {self.field_name: {"lt": self.comparison_value}}} return {"range": {self.field_name: {"lt": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm): def convert_to_sql(self, meta_document_orm):
if not select: sql_import.check()
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
return select([meta_document_orm.document_id]).where( return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value < self.comparison_value meta_document_orm.name == self.field_name, meta_document_orm.value < self.comparison_value
) )
@ -770,10 +736,7 @@ class LteOperation(ComparisonOperation):
return {"range": {self.field_name: {"lte": self.comparison_value}}} return {"range": {self.field_name: {"lte": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm): def convert_to_sql(self, meta_document_orm):
if not select: sql_import.check()
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
return select([meta_document_orm.document_id]).where( return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value <= self.comparison_value meta_document_orm.name == self.field_name, meta_document_orm.value <= self.comparison_value
) )

View File

@ -8,30 +8,15 @@ from transformers import AutoConfig
from haystack.modeling.model.prediction_head import PredictionHead from haystack.modeling.model.prediction_head import PredictionHead
from haystack.modeling.utils import flatten_list from haystack.modeling.utils import flatten_list
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: with LazyImport("Run 'pip install farm-haystack[metrics]'") as metrics_import:
from scipy.stats import pearsonr, spearmanr from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import classification_report, f1_score, matthews_corrcoef, mean_squared_error, r2_score from sklearn.metrics import classification_report, f1_score, matthews_corrcoef, mean_squared_error, r2_score
from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_similarity
except ImportError as exc:
logger.debug("scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue.")
pearsonr = None
spearmanr = None
classification_report = None
f1_score = None
matthews_corrcoef = None
mean_squared_error = None
r2_score = None
cosine_similarity = None
try:
from seqeval.metrics import classification_report as token_classification_report from seqeval.metrics import classification_report as token_classification_report
except ImportError as exc:
logger.debug("seqeval could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue.")
token_classification_report = None
registered_metrics = {} registered_metrics = {}
@ -71,28 +56,19 @@ def simple_accuracy(preds, labels):
def acc_and_f1(preds, labels): def acc_and_f1(preds, labels):
if not f1_score: metrics_import.check()
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
acc = simple_accuracy(preds, labels) acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds) f1 = f1_score(y_true=labels, y_pred=preds)
return {"acc": acc["acc"], "f1": f1, "acc_and_f1": (acc["acc"] + f1) / 2} return {"acc": acc["acc"], "f1": f1, "acc_and_f1": (acc["acc"] + f1) / 2}
def f1_macro(preds, labels): def f1_macro(preds, labels):
if not f1_score: metrics_import.check()
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
return {"f1_macro": f1_score(y_true=labels, y_pred=preds, average="macro")} return {"f1_macro": f1_score(y_true=labels, y_pred=preds, average="macro")}
def pearson_and_spearman(preds, labels): def pearson_and_spearman(preds, labels):
if not pearsonr or not spearmanr: metrics_import.check()
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
pearson_corr = pearsonr(preds, labels)[0] pearson_corr = pearsonr(preds, labels)[0]
spearman_corr = spearmanr(preds, labels)[0] spearman_corr = spearmanr(preds, labels)[0]
return {"pearson": pearson_corr, "spearman": spearman_corr, "corr": (pearson_corr + spearman_corr) / 2} return {"pearson": pearson_corr, "spearman": spearman_corr, "corr": (pearson_corr + spearman_corr) / 2}
@ -124,14 +100,8 @@ def compute_metrics(metric: str, preds, labels):
} }
assert len(preds) == len(labels) assert len(preds) == len(labels)
if metric in FUNCTION_FOR_METRIC.keys(): if metric in FUNCTION_FOR_METRIC.keys():
if ( # pylint: disable=too-many-boolean-expressions if metric in ["mcc", "mse", "r2"]:
(metric == "mcc" and not matthews_corrcoef) metrics_import.check()
or (metric == "mse" and not mean_squared_error)
or (metric == "r2" and not r2_score)
):
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
return FUNCTION_FOR_METRIC[metric](preds, labels) return FUNCTION_FOR_METRIC[metric](preds, labels)
elif isinstance(metric, list): elif isinstance(metric, list):
ret = {} ret = {}
@ -149,16 +119,10 @@ def compute_report_metrics(head: PredictionHead, preds, labels):
if head.ph_output_type in registered_reports: if head.ph_output_type in registered_reports:
report_fn = registered_reports[head.ph_output_type] # type: ignore [index] report_fn = registered_reports[head.ph_output_type] # type: ignore [index]
elif head.ph_output_type == "per_token": elif head.ph_output_type == "per_token":
if not token_classification_report: metrics_import.check()
raise ImportError(
"seqeval could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
report_fn = token_classification_report report_fn = token_classification_report
elif head.ph_output_type == "per_sequence": elif head.ph_output_type == "per_sequence":
if not classification_report: metrics_import.check()
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
report_fn = classification_report report_fn = classification_report
elif head.ph_output_type == "per_token_squad": elif head.ph_output_type == "per_token_squad":
report_fn = lambda *args, **kwargs: "Not Implemented" # pylint: disable=unnecessary-lambda-assignment report_fn = lambda *args, **kwargs: "Not Implemented" # pylint: disable=unnecessary-lambda-assignment
@ -453,11 +417,7 @@ def semantic_answer_similarity(
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
:return: top_1_sas, top_k_sas, pred_label_matrix :return: top_1_sas, top_k_sas, pred_label_matrix
""" """
if not cosine_similarity: metrics_import.check()
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
assert len(predictions) == len(gold_labels) assert len(predictions) == len(gold_labels)
config = AutoConfig.from_pretrained(sas_model_name_or_path, use_auth_token=use_auth_token) config = AutoConfig.from_pretrained(sas_model_name_or_path, use_auth_token=use_auth_token)

View File

@ -1,21 +1,16 @@
import logging import logging
from typing import List, Optional from typing import List, Optional
from haystack.nodes.base import Document from haystack.nodes.base import Document
from haystack.nodes.doc_language_classifier.base import BaseDocumentLanguageClassifier from haystack.nodes.doc_language_classifier.base import BaseDocumentLanguageClassifier
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: with LazyImport("Run 'pip install farm-haystack[preprocessing]' or 'pip install langdetect'") as langdetect_import:
import langdetect import langdetect
except (ImportError, ModuleNotFoundError) as exc:
logger.debug(
"langdetect could not be imported. "
"Run 'pip install farm-haystack[preprocessing]' or 'pip install langdetect' to fix this issue."
)
langdetect = None
class LangdetectDocumentLanguageClassifier(BaseDocumentLanguageClassifier): class LangdetectDocumentLanguageClassifier(BaseDocumentLanguageClassifier):
@ -59,11 +54,7 @@ class LangdetectDocumentLanguageClassifier(BaseDocumentLanguageClassifier):
:param languages_to_route: A list of languages in ISO code, each corresponding to a different output edge (see :param languages_to_route: A list of languages in ISO code, each corresponding to a different output edge (see
[langdetect` documentation](https://github.com/Mimino666/langdetect#languages)). [langdetect` documentation](https://github.com/Mimino666/langdetect#languages)).
""" """
if not langdetect: langdetect_import.check()
raise ImportError(
"langdetect could not be imported. "
"Run 'pip install farm-haystack[file-conversion]' or 'pip install langdetect' to fix this issue."
)
super().__init__(route_by_language=route_by_language, languages_to_route=languages_to_route) super().__init__(route_by_language=route_by_language, languages_to_route=languages_to_route)
def predict(self, documents: List[Document], batch_size: Optional[int] = None) -> List[Document]: def predict(self, documents: List[Document], batch_size: Optional[int] = None) -> List[Document]:

View File

@ -7,17 +7,14 @@ from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
from haystack.nodes.base import BaseComponent from haystack.nodes.base import BaseComponent
from haystack.schema import Document from haystack.schema import Document
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try:
with LazyImport("Run 'pip install farm-haystack[preprocessing]' or 'pip install langdetect'") as langdetect_import:
import langdetect import langdetect
except (ImportError, ModuleNotFoundError) as exc:
logger.debug(
"langdetect could not be imported. "
"Run 'pip install farm-haystack[preprocessing]' or 'pip install langdetect' to fix this issue."
)
langdetect = None
# https://en.wikipedia.org/wiki/Ligature_(writing) # https://en.wikipedia.org/wiki/Ligature_(writing)
@ -138,13 +135,17 @@ class BaseConverter(BaseComponent):
return True return True
lang = None lang = None
if not langdetect: try:
logger.debug("langdetect could not be imported. Haystack won't try to guess the document language.") langdetect_import.check()
else: lang = langdetect.detect(text)
try: except langdetect.lang_detect_exception.LangDetectException:
lang = langdetect.detect(text) pass
except langdetect.lang_detect_exception.LangDetectException: except ImportError as exc:
pass logger.debug(
"langdetect could not be imported. Haystack won't try to guess the document language. "
"Original error: %s",
exc,
)
return lang in valid_languages return lang in valid_languages

View File

@ -5,19 +5,14 @@ from pathlib import Path
from haystack.nodes.file_converter.base import BaseConverter from haystack.nodes.file_converter.base import BaseConverter
from haystack.schema import Document from haystack.schema import Document
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: with LazyImport("Run 'pip install farm-haystack[file-conversion]' or 'pip install docx'") as docx_import:
import docx import docx
except ImportError as exc:
logger.debug(
"docx could not be imported. "
"Run 'pip install farm-haystack[file-conversion]' or 'pip install python-docx' to fix this issue."
)
docx = None
class DocxToTextConverter(BaseConverter): class DocxToTextConverter(BaseConverter):
@ -28,11 +23,7 @@ class DocxToTextConverter(BaseConverter):
id_hash_keys: Optional[List[str]] = None, id_hash_keys: Optional[List[str]] = None,
progress_bar: bool = True, progress_bar: bool = True,
): ):
if not docx: docx_import.check()
raise ImportError(
"docx could not be imported. "
"Run 'pip install farm-haystack[file-conversion]' or 'pip install python-docx' to fix this issue."
)
super().__init__( super().__init__(
remove_numeric_tables=remove_numeric_tables, remove_numeric_tables=remove_numeric_tables,
valid_languages=valid_languages, valid_languages=valid_languages,

View File

@ -10,25 +10,21 @@ import requests
from haystack.nodes.file_converter.base import BaseConverter from haystack.nodes.file_converter.base import BaseConverter
from haystack.schema import Document from haystack.schema import Document
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: with LazyImport("Run 'pip install farm-haystack[file-conversion]' or 'pip install tika'") as tika_import:
from tika import parser as tika_parser from tika import parser as tika_parser
except ImportError as exc:
logger.debug(
"tika could not be imported. "
"Run 'pip install farm-haystack[file-conversion]' or 'pip install tika' to fix this issue."
)
tika_parser = None
TIKA_CONTAINER_NAME = "tika" TIKA_CONTAINER_NAME = "tika"
def launch_tika(sleep=15, delete_existing=False): def launch_tika(sleep=15, delete_existing=False):
tika_import.check()
# Start a Tika server via Docker # Start a Tika server via Docker
logger.debug("Starting Tika ...") logger.debug("Starting Tika ...")
@ -54,6 +50,7 @@ def launch_tika(sleep=15, delete_existing=False):
class TikaXHTMLParser(HTMLParser): class TikaXHTMLParser(HTMLParser):
# Use the built-in HTML parser with minimum dependencies # Use the built-in HTML parser with minimum dependencies
def __init__(self): def __init__(self):
tika_import.check()
self.ingest = True self.ingest = True
self.page = "" self.page = ""
self.pages: List[str] = [] self.pages: List[str] = []
@ -107,11 +104,7 @@ class TikaConverter(BaseConverter):
as a float, or a :ref:`(connect timeout, read timeout) <timeouts>` tuple. as a float, or a :ref:`(connect timeout, read timeout) <timeouts>` tuple.
Defaults to 10 seconds. Defaults to 10 seconds.
""" """
if not tika_parser: tika_import.check()
raise ImportError(
"tika could not be imported. "
"Run 'pip install farm-haystack[file-conversion]' or 'pip install tika' to fix this issue."
)
super().__init__( super().__init__(
remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages, id_hash_keys=id_hash_keys remove_numeric_tables=remove_numeric_tables, valid_languages=valid_languages, id_hash_keys=id_hash_keys
) )

View File

@ -15,19 +15,14 @@ from more_itertools import windowed
from haystack.nodes.preprocessor.base import BasePreProcessor from haystack.nodes.preprocessor.base import BasePreProcessor
from haystack.errors import HaystackError from haystack.errors import HaystackError
from haystack.schema import Document from haystack.schema import Document
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: with LazyImport("Run 'pip install farm-haystack[preprocessing]' or 'pip install nltk'") as nltk_import:
import nltk import nltk
except ImportError as exc:
logger.debug(
"nltk could not be imported. "
"Run 'pip install farm-haystack[preprocessing]' or 'pip install nltk' to fix this issue."
)
nltk = None
iso639_to_nltk = { iso639_to_nltk = {
@ -108,17 +103,16 @@ class PreProcessor(BasePreProcessor):
`max_char_check`-th char, regardless of any other constraint. If the resulting documents are still too long, `max_char_check`-th char, regardless of any other constraint. If the resulting documents are still too long,
they'll be cut again until all fragments are below the maximum allowed length. they'll be cut again until all fragments are below the maximum allowed length.
""" """
nltk_import.check()
if remove_substrings is None: if remove_substrings is None:
remove_substrings = [] remove_substrings = []
super().__init__() super().__init__()
try: try:
if nltk: nltk.data.find("tokenizers/punkt")
nltk.data.find("tokenizers/punkt")
except LookupError: except LookupError:
try: try:
if nltk: nltk.download("punkt")
nltk.download("punkt")
except FileExistsError as error: except FileExistsError as error:
logger.debug("NLTK punkt tokenizer seems to be already downloaded. Error message: %s", error) logger.debug("NLTK punkt tokenizer seems to be already downloaded. Error message: %s", error)
pass pass
@ -830,11 +824,6 @@ class PreProcessor(BasePreProcessor):
return sentences return sentences
def _load_sentence_tokenizer(self, language_name: Optional[str]) -> "nltk.tokenize.punkt.PunktSentenceTokenizer": def _load_sentence_tokenizer(self, language_name: Optional[str]) -> "nltk.tokenize.punkt.PunktSentenceTokenizer":
if not nltk:
raise ImportError(
"nltk could not be imported. "
"Run 'pip install farm-haystack[preprocessing]' or 'pip install nltk' to fix this issue."
)
# Try to load a custom model from 'tokenizer_model_path' # Try to load a custom model from 'tokenizer_model_path'
if self.tokenizer_model_folder is not None: if self.tokenizer_model_folder is not None:
tokenizer_model_path = Path(self.tokenizer_model_folder).absolute() / f"{self.language}.pickle" tokenizer_model_path = Path(self.tokenizer_model_folder).absolute() / f"{self.language}.pickle"

View File

@ -8,14 +8,13 @@ from collections import namedtuple
from tqdm import tqdm from tqdm import tqdm
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
try: with LazyImport("Run 'pip install farm-haystack[metrics]' or 'pip install rapidfuzz'") as rapidfuzz_import:
from rapidfuzz import fuzz from rapidfuzz import fuzz
except ImportError as exc:
logger.debug("rapidfuzz could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue.")
fuzz = None # type: ignore
_CandidateScore = namedtuple("_CandidateScore", ["context_id", "candidate_id", "score"]) _CandidateScore = namedtuple("_CandidateScore", ["context_id", "candidate_id", "score"])
@ -55,10 +54,7 @@ def calculate_context_similarity(
we cut the context on the same side, recalculate the score and take the mean of both. we cut the context on the same side, recalculate the score and take the mean of both.
Thus [AB] <-> [BC] (score ~50) gets recalculated with B <-> B (score ~100) scoring ~75 in total. Thus [AB] <-> [BC] (score ~50) gets recalculated with B <-> B (score ~100) scoring ~75 in total.
""" """
if not fuzz: rapidfuzz_import.check()
raise ImportError(
"rapidfuzz could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
# we need to handle short contexts/contents (e.g single word) # we need to handle short contexts/contents (e.g single word)
# as they produce high scores by matching if the chars of the word are contained in the other one # as they produce high scores by matching if the chars of the word are contained in the other one
# this has to be done after normalizing # this has to be done after normalizing