fix: Ensure Secret types are immutable (#6994)

This commit is contained in:
Madeesh Kannan 2024-02-16 12:46:38 +01:00 committed by GitHub
parent 0aa788facc
commit b552b0b37c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 45 deletions

View File

@ -1,6 +1,6 @@
from enum import Enum
import os
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from dataclasses import dataclass
from abc import ABC, abstractmethod
@ -21,18 +21,11 @@ class SecretType(Enum):
return type
@dataclass
class Secret(ABC):
"""
Encapsulates a secret used for authentication.
"""
_type: SecretType
def __init__(self, type: SecretType):
super().__init__()
self._type = type
@staticmethod
def from_token(token: str) -> "Secret":
"""
@ -41,7 +34,7 @@ class Secret(ABC):
:param token:
The token to use for authentication.
"""
return TokenSecret(token)
return TokenSecret(_token=token)
@staticmethod
def from_env_var(env_vars: Union[str, List[str]], *, strict: bool = True) -> "Secret":
@ -60,7 +53,7 @@ class Secret(ABC):
"""
if isinstance(env_vars, str):
env_vars = [env_vars]
return EnvVarSecret(env_vars, strict=strict)
return EnvVarSecret(_env_vars=tuple(env_vars), _strict=strict)
def to_dict(self) -> Dict[str, Any]:
"""
@ -70,7 +63,7 @@ class Secret(ABC):
:returns:
The serialized policy.
"""
out = {"type": self._type.value}
out = {"type": self.type.value}
inner = self._to_dict()
assert all(k not in inner for k in out.keys())
out.update(inner)
@ -101,6 +94,14 @@ class Secret(ABC):
"""
pass
@property
@abstractmethod
def type(self) -> SecretType:
"""
The type of the secret.
"""
pass
@abstractmethod
def _to_dict(self) -> Dict[str, Any]:
pass
@ -111,7 +112,7 @@ class Secret(ABC):
pass
@dataclass
@dataclass(frozen=True)
class TokenSecret(Secret):
"""
A secret that uses a string token/API key.
@ -119,18 +120,13 @@ class TokenSecret(Secret):
"""
_token: str
_type: SecretType = SecretType.TOKEN
def __init__(self, token: str):
"""
Create a token secret.
def __post_init__(self):
super().__init__()
assert self._type == SecretType.TOKEN
:param token:
The token to use for authentication.
"""
super().__init__(SecretType.TOKEN)
self._token = token
if len(token) == 0:
if len(self._token) == 0:
raise ValueError("Authentication token cannot be empty.")
def _to_dict(self) -> Dict[str, Any]:
@ -147,8 +143,12 @@ class TokenSecret(Secret):
def resolve_value(self) -> Optional[Any]:
return self._token
@property
def type(self) -> SecretType:
return self._type
@dataclass
@dataclass(frozen=True)
class EnvVarSecret(Secret):
"""
A secret that accepts one or more environment variables.
@ -156,32 +156,23 @@ class EnvVarSecret(Secret):
environment variable that is set. Can be serialized.
"""
_env_vars: List[str]
_strict: bool
_env_vars: Tuple[str, ...]
_strict: bool = True
_type: SecretType = SecretType.ENV_VAR
def __init__(self, env_vars: List[str], *, strict: bool = True):
"""
Create an environment variable secret.
def __post_init__(self):
super().__init__()
assert self._type == SecretType.ENV_VAR
:param env_vars:
Ordered list of candidate environment variables.
:param strict:
Whether to raise an exception if none of the environment
variables are set.
"""
super().__init__(SecretType.ENV_VAR)
self._env_vars = list(env_vars)
self._strict = strict
if len(env_vars) == 0:
if len(self._env_vars) == 0:
raise ValueError("One or more environment variables must be provided for the secret.")
def _to_dict(self) -> Dict[str, Any]:
return {"env_vars": self._env_vars, "strict": self._strict}
return {"env_vars": list(self._env_vars), "strict": self._strict}
@staticmethod
def _from_dict(dict: Dict[str, Any]) -> "Secret":
return EnvVarSecret(dict["env_vars"], strict=dict["strict"])
return EnvVarSecret(tuple(dict["env_vars"]), _strict=dict["strict"])
def resolve_value(self) -> Optional[Any]:
out = None
@ -194,6 +185,10 @@ class EnvVarSecret(Secret):
raise ValueError(f"None of the following authentication environment variables are set: {self._env_vars}")
return out
@property
def type(self) -> SecretType:
return self._type
def deserialize_secrets_inplace(data: Dict[str, Any], keys: Iterable[str], *, recursive: bool = False):
"""

View File

@ -3,6 +3,7 @@ import os
import pytest
from haystack.utils.auth import Secret, EnvVarSecret, SecretType, TokenSecret
from dataclasses import FrozenInstanceError
def test_secret_type():
@ -15,7 +16,7 @@ def test_secret_type():
def test_token_secret():
secret = Secret.from_token("test-token")
assert secret._type == SecretType.TOKEN
assert secret.type == SecretType.TOKEN
assert isinstance(secret, TokenSecret)
assert secret._token == "test-token"
assert secret.resolve_value() == "test-token"
@ -26,14 +27,19 @@ def test_token_secret():
with pytest.raises(ValueError, match="cannot be empty"):
Secret.from_token("")
with pytest.raises(FrozenInstanceError):
secret._token = "abba"
with pytest.raises(FrozenInstanceError):
secret._type = SecretType.ENV_VAR
def test_env_var_secret():
secret = Secret.from_env_var("TEST_ENV_VAR1")
os.environ["TEST_ENV_VAR1"] = "test-token"
assert secret._type == SecretType.ENV_VAR
assert secret.type == SecretType.ENV_VAR
assert isinstance(secret, EnvVarSecret)
assert secret._env_vars == ["TEST_ENV_VAR1"]
assert secret._env_vars == ("TEST_ENV_VAR1",)
assert secret._strict is True
assert secret.resolve_value() == "test-token"
@ -46,7 +52,7 @@ def test_env_var_secret():
assert secret.resolve_value() == None
secret = Secret.from_env_var(["TEST_ENV_VAR2", "TEST_ENV_VAR1"], strict=True)
assert secret._env_vars == ["TEST_ENV_VAR2", "TEST_ENV_VAR1"]
assert secret._env_vars == ("TEST_ENV_VAR2", "TEST_ENV_VAR1")
with pytest.raises(ValueError, match="None of the following .* variables are set"):
secret.resolve_value()
os.environ["TEST_ENV_VAR1"] = "test-token-2"
@ -61,3 +67,10 @@ def test_env_var_secret():
assert (
Secret.from_dict({"type": "env_var", "env_vars": ["TEST_ENV_VAR2", "TEST_ENV_VAR1"], "strict": True}) == secret
)
with pytest.raises(FrozenInstanceError):
secret._env_vars = ("A", "B")
with pytest.raises(FrozenInstanceError):
secret._strict = False
with pytest.raises(FrozenInstanceError):
secret._type = SecretType.TOKEN