Log evaluation results to MLflow (#2337)

* track eval results in mlflow

* Update Documentation & Code Style

* add pipeline.yaml and environment info

* improve logging to mlflow

* Update Documentation & Code Style

* introduce ExperimentTracker

* Update Documentation & Code Style

* move modeling.utils.logger to utils.experiment_tracking

* renaming: tracker and TrackingHead

* Update Documentation & Code Style

* refactor env tracking

* fix pylint findings

* Update Documentation & Code Style

* rename MLFlowTrackingHead to MLflowTrackingHead

* implement dataset hash

* Update Documentation & Code Style

* set docstrings

* Update Documentation & Code Style

* introduce PipelineBundle and Corpus

* Update Documentation & Code Style

* support reusing index

* Update Documentation & Code Style

* rename Corpus to FileCorpus

* fix Corpus -> FileCorpus

* Update Documentation & Code Style

* resolve cyclic dependencies

* fix linter issues

* Update Documentation & Code Style

* remove helper classes

* Update Documentation & Code Style

* fix imports

* fix another unused import

* update docstrings

* Update Documentation & Code Style

* simplify usage of experiment tracking tools

* fix Literal import

* revert schema changes

* Update Documentation & Code Style

* always end run

* Update Documentation & Code Style

* fix mypy issue

* rename to execute_eval_run

* Update Documentation & Code Style

* fix merge of get_or_create_env_meta_data

* improve docstrings

* Update Documentation & Code Style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
tstadel 2022-04-25 20:14:48 +02:00 committed by GitHub
parent c401e86099
commit 60ff46e4e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 4322 additions and 250 deletions

View File

@ -464,6 +464,96 @@ If True the index will be kept after beir evaluation. Otherwise it will be delet
Returns a tuple containing the ncdg, map, recall and precision scores.
Each metric is represented by a dictionary containing the scores for each top_k value.
<a id="base.Pipeline.execute_eval_run"></a>
#### execute\_eval\_run
```python
@classmethod
def execute_eval_run(cls, index_pipeline: Pipeline, query_pipeline: Pipeline, evaluation_set_labels: List[MultiLabel], corpus_file_paths: List[str], experiment_name: str, experiment_run_name: str, experiment_tracking_tool: Literal["mlflow", None] = None, experiment_tracking_uri: Optional[str] = None, corpus_file_metas: List[Dict[str, Any]] = None, corpus_meta: Dict[str, Any] = {}, evaluation_set_meta: Dict[str, Any] = {}, pipeline_meta: Dict[str, Any] = {}, index_params: dict = {}, query_params: dict = {}, sas_model_name_or_path: str = None, sas_batch_size: int = 32, sas_use_gpu: bool = True, add_isolated_node_eval: bool = False, reuse_index: bool = False) -> EvaluationResult
```
Starts an experiment run that first indexes the specified files (forming a corpus) using the index pipeline
and subsequently evaluates the query pipeline on the provided labels (forming an evaluation set) using pipeline.eval().
Parameters and results (metrics and predictions) of the run are tracked by an experiment tracking tool for further analysis.
You can specify the experiment tracking tool by setting the params `experiment_tracking_tool` and `experiment_tracking_uri`
or by passing a (custom) tracking head to Tracker.set_tracking_head().
Note, that `experiment_tracking_tool` only supports `mlflow` currently.
For easier comparison you can pass additional metadata regarding corpus (corpus_meta), evaluation set (evaluation_set_meta) and pipelines (pipeline_meta).
E.g. you can give them names or ids to identify them across experiment runs.
This method executes an experiment run. Each experiment run is part of at least one experiment.
An experiment typically consists of multiple runs to be compared (e.g. using different retrievers in query pipeline).
Experiment tracking tools usually share the same concepts of experiments and provide additional functionality to easily compare runs across experiments.
E.g. you can call execute_eval_run() multiple times with different retrievers in your query pipeline and compare the runs in mlflow:
```python
| for retriever_type, query_pipeline in zip(["sparse", "dpr", "embedding"], [sparse_pipe, dpr_pipe, embedding_pipe]):
| eval_result = Pipeline.execute_eval_run(
| index_pipeline=index_pipeline,
| query_pipeline=query_pipeline,
| evaluation_set_labels=labels,
| corpus_file_paths=file_paths,
| corpus_file_metas=file_metas,
| experiment_tracking_tool="mlflow",
| experiment_tracking_uri="http://localhost:5000",
| experiment_name="my-retriever-experiment",
| experiment_run_name=f"run_{retriever_type}",
| pipeline_meta={"name": f"my-pipeline-{retriever_type}"},
| evaluation_set_meta={"name": "my-evalset"},
| corpus_meta={"name": "my-corpus"}.
| reuse_index=False
| )
```
**Arguments**:
- `index_pipeline`: The indexing pipeline to use.
- `query_pipeline`: The query pipeline to evaluate.
- `evaluation_set_labels`: The labels to evaluate on forming an evalution set.
- `corpus_file_paths`: The files to be indexed and searched during evaluation forming a corpus.
- `experiment_name`: The name of the experiment
- `experiment_run_name`: The name of the experiment run
- `experiment_tracking_tool`: The experiment tracking tool to be used. Currently we only support "mlflow".
If left unset the current TrackingHead specified by Tracker.set_tracking_head() will be used.
- `experiment_tracking_uri`: The uri of the experiment tracking server to be used. Must be specified if experiment_tracking_tool is set.
You can use deepset's public mlflow server via https://public-mlflow.deepset.ai/.
Note, that artifact logging (e.g. Pipeline YAML or evaluation result CSVs) are currently not allowed on deepset's public mlflow server as this might expose sensitive data.
- `corpus_file_metas`: The optional metadata to be stored for each corpus file (e.g. title).
- `corpus_meta`: Metadata about the corpus to track (e.g. name, date, author, version).
- `evaluation_set_meta`: Metadata about the evalset to track (e.g. name, date, author, version).
- `pipeline_meta`: Metadata about the pipelines to track (e.g. name, author, version).
- `index_params`: The params to use during indexing (see pipeline.run's params).
- `query_params`: The params to use during querying (see pipeline.run's params).
- `sas_model_name_or_path`: Name or path of "Semantic Answer Similarity (SAS) model". When set, the model will be used to calculate similarity between predictions and labels and generate the SAS metric.
The SAS metric correlates better with human judgement of correct answers as it does not rely on string overlaps.
Example: Prediction = "30%", Label = "thirty percent", EM and F1 would be overly pessimistic with both being 0, while SAS paints a more realistic picture.
More info in the paper: https://arxiv.org/abs/2108.06130
Models:
- You can use Bi Encoders (sentence transformers) or cross encoders trained on Semantic Textual Similarity (STS) data.
Not all cross encoders can be used because of different return types.
If you use custom cross encoders please make sure they work with sentence_transformers.CrossEncoder class
- Good default for multiple languages: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
- Large, powerful, but slow model for English only: "cross-encoder/stsb-roberta-large"
- Large model for German only: "deepset/gbert-large-sts"
- `sas_batch_size`: Number of prediction label pairs to encode at once by CrossEncoder or SentenceTransformer while calculating SAS.
- `sas_use_gpu`: Whether to use a GPU or the CPU for calculating semantic answer similarity.
Falls back to CPU if no GPU is available.
- `add_isolated_node_eval`: If set to True, in addition to the integrated evaluation of the pipeline, each node is evaluated in isolated evaluation mode.
This mode helps to understand the bottlenecks of a pipeline in terms of output quality of each individual node.
If a node performs much better in the isolated evaluation than in the integrated evaluation, the previous node needs to be optimized to improve the pipeline's performance.
If a node's performance is similar in both modes, this node itself needs to be optimized to improve the pipeline's performance.
The isolated evaluation calculates the upper bound of each node's evaluation metrics under the assumption that it received perfect inputs from the previous node.
To this end, labels are used as input to the node instead of the output of the previous node in the pipeline.
The generated dataframes in the EvaluationResult then contain additional rows, which can be distinguished from the integrated evaluation results based on the
values "integrated" or "isolated" in the column "eval_mode" and the evaluation report then additionally lists the upper bound of each node's evaluation metrics.
- `reuse_index`: Whether to reuse existing non-empty index and to keep the index after evaluation.
If True the index will be kept after evaluation and no indexing will take place if index has already documents. Otherwise it will be deleted immediately afterwards.
Defaults to False.
<a id="base.Pipeline.eval"></a>
#### eval

View File

@ -22,7 +22,7 @@ logging.getLogger("haystack").setLevel(logging.INFO)
import pandas as pd
from haystack.schema import Document, Answer, Label, MultiLabel, Span
from haystack.schema import Document, Answer, Label, MultiLabel, Span, EvaluationResult
from haystack.nodes.base import BaseComponent
from haystack.pipelines.base import Pipeline
@ -104,7 +104,6 @@ except ImportError:
pass
from haystack.modeling.evaluation import eval
from haystack.modeling.logger import MLFlowLogger, StdoutLogger, TensorBoardLogger
from haystack.nodes.other import JoinDocuments, Docs2Answers, JoinAnswers, RouteDocuments
from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier
from haystack.nodes.file_classifier import FileTypeClassifier
@ -178,9 +177,6 @@ if graph_retriever:
# Adding them to sys.modules would enable `import haystack.pipelines.JoinDocuments`,
# which I believe it's a very rare import style.
setattr(file_converter, "FileTypeClassifier", FileTypeClassifier)
setattr(modeling_utils, "MLFlowLogger", MLFlowLogger)
setattr(modeling_utils, "StdoutLogger", StdoutLogger)
setattr(modeling_utils, "TensorBoardLogger", TensorBoardLogger)
setattr(pipelines, "JoinDocuments", JoinDocuments)
setattr(pipelines, "Docs2Answers", Docs2Answers)
setattr(pipelines, "SklearnQueryClassifier", SklearnQueryClassifier)

62
haystack/environment.py Normal file
View File

@ -0,0 +1,62 @@
import os
import platform
import sys
from typing import Any, Dict
import torch
import transformers
from haystack import __version__
HAYSTACK_EXECUTION_CONTEXT = "HAYSTACK_EXECUTION_CONTEXT"
HAYSTACK_DOCKER_CONTAINER = "HAYSTACK_DOCKER_CONTAINER"
env_meta_data: Dict[str, Any] = {}
def get_or_create_env_meta_data() -> Dict[str, Any]:
"""
Collects meta data about the setup that is used with Haystack, such as: operating system, python version, Haystack version, transformers version, pytorch version, number of GPUs, execution environment, and the value stored in the env variable HAYSTACK_EXECUTION_CONTEXT.
"""
global env_meta_data # pylint: disable=global-statement
if not env_meta_data:
env_meta_data = {
"os_version": platform.release(),
"os_family": platform.system(),
"os_machine": platform.machine(),
"python_version": platform.python_version(),
"haystack_version": __version__,
"transformers_version": transformers.__version__,
"torch_version": torch.__version__,
"torch_cuda_version": torch.version.cuda if torch.cuda.is_available() else 0,
"n_gpu": torch.cuda.device_count() if torch.cuda.is_available() else 0,
"n_cpu": os.cpu_count(),
"context": os.environ.get(HAYSTACK_EXECUTION_CONTEXT),
"execution_env": _get_execution_environment(),
}
return env_meta_data
def _get_execution_environment():
"""
Identifies the execution environment that Haystack is running in.
Options are: colab notebook, kubernetes, CPU/GPU docker container, test environment, jupyter notebook, python script
"""
if os.environ.get("CI", "False").lower() == "true":
execution_env = "ci"
elif "google.colab" in sys.modules:
execution_env = "colab"
elif "KUBERNETES_SERVICE_HOST" in os.environ:
execution_env = "kubernetes"
elif HAYSTACK_DOCKER_CONTAINER in os.environ:
execution_env = os.environ.get(HAYSTACK_DOCKER_CONTAINER)
# check if pytest is imported
elif "pytest" in sys.modules:
execution_env = "test"
else:
try:
execution_env = get_ipython().__class__.__name__ # pylint: disable=undefined-variable
except NameError:
execution_env = "script"
return execution_env

View File

@ -1411,10 +1411,10 @@
"title": "Use Auth Token",
"anyOf": [
{
"type": "boolean"
"type": "string"
},
{
"type": "string"
"type": "boolean"
}
]
}
@ -1682,10 +1682,10 @@
"title": "Use Auth Token",
"anyOf": [
{
"type": "boolean"
"type": "string"
},
{
"type": "string"
"type": "boolean"
}
]
}
@ -1949,10 +1949,10 @@
"title": "Use Auth Token",
"anyOf": [
{
"type": "boolean"
"type": "string"
},
{
"type": "string"
"type": "boolean"
}
]
}
@ -3124,10 +3124,10 @@
"title": "Use Auth Token",
"anyOf": [
{
"type": "boolean"
"type": "string"
},
{
"type": "string"
"type": "boolean"
}
]
}

File diff suppressed because it is too large Load Diff

View File

@ -18,7 +18,7 @@ from torch.utils.data.sampler import RandomSampler, SequentialSampler
from haystack.modeling.data_handler.dataloader import NamedDataLoader
from haystack.modeling.data_handler.processor import Processor
from haystack.modeling.logger import MLFlowLogger as MlLogger
from haystack.utils.experiment_tracking import Tracker as tracker
from haystack.modeling.utils import log_ascii_workers, grouper, calc_chunksize
from haystack.modeling.visual import TRACTOR_SMALL
@ -497,7 +497,7 @@ class DataSilo:
logger.info("Average passage length after clipping: {}".format(ave_len[1]))
logger.info("Proportion passages clipped: {}".format(clipped[1]))
MlLogger.log_params(
tracker.track_params(
{
"n_samples_train": self.counts["train"],
"n_samples_dev": self.counts["dev"],

View File

@ -31,7 +31,7 @@ from haystack.modeling.data_handler.samples import (
offset_to_token_idx_vecorized,
)
from haystack.modeling.data_handler.input_features import sample_to_features_text
from haystack.modeling.logger import MLFlowLogger as MlLogger
from haystack.utils.experiment_tracking import Tracker as tracker
DOWNSTREAM_TASK_MAP = {
@ -359,7 +359,7 @@ class Processor(ABC):
for name in names:
value = getattr(self, name)
params.update({name: str(value)})
MlLogger.log_params(params)
tracker.track_params(params)
class SquadProcessor(Processor):

View File

@ -8,7 +8,7 @@ from tqdm import tqdm
from haystack.modeling.evaluation.metrics import compute_metrics, compute_report_metrics
from haystack.modeling.model.adaptive_model import AdaptiveModel
from haystack.modeling.logger import MLFlowLogger as MlLogger
from haystack.utils.experiment_tracking import Tracker as tracker
from haystack.modeling.visual import BUSH_SEP
@ -157,11 +157,11 @@ class Evaluator:
for head_num, head in enumerate(results):
logger.info("\n _________ {} _________".format(head["task_name"]))
for metric_name, metric_val in head.items():
# log with ML framework (e.g. Mlflow)
# log with experiment tracking framework (e.g. Mlflow)
if logging:
if not metric_name in ["preds", "labels"] and not metric_name.startswith("_"):
if isinstance(metric_val, numbers.Number):
MlLogger.log_metrics(
tracker.track_metrics(
metrics={f"{dataset_name}_{metric_name}_{head['task_name']}": metric_val}, step=steps
)
# print via standard python logger

View File

@ -21,7 +21,6 @@ from haystack.modeling.utils import (
)
from haystack.modeling.data_handler.inputs import QAInput
from haystack.modeling.model.adaptive_model import AdaptiveModel, BaseAdaptiveModel
from haystack.modeling.logger import MLFlowLogger
from haystack.modeling.model.predictions import QAPred
@ -74,8 +73,6 @@ class Inferencer:
:return: An instance of the Inferencer.
"""
MLFlowLogger.disable()
# Init device and distributed settings
self.devices, n_gpu = initialize_device_settings(use_cuda=gpu, multi_gpu=False)

View File

@ -1,145 +0,0 @@
import logging
import mlflow
from requests.exceptions import ConnectionError
logger = logging.getLogger(__name__)
class BaseMLLogger:
"""
Base class for tracking experiments.
This class can be extended to implement custom logging backends like MLFlow, Tensorboard, or Sacred.
"""
disable_logging = False
def __init__(self, tracking_uri, **kwargs):
self.tracking_uri = tracking_uri
def init_experiment(self, tracking_uri):
raise NotImplementedError()
@classmethod
def log_metrics(cls, metrics, step):
raise NotImplementedError()
@classmethod
def log_artifacts(cls, self):
raise NotImplementedError()
@classmethod
def log_params(cls, params):
raise NotImplementedError()
class StdoutLogger(BaseMLLogger):
"""Minimal logger printing metrics and params to stdout.
Useful for services like AWS SageMaker, where you parse metrics from the actual logs"""
def init_experiment(self, experiment_name, run_name=None, nested=True):
logger.info(f"\n **** Starting experiment '{experiment_name}' (Run: {run_name}) ****")
@classmethod
def log_metrics(cls, metrics, step):
logger.info(f"Logged metrics at step {step}: \n {metrics}")
@classmethod
def log_params(cls, params):
logger.info(f"Logged parameters: \n {params}")
@classmethod
def log_artifacts(cls, dir_path, artifact_path=None):
raise NotImplementedError
@classmethod
def end_run(cls):
logger.info(f"**** End of Experiment **** ")
class MLFlowLogger(BaseMLLogger):
"""
Logger for MLFlow experiment tracking.
"""
def init_experiment(self, experiment_name, run_name=None, nested=True):
if not self.disable_logging:
try:
mlflow.set_tracking_uri(self.tracking_uri)
mlflow.set_experiment(experiment_name)
mlflow.start_run(run_name=run_name, nested=nested)
except ConnectionError:
raise Exception(
f"MLFlow cannot connect to the remote server at {self.tracking_uri}.\n"
f"MLFlow also supports logging runs locally to files. Set the MLFlowLogger "
f"tracking_uri to an empty string to use that."
)
@classmethod
def log_metrics(cls, metrics, step):
if not cls.disable_logging:
try:
mlflow.log_metrics(metrics, step=step)
except ConnectionError:
logger.warning(f"ConnectionError in logging metrics to MLFlow.")
except Exception as e:
logger.warning(f"Failed to log metrics: {e}")
@classmethod
def log_params(cls, params):
if not cls.disable_logging:
try:
mlflow.log_params(params)
except ConnectionError:
logger.warning("ConnectionError in logging params to MLFlow")
except Exception as e:
logger.warning(f"Failed to log params: {e}")
@classmethod
def log_artifacts(cls, dir_path, artifact_path=None):
if not cls.disable_logging:
try:
mlflow.log_artifacts(dir_path, artifact_path)
except ConnectionError:
logger.warning(f"ConnectionError in logging artifacts to MLFlow")
except Exception as e:
logger.warning(f"Failed to log artifacts: {e}")
@classmethod
def end_run(cls):
if not cls.disable_logging:
mlflow.end_run()
@classmethod
def disable(cls):
logger.info("ML Logging is turned off. No parameters, metrics or artifacts will be logged to MLFlow.")
cls.disable_logging = True
class TensorBoardLogger(BaseMLLogger):
"""
PyTorch TensorBoard Logger
"""
def __init__(self, **kwargs):
try:
from tensorboardX import SummaryWriter # pylint: disable=import-error
except (ImportError, ModuleNotFoundError):
logger.info(
"tensorboardX not found, can't initialize TensorBoardLogger. "
"Enable it with 'pip install tensorboardX'."
)
TensorBoardLogger.summary_writer = SummaryWriter()
super().__init__(**kwargs)
@classmethod
def log_metrics(cls, metrics, step):
for key, value in metrics.items():
TensorBoardLogger.summary_writer.add_scalar(tag=key, scalar_value=value, global_step=step)
@classmethod
def log_params(cls, params):
for key, value in params.items():
TensorBoardLogger.summary_writer.add_text(tag=key, text_string=str(value))

View File

@ -15,7 +15,7 @@ from transformers.convert_graph_to_onnx import convert, quantize as quantize_mod
from haystack.modeling.data_handler.processor import Processor
from haystack.modeling.model.language_model import LanguageModel
from haystack.modeling.model.prediction_head import PredictionHead, QuestionAnsweringHead
from haystack.modeling.logger import MLFlowLogger as MlLogger
from haystack.utils.experiment_tracking import Tracker as tracker
logger = logging.getLogger(__name__)
@ -556,7 +556,7 @@ class AdaptiveModel(nn.Module, BaseAdaptiveModel):
"lm_output_types": ",".join(self.lm_output_types),
}
try:
MlLogger.log_params(params)
tracker.track_params(params)
except Exception as e:
logger.warning(f"ML logging didn't work: {e}")

View File

@ -10,7 +10,7 @@ from torch import nn
from haystack.modeling.data_handler.processor import Processor
from haystack.modeling.model.language_model import LanguageModel
from haystack.modeling.model.prediction_head import PredictionHead, TextSimilarityHead
from haystack.modeling.logger import MLFlowLogger as MlLogger
from haystack.utils.experiment_tracking import Tracker as tracker
logger = logging.getLogger(__name__)
@ -335,7 +335,7 @@ class BiAdaptiveModel(nn.Module):
"prediction_heads": ",".join([head.__class__.__name__ for head in self.prediction_heads]),
}
try:
MlLogger.log_params(params)
tracker.track_params(params)
except Exception as e:
logger.warning(f"ML logging didn't work: {e}")

View File

@ -10,7 +10,7 @@ from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel
from haystack.modeling.model.adaptive_model import AdaptiveModel
from haystack.modeling.logger import MLFlowLogger as MlLogger
from haystack.utils.experiment_tracking import Tracker as tracker
logger = logging.getLogger(__name__)
@ -161,7 +161,7 @@ def initialize_optimizer(
schedule_opts["num_training_steps"] = num_train_optimization_steps
# Log params
MlLogger.log_params({"use_amp": use_amp, "num_train_optimization_steps": schedule_opts["num_training_steps"]})
tracker.track_params({"use_amp": use_amp, "num_train_optimization_steps": schedule_opts["num_training_steps"]})
# Get optimizer from pytorch, transformers or apex
optimizer = _get_optim(model, optimizer_opts)
@ -189,8 +189,8 @@ def _get_optim(model, opts: Dict):
# Logging
logger.info(f"Loading optimizer `{optimizer_name}`: '{opts}'")
MlLogger.log_params(opts)
MlLogger.log_params({"optimizer_name": optimizer_name})
tracker.track_params(opts)
tracker.track_params({"optimizer_name": optimizer_name})
weight_decay = opts.pop("weight_decay", None)
no_decay = opts.pop("no_decay", None)
@ -279,15 +279,15 @@ def get_scheduler(optimizer, opts):
# convert from warmup proportion to steps if required
if "num_warmup_steps" in allowed_args and "num_warmup_steps" not in opts and "warmup_proportion" in opts:
opts["num_warmup_steps"] = int(opts["warmup_proportion"] * opts["num_training_steps"])
MlLogger.log_params({"warmup_proportion": opts["warmup_proportion"]})
tracker.track_params({"warmup_proportion": opts["warmup_proportion"]})
# only pass args that are supported by the constructor
constructor_opts = {k: v for k, v in opts.items() if k in allowed_args}
# Logging
logger.info(f"Loading schedule `{schedule_name}`: '{constructor_opts}'")
MlLogger.log_params(constructor_opts)
MlLogger.log_params({"schedule_name": schedule_name})
tracker.track_params(constructor_opts)
tracker.track_params({"schedule_name": schedule_name})
scheduler = sched_constructor(optimizer, **constructor_opts)
scheduler.opts = opts # save the opts with the scheduler to use in load/save

View File

@ -9,7 +9,7 @@ from torch import nn
from haystack.modeling.data_handler.processor import Processor
from haystack.modeling.model.language_model import LanguageModel
from haystack.modeling.model.prediction_head import PredictionHead
from haystack.modeling.logger import MLFlowLogger as MlLogger
from haystack.utils.experiment_tracking import Tracker as tracker
logger = logging.getLogger(__name__)
@ -369,7 +369,7 @@ class TriAdaptiveModel(nn.Module):
"prediction_heads": ",".join([head.__class__.__name__ for head in self.prediction_heads]),
}
try:
MlLogger.log_params(params)
tracker.track_params(params)
except Exception as e:
logger.warning(f"ML logging didn't work: {e}")

View File

@ -19,7 +19,7 @@ from haystack.modeling.evaluation.eval import Evaluator
from haystack.modeling.model.adaptive_model import AdaptiveModel
from haystack.modeling.model.optimization import get_scheduler
from haystack.modeling.utils import GracefulKiller
from haystack.modeling.logger import MLFlowLogger as MlLogger
from haystack.utils.experiment_tracking import Tracker as tracker
try:
from apex import amp
@ -161,7 +161,7 @@ class Trainer:
Useful to achieve larger effective batch sizes that would not fit in GPU memory.
:param local_rank: Local rank of process when distributed training via DDP is used.
:param early_stopping: an initialized EarlyStopping object to control early stopping and saving of best models.
:param log_learning_rate: Whether to log learning rate to Mlflow
:param log_learning_rate: Whether to log learning rate to experiment tracker (e.g. Mlflow)
:param log_loss_every: Log current train loss after this many train steps.
:param checkpoint_on_sigterm: save a checkpoint for the Trainer when a SIGTERM signal is sent. The checkpoint
can be used to resume training. It is useful in frameworks like AWS SageMaker with Spot instances where
@ -377,9 +377,9 @@ class Trainer:
loss = self.adjust_loss(loss)
if self.global_step % self.log_loss_every == 0 and self.local_rank in [-1, 0]:
if self.local_rank in [-1, 0]:
MlLogger.log_metrics({"Train_loss_total": float(loss.detach().cpu().numpy())}, step=self.global_step)
tracker.track_metrics({"Train_loss_total": float(loss.detach().cpu().numpy())}, step=self.global_step)
if self.log_learning_rate:
MlLogger.log_metrics({"learning_rate": self.lr_schedule.get_last_lr()[0]}, step=self.global_step)
tracker.track_metrics({"learning_rate": self.lr_schedule.get_last_lr()[0]}, step=self.global_step)
if self.use_amp:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
@ -406,7 +406,7 @@ class Trainer:
def log_params(self):
params = {"epochs": self.epochs, "n_gpu": self.n_gpu, "device": self.device}
MlLogger.log_params(params)
tracker.track_params(params)
@classmethod
def create_or_load_checkpoint(
@ -700,7 +700,7 @@ class DistillationTrainer(Trainer):
Useful to achieve larger effective batch sizes that would not fit in GPU memory.
:param local_rank: Local rank of process when distributed training via DDP is used.
:param early_stopping: an initialized EarlyStopping object to control early stopping and saving of best models.
:param log_learning_rate: Whether to log learning rate to Mlflow
:param log_learning_rate: Whether to log learning rate to experiment tracker (e.g. Mlflow)
:param log_loss_every: Log current train loss after this many train steps.
:param checkpoint_on_sigterm: save a checkpoint for the Trainer when a SIGTERM signal is sent. The checkpoint
can be used to resume training. It is useful in frameworks like AWS SageMaker with Spot instances where
@ -842,7 +842,7 @@ class TinyBERTDistillationTrainer(Trainer):
Useful to achieve larger effective batch sizes that would not fit in GPU memory.
:param local_rank: Local rank of process when distributed training via DDP is used.
:param early_stopping: an initialized EarlyStopping object to control early stopping and saving of best models.
:param log_learning_rate: Whether to log learning rate to Mlflow
:param log_learning_rate: Whether to log learning rate to experiment tracker (e.g. Mlflow)
:param log_loss_every: Log current train loss after this many train steps.
:param checkpoint_on_sigterm: save a checkpoint for the Trainer when a SIGTERM signal is sent. The checkpoint
can be used to resume training. It is useful in frameworks like AWS SageMaker with Spot instances where

View File

@ -455,7 +455,7 @@ class DensePassageRetriever(BaseRetriever):
use_amp=use_amp,
)
# 7. Let it grow! Watch the tracked metrics live on the public mlflow server: https://public-mlflow.deepset.ai
# 7. Let it grow! Watch the tracked metrics live on experiment tracker (e.g. Mlflow)
trainer.train()
self.model.save(Path(save_dir), lm1_name=query_encoder_save_dir, lm2_name=passage_encoder_save_dir)
@ -985,7 +985,7 @@ class TableTextRetriever(BaseRetriever):
use_amp=use_amp,
)
# 7. Let it grow! Watch the tracked metrics live on the public mlflow server: https://public-mlflow.deepset.ai
# 7. Let it grow! Watch the tracked metrics live on experiment tracker (e.g. Mlflow)
trainer.train()
self.model.save(

View File

@ -1,6 +1,11 @@
from __future__ import annotations
from typing import Dict, List, Optional, Any, Set, Tuple, Union
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal # type: ignore
import copy
import json
import inspect
@ -46,6 +51,7 @@ from haystack.nodes.base import BaseComponent
from haystack.nodes.retriever.base import BaseRetriever
from haystack.document_stores.base import BaseDocumentStore
from haystack.telemetry import send_event
from haystack.utils.experiment_tracking import MLflowTrackingHead, Tracker as tracker
logger = logging.getLogger(__name__)
@ -53,6 +59,7 @@ logger = logging.getLogger(__name__)
ROOT_NODE_TO_PIPELINE_NAME = {"query": "query", "file": "indexing"}
CODE_GEN_DEFAULT_COMMENT = "This code has been generated."
TRACKING_TOOL_TO_HEAD = {"mlflow": MLflowTrackingHead}
class RootNode(BaseComponent):
@ -770,6 +777,201 @@ class Pipeline(BasePipeline):
ndcg, map_, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
return ndcg, map_, recall, precision
@classmethod
def execute_eval_run(
cls,
index_pipeline: Pipeline,
query_pipeline: Pipeline,
evaluation_set_labels: List[MultiLabel],
corpus_file_paths: List[str],
experiment_name: str,
experiment_run_name: str,
experiment_tracking_tool: Literal["mlflow", None] = None,
experiment_tracking_uri: Optional[str] = None,
corpus_file_metas: List[Dict[str, Any]] = None,
corpus_meta: Dict[str, Any] = {},
evaluation_set_meta: Dict[str, Any] = {},
pipeline_meta: Dict[str, Any] = {},
index_params: dict = {},
query_params: dict = {},
sas_model_name_or_path: str = None,
sas_batch_size: int = 32,
sas_use_gpu: bool = True,
add_isolated_node_eval: bool = False,
reuse_index: bool = False,
) -> EvaluationResult:
"""
Starts an experiment run that first indexes the specified files (forming a corpus) using the index pipeline
and subsequently evaluates the query pipeline on the provided labels (forming an evaluation set) using pipeline.eval().
Parameters and results (metrics and predictions) of the run are tracked by an experiment tracking tool for further analysis.
You can specify the experiment tracking tool by setting the params `experiment_tracking_tool` and `experiment_tracking_uri`
or by passing a (custom) tracking head to Tracker.set_tracking_head().
Note, that `experiment_tracking_tool` only supports `mlflow` currently.
For easier comparison you can pass additional metadata regarding corpus (corpus_meta), evaluation set (evaluation_set_meta) and pipelines (pipeline_meta).
E.g. you can give them names or ids to identify them across experiment runs.
This method executes an experiment run. Each experiment run is part of at least one experiment.
An experiment typically consists of multiple runs to be compared (e.g. using different retrievers in query pipeline).
Experiment tracking tools usually share the same concepts of experiments and provide additional functionality to easily compare runs across experiments.
E.g. you can call execute_eval_run() multiple times with different retrievers in your query pipeline and compare the runs in mlflow:
```python
| for retriever_type, query_pipeline in zip(["sparse", "dpr", "embedding"], [sparse_pipe, dpr_pipe, embedding_pipe]):
| eval_result = Pipeline.execute_eval_run(
| index_pipeline=index_pipeline,
| query_pipeline=query_pipeline,
| evaluation_set_labels=labels,
| corpus_file_paths=file_paths,
| corpus_file_metas=file_metas,
| experiment_tracking_tool="mlflow",
| experiment_tracking_uri="http://localhost:5000",
| experiment_name="my-retriever-experiment",
| experiment_run_name=f"run_{retriever_type}",
| pipeline_meta={"name": f"my-pipeline-{retriever_type}"},
| evaluation_set_meta={"name": "my-evalset"},
| corpus_meta={"name": "my-corpus"}.
| reuse_index=False
| )
```
:param index_pipeline: The indexing pipeline to use.
:param query_pipeline: The query pipeline to evaluate.
:param evaluation_set_labels: The labels to evaluate on forming an evalution set.
:param corpus_file_paths: The files to be indexed and searched during evaluation forming a corpus.
:param experiment_name: The name of the experiment
:param experiment_run_name: The name of the experiment run
:param experiment_tracking_tool: The experiment tracking tool to be used. Currently we only support "mlflow".
If left unset the current TrackingHead specified by Tracker.set_tracking_head() will be used.
:param experiment_tracking_uri: The uri of the experiment tracking server to be used. Must be specified if experiment_tracking_tool is set.
You can use deepset's public mlflow server via https://public-mlflow.deepset.ai/.
Note, that artifact logging (e.g. Pipeline YAML or evaluation result CSVs) are currently not allowed on deepset's public mlflow server as this might expose sensitive data.
:param corpus_file_metas: The optional metadata to be stored for each corpus file (e.g. title).
:param corpus_meta: Metadata about the corpus to track (e.g. name, date, author, version).
:param evaluation_set_meta: Metadata about the evalset to track (e.g. name, date, author, version).
:param pipeline_meta: Metadata about the pipelines to track (e.g. name, author, version).
:param index_params: The params to use during indexing (see pipeline.run's params).
:param query_params: The params to use during querying (see pipeline.run's params).
:param sas_model_name_or_path: Name or path of "Semantic Answer Similarity (SAS) model". When set, the model will be used to calculate similarity between predictions and labels and generate the SAS metric.
The SAS metric correlates better with human judgement of correct answers as it does not rely on string overlaps.
Example: Prediction = "30%", Label = "thirty percent", EM and F1 would be overly pessimistic with both being 0, while SAS paints a more realistic picture.
More info in the paper: https://arxiv.org/abs/2108.06130
Models:
- You can use Bi Encoders (sentence transformers) or cross encoders trained on Semantic Textual Similarity (STS) data.
Not all cross encoders can be used because of different return types.
If you use custom cross encoders please make sure they work with sentence_transformers.CrossEncoder class
- Good default for multiple languages: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
- Large, powerful, but slow model for English only: "cross-encoder/stsb-roberta-large"
- Large model for German only: "deepset/gbert-large-sts"
:param sas_batch_size: Number of prediction label pairs to encode at once by CrossEncoder or SentenceTransformer while calculating SAS.
:param sas_use_gpu: Whether to use a GPU or the CPU for calculating semantic answer similarity.
Falls back to CPU if no GPU is available.
:param add_isolated_node_eval: If set to True, in addition to the integrated evaluation of the pipeline, each node is evaluated in isolated evaluation mode.
This mode helps to understand the bottlenecks of a pipeline in terms of output quality of each individual node.
If a node performs much better in the isolated evaluation than in the integrated evaluation, the previous node needs to be optimized to improve the pipeline's performance.
If a node's performance is similar in both modes, this node itself needs to be optimized to improve the pipeline's performance.
The isolated evaluation calculates the upper bound of each node's evaluation metrics under the assumption that it received perfect inputs from the previous node.
To this end, labels are used as input to the node instead of the output of the previous node in the pipeline.
The generated dataframes in the EvaluationResult then contain additional rows, which can be distinguished from the integrated evaluation results based on the
values "integrated" or "isolated" in the column "eval_mode" and the evaluation report then additionally lists the upper bound of each node's evaluation metrics.
:param reuse_index: Whether to reuse existing non-empty index and to keep the index after evaluation.
If True the index will be kept after evaluation and no indexing will take place if index has already documents. Otherwise it will be deleted immediately afterwards.
Defaults to False.
"""
if experiment_tracking_tool is not None:
tracking_head_cls = TRACKING_TOOL_TO_HEAD.get(experiment_tracking_tool, None)
if tracking_head_cls is None:
raise HaystackError(
f"Please specify a valid experiment_tracking_tool. Possible values are: {TRACKING_TOOL_TO_HEAD.keys()}"
)
if experiment_tracking_uri is None:
raise HaystackError(f"experiment_tracking_uri must be specified if experiment_tracking_tool is set.")
tracking_head = tracking_head_cls(tracking_uri=experiment_tracking_uri)
tracker.set_tracking_head(tracking_head)
try:
tracker.init_experiment(
experiment_name=experiment_name, run_name=experiment_run_name, tags={experiment_name: "True"}
)
tracker.track_params(
{
"dataset_label_count": len(evaluation_set_labels),
"dataset": evaluation_set_meta,
"sas_model_name_or_path": sas_model_name_or_path,
"sas_batch_size": sas_batch_size,
"sas_use_gpu": sas_use_gpu,
"pipeline_index_params": index_params,
"pipeline_query_params": query_params,
"pipeline": pipeline_meta,
"corpus_file_count": len(corpus_file_paths),
"corpus": corpus_meta,
"type": "offline/evaluation",
}
)
# check index before eval
document_store = index_pipeline.get_document_store()
if document_store is None:
raise HaystackError(f"Document store not found. Please provide pipelines with proper document store.")
document_count = document_store.get_document_count()
if document_count > 0:
if not reuse_index:
raise HaystackError(f"Index '{document_store.index}' is not empty. Please provide an empty index.")
else:
logger.info(f"indexing {len(corpus_file_paths)} documents...")
index_pipeline.run(file_paths=corpus_file_paths, meta=corpus_file_metas, params=index_params)
document_count = document_store.get_document_count()
logger.info(f"indexing {len(evaluation_set_labels)} files to {document_count} documents finished.")
tracker.track_params({"pipeline_index_document_count": document_count})
eval_result = query_pipeline.eval(
labels=evaluation_set_labels,
params=query_params,
sas_model_name_or_path=sas_model_name_or_path,
sas_batch_size=sas_batch_size,
sas_use_gpu=sas_use_gpu,
add_isolated_node_eval=add_isolated_node_eval,
)
integrated_metrics = eval_result.calculate_metrics()
integrated_top_1_metrics = eval_result.calculate_metrics(simulated_top_k_reader=1)
metrics = {"integrated": integrated_metrics, "integrated_top_1": integrated_top_1_metrics}
if add_isolated_node_eval:
isolated_metrics = eval_result.calculate_metrics(eval_mode="isolated")
isolated_top_1_metrics = eval_result.calculate_metrics(eval_mode="isolated", simulated_top_k_reader=1)
metrics["isolated"] = isolated_metrics
metrics["isolated_top_1"] = isolated_top_1_metrics
tracker.track_metrics(metrics, step=0)
with tempfile.TemporaryDirectory() as temp_dir:
eval_result_dir = Path(temp_dir) / "eval_result"
eval_result_dir.mkdir(exist_ok=True)
eval_result.save(out_dir=eval_result_dir)
tracker.track_artifacts(eval_result_dir, artifact_path="eval_result")
with open(Path(temp_dir) / "pipelines.yaml", "w") as outfile:
index_config = index_pipeline.get_config()
query_config = query_pipeline.get_config()
components = list(
{c["name"]: c for c in (index_config["components"] + query_config["components"])}.values()
)
pipelines = index_config["pipelines"] + query_config["pipelines"]
config = {"version": index_config["version"], "components": components, "pipelines": pipelines}
yaml.dump(config, outfile, default_flow_style=False)
tracker.track_artifacts(temp_dir)
# Clean up document store
if not reuse_index and document_store.index is not None:
logger.info(f"Cleaning up: deleting index '{document_store.index}'...")
document_store.delete_index(document_store.index)
finally:
tracker.end_run()
return eval_result
@send_event
def eval(
self,

View File

@ -5,34 +5,26 @@
You can opt-out of sharing usage statistics by calling disable_telemetry() or by manually setting the environment variable HAYSTACK_TELEMETRY_ENABLED as described for different operating systems on the documentation page.
You can log all events to the local file specified in LOG_PATH for inspection by setting the environment variable HAYSTACK_TELEMETRY_LOGGING_TO_FILE_ENABLED to "True".
"""
from typing import List, Dict, Any, Optional
import os
import sys
from typing import Any, Dict, List, Optional
import uuid
import logging
import platform
from enum import Enum
from functools import wraps
from pathlib import Path
import yaml
import torch
import posthog
import transformers
from haystack import __version__
from haystack.environment import HAYSTACK_EXECUTION_CONTEXT, get_or_create_env_meta_data
posthog.api_key = "phc_F5v11iI2YHkoP6Er3cPILWSrLhY3D6UY4dEMga4eoaa"
posthog.host = "https://tm.hs.deepset.ai"
HAYSTACK_TELEMETRY_ENABLED = "HAYSTACK_TELEMETRY_ENABLED"
HAYSTACK_TELEMETRY_LOGGING_TO_FILE_ENABLED = "HAYSTACK_TELEMETRY_LOGGING_TO_FILE_ENABLED"
HAYSTACK_EXECUTION_CONTEXT = "HAYSTACK_EXECUTION_CONTEXT"
HAYSTACK_DOCKER_CONTAINER = "HAYSTACK_DOCKER_CONTAINER"
CONFIG_PATH = Path("~/.haystack/config.yaml").expanduser()
LOG_PATH = Path("~/.haystack/telemetry.log").expanduser()
telemetry_meta_data: Dict[str, Any] = {}
user_id: Optional[str] = None
logger = logging.getLogger(__name__)
@ -49,7 +41,7 @@ def print_telemetry_report():
"""
if is_telemetry_enabled():
user_id = _get_or_create_user_id()
meta_data = _get_or_create_telemetry_meta_data()
meta_data = get_or_create_env_meta_data()
print({**{"user_id": user_id}, **meta_data})
else:
print("Telemetry is disabled.")
@ -152,7 +144,7 @@ def send_custom_event(event: str = "", payload: Dict[str, Any] = {}):
:param payload: A dictionary containing event meta data, e.g., parameter settings
"""
event_properties = {**(NonPrivateParameters.apply_filter(payload)), **_get_or_create_telemetry_meta_data()}
event_properties = {**(NonPrivateParameters.apply_filter(payload)), **get_or_create_env_meta_data()}
if user_id is None:
raise RuntimeError("User id was not initialized")
try:
@ -224,53 +216,6 @@ def _get_or_create_user_id() -> str:
return user_id
def _get_or_create_telemetry_meta_data() -> Dict[str, Any]:
"""
Collects meta data about the setup that is used with Haystack, such as: operating system, python version, Haystack version, transformers version, pytorch version, number of GPUs, execution environment, and the value stored in the env variable HAYSTACK_EXECUTION_CONTEXT.
"""
global telemetry_meta_data # pylint: disable=global-statement
if not telemetry_meta_data:
telemetry_meta_data = {
"os_version": platform.release(),
"os_family": platform.system(),
"os_machine": platform.machine(),
"python_version": platform.python_version(),
"haystack_version": __version__,
"transformers_version": transformers.__version__,
"torch_version": torch.__version__,
"torch_cuda_version": torch.version.cuda if torch.cuda.is_available() else 0,
"n_gpu": torch.cuda.device_count() if torch.cuda.is_available() else 0,
"n_cpu": os.cpu_count(),
"context": os.environ.get(HAYSTACK_EXECUTION_CONTEXT),
"execution_env": _get_execution_environment(),
}
return telemetry_meta_data
def _get_execution_environment():
"""
Identifies the execution environment that Haystack is running in.
Options are: colab notebook, kubernetes, CPU/GPU docker container, test environment, jupyter notebook, python script
"""
if os.environ.get("CI", "False").lower() == "true":
execution_env = "ci"
elif "google.colab" in sys.modules:
execution_env = "colab"
elif "KUBERNETES_SERVICE_HOST" in os.environ:
execution_env = "kubernetes"
elif HAYSTACK_DOCKER_CONTAINER in os.environ:
execution_env = os.environ.get(HAYSTACK_DOCKER_CONTAINER)
# check if pytest is imported
elif "pytest" in sys.modules:
execution_env = "test"
else:
try:
execution_env = get_ipython().__class__.__name__ # pylint: disable=undefined-variable
except NameError:
execution_env = "script"
return execution_env
def _read_telemetry_config():
"""
Loads the config from the file specified in CONFIG_PATH

View File

@ -19,3 +19,10 @@ from haystack.utils.export_utils import (
)
from haystack.utils.squad_data import SquadData
from haystack.utils.context_matching import calculate_context_similarity, match_context, match_contexts
from haystack.utils.experiment_tracking import (
Tracker,
NoTrackingHead,
BaseTrackingHead,
MLflowTrackingHead,
StdoutTrackingHead,
)

View File

@ -0,0 +1,188 @@
from abc import ABC, abstractmethod
import logging
from pathlib import Path
from typing import Any, Dict, Union
import mlflow
from requests.exceptions import ConnectionError
from haystack.environment import get_or_create_env_meta_data
logger = logging.getLogger(__name__)
def flatten_dict(dict_to_flatten: dict, prefix: str = ""):
flat_dict = {}
for k, v in dict_to_flatten.items():
if isinstance(v, dict):
flat_dict.update(flatten_dict(v, prefix + k + "_"))
else:
flat_dict[prefix + k] = v
return flat_dict
class BaseTrackingHead(ABC):
"""
Base class for tracking experiments.
This class can be extended to implement custom logging backends like MLflow, WandB, or TensorBoard.
"""
@abstractmethod
def init_experiment(
self, experiment_name: str, run_name: str = None, tags: Dict[str, Any] = None, nested: bool = False
):
raise NotImplementedError()
@abstractmethod
def track_metrics(self, metrics: Dict[str, Any], step: int):
raise NotImplementedError()
@abstractmethod
def track_artifacts(self, dir_path: Union[str, Path], artifact_path: str = None):
raise NotImplementedError()
@abstractmethod
def track_params(self, params: Dict[str, Any]):
raise NotImplementedError()
@abstractmethod
def end_run(self):
raise NotImplementedError()
class NoTrackingHead(BaseTrackingHead):
"""
Null object implementation of a tracking head: i.e. does nothing.
"""
def init_experiment(
self, experiment_name: str, run_name: str = None, tags: Dict[str, Any] = None, nested: bool = False
):
pass
def track_metrics(self, metrics: Dict[str, Any], step: int):
pass
def track_artifacts(self, dir_path: Union[str, Path], artifact_path: str = None):
pass
def track_params(self, params: Dict[str, Any]):
pass
def end_run(self):
pass
class Tracker:
"""
Facade for tracking experiments.
"""
tracker: BaseTrackingHead = NoTrackingHead()
@classmethod
def init_experiment(
cls, experiment_name: str, run_name: str = None, tags: Dict[str, Any] = None, nested: bool = False
):
cls.tracker.init_experiment(experiment_name=experiment_name, run_name=run_name, tags=tags, nested=nested)
@classmethod
def track_metrics(cls, metrics: Dict[str, Any], step: int):
cls.tracker.track_metrics(metrics=metrics, step=step)
@classmethod
def track_artifacts(cls, dir_path: Union[str, Path], artifact_path: str = None):
cls.tracker.track_artifacts(dir_path=dir_path, artifact_path=artifact_path)
@classmethod
def track_params(cls, params: Dict[str, Any]):
cls.tracker.track_params(params=params)
@classmethod
def end_run(cls):
cls.tracker.end_run()
@classmethod
def set_tracking_head(cls, tracker: BaseTrackingHead):
cls.tracker = tracker
class StdoutTrackingHead(BaseTrackingHead):
"""
Experiment tracking head printing metrics and params to stdout.
Useful for services like AWS SageMaker, where you parse metrics from the actual logs
"""
def init_experiment(
self, experiment_name: str, run_name: str = None, tags: Dict[str, Any] = None, nested: bool = False
):
logger.info(f"\n **** Starting experiment '{experiment_name}' (Run: {run_name}) ****")
def track_metrics(self, metrics: Dict[str, Any], step: int):
logger.info(f"Logged metrics at step {step}: \n {metrics}")
def track_params(self, params: Dict[str, Any]):
logger.info(f"Logged parameters: \n {params}")
def track_artifacts(self, dir_path: Union[str, Path], artifact_path: str = None):
logger.warning(f"Cannot log artifacts with StdoutLogger: \n {dir_path}")
def end_run(self):
logger.info(f"**** End of Experiment **** ")
class MLflowTrackingHead(BaseTrackingHead):
def __init__(self, tracking_uri: str, auto_track_environment: bool = True) -> None:
"""
Experiment tracking head for MLflow.
"""
super().__init__()
self.tracking_uri = tracking_uri
self.auto_track_environment = auto_track_environment
def init_experiment(
self, experiment_name: str, run_name: str = None, tags: Dict[str, Any] = None, nested: bool = False
):
try:
mlflow.set_tracking_uri(self.tracking_uri)
mlflow.set_experiment(experiment_name)
mlflow.start_run(run_name=run_name, nested=nested, tags=tags)
logger.info(f"Tracking run {run_name} of experiment {experiment_name} by mlflow under {self.tracking_uri}")
if self.auto_track_environment:
mlflow.log_params(flatten_dict({"environment": get_or_create_env_meta_data()}))
except ConnectionError:
raise Exception(
f"MLflow cannot connect to the remote server at {self.tracking_uri}.\n"
f"MLflow also supports logging runs locally to files. Set the MLflowTrackingHead "
f"tracking_uri to an empty string to use that."
)
def track_metrics(self, metrics: Dict[str, Any], step: int):
try:
metrics = flatten_dict(metrics)
mlflow.log_metrics(metrics, step=step)
except ConnectionError:
logger.warning(f"ConnectionError in logging metrics to MLflow.")
except Exception as e:
logger.warning(f"Failed to log metrics: {e}")
def track_params(self, params: Dict[str, Any]):
try:
params = flatten_dict(params)
mlflow.log_params(params)
except ConnectionError:
logger.warning("ConnectionError in logging params to MLflow")
except Exception as e:
logger.warning(f"Failed to log params: {e}")
def track_artifacts(self, dir_path: Union[str, Path], artifact_path: str = None):
try:
mlflow.log_artifacts(dir_path, artifact_path)
except ConnectionError:
logger.warning(f"ConnectionError in logging artifacts to MLflow")
except Exception as e:
logger.warning(f"Failed to log artifacts: {e}")
def end_run(self):
mlflow.end_run()

View File

@ -91,7 +91,7 @@ install_requires =
# Metrics and logging
seqeval
mlflow<=1.13.1
mlflow
# Elasticsearch
elasticsearch>=7.7,<=7.10