refactor!: extract evaluation and statistical dependencies (#4457)

* try-catch sklearn and scipy

* haystack imports

* linting

* mypy

* try to import baseretriever

* remove typing

* unused import

* remove more typing

* pylint

* isolate sql imports for postgres, which we don't use anyway

* remove stats

* replace expit

* als inmemory

* mypy

* feedback

* docker

* expit

* re-add njit
This commit is contained in:
ZanSara 2023-04-12 15:38:56 +02:00 committed by GitHub
parent 5d41e60d89
commit ba11d1c2a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 180 additions and 44 deletions

View File

@ -45,7 +45,7 @@ target "base-cpu" {
build_image = "python:3.10-slim"
base_image = "python:3.10-slim"
haystack_version = "${HAYSTACK_VERSION}"
haystack_extras = notequal("",HAYSTACK_EXTRAS) ? "${HAYSTACK_EXTRAS}" : "[docstores,crawler,preprocessing,ocr,onnx,beir]"
haystack_extras = notequal("",HAYSTACK_EXTRAS) ? "${HAYSTACK_EXTRAS}" : "[docstores,crawler,preprocessing,ocr,onnx,metrics,beir]"
}
platforms = ["linux/amd64", "linux/arm64"]
}
@ -59,7 +59,7 @@ target "base-gpu" {
build_image = "pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime"
base_image = "pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime"
haystack_version = "${HAYSTACK_VERSION}"
haystack_extras = notequal("",HAYSTACK_EXTRAS) ? "${HAYSTACK_EXTRAS}" : "[docstores-gpu,crawler,preprocessing,ocr,onnx-gpu]"
haystack_extras = notequal("",HAYSTACK_EXTRAS) ? "${HAYSTACK_EXTRAS}" : "[docstores-gpu,crawler,preprocessing,ocr,onnx-gpu,metrics]"
}
platforms = ["linux/amd64", "linux/arm64"]
}

View File

@ -17,6 +17,7 @@ from haystack.errors import DuplicateDocumentError, DocumentStoreError, Haystack
from haystack.nodes.preprocessor import PreProcessor
from haystack.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl
from haystack.utils.labels import aggregate_labels
from haystack.utils.scipy_utils import expit
logger = logging.getLogger(__name__)
@ -31,11 +32,6 @@ except (ImportError, ModuleNotFoundError):
return f
@njit # (fastmath=True)
def expit(x: float) -> float:
return 1 / (1 + np.exp(-x))
class BaseKnowledgeGraph(BaseComponent):
"""
Base class for implementing Knowledge Graphs.

View File

@ -1,13 +1,22 @@
import logging
from typing import Union, List, Dict, Optional, Tuple
from abc import ABC, abstractmethod
from collections import defaultdict
from sqlalchemy.sql import select
from sqlalchemy import and_, or_
from haystack.document_stores.utils import convert_date_to_rfc3339
from haystack.errors import FilterError
logger = logging.getLogger(__file__)
try:
from sqlalchemy.sql import select
from sqlalchemy import and_, or_
except ImportError as exc:
logger.debug("sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue.")
select = None
and_ = None
or_ = None
def nested_defaultdict() -> defaultdict:
"""
@ -310,6 +319,10 @@ class NotOperation(LogicalFilterClause):
return {"bool": {"must_not": conditions}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
conditions = [
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
for condition in self.conditions
@ -357,6 +370,10 @@ class AndOperation(LogicalFilterClause):
return {"bool": {"must": conditions}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
conditions = [
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
for condition in self.conditions
@ -389,6 +406,10 @@ class OrOperation(LogicalFilterClause):
return {"bool": {"should": conditions}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
conditions = [
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
for condition in self.conditions
@ -434,6 +455,10 @@ class EqOperation(ComparisonOperation):
return {"term": {self.field_name: self.comparison_value}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value == self.comparison_value
)
@ -466,6 +491,10 @@ class InOperation(ComparisonOperation):
return {"terms": {self.field_name: self.comparison_value}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value.in_(self.comparison_value)
)
@ -508,6 +537,10 @@ class NeOperation(ComparisonOperation):
return {"bool": {"must_not": {"term": {self.field_name: self.comparison_value}}}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value != self.comparison_value
)
@ -540,6 +573,10 @@ class NinOperation(ComparisonOperation):
return {"bool": {"must_not": {"terms": {self.field_name: self.comparison_value}}}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value.notin_(self.comparison_value)
)
@ -582,6 +619,10 @@ class GtOperation(ComparisonOperation):
return {"range": {self.field_name: {"gt": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value > self.comparison_value
)
@ -617,6 +658,10 @@ class GteOperation(ComparisonOperation):
return {"range": {self.field_name: {"gte": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value >= self.comparison_value
)
@ -652,6 +697,10 @@ class LtOperation(ComparisonOperation):
return {"range": {self.field_name: {"lt": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value < self.comparison_value
)
@ -687,6 +736,10 @@ class LteOperation(ComparisonOperation):
return {"range": {self.field_name: {"lte": self.comparison_value}}}
def convert_to_sql(self, meta_document_orm):
if not select:
raise ImportError(
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
)
return select([meta_document_orm.document_id]).where(
meta_document_orm.name == self.field_name, meta_document_orm.value <= self.comparison_value
)

View File

@ -16,7 +16,6 @@ import torch
from tqdm.auto import tqdm
import rank_bm25
import pandas as pd
from scipy.special import expit
from haystack.schema import Document, FilterType, Label
from haystack.errors import DuplicateDocumentError, DocumentStoreError
@ -24,7 +23,9 @@ from haystack.document_stores import KeywordDocumentStore
from haystack.document_stores.base import get_batches_from_generator
from haystack.modeling.utils import initialize_device_settings
from haystack.document_stores.filter_utils import LogicalFilterClause
from haystack.nodes.retriever import DenseRetriever
from haystack.nodes.retriever.dense import DenseRetriever
from haystack.utils.scipy_utils import expit
logger = logging.getLogger(__name__)

View File

@ -10,13 +10,13 @@ try:
from pymilvus import FieldSchema, CollectionSchema, Collection, connections, utility
from pymilvus.client.abstract import QueryResult
from pymilvus.client.types import DataType
from haystack.document_stores.sql import SQLDocumentStore # type: ignore
except (ImportError, ModuleNotFoundError) as ie:
from haystack.utils.import_utils import _optional_component_not_installed
_optional_component_not_installed(__name__, "milvus2", ie)
_optional_component_not_installed(__name__, "milvus", ie)
from haystack.schema import Document, FilterType
from haystack.document_stores import SQLDocumentStore
from haystack.document_stores.base import get_batches_from_generator
from haystack.nodes.retriever import DenseRetriever

View File

@ -10,7 +10,6 @@ import time
from string import Template
import numpy as np
from scipy.special import expit
from tqdm.auto import tqdm
from pydantic.error_wrappers import ValidationError
@ -20,6 +19,7 @@ from haystack.document_stores.base import get_batches_from_generator
from haystack.document_stores.filter_utils import LogicalFilterClause
from haystack.errors import DocumentStoreError, HaystackError
from haystack.nodes.retriever import DenseRetriever
from haystack.utils.scipy_utils import expit
logger = logging.getLogger(__name__)

View File

@ -3,18 +3,36 @@ from functools import reduce
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from scipy.stats import pearsonr, spearmanr
from sentence_transformers import CrossEncoder, SentenceTransformer
from seqeval.metrics import classification_report as token_classification_report
from sklearn.metrics import classification_report, f1_score, matthews_corrcoef, mean_squared_error, r2_score
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoConfig
from haystack.modeling.model.prediction_head import PredictionHead
from haystack.modeling.utils import flatten_list
logger = logging.getLogger(__name__)
try:
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import classification_report, f1_score, matthews_corrcoef, mean_squared_error, r2_score
from sklearn.metrics.pairwise import cosine_similarity
except ImportError as exc:
logger.debug("scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue.")
pearsonr = None
spearmanr = None
classification_report = None
f1_score = None
matthews_corrcoef = None
mean_squared_error = None
r2_score = None
cosine_similarity = None
try:
from seqeval.metrics import classification_report as token_classification_report
except ImportError as exc:
logger.debug("seqeval could not be imported. Run 'pip install farm-haystack[eval]' to fix this issue.")
token_classification_report = None
registered_metrics = {}
registered_reports = {}
@ -53,16 +71,28 @@ def simple_accuracy(preds, labels):
def acc_and_f1(preds, labels):
if not f1_score:
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds)
return {"acc": acc["acc"], "f1": f1, "acc_and_f1": (acc["acc"] + f1) / 2}
def f1_macro(preds, labels):
if not f1_score:
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
return {"f1_macro": f1_score(y_true=labels, y_pred=preds, average="macro")}
def pearson_and_spearman(preds, labels):
if not pearsonr or not spearmanr:
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
pearson_corr = pearsonr(preds, labels)[0]
spearman_corr = spearmanr(preds, labels)[0]
return {"pearson": pearson_corr, "spearman": spearman_corr, "corr": (pearson_corr + spearman_corr) / 2}
@ -94,6 +124,14 @@ def compute_metrics(metric: str, preds, labels):
}
assert len(preds) == len(labels)
if metric in FUNCTION_FOR_METRIC.keys():
if ( # pylint: disable=too-many-boolean-expressions
(metric == "mcc" and not matthews_corrcoef)
or (metric == "mse" and not mean_squared_error)
or (metric == "r2" and not r2_score)
):
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
return FUNCTION_FOR_METRIC[metric](preds, labels)
elif isinstance(metric, list):
ret = {}
@ -111,8 +149,14 @@ def compute_report_metrics(head: PredictionHead, preds, labels):
if head.ph_output_type in registered_reports:
report_fn = registered_reports[head.ph_output_type] # type: ignore [index]
elif head.ph_output_type == "per_token":
if not token_classification_report:
raise ImportError("seqeval could not be imported. Run 'pip install farm-haystack[eval]' to fix this issue.")
report_fn = token_classification_report
elif head.ph_output_type == "per_sequence":
if not classification_report:
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
report_fn = classification_report
elif head.ph_output_type == "per_token_squad":
report_fn = lambda *args, **kwargs: "Not Implemented" # pylint: disable=unnecessary-lambda-assignment
@ -407,6 +451,11 @@ def semantic_answer_similarity(
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
:return: top_1_sas, top_k_sas, pred_label_matrix
"""
if not cosine_similarity:
raise ImportError(
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
)
assert len(predictions) == len(gold_labels)
config = AutoConfig.from_pretrained(sas_model_name_or_path, use_auth_token=use_auth_token)

View File

@ -10,11 +10,11 @@ from torch import nn
from torch import optim
from torch.nn import CrossEntropyLoss, NLLLoss
from transformers import AutoModelForQuestionAnswering
from scipy.special import expit
from haystack.modeling.data_handler.samples import SampleBasket
from haystack.modeling.model.predictions import QACandidate, QAPred
from haystack.modeling.utils import try_get, all_gather_list
from haystack.utils.scipy_utils import expit
logger = logging.getLogger(__name__)

View File

@ -15,7 +15,6 @@ from transformers import (
from haystack.schema import Document
from haystack.nodes.answer_generator.base import BaseGenerator
from haystack.nodes.retriever.dense import DensePassageRetriever
from haystack.modeling.utils import initialize_device_settings
@ -69,7 +68,7 @@ class RAGenerator(BaseGenerator):
self,
model_name_or_path: str = "facebook/rag-token-nq",
model_version: Optional[str] = None,
retriever: Optional[DensePassageRetriever] = None,
retriever=None,
generator_type: str = "token",
top_k: int = 2,
max_length: int = 200,

View File

@ -9,9 +9,9 @@ from tqdm.auto import tqdm
from haystack.modeling.utils import initialize_device_settings
from haystack.nodes.base import BaseComponent
from haystack.nodes.question_generator import QuestionGenerator
from haystack.nodes.retriever.base import BaseRetriever
from haystack.schema import Document
logger = logging.getLogger(__name__)
@ -60,7 +60,7 @@ class PseudoLabelGenerator(BaseComponent):
def __init__(
self,
question_producer: Union[QuestionGenerator, List[Dict[str, str]]],
retriever: BaseRetriever,
retriever,
cross_encoder_model_name_or_path: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
max_questions_per_document: int = 3,
top_k: int = 50,

View File

@ -1,5 +1,6 @@
import itertools
from typing import List, Optional, Sequence, Tuple, Union
import itertools
import logging
from abc import abstractmethod
from copy import deepcopy
@ -7,10 +8,13 @@ from functools import wraps
from time import perf_counter
import numpy as np
from scipy.special import expit
from haystack.schema import Document, Answer, Span, MultiLabel
from haystack.nodes.base import BaseComponent
from haystack.utils.scipy_utils import expit
logger = logging.getLogger(__name__)
class BaseReader(BaseComponent):

View File

@ -1,13 +1,22 @@
from typing import Generator, Iterable, Optional, Tuple, List, Union
import re
import logging
from itertools import groupby
from multiprocessing.pool import Pool
from collections import namedtuple
from rapidfuzz import fuzz
from tqdm.auto import tqdm
logger = logging.getLogger(__file__)
try:
from rapidfuzz import fuzz
except ImportError as exc:
logger.debug("rapidfuzz could not be imported. Run 'pip install farm-haystack[eval]' to fix this issue.")
fuzz = None # type: ignore
_CandidateScore = namedtuple("_CandidateScore", ["context_id", "candidate_id", "score"])
@ -46,6 +55,8 @@ def calculate_context_similarity(
we cut the context on the same side, recalculate the score and take the mean of both.
Thus [AB] <-> [BC] (score ~50) gets recalculated with B <-> B (score ~100) scoring ~75 in total.
"""
if not fuzz:
raise ImportError("rapidfuzz could not be imported. Run 'pip install farm-haystack[eval]' to fix this issue.")
# we need to handle short contexts/contents (e.g single word)
# as they produce high scores by matching if the chars of the word are contained in the other one
# this has to be done after normalizing

View File

@ -9,14 +9,18 @@ import sys
import torch
import transformers
import mlflow
from requests.exceptions import ConnectionError
from haystack import __version__
logger = logging.getLogger(__name__)
try:
import mlflow
except ImportError as exc:
logger.debug("mlflow could not be imported. Run 'pip install farm-haystack[eval]' to fix this issue.")
mlflow = None
def flatten_dict(dict_to_flatten: dict, prefix: str = ""):
flat_dict = {}
@ -160,6 +164,8 @@ class MLflowTrackingHead(BaseTrackingHead):
"""
Experiment tracking head for MLflow.
"""
if not mlflow:
raise ImportError("mlflow could not be imported. Run 'pip install farm-haystack[eval]' to fix this issue.")
super().__init__()
self.tracking_uri = tracking_uri
self.auto_track_environment = auto_track_environment

View File

@ -8,7 +8,6 @@ from collections import defaultdict
import pandas as pd
from haystack.schema import Document, Answer
from haystack.document_stores.sql import DocumentORM # type: ignore[attr-defined]
logger = logging.getLogger(__name__)
@ -183,6 +182,8 @@ def convert_labels_to_squad(labels_file: str):
:param labels_file: The path to the file containing labels.
:return:
"""
from haystack.document_stores.sql import DocumentORM # type: ignore[attr-defined]
with open(labels_file, encoding="utf-8") as label_file:
labels = json.load(label_file)

View File

@ -0,0 +1,17 @@
import logging
import numpy as np
logger = logging.getLogger(__name__)
try:
from numba import njit # pylint: disable=import-error
except (ImportError, ModuleNotFoundError):
logger.debug("Numba not found, replacing njit() with no-op implementation. Enable it with 'pip install numba'.")
def njit(f):
return f
@njit # (fastmath=True)
def expit(x: float) -> float:
return 1 / (1 + np.exp(-x))

View File

@ -53,6 +53,7 @@ dependencies = [
"nltk",
"pandas",
"rank_bm25",
"scikit-learn>=1.0.0", # TF-IDF, SklearnQueryClassifier and metrics
# Utils
"dill", # pickle extension for (de-)serialization
@ -79,23 +80,12 @@ dependencies = [
# See haystack/nodes/retriever/_embedding_encoder.py, _SentenceTransformersEmbeddingEncoder
"sentence-transformers>=2.2.0",
# for stats in run_classifier
"scipy>=1.3.2",
"scikit-learn>=1.0.0",
# Metrics and logging
"seqeval",
"mlflow",
# Elasticsearch
"elasticsearch>=7.7,<8",
# OpenAI tokenizer
"tiktoken>=0.3.0; python_version >= '3.8' and (platform_machine == 'AMD64' or platform_machine == 'amd64' or platform_machine == 'x86_64' or (platform_machine == 'arm64' and platform_system == 'Darwin'))",
# context matching
"rapidfuzz>=2.0.15,<2.8.0", # FIXME https://github.com/deepset-ai/haystack/pull/3199
# Schema validation
"jsonschema",
@ -188,6 +178,12 @@ onnx-gpu = [
"onnxruntime-gpu",
"onnxruntime_tools",
]
metrics = [ # for metrics
"scipy>=1.3.2",
"rapidfuzz>=2.0.15,<2.8.0", # FIXME https://github.com/deepset-ai/haystack/pull/3199
"seqeval",
"mlflow",
]
ray = [
"ray[serve]>=1.9.1,<2; platform_system != 'Windows'",
"ray[serve]>=1.9.1,<2,!=1.12.0; platform_system == 'Windows'", # Avoid 1.12.0 due to https://github.com/ray-project/ray/issues/24169 (fails on windows)
@ -228,11 +224,11 @@ formatting = [
]
all = [
"farm-haystack[docstores,audio,crawler,preprocessing,pdf,ocr,ray,dev,onnx,beir]",
"farm-haystack[docstores,audio,crawler,preprocessing,pdf,ocr,ray,dev,onnx,beir,metrics]",
]
all-gpu = [
# beir is incompatible with faiss-gpu: https://github.com/beir-cellar/beir/issues/71
"farm-haystack[docstores-gpu,audio,crawler,preprocessing,pdf,ocr,ray,dev,onnx-gpu]",
"farm-haystack[docstores-gpu,audio,crawler,preprocessing,pdf,ocr,ray,dev,onnx-gpu,metrics]",
]
[project.scripts]

View File

@ -11,7 +11,6 @@ from functools import wraps
import requests_cache
import responses
from sqlalchemy import create_engine, text
import posthog
import numpy as np
@ -670,6 +669,8 @@ def setup_postgres():
# logging.warning("Tried to start PostgreSQL through Docker but this failed. It is likely that there is already an existing instance running.")
# else:
# sleep(5)
from sqlalchemy import create_engine, text
engine = create_engine("postgresql://postgres:postgres@127.0.0.1/postgres", isolation_level="AUTOCOMMIT")
with engine.connect() as connection:
@ -683,6 +684,8 @@ def setup_postgres():
# TODO: Verify this is still necessary as it's called by no one
def teardown_postgres():
from sqlalchemy import create_engine, text
engine = create_engine("postgresql://postgres:postgres@127.0.0.1/postgres", isolation_level="AUTOCOMMIT")
with engine.connect() as connection:
connection.execute(text("DROP SCHEMA public CASCADE"))