diff --git a/haystack/utils/auth.py b/haystack/utils/auth.py index d4f62842a..88d70dd48 100644 --- a/haystack/utils/auth.py +++ b/haystack/utils/auth.py @@ -1,6 +1,6 @@ from enum import Enum import os -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from dataclasses import dataclass from abc import ABC, abstractmethod @@ -21,18 +21,11 @@ class SecretType(Enum): return type -@dataclass class Secret(ABC): """ Encapsulates a secret used for authentication. """ - _type: SecretType - - def __init__(self, type: SecretType): - super().__init__() - self._type = type - @staticmethod def from_token(token: str) -> "Secret": """ @@ -41,7 +34,7 @@ class Secret(ABC): :param token: The token to use for authentication. """ - return TokenSecret(token) + return TokenSecret(_token=token) @staticmethod def from_env_var(env_vars: Union[str, List[str]], *, strict: bool = True) -> "Secret": @@ -60,7 +53,7 @@ class Secret(ABC): """ if isinstance(env_vars, str): env_vars = [env_vars] - return EnvVarSecret(env_vars, strict=strict) + return EnvVarSecret(_env_vars=tuple(env_vars), _strict=strict) def to_dict(self) -> Dict[str, Any]: """ @@ -70,7 +63,7 @@ class Secret(ABC): :returns: The serialized policy. """ - out = {"type": self._type.value} + out = {"type": self.type.value} inner = self._to_dict() assert all(k not in inner for k in out.keys()) out.update(inner) @@ -101,6 +94,14 @@ class Secret(ABC): """ pass + @property + @abstractmethod + def type(self) -> SecretType: + """ + The type of the secret. + """ + pass + @abstractmethod def _to_dict(self) -> Dict[str, Any]: pass @@ -111,7 +112,7 @@ class Secret(ABC): pass -@dataclass +@dataclass(frozen=True) class TokenSecret(Secret): """ A secret that uses a string token/API key. @@ -119,18 +120,13 @@ class TokenSecret(Secret): """ _token: str + _type: SecretType = SecretType.TOKEN - def __init__(self, token: str): - """ - Create a token secret. + def __post_init__(self): + super().__init__() + assert self._type == SecretType.TOKEN - :param token: - The token to use for authentication. - """ - super().__init__(SecretType.TOKEN) - self._token = token - - if len(token) == 0: + if len(self._token) == 0: raise ValueError("Authentication token cannot be empty.") def _to_dict(self) -> Dict[str, Any]: @@ -147,8 +143,12 @@ class TokenSecret(Secret): def resolve_value(self) -> Optional[Any]: return self._token + @property + def type(self) -> SecretType: + return self._type -@dataclass + +@dataclass(frozen=True) class EnvVarSecret(Secret): """ A secret that accepts one or more environment variables. @@ -156,32 +156,23 @@ class EnvVarSecret(Secret): environment variable that is set. Can be serialized. """ - _env_vars: List[str] - _strict: bool + _env_vars: Tuple[str, ...] + _strict: bool = True + _type: SecretType = SecretType.ENV_VAR - def __init__(self, env_vars: List[str], *, strict: bool = True): - """ - Create an environment variable secret. + def __post_init__(self): + super().__init__() + assert self._type == SecretType.ENV_VAR - :param env_vars: - Ordered list of candidate environment variables. - :param strict: - Whether to raise an exception if none of the environment - variables are set. - """ - super().__init__(SecretType.ENV_VAR) - self._env_vars = list(env_vars) - self._strict = strict - - if len(env_vars) == 0: + if len(self._env_vars) == 0: raise ValueError("One or more environment variables must be provided for the secret.") def _to_dict(self) -> Dict[str, Any]: - return {"env_vars": self._env_vars, "strict": self._strict} + return {"env_vars": list(self._env_vars), "strict": self._strict} @staticmethod def _from_dict(dict: Dict[str, Any]) -> "Secret": - return EnvVarSecret(dict["env_vars"], strict=dict["strict"]) + return EnvVarSecret(tuple(dict["env_vars"]), _strict=dict["strict"]) def resolve_value(self) -> Optional[Any]: out = None @@ -194,6 +185,10 @@ class EnvVarSecret(Secret): raise ValueError(f"None of the following authentication environment variables are set: {self._env_vars}") return out + @property + def type(self) -> SecretType: + return self._type + def deserialize_secrets_inplace(data: Dict[str, Any], keys: Iterable[str], *, recursive: bool = False): """ diff --git a/test/utils/test_auth.py b/test/utils/test_auth.py index c9221aad9..1c02e8ef8 100644 --- a/test/utils/test_auth.py +++ b/test/utils/test_auth.py @@ -3,6 +3,7 @@ import os import pytest from haystack.utils.auth import Secret, EnvVarSecret, SecretType, TokenSecret +from dataclasses import FrozenInstanceError def test_secret_type(): @@ -15,7 +16,7 @@ def test_secret_type(): def test_token_secret(): secret = Secret.from_token("test-token") - assert secret._type == SecretType.TOKEN + assert secret.type == SecretType.TOKEN assert isinstance(secret, TokenSecret) assert secret._token == "test-token" assert secret.resolve_value() == "test-token" @@ -26,14 +27,19 @@ def test_token_secret(): with pytest.raises(ValueError, match="cannot be empty"): Secret.from_token("") + with pytest.raises(FrozenInstanceError): + secret._token = "abba" + with pytest.raises(FrozenInstanceError): + secret._type = SecretType.ENV_VAR + def test_env_var_secret(): secret = Secret.from_env_var("TEST_ENV_VAR1") os.environ["TEST_ENV_VAR1"] = "test-token" - assert secret._type == SecretType.ENV_VAR + assert secret.type == SecretType.ENV_VAR assert isinstance(secret, EnvVarSecret) - assert secret._env_vars == ["TEST_ENV_VAR1"] + assert secret._env_vars == ("TEST_ENV_VAR1",) assert secret._strict is True assert secret.resolve_value() == "test-token" @@ -46,7 +52,7 @@ def test_env_var_secret(): assert secret.resolve_value() == None secret = Secret.from_env_var(["TEST_ENV_VAR2", "TEST_ENV_VAR1"], strict=True) - assert secret._env_vars == ["TEST_ENV_VAR2", "TEST_ENV_VAR1"] + assert secret._env_vars == ("TEST_ENV_VAR2", "TEST_ENV_VAR1") with pytest.raises(ValueError, match="None of the following .* variables are set"): secret.resolve_value() os.environ["TEST_ENV_VAR1"] = "test-token-2" @@ -61,3 +67,10 @@ def test_env_var_secret(): assert ( Secret.from_dict({"type": "env_var", "env_vars": ["TEST_ENV_VAR2", "TEST_ENV_VAR1"], "strict": True}) == secret ) + + with pytest.raises(FrozenInstanceError): + secret._env_vars = ("A", "B") + with pytest.raises(FrozenInstanceError): + secret._strict = False + with pytest.raises(FrozenInstanceError): + secret._type = SecretType.TOKEN