mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 03:57:19 +00:00
feat: globally disable progress bars (#5207)
* add SilenceableTqdm and update usage * pylint * rename module * add tests
This commit is contained in:
parent
5ee393226d
commit
462f3a5c99
@ -5,6 +5,8 @@ from importlib import metadata
|
||||
|
||||
__version__: str = str(metadata.version("farm-haystack"))
|
||||
|
||||
|
||||
import haystack.silenceable_tqdm # Needs to be imported first to wrap TQDM for all following modules
|
||||
from haystack.schema import Document, Answer, Label, MultiLabel, Span, EvaluationResult, TableCell
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.pipelines.base import Pipeline
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Dict, Optional, List, Union
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.schema import Document
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
|
||||
@ -9,7 +9,7 @@ from copy import deepcopy
|
||||
from inspect import Signature, signature
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.schema import Document, FilterType
|
||||
from haystack.utils.batching import get_batches_from_generator
|
||||
|
||||
@ -8,7 +8,7 @@ from collections import defaultdict
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
import rank_bm25
|
||||
import pandas as pd
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import List, Optional, Union, Dict, Any
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
from tenacity import retry, wait_exponential, retry_if_not_result
|
||||
|
||||
from haystack.schema import Document, FilterType
|
||||
|
||||
@ -8,7 +8,7 @@ from functools import reduce
|
||||
import operator
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.schema import Document, FilterType, Label, Answer, Span
|
||||
from haystack.document_stores import BaseDocumentStore
|
||||
|
||||
@ -10,7 +10,7 @@ import time
|
||||
from string import Template
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
|
||||
from haystack.document_stores import KeywordDocumentStore
|
||||
|
||||
@ -6,7 +6,7 @@ import uuid
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.schema import Document, FilterType, Label
|
||||
from haystack.document_stores import KeywordDocumentStore
|
||||
|
||||
@ -7,7 +7,7 @@ import random
|
||||
from itertools import groupby
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torch.utils.data import ConcatDataset, Dataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
@ -14,7 +14,7 @@ from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import TensorDataset
|
||||
import transformers
|
||||
from transformers import PreTrainedTokenizer, AutoTokenizer
|
||||
|
||||
@ -5,7 +5,7 @@ import numbers
|
||||
import torch
|
||||
from torch.nn import DataParallel
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.modeling.evaluation.metrics import compute_metrics, compute_report_metrics
|
||||
from haystack.modeling.model.adaptive_model import AdaptiveModel
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import List, Optional, Dict, Union, Set, Any
|
||||
|
||||
import os
|
||||
import logging
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torch.utils.data.sampler import SequentialSampler
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
@ -6,7 +6,7 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.nn import MSELoss, Linear, Module, ModuleList, DataParallel
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, List, Optional, Dict, Union
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.errors import HaystackError
|
||||
from haystack.schema import Answer, Document, MultiLabel
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
from typing import List, Optional, Union, Dict
|
||||
import itertools
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.nodes.base import Document
|
||||
from haystack.nodes.doc_language_classifier.base import BaseDocumentLanguageClassifier
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import List, Optional, Union
|
||||
import logging
|
||||
import itertools
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.schema import Document
|
||||
from haystack.nodes.document_classifier.base import BaseDocumentClassifier
|
||||
|
||||
@ -24,7 +24,7 @@ import itertools
|
||||
import numpy as np
|
||||
|
||||
from tokenizers.pre_tokenizers import WhitespaceSplit
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
from haystack.schema import Document
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.schema import Document
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import List, Optional, Union
|
||||
|
||||
import logging
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.schema import Document
|
||||
from haystack.nodes.image_to_text.base import BaseImageToText
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import random
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.nodes.question_generator import QuestionGenerator
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import List, Union, Dict
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.errors import HaystackError
|
||||
from haystack.schema import Document, Answer, Span
|
||||
|
||||
@ -9,7 +9,7 @@ import warnings
|
||||
from pathlib import Path
|
||||
from pickle import UnpicklingError
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
from more_itertools import windowed
|
||||
|
||||
from haystack.nodes.preprocessor.base import BasePreProcessor
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Union, Any, List, Optional, Iterator, Dict
|
||||
import pickle
|
||||
import urllib
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
from sklearn.ensemble._gb_losses import BinomialDeviance
|
||||
from sklearn.ensemble._gb import GradientBoostingClassifier
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Union, List, Optional, Dict, Any
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.nodes.query_classifier.base import BaseQueryClassifier
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
from typing import List, Union, Optional, Iterator
|
||||
import itertools
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.errors import HaystackError
|
||||
from haystack.schema import Document
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import List, Optional, Union, Tuple, Iterator, Any
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.errors import HaystackError
|
||||
from haystack.schema import Document
|
||||
|
||||
@ -9,7 +9,7 @@ from tenacity import retry, retry_if_exception_type, wait_exponential, stop_afte
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.environment import (
|
||||
HAYSTACK_REMOTE_API_BACKOFF_SEC,
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from tiktoken.model import MODEL_TO_ENCODING
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.environment import HAYSTACK_REMOTE_API_TIMEOUT_SEC
|
||||
from haystack.nodes.retriever._base_embedding_encoder import _BaseEmbeddingEncoder
|
||||
|
||||
@ -5,7 +5,7 @@ from abc import abstractmethod
|
||||
from time import perf_counter
|
||||
from functools import wraps
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.schema import Document, MultiLabel
|
||||
from haystack.errors import HaystackError, PipelineError
|
||||
|
||||
@ -8,7 +8,7 @@ from copy import deepcopy
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
import pandas as pd
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Union, Optional, Dict, List, Any
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import List, Optional, Set, Union
|
||||
|
||||
import logging
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.schema import Document
|
||||
from haystack.nodes.summarizer.base import BaseSummarizer
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.errors import HaystackError
|
||||
from haystack.schema import Document, Answer
|
||||
|
||||
@ -20,7 +20,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import networkx as nx
|
||||
from pandas.core.frame import DataFrame
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
from networkx import DiGraph
|
||||
from networkx.drawing.nx_agraph import to_agraph
|
||||
|
||||
|
||||
40
haystack/silenceable_tqdm.py
Normal file
40
haystack/silenceable_tqdm.py
Normal file
@ -0,0 +1,40 @@
|
||||
import os
|
||||
import tqdm
|
||||
|
||||
|
||||
class SilenceableTqdm(tqdm.tqdm):
|
||||
"""
|
||||
Wrapper for tqdm that disables all progress bars if HAYSTACK_PROGRESS_BARS is set to a falsey value
|
||||
("0", "False", "FALSE", "false").
|
||||
|
||||
Note: this check is done every time a tqdm iterator is initialized, so normally for each method run. Therefore
|
||||
progress bars can be enabled and disabled at runtime, but not during a specific iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Passes `disable=True` to tqdm if `self.no_progress_bars` is set to True.
|
||||
"""
|
||||
if self.no_progress_bars:
|
||||
kwargs["disable"] = True
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def no_progress_bars(self):
|
||||
"""
|
||||
Reads the HAYSTACK_PROGRESS_BARS env var to check if the progress bars should be disabled.
|
||||
"""
|
||||
return os.getenv("HAYSTACK_PROGRESS_BARS", "1") in ["0", "False", "FALSE", "false"]
|
||||
|
||||
@property
|
||||
def disable(self):
|
||||
return self.no_progress_bars or self._disable
|
||||
|
||||
@disable.setter
|
||||
def disable(self, value):
|
||||
self._disable = value
|
||||
|
||||
|
||||
tqdm.std.tqdm = SilenceableTqdm
|
||||
tqdm.tqdm = SilenceableTqdm
|
||||
@ -38,7 +38,7 @@ from torch.nn import functional as F
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
|
||||
import requests
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -6,7 +6,7 @@ from itertools import groupby
|
||||
from multiprocessing.pool import Pool
|
||||
from collections import namedtuple
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from enum import Enum
|
||||
import pandas as pd
|
||||
import requests
|
||||
import yaml
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.schema import Answer, Document, EvaluationResult, FilterType, Label
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import logging
|
||||
import json
|
||||
import random
|
||||
import pandas as pd
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.mmh3 import hash128
|
||||
from haystack.schema import Document, Label, Answer
|
||||
|
||||
@ -64,7 +64,7 @@ from time import sleep
|
||||
from pathlib import Path
|
||||
from itertools import islice
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore # keep it here !
|
||||
|
||||
85
test/utils/test_tqdm.py
Normal file
85
test/utils/test_tqdm.py
Normal file
@ -0,0 +1,85 @@
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def test_silenceable_tqdm_not_disabled_by_default():
|
||||
progress_bar = tqdm(range(1))
|
||||
assert not progress_bar.disable
|
||||
|
||||
|
||||
def test_silenceable_tqdm_can_be_silenced_with_0(monkeypatch):
|
||||
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "0")
|
||||
progress_bar = tqdm(range(1))
|
||||
assert progress_bar.disable
|
||||
|
||||
|
||||
def test_silenceable_tqdm_can_be_silenced_with_false(monkeypatch):
|
||||
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "false")
|
||||
progress_bar = tqdm(range(1))
|
||||
assert progress_bar.disable
|
||||
|
||||
|
||||
def test_silenceable_tqdm_can_be_silenced_with_False(monkeypatch):
|
||||
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "False")
|
||||
progress_bar = tqdm(range(1))
|
||||
assert progress_bar.disable
|
||||
|
||||
|
||||
def test_silenceable_tqdm_can_be_silenced_with_FALSE(monkeypatch):
|
||||
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "FALSE")
|
||||
progress_bar = tqdm(range(1))
|
||||
assert progress_bar.disable
|
||||
|
||||
|
||||
def test_silenceable_tqdm_not_disabled_with_number_above_zero(monkeypatch):
|
||||
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "1")
|
||||
progress_bar = tqdm(range(1))
|
||||
assert not progress_bar.disable
|
||||
|
||||
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "10")
|
||||
progress_bar = tqdm(range(1))
|
||||
assert not progress_bar.disable
|
||||
|
||||
|
||||
def test_silenceable_tqdm_not_disabled_with_empty_string(monkeypatch):
|
||||
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "")
|
||||
progress_bar = tqdm(range(1))
|
||||
assert not progress_bar.disable
|
||||
|
||||
|
||||
def test_silenceable_tqdm_not_disabled_with_other_string(monkeypatch):
|
||||
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "true")
|
||||
progress_bar = tqdm(range(1))
|
||||
assert not progress_bar.disable
|
||||
|
||||
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "don't print the progress bars please")
|
||||
progress_bar = tqdm(range(1))
|
||||
assert not progress_bar.disable
|
||||
|
||||
|
||||
def test_silenceable_tqdm_can_be_disabled_explicitly():
|
||||
progress_bar = tqdm(range(1), disable=True)
|
||||
assert progress_bar.disable
|
||||
|
||||
|
||||
def test_silenceable_tqdm_global_disable_overrides_local_enable(monkeypatch):
|
||||
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "0")
|
||||
progress_bar = tqdm(range(1), disable=False)
|
||||
assert progress_bar.disable
|
||||
|
||||
|
||||
def test_silenceable_tqdm_global_enable_does_not_overrides_local_disable(monkeypatch):
|
||||
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "1")
|
||||
progress_bar = tqdm(range(1), disable=True)
|
||||
assert progress_bar.disable
|
||||
|
||||
|
||||
def test_silenceable_tqdm_global_and_local_disable_do_not_clash(monkeypatch):
|
||||
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "0")
|
||||
progress_bar = tqdm(range(1), disable=True)
|
||||
assert progress_bar.disable
|
||||
|
||||
|
||||
def test_silenceable_tqdm_global_and_local_enable_do_not_clash(monkeypatch):
|
||||
monkeypatch.setenv("HAYSTACK_PROGRESS_BARS", "1")
|
||||
progress_bar = tqdm(range(1), disable=False)
|
||||
assert not progress_bar.disable
|
||||
Loading…
x
Reference in New Issue
Block a user