feat(secret): FileSecretStore and EnvironmentSecretStore (#14882)

This commit is contained in:
Sergio Gómez Villamor 2025-09-30 09:30:56 +02:00 committed by GitHub
parent 795a6828e8
commit e9e18e4705
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 516 additions and 0 deletions

View File

@ -65,3 +65,6 @@ class DataHubSecretStore(SecretStore):
def create(cls, config: Any) -> "DataHubSecretStore":
config = DataHubSecretStoreConfig.parse_obj(config)
return cls(config)
def close(self) -> None:
self.client.graph.close()

View File

@ -0,0 +1,29 @@
import os
from typing import Dict, List, Union
from datahub.secret.secret_store import SecretStore
# Simple SecretStore implementation that fetches Secret values from the local environment.
class EnvironmentSecretStore(SecretStore):
def __init__(self, config):
pass
def close(self) -> None:
return
def get_secret_values(self, secret_names: List[str]) -> Dict[str, Union[str, None]]:
values = {}
for secret_name in secret_names:
values[secret_name] = os.getenv(secret_name)
return values
def get_secret_value(self, secret_name: str) -> Union[str, None]:
return os.getenv(secret_name)
def get_id(self) -> str:
return "env"
@classmethod
def create(cls, config: Dict) -> "EnvironmentSecretStore":
return cls(config)

View File

@ -0,0 +1,49 @@
import logging
import os
from typing import Any, Dict, List, Union
from pydantic import BaseModel
from datahub.secret.secret_store import SecretStore
logger = logging.getLogger(__name__)
class FileSecretStoreConfig(BaseModel):
basedir: str = "/mnt/secrets"
max_length: int = 1024768
# Simple SecretStore implementation that fetches Secret values from the local files.
class FileSecretStore(SecretStore):
def __init__(self, config: FileSecretStoreConfig):
self.config = config
def get_secret_values(self, secret_names: List[str]) -> Dict[str, Union[str, None]]:
values = {}
for secret_name in secret_names:
values[secret_name] = self.get_secret_value(secret_name)
return values
def get_secret_value(self, secret_name: str) -> Union[str, None]:
secret_path = os.path.join(self.config.basedir, secret_name)
if os.path.exists(secret_path):
with open(secret_path, "r") as f:
secret_value = f.read(self.config.max_length + 1)
if len(secret_value) > self.config.max_length:
logger.warning(
f"Secret {secret_name} is longer than {self.config.max_length} and will be truncated."
)
return secret_value[: self.config.max_length].rstrip()
return None
def get_id(self) -> str:
return "file"
def close(self) -> None:
return
@classmethod
def create(cls, config: Any) -> "FileSecretStore":
config = FileSecretStoreConfig.parse_obj(config)
return cls(config)

View File

@ -0,0 +1,177 @@
from unittest.mock import Mock, patch
import pytest
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.graph.config import DatahubClientConfig
from datahub.secret.datahub_secret_store import (
DataHubSecretStore,
DataHubSecretStoreConfig,
)
from datahub.secret.datahub_secrets_client import DataHubSecretsClient
class TestDataHubSecretStore:
def test_init_with_graph_client(self):
mock_graph = Mock(spec=DataHubGraph)
mock_graph.test_connection.return_value = True
config = DataHubSecretStoreConfig(graph_client=mock_graph)
store = DataHubSecretStore(config)
assert store.client is not None
assert isinstance(store.client, DataHubSecretsClient)
mock_graph.test_connection.assert_called_once()
def test_init_with_graph_client_config(self):
mock_client_config = Mock(spec=DatahubClientConfig)
with patch(
"datahub.secret.datahub_secret_store.DataHubGraph"
) as mock_graph_class:
mock_graph = Mock(spec=DataHubGraph)
mock_graph_class.return_value = mock_graph
config = DataHubSecretStoreConfig(graph_client_config=mock_client_config)
store = DataHubSecretStore(config)
assert store.client is not None
mock_graph_class.assert_called_once_with(mock_client_config)
def test_init_with_no_config_raises_exception(self):
config = DataHubSecretStoreConfig()
with pytest.raises(Exception, match="Invalid configuration provided"):
DataHubSecretStore(config)
def test_get_secret_values_success(self):
mock_graph = Mock(spec=DataHubGraph)
mock_graph.test_connection.return_value = True
expected_secrets = {"secret1": "value1", "secret2": "value2"}
with patch(
"datahub.secret.datahub_secret_store.DataHubSecretsClient"
) as mock_client_class:
mock_client = Mock(spec=DataHubSecretsClient)
mock_client.get_secret_values.return_value = expected_secrets
mock_client_class.return_value = mock_client
config = DataHubSecretStoreConfig(graph_client=mock_graph)
store = DataHubSecretStore(config)
result = store.get_secret_values(["secret1", "secret2"])
assert result == expected_secrets
mock_client.get_secret_values.assert_called_once_with(
["secret1", "secret2"]
)
def test_get_secret_values_exception_handling(self):
mock_graph = Mock(spec=DataHubGraph)
mock_graph.test_connection.return_value = True
with patch(
"datahub.secret.datahub_secret_store.DataHubSecretsClient"
) as mock_client_class:
mock_client = Mock(spec=DataHubSecretsClient)
mock_client.get_secret_values.side_effect = Exception("Connection failed")
mock_client_class.return_value = mock_client
config = DataHubSecretStoreConfig(graph_client=mock_graph)
store = DataHubSecretStore(config)
with patch("datahub.secret.datahub_secret_store.logger") as mock_logger:
result = store.get_secret_values(["secret1"])
assert result == {}
mock_logger.exception.assert_called_once()
def test_get_secret_value(self):
mock_graph = Mock(spec=DataHubGraph)
mock_graph.test_connection.return_value = True
with patch(
"datahub.secret.datahub_secret_store.DataHubSecretsClient"
) as mock_client_class:
mock_client = Mock(spec=DataHubSecretsClient)
mock_client.get_secret_values.return_value = {"secret1": "value1"}
mock_client_class.return_value = mock_client
config = DataHubSecretStoreConfig(graph_client=mock_graph)
store = DataHubSecretStore(config)
result = store.get_secret_value("secret1")
assert result == "value1"
mock_client.get_secret_values.assert_called_once_with(["secret1"])
def test_get_secret_value_not_found(self):
mock_graph = Mock(spec=DataHubGraph)
mock_graph.test_connection.return_value = True
with patch(
"datahub.secret.datahub_secret_store.DataHubSecretsClient"
) as mock_client_class:
mock_client = Mock(spec=DataHubSecretsClient)
mock_client.get_secret_values.return_value = {}
mock_client_class.return_value = mock_client
config = DataHubSecretStoreConfig(graph_client=mock_graph)
store = DataHubSecretStore(config)
result = store.get_secret_value("nonexistent")
assert result is None
def test_get_id(self):
mock_graph = Mock(spec=DataHubGraph)
mock_graph.test_connection.return_value = True
config = DataHubSecretStoreConfig(graph_client=mock_graph)
store = DataHubSecretStore(config)
assert store.get_id() == "datahub"
def test_create_classmethod(self):
mock_graph = Mock(spec=DataHubGraph)
mock_graph.test_connection.return_value = True
config_dict = {"graph_client": mock_graph}
store = DataHubSecretStore.create(config_dict)
assert isinstance(store, DataHubSecretStore)
assert store.client is not None
def test_close(self):
mock_graph = Mock(spec=DataHubGraph)
mock_graph.test_connection.return_value = True
with patch(
"datahub.secret.datahub_secret_store.DataHubSecretsClient"
) as mock_client_class:
mock_client = Mock(spec=DataHubSecretsClient)
mock_client.graph = mock_graph
mock_client_class.return_value = mock_client
config = DataHubSecretStoreConfig(graph_client=mock_graph)
store = DataHubSecretStore(config)
store.close()
mock_graph.close.assert_called_once()
def test_config_validator_with_working_connection(self):
mock_graph = Mock(spec=DataHubGraph)
mock_graph.test_connection.return_value = True
config = DataHubSecretStoreConfig(graph_client=mock_graph)
assert config.graph_client == mock_graph
mock_graph.test_connection.assert_called_once()
def test_config_validator_with_none_graph_client(self):
config = DataHubSecretStoreConfig(graph_client=None)
assert config.graph_client is None

View File

@ -0,0 +1,83 @@
import os
from unittest.mock import patch
from datahub.secret.environment_secret_store import EnvironmentSecretStore
class TestEnvironmentSecretStore:
def test_init(self):
config: dict = {}
store = EnvironmentSecretStore(config)
assert store is not None
def test_get_secret_values_with_existing_env_vars(self):
store = EnvironmentSecretStore({})
with patch.dict(os.environ, {"SECRET1": "value1", "SECRET2": "value2"}):
result = store.get_secret_values(["SECRET1", "SECRET2"])
assert result == {"SECRET1": "value1", "SECRET2": "value2"}
def test_get_secret_values_with_missing_env_vars(self):
store = EnvironmentSecretStore({})
with patch.dict(os.environ, {}, clear=True):
result = store.get_secret_values(["NONEXISTENT1", "NONEXISTENT2"])
assert result == {"NONEXISTENT1": None, "NONEXISTENT2": None}
def test_get_secret_values_mixed_existing_and_missing(self):
store = EnvironmentSecretStore({})
with patch.dict(os.environ, {"SECRET1": "value1"}, clear=True):
result = store.get_secret_values(["SECRET1", "NONEXISTENT"])
assert result == {"SECRET1": "value1", "NONEXISTENT": None}
def test_get_secret_value_existing(self):
store = EnvironmentSecretStore({})
with patch.dict(os.environ, {"SECRET1": "value1"}):
result = store.get_secret_value("SECRET1")
assert result == "value1"
def test_get_secret_value_nonexistent(self):
store = EnvironmentSecretStore({})
with patch.dict(os.environ, {}, clear=True):
result = store.get_secret_value("NONEXISTENT")
assert result is None
def test_get_secret_value_empty_string(self):
store = EnvironmentSecretStore({})
with patch.dict(os.environ, {"EMPTY_SECRET": ""}):
result = store.get_secret_value("EMPTY_SECRET")
assert result == ""
def test_get_id(self):
store = EnvironmentSecretStore({})
assert store.get_id() == "env"
def test_create_classmethod(self):
config = {"some_key": "some_value"}
store = EnvironmentSecretStore.create(config)
assert isinstance(store, EnvironmentSecretStore)
def test_get_secret_values_empty_list(self):
store = EnvironmentSecretStore({})
result = store.get_secret_values([])
assert result == {}
def test_get_secret_values_with_special_characters(self):
store = EnvironmentSecretStore({})
with patch.dict(os.environ, {"SECRET_WITH_SPECIAL": "value!@#$%^&*()"}):
result = store.get_secret_values(["SECRET_WITH_SPECIAL"])
assert result == {"SECRET_WITH_SPECIAL": "value!@#$%^&*()"}

View File

@ -0,0 +1,175 @@
import os
import tempfile
from unittest.mock import patch
from datahub.secret.file_secret_store import FileSecretStore, FileSecretStoreConfig
class TestFileSecretStore:
def test_init_with_default_config(self):
config = FileSecretStoreConfig()
store = FileSecretStore(config)
assert store.config.basedir == "/mnt/secrets"
assert store.config.max_length == 1024768
def test_init_with_custom_config(self):
config = FileSecretStoreConfig(basedir="/custom/path", max_length=512)
store = FileSecretStore(config)
assert store.config.basedir == "/custom/path"
assert store.config.max_length == 512
def test_get_secret_value_file_exists(self):
with tempfile.TemporaryDirectory() as temp_dir:
# Create a test secret file
secret_file = os.path.join(temp_dir, "test_secret")
with open(secret_file, "w") as f:
f.write("secret_value")
config = FileSecretStoreConfig(basedir=temp_dir)
store = FileSecretStore(config)
result = store.get_secret_value("test_secret")
assert result == "secret_value"
def test_get_secret_value_file_not_exists(self):
with tempfile.TemporaryDirectory() as temp_dir:
config = FileSecretStoreConfig(basedir=temp_dir)
store = FileSecretStore(config)
result = store.get_secret_value("nonexistent_secret")
assert result is None
def test_get_secret_value_with_trailing_whitespace(self):
with tempfile.TemporaryDirectory() as temp_dir:
secret_file = os.path.join(temp_dir, "test_secret")
with open(secret_file, "w") as f:
f.write("secret_value\n\t ")
config = FileSecretStoreConfig(basedir=temp_dir)
store = FileSecretStore(config)
result = store.get_secret_value("test_secret")
assert result == "secret_value"
def test_get_secret_value_exceeds_max_length(self):
with tempfile.TemporaryDirectory() as temp_dir:
secret_file = os.path.join(temp_dir, "large_secret")
large_content = "a" * 100
with open(secret_file, "w") as f:
f.write(large_content)
config = FileSecretStoreConfig(basedir=temp_dir, max_length=50)
store = FileSecretStore(config)
with patch("datahub.secret.file_secret_store.logger") as mock_logger:
result = store.get_secret_value("large_secret")
assert result == "a" * 50
mock_logger.warning.assert_called_once()
assert "longer than 50" in mock_logger.warning.call_args[0][0]
def test_get_secret_values(self):
with tempfile.TemporaryDirectory() as temp_dir:
# Create test secret files
for i, content in enumerate(["value1", "value2"], 1):
secret_file = os.path.join(temp_dir, f"secret{i}")
with open(secret_file, "w") as f:
f.write(content)
config = FileSecretStoreConfig(basedir=temp_dir)
store = FileSecretStore(config)
result = store.get_secret_values(["secret1", "secret2", "nonexistent"])
assert result == {
"secret1": "value1",
"secret2": "value2",
"nonexistent": None,
}
def test_get_secret_values_empty_list(self):
with tempfile.TemporaryDirectory() as temp_dir:
config = FileSecretStoreConfig(basedir=temp_dir)
store = FileSecretStore(config)
result = store.get_secret_values([])
assert result == {}
def test_get_id(self):
config = FileSecretStoreConfig()
store = FileSecretStore(config)
assert store.get_id() == "file"
def test_close(self):
config = FileSecretStoreConfig()
store = FileSecretStore(config)
# Should not raise an exception
store.close()
def test_create_classmethod(self):
config_dict = {"basedir": "/test/path", "max_length": 2048}
store = FileSecretStore.create(config_dict)
assert isinstance(store, FileSecretStore)
assert store.config.basedir == "/test/path"
assert store.config.max_length == 2048
def test_create_classmethod_with_invalid_config(self):
config_dict = {"invalid_field": "value"}
# Pydantic will ignore unknown fields by default, so this creates a store with defaults
store = FileSecretStore.create(config_dict)
assert isinstance(store, FileSecretStore)
assert store.config.basedir == "/mnt/secrets" # Default value
assert store.config.max_length == 1024768 # Default value
def test_get_secret_value_empty_file(self):
with tempfile.TemporaryDirectory() as temp_dir:
secret_file = os.path.join(temp_dir, "empty_secret")
with open(secret_file, "w") as f:
f.write("")
config = FileSecretStoreConfig(basedir=temp_dir)
store = FileSecretStore(config)
result = store.get_secret_value("empty_secret")
assert result == ""
def test_get_secret_value_exactly_max_length(self):
with tempfile.TemporaryDirectory() as temp_dir:
secret_file = os.path.join(temp_dir, "exact_length_secret")
content = "a" * 100
with open(secret_file, "w") as f:
f.write(content)
config = FileSecretStoreConfig(basedir=temp_dir, max_length=100)
store = FileSecretStore(config)
with patch("datahub.secret.file_secret_store.logger") as mock_logger:
result = store.get_secret_value("exact_length_secret")
assert result == content
# Should not log warning for exact length
mock_logger.warning.assert_not_called()
def test_file_secret_store_config_defaults(self):
config = FileSecretStoreConfig()
assert config.basedir == "/mnt/secrets"
assert config.max_length == 1024768
def test_file_secret_store_config_custom_values(self):
config = FileSecretStoreConfig(basedir="/custom", max_length=512)
assert config.basedir == "/custom"
assert config.max_length == 512