mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-26 14:38:36 +00:00
feat: adding secure loading of models by default for haystack (#3901)
* adding secure loading of models by default * simplified set function * testing import effect correctly * added appropriate log line, adapted the test * change log string formatting Co-authored-by: Massimiliano Pippi <mpippi@gmail.com> * remove extra closing bracket ) Co-authored-by: Julian Risch <julian.risch@deepset.ai> Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
This commit is contained in:
parent
739fc228c6
commit
5c53b2bd4a
@ -20,6 +20,8 @@ import pandas as pd
|
||||
from haystack.schema import Document, Answer, Label, MultiLabel, Span, EvaluationResult
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.pipelines.base import Pipeline
|
||||
from haystack.environment import set_pytorch_secure_model_loading
|
||||
|
||||
|
||||
pd.options.display.max_colwidth = 80
|
||||
set_pytorch_secure_model_loading()
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
@ -17,6 +18,18 @@ HAYSTACK_REMOTE_API_MAX_RETRIES = "HAYSTACK_REMOTE_API_MAX_RETRIES"
|
||||
|
||||
env_meta_data: Dict[str, Any] = {}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_pytorch_secure_model_loading(flag_val="1"):
|
||||
# To load secure only model pytorch requires value of
|
||||
# TORCH_FORCE_WEIGHTS_ONLY_LOAD to be ["1", "y", "yes", "true"]
|
||||
os_flag_val = os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD")
|
||||
if os_flag_val is None:
|
||||
os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = flag_val
|
||||
else:
|
||||
logger.info("TORCH_FORCE_WEIGHTS_ONLY_LOAD is already set to %s, Haystack will use the same.", os_flag_val)
|
||||
|
||||
|
||||
def get_or_create_env_meta_data() -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from random import random
|
||||
from typing import List
|
||||
|
||||
@ -12,6 +14,7 @@ import _pytest
|
||||
from ..conftest import fail_at_version, haystack_version
|
||||
|
||||
from haystack.errors import OpenAIRateLimitError
|
||||
from haystack.environment import set_pytorch_secure_model_loading
|
||||
from haystack.schema import Answer, Document, Span, Label
|
||||
from haystack.utils.deepsetcloud import DeepsetCloud, DeepsetCloudExperiments
|
||||
from haystack.utils.labels import aggregate_labels
|
||||
@ -1245,6 +1248,17 @@ def test_exponential_backoff():
|
||||
assert greet2("John") == "Hello John"
|
||||
|
||||
|
||||
def test_secure_model_loading(monkeypatch, caplog):
|
||||
caplog.set_level(logging.INFO)
|
||||
monkeypatch.setenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0")
|
||||
|
||||
# now testing if just importing haystack is enough to enable secure loading of pytorch models
|
||||
import haystack
|
||||
|
||||
importlib.reload(haystack)
|
||||
assert "already set to" in caplog.text
|
||||
|
||||
|
||||
class TestAggregateLabels:
|
||||
@pytest.fixture
|
||||
def standard_labels(self) -> List[Label]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user