mirror of
https://github.com/open-metadata/OpenMetadata.git
synced 2025-09-22 23:43:09 +00:00
MINOR: Update Snowflake Connection (#22167)
* Update Snowflake Connection * Extracting needed methods
This commit is contained in:
parent
d12d93e0ea
commit
1260a0600a
@ -26,7 +26,7 @@ from metadata.generated.schema.entity.automations.workflow import (
|
||||
Workflow as AutomationWorkflow,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import (
|
||||
SnowflakeConnection,
|
||||
SnowflakeConnection as SnowflakeConnectionConfig,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.testConnectionResult import (
|
||||
TestConnectionResult,
|
||||
@ -37,6 +37,7 @@ from metadata.ingestion.connections.builders import (
|
||||
get_connection_options_dict,
|
||||
init_empty_connection_arguments,
|
||||
)
|
||||
from metadata.ingestion.connections.connection import BaseConnection
|
||||
from metadata.ingestion.connections.test_connections import (
|
||||
test_connection_engine_step,
|
||||
test_connection_steps,
|
||||
@ -58,171 +59,11 @@ logger = ingestion_logger()
|
||||
|
||||
|
||||
class SnowflakeEngineWrapper(BaseModel):
|
||||
service_connection: SnowflakeConnection
|
||||
service_connection: SnowflakeConnectionConfig
|
||||
engine: Any
|
||||
database_name: Optional[str] = None
|
||||
|
||||
|
||||
def get_connection_url(connection: SnowflakeConnection) -> str:
|
||||
"""
|
||||
Set the connection URL
|
||||
"""
|
||||
url = f"{connection.scheme.value}://"
|
||||
|
||||
if connection.username:
|
||||
url += f"{quote_plus(connection.username)}"
|
||||
if not connection.password:
|
||||
connection.password = SecretStr("")
|
||||
url += (
|
||||
f":{quote_plus(connection.password.get_secret_value())}"
|
||||
if connection
|
||||
else ""
|
||||
)
|
||||
url += "@"
|
||||
|
||||
url += connection.account
|
||||
url += f"/{connection.database}" if connection.database else ""
|
||||
|
||||
options = get_connection_options_dict(connection)
|
||||
if options:
|
||||
if not connection.database:
|
||||
url += "/"
|
||||
params = "&".join(
|
||||
f"{key}={quote_plus(value)}" for (key, value) in options.items() if value
|
||||
)
|
||||
url = f"{url}?{params}"
|
||||
options = {
|
||||
"account": connection.account,
|
||||
"warehouse": connection.warehouse,
|
||||
"role": connection.role,
|
||||
}
|
||||
params = "&".join(f"{key}={value}" for (key, value) in options.items() if value)
|
||||
if params:
|
||||
url = f"{url}?{params}"
|
||||
return url
|
||||
|
||||
|
||||
def get_connection(connection: SnowflakeConnection) -> Engine:
|
||||
"""
|
||||
Create connection
|
||||
"""
|
||||
if not connection.connectionArguments:
|
||||
connection.connectionArguments = init_empty_connection_arguments()
|
||||
|
||||
if connection.privateKey:
|
||||
snowflake_private_key_passphrase = (
|
||||
connection.snowflakePrivatekeyPassphrase.get_secret_value()
|
||||
if connection.snowflakePrivatekeyPassphrase
|
||||
else ""
|
||||
)
|
||||
|
||||
if not snowflake_private_key_passphrase:
|
||||
logger.warning(
|
||||
"Snowflake Private Key Passphrase not found, replacing it with empty string"
|
||||
)
|
||||
p_key = serialization.load_pem_private_key(
|
||||
bytes(connection.privateKey.get_secret_value(), "utf-8"),
|
||||
password=snowflake_private_key_passphrase.encode() or None,
|
||||
backend=default_backend(),
|
||||
)
|
||||
pkb = p_key.private_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
|
||||
connection.connectionArguments.root["private_key"] = pkb
|
||||
|
||||
if connection.clientSessionKeepAlive:
|
||||
connection.connectionArguments.root[
|
||||
"client_session_keep_alive"
|
||||
] = connection.clientSessionKeepAlive
|
||||
|
||||
engine = create_generic_db_connection(
|
||||
connection=connection,
|
||||
get_connection_url_fn=get_connection_url,
|
||||
get_connection_args_fn=get_connection_args_common,
|
||||
)
|
||||
if connection.connectionArguments.root and connection.connectionArguments.root.get(
|
||||
"private_key"
|
||||
):
|
||||
del connection.connectionArguments.root["private_key"]
|
||||
return engine
|
||||
|
||||
|
||||
def test_connection(
|
||||
metadata: OpenMetadata,
|
||||
engine: Engine,
|
||||
service_connection: SnowflakeConnection,
|
||||
automation_workflow: Optional[AutomationWorkflow] = None,
|
||||
timeout_seconds: Optional[int] = THREE_MIN,
|
||||
) -> TestConnectionResult:
|
||||
"""
|
||||
Test connection. This can be executed either as part
|
||||
of a metadata workflow or during an Automation Workflow.
|
||||
|
||||
Note how we run a custom GetTables query:
|
||||
|
||||
The default inspector `get_table_names` runs a SHOW which
|
||||
has a limit on 10000 rows in the result set:
|
||||
https://github.com/open-metadata/OpenMetadata/issues/12798
|
||||
|
||||
This can cause errors if we are running tests against schemas
|
||||
with more tables than that. There is no issues during the metadata
|
||||
ingestion since in metadata.py we are overriding the default
|
||||
`get_table_names` function with our custom queries.
|
||||
"""
|
||||
engine_wrapper = SnowflakeEngineWrapper(
|
||||
service_connection=service_connection, engine=engine, database_name=None
|
||||
)
|
||||
test_fn = {
|
||||
"CheckAccess": partial(test_connection_engine_step, engine),
|
||||
"GetDatabases": partial(
|
||||
test_query, statement=SNOWFLAKE_GET_DATABASES, engine=engine
|
||||
),
|
||||
"GetSchemas": partial(
|
||||
execute_inspector_func, engine_wrapper, "get_schema_names"
|
||||
),
|
||||
"GetTables": partial(
|
||||
test_table_query,
|
||||
statement=SNOWFLAKE_TEST_GET_TABLES,
|
||||
engine_wrapper=engine_wrapper,
|
||||
),
|
||||
"GetViews": partial(
|
||||
test_table_query,
|
||||
statement=SNOWFLAKE_TEST_GET_VIEWS,
|
||||
engine_wrapper=engine_wrapper,
|
||||
),
|
||||
"GetStreams": partial(
|
||||
test_table_query,
|
||||
statement=SNOWFLAKE_TEST_GET_STREAMS,
|
||||
engine_wrapper=engine_wrapper,
|
||||
),
|
||||
"GetQueries": partial(
|
||||
test_query,
|
||||
statement=SNOWFLAKE_TEST_GET_QUERIES.format(
|
||||
account_usage=service_connection.accountUsageSchema
|
||||
),
|
||||
engine=engine,
|
||||
),
|
||||
"GetTags": partial(
|
||||
test_query,
|
||||
statement=SNOWFLAKE_TEST_FETCH_TAG.format(
|
||||
account_usage=service_connection.accountUsageSchema
|
||||
),
|
||||
engine=engine,
|
||||
),
|
||||
}
|
||||
|
||||
return test_connection_steps(
|
||||
metadata=metadata,
|
||||
test_fn=test_fn,
|
||||
service_type=service_connection.type.value,
|
||||
automation_workflow=automation_workflow,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
|
||||
def _init_database(engine_wrapper: SnowflakeEngineWrapper):
|
||||
"""
|
||||
Initialize database
|
||||
@ -259,3 +100,194 @@ def test_table_query(engine_wrapper: SnowflakeEngineWrapper, statement: str):
|
||||
engine=engine_wrapper.engine,
|
||||
statement=statement.format(database_name=engine_wrapper.database_name),
|
||||
)
|
||||
|
||||
|
||||
class SnowflakeConnection(BaseConnection[SnowflakeConnectionConfig, Engine]):
|
||||
def _get_client(self) -> Engine:
|
||||
"""
|
||||
Return the SQLAlchemy Engine for Snowflake.
|
||||
"""
|
||||
return self.get_connection()
|
||||
|
||||
@staticmethod
|
||||
def get_connection_url(connection: SnowflakeConnectionConfig) -> str:
|
||||
"""
|
||||
Set the connection URL
|
||||
"""
|
||||
|
||||
url = f"{connection.scheme.value}://"
|
||||
|
||||
if connection.username:
|
||||
url += f"{quote_plus(connection.username)}"
|
||||
if not connection.password:
|
||||
connection.password = SecretStr("")
|
||||
url += (
|
||||
f":{quote_plus(connection.password.get_secret_value())}"
|
||||
if connection
|
||||
else ""
|
||||
)
|
||||
url += "@"
|
||||
|
||||
url += connection.account
|
||||
url += f"/{connection.database}" if connection.database else ""
|
||||
|
||||
options = get_connection_options_dict(connection)
|
||||
if options:
|
||||
if not connection.database:
|
||||
url += "/"
|
||||
params = "&".join(
|
||||
f"{key}={quote_plus(value)}"
|
||||
for (key, value) in options.items()
|
||||
if value
|
||||
)
|
||||
url = f"{url}?{params}"
|
||||
options = {
|
||||
"account": connection.account,
|
||||
"warehouse": connection.warehouse,
|
||||
"role": connection.role,
|
||||
}
|
||||
params = "&".join(f"{key}={value}" for (key, value) in options.items() if value)
|
||||
if params:
|
||||
url = f"{url}?{params}"
|
||||
return url
|
||||
|
||||
def _get_private_key(self) -> Optional[bytes]:
|
||||
connection = self.service_connection
|
||||
if connection.privateKey:
|
||||
snowflake_private_key_passphrase = (
|
||||
connection.snowflakePrivatekeyPassphrase.get_secret_value()
|
||||
if connection.snowflakePrivatekeyPassphrase
|
||||
else ""
|
||||
)
|
||||
|
||||
if not snowflake_private_key_passphrase:
|
||||
logger.warning(
|
||||
"Snowflake Private Key Passphrase not found, replacing it with empty string"
|
||||
)
|
||||
p_key = serialization.load_pem_private_key(
|
||||
bytes(connection.privateKey.get_secret_value(), "utf-8"),
|
||||
password=snowflake_private_key_passphrase.encode() or None,
|
||||
backend=default_backend(),
|
||||
)
|
||||
pkb = p_key.private_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
return pkb
|
||||
return None
|
||||
|
||||
def _get_client_session_keep_alive(self) -> Optional[bool]:
|
||||
connection = self.service_connection
|
||||
if connection.clientSessionKeepAlive:
|
||||
return connection.clientSessionKeepAlive
|
||||
return None
|
||||
|
||||
def get_connection(self) -> Engine:
|
||||
"""
|
||||
Create connection
|
||||
"""
|
||||
connection = self.service_connection
|
||||
if not connection.connectionArguments:
|
||||
connection.connectionArguments = init_empty_connection_arguments()
|
||||
|
||||
if private_key := self._get_private_key():
|
||||
connection.connectionArguments.root["private_key"] = private_key
|
||||
|
||||
if keep_alive := self._get_client_session_keep_alive():
|
||||
connection.connectionArguments.root[
|
||||
"client_session_keep_alive"
|
||||
] = keep_alive
|
||||
|
||||
engine = create_generic_db_connection(
|
||||
connection=connection,
|
||||
get_connection_url_fn=self.get_connection_url,
|
||||
get_connection_args_fn=get_connection_args_common,
|
||||
)
|
||||
if (
|
||||
connection.connectionArguments.root
|
||||
and connection.connectionArguments.root.get("private_key")
|
||||
):
|
||||
del connection.connectionArguments.root["private_key"]
|
||||
return engine
|
||||
|
||||
def test_connection(
|
||||
self,
|
||||
metadata: OpenMetadata,
|
||||
automation_workflow: Optional[AutomationWorkflow] = None,
|
||||
timeout_seconds: Optional[int] = THREE_MIN,
|
||||
) -> TestConnectionResult:
|
||||
"""
|
||||
Test connection. This can be executed either as part
|
||||
of a metadata workflow or during an Automation Workflow.
|
||||
|
||||
Note how we run a custom GetTables query:
|
||||
|
||||
The default inspector `get_table_names` runs a SHOW which
|
||||
has a limit on 10000 rows in the result set:
|
||||
https://github.com/open-metadata/OpenMetadata/issues/12798
|
||||
|
||||
This can cause errors if we are running tests against schemas
|
||||
with more tables than that. There is no issues during the metadata
|
||||
ingestion since in metadata.py we are overriding the default
|
||||
`get_table_names` function with our custom queries.
|
||||
"""
|
||||
engine_wrapper = SnowflakeEngineWrapper(
|
||||
service_connection=self.service_connection,
|
||||
engine=self.client,
|
||||
database_name=None,
|
||||
)
|
||||
test_fn = {
|
||||
"CheckAccess": partial(test_connection_engine_step, self.client),
|
||||
"GetDatabases": partial(
|
||||
test_query, statement=SNOWFLAKE_GET_DATABASES, engine=self.client
|
||||
),
|
||||
"GetSchemas": partial(
|
||||
execute_inspector_func, engine_wrapper, "get_schema_names"
|
||||
),
|
||||
"GetTables": partial(
|
||||
test_table_query,
|
||||
statement=SNOWFLAKE_TEST_GET_TABLES,
|
||||
engine_wrapper=engine_wrapper,
|
||||
),
|
||||
"GetViews": partial(
|
||||
test_table_query,
|
||||
statement=SNOWFLAKE_TEST_GET_VIEWS,
|
||||
engine_wrapper=engine_wrapper,
|
||||
),
|
||||
"GetStreams": partial(
|
||||
test_table_query,
|
||||
statement=SNOWFLAKE_TEST_GET_STREAMS,
|
||||
engine_wrapper=engine_wrapper,
|
||||
),
|
||||
"GetQueries": partial(
|
||||
test_query,
|
||||
statement=SNOWFLAKE_TEST_GET_QUERIES.format(
|
||||
account_usage=self.service_connection.accountUsageSchema
|
||||
),
|
||||
engine=self.client,
|
||||
),
|
||||
"GetTags": partial(
|
||||
test_query,
|
||||
statement=SNOWFLAKE_TEST_FETCH_TAG.format(
|
||||
account_usage=self.service_connection.accountUsageSchema
|
||||
),
|
||||
engine=self.client,
|
||||
),
|
||||
}
|
||||
|
||||
return test_connection_steps(
|
||||
metadata=metadata,
|
||||
test_fn=test_fn,
|
||||
service_type=self.service_connection.type.value,
|
||||
automation_workflow=automation_workflow,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
|
||||
def get_connection_dict(self) -> dict:
|
||||
"""
|
||||
Return the connection dictionary for this service.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"get_connection_dict is not implemented for Snowflake"
|
||||
)
|
||||
|
@ -1,6 +1,7 @@
|
||||
from metadata.data_quality.interface.sqlalchemy.snowflake.test_suite_interface import (
|
||||
SnowflakeTestSuiteInterface,
|
||||
)
|
||||
from metadata.ingestion.source.database.snowflake.connection import SnowflakeConnection
|
||||
from metadata.ingestion.source.database.snowflake.data_diff.data_diff import (
|
||||
SnowflakeTableParameter,
|
||||
)
|
||||
@ -21,4 +22,5 @@ ServiceSpec = DefaultDatabaseSpec(
|
||||
test_suite_class=SnowflakeTestSuiteInterface,
|
||||
sampler_class=SnowflakeSampler,
|
||||
data_diff=SnowflakeTableParameter,
|
||||
connection_class=SnowflakeConnection,
|
||||
)
|
||||
|
@ -96,7 +96,9 @@ from metadata.generated.schema.entity.services.connections.database.singleStoreC
|
||||
SingleStoreScheme,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import (
|
||||
SnowflakeConnection,
|
||||
SnowflakeConnection as SnowflakeConnectionConfig,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.snowflakeConnection import (
|
||||
SnowflakeScheme,
|
||||
)
|
||||
from metadata.generated.schema.entity.services.connections.database.trinoConnection import (
|
||||
@ -114,6 +116,7 @@ from metadata.ingestion.connections.builders import (
|
||||
get_connection_args_common,
|
||||
get_connection_url_common,
|
||||
)
|
||||
from metadata.ingestion.source.database.snowflake.connection import SnowflakeConnection
|
||||
from metadata.ingestion.source.database.trino.connection import TrinoConnection
|
||||
|
||||
|
||||
@ -782,31 +785,9 @@ class SourceConnectionTest(TestCase):
|
||||
assert expected_url == get_connection_url_common(db2_conn_obj)
|
||||
|
||||
def test_snowflake_url(self):
|
||||
# connection arguments without db
|
||||
|
||||
from metadata.ingestion.source.database.snowflake.connection import (
|
||||
get_connection_url,
|
||||
)
|
||||
|
||||
expected_url = "snowflake://coding:Abhi@ue18849.us-east-2.aws?account=ue18849.us-east-2.aws&warehouse=COMPUTE_WH"
|
||||
snowflake_conn_obj = SnowflakeConnection(
|
||||
scheme=SnowflakeScheme.snowflake,
|
||||
username="coding",
|
||||
password="Abhi",
|
||||
warehouse="COMPUTE_WH",
|
||||
account="ue18849.us-east-2.aws",
|
||||
)
|
||||
|
||||
assert expected_url == get_connection_url(snowflake_conn_obj)
|
||||
|
||||
def test_snowflake_url(self):
|
||||
from metadata.ingestion.source.database.snowflake.connection import (
|
||||
get_connection_url,
|
||||
)
|
||||
|
||||
# Passing @ in username and password
|
||||
expected_url = "snowflake://coding%40444:Abhi%40123@ue18849.us-east-2.aws?account=ue18849.us-east-2.aws&warehouse=COMPUTE_WH"
|
||||
snowflake_conn_obj = SnowflakeConnection(
|
||||
snowflake_conn_obj = SnowflakeConnectionConfig(
|
||||
scheme=SnowflakeScheme.snowflake,
|
||||
username="coding@444",
|
||||
password="Abhi@123",
|
||||
@ -814,11 +795,13 @@ class SourceConnectionTest(TestCase):
|
||||
account="ue18849.us-east-2.aws",
|
||||
)
|
||||
|
||||
assert expected_url == get_connection_url(snowflake_conn_obj)
|
||||
assert expected_url == SnowflakeConnection.get_connection_url(
|
||||
snowflake_conn_obj
|
||||
)
|
||||
|
||||
# connection arguments with db
|
||||
expected_url = "snowflake://coding:Abhi@ue18849.us-east-2.aws/testdb?account=ue18849.us-east-2.aws&warehouse=COMPUTE_WH"
|
||||
snowflake_conn_obj = SnowflakeConnection(
|
||||
snowflake_conn_obj = SnowflakeConnectionConfig(
|
||||
scheme=SnowflakeScheme.snowflake,
|
||||
username="coding",
|
||||
password="Abhi",
|
||||
@ -826,7 +809,10 @@ class SourceConnectionTest(TestCase):
|
||||
warehouse="COMPUTE_WH",
|
||||
account="ue18849.us-east-2.aws",
|
||||
)
|
||||
assert expected_url == get_connection_url(snowflake_conn_obj)
|
||||
|
||||
assert expected_url == SnowflakeConnection.get_connection_url(
|
||||
snowflake_conn_obj
|
||||
)
|
||||
|
||||
def test_mysql_conn_arguments(self):
|
||||
# connection arguments without connectionArguments
|
||||
@ -1008,7 +994,7 @@ class SourceConnectionTest(TestCase):
|
||||
def test_snowflake_conn_arguments(self):
|
||||
# connection arguments without connectionArguments
|
||||
expected_args = {}
|
||||
snowflake_conn_obj = SnowflakeConnection(
|
||||
snowflake_conn_obj = SnowflakeConnectionConfig(
|
||||
username="user",
|
||||
password="test-pwd",
|
||||
database="tiny",
|
||||
@ -1020,7 +1006,7 @@ class SourceConnectionTest(TestCase):
|
||||
|
||||
# connection arguments with connectionArguments
|
||||
expected_args = {"user": "user-to-be-impersonated"}
|
||||
snowflake_conn_obj = SnowflakeConnection(
|
||||
snowflake_conn_obj = SnowflakeConnectionConfig(
|
||||
username="user",
|
||||
password="test-pwd",
|
||||
database="tiny",
|
||||
|
Loading…
x
Reference in New Issue
Block a user