mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-28 02:17:53 +00:00
fix(ingest/snowflake): fix type annotations + refactor get_connect_args (#7004)
This commit is contained in:
parent
acd2ba13fc
commit
93dd87a14b
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pydantic
|
||||
import snowflake.connector
|
||||
@ -153,7 +153,7 @@ class BaseSnowflakeConfig(BaseTimeWindowConfig):
|
||||
default=True,
|
||||
description="If enabled, populates the snowflake view->table and table->view lineages (no view->view lineage yet). Requires appropriate grants given to the role, and include_table_lineage to be True. view->table lineage requires Snowflake Enterprise Edition or above.",
|
||||
)
|
||||
connect_args: Optional[Dict] = pydantic.Field(
|
||||
connect_args: Optional[Dict[str, Any]] = pydantic.Field(
|
||||
default=None,
|
||||
description="Connect args to pass to Snowflake SqlAlchemy driver",
|
||||
exclude=True,
|
||||
@ -297,28 +297,28 @@ class BaseSnowflakeConfig(BaseTimeWindowConfig):
|
||||
},
|
||||
)
|
||||
|
||||
_computed_connect_args: Optional[dict] = None
|
||||
|
||||
def get_connect_args(self) -> dict:
|
||||
"""
|
||||
Builds connect args and updates self.connect_args so that
|
||||
Subsequent calls to this method are efficient, i.e. do not read files again
|
||||
Builds connect args, adding defaults and reading a private key from the file if needed.
|
||||
Caches the results in a private instance variable to avoid reading the file multiple times.
|
||||
"""
|
||||
|
||||
base_connect_args = {
|
||||
if self._computed_connect_args is not None:
|
||||
return self._computed_connect_args
|
||||
|
||||
connect_args: dict = {
|
||||
# Improves performance and avoids timeout errors for larger query result
|
||||
CLIENT_PREFETCH_THREADS: 10,
|
||||
CLIENT_SESSION_KEEP_ALIVE: True,
|
||||
# Let user override the default config values
|
||||
**(self.connect_args or {}),
|
||||
}
|
||||
|
||||
if self.connect_args is None:
|
||||
self.connect_args = base_connect_args
|
||||
else:
|
||||
# Let user override the default config values
|
||||
base_connect_args.update(self.connect_args)
|
||||
self.connect_args = base_connect_args
|
||||
|
||||
if (
|
||||
self.authentication_type == "KEY_PAIR_AUTHENTICATOR"
|
||||
and "private_key" not in self.connect_args.keys()
|
||||
"private_key" not in connect_args
|
||||
and self.authentication_type == "KEY_PAIR_AUTHENTICATOR"
|
||||
):
|
||||
if self.private_key is not None:
|
||||
pkey_bytes = self.private_key.replace("\\n", "\n").encode()
|
||||
@ -337,13 +337,16 @@ class BaseSnowflakeConfig(BaseTimeWindowConfig):
|
||||
backend=default_backend(),
|
||||
)
|
||||
|
||||
pkb = p_key.private_bytes(
|
||||
pkb: bytes = p_key.private_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
self.connect_args.update({"private_key": pkb})
|
||||
return self.connect_args
|
||||
|
||||
connect_args["private_key"] = pkb
|
||||
|
||||
self._computed_connect_args = connect_args
|
||||
return connect_args
|
||||
|
||||
|
||||
class SnowflakeConfig(BaseSnowflakeConfig, SQLAlchemyConfig):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user