fix(ingest/snowflake): fix type annotations + refactor get_connect_args (#7004)

This commit is contained in:
Harshal Sheth 2023-01-10 18:47:11 -08:00 committed by GitHub
parent acd2ba13fc
commit 93dd87a14b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,5 @@
import logging
from typing import Dict, Optional
from typing import Any, Dict, Optional
import pydantic
import snowflake.connector
@ -153,7 +153,7 @@ class BaseSnowflakeConfig(BaseTimeWindowConfig):
default=True,
description="If enabled, populates the snowflake view->table and table->view lineages (no view->view lineage yet). Requires appropriate grants given to the role, and include_table_lineage to be True. view->table lineage requires Snowflake Enterprise Edition or above.",
)
connect_args: Optional[Dict] = pydantic.Field(
connect_args: Optional[Dict[str, Any]] = pydantic.Field(
default=None,
description="Connect args to pass to Snowflake SqlAlchemy driver",
exclude=True,
@ -297,28 +297,28 @@ class BaseSnowflakeConfig(BaseTimeWindowConfig):
},
)
_computed_connect_args: Optional[dict] = None
def get_connect_args(self) -> dict:
"""
Builds connect args and updates self.connect_args so that
Subsequent calls to this method are efficient, i.e. do not read files again
Builds connect args, adding defaults and reading a private key from the file if needed.
Caches the results in a private instance variable to avoid reading the file multiple times.
"""
base_connect_args = {
if self._computed_connect_args is not None:
return self._computed_connect_args
connect_args: dict = {
# Improves performance and avoids timeout errors for larger query result
CLIENT_PREFETCH_THREADS: 10,
CLIENT_SESSION_KEEP_ALIVE: True,
# Let user override the default config values
**(self.connect_args or {}),
}
if self.connect_args is None:
self.connect_args = base_connect_args
else:
# Let user override the default config values
base_connect_args.update(self.connect_args)
self.connect_args = base_connect_args
if (
self.authentication_type == "KEY_PAIR_AUTHENTICATOR"
and "private_key" not in self.connect_args.keys()
"private_key" not in connect_args
and self.authentication_type == "KEY_PAIR_AUTHENTICATOR"
):
if self.private_key is not None:
pkey_bytes = self.private_key.replace("\\n", "\n").encode()
@ -337,13 +337,16 @@ class BaseSnowflakeConfig(BaseTimeWindowConfig):
backend=default_backend(),
)
pkb = p_key.private_bytes(
pkb: bytes = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
self.connect_args.update({"private_key": pkb})
return self.connect_args
connect_args["private_key"] = pkb
self._computed_connect_args = connect_args
return connect_args
class SnowflakeConfig(BaseSnowflakeConfig, SQLAlchemyConfig):