fix(snowflake): passing connect args should not cause failures (#4764)

* fix(snowflake): passing connect args should not cause failures

Co-authored-by: Ravindra Lanka <rlanka@acryl.io>
This commit is contained in:
Aseem Bansal 2022-05-03 17:50:11 +05:30 committed by GitHub
parent 6828dc3d4c
commit 3ff53b417b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 29 additions and 7 deletions

View File

@ -89,8 +89,7 @@ class SnowflakeSource(SQLAlchemySource):
logger.debug(f"sql_alchemy_url={url}")
return create_engine(
url,
connect_args=self.config.get_sql_alchemy_connect_args(),
**self.config.options,
**self.config.get_options(),
)
def inspect_session_metadata(self) -> Any:

View File

@ -293,8 +293,7 @@ class SnowflakeUsageSource(StatefulIngestionSourceBase):
logger.debug(f"sql_alchemy_url={url}")
engine = create_engine(
url,
connect_args=self.config.get_sql_alchemy_connect_args(),
**self.config.options,
**self.config.get_options(),
)
return engine

View File

@ -215,5 +215,8 @@ class SnowflakeConfig(BaseSnowflakeConfig, SQLAlchemyConfig):
database=database, username=username, password=password, role=role
)
def get_sql_alchemy_connect_args(self) -> dict:
return super().get_sql_alchemy_connect_args()
def get_options(self) -> dict:
options_connect_args: Dict = super().get_sql_alchemy_connect_args()
options_connect_args.update(self.options.get("connect_args", {}))
self.options["connect_args"] = options_connect_args
return self.options

View File

@ -1,5 +1,5 @@
import logging
from typing import Optional
from typing import Dict, Optional
import pydantic
@ -38,6 +38,12 @@ class SnowflakeUsageConfig(
apply_view_usage_to_tables: bool = False
stateful_ingestion: Optional[SnowflakeStatefulIngestionConfig] = None
def get_options(self) -> dict:
options_connect_args: Dict = super().get_sql_alchemy_connect_args()
options_connect_args.update(self.options.get("connect_args", {}))
self.options["connect_args"] = options_connect_args
return self.options
def get_sql_alchemy_url(self):
return super().get_sql_alchemy_url(
database="snowflake",

View File

@ -88,3 +88,18 @@ def test_snowflake_uri_key_pair_authentication():
== "snowflake://user@acctname/?authenticator=SNOWFLAKE_JWT&warehouse=COMPUTE_WH&role"
"=sysadmin&application=acryl_datahub"
)
def test_options_contain_connect_args():
config = SnowflakeConfig.parse_obj(
{
"username": "user",
"password": "password",
"host_port": "acctname",
"database_pattern": {"allow": {"^demo$"}},
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
}
)
connect_args = config.get_options().get("connect_args")
assert connect_args is not None