diff --git a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py index 1ddfcd1bba6..734f866c443 100644 --- a/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py +++ b/ingestion/src/metadata/data_quality/validations/runtime_param_setter/table_diff_params_setter.py @@ -74,7 +74,7 @@ class TableDiffParamsSetter(RuntimeParameterSetter): - table2: The table path for the second service (only schema and table name) - keyColumns: If not defined, construct the key columns based on primary key or unique constraint. - extraColumns: If not defined, construct the extra columns as all columns except the key columns. - - whereClause: Exrtact where clause based on partitioning and user input + - whereClause: Extract where clause based on partitioning and user input """ def __init__( diff --git a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py index f4ca09fa356..1d1d137f2ee 100644 --- a/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py +++ b/ingestion/src/metadata/data_quality/validations/table/sqlalchemy/tableDiff.py @@ -55,6 +55,7 @@ from metadata.profiler.orm.functions.md5 import MD5 from metadata.profiler.orm.functions.substr import Substr from metadata.profiler.orm.registry import Dialects, PythonDialects from metadata.utils.collections import CaseInsensitiveList +from metadata.utils.credentials import normalize_pem_string from metadata.utils.logger import test_suite_logger logger = test_suite_logger() @@ -283,7 +284,9 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): self.runtime_params.table1.key_columns, extra_columns=self.runtime_params.extraColumns, case_sensitive=self.get_case_sensitive(), - key_content=self.runtime_params.table1.privateKey.get_secret_value() + key_content=normalize_pem_string( + self.runtime_params.table1.privateKey.get_secret_value() + ) if self.runtime_params.table1.privateKey else None, private_key_passphrase=self.runtime_params.table1.passPhrase.get_secret_value() @@ -296,7 +299,9 @@ class TableDiffValidator(BaseTestValidator, SQAValidatorMixin): self.runtime_params.table2.key_columns, extra_columns=self.runtime_params.extraColumns, case_sensitive=self.get_case_sensitive(), - key_content=self.runtime_params.table2.privateKey.get_secret_value() + key_content=normalize_pem_string( + self.runtime_params.table2.privateKey.get_secret_value() + ) if self.runtime_params.table2.privateKey else None, private_key_passphrase=self.runtime_params.table2.passPhrase.get_secret_value() diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/connection.py b/ingestion/src/metadata/ingestion/source/database/snowflake/connection.py index 3bdab7a39f4..1a1d674f43a 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/connection.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/connection.py @@ -53,6 +53,7 @@ from metadata.ingestion.source.database.snowflake.queries import ( SNOWFLAKE_TEST_GET_VIEWS, ) from metadata.utils.constants import THREE_MIN +from metadata.utils.credentials import normalize_pem_string from metadata.utils.filters import filter_by_database from metadata.utils.logger import ingestion_logger @@ -173,8 +174,16 @@ class SnowflakeConnection(BaseConnection[SnowflakeConnectionConfig, Engine]): logger.warning( "Snowflake Private Key Passphrase not found, replacing it with empty string" ) + + encrypted_private_key = normalize_pem_string( + connection.privateKey.get_secret_value() + ) + p_key = serialization.load_pem_private_key( - bytes(connection.privateKey.get_secret_value(), "utf-8"), + bytes( + encrypted_private_key, + "utf-8", + ), password=snowflake_private_key_passphrase.encode() or None, backend=default_backend(), ) diff --git a/ingestion/src/metadata/utils/credentials.py b/ingestion/src/metadata/utils/credentials.py index 45a4caf1a97..d590fa94d13 100644 --- a/ingestion/src/metadata/utils/credentials.py +++ b/ingestion/src/metadata/utils/credentials.py @@ -73,6 +73,43 @@ def validate_private_key(private_key: str) -> None: raise InvalidPrivateKeyException(msg) from err +def normalize_pem_string(value: str) -> str: + """ + Normalize a PEM-encoded private key, public key, or certificate string. + + This covers edge cases where getting private keys from the server end up with + escaped newlines for whatever reason. e.g: private key came from a JSON response like + + `{"private_key": "-----BEGIN PRIVATE KEY-----\\nABC\\n-----END PRIVATE KEY-----"}` + + - If the string looks like a PEM (contains BEGIN/END headers) + and has literal '\\n' sequences instead of real newlines, + convert them to real newlines. + - Otherwise, return the string unchanged. + + Example: + >>> normalize_pem_string("-----BEGIN PRIVATE KEY-----\\nABC\\n-----END PRIVATE KEY-----") + '-----BEGIN PRIVATE KEY-----\nABC\n-----END PRIVATE KEY-----' + """ + if not isinstance(value, str): + return value + + pem_headers = ( + "-----BEGIN RSA PRIVATE KEY-----", + "-----BEGIN ENCRYPTED PRIVATE KEY-----", + "-----BEGIN PRIVATE KEY-----", + "-----BEGIN OPENSSH PRIVATE KEY-----", + "-----BEGIN CERTIFICATE-----", + ) + + # Only normalize if it looks like PEM and is all on one line (escaped newlines) + if any(h in value for h in pem_headers): + if "\\n" in value and "\n" not in value: + return value.replace("\\n", "\n") + + return value + + def create_credential_tmp_file(credentials: dict) -> str: """ Given a credentials' dict, store it in a tmp file diff --git a/ingestion/tests/cli_e2e/base/test_cli.py b/ingestion/tests/cli_e2e/base/test_cli.py index a25f235ef1b..0e3a2df8478 100644 --- a/ingestion/tests/cli_e2e/base/test_cli.py +++ b/ingestion/tests/cli_e2e/base/test_cli.py @@ -17,11 +17,14 @@ import re import subprocess from abc import ABC, abstractmethod from ast import literal_eval +from copy import deepcopy from pathlib import Path +from typing import Any, Optional import yaml from metadata.config.common import load_config_file +from metadata.generated.schema.entity.teams.user import AuthenticationMechanism, User from metadata.ingestion.api.status import Status from metadata.ingestion.ometa.ometa_api import OpenMetadata from metadata.utils.constants import UTF_8 @@ -43,6 +46,7 @@ class CliBase(ABC): openmetadata: OpenMetadata test_file_path: str config_file_path: str + ingestion_bot_jwt_token: Optional[str] = None def run_command(self, command: str = "ingest", test_file_path=None) -> str: file_path = ( @@ -70,11 +74,34 @@ class CliBase(ABC): f"/lineage/table/name/{entity_fqn}?upstreamDepth=3&downstreamDepth=3" ) + @classmethod + def set_ingestion_bot_jwt_token(cls) -> None: + ingestion_bot: User = cls.openmetadata.get_by_name(User, "ingestion-bot") + ingestion_bot_auth: AuthenticationMechanism = cls.openmetadata.get_by_id( + AuthenticationMechanism, ingestion_bot.id + ) + cls.ingestion_bot_jwt_token = ( + ingestion_bot_auth.config.JWTToken.get_secret_value() + ) + + def patch_server_security_config(self, config: dict[str, Any]) -> dict[str, Any]: + if self.ingestion_bot_jwt_token is None: + return config + + server_config = deepcopy(config) + server_config["workflowConfig"]["openMetadataServerConfig"][ + "securityConfig" + ] = { + "jwtToken": self.ingestion_bot_jwt_token, + } + return server_config + def build_config_file( self, test_type: E2EType = E2EType.INGEST, extra_args: dict = None ) -> None: config_yaml = load_config_file(Path(self.config_file_path)) config_yaml = self.build_yaml(config_yaml, test_type, extra_args) + config_yaml = self.patch_server_security_config(config_yaml) with open(self.test_file_path, "w", encoding=UTF_8) as test_file: yaml.dump(config_yaml, test_file) diff --git a/ingestion/tests/cli_e2e/common/test_cli_db.py b/ingestion/tests/cli_e2e/common/test_cli_db.py index 5a060b7a89f..b00a0ad9182 100644 --- a/ingestion/tests/cli_e2e/common/test_cli_db.py +++ b/ingestion/tests/cli_e2e/common/test_cli_db.py @@ -42,7 +42,10 @@ class CliCommonDB: connector, cls.get_test_type() ) cls.engine = workflow.source.engine + cls.openmetadata = workflow.source.metadata + cls.set_ingestion_bot_jwt_token() + cls.config_file_path = str( Path(PATH_TO_RESOURCES + f"/database/{connector}/{connector}.yaml") ) diff --git a/ingestion/tests/unit/data_quality/validations/runtime_param_setter/test_table_diff_params_setter.py b/ingestion/tests/unit/data_quality/validations/runtime_param_setter/test_table_diff_params_setter.py index 221f91b12ae..6f1d66b9e1c 100644 --- a/ingestion/tests/unit/data_quality/validations/runtime_param_setter/test_table_diff_params_setter.py +++ b/ingestion/tests/unit/data_quality/validations/runtime_param_setter/test_table_diff_params_setter.py @@ -23,6 +23,9 @@ from metadata.generated.schema.entity.data.table import ( from metadata.generated.schema.entity.services.connections.database.postgresConnection import ( PostgresConnection, ) +from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import ( + SnowflakeConnection, +) from metadata.generated.schema.entity.services.databaseService import ( DatabaseConnection, DatabaseService, @@ -63,8 +66,8 @@ def metadata( @pytest.fixture -def service_connection_config() -> DatabaseConnection: - return create_autospec(DatabaseConnection, spec_set=True, instance=True) +def service_connection_config() -> PostgresConnection: + return PostgresConnection.model_construct() @pytest.fixture @@ -75,13 +78,13 @@ def sampler() -> SamplerInterface: @pytest.fixture -def service1() -> DatabaseService: +def service1(service_connection_config: PostgresConnection) -> DatabaseService: return DatabaseService.model_construct( id=uuid.uuid4(), name="TestService1", fullyQualifiedName="TestService1", serviceType=DatabaseServiceType.Postgres, - connection=DatabaseConnection(config=PostgresConnection.model_construct()), + connection=DatabaseConnection(config=service_connection_config), ) @@ -224,3 +227,73 @@ def test_setter_gets_per_table_key_columns( key_columns=["table_id"], ), ) + + +class TestForSnowflake: + @pytest.fixture + def service_connection_config(self) -> SnowflakeConnection: + return SnowflakeConnection( + account="account", + username="username", + warehouse="warehouse", + privateKey="-----BEGIN ENCRYPTED PRIVATE KEY-----\\nMIIFNTBfBgkqhkiG9w0BBQ0wUjAxBgkqhkiG9w0BBQwwJAQQkygnWhWG1aAiElog\\n0itnbwICCAAwDAYIKoZIhvcNAgkFADAdBglghkgBZQMEASoEEHwOOuGPCXoQiqPd\\ntg/fkPAEggTQqUObUeUhiBSJNVZ2hF5O/oK2glaT7gAsXG6FB56GD09KjdNE+KTk\\nuEMmQgKN1oYdlY6NVJ7zDak6a/fn79jWHN0jTEODJKoo+2sD4QvJxFqqxp008mYS\\n9HTJhlwmfM4cqCLaIbAvDG74s8+48Hq5n71gA91RdPHxtE/La91hOCS+UVRjAuXZ\\nJ2bEYuoWrP6FSTysIDNFhI3SshzrP+SJ7rGY1ahkhHu5kfActy1ATr9288vWKiHv\\n564GOq85Vt8QGcq6dM3vClKEAhljS35TMs2LlM3cP+sFCO4PYRaOtrH7ENuusaOU\\nvfEpo41W53uVP6hGMU8LuWzdDjVZUqNJdcnlAIdUkI8XG70IlMyGAna7Y5UyB5Xn\\nXlivhvvJSHly9pj00QWI5uiSY08cDDqvLmyg5Vmqr2lINfn5kMjtoeVF4T9UoxXc\\nLrCLQmYqhUYtBBJh7i1IxepqI69KaotgZsDwV5oJz2+GofVo/O0kXq3/JWvlpQ8o\\nkisZiWSpld/NFeJdxCuE308zdLb5D5aeJbcyHM0ldZ2+zH6+ERCs134bkEJFeBmC\\nmNJ++DPfZJGe6AWqM9qrBr2UZyZhLg5VV2MzDB3YBkI+FxSVnRZNu9WWreLw9+k5\\n6LJ4Mnwrw+jGdPBXf5sqEoCmem85N1IKtJMXl5BNHE8V2MZm+xPLRypoFH2ipNEL\\nBu8XFaqxe+4cTA/eYoyf0DEzGYY/x1PMy5y3EYJu7xhkkCjzX09ZKkM3EdycCEvY\\nAdIKhXdKphe88WBzDtssjBtEJGjgZkX5JioW0VrMOlQBXA3xS/vdRemBXwTM2Fmy\\nuZajbWQq1yBtlpKtRFF9Yj2QJinumjoiRCWIcNcEN6/V+5IETClzBOYgwpZHwSIv\\nZGbIXPHrmbk2GIJXtRXjnXGVIrcgUOJfrZmpvpBhpcbAIoRUwCj7iSgBMOhKth8Y\\nk3uc8ZTXDdKAayxo1USG87tWojeyu0rRJxCiu4WuAQgHnUYRpViOrPGO7msKPPhd\\nZLO+x428W2myXHw/ZsLZoM2AyK4h6M0m647L9+lbrurGkTHwDs35RuNeflyTvGkF\\nOfTN9xYgeBXi99TdLmo0G1giKqKp6Gq1h+iTXbqbqJiqS1wzS5duvLA53uojkHIC\\n2/fCnANUhMKtGUCyHZ8Lr6FLYQiBDmCQwq1buEKHLgA7uap6WNVLnSAvRmPWGwn3\\nmZxuVBBX2uDBkZgBbVE19kSAWjFjfAGr6+LCZpHHcUWP+LiV/Qpbbrg2j3xcI7d8\\ncwjON0uR7DU10i3gWncsPUCACs44O86OHVJTFUqrZAMjnSdXuSmiIHzTaOY0QhYn\\n/K35NknBplnD3bw89by0vfFbGsvm19jTawzLVhmGBLnQAB780vODdKjMKgUfzW9t\\nsDO2+gdo7vO5Ep6xh+UVzakAY+JD6Z0qDnM8KNURo9iku06Ctroyf7drHq5rqb3A\\nSLsYtMImlPbHLGX62lNqs9016h6QoDCazxW1Ef/B5/gnLfCeiW4rTMemZ6Nlzu+8\\ntDMxQrRpo5tGdhZgfiEIfFUlZTMJWmjHzZw5z4LYvxCKBPabUSxPSeuTi8ll2ljF\\n8fGq0P3vYJbQ0SIw20Srmqdoj1g3HJP4D+a0iUlMpr7wdkP3sAgy7so=\\n-----END ENCRYPTED PRIVATE KEY-----", + snowflakePrivatekeyPassphrase="passphrase", + ) + + @pytest.fixture + def service1( + self, service_connection_config: SnowflakeConnection + ) -> DatabaseService: + return DatabaseService.model_construct( + id=uuid.uuid4(), + name="TestService1", + fullyQualifiedName="TestService1", + serviceType=DatabaseServiceType.Snowflake, + connection=DatabaseConnection(config=service_connection_config), + ) + + @pytest.fixture + def service2( + self, service_connection_config: SnowflakeConnection + ) -> DatabaseService: + return DatabaseService.model_construct( + id=uuid.uuid4(), + name="TestService2", + fullyQualifiedName="TestService2", + serviceType=DatabaseServiceType.Snowflake, + connection=DatabaseConnection(config=service_connection_config), + ) + + @pytest.fixture + def setter( + self, + metadata: OpenMetadata, + service_connection_config: DatabaseConnection, + sampler: SamplerInterface, + table1: Table, + ) -> TableDiffParamsSetter: + return TableDiffParamsSetter( + ometa_client=metadata, + service_connection_config=service_connection_config, + sampler=sampler, + table_entity=table1, + ) + + def test_setter_gets_parameters_for_snowflake( + self, + setter: TableDiffParamsSetter, + parameter_values: List[TestCaseParameterValue], + ) -> None: + test_case = TestCase.model_construct( + parameterValues=[ + *parameter_values, + TestCaseParameterValue(name="keyColumns", value=json.dumps(["id"])), + TestCaseParameterValue( + name="table2.keyColumns", value=json.dumps(["table_id"]) + ), + ], + ) + + assert setter.get_parameters(test_case) == IsInstance( + TableDiffRuntimeParameters + ) diff --git a/ingestion/tests/unit/utils/test_credentials.py b/ingestion/tests/unit/utils/test_credentials.py new file mode 100644 index 00000000000..92e4f673a95 --- /dev/null +++ b/ingestion/tests/unit/utils/test_credentials.py @@ -0,0 +1,50 @@ +import pytest + +from metadata.utils.credentials import normalize_pem_string + + +def test_normalizes_escaped_newlines_for_pem(): + """It should replace literal '\\n' with actual newlines for PEM-like strings.""" + pem = "-----BEGIN PRIVATE KEY-----\\nABCDEF\\n-----END PRIVATE KEY-----" + result = normalize_pem_string(pem) + # Should contain actual newlines, not literal \n + assert "\n" in result + assert "\\n" not in result + # Should start/end correctly + assert result.startswith("-----BEGIN PRIVATE KEY-----") + assert result.endswith("-----END PRIVATE KEY-----") + + +def test_does_not_change_already_correct_pem(): + """It should leave PEMs with real newlines unchanged.""" + pem = "-----BEGIN PRIVATE KEY-----\nABCDEF\n-----END PRIVATE KEY-----" + result = normalize_pem_string(pem) + assert result == pem + + +def test_ignores_non_pem_strings(): + """It should not touch non-PEM strings, even if they contain '\\n'.""" + s = "password\\nwith\\nnewlines" + result = normalize_pem_string(s) + assert result == s # unchanged + + +def test_handles_other_pem_types(): + """It should detect and normalize other PEM headers like certificates.""" + cert = "-----BEGIN CERTIFICATE-----\\nXYZ\\n-----END CERTIFICATE-----" + result = normalize_pem_string(cert) + assert "\n" in result and "\\n" not in result + + +def test_mixed_case_is_left_unchanged(): + """If both literal and real newlines exist, don't double-convert.""" + mixed = "-----BEGIN PRIVATE KEY-----\\nABC\nDEF\\n-----END PRIVATE KEY-----" + result = normalize_pem_string(mixed) + # It should be left unchanged, since it has both kinds of newlines + assert result == mixed + + +@pytest.mark.parametrize("invalid", [None, 123, b"not a string"]) +def test_non_string_inputs_return_untouched(invalid): + """Non-string inputs should be returned as-is (no crash).""" + assert normalize_pem_string(invalid) == invalid