mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 04:27:15 +00:00
refactor!: extract evaluation and statistical dependencies (#4457)
* try-catch sklearn and scipy * haystack imports * linting * mypy * try to import baseretriever * remove typing * unused import * remove more typing * pylint * isolate sql imports for postgres, which we don't use anyway * remove stats * replace expit * als inmemory * mypy * feedback * docker * expit * re-add njit
This commit is contained in:
parent
5d41e60d89
commit
ba11d1c2a8
@ -45,7 +45,7 @@ target "base-cpu" {
|
||||
build_image = "python:3.10-slim"
|
||||
base_image = "python:3.10-slim"
|
||||
haystack_version = "${HAYSTACK_VERSION}"
|
||||
haystack_extras = notequal("",HAYSTACK_EXTRAS) ? "${HAYSTACK_EXTRAS}" : "[docstores,crawler,preprocessing,ocr,onnx,beir]"
|
||||
haystack_extras = notequal("",HAYSTACK_EXTRAS) ? "${HAYSTACK_EXTRAS}" : "[docstores,crawler,preprocessing,ocr,onnx,metrics,beir]"
|
||||
}
|
||||
platforms = ["linux/amd64", "linux/arm64"]
|
||||
}
|
||||
@ -59,7 +59,7 @@ target "base-gpu" {
|
||||
build_image = "pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime"
|
||||
base_image = "pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime"
|
||||
haystack_version = "${HAYSTACK_VERSION}"
|
||||
haystack_extras = notequal("",HAYSTACK_EXTRAS) ? "${HAYSTACK_EXTRAS}" : "[docstores-gpu,crawler,preprocessing,ocr,onnx-gpu]"
|
||||
haystack_extras = notequal("",HAYSTACK_EXTRAS) ? "${HAYSTACK_EXTRAS}" : "[docstores-gpu,crawler,preprocessing,ocr,onnx-gpu,metrics]"
|
||||
}
|
||||
platforms = ["linux/amd64", "linux/arm64"]
|
||||
}
|
||||
|
||||
@ -17,6 +17,7 @@ from haystack.errors import DuplicateDocumentError, DocumentStoreError, Haystack
|
||||
from haystack.nodes.preprocessor import PreProcessor
|
||||
from haystack.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl
|
||||
from haystack.utils.labels import aggregate_labels
|
||||
from haystack.utils.scipy_utils import expit
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -31,11 +32,6 @@ except (ImportError, ModuleNotFoundError):
|
||||
return f
|
||||
|
||||
|
||||
@njit # (fastmath=True)
|
||||
def expit(x: float) -> float:
|
||||
return 1 / (1 + np.exp(-x))
|
||||
|
||||
|
||||
class BaseKnowledgeGraph(BaseComponent):
|
||||
"""
|
||||
Base class for implementing Knowledge Graphs.
|
||||
|
||||
@ -1,13 +1,22 @@
|
||||
import logging
|
||||
from typing import Union, List, Dict, Optional, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlalchemy.sql import select
|
||||
from sqlalchemy import and_, or_
|
||||
|
||||
from haystack.document_stores.utils import convert_date_to_rfc3339
|
||||
from haystack.errors import FilterError
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
try:
|
||||
from sqlalchemy.sql import select
|
||||
from sqlalchemy import and_, or_
|
||||
except ImportError as exc:
|
||||
logger.debug("sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue.")
|
||||
select = None
|
||||
and_ = None
|
||||
or_ = None
|
||||
|
||||
|
||||
def nested_defaultdict() -> defaultdict:
|
||||
"""
|
||||
@ -310,6 +319,10 @@ class NotOperation(LogicalFilterClause):
|
||||
return {"bool": {"must_not": conditions}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
if not select:
|
||||
raise ImportError(
|
||||
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
|
||||
)
|
||||
conditions = [
|
||||
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
|
||||
for condition in self.conditions
|
||||
@ -357,6 +370,10 @@ class AndOperation(LogicalFilterClause):
|
||||
return {"bool": {"must": conditions}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
if not select:
|
||||
raise ImportError(
|
||||
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
|
||||
)
|
||||
conditions = [
|
||||
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
|
||||
for condition in self.conditions
|
||||
@ -389,6 +406,10 @@ class OrOperation(LogicalFilterClause):
|
||||
return {"bool": {"should": conditions}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
if not select:
|
||||
raise ImportError(
|
||||
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
|
||||
)
|
||||
conditions = [
|
||||
meta_document_orm.document_id.in_(condition.convert_to_sql(meta_document_orm))
|
||||
for condition in self.conditions
|
||||
@ -434,6 +455,10 @@ class EqOperation(ComparisonOperation):
|
||||
return {"term": {self.field_name: self.comparison_value}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
if not select:
|
||||
raise ImportError(
|
||||
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
|
||||
)
|
||||
return select([meta_document_orm.document_id]).where(
|
||||
meta_document_orm.name == self.field_name, meta_document_orm.value == self.comparison_value
|
||||
)
|
||||
@ -466,6 +491,10 @@ class InOperation(ComparisonOperation):
|
||||
return {"terms": {self.field_name: self.comparison_value}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
if not select:
|
||||
raise ImportError(
|
||||
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
|
||||
)
|
||||
return select([meta_document_orm.document_id]).where(
|
||||
meta_document_orm.name == self.field_name, meta_document_orm.value.in_(self.comparison_value)
|
||||
)
|
||||
@ -508,6 +537,10 @@ class NeOperation(ComparisonOperation):
|
||||
return {"bool": {"must_not": {"term": {self.field_name: self.comparison_value}}}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
if not select:
|
||||
raise ImportError(
|
||||
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
|
||||
)
|
||||
return select([meta_document_orm.document_id]).where(
|
||||
meta_document_orm.name == self.field_name, meta_document_orm.value != self.comparison_value
|
||||
)
|
||||
@ -540,6 +573,10 @@ class NinOperation(ComparisonOperation):
|
||||
return {"bool": {"must_not": {"terms": {self.field_name: self.comparison_value}}}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
if not select:
|
||||
raise ImportError(
|
||||
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
|
||||
)
|
||||
return select([meta_document_orm.document_id]).where(
|
||||
meta_document_orm.name == self.field_name, meta_document_orm.value.notin_(self.comparison_value)
|
||||
)
|
||||
@ -582,6 +619,10 @@ class GtOperation(ComparisonOperation):
|
||||
return {"range": {self.field_name: {"gt": self.comparison_value}}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
if not select:
|
||||
raise ImportError(
|
||||
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
|
||||
)
|
||||
return select([meta_document_orm.document_id]).where(
|
||||
meta_document_orm.name == self.field_name, meta_document_orm.value > self.comparison_value
|
||||
)
|
||||
@ -617,6 +658,10 @@ class GteOperation(ComparisonOperation):
|
||||
return {"range": {self.field_name: {"gte": self.comparison_value}}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
if not select:
|
||||
raise ImportError(
|
||||
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
|
||||
)
|
||||
return select([meta_document_orm.document_id]).where(
|
||||
meta_document_orm.name == self.field_name, meta_document_orm.value >= self.comparison_value
|
||||
)
|
||||
@ -652,6 +697,10 @@ class LtOperation(ComparisonOperation):
|
||||
return {"range": {self.field_name: {"lt": self.comparison_value}}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
if not select:
|
||||
raise ImportError(
|
||||
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
|
||||
)
|
||||
return select([meta_document_orm.document_id]).where(
|
||||
meta_document_orm.name == self.field_name, meta_document_orm.value < self.comparison_value
|
||||
)
|
||||
@ -687,6 +736,10 @@ class LteOperation(ComparisonOperation):
|
||||
return {"range": {self.field_name: {"lte": self.comparison_value}}}
|
||||
|
||||
def convert_to_sql(self, meta_document_orm):
|
||||
if not select:
|
||||
raise ImportError(
|
||||
"sqlalchemy could not be imported. Run 'pip install farm-haystack[sql]' to fix this issue."
|
||||
)
|
||||
return select([meta_document_orm.document_id]).where(
|
||||
meta_document_orm.name == self.field_name, meta_document_orm.value <= self.comparison_value
|
||||
)
|
||||
|
||||
@ -16,7 +16,6 @@ import torch
|
||||
from tqdm.auto import tqdm
|
||||
import rank_bm25
|
||||
import pandas as pd
|
||||
from scipy.special import expit
|
||||
|
||||
from haystack.schema import Document, FilterType, Label
|
||||
from haystack.errors import DuplicateDocumentError, DocumentStoreError
|
||||
@ -24,7 +23,9 @@ from haystack.document_stores import KeywordDocumentStore
|
||||
from haystack.document_stores.base import get_batches_from_generator
|
||||
from haystack.modeling.utils import initialize_device_settings
|
||||
from haystack.document_stores.filter_utils import LogicalFilterClause
|
||||
from haystack.nodes.retriever import DenseRetriever
|
||||
from haystack.nodes.retriever.dense import DenseRetriever
|
||||
from haystack.utils.scipy_utils import expit
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -10,13 +10,13 @@ try:
|
||||
from pymilvus import FieldSchema, CollectionSchema, Collection, connections, utility
|
||||
from pymilvus.client.abstract import QueryResult
|
||||
from pymilvus.client.types import DataType
|
||||
from haystack.document_stores.sql import SQLDocumentStore # type: ignore
|
||||
except (ImportError, ModuleNotFoundError) as ie:
|
||||
from haystack.utils.import_utils import _optional_component_not_installed
|
||||
|
||||
_optional_component_not_installed(__name__, "milvus2", ie)
|
||||
_optional_component_not_installed(__name__, "milvus", ie)
|
||||
|
||||
from haystack.schema import Document, FilterType
|
||||
from haystack.document_stores import SQLDocumentStore
|
||||
from haystack.document_stores.base import get_batches_from_generator
|
||||
from haystack.nodes.retriever import DenseRetriever
|
||||
|
||||
|
||||
@ -10,7 +10,6 @@ import time
|
||||
from string import Template
|
||||
|
||||
import numpy as np
|
||||
from scipy.special import expit
|
||||
from tqdm.auto import tqdm
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
|
||||
@ -20,6 +19,7 @@ from haystack.document_stores.base import get_batches_from_generator
|
||||
from haystack.document_stores.filter_utils import LogicalFilterClause
|
||||
from haystack.errors import DocumentStoreError, HaystackError
|
||||
from haystack.nodes.retriever import DenseRetriever
|
||||
from haystack.utils.scipy_utils import expit
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -3,18 +3,36 @@ from functools import reduce
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
from sentence_transformers import CrossEncoder, SentenceTransformer
|
||||
from seqeval.metrics import classification_report as token_classification_report
|
||||
from sklearn.metrics import classification_report, f1_score, matthews_corrcoef, mean_squared_error, r2_score
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from transformers import AutoConfig
|
||||
|
||||
from haystack.modeling.model.prediction_head import PredictionHead
|
||||
from haystack.modeling.utils import flatten_list
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
from sklearn.metrics import classification_report, f1_score, matthews_corrcoef, mean_squared_error, r2_score
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
except ImportError as exc:
|
||||
logger.debug("scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue.")
|
||||
pearsonr = None
|
||||
spearmanr = None
|
||||
classification_report = None
|
||||
f1_score = None
|
||||
matthews_corrcoef = None
|
||||
mean_squared_error = None
|
||||
r2_score = None
|
||||
cosine_similarity = None
|
||||
|
||||
try:
|
||||
from seqeval.metrics import classification_report as token_classification_report
|
||||
except ImportError as exc:
|
||||
logger.debug("seqeval could not be imported. Run 'pip install farm-haystack[eval]' to fix this issue.")
|
||||
token_classification_report = None
|
||||
|
||||
|
||||
registered_metrics = {}
|
||||
registered_reports = {}
|
||||
@ -53,16 +71,28 @@ def simple_accuracy(preds, labels):
|
||||
|
||||
|
||||
def acc_and_f1(preds, labels):
|
||||
if not f1_score:
|
||||
raise ImportError(
|
||||
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
|
||||
)
|
||||
acc = simple_accuracy(preds, labels)
|
||||
f1 = f1_score(y_true=labels, y_pred=preds)
|
||||
return {"acc": acc["acc"], "f1": f1, "acc_and_f1": (acc["acc"] + f1) / 2}
|
||||
|
||||
|
||||
def f1_macro(preds, labels):
|
||||
if not f1_score:
|
||||
raise ImportError(
|
||||
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
|
||||
)
|
||||
return {"f1_macro": f1_score(y_true=labels, y_pred=preds, average="macro")}
|
||||
|
||||
|
||||
def pearson_and_spearman(preds, labels):
|
||||
if not pearsonr or not spearmanr:
|
||||
raise ImportError(
|
||||
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
|
||||
)
|
||||
pearson_corr = pearsonr(preds, labels)[0]
|
||||
spearman_corr = spearmanr(preds, labels)[0]
|
||||
return {"pearson": pearson_corr, "spearman": spearman_corr, "corr": (pearson_corr + spearman_corr) / 2}
|
||||
@ -94,6 +124,14 @@ def compute_metrics(metric: str, preds, labels):
|
||||
}
|
||||
assert len(preds) == len(labels)
|
||||
if metric in FUNCTION_FOR_METRIC.keys():
|
||||
if ( # pylint: disable=too-many-boolean-expressions
|
||||
(metric == "mcc" and not matthews_corrcoef)
|
||||
or (metric == "mse" and not mean_squared_error)
|
||||
or (metric == "r2" and not r2_score)
|
||||
):
|
||||
raise ImportError(
|
||||
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
|
||||
)
|
||||
return FUNCTION_FOR_METRIC[metric](preds, labels)
|
||||
elif isinstance(metric, list):
|
||||
ret = {}
|
||||
@ -111,8 +149,14 @@ def compute_report_metrics(head: PredictionHead, preds, labels):
|
||||
if head.ph_output_type in registered_reports:
|
||||
report_fn = registered_reports[head.ph_output_type] # type: ignore [index]
|
||||
elif head.ph_output_type == "per_token":
|
||||
if not token_classification_report:
|
||||
raise ImportError("seqeval could not be imported. Run 'pip install farm-haystack[eval]' to fix this issue.")
|
||||
report_fn = token_classification_report
|
||||
elif head.ph_output_type == "per_sequence":
|
||||
if not classification_report:
|
||||
raise ImportError(
|
||||
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
|
||||
)
|
||||
report_fn = classification_report
|
||||
elif head.ph_output_type == "per_token_squad":
|
||||
report_fn = lambda *args, **kwargs: "Not Implemented" # pylint: disable=unnecessary-lambda-assignment
|
||||
@ -407,6 +451,11 @@ def semantic_answer_similarity(
|
||||
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
:return: top_1_sas, top_k_sas, pred_label_matrix
|
||||
"""
|
||||
if not cosine_similarity:
|
||||
raise ImportError(
|
||||
"scipy or sklearn could not be imported. Run 'pip install farm-haystack[metrics]' to fix this issue."
|
||||
)
|
||||
|
||||
assert len(predictions) == len(gold_labels)
|
||||
|
||||
config = AutoConfig.from_pretrained(sas_model_name_or_path, use_auth_token=use_auth_token)
|
||||
|
||||
@ -10,11 +10,11 @@ from torch import nn
|
||||
from torch import optim
|
||||
from torch.nn import CrossEntropyLoss, NLLLoss
|
||||
from transformers import AutoModelForQuestionAnswering
|
||||
from scipy.special import expit
|
||||
|
||||
from haystack.modeling.data_handler.samples import SampleBasket
|
||||
from haystack.modeling.model.predictions import QACandidate, QAPred
|
||||
from haystack.modeling.utils import try_get, all_gather_list
|
||||
from haystack.utils.scipy_utils import expit
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -15,7 +15,6 @@ from transformers import (
|
||||
|
||||
from haystack.schema import Document
|
||||
from haystack.nodes.answer_generator.base import BaseGenerator
|
||||
from haystack.nodes.retriever.dense import DensePassageRetriever
|
||||
from haystack.modeling.utils import initialize_device_settings
|
||||
|
||||
|
||||
@ -69,7 +68,7 @@ class RAGenerator(BaseGenerator):
|
||||
self,
|
||||
model_name_or_path: str = "facebook/rag-token-nq",
|
||||
model_version: Optional[str] = None,
|
||||
retriever: Optional[DensePassageRetriever] = None,
|
||||
retriever=None,
|
||||
generator_type: str = "token",
|
||||
top_k: int = 2,
|
||||
max_length: int = 200,
|
||||
|
||||
@ -9,9 +9,9 @@ from tqdm.auto import tqdm
|
||||
from haystack.modeling.utils import initialize_device_settings
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.nodes.question_generator import QuestionGenerator
|
||||
from haystack.nodes.retriever.base import BaseRetriever
|
||||
from haystack.schema import Document
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -60,7 +60,7 @@ class PseudoLabelGenerator(BaseComponent):
|
||||
def __init__(
|
||||
self,
|
||||
question_producer: Union[QuestionGenerator, List[Dict[str, str]]],
|
||||
retriever: BaseRetriever,
|
||||
retriever,
|
||||
cross_encoder_model_name_or_path: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
max_questions_per_document: int = 3,
|
||||
top_k: int = 50,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
import itertools
|
||||
import logging
|
||||
|
||||
from abc import abstractmethod
|
||||
from copy import deepcopy
|
||||
@ -7,10 +8,13 @@ from functools import wraps
|
||||
from time import perf_counter
|
||||
|
||||
import numpy as np
|
||||
from scipy.special import expit
|
||||
|
||||
from haystack.schema import Document, Answer, Span, MultiLabel
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.utils.scipy_utils import expit
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseReader(BaseComponent):
|
||||
|
||||
@ -1,13 +1,22 @@
|
||||
from typing import Generator, Iterable, Optional, Tuple, List, Union
|
||||
|
||||
import re
|
||||
import logging
|
||||
from itertools import groupby
|
||||
from multiprocessing.pool import Pool
|
||||
from collections import namedtuple
|
||||
|
||||
from rapidfuzz import fuzz
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
try:
|
||||
from rapidfuzz import fuzz
|
||||
except ImportError as exc:
|
||||
logger.debug("rapidfuzz could not be imported. Run 'pip install farm-haystack[eval]' to fix this issue.")
|
||||
fuzz = None # type: ignore
|
||||
|
||||
|
||||
_CandidateScore = namedtuple("_CandidateScore", ["context_id", "candidate_id", "score"])
|
||||
|
||||
@ -46,6 +55,8 @@ def calculate_context_similarity(
|
||||
we cut the context on the same side, recalculate the score and take the mean of both.
|
||||
Thus [AB] <-> [BC] (score ~50) gets recalculated with B <-> B (score ~100) scoring ~75 in total.
|
||||
"""
|
||||
if not fuzz:
|
||||
raise ImportError("rapidfuzz could not be imported. Run 'pip install farm-haystack[eval]' to fix this issue.")
|
||||
# we need to handle short contexts/contents (e.g single word)
|
||||
# as they produce high scores by matching if the chars of the word are contained in the other one
|
||||
# this has to be done after normalizing
|
||||
|
||||
@ -9,14 +9,18 @@ import sys
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
import mlflow
|
||||
from requests.exceptions import ConnectionError
|
||||
|
||||
from haystack import __version__
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import mlflow
|
||||
except ImportError as exc:
|
||||
logger.debug("mlflow could not be imported. Run 'pip install farm-haystack[eval]' to fix this issue.")
|
||||
mlflow = None
|
||||
|
||||
|
||||
def flatten_dict(dict_to_flatten: dict, prefix: str = ""):
|
||||
flat_dict = {}
|
||||
@ -160,6 +164,8 @@ class MLflowTrackingHead(BaseTrackingHead):
|
||||
"""
|
||||
Experiment tracking head for MLflow.
|
||||
"""
|
||||
if not mlflow:
|
||||
raise ImportError("mlflow could not be imported. Run 'pip install farm-haystack[eval]' to fix this issue.")
|
||||
super().__init__()
|
||||
self.tracking_uri = tracking_uri
|
||||
self.auto_track_environment = auto_track_environment
|
||||
|
||||
@ -8,7 +8,6 @@ from collections import defaultdict
|
||||
import pandas as pd
|
||||
|
||||
from haystack.schema import Document, Answer
|
||||
from haystack.document_stores.sql import DocumentORM # type: ignore[attr-defined]
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -183,6 +182,8 @@ def convert_labels_to_squad(labels_file: str):
|
||||
:param labels_file: The path to the file containing labels.
|
||||
:return:
|
||||
"""
|
||||
from haystack.document_stores.sql import DocumentORM # type: ignore[attr-defined]
|
||||
|
||||
with open(labels_file, encoding="utf-8") as label_file:
|
||||
labels = json.load(label_file)
|
||||
|
||||
|
||||
17
haystack/utils/scipy_utils.py
Normal file
17
haystack/utils/scipy_utils.py
Normal file
@ -0,0 +1,17 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from numba import njit # pylint: disable=import-error
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logger.debug("Numba not found, replacing njit() with no-op implementation. Enable it with 'pip install numba'.")
|
||||
|
||||
def njit(f):
|
||||
return f
|
||||
|
||||
|
||||
@njit # (fastmath=True)
|
||||
def expit(x: float) -> float:
|
||||
return 1 / (1 + np.exp(-x))
|
||||
@ -53,6 +53,7 @@ dependencies = [
|
||||
"nltk",
|
||||
"pandas",
|
||||
"rank_bm25",
|
||||
"scikit-learn>=1.0.0", # TF-IDF, SklearnQueryClassifier and metrics
|
||||
|
||||
# Utils
|
||||
"dill", # pickle extension for (de-)serialization
|
||||
@ -79,23 +80,12 @@ dependencies = [
|
||||
# See haystack/nodes/retriever/_embedding_encoder.py, _SentenceTransformersEmbeddingEncoder
|
||||
"sentence-transformers>=2.2.0",
|
||||
|
||||
# for stats in run_classifier
|
||||
"scipy>=1.3.2",
|
||||
"scikit-learn>=1.0.0",
|
||||
|
||||
# Metrics and logging
|
||||
"seqeval",
|
||||
"mlflow",
|
||||
|
||||
# Elasticsearch
|
||||
"elasticsearch>=7.7,<8",
|
||||
|
||||
# OpenAI tokenizer
|
||||
"tiktoken>=0.3.0; python_version >= '3.8' and (platform_machine == 'AMD64' or platform_machine == 'amd64' or platform_machine == 'x86_64' or (platform_machine == 'arm64' and platform_system == 'Darwin'))",
|
||||
|
||||
# context matching
|
||||
"rapidfuzz>=2.0.15,<2.8.0", # FIXME https://github.com/deepset-ai/haystack/pull/3199
|
||||
|
||||
# Schema validation
|
||||
"jsonschema",
|
||||
|
||||
@ -188,6 +178,12 @@ onnx-gpu = [
|
||||
"onnxruntime-gpu",
|
||||
"onnxruntime_tools",
|
||||
]
|
||||
metrics = [ # for metrics
|
||||
"scipy>=1.3.2",
|
||||
"rapidfuzz>=2.0.15,<2.8.0", # FIXME https://github.com/deepset-ai/haystack/pull/3199
|
||||
"seqeval",
|
||||
"mlflow",
|
||||
]
|
||||
ray = [
|
||||
"ray[serve]>=1.9.1,<2; platform_system != 'Windows'",
|
||||
"ray[serve]>=1.9.1,<2,!=1.12.0; platform_system == 'Windows'", # Avoid 1.12.0 due to https://github.com/ray-project/ray/issues/24169 (fails on windows)
|
||||
@ -228,11 +224,11 @@ formatting = [
|
||||
]
|
||||
|
||||
all = [
|
||||
"farm-haystack[docstores,audio,crawler,preprocessing,pdf,ocr,ray,dev,onnx,beir]",
|
||||
"farm-haystack[docstores,audio,crawler,preprocessing,pdf,ocr,ray,dev,onnx,beir,metrics]",
|
||||
]
|
||||
all-gpu = [
|
||||
# beir is incompatible with faiss-gpu: https://github.com/beir-cellar/beir/issues/71
|
||||
"farm-haystack[docstores-gpu,audio,crawler,preprocessing,pdf,ocr,ray,dev,onnx-gpu]",
|
||||
"farm-haystack[docstores-gpu,audio,crawler,preprocessing,pdf,ocr,ray,dev,onnx-gpu,metrics]",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@ -11,7 +11,6 @@ from functools import wraps
|
||||
|
||||
import requests_cache
|
||||
import responses
|
||||
from sqlalchemy import create_engine, text
|
||||
import posthog
|
||||
|
||||
import numpy as np
|
||||
@ -670,6 +669,8 @@ def setup_postgres():
|
||||
# logging.warning("Tried to start PostgreSQL through Docker but this failed. It is likely that there is already an existing instance running.")
|
||||
# else:
|
||||
# sleep(5)
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
engine = create_engine("postgresql://postgres:postgres@127.0.0.1/postgres", isolation_level="AUTOCOMMIT")
|
||||
|
||||
with engine.connect() as connection:
|
||||
@ -683,6 +684,8 @@ def setup_postgres():
|
||||
|
||||
# TODO: Verify this is still necessary as it's called by no one
|
||||
def teardown_postgres():
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
engine = create_engine("postgresql://postgres:postgres@127.0.0.1/postgres", isolation_level="AUTOCOMMIT")
|
||||
with engine.connect() as connection:
|
||||
connection.execute(text("DROP SCHEMA public CASCADE"))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user