feat: adding secure loading of models by default for haystack (#3901)

* adding secure loading of models by default

* simplified set function

* testing import effect correctly

* added appropriate log line, adapted the test

* change log string formatting

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>

* remove extra closing bracket )

Co-authored-by: Julian Risch <julian.risch@deepset.ai>
Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
Mayank Jobanputra 2023-01-24 23:01:20 +05:30 committed by GitHub
parent 739fc228c6
commit 5c53b2bd4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 0 deletions

View File

@ -20,6 +20,8 @@ import pandas as pd
from haystack.schema import Document, Answer, Label, MultiLabel, Span, EvaluationResult
from haystack.nodes.base import BaseComponent
from haystack.pipelines.base import Pipeline
from haystack.environment import set_pytorch_secure_model_loading
pd.options.display.max_colwidth = 80
set_pytorch_secure_model_loading()

View File

@ -1,3 +1,4 @@
import logging
import os
import platform
import sys
@ -17,6 +18,18 @@ HAYSTACK_REMOTE_API_MAX_RETRIES = "HAYSTACK_REMOTE_API_MAX_RETRIES"
env_meta_data: Dict[str, Any] = {}
logger = logging.getLogger(__name__)
def set_pytorch_secure_model_loading(flag_val="1"):
# To load secure only model pytorch requires value of
# TORCH_FORCE_WEIGHTS_ONLY_LOAD to be ["1", "y", "yes", "true"]
os_flag_val = os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD")
if os_flag_val is None:
os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = flag_val
else:
logger.info("TORCH_FORCE_WEIGHTS_ONLY_LOAD is already set to %s, Haystack will use the same.", os_flag_val)
def get_or_create_env_meta_data() -> Dict[str, Any]:
"""

View File

@ -1,4 +1,6 @@
import importlib
import logging
import os
from random import random
from typing import List
@ -12,6 +14,7 @@ import _pytest
from ..conftest import fail_at_version, haystack_version
from haystack.errors import OpenAIRateLimitError
from haystack.environment import set_pytorch_secure_model_loading
from haystack.schema import Answer, Document, Span, Label
from haystack.utils.deepsetcloud import DeepsetCloud, DeepsetCloudExperiments
from haystack.utils.labels import aggregate_labels
@ -1245,6 +1248,17 @@ def test_exponential_backoff():
assert greet2("John") == "Hello John"
def test_secure_model_loading(monkeypatch, caplog):
caplog.set_level(logging.INFO)
monkeypatch.setenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0")
# now testing if just importing haystack is enough to enable secure loading of pytorch models
import haystack
importlib.reload(haystack)
assert "already set to" in caplog.text
class TestAggregateLabels:
@pytest.fixture
def standard_labels(self) -> List[Label]: