diff --git a/haystack/__init__.py b/haystack/__init__.py index 5ce8338b7..9eeaa7db7 100644 --- a/haystack/__init__.py +++ b/haystack/__init__.py @@ -20,6 +20,8 @@ import pandas as pd from haystack.schema import Document, Answer, Label, MultiLabel, Span, EvaluationResult from haystack.nodes.base import BaseComponent from haystack.pipelines.base import Pipeline +from haystack.environment import set_pytorch_secure_model_loading pd.options.display.max_colwidth = 80 +set_pytorch_secure_model_loading() diff --git a/haystack/environment.py b/haystack/environment.py index a0488c709..1eefd7a50 100644 --- a/haystack/environment.py +++ b/haystack/environment.py @@ -1,3 +1,4 @@ +import logging import os import platform import sys @@ -17,6 +18,18 @@ HAYSTACK_REMOTE_API_MAX_RETRIES = "HAYSTACK_REMOTE_API_MAX_RETRIES" env_meta_data: Dict[str, Any] = {} +logger = logging.getLogger(__name__) + + +def set_pytorch_secure_model_loading(flag_val="1"): + # To load secure only model pytorch requires value of + # TORCH_FORCE_WEIGHTS_ONLY_LOAD to be ["1", "y", "yes", "true"] + os_flag_val = os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD") + if os_flag_val is None: + os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = flag_val + else: + logger.info("TORCH_FORCE_WEIGHTS_ONLY_LOAD is already set to %s, Haystack will use the same.", os_flag_val) + def get_or_create_env_meta_data() -> Dict[str, Any]: """ diff --git a/test/others/test_utils.py b/test/others/test_utils.py index 68449bcb2..2d16993d4 100644 --- a/test/others/test_utils.py +++ b/test/others/test_utils.py @@ -1,4 +1,6 @@ +import importlib import logging +import os from random import random from typing import List @@ -12,6 +14,7 @@ import _pytest from ..conftest import fail_at_version, haystack_version from haystack.errors import OpenAIRateLimitError +from haystack.environment import set_pytorch_secure_model_loading from haystack.schema import Answer, Document, Span, Label from haystack.utils.deepsetcloud import DeepsetCloud, DeepsetCloudExperiments from haystack.utils.labels import aggregate_labels @@ -1245,6 +1248,17 @@ def test_exponential_backoff(): assert greet2("John") == "Hello John" +def test_secure_model_loading(monkeypatch, caplog): + caplog.set_level(logging.INFO) + monkeypatch.setenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0") + + # now testing if just importing haystack is enough to enable secure loading of pytorch models + import haystack + + importlib.reload(haystack) + assert "already set to" in caplog.text + + class TestAggregateLabels: @pytest.fixture def standard_labels(self) -> List[Label]: