feat(ingest): snowflake - support for additional auth mechanisms (#4009)

This commit is contained in:
Michael A. Schlosser 2022-01-30 13:47:53 -06:00 committed by GitHub
parent c5ff486435
commit c36662f837
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 151 additions and 9 deletions

View File

@ -83,6 +83,14 @@ bigquery_common = {
"more-itertools>=8.12.0",
}
snowflake_common = {
# Snowflake plugin utilizes sql common
*sql_common,
# Required for all Snowflake sources
"snowflake-sqlalchemy<=1.2.4",
"cryptography==3.4.8"
}
# Note: for all of these, framework_common will be added.
plugins: Dict[str, Set[str]] = {
# Sink plugins.
@ -134,9 +142,8 @@ plugins: Dict[str, Set[str]] = {
"redshift-usage": sql_common
| {"sqlalchemy-redshift", "psycopg2-binary", "GeoAlchemy2"},
"sagemaker": aws_common,
"snowflake": sql_common | {"snowflake-sqlalchemy<=1.2.4"},
"snowflake-usage": sql_common
| {"snowflake-sqlalchemy<=1.2.4", "more-itertools>=8.12.0"},
"snowflake": snowflake_common,
"snowflake-usage": snowflake_common | {"more-itertools>=8.12.0"},
"sqlalchemy": sql_common,
"superset": {"requests"},
"trino": sql_common

View File

@ -95,8 +95,11 @@ Note that a `.` is used to denote nested fields in the YAML recipe.
| Field | Required | Default | Description |
| ----------------------------- | -------- | --------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `authentication_type` | | `"DEFAULT_AUTHENTICATOR"` | The type of authenticator to use when connecting to Snowflake. Supports `"DEFAULT_AUTHENTICATOR"`, `"EXTERNAL_BROWSER_AUTHENTICATOR"` and `"KEY_PAIR_AUTHENTICATOR"`. |
| `username` | | | Snowflake username. |
| `password` | | | Snowflake password. |
| `private_key_path` | | | The path to the private key if using key pair authentication. See: https://docs.snowflake.com/en/user-guide/key-pair-auth.html |
| `private_key_password` | | | Password for your private key if using key pair authentication. |
| `host_port` | ✅ | | Snowflake host URL. |
| `warehouse` | | | Snowflake warehouse. |
| `role` | | | Snowflake role. |

View File

@ -7,6 +7,13 @@ import pydantic
# This import verifies that the dependencies are available.
import snowflake.sqlalchemy # noqa: F401
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from snowflake.connector.network import (
DEFAULT_AUTHENTICATOR,
EXTERNAL_BROWSER_AUTHENTICATOR,
KEY_PAIR_AUTHENTICATOR,
)
from snowflake.sqlalchemy import custom_types, snowdialect
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine.reflection import Inspector
@ -55,11 +62,46 @@ class BaseSnowflakeConfig(BaseTimeWindowConfig):
username: Optional[str] = None
password: Optional[pydantic.SecretStr] = pydantic.Field(default=None, exclude=True)
private_key_path: Optional[str]
private_key_password: Optional[pydantic.SecretStr] = pydantic.Field(
default=None, exclude=True
)
authentication_type: Optional[str] = "DEFAULT_AUTHENTICATOR"
host_port: str
warehouse: Optional[str]
role: Optional[str]
include_table_lineage: Optional[bool] = True
connect_args: Optional[dict]
@pydantic.validator("authentication_type")
def authenticator_type_is_valid(cls, v, values, **kwargs):
valid_auth_types = {
"DEFAULT_AUTHENTICATOR": DEFAULT_AUTHENTICATOR,
"EXTERNAL_BROWSER_AUTHENTICATOR": EXTERNAL_BROWSER_AUTHENTICATOR,
"KEY_PAIR_AUTHENTICATOR": KEY_PAIR_AUTHENTICATOR,
}
if v not in valid_auth_types.keys():
raise ValueError(
f"unsupported authenticator type '{v}' was provided,"
f" use one of {list(valid_auth_types.keys())}"
)
else:
if v == "KEY_PAIR_AUTHENTICATOR":
# If we are using key pair auth, we need the private key path and password to be set
if values.get("private_key_path") is None:
raise ValueError(
f"'private_key_path' was none "
f"but should be set when using {v} authentication"
)
if values.get("private_key_password") is None:
raise ValueError(
f"'private_key_password' was none "
f"but should be set when using {v} authentication"
)
logger.info(f"using authenticator type '{v}'")
return valid_auth_types.get(v)
def get_sql_alchemy_url(self, database=None):
return make_sqlalchemy_uri(
self.scheme,
@ -71,6 +113,7 @@ class BaseSnowflakeConfig(BaseTimeWindowConfig):
# Drop the options if value is None.
key: value
for (key, value) in {
"authenticator": self.authentication_type,
"warehouse": self.warehouse,
"role": self.role,
"application": APPLICATION_NAME,
@ -79,6 +122,29 @@ class BaseSnowflakeConfig(BaseTimeWindowConfig):
},
)
def get_sql_alchemy_connect_args(self) -> dict:
if self.authentication_type != KEY_PAIR_AUTHENTICATOR:
return {}
if self.connect_args is None:
if self.private_key_path is None:
raise ValueError("missing required private key path to read key from")
if self.private_key_password is None:
raise ValueError("missing required private key password")
with open(self.private_key_path, "rb") as key:
p_key = serialization.load_pem_private_key(
key.read(),
password=self.private_key_password.get_secret_value().encode(),
backend=default_backend(),
)
pkb = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
self.connect_args = {"private_key": pkb}
return self.connect_args
class SnowflakeConfig(BaseSnowflakeConfig, SQLAlchemyConfig):
database_pattern: AllowDenyPattern = AllowDenyPattern(
@ -98,6 +164,9 @@ class SnowflakeConfig(BaseSnowflakeConfig, SQLAlchemyConfig):
def get_sql_alchemy_url(self, database: str = None) -> str:
return super().get_sql_alchemy_url(database=database)
def get_sql_alchemy_connect_args(self) -> dict:
return super().get_sql_alchemy_connect_args()
class SnowflakeSource(SQLAlchemySource):
config: SnowflakeConfig
@ -116,7 +185,12 @@ class SnowflakeSource(SQLAlchemySource):
def get_inspectors(self) -> Iterable[Inspector]:
url = self.config.get_sql_alchemy_url(database=None)
logger.debug(f"sql_alchemy_url={url}")
db_listing_engine = create_engine(url, **self.config.options)
db_listing_engine = create_engine(
url,
connect_args=self.config.get_sql_alchemy_connect_args(),
**self.config.options,
)
for db_row in db_listing_engine.execute(text("SHOW DATABASES")):
db = db_row.name
@ -125,7 +199,9 @@ class SnowflakeSource(SQLAlchemySource):
# they are isolated from each other.
self.current_database = db
engine = create_engine(
self.config.get_sql_alchemy_url(database=db), **self.config.options
self.config.get_sql_alchemy_url(database=db),
connect_args=self.config.get_sql_alchemy_connect_args(),
**self.config.options,
)
with engine.connect() as conn:
@ -182,7 +258,11 @@ QUALIFY ROW_NUMBER() OVER (PARTITION BY downstream_table_name ORDER BY query_sta
def _populate_lineage(self) -> None:
url = self.config.get_sql_alchemy_url()
logger.debug(f"sql_alchemy_url={url}")
engine = create_engine(url, **self.config.options)
engine = create_engine(
url,
connect_args=self.config.get_sql_alchemy_connect_args(),
**self.config.options,
)
query: str = """
WITH table_lineage_history AS (
SELECT

View File

@ -287,7 +287,11 @@ class SnowflakeUsageSource(StatefulIngestionSourceBase):
def _make_sql_engine(self) -> Engine:
url = self.config.get_sql_alchemy_url()
logger.debug(f"sql_alchemy_url={url}")
engine = create_engine(url, **self.config.options)
engine = create_engine(
url,
connect_args=self.config.get_sql_alchemy_connect_args(),
**self.config.options,
)
return engine
def _get_snowflake_history(self) -> Iterable[SnowflakeJoinedAccessEvent]:

View File

@ -2,7 +2,7 @@ import pytest
@pytest.mark.integration
def test_snowflake_uri():
def test_snowflake_uri_default_authentication():
from datahub.ingestion.source.sql.snowflake import SnowflakeConfig
config = SnowflakeConfig.parse_obj(
@ -15,7 +15,55 @@ def test_snowflake_uri():
"role": "sysadmin",
}
)
assert (
config.get_sql_alchemy_url()
== "snowflake://user:password@acctname/?warehouse=COMPUTE_WH&role=sysadmin&application=acryl_datahub"
== "snowflake://user:password@acctname/?authenticator=DEFAULT_AUTHENTICATOR&warehouse=COMPUTE_WH&role"
"=sysadmin&application=acryl_datahub"
)
@pytest.mark.integration
def test_snowflake_uri_external_browser_authentication():
from datahub.ingestion.source.sql.snowflake import SnowflakeConfig
config = SnowflakeConfig.parse_obj(
{
"username": "user",
"host_port": "acctname",
"database": "demo",
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
"authentication_type": "EXTERNAL_BROWSER_AUTHENTICATOR",
}
)
assert (
config.get_sql_alchemy_url()
== "snowflake://user@acctname/?authenticator=EXTERNALBROWSER&warehouse=COMPUTE_WH&role"
"=sysadmin&application=acryl_datahub"
)
@pytest.mark.integration
def test_snowflake_uri_key_pair_authentication():
from datahub.ingestion.source.sql.snowflake import SnowflakeConfig
config = SnowflakeConfig.parse_obj(
{
"username": "user",
"host_port": "acctname",
"database": "demo",
"warehouse": "COMPUTE_WH",
"role": "sysadmin",
"authentication_type": "KEY_PAIR_AUTHENTICATOR",
"private_key_path": "/a/random/path",
"private_key_password": "a_random_password",
}
)
assert (
config.get_sql_alchemy_url()
== "snowflake://user@acctname/?authenticator=SNOWFLAKE_JWT&warehouse=COMPUTE_WH&role"
"=sysadmin&application=acryl_datahub"
)