mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-24 13:38:53 +00:00
fix: Ensure Secret types are immutable (#6994)
This commit is contained in:
parent
0aa788facc
commit
b552b0b37c
@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
import os
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
@ -21,18 +21,11 @@ class SecretType(Enum):
|
||||
return type
|
||||
|
||||
|
||||
@dataclass
|
||||
class Secret(ABC):
|
||||
"""
|
||||
Encapsulates a secret used for authentication.
|
||||
"""
|
||||
|
||||
_type: SecretType
|
||||
|
||||
def __init__(self, type: SecretType):
|
||||
super().__init__()
|
||||
self._type = type
|
||||
|
||||
@staticmethod
|
||||
def from_token(token: str) -> "Secret":
|
||||
"""
|
||||
@ -41,7 +34,7 @@ class Secret(ABC):
|
||||
:param token:
|
||||
The token to use for authentication.
|
||||
"""
|
||||
return TokenSecret(token)
|
||||
return TokenSecret(_token=token)
|
||||
|
||||
@staticmethod
|
||||
def from_env_var(env_vars: Union[str, List[str]], *, strict: bool = True) -> "Secret":
|
||||
@ -60,7 +53,7 @@ class Secret(ABC):
|
||||
"""
|
||||
if isinstance(env_vars, str):
|
||||
env_vars = [env_vars]
|
||||
return EnvVarSecret(env_vars, strict=strict)
|
||||
return EnvVarSecret(_env_vars=tuple(env_vars), _strict=strict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -70,7 +63,7 @@ class Secret(ABC):
|
||||
:returns:
|
||||
The serialized policy.
|
||||
"""
|
||||
out = {"type": self._type.value}
|
||||
out = {"type": self.type.value}
|
||||
inner = self._to_dict()
|
||||
assert all(k not in inner for k in out.keys())
|
||||
out.update(inner)
|
||||
@ -101,6 +94,14 @@ class Secret(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type(self) -> SecretType:
|
||||
"""
|
||||
The type of the secret.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _to_dict(self) -> Dict[str, Any]:
|
||||
pass
|
||||
@ -111,7 +112,7 @@ class Secret(ABC):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class TokenSecret(Secret):
|
||||
"""
|
||||
A secret that uses a string token/API key.
|
||||
@ -119,18 +120,13 @@ class TokenSecret(Secret):
|
||||
"""
|
||||
|
||||
_token: str
|
||||
_type: SecretType = SecretType.TOKEN
|
||||
|
||||
def __init__(self, token: str):
|
||||
"""
|
||||
Create a token secret.
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
assert self._type == SecretType.TOKEN
|
||||
|
||||
:param token:
|
||||
The token to use for authentication.
|
||||
"""
|
||||
super().__init__(SecretType.TOKEN)
|
||||
self._token = token
|
||||
|
||||
if len(token) == 0:
|
||||
if len(self._token) == 0:
|
||||
raise ValueError("Authentication token cannot be empty.")
|
||||
|
||||
def _to_dict(self) -> Dict[str, Any]:
|
||||
@ -147,8 +143,12 @@ class TokenSecret(Secret):
|
||||
def resolve_value(self) -> Optional[Any]:
|
||||
return self._token
|
||||
|
||||
@property
|
||||
def type(self) -> SecretType:
|
||||
return self._type
|
||||
|
||||
@dataclass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EnvVarSecret(Secret):
|
||||
"""
|
||||
A secret that accepts one or more environment variables.
|
||||
@ -156,32 +156,23 @@ class EnvVarSecret(Secret):
|
||||
environment variable that is set. Can be serialized.
|
||||
"""
|
||||
|
||||
_env_vars: List[str]
|
||||
_strict: bool
|
||||
_env_vars: Tuple[str, ...]
|
||||
_strict: bool = True
|
||||
_type: SecretType = SecretType.ENV_VAR
|
||||
|
||||
def __init__(self, env_vars: List[str], *, strict: bool = True):
|
||||
"""
|
||||
Create an environment variable secret.
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
assert self._type == SecretType.ENV_VAR
|
||||
|
||||
:param env_vars:
|
||||
Ordered list of candidate environment variables.
|
||||
:param strict:
|
||||
Whether to raise an exception if none of the environment
|
||||
variables are set.
|
||||
"""
|
||||
super().__init__(SecretType.ENV_VAR)
|
||||
self._env_vars = list(env_vars)
|
||||
self._strict = strict
|
||||
|
||||
if len(env_vars) == 0:
|
||||
if len(self._env_vars) == 0:
|
||||
raise ValueError("One or more environment variables must be provided for the secret.")
|
||||
|
||||
def _to_dict(self) -> Dict[str, Any]:
|
||||
return {"env_vars": self._env_vars, "strict": self._strict}
|
||||
return {"env_vars": list(self._env_vars), "strict": self._strict}
|
||||
|
||||
@staticmethod
|
||||
def _from_dict(dict: Dict[str, Any]) -> "Secret":
|
||||
return EnvVarSecret(dict["env_vars"], strict=dict["strict"])
|
||||
return EnvVarSecret(tuple(dict["env_vars"]), _strict=dict["strict"])
|
||||
|
||||
def resolve_value(self) -> Optional[Any]:
|
||||
out = None
|
||||
@ -194,6 +185,10 @@ class EnvVarSecret(Secret):
|
||||
raise ValueError(f"None of the following authentication environment variables are set: {self._env_vars}")
|
||||
return out
|
||||
|
||||
@property
|
||||
def type(self) -> SecretType:
|
||||
return self._type
|
||||
|
||||
|
||||
def deserialize_secrets_inplace(data: Dict[str, Any], keys: Iterable[str], *, recursive: bool = False):
|
||||
"""
|
||||
|
||||
@ -3,6 +3,7 @@ import os
|
||||
import pytest
|
||||
|
||||
from haystack.utils.auth import Secret, EnvVarSecret, SecretType, TokenSecret
|
||||
from dataclasses import FrozenInstanceError
|
||||
|
||||
|
||||
def test_secret_type():
|
||||
@ -15,7 +16,7 @@ def test_secret_type():
|
||||
|
||||
def test_token_secret():
|
||||
secret = Secret.from_token("test-token")
|
||||
assert secret._type == SecretType.TOKEN
|
||||
assert secret.type == SecretType.TOKEN
|
||||
assert isinstance(secret, TokenSecret)
|
||||
assert secret._token == "test-token"
|
||||
assert secret.resolve_value() == "test-token"
|
||||
@ -26,14 +27,19 @@ def test_token_secret():
|
||||
with pytest.raises(ValueError, match="cannot be empty"):
|
||||
Secret.from_token("")
|
||||
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
secret._token = "abba"
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
secret._type = SecretType.ENV_VAR
|
||||
|
||||
|
||||
def test_env_var_secret():
|
||||
secret = Secret.from_env_var("TEST_ENV_VAR1")
|
||||
os.environ["TEST_ENV_VAR1"] = "test-token"
|
||||
|
||||
assert secret._type == SecretType.ENV_VAR
|
||||
assert secret.type == SecretType.ENV_VAR
|
||||
assert isinstance(secret, EnvVarSecret)
|
||||
assert secret._env_vars == ["TEST_ENV_VAR1"]
|
||||
assert secret._env_vars == ("TEST_ENV_VAR1",)
|
||||
assert secret._strict is True
|
||||
assert secret.resolve_value() == "test-token"
|
||||
|
||||
@ -46,7 +52,7 @@ def test_env_var_secret():
|
||||
assert secret.resolve_value() == None
|
||||
|
||||
secret = Secret.from_env_var(["TEST_ENV_VAR2", "TEST_ENV_VAR1"], strict=True)
|
||||
assert secret._env_vars == ["TEST_ENV_VAR2", "TEST_ENV_VAR1"]
|
||||
assert secret._env_vars == ("TEST_ENV_VAR2", "TEST_ENV_VAR1")
|
||||
with pytest.raises(ValueError, match="None of the following .* variables are set"):
|
||||
secret.resolve_value()
|
||||
os.environ["TEST_ENV_VAR1"] = "test-token-2"
|
||||
@ -61,3 +67,10 @@ def test_env_var_secret():
|
||||
assert (
|
||||
Secret.from_dict({"type": "env_var", "env_vars": ["TEST_ENV_VAR2", "TEST_ENV_VAR1"], "strict": True}) == secret
|
||||
)
|
||||
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
secret._env_vars = ("A", "B")
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
secret._strict = False
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
secret._type = SecretType.TOKEN
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user