mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-14 09:28:56 +00:00
Log evaluation results to MLflow (#2337)
* track eval results in mlflow * Update Documentation & Code Style * add pipeline.yaml and environment info * improve logging to mlflow * Update Documentation & Code Style * introduce ExperimentTracker * Update Documentation & Code Style * move modeling.utils.logger to utils.experiment_tracking * renaming: tracker and TrackingHead * Update Documentation & Code Style * refactor env tracking * fix pylint findings * Update Documentation & Code Style * rename MLFlowTrackingHead to MLflowTrackingHead * implement dataset hash * Update Documentation & Code Style * set docstrings * Update Documentation & Code Style * introduce PipelineBundle and Corpus * Update Documentation & Code Style * support reusing index * Update Documentation & Code Style * rename Corpus to FileCorpus * fix Corpus -> FileCorpus * Update Documentation & Code Style * resolve cyclic dependencies * fix linter issues * Update Documentation & Code Style * remove helper classes * Update Documentation & Code Style * fix imports * fix another unused import * update docstrings * Update Documentation & Code Style * simplify usage of experiment tracking tools * fix Literal import * revert schema changes * Update Documentation & Code Style * always end run * Update Documentation & Code Style * fix mypy issue * rename to execute_eval_run * Update Documentation & Code Style * fix merge of get_or_create_env_meta_data * improve docstrings * Update Documentation & Code Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
c401e86099
commit
60ff46e4e1
@ -464,6 +464,96 @@ If True the index will be kept after beir evaluation. Otherwise it will be delet
|
||||
Returns a tuple containing the ncdg, map, recall and precision scores.
|
||||
Each metric is represented by a dictionary containing the scores for each top_k value.
|
||||
|
||||
<a id="base.Pipeline.execute_eval_run"></a>
|
||||
|
||||
#### execute\_eval\_run
|
||||
|
||||
```python
|
||||
@classmethod
|
||||
def execute_eval_run(cls, index_pipeline: Pipeline, query_pipeline: Pipeline, evaluation_set_labels: List[MultiLabel], corpus_file_paths: List[str], experiment_name: str, experiment_run_name: str, experiment_tracking_tool: Literal["mlflow", None] = None, experiment_tracking_uri: Optional[str] = None, corpus_file_metas: List[Dict[str, Any]] = None, corpus_meta: Dict[str, Any] = {}, evaluation_set_meta: Dict[str, Any] = {}, pipeline_meta: Dict[str, Any] = {}, index_params: dict = {}, query_params: dict = {}, sas_model_name_or_path: str = None, sas_batch_size: int = 32, sas_use_gpu: bool = True, add_isolated_node_eval: bool = False, reuse_index: bool = False) -> EvaluationResult
|
||||
```
|
||||
|
||||
Starts an experiment run that first indexes the specified files (forming a corpus) using the index pipeline
|
||||
|
||||
and subsequently evaluates the query pipeline on the provided labels (forming an evaluation set) using pipeline.eval().
|
||||
Parameters and results (metrics and predictions) of the run are tracked by an experiment tracking tool for further analysis.
|
||||
You can specify the experiment tracking tool by setting the params `experiment_tracking_tool` and `experiment_tracking_uri`
|
||||
or by passing a (custom) tracking head to Tracker.set_tracking_head().
|
||||
Note, that `experiment_tracking_tool` only supports `mlflow` currently.
|
||||
|
||||
For easier comparison you can pass additional metadata regarding corpus (corpus_meta), evaluation set (evaluation_set_meta) and pipelines (pipeline_meta).
|
||||
E.g. you can give them names or ids to identify them across experiment runs.
|
||||
|
||||
This method executes an experiment run. Each experiment run is part of at least one experiment.
|
||||
An experiment typically consists of multiple runs to be compared (e.g. using different retrievers in query pipeline).
|
||||
Experiment tracking tools usually share the same concepts of experiments and provide additional functionality to easily compare runs across experiments.
|
||||
|
||||
E.g. you can call execute_eval_run() multiple times with different retrievers in your query pipeline and compare the runs in mlflow:
|
||||
|
||||
```python
|
||||
| for retriever_type, query_pipeline in zip(["sparse", "dpr", "embedding"], [sparse_pipe, dpr_pipe, embedding_pipe]):
|
||||
| eval_result = Pipeline.execute_eval_run(
|
||||
| index_pipeline=index_pipeline,
|
||||
| query_pipeline=query_pipeline,
|
||||
| evaluation_set_labels=labels,
|
||||
| corpus_file_paths=file_paths,
|
||||
| corpus_file_metas=file_metas,
|
||||
| experiment_tracking_tool="mlflow",
|
||||
| experiment_tracking_uri="http://localhost:5000",
|
||||
| experiment_name="my-retriever-experiment",
|
||||
| experiment_run_name=f"run_{retriever_type}",
|
||||
| pipeline_meta={"name": f"my-pipeline-{retriever_type}"},
|
||||
| evaluation_set_meta={"name": "my-evalset"},
|
||||
| corpus_meta={"name": "my-corpus"}.
|
||||
| reuse_index=False
|
||||
| )
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `index_pipeline`: The indexing pipeline to use.
|
||||
- `query_pipeline`: The query pipeline to evaluate.
|
||||
- `evaluation_set_labels`: The labels to evaluate on forming an evalution set.
|
||||
- `corpus_file_paths`: The files to be indexed and searched during evaluation forming a corpus.
|
||||
- `experiment_name`: The name of the experiment
|
||||
- `experiment_run_name`: The name of the experiment run
|
||||
- `experiment_tracking_tool`: The experiment tracking tool to be used. Currently we only support "mlflow".
|
||||
If left unset the current TrackingHead specified by Tracker.set_tracking_head() will be used.
|
||||
- `experiment_tracking_uri`: The uri of the experiment tracking server to be used. Must be specified if experiment_tracking_tool is set.
|
||||
You can use deepset's public mlflow server via https://public-mlflow.deepset.ai/.
|
||||
Note, that artifact logging (e.g. Pipeline YAML or evaluation result CSVs) are currently not allowed on deepset's public mlflow server as this might expose sensitive data.
|
||||
- `corpus_file_metas`: The optional metadata to be stored for each corpus file (e.g. title).
|
||||
- `corpus_meta`: Metadata about the corpus to track (e.g. name, date, author, version).
|
||||
- `evaluation_set_meta`: Metadata about the evalset to track (e.g. name, date, author, version).
|
||||
- `pipeline_meta`: Metadata about the pipelines to track (e.g. name, author, version).
|
||||
- `index_params`: The params to use during indexing (see pipeline.run's params).
|
||||
- `query_params`: The params to use during querying (see pipeline.run's params).
|
||||
- `sas_model_name_or_path`: Name or path of "Semantic Answer Similarity (SAS) model". When set, the model will be used to calculate similarity between predictions and labels and generate the SAS metric.
|
||||
The SAS metric correlates better with human judgement of correct answers as it does not rely on string overlaps.
|
||||
Example: Prediction = "30%", Label = "thirty percent", EM and F1 would be overly pessimistic with both being 0, while SAS paints a more realistic picture.
|
||||
More info in the paper: https://arxiv.org/abs/2108.06130
|
||||
Models:
|
||||
- You can use Bi Encoders (sentence transformers) or cross encoders trained on Semantic Textual Similarity (STS) data.
|
||||
Not all cross encoders can be used because of different return types.
|
||||
If you use custom cross encoders please make sure they work with sentence_transformers.CrossEncoder class
|
||||
- Good default for multiple languages: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
- Large, powerful, but slow model for English only: "cross-encoder/stsb-roberta-large"
|
||||
- Large model for German only: "deepset/gbert-large-sts"
|
||||
- `sas_batch_size`: Number of prediction label pairs to encode at once by CrossEncoder or SentenceTransformer while calculating SAS.
|
||||
- `sas_use_gpu`: Whether to use a GPU or the CPU for calculating semantic answer similarity.
|
||||
Falls back to CPU if no GPU is available.
|
||||
- `add_isolated_node_eval`: If set to True, in addition to the integrated evaluation of the pipeline, each node is evaluated in isolated evaluation mode.
|
||||
This mode helps to understand the bottlenecks of a pipeline in terms of output quality of each individual node.
|
||||
If a node performs much better in the isolated evaluation than in the integrated evaluation, the previous node needs to be optimized to improve the pipeline's performance.
|
||||
If a node's performance is similar in both modes, this node itself needs to be optimized to improve the pipeline's performance.
|
||||
The isolated evaluation calculates the upper bound of each node's evaluation metrics under the assumption that it received perfect inputs from the previous node.
|
||||
To this end, labels are used as input to the node instead of the output of the previous node in the pipeline.
|
||||
The generated dataframes in the EvaluationResult then contain additional rows, which can be distinguished from the integrated evaluation results based on the
|
||||
values "integrated" or "isolated" in the column "eval_mode" and the evaluation report then additionally lists the upper bound of each node's evaluation metrics.
|
||||
- `reuse_index`: Whether to reuse existing non-empty index and to keep the index after evaluation.
|
||||
If True the index will be kept after evaluation and no indexing will take place if index has already documents. Otherwise it will be deleted immediately afterwards.
|
||||
Defaults to False.
|
||||
|
||||
<a id="base.Pipeline.eval"></a>
|
||||
|
||||
#### eval
|
||||
|
@ -22,7 +22,7 @@ logging.getLogger("haystack").setLevel(logging.INFO)
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from haystack.schema import Document, Answer, Label, MultiLabel, Span
|
||||
from haystack.schema import Document, Answer, Label, MultiLabel, Span, EvaluationResult
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.pipelines.base import Pipeline
|
||||
|
||||
@ -104,7 +104,6 @@ except ImportError:
|
||||
pass
|
||||
|
||||
from haystack.modeling.evaluation import eval
|
||||
from haystack.modeling.logger import MLFlowLogger, StdoutLogger, TensorBoardLogger
|
||||
from haystack.nodes.other import JoinDocuments, Docs2Answers, JoinAnswers, RouteDocuments
|
||||
from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier
|
||||
from haystack.nodes.file_classifier import FileTypeClassifier
|
||||
@ -178,9 +177,6 @@ if graph_retriever:
|
||||
# Adding them to sys.modules would enable `import haystack.pipelines.JoinDocuments`,
|
||||
# which I believe it's a very rare import style.
|
||||
setattr(file_converter, "FileTypeClassifier", FileTypeClassifier)
|
||||
setattr(modeling_utils, "MLFlowLogger", MLFlowLogger)
|
||||
setattr(modeling_utils, "StdoutLogger", StdoutLogger)
|
||||
setattr(modeling_utils, "TensorBoardLogger", TensorBoardLogger)
|
||||
setattr(pipelines, "JoinDocuments", JoinDocuments)
|
||||
setattr(pipelines, "Docs2Answers", Docs2Answers)
|
||||
setattr(pipelines, "SklearnQueryClassifier", SklearnQueryClassifier)
|
||||
|
62
haystack/environment.py
Normal file
62
haystack/environment.py
Normal file
@ -0,0 +1,62 @@
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from haystack import __version__
|
||||
|
||||
|
||||
HAYSTACK_EXECUTION_CONTEXT = "HAYSTACK_EXECUTION_CONTEXT"
|
||||
HAYSTACK_DOCKER_CONTAINER = "HAYSTACK_DOCKER_CONTAINER"
|
||||
|
||||
|
||||
env_meta_data: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def get_or_create_env_meta_data() -> Dict[str, Any]:
|
||||
"""
|
||||
Collects meta data about the setup that is used with Haystack, such as: operating system, python version, Haystack version, transformers version, pytorch version, number of GPUs, execution environment, and the value stored in the env variable HAYSTACK_EXECUTION_CONTEXT.
|
||||
"""
|
||||
global env_meta_data # pylint: disable=global-statement
|
||||
if not env_meta_data:
|
||||
env_meta_data = {
|
||||
"os_version": platform.release(),
|
||||
"os_family": platform.system(),
|
||||
"os_machine": platform.machine(),
|
||||
"python_version": platform.python_version(),
|
||||
"haystack_version": __version__,
|
||||
"transformers_version": transformers.__version__,
|
||||
"torch_version": torch.__version__,
|
||||
"torch_cuda_version": torch.version.cuda if torch.cuda.is_available() else 0,
|
||||
"n_gpu": torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
||||
"n_cpu": os.cpu_count(),
|
||||
"context": os.environ.get(HAYSTACK_EXECUTION_CONTEXT),
|
||||
"execution_env": _get_execution_environment(),
|
||||
}
|
||||
return env_meta_data
|
||||
|
||||
|
||||
def _get_execution_environment():
|
||||
"""
|
||||
Identifies the execution environment that Haystack is running in.
|
||||
Options are: colab notebook, kubernetes, CPU/GPU docker container, test environment, jupyter notebook, python script
|
||||
"""
|
||||
if os.environ.get("CI", "False").lower() == "true":
|
||||
execution_env = "ci"
|
||||
elif "google.colab" in sys.modules:
|
||||
execution_env = "colab"
|
||||
elif "KUBERNETES_SERVICE_HOST" in os.environ:
|
||||
execution_env = "kubernetes"
|
||||
elif HAYSTACK_DOCKER_CONTAINER in os.environ:
|
||||
execution_env = os.environ.get(HAYSTACK_DOCKER_CONTAINER)
|
||||
# check if pytest is imported
|
||||
elif "pytest" in sys.modules:
|
||||
execution_env = "test"
|
||||
else:
|
||||
try:
|
||||
execution_env = get_ipython().__class__.__name__ # pylint: disable=undefined-variable
|
||||
except NameError:
|
||||
execution_env = "script"
|
||||
return execution_env
|
@ -1411,10 +1411,10 @@
|
||||
"title": "Use Auth Token",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "boolean"
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
"type": "boolean"
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -1682,10 +1682,10 @@
|
||||
"title": "Use Auth Token",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "boolean"
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
"type": "boolean"
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -1949,10 +1949,10 @@
|
||||
"title": "Use Auth Token",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "boolean"
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
"type": "boolean"
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -3124,10 +3124,10 @@
|
||||
"title": "Use Auth Token",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "boolean"
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
"type": "boolean"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
3730
haystack/json-schemas/haystack-pipeline-1.3.1rc0.schema.json
Normal file
3730
haystack/json-schemas/haystack-pipeline-1.3.1rc0.schema.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -18,7 +18,7 @@ from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
||||
|
||||
from haystack.modeling.data_handler.dataloader import NamedDataLoader
|
||||
from haystack.modeling.data_handler.processor import Processor
|
||||
from haystack.modeling.logger import MLFlowLogger as MlLogger
|
||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||
from haystack.modeling.utils import log_ascii_workers, grouper, calc_chunksize
|
||||
from haystack.modeling.visual import TRACTOR_SMALL
|
||||
|
||||
@ -497,7 +497,7 @@ class DataSilo:
|
||||
logger.info("Average passage length after clipping: {}".format(ave_len[1]))
|
||||
logger.info("Proportion passages clipped: {}".format(clipped[1]))
|
||||
|
||||
MlLogger.log_params(
|
||||
tracker.track_params(
|
||||
{
|
||||
"n_samples_train": self.counts["train"],
|
||||
"n_samples_dev": self.counts["dev"],
|
||||
|
@ -31,7 +31,7 @@ from haystack.modeling.data_handler.samples import (
|
||||
offset_to_token_idx_vecorized,
|
||||
)
|
||||
from haystack.modeling.data_handler.input_features import sample_to_features_text
|
||||
from haystack.modeling.logger import MLFlowLogger as MlLogger
|
||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||
|
||||
|
||||
DOWNSTREAM_TASK_MAP = {
|
||||
@ -359,7 +359,7 @@ class Processor(ABC):
|
||||
for name in names:
|
||||
value = getattr(self, name)
|
||||
params.update({name: str(value)})
|
||||
MlLogger.log_params(params)
|
||||
tracker.track_params(params)
|
||||
|
||||
|
||||
class SquadProcessor(Processor):
|
||||
|
@ -8,7 +8,7 @@ from tqdm import tqdm
|
||||
|
||||
from haystack.modeling.evaluation.metrics import compute_metrics, compute_report_metrics
|
||||
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
||||
from haystack.modeling.logger import MLFlowLogger as MlLogger
|
||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||
from haystack.modeling.visual import BUSH_SEP
|
||||
|
||||
|
||||
@ -157,11 +157,11 @@ class Evaluator:
|
||||
for head_num, head in enumerate(results):
|
||||
logger.info("\n _________ {} _________".format(head["task_name"]))
|
||||
for metric_name, metric_val in head.items():
|
||||
# log with ML framework (e.g. Mlflow)
|
||||
# log with experiment tracking framework (e.g. Mlflow)
|
||||
if logging:
|
||||
if not metric_name in ["preds", "labels"] and not metric_name.startswith("_"):
|
||||
if isinstance(metric_val, numbers.Number):
|
||||
MlLogger.log_metrics(
|
||||
tracker.track_metrics(
|
||||
metrics={f"{dataset_name}_{metric_name}_{head['task_name']}": metric_val}, step=steps
|
||||
)
|
||||
# print via standard python logger
|
||||
|
@ -21,7 +21,6 @@ from haystack.modeling.utils import (
|
||||
)
|
||||
from haystack.modeling.data_handler.inputs import QAInput
|
||||
from haystack.modeling.model.adaptive_model import AdaptiveModel, BaseAdaptiveModel
|
||||
from haystack.modeling.logger import MLFlowLogger
|
||||
from haystack.modeling.model.predictions import QAPred
|
||||
|
||||
|
||||
@ -74,8 +73,6 @@ class Inferencer:
|
||||
:return: An instance of the Inferencer.
|
||||
|
||||
"""
|
||||
MLFlowLogger.disable()
|
||||
|
||||
# Init device and distributed settings
|
||||
self.devices, n_gpu = initialize_device_settings(use_cuda=gpu, multi_gpu=False)
|
||||
|
||||
|
@ -1,145 +0,0 @@
|
||||
import logging
|
||||
import mlflow
|
||||
from requests.exceptions import ConnectionError
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseMLLogger:
|
||||
"""
|
||||
Base class for tracking experiments.
|
||||
|
||||
This class can be extended to implement custom logging backends like MLFlow, Tensorboard, or Sacred.
|
||||
"""
|
||||
|
||||
disable_logging = False
|
||||
|
||||
def __init__(self, tracking_uri, **kwargs):
|
||||
self.tracking_uri = tracking_uri
|
||||
|
||||
def init_experiment(self, tracking_uri):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def log_metrics(cls, metrics, step):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def log_artifacts(cls, self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def log_params(cls, params):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class StdoutLogger(BaseMLLogger):
|
||||
"""Minimal logger printing metrics and params to stdout.
|
||||
Useful for services like AWS SageMaker, where you parse metrics from the actual logs"""
|
||||
|
||||
def init_experiment(self, experiment_name, run_name=None, nested=True):
|
||||
logger.info(f"\n **** Starting experiment '{experiment_name}' (Run: {run_name}) ****")
|
||||
|
||||
@classmethod
|
||||
def log_metrics(cls, metrics, step):
|
||||
logger.info(f"Logged metrics at step {step}: \n {metrics}")
|
||||
|
||||
@classmethod
|
||||
def log_params(cls, params):
|
||||
logger.info(f"Logged parameters: \n {params}")
|
||||
|
||||
@classmethod
|
||||
def log_artifacts(cls, dir_path, artifact_path=None):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def end_run(cls):
|
||||
logger.info(f"**** End of Experiment **** ")
|
||||
|
||||
|
||||
class MLFlowLogger(BaseMLLogger):
|
||||
"""
|
||||
Logger for MLFlow experiment tracking.
|
||||
"""
|
||||
|
||||
def init_experiment(self, experiment_name, run_name=None, nested=True):
|
||||
if not self.disable_logging:
|
||||
try:
|
||||
mlflow.set_tracking_uri(self.tracking_uri)
|
||||
mlflow.set_experiment(experiment_name)
|
||||
mlflow.start_run(run_name=run_name, nested=nested)
|
||||
except ConnectionError:
|
||||
raise Exception(
|
||||
f"MLFlow cannot connect to the remote server at {self.tracking_uri}.\n"
|
||||
f"MLFlow also supports logging runs locally to files. Set the MLFlowLogger "
|
||||
f"tracking_uri to an empty string to use that."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def log_metrics(cls, metrics, step):
|
||||
if not cls.disable_logging:
|
||||
try:
|
||||
mlflow.log_metrics(metrics, step=step)
|
||||
except ConnectionError:
|
||||
logger.warning(f"ConnectionError in logging metrics to MLFlow.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log metrics: {e}")
|
||||
|
||||
@classmethod
|
||||
def log_params(cls, params):
|
||||
if not cls.disable_logging:
|
||||
try:
|
||||
mlflow.log_params(params)
|
||||
except ConnectionError:
|
||||
logger.warning("ConnectionError in logging params to MLFlow")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log params: {e}")
|
||||
|
||||
@classmethod
|
||||
def log_artifacts(cls, dir_path, artifact_path=None):
|
||||
if not cls.disable_logging:
|
||||
try:
|
||||
mlflow.log_artifacts(dir_path, artifact_path)
|
||||
except ConnectionError:
|
||||
logger.warning(f"ConnectionError in logging artifacts to MLFlow")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log artifacts: {e}")
|
||||
|
||||
@classmethod
|
||||
def end_run(cls):
|
||||
if not cls.disable_logging:
|
||||
mlflow.end_run()
|
||||
|
||||
@classmethod
|
||||
def disable(cls):
|
||||
logger.info("ML Logging is turned off. No parameters, metrics or artifacts will be logged to MLFlow.")
|
||||
cls.disable_logging = True
|
||||
|
||||
|
||||
class TensorBoardLogger(BaseMLLogger):
|
||||
"""
|
||||
PyTorch TensorBoard Logger
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
try:
|
||||
from tensorboardX import SummaryWriter # pylint: disable=import-error
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logger.info(
|
||||
"tensorboardX not found, can't initialize TensorBoardLogger. "
|
||||
"Enable it with 'pip install tensorboardX'."
|
||||
)
|
||||
|
||||
TensorBoardLogger.summary_writer = SummaryWriter()
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def log_metrics(cls, metrics, step):
|
||||
for key, value in metrics.items():
|
||||
TensorBoardLogger.summary_writer.add_scalar(tag=key, scalar_value=value, global_step=step)
|
||||
|
||||
@classmethod
|
||||
def log_params(cls, params):
|
||||
for key, value in params.items():
|
||||
TensorBoardLogger.summary_writer.add_text(tag=key, text_string=str(value))
|
@ -15,7 +15,7 @@ from transformers.convert_graph_to_onnx import convert, quantize as quantize_mod
|
||||
from haystack.modeling.data_handler.processor import Processor
|
||||
from haystack.modeling.model.language_model import LanguageModel
|
||||
from haystack.modeling.model.prediction_head import PredictionHead, QuestionAnsweringHead
|
||||
from haystack.modeling.logger import MLFlowLogger as MlLogger
|
||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -556,7 +556,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
|
||||
"lm_output_types": ",".join(self.lm_output_types),
|
||||
}
|
||||
try:
|
||||
MlLogger.log_params(params)
|
||||
tracker.track_params(params)
|
||||
except Exception as e:
|
||||
logger.warning(f"ML logging didn't work: {e}")
|
||||
|
||||
|
@ -10,7 +10,7 @@ from torch import nn
|
||||
from haystack.modeling.data_handler.processor import Processor
|
||||
from haystack.modeling.model.language_model import LanguageModel
|
||||
from haystack.modeling.model.prediction_head import PredictionHead, TextSimilarityHead
|
||||
from haystack.modeling.logger import MLFlowLogger as MlLogger
|
||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -335,7 +335,7 @@ class BiAdaptiveModel(nn.Module):
|
||||
"prediction_heads": ",".join([head.__class__.__name__ for head in self.prediction_heads]),
|
||||
}
|
||||
try:
|
||||
MlLogger.log_params(params)
|
||||
tracker.track_params(params)
|
||||
except Exception as e:
|
||||
logger.warning(f"ML logging didn't work: {e}")
|
||||
|
||||
|
@ -10,7 +10,7 @@ from torch.nn import DataParallel
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
||||
from haystack.modeling.logger import MLFlowLogger as MlLogger
|
||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -161,7 +161,7 @@ def initialize_optimizer(
|
||||
schedule_opts["num_training_steps"] = num_train_optimization_steps
|
||||
|
||||
# Log params
|
||||
MlLogger.log_params({"use_amp": use_amp, "num_train_optimization_steps": schedule_opts["num_training_steps"]})
|
||||
tracker.track_params({"use_amp": use_amp, "num_train_optimization_steps": schedule_opts["num_training_steps"]})
|
||||
|
||||
# Get optimizer from pytorch, transformers or apex
|
||||
optimizer = _get_optim(model, optimizer_opts)
|
||||
@ -189,8 +189,8 @@ def _get_optim(model, opts: Dict):
|
||||
|
||||
# Logging
|
||||
logger.info(f"Loading optimizer `{optimizer_name}`: '{opts}'")
|
||||
MlLogger.log_params(opts)
|
||||
MlLogger.log_params({"optimizer_name": optimizer_name})
|
||||
tracker.track_params(opts)
|
||||
tracker.track_params({"optimizer_name": optimizer_name})
|
||||
|
||||
weight_decay = opts.pop("weight_decay", None)
|
||||
no_decay = opts.pop("no_decay", None)
|
||||
@ -279,15 +279,15 @@ def get_scheduler(optimizer, opts):
|
||||
# convert from warmup proportion to steps if required
|
||||
if "num_warmup_steps" in allowed_args and "num_warmup_steps" not in opts and "warmup_proportion" in opts:
|
||||
opts["num_warmup_steps"] = int(opts["warmup_proportion"] * opts["num_training_steps"])
|
||||
MlLogger.log_params({"warmup_proportion": opts["warmup_proportion"]})
|
||||
tracker.track_params({"warmup_proportion": opts["warmup_proportion"]})
|
||||
|
||||
# only pass args that are supported by the constructor
|
||||
constructor_opts = {k: v for k, v in opts.items() if k in allowed_args}
|
||||
|
||||
# Logging
|
||||
logger.info(f"Loading schedule `{schedule_name}`: '{constructor_opts}'")
|
||||
MlLogger.log_params(constructor_opts)
|
||||
MlLogger.log_params({"schedule_name": schedule_name})
|
||||
tracker.track_params(constructor_opts)
|
||||
tracker.track_params({"schedule_name": schedule_name})
|
||||
|
||||
scheduler = sched_constructor(optimizer, **constructor_opts)
|
||||
scheduler.opts = opts # save the opts with the scheduler to use in load/save
|
||||
|
@ -9,7 +9,7 @@ from torch import nn
|
||||
from haystack.modeling.data_handler.processor import Processor
|
||||
from haystack.modeling.model.language_model import LanguageModel
|
||||
from haystack.modeling.model.prediction_head import PredictionHead
|
||||
from haystack.modeling.logger import MLFlowLogger as MlLogger
|
||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -369,7 +369,7 @@ class TriAdaptiveModel(nn.Module):
|
||||
"prediction_heads": ",".join([head.__class__.__name__ for head in self.prediction_heads]),
|
||||
}
|
||||
try:
|
||||
MlLogger.log_params(params)
|
||||
tracker.track_params(params)
|
||||
except Exception as e:
|
||||
logger.warning(f"ML logging didn't work: {e}")
|
||||
|
||||
|
@ -19,7 +19,7 @@ from haystack.modeling.evaluation.eval import Evaluator
|
||||
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
||||
from haystack.modeling.model.optimization import get_scheduler
|
||||
from haystack.modeling.utils import GracefulKiller
|
||||
from haystack.modeling.logger import MLFlowLogger as MlLogger
|
||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
@ -161,7 +161,7 @@ class Trainer:
|
||||
Useful to achieve larger effective batch sizes that would not fit in GPU memory.
|
||||
:param local_rank: Local rank of process when distributed training via DDP is used.
|
||||
:param early_stopping: an initialized EarlyStopping object to control early stopping and saving of best models.
|
||||
:param log_learning_rate: Whether to log learning rate to Mlflow
|
||||
:param log_learning_rate: Whether to log learning rate to experiment tracker (e.g. Mlflow)
|
||||
:param log_loss_every: Log current train loss after this many train steps.
|
||||
:param checkpoint_on_sigterm: save a checkpoint for the Trainer when a SIGTERM signal is sent. The checkpoint
|
||||
can be used to resume training. It is useful in frameworks like AWS SageMaker with Spot instances where
|
||||
@ -377,9 +377,9 @@ class Trainer:
|
||||
loss = self.adjust_loss(loss)
|
||||
if self.global_step % self.log_loss_every == 0 and self.local_rank in [-1, 0]:
|
||||
if self.local_rank in [-1, 0]:
|
||||
MlLogger.log_metrics({"Train_loss_total": float(loss.detach().cpu().numpy())}, step=self.global_step)
|
||||
tracker.track_metrics({"Train_loss_total": float(loss.detach().cpu().numpy())}, step=self.global_step)
|
||||
if self.log_learning_rate:
|
||||
MlLogger.log_metrics({"learning_rate": self.lr_schedule.get_last_lr()[0]}, step=self.global_step)
|
||||
tracker.track_metrics({"learning_rate": self.lr_schedule.get_last_lr()[0]}, step=self.global_step)
|
||||
if self.use_amp:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
@ -406,7 +406,7 @@ class Trainer:
|
||||
|
||||
def log_params(self):
|
||||
params = {"epochs": self.epochs, "n_gpu": self.n_gpu, "device": self.device}
|
||||
MlLogger.log_params(params)
|
||||
tracker.track_params(params)
|
||||
|
||||
@classmethod
|
||||
def create_or_load_checkpoint(
|
||||
@ -700,7 +700,7 @@ class DistillationTrainer(Trainer):
|
||||
Useful to achieve larger effective batch sizes that would not fit in GPU memory.
|
||||
:param local_rank: Local rank of process when distributed training via DDP is used.
|
||||
:param early_stopping: an initialized EarlyStopping object to control early stopping and saving of best models.
|
||||
:param log_learning_rate: Whether to log learning rate to Mlflow
|
||||
:param log_learning_rate: Whether to log learning rate to experiment tracker (e.g. Mlflow)
|
||||
:param log_loss_every: Log current train loss after this many train steps.
|
||||
:param checkpoint_on_sigterm: save a checkpoint for the Trainer when a SIGTERM signal is sent. The checkpoint
|
||||
can be used to resume training. It is useful in frameworks like AWS SageMaker with Spot instances where
|
||||
@ -842,7 +842,7 @@ class TinyBERTDistillationTrainer(Trainer):
|
||||
Useful to achieve larger effective batch sizes that would not fit in GPU memory.
|
||||
:param local_rank: Local rank of process when distributed training via DDP is used.
|
||||
:param early_stopping: an initialized EarlyStopping object to control early stopping and saving of best models.
|
||||
:param log_learning_rate: Whether to log learning rate to Mlflow
|
||||
:param log_learning_rate: Whether to log learning rate to experiment tracker (e.g. Mlflow)
|
||||
:param log_loss_every: Log current train loss after this many train steps.
|
||||
:param checkpoint_on_sigterm: save a checkpoint for the Trainer when a SIGTERM signal is sent. The checkpoint
|
||||
can be used to resume training. It is useful in frameworks like AWS SageMaker with Spot instances where
|
||||
|
@ -455,7 +455,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
use_amp=use_amp,
|
||||
)
|
||||
|
||||
# 7. Let it grow! Watch the tracked metrics live on the public mlflow server: https://public-mlflow.deepset.ai
|
||||
# 7. Let it grow! Watch the tracked metrics live on experiment tracker (e.g. Mlflow)
|
||||
trainer.train()
|
||||
|
||||
self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir)
|
||||
@ -985,7 +985,7 @@ class TableTextRetriever(BaseRetriever):
|
||||
use_amp=use_amp,
|
||||
)
|
||||
|
||||
# 7. Let it grow! Watch the tracked metrics live on the public mlflow server: https://public-mlflow.deepset.ai
|
||||
# 7. Let it grow! Watch the tracked metrics live on experiment tracker (e.g. Mlflow)
|
||||
trainer.train()
|
||||
|
||||
self.model.save(
|
||||
|
@ -1,6 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from typing import Dict, List, Optional, Any, Set, Tuple, Union
|
||||
|
||||
try:
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
from typing_extensions import Literal # type: ignore
|
||||
|
||||
import copy
|
||||
import json
|
||||
import inspect
|
||||
@ -46,6 +51,7 @@ from haystack.nodes.base import BaseComponent
|
||||
from haystack.nodes.retriever.base import BaseRetriever
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
from haystack.telemetry import send_event
|
||||
from haystack.utils.experiment_tracking import MLflowTrackingHead, Tracker as tracker
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -53,6 +59,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
ROOT_NODE_TO_PIPELINE_NAME = {"query": "query", "file": "indexing"}
|
||||
CODE_GEN_DEFAULT_COMMENT = "This code has been generated."
|
||||
TRACKING_TOOL_TO_HEAD = {"mlflow": MLflowTrackingHead}
|
||||
|
||||
|
||||
class RootNode(BaseComponent):
|
||||
@ -770,6 +777,201 @@ class Pipeline(BasePipeline):
|
||||
ndcg, map_, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
|
||||
return ndcg, map_, recall, precision
|
||||
|
||||
@classmethod
|
||||
def execute_eval_run(
|
||||
cls,
|
||||
index_pipeline: Pipeline,
|
||||
query_pipeline: Pipeline,
|
||||
evaluation_set_labels: List[MultiLabel],
|
||||
corpus_file_paths: List[str],
|
||||
experiment_name: str,
|
||||
experiment_run_name: str,
|
||||
experiment_tracking_tool: Literal["mlflow", None] = None,
|
||||
experiment_tracking_uri: Optional[str] = None,
|
||||
corpus_file_metas: List[Dict[str, Any]] = None,
|
||||
corpus_meta: Dict[str, Any] = {},
|
||||
evaluation_set_meta: Dict[str, Any] = {},
|
||||
pipeline_meta: Dict[str, Any] = {},
|
||||
index_params: dict = {},
|
||||
query_params: dict = {},
|
||||
sas_model_name_or_path: str = None,
|
||||
sas_batch_size: int = 32,
|
||||
sas_use_gpu: bool = True,
|
||||
add_isolated_node_eval: bool = False,
|
||||
reuse_index: bool = False,
|
||||
) -> EvaluationResult:
|
||||
"""
|
||||
Starts an experiment run that first indexes the specified files (forming a corpus) using the index pipeline
|
||||
and subsequently evaluates the query pipeline on the provided labels (forming an evaluation set) using pipeline.eval().
|
||||
Parameters and results (metrics and predictions) of the run are tracked by an experiment tracking tool for further analysis.
|
||||
You can specify the experiment tracking tool by setting the params `experiment_tracking_tool` and `experiment_tracking_uri`
|
||||
or by passing a (custom) tracking head to Tracker.set_tracking_head().
|
||||
Note, that `experiment_tracking_tool` only supports `mlflow` currently.
|
||||
|
||||
For easier comparison you can pass additional metadata regarding corpus (corpus_meta), evaluation set (evaluation_set_meta) and pipelines (pipeline_meta).
|
||||
E.g. you can give them names or ids to identify them across experiment runs.
|
||||
|
||||
This method executes an experiment run. Each experiment run is part of at least one experiment.
|
||||
An experiment typically consists of multiple runs to be compared (e.g. using different retrievers in query pipeline).
|
||||
Experiment tracking tools usually share the same concepts of experiments and provide additional functionality to easily compare runs across experiments.
|
||||
|
||||
E.g. you can call execute_eval_run() multiple times with different retrievers in your query pipeline and compare the runs in mlflow:
|
||||
|
||||
```python
|
||||
| for retriever_type, query_pipeline in zip(["sparse", "dpr", "embedding"], [sparse_pipe, dpr_pipe, embedding_pipe]):
|
||||
| eval_result = Pipeline.execute_eval_run(
|
||||
| index_pipeline=index_pipeline,
|
||||
| query_pipeline=query_pipeline,
|
||||
| evaluation_set_labels=labels,
|
||||
| corpus_file_paths=file_paths,
|
||||
| corpus_file_metas=file_metas,
|
||||
| experiment_tracking_tool="mlflow",
|
||||
| experiment_tracking_uri="http://localhost:5000",
|
||||
| experiment_name="my-retriever-experiment",
|
||||
| experiment_run_name=f"run_{retriever_type}",
|
||||
| pipeline_meta={"name": f"my-pipeline-{retriever_type}"},
|
||||
| evaluation_set_meta={"name": "my-evalset"},
|
||||
| corpus_meta={"name": "my-corpus"}.
|
||||
| reuse_index=False
|
||||
| )
|
||||
```
|
||||
|
||||
:param index_pipeline: The indexing pipeline to use.
|
||||
:param query_pipeline: The query pipeline to evaluate.
|
||||
:param evaluation_set_labels: The labels to evaluate on forming an evalution set.
|
||||
:param corpus_file_paths: The files to be indexed and searched during evaluation forming a corpus.
|
||||
:param experiment_name: The name of the experiment
|
||||
:param experiment_run_name: The name of the experiment run
|
||||
:param experiment_tracking_tool: The experiment tracking tool to be used. Currently we only support "mlflow".
|
||||
If left unset the current TrackingHead specified by Tracker.set_tracking_head() will be used.
|
||||
:param experiment_tracking_uri: The uri of the experiment tracking server to be used. Must be specified if experiment_tracking_tool is set.
|
||||
You can use deepset's public mlflow server via https://public-mlflow.deepset.ai/.
|
||||
Note, that artifact logging (e.g. Pipeline YAML or evaluation result CSVs) are currently not allowed on deepset's public mlflow server as this might expose sensitive data.
|
||||
:param corpus_file_metas: The optional metadata to be stored for each corpus file (e.g. title).
|
||||
:param corpus_meta: Metadata about the corpus to track (e.g. name, date, author, version).
|
||||
:param evaluation_set_meta: Metadata about the evalset to track (e.g. name, date, author, version).
|
||||
:param pipeline_meta: Metadata about the pipelines to track (e.g. name, author, version).
|
||||
:param index_params: The params to use during indexing (see pipeline.run's params).
|
||||
:param query_params: The params to use during querying (see pipeline.run's params).
|
||||
:param sas_model_name_or_path: Name or path of "Semantic Answer Similarity (SAS) model". When set, the model will be used to calculate similarity between predictions and labels and generate the SAS metric.
|
||||
The SAS metric correlates better with human judgement of correct answers as it does not rely on string overlaps.
|
||||
Example: Prediction = "30%", Label = "thirty percent", EM and F1 would be overly pessimistic with both being 0, while SAS paints a more realistic picture.
|
||||
More info in the paper: https://arxiv.org/abs/2108.06130
|
||||
Models:
|
||||
- You can use Bi Encoders (sentence transformers) or cross encoders trained on Semantic Textual Similarity (STS) data.
|
||||
Not all cross encoders can be used because of different return types.
|
||||
If you use custom cross encoders please make sure they work with sentence_transformers.CrossEncoder class
|
||||
- Good default for multiple languages: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
- Large, powerful, but slow model for English only: "cross-encoder/stsb-roberta-large"
|
||||
- Large model for German only: "deepset/gbert-large-sts"
|
||||
:param sas_batch_size: Number of prediction label pairs to encode at once by CrossEncoder or SentenceTransformer while calculating SAS.
|
||||
:param sas_use_gpu: Whether to use a GPU or the CPU for calculating semantic answer similarity.
|
||||
Falls back to CPU if no GPU is available.
|
||||
:param add_isolated_node_eval: If set to True, in addition to the integrated evaluation of the pipeline, each node is evaluated in isolated evaluation mode.
|
||||
This mode helps to understand the bottlenecks of a pipeline in terms of output quality of each individual node.
|
||||
If a node performs much better in the isolated evaluation than in the integrated evaluation, the previous node needs to be optimized to improve the pipeline's performance.
|
||||
If a node's performance is similar in both modes, this node itself needs to be optimized to improve the pipeline's performance.
|
||||
The isolated evaluation calculates the upper bound of each node's evaluation metrics under the assumption that it received perfect inputs from the previous node.
|
||||
To this end, labels are used as input to the node instead of the output of the previous node in the pipeline.
|
||||
The generated dataframes in the EvaluationResult then contain additional rows, which can be distinguished from the integrated evaluation results based on the
|
||||
values "integrated" or "isolated" in the column "eval_mode" and the evaluation report then additionally lists the upper bound of each node's evaluation metrics.
|
||||
:param reuse_index: Whether to reuse existing non-empty index and to keep the index after evaluation.
|
||||
If True the index will be kept after evaluation and no indexing will take place if index has already documents. Otherwise it will be deleted immediately afterwards.
|
||||
Defaults to False.
|
||||
"""
|
||||
if experiment_tracking_tool is not None:
|
||||
tracking_head_cls = TRACKING_TOOL_TO_HEAD.get(experiment_tracking_tool, None)
|
||||
if tracking_head_cls is None:
|
||||
raise HaystackError(
|
||||
f"Please specify a valid experiment_tracking_tool. Possible values are: {TRACKING_TOOL_TO_HEAD.keys()}"
|
||||
)
|
||||
if experiment_tracking_uri is None:
|
||||
raise HaystackError(f"experiment_tracking_uri must be specified if experiment_tracking_tool is set.")
|
||||
tracking_head = tracking_head_cls(tracking_uri=experiment_tracking_uri)
|
||||
tracker.set_tracking_head(tracking_head)
|
||||
|
||||
try:
|
||||
tracker.init_experiment(
|
||||
experiment_name=experiment_name, run_name=experiment_run_name, tags={experiment_name: "True"}
|
||||
)
|
||||
tracker.track_params(
|
||||
{
|
||||
"dataset_label_count": len(evaluation_set_labels),
|
||||
"dataset": evaluation_set_meta,
|
||||
"sas_model_name_or_path": sas_model_name_or_path,
|
||||
"sas_batch_size": sas_batch_size,
|
||||
"sas_use_gpu": sas_use_gpu,
|
||||
"pipeline_index_params": index_params,
|
||||
"pipeline_query_params": query_params,
|
||||
"pipeline": pipeline_meta,
|
||||
"corpus_file_count": len(corpus_file_paths),
|
||||
"corpus": corpus_meta,
|
||||
"type": "offline/evaluation",
|
||||
}
|
||||
)
|
||||
|
||||
# check index before eval
|
||||
document_store = index_pipeline.get_document_store()
|
||||
if document_store is None:
|
||||
raise HaystackError(f"Document store not found. Please provide pipelines with proper document store.")
|
||||
document_count = document_store.get_document_count()
|
||||
|
||||
if document_count > 0:
|
||||
if not reuse_index:
|
||||
raise HaystackError(f"Index '{document_store.index}' is not empty. Please provide an empty index.")
|
||||
else:
|
||||
logger.info(f"indexing {len(corpus_file_paths)} documents...")
|
||||
index_pipeline.run(file_paths=corpus_file_paths, meta=corpus_file_metas, params=index_params)
|
||||
document_count = document_store.get_document_count()
|
||||
logger.info(f"indexing {len(evaluation_set_labels)} files to {document_count} documents finished.")
|
||||
|
||||
tracker.track_params({"pipeline_index_document_count": document_count})
|
||||
|
||||
eval_result = query_pipeline.eval(
|
||||
labels=evaluation_set_labels,
|
||||
params=query_params,
|
||||
sas_model_name_or_path=sas_model_name_or_path,
|
||||
sas_batch_size=sas_batch_size,
|
||||
sas_use_gpu=sas_use_gpu,
|
||||
add_isolated_node_eval=add_isolated_node_eval,
|
||||
)
|
||||
|
||||
integrated_metrics = eval_result.calculate_metrics()
|
||||
integrated_top_1_metrics = eval_result.calculate_metrics(simulated_top_k_reader=1)
|
||||
metrics = {"integrated": integrated_metrics, "integrated_top_1": integrated_top_1_metrics}
|
||||
if add_isolated_node_eval:
|
||||
isolated_metrics = eval_result.calculate_metrics(eval_mode="isolated")
|
||||
isolated_top_1_metrics = eval_result.calculate_metrics(eval_mode="isolated", simulated_top_k_reader=1)
|
||||
metrics["isolated"] = isolated_metrics
|
||||
metrics["isolated_top_1"] = isolated_top_1_metrics
|
||||
tracker.track_metrics(metrics, step=0)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
eval_result_dir = Path(temp_dir) / "eval_result"
|
||||
eval_result_dir.mkdir(exist_ok=True)
|
||||
eval_result.save(out_dir=eval_result_dir)
|
||||
tracker.track_artifacts(eval_result_dir, artifact_path="eval_result")
|
||||
with open(Path(temp_dir) / "pipelines.yaml", "w") as outfile:
|
||||
index_config = index_pipeline.get_config()
|
||||
query_config = query_pipeline.get_config()
|
||||
components = list(
|
||||
{c["name"]: c for c in (index_config["components"] + query_config["components"])}.values()
|
||||
)
|
||||
pipelines = index_config["pipelines"] + query_config["pipelines"]
|
||||
config = {"version": index_config["version"], "components": components, "pipelines": pipelines}
|
||||
yaml.dump(config, outfile, default_flow_style=False)
|
||||
tracker.track_artifacts(temp_dir)
|
||||
|
||||
# Clean up document store
|
||||
if not reuse_index and document_store.index is not None:
|
||||
logger.info(f"Cleaning up: deleting index '{document_store.index}'...")
|
||||
document_store.delete_index(document_store.index)
|
||||
|
||||
finally:
|
||||
tracker.end_run()
|
||||
|
||||
return eval_result
|
||||
|
||||
@send_event
|
||||
def eval(
|
||||
self,
|
||||
|
@ -5,34 +5,26 @@
|
||||
You can opt-out of sharing usage statistics by calling disable_telemetry() or by manually setting the environment variable HAYSTACK_TELEMETRY_ENABLED as described for different operating systems on the documentation page.
|
||||
You can log all events to the local file specified in LOG_PATH for inspection by setting the environment variable HAYSTACK_TELEMETRY_LOGGING_TO_FILE_ENABLED to "True".
|
||||
"""
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
import logging
|
||||
import platform
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
import torch
|
||||
import posthog
|
||||
import transformers
|
||||
|
||||
from haystack import __version__
|
||||
from haystack.environment import HAYSTACK_EXECUTION_CONTEXT, get_or_create_env_meta_data
|
||||
|
||||
posthog.api_key = "phc_F5v11iI2YHkoP6Er3cPILWSrLhY3D6UY4dEMga4eoaa"
|
||||
posthog.host = "https://tm.hs.deepset.ai"
|
||||
HAYSTACK_TELEMETRY_ENABLED = "HAYSTACK_TELEMETRY_ENABLED"
|
||||
HAYSTACK_TELEMETRY_LOGGING_TO_FILE_ENABLED = "HAYSTACK_TELEMETRY_LOGGING_TO_FILE_ENABLED"
|
||||
HAYSTACK_EXECUTION_CONTEXT = "HAYSTACK_EXECUTION_CONTEXT"
|
||||
HAYSTACK_DOCKER_CONTAINER = "HAYSTACK_DOCKER_CONTAINER"
|
||||
CONFIG_PATH = Path("~/.haystack/config.yaml").expanduser()
|
||||
LOG_PATH = Path("~/.haystack/telemetry.log").expanduser()
|
||||
|
||||
telemetry_meta_data: Dict[str, Any] = {}
|
||||
user_id: Optional[str] = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -49,7 +41,7 @@ def print_telemetry_report():
|
||||
"""
|
||||
if is_telemetry_enabled():
|
||||
user_id = _get_or_create_user_id()
|
||||
meta_data = _get_or_create_telemetry_meta_data()
|
||||
meta_data = get_or_create_env_meta_data()
|
||||
print({**{"user_id": user_id}, **meta_data})
|
||||
else:
|
||||
print("Telemetry is disabled.")
|
||||
@ -152,7 +144,7 @@ def send_custom_event(event: str = "", payload: Dict[str, Any] = {}):
|
||||
|
||||
:param payload: A dictionary containing event meta data, e.g., parameter settings
|
||||
"""
|
||||
event_properties = {**(NonPrivateParameters.apply_filter(payload)), **_get_or_create_telemetry_meta_data()}
|
||||
event_properties = {**(NonPrivateParameters.apply_filter(payload)), **get_or_create_env_meta_data()}
|
||||
if user_id is None:
|
||||
raise RuntimeError("User id was not initialized")
|
||||
try:
|
||||
@ -224,53 +216,6 @@ def _get_or_create_user_id() -> str:
|
||||
return user_id
|
||||
|
||||
|
||||
def _get_or_create_telemetry_meta_data() -> Dict[str, Any]:
|
||||
"""
|
||||
Collects meta data about the setup that is used with Haystack, such as: operating system, python version, Haystack version, transformers version, pytorch version, number of GPUs, execution environment, and the value stored in the env variable HAYSTACK_EXECUTION_CONTEXT.
|
||||
"""
|
||||
global telemetry_meta_data # pylint: disable=global-statement
|
||||
if not telemetry_meta_data:
|
||||
telemetry_meta_data = {
|
||||
"os_version": platform.release(),
|
||||
"os_family": platform.system(),
|
||||
"os_machine": platform.machine(),
|
||||
"python_version": platform.python_version(),
|
||||
"haystack_version": __version__,
|
||||
"transformers_version": transformers.__version__,
|
||||
"torch_version": torch.__version__,
|
||||
"torch_cuda_version": torch.version.cuda if torch.cuda.is_available() else 0,
|
||||
"n_gpu": torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
||||
"n_cpu": os.cpu_count(),
|
||||
"context": os.environ.get(HAYSTACK_EXECUTION_CONTEXT),
|
||||
"execution_env": _get_execution_environment(),
|
||||
}
|
||||
return telemetry_meta_data
|
||||
|
||||
|
||||
def _get_execution_environment():
|
||||
"""
|
||||
Identifies the execution environment that Haystack is running in.
|
||||
Options are: colab notebook, kubernetes, CPU/GPU docker container, test environment, jupyter notebook, python script
|
||||
"""
|
||||
if os.environ.get("CI", "False").lower() == "true":
|
||||
execution_env = "ci"
|
||||
elif "google.colab" in sys.modules:
|
||||
execution_env = "colab"
|
||||
elif "KUBERNETES_SERVICE_HOST" in os.environ:
|
||||
execution_env = "kubernetes"
|
||||
elif HAYSTACK_DOCKER_CONTAINER in os.environ:
|
||||
execution_env = os.environ.get(HAYSTACK_DOCKER_CONTAINER)
|
||||
# check if pytest is imported
|
||||
elif "pytest" in sys.modules:
|
||||
execution_env = "test"
|
||||
else:
|
||||
try:
|
||||
execution_env = get_ipython().__class__.__name__ # pylint: disable=undefined-variable
|
||||
except NameError:
|
||||
execution_env = "script"
|
||||
return execution_env
|
||||
|
||||
|
||||
def _read_telemetry_config():
|
||||
"""
|
||||
Loads the config from the file specified in CONFIG_PATH
|
||||
|
@ -19,3 +19,10 @@ from haystack.utils.export_utils import (
|
||||
)
|
||||
from haystack.utils.squad_data import SquadData
|
||||
from haystack.utils.context_matching import calculate_context_similarity, match_context, match_contexts
|
||||
from haystack.utils.experiment_tracking import (
|
||||
Tracker,
|
||||
NoTrackingHead,
|
||||
BaseTrackingHead,
|
||||
MLflowTrackingHead,
|
||||
StdoutTrackingHead,
|
||||
)
|
||||
|
188
haystack/utils/experiment_tracking.py
Normal file
188
haystack/utils/experiment_tracking.py
Normal file
@ -0,0 +1,188 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Union
|
||||
import mlflow
|
||||
from requests.exceptions import ConnectionError
|
||||
|
||||
from haystack.environment import get_or_create_env_meta_data
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def flatten_dict(dict_to_flatten: dict, prefix: str = ""):
|
||||
flat_dict = {}
|
||||
for k, v in dict_to_flatten.items():
|
||||
if isinstance(v, dict):
|
||||
flat_dict.update(flatten_dict(v, prefix + k + "_"))
|
||||
else:
|
||||
flat_dict[prefix + k] = v
|
||||
return flat_dict
|
||||
|
||||
|
||||
class BaseTrackingHead(ABC):
|
||||
"""
|
||||
Base class for tracking experiments.
|
||||
|
||||
This class can be extended to implement custom logging backends like MLflow, WandB, or TensorBoard.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def init_experiment(
|
||||
self, experiment_name: str, run_name: str = None, tags: Dict[str, Any] = None, nested: bool = False
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def track_metrics(self, metrics: Dict[str, Any], step: int):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def track_artifacts(self, dir_path: Union[str, Path], artifact_path: str = None):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def track_params(self, params: Dict[str, Any]):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def end_run(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class NoTrackingHead(BaseTrackingHead):
|
||||
"""
|
||||
Null object implementation of a tracking head: i.e. does nothing.
|
||||
"""
|
||||
|
||||
def init_experiment(
|
||||
self, experiment_name: str, run_name: str = None, tags: Dict[str, Any] = None, nested: bool = False
|
||||
):
|
||||
pass
|
||||
|
||||
def track_metrics(self, metrics: Dict[str, Any], step: int):
|
||||
pass
|
||||
|
||||
def track_artifacts(self, dir_path: Union[str, Path], artifact_path: str = None):
|
||||
pass
|
||||
|
||||
def track_params(self, params: Dict[str, Any]):
|
||||
pass
|
||||
|
||||
def end_run(self):
|
||||
pass
|
||||
|
||||
|
||||
class Tracker:
|
||||
"""
|
||||
Facade for tracking experiments.
|
||||
"""
|
||||
|
||||
tracker: BaseTrackingHead = NoTrackingHead()
|
||||
|
||||
@classmethod
|
||||
def init_experiment(
|
||||
cls, experiment_name: str, run_name: str = None, tags: Dict[str, Any] = None, nested: bool = False
|
||||
):
|
||||
cls.tracker.init_experiment(experiment_name=experiment_name, run_name=run_name, tags=tags, nested=nested)
|
||||
|
||||
@classmethod
|
||||
def track_metrics(cls, metrics: Dict[str, Any], step: int):
|
||||
cls.tracker.track_metrics(metrics=metrics, step=step)
|
||||
|
||||
@classmethod
|
||||
def track_artifacts(cls, dir_path: Union[str, Path], artifact_path: str = None):
|
||||
cls.tracker.track_artifacts(dir_path=dir_path, artifact_path=artifact_path)
|
||||
|
||||
@classmethod
|
||||
def track_params(cls, params: Dict[str, Any]):
|
||||
cls.tracker.track_params(params=params)
|
||||
|
||||
@classmethod
|
||||
def end_run(cls):
|
||||
cls.tracker.end_run()
|
||||
|
||||
@classmethod
|
||||
def set_tracking_head(cls, tracker: BaseTrackingHead):
|
||||
cls.tracker = tracker
|
||||
|
||||
|
||||
class StdoutTrackingHead(BaseTrackingHead):
|
||||
"""
|
||||
Experiment tracking head printing metrics and params to stdout.
|
||||
Useful for services like AWS SageMaker, where you parse metrics from the actual logs
|
||||
"""
|
||||
|
||||
def init_experiment(
|
||||
self, experiment_name: str, run_name: str = None, tags: Dict[str, Any] = None, nested: bool = False
|
||||
):
|
||||
logger.info(f"\n **** Starting experiment '{experiment_name}' (Run: {run_name}) ****")
|
||||
|
||||
def track_metrics(self, metrics: Dict[str, Any], step: int):
|
||||
logger.info(f"Logged metrics at step {step}: \n {metrics}")
|
||||
|
||||
def track_params(self, params: Dict[str, Any]):
|
||||
logger.info(f"Logged parameters: \n {params}")
|
||||
|
||||
def track_artifacts(self, dir_path: Union[str, Path], artifact_path: str = None):
|
||||
logger.warning(f"Cannot log artifacts with StdoutLogger: \n {dir_path}")
|
||||
|
||||
def end_run(self):
|
||||
logger.info(f"**** End of Experiment **** ")
|
||||
|
||||
|
||||
class MLflowTrackingHead(BaseTrackingHead):
|
||||
def __init__(self, tracking_uri: str, auto_track_environment: bool = True) -> None:
|
||||
"""
|
||||
Experiment tracking head for MLflow.
|
||||
"""
|
||||
super().__init__()
|
||||
self.tracking_uri = tracking_uri
|
||||
self.auto_track_environment = auto_track_environment
|
||||
|
||||
def init_experiment(
|
||||
self, experiment_name: str, run_name: str = None, tags: Dict[str, Any] = None, nested: bool = False
|
||||
):
|
||||
try:
|
||||
mlflow.set_tracking_uri(self.tracking_uri)
|
||||
mlflow.set_experiment(experiment_name)
|
||||
mlflow.start_run(run_name=run_name, nested=nested, tags=tags)
|
||||
logger.info(f"Tracking run {run_name} of experiment {experiment_name} by mlflow under {self.tracking_uri}")
|
||||
if self.auto_track_environment:
|
||||
mlflow.log_params(flatten_dict({"environment": get_or_create_env_meta_data()}))
|
||||
except ConnectionError:
|
||||
raise Exception(
|
||||
f"MLflow cannot connect to the remote server at {self.tracking_uri}.\n"
|
||||
f"MLflow also supports logging runs locally to files. Set the MLflowTrackingHead "
|
||||
f"tracking_uri to an empty string to use that."
|
||||
)
|
||||
|
||||
def track_metrics(self, metrics: Dict[str, Any], step: int):
|
||||
try:
|
||||
metrics = flatten_dict(metrics)
|
||||
mlflow.log_metrics(metrics, step=step)
|
||||
except ConnectionError:
|
||||
logger.warning(f"ConnectionError in logging metrics to MLflow.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log metrics: {e}")
|
||||
|
||||
def track_params(self, params: Dict[str, Any]):
|
||||
try:
|
||||
params = flatten_dict(params)
|
||||
mlflow.log_params(params)
|
||||
except ConnectionError:
|
||||
logger.warning("ConnectionError in logging params to MLflow")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log params: {e}")
|
||||
|
||||
def track_artifacts(self, dir_path: Union[str, Path], artifact_path: str = None):
|
||||
try:
|
||||
mlflow.log_artifacts(dir_path, artifact_path)
|
||||
except ConnectionError:
|
||||
logger.warning(f"ConnectionError in logging artifacts to MLflow")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log artifacts: {e}")
|
||||
|
||||
def end_run(self):
|
||||
mlflow.end_run()
|
Loading…
x
Reference in New Issue
Block a user