MINOR: Update Snowflake Connection (#22167)

* Update Snowflake Connection

* Extracting needed methods
This commit is contained in:
IceS2 2025-07-07 16:38:22 +02:00 committed by GitHub
parent d12d93e0ea
commit 1260a0600a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 211 additions and 191 deletions

View File

@ -26,7 +26,7 @@ from metadata.generated.schema.entity.automations.workflow import (
Workflow as AutomationWorkflow, Workflow as AutomationWorkflow,
) )
from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import ( from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import (
SnowflakeConnection, SnowflakeConnection as SnowflakeConnectionConfig,
) )
from metadata.generated.schema.entity.services.connections.testConnectionResult import ( from metadata.generated.schema.entity.services.connections.testConnectionResult import (
TestConnectionResult, TestConnectionResult,
@ -37,6 +37,7 @@ from metadata.ingestion.connections.builders import (
get_connection_options_dict, get_connection_options_dict,
init_empty_connection_arguments, init_empty_connection_arguments,
) )
from metadata.ingestion.connections.connection import BaseConnection
from metadata.ingestion.connections.test_connections import ( from metadata.ingestion.connections.test_connections import (
test_connection_engine_step, test_connection_engine_step,
test_connection_steps, test_connection_steps,
@ -58,171 +59,11 @@ logger = ingestion_logger()
class SnowflakeEngineWrapper(BaseModel): class SnowflakeEngineWrapper(BaseModel):
service_connection: SnowflakeConnection service_connection: SnowflakeConnectionConfig
engine: Any engine: Any
database_name: Optional[str] = None database_name: Optional[str] = None
def get_connection_url(connection: SnowflakeConnection) -> str:
"""
Set the connection URL
"""
url = f"{connection.scheme.value}://"
if connection.username:
url += f"{quote_plus(connection.username)}"
if not connection.password:
connection.password = SecretStr("")
url += (
f":{quote_plus(connection.password.get_secret_value())}"
if connection
else ""
)
url += "@"
url += connection.account
url += f"/{connection.database}" if connection.database else ""
options = get_connection_options_dict(connection)
if options:
if not connection.database:
url += "/"
params = "&".join(
f"{key}={quote_plus(value)}" for (key, value) in options.items() if value
)
url = f"{url}?{params}"
options = {
"account": connection.account,
"warehouse": connection.warehouse,
"role": connection.role,
}
params = "&".join(f"{key}={value}" for (key, value) in options.items() if value)
if params:
url = f"{url}?{params}"
return url
def get_connection(connection: SnowflakeConnection) -> Engine:
"""
Create connection
"""
if not connection.connectionArguments:
connection.connectionArguments = init_empty_connection_arguments()
if connection.privateKey:
snowflake_private_key_passphrase = (
connection.snowflakePrivatekeyPassphrase.get_secret_value()
if connection.snowflakePrivatekeyPassphrase
else ""
)
if not snowflake_private_key_passphrase:
logger.warning(
"Snowflake Private Key Passphrase not found, replacing it with empty string"
)
p_key = serialization.load_pem_private_key(
bytes(connection.privateKey.get_secret_value(), "utf-8"),
password=snowflake_private_key_passphrase.encode() or None,
backend=default_backend(),
)
pkb = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
connection.connectionArguments.root["private_key"] = pkb
if connection.clientSessionKeepAlive:
connection.connectionArguments.root[
"client_session_keep_alive"
] = connection.clientSessionKeepAlive
engine = create_generic_db_connection(
connection=connection,
get_connection_url_fn=get_connection_url,
get_connection_args_fn=get_connection_args_common,
)
if connection.connectionArguments.root and connection.connectionArguments.root.get(
"private_key"
):
del connection.connectionArguments.root["private_key"]
return engine
def test_connection(
metadata: OpenMetadata,
engine: Engine,
service_connection: SnowflakeConnection,
automation_workflow: Optional[AutomationWorkflow] = None,
timeout_seconds: Optional[int] = THREE_MIN,
) -> TestConnectionResult:
"""
Test connection. This can be executed either as part
of a metadata workflow or during an Automation Workflow.
Note how we run a custom GetTables query:
The default inspector `get_table_names` runs a SHOW which
has a limit on 10000 rows in the result set:
https://github.com/open-metadata/OpenMetadata/issues/12798
This can cause errors if we are running tests against schemas
with more tables than that. There is no issues during the metadata
ingestion since in metadata.py we are overriding the default
`get_table_names` function with our custom queries.
"""
engine_wrapper = SnowflakeEngineWrapper(
service_connection=service_connection, engine=engine, database_name=None
)
test_fn = {
"CheckAccess": partial(test_connection_engine_step, engine),
"GetDatabases": partial(
test_query, statement=SNOWFLAKE_GET_DATABASES, engine=engine
),
"GetSchemas": partial(
execute_inspector_func, engine_wrapper, "get_schema_names"
),
"GetTables": partial(
test_table_query,
statement=SNOWFLAKE_TEST_GET_TABLES,
engine_wrapper=engine_wrapper,
),
"GetViews": partial(
test_table_query,
statement=SNOWFLAKE_TEST_GET_VIEWS,
engine_wrapper=engine_wrapper,
),
"GetStreams": partial(
test_table_query,
statement=SNOWFLAKE_TEST_GET_STREAMS,
engine_wrapper=engine_wrapper,
),
"GetQueries": partial(
test_query,
statement=SNOWFLAKE_TEST_GET_QUERIES.format(
account_usage=service_connection.accountUsageSchema
),
engine=engine,
),
"GetTags": partial(
test_query,
statement=SNOWFLAKE_TEST_FETCH_TAG.format(
account_usage=service_connection.accountUsageSchema
),
engine=engine,
),
}
return test_connection_steps(
metadata=metadata,
test_fn=test_fn,
service_type=service_connection.type.value,
automation_workflow=automation_workflow,
timeout_seconds=timeout_seconds,
)
def _init_database(engine_wrapper: SnowflakeEngineWrapper): def _init_database(engine_wrapper: SnowflakeEngineWrapper):
""" """
Initialize database Initialize database
@ -259,3 +100,194 @@ def test_table_query(engine_wrapper: SnowflakeEngineWrapper, statement: str):
engine=engine_wrapper.engine, engine=engine_wrapper.engine,
statement=statement.format(database_name=engine_wrapper.database_name), statement=statement.format(database_name=engine_wrapper.database_name),
) )
class SnowflakeConnection(BaseConnection[SnowflakeConnectionConfig, Engine]):
def _get_client(self) -> Engine:
"""
Return the SQLAlchemy Engine for Snowflake.
"""
return self.get_connection()
@staticmethod
def get_connection_url(connection: SnowflakeConnectionConfig) -> str:
"""
Set the connection URL
"""
url = f"{connection.scheme.value}://"
if connection.username:
url += f"{quote_plus(connection.username)}"
if not connection.password:
connection.password = SecretStr("")
url += (
f":{quote_plus(connection.password.get_secret_value())}"
if connection
else ""
)
url += "@"
url += connection.account
url += f"/{connection.database}" if connection.database else ""
options = get_connection_options_dict(connection)
if options:
if not connection.database:
url += "/"
params = "&".join(
f"{key}={quote_plus(value)}"
for (key, value) in options.items()
if value
)
url = f"{url}?{params}"
options = {
"account": connection.account,
"warehouse": connection.warehouse,
"role": connection.role,
}
params = "&".join(f"{key}={value}" for (key, value) in options.items() if value)
if params:
url = f"{url}?{params}"
return url
def _get_private_key(self) -> Optional[bytes]:
connection = self.service_connection
if connection.privateKey:
snowflake_private_key_passphrase = (
connection.snowflakePrivatekeyPassphrase.get_secret_value()
if connection.snowflakePrivatekeyPassphrase
else ""
)
if not snowflake_private_key_passphrase:
logger.warning(
"Snowflake Private Key Passphrase not found, replacing it with empty string"
)
p_key = serialization.load_pem_private_key(
bytes(connection.privateKey.get_secret_value(), "utf-8"),
password=snowflake_private_key_passphrase.encode() or None,
backend=default_backend(),
)
pkb = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
return pkb
return None
def _get_client_session_keep_alive(self) -> Optional[bool]:
connection = self.service_connection
if connection.clientSessionKeepAlive:
return connection.clientSessionKeepAlive
return None
def get_connection(self) -> Engine:
"""
Create connection
"""
connection = self.service_connection
if not connection.connectionArguments:
connection.connectionArguments = init_empty_connection_arguments()
if private_key := self._get_private_key():
connection.connectionArguments.root["private_key"] = private_key
if keep_alive := self._get_client_session_keep_alive():
connection.connectionArguments.root[
"client_session_keep_alive"
] = keep_alive
engine = create_generic_db_connection(
connection=connection,
get_connection_url_fn=self.get_connection_url,
get_connection_args_fn=get_connection_args_common,
)
if (
connection.connectionArguments.root
and connection.connectionArguments.root.get("private_key")
):
del connection.connectionArguments.root["private_key"]
return engine
def test_connection(
self,
metadata: OpenMetadata,
automation_workflow: Optional[AutomationWorkflow] = None,
timeout_seconds: Optional[int] = THREE_MIN,
) -> TestConnectionResult:
"""
Test connection. This can be executed either as part
of a metadata workflow or during an Automation Workflow.
Note how we run a custom GetTables query:
The default inspector `get_table_names` runs a SHOW which
has a limit on 10000 rows in the result set:
https://github.com/open-metadata/OpenMetadata/issues/12798
This can cause errors if we are running tests against schemas
with more tables than that. There is no issues during the metadata
ingestion since in metadata.py we are overriding the default
`get_table_names` function with our custom queries.
"""
engine_wrapper = SnowflakeEngineWrapper(
service_connection=self.service_connection,
engine=self.client,
database_name=None,
)
test_fn = {
"CheckAccess": partial(test_connection_engine_step, self.client),
"GetDatabases": partial(
test_query, statement=SNOWFLAKE_GET_DATABASES, engine=self.client
),
"GetSchemas": partial(
execute_inspector_func, engine_wrapper, "get_schema_names"
),
"GetTables": partial(
test_table_query,
statement=SNOWFLAKE_TEST_GET_TABLES,
engine_wrapper=engine_wrapper,
),
"GetViews": partial(
test_table_query,
statement=SNOWFLAKE_TEST_GET_VIEWS,
engine_wrapper=engine_wrapper,
),
"GetStreams": partial(
test_table_query,
statement=SNOWFLAKE_TEST_GET_STREAMS,
engine_wrapper=engine_wrapper,
),
"GetQueries": partial(
test_query,
statement=SNOWFLAKE_TEST_GET_QUERIES.format(
account_usage=self.service_connection.accountUsageSchema
),
engine=self.client,
),
"GetTags": partial(
test_query,
statement=SNOWFLAKE_TEST_FETCH_TAG.format(
account_usage=self.service_connection.accountUsageSchema
),
engine=self.client,
),
}
return test_connection_steps(
metadata=metadata,
test_fn=test_fn,
service_type=self.service_connection.type.value,
automation_workflow=automation_workflow,
timeout_seconds=timeout_seconds,
)
def get_connection_dict(self) -> dict:
"""
Return the connection dictionary for this service.
"""
raise NotImplementedError(
"get_connection_dict is not implemented for Snowflake"
)

View File

@ -1,6 +1,7 @@
from metadata.data_quality.interface.sqlalchemy.snowflake.test_suite_interface import ( from metadata.data_quality.interface.sqlalchemy.snowflake.test_suite_interface import (
SnowflakeTestSuiteInterface, SnowflakeTestSuiteInterface,
) )
from metadata.ingestion.source.database.snowflake.connection import SnowflakeConnection
from metadata.ingestion.source.database.snowflake.data_diff.data_diff import ( from metadata.ingestion.source.database.snowflake.data_diff.data_diff import (
SnowflakeTableParameter, SnowflakeTableParameter,
) )
@ -21,4 +22,5 @@ ServiceSpec = DefaultDatabaseSpec(
test_suite_class=SnowflakeTestSuiteInterface, test_suite_class=SnowflakeTestSuiteInterface,
sampler_class=SnowflakeSampler, sampler_class=SnowflakeSampler,
data_diff=SnowflakeTableParameter, data_diff=SnowflakeTableParameter,
connection_class=SnowflakeConnection,
) )

View File

@ -96,7 +96,9 @@ from metadata.generated.schema.entity.services.connections.database.singleStoreC
SingleStoreScheme, SingleStoreScheme,
) )
from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import ( from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import (
SnowflakeConnection, SnowflakeConnection as SnowflakeConnectionConfig,
)
from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import (
SnowflakeScheme, SnowflakeScheme,
) )
from metadata.generated.schema.entity.services.connections.database.trinoConnection import ( from metadata.generated.schema.entity.services.connections.database.trinoConnection import (
@ -114,6 +116,7 @@ from metadata.ingestion.connections.builders import (
get_connection_args_common, get_connection_args_common,
get_connection_url_common, get_connection_url_common,
) )
from metadata.ingestion.source.database.snowflake.connection import SnowflakeConnection
from metadata.ingestion.source.database.trino.connection import TrinoConnection from metadata.ingestion.source.database.trino.connection import TrinoConnection
@ -782,31 +785,9 @@ class SourceConnectionTest(TestCase):
assert expected_url == get_connection_url_common(db2_conn_obj) assert expected_url == get_connection_url_common(db2_conn_obj)
def test_snowflake_url(self): def test_snowflake_url(self):
# connection arguments without db
from metadata.ingestion.source.database.snowflake.connection import (
get_connection_url,
)
expected_url = "snowflake://coding:Abhi@ue18849.us-east-2.aws?account=ue18849.us-east-2.aws&warehouse=COMPUTE_WH"
snowflake_conn_obj = SnowflakeConnection(
scheme=SnowflakeScheme.snowflake,
username="coding",
password="Abhi",
warehouse="COMPUTE_WH",
account="ue18849.us-east-2.aws",
)
assert expected_url == get_connection_url(snowflake_conn_obj)
def test_snowflake_url(self):
from metadata.ingestion.source.database.snowflake.connection import (
get_connection_url,
)
# Passing @ in username and password # Passing @ in username and password
expected_url = "snowflake://coding%40444:Abhi%40123@ue18849.us-east-2.aws?account=ue18849.us-east-2.aws&warehouse=COMPUTE_WH" expected_url = "snowflake://coding%40444:Abhi%40123@ue18849.us-east-2.aws?account=ue18849.us-east-2.aws&warehouse=COMPUTE_WH"
snowflake_conn_obj = SnowflakeConnection( snowflake_conn_obj = SnowflakeConnectionConfig(
scheme=SnowflakeScheme.snowflake, scheme=SnowflakeScheme.snowflake,
username="coding@444", username="coding@444",
password="Abhi@123", password="Abhi@123",
@ -814,11 +795,13 @@ class SourceConnectionTest(TestCase):
account="ue18849.us-east-2.aws", account="ue18849.us-east-2.aws",
) )
assert expected_url == get_connection_url(snowflake_conn_obj) assert expected_url == SnowflakeConnection.get_connection_url(
snowflake_conn_obj
)
# connection arguments with db # connection arguments with db
expected_url = "snowflake://coding:Abhi@ue18849.us-east-2.aws/testdb?account=ue18849.us-east-2.aws&warehouse=COMPUTE_WH" expected_url = "snowflake://coding:Abhi@ue18849.us-east-2.aws/testdb?account=ue18849.us-east-2.aws&warehouse=COMPUTE_WH"
snowflake_conn_obj = SnowflakeConnection( snowflake_conn_obj = SnowflakeConnectionConfig(
scheme=SnowflakeScheme.snowflake, scheme=SnowflakeScheme.snowflake,
username="coding", username="coding",
password="Abhi", password="Abhi",
@ -826,7 +809,10 @@ class SourceConnectionTest(TestCase):
warehouse="COMPUTE_WH", warehouse="COMPUTE_WH",
account="ue18849.us-east-2.aws", account="ue18849.us-east-2.aws",
) )
assert expected_url == get_connection_url(snowflake_conn_obj)
assert expected_url == SnowflakeConnection.get_connection_url(
snowflake_conn_obj
)
def test_mysql_conn_arguments(self): def test_mysql_conn_arguments(self):
# connection arguments without connectionArguments # connection arguments without connectionArguments
@ -1008,7 +994,7 @@ class SourceConnectionTest(TestCase):
def test_snowflake_conn_arguments(self): def test_snowflake_conn_arguments(self):
# connection arguments without connectionArguments # connection arguments without connectionArguments
expected_args = {} expected_args = {}
snowflake_conn_obj = SnowflakeConnection( snowflake_conn_obj = SnowflakeConnectionConfig(
username="user", username="user",
password="test-pwd", password="test-pwd",
database="tiny", database="tiny",
@ -1020,7 +1006,7 @@ class SourceConnectionTest(TestCase):
# connection arguments with connectionArguments # connection arguments with connectionArguments
expected_args = {"user": "user-to-be-impersonated"} expected_args = {"user": "user-to-be-impersonated"}
snowflake_conn_obj = SnowflakeConnection( snowflake_conn_obj = SnowflakeConnectionConfig(
username="user", username="user",
password="test-pwd", password="test-pwd",
database="tiny", database="tiny",