feat: globally disable progress bars (#5207)

* add SilenceableTqdm and update usage

* pylint

* rename module

* add tests
This commit is contained in:
ZanSara 2023-06-27 11:45:17 +02:00 committed by GitHub
parent 5ee393226d
commit 462f3a5c99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 165 additions and 38 deletions

View File

@ -5,6 +5,8 @@ from importlib import metadata
__version__: str = str(metadata.version("farm-haystack"))
import haystack.silenceable_tqdm # Needs to be imported first to wrap TQDM for all following modules
from haystack.schema import Document, Answer, Label, MultiLabel, Span, EvaluationResult, TableCell
from haystack.nodes.base import BaseComponent
from haystack.pipelines.base import Pipeline

View File

@ -1,6 +1,6 @@
from typing import Dict, Optional, List, Union
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.schema import Document
from haystack.document_stores.base import BaseDocumentStore

View File

@ -9,7 +9,7 @@ from copy import deepcopy
from inspect import Signature, signature
import numpy as np
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.schema import Document, FilterType
from haystack.utils.batching import get_batches_from_generator

View File

@ -8,7 +8,7 @@ from collections import defaultdict
import re
import numpy as np
from tqdm.auto import tqdm
from tqdm import tqdm
import rank_bm25
import pandas as pd

View File

@ -3,7 +3,7 @@ from typing import List, Optional, Union, Dict, Any
import logging
import numpy as np
from tqdm.auto import tqdm
from tqdm import tqdm
from tenacity import retry, wait_exponential, retry_if_not_result
from haystack.schema import Document, FilterType

View File

@ -8,7 +8,7 @@ from functools import reduce
import operator
import numpy as np
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.schema import Document, FilterType, Label, Answer, Span
from haystack.document_stores import BaseDocumentStore

View File

@ -10,7 +10,7 @@ import time
from string import Template
import numpy as np
from tqdm.auto import tqdm
from tqdm import tqdm
from pydantic.error_wrappers import ValidationError
from haystack.document_stores import KeywordDocumentStore

View File

@ -6,7 +6,7 @@ import uuid
from typing import Any, Dict, Generator, List, Optional, Union
import numpy as np
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.schema import Document, FilterType, Label
from haystack.document_stores import KeywordDocumentStore

View File

@ -7,7 +7,7 @@ import random
from itertools import groupby
from pathlib import Path
import numpy as np
from tqdm.auto import tqdm
from tqdm import tqdm
import torch
from torch.utils.data import ConcatDataset, Dataset
from torch.utils.data.distributed import DistributedSampler

View File

@ -14,7 +14,7 @@ from abc import ABC, abstractmethod
import numpy as np
import requests
from tqdm.auto import tqdm
from tqdm import tqdm
from torch.utils.data import TensorDataset
import transformers
from transformers import PreTrainedTokenizer, AutoTokenizer

View File

@ -5,7 +5,7 @@ import numbers
import torch
from torch.nn import DataParallel
import numpy as np
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.modeling.evaluation.metrics import compute_metrics, compute_report_metrics
from haystack.modeling.model.adaptive_model import AdaptiveModel

View File

@ -2,7 +2,7 @@ from typing import List, Optional, Dict, Union, Set, Any
import os
import logging
from tqdm.auto import tqdm
from tqdm import tqdm
import torch
from torch.utils.data.sampler import SequentialSampler
from torch.utils.data import Dataset

View File

@ -6,7 +6,7 @@ import logging
from pathlib import Path
import numpy
from tqdm.auto import tqdm
from tqdm import tqdm
import torch
from torch.optim.lr_scheduler import _LRScheduler
from torch.nn import MSELoss, Linear, Module, ModuleList, DataParallel

View File

@ -1,7 +1,7 @@
from abc import abstractmethod
from typing import Any, List, Optional, Dict, Union
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.errors import HaystackError
from haystack.schema import Answer, Document, MultiLabel

View File

@ -2,7 +2,7 @@ import logging
from typing import List, Optional, Union, Dict
import itertools
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.nodes.base import Document
from haystack.nodes.doc_language_classifier.base import BaseDocumentLanguageClassifier

View File

@ -2,7 +2,7 @@ from typing import List, Optional, Union
import logging
import itertools
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.schema import Document
from haystack.nodes.document_classifier.base import BaseDocumentClassifier

View File

@ -24,7 +24,7 @@ import itertools
import numpy as np
from tokenizers.pre_tokenizers import WhitespaceSplit
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.schema import Document
from haystack.nodes.base import BaseComponent
from haystack.lazy_imports import LazyImport

View File

@ -4,7 +4,7 @@ import logging
from abc import abstractmethod
from pathlib import Path
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.nodes.base import BaseComponent
from haystack.schema import Document

View File

@ -2,7 +2,7 @@ from typing import List, Optional, Union
import logging
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.schema import Document
from haystack.nodes.image_to_text.base import BaseImageToText

View File

@ -2,7 +2,7 @@ import logging
import random
from typing import Dict, Iterable, List, Optional, Tuple, Union
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.nodes.base import BaseComponent
from haystack.nodes.question_generator import QuestionGenerator

View File

@ -1,6 +1,6 @@
from typing import List, Union, Dict
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.errors import HaystackError
from haystack.schema import Document, Answer, Span

View File

@ -9,7 +9,7 @@ import warnings
from pathlib import Path
from pickle import UnpicklingError
from tqdm.auto import tqdm
from tqdm import tqdm
from more_itertools import windowed
from haystack.nodes.preprocessor.base import BasePreProcessor

View File

@ -4,7 +4,7 @@ from typing import Union, Any, List, Optional, Iterator, Dict
import pickle
import urllib
from tqdm.auto import tqdm
from tqdm import tqdm
from sklearn.ensemble._gb_losses import BinomialDeviance
from sklearn.ensemble._gb import GradientBoostingClassifier

View File

@ -2,7 +2,7 @@ import logging
from pathlib import Path
from typing import Union, List, Optional, Dict, Any
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.nodes.query_classifier.base import BaseQueryClassifier
from haystack.lazy_imports import LazyImport

View File

@ -2,7 +2,7 @@ import logging
from typing import List, Union, Optional, Iterator
import itertools
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.errors import HaystackError
from haystack.schema import Document

View File

@ -2,7 +2,7 @@ from typing import List, Optional, Union, Tuple, Iterator, Any
import logging
from pathlib import Path
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.errors import HaystackError
from haystack.schema import Document

View File

@ -9,7 +9,7 @@ from tenacity import retry, retry_if_exception_type, wait_exponential, stop_afte
import numpy as np
import requests
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.environment import (
HAYSTACK_REMOTE_API_BACKOFF_SEC,

View File

@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
import numpy as np
from tiktoken.model import MODEL_TO_ENCODING
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.environment import HAYSTACK_REMOTE_API_TIMEOUT_SEC
from haystack.nodes.retriever._base_embedding_encoder import _BaseEmbeddingEncoder

View File

@ -5,7 +5,7 @@ from abc import abstractmethod
from time import perf_counter
from functools import wraps
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.schema import Document, MultiLabel
from haystack.errors import HaystackError, PipelineError

View File

@ -8,7 +8,7 @@ from copy import deepcopy
from requests.exceptions import HTTPError
import numpy as np
from tqdm.auto import tqdm
from tqdm import tqdm
import pandas as pd
from huggingface_hub import hf_hub_download

View File

@ -3,7 +3,7 @@ from typing import Union, Optional, Dict, List, Any
import logging
from pathlib import Path
from tqdm.auto import tqdm
from tqdm import tqdm
import numpy as np
from PIL import Image

View File

@ -3,7 +3,7 @@ from typing import List, Optional, Set, Union
import logging
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.schema import Document
from haystack.nodes.summarizer.base import BaseSummarizer

View File

@ -2,7 +2,7 @@ import logging
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.errors import HaystackError
from haystack.schema import Document, Answer

View File

@ -20,7 +20,7 @@ import numpy as np
import pandas as pd
import networkx as nx
from pandas.core.frame import DataFrame
from tqdm.auto import tqdm
from tqdm import tqdm
from networkx import DiGraph
from networkx.drawing.nx_agraph import to_agraph

View File

@ -0,0 +1,40 @@
import os
import tqdm
class SilenceableTqdm(tqdm.tqdm):
"""
Wrapper for tqdm that disables all progress bars if HAYSTACK_PROGRESS_BARS is set to a falsey value
("0", "False", "FALSE", "false").
Note: this check is done every time a tqdm iterator is initialized, so normally for each method run. Therefore
progress bars can be enabled and disabled at runtime, but not during a specific iteration.
"""
def __init__(self, *args, **kwargs):
"""
Passes `disable=True` to tqdm if `self.no_progress_bars` is set to True.
"""
if self.no_progress_bars:
kwargs["disable"] = True
super().__init__(*args, **kwargs)
@property
def no_progress_bars(self):
"""
Reads the HAYSTACK_PROGRESS_BARS env var to check if the progress bars should be disabled.
"""
return os.getenv("HAYSTACK_PROGRESS_BARS", "1") in ["0", "False", "FALSE", "false"]
@property
def disable(self):
return self.no_progress_bars or self._disable
@disable.setter
def disable(self, value):
self._disable = value
tqdm.std.tqdm = SilenceableTqdm
tqdm.tqdm = SilenceableTqdm

View File

@ -38,7 +38,7 @@ from torch.nn import functional as F
from transformers import AutoModelForMaskedLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
import requests
import numpy as np
from tqdm.auto import tqdm
from tqdm import tqdm
logger = logging.getLogger(__name__)

View File

@ -6,7 +6,7 @@ from itertools import groupby
from multiprocessing.pool import Pool
from collections import namedtuple
from tqdm.auto import tqdm
from tqdm import tqdm
logger = logging.getLogger(__file__)

View File

@ -11,7 +11,7 @@ from enum import Enum
import pandas as pd
import requests
import yaml
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.schema import Answer, Document, EvaluationResult, FilterType, Label

View File

@ -4,7 +4,7 @@ import logging
import json
import random
import pandas as pd
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.mmh3 import hash128
from haystack.schema import Document, Label, Answer

View File

@ -64,7 +64,7 @@ from time import sleep
from pathlib import Path
from itertools import islice
from tqdm.auto import tqdm
from tqdm import tqdm
from haystack.document_stores.base import BaseDocumentStore
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore # keep it here !

85
test/utils/test_tqdm.py Normal file
View File

@ -0,0 +1,85 @@
from tqdm import tqdm
def test_silenceable_tqdm_not_disabled_by_default():
progress_bar = tqdm(range(1))
assert not progress_bar.disable
def test_silenceable_tqdm_can_be_silenced_with_0(monkeypatch):
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "0")
progress_bar = tqdm(range(1))
assert progress_bar.disable
def test_silenceable_tqdm_can_be_silenced_with_false(monkeypatch):
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "false")
progress_bar = tqdm(range(1))
assert progress_bar.disable
def test_silenceable_tqdm_can_be_silenced_with_False(monkeypatch):
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "False")
progress_bar = tqdm(range(1))
assert progress_bar.disable
def test_silenceable_tqdm_can_be_silenced_with_FALSE(monkeypatch):
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "FALSE")
progress_bar = tqdm(range(1))
assert progress_bar.disable
def test_silenceable_tqdm_not_disabled_with_number_above_zero(monkeypatch):
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "1")
progress_bar = tqdm(range(1))
assert not progress_bar.disable
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "10")
progress_bar = tqdm(range(1))
assert not progress_bar.disable
def test_silenceable_tqdm_not_disabled_with_empty_string(monkeypatch):
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "")
progress_bar = tqdm(range(1))
assert not progress_bar.disable
def test_silenceable_tqdm_not_disabled_with_other_string(monkeypatch):
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "true")
progress_bar = tqdm(range(1))
assert not progress_bar.disable
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "don't print the progress bars please")
progress_bar = tqdm(range(1))
assert not progress_bar.disable
def test_silenceable_tqdm_can_be_disabled_explicitly():
progress_bar = tqdm(range(1), disable=True)
assert progress_bar.disable
def test_silenceable_tqdm_global_disable_overrides_local_enable(monkeypatch):
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "0")
progress_bar = tqdm(range(1), disable=False)
assert progress_bar.disable
def test_silenceable_tqdm_global_enable_does_not_overrides_local_disable(monkeypatch):
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "1")
progress_bar = tqdm(range(1), disable=True)
assert progress_bar.disable
def test_silenceable_tqdm_global_and_local_disable_do_not_clash(monkeypatch):
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "0")
progress_bar = tqdm(range(1), disable=True)
assert progress_bar.disable
def test_silenceable_tqdm_global_and_local_enable_do_not_clash(monkeypatch):
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "1")
progress_bar = tqdm(range(1), disable=False)
assert not progress_bar.disable