mirror of
https://github.com/datahub-project/datahub.git
synced 2025-12-12 10:35:51 +00:00
feat(ingest): snowflake - support for additional auth mechanisms (#4009)
This commit is contained in:
parent
c5ff486435
commit
c36662f837
@ -83,6 +83,14 @@ bigquery_common = {
|
||||
"more-itertools>=8.12.0",
|
||||
}
|
||||
|
||||
snowflake_common = {
|
||||
# Snowflake plugin utilizes sql common
|
||||
*sql_common,
|
||||
# Required for all Snowflake sources
|
||||
"snowflake-sqlalchemy<=1.2.4",
|
||||
"cryptography==3.4.8"
|
||||
}
|
||||
|
||||
# Note: for all of these, framework_common will be added.
|
||||
plugins: Dict[str, Set[str]] = {
|
||||
# Sink plugins.
|
||||
@ -134,9 +142,8 @@ plugins: Dict[str, Set[str]] = {
|
||||
"redshift-usage": sql_common
|
||||
| {"sqlalchemy-redshift", "psycopg2-binary", "GeoAlchemy2"},
|
||||
"sagemaker": aws_common,
|
||||
"snowflake": sql_common | {"snowflake-sqlalchemy<=1.2.4"},
|
||||
"snowflake-usage": sql_common
|
||||
| {"snowflake-sqlalchemy<=1.2.4", "more-itertools>=8.12.0"},
|
||||
"snowflake": snowflake_common,
|
||||
"snowflake-usage": snowflake_common | {"more-itertools>=8.12.0"},
|
||||
"sqlalchemy": sql_common,
|
||||
"superset": {"requests"},
|
||||
"trino": sql_common
|
||||
|
||||
@ -95,8 +95,11 @@ Note that a `.` is used to denote nested fields in the YAML recipe.
|
||||
|
||||
| Field | Required | Default | Description |
|
||||
| ----------------------------- | -------- | --------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `authentication_type` | | `"DEFAULT_AUTHENTICATOR"` | The type of authenticator to use when connecting to Snowflake. Supports `"DEFAULT_AUTHENTICATOR"`, `"EXTERNAL_BROWSER_AUTHENTICATOR"` and `"KEY_PAIR_AUTHENTICATOR"`. |
|
||||
| `username` | | | Snowflake username. |
|
||||
| `password` | | | Snowflake password. |
|
||||
| `private_key_path` | | | The path to the private key if using key pair authentication. See: https://docs.snowflake.com/en/user-guide/key-pair-auth.html |
|
||||
| `private_key_password` | | | Password for your private key if using key pair authentication. |
|
||||
| `host_port` | ✅ | | Snowflake host URL. |
|
||||
| `warehouse` | | | Snowflake warehouse. |
|
||||
| `role` | | | Snowflake role. |
|
||||
|
||||
@ -7,6 +7,13 @@ import pydantic
|
||||
|
||||
# This import verifies that the dependencies are available.
|
||||
import snowflake.sqlalchemy # noqa: F401
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from snowflake.connector.network import (
|
||||
DEFAULT_AUTHENTICATOR,
|
||||
EXTERNAL_BROWSER_AUTHENTICATOR,
|
||||
KEY_PAIR_AUTHENTICATOR,
|
||||
)
|
||||
from snowflake.sqlalchemy import custom_types, snowdialect
|
||||
from sqlalchemy import create_engine, inspect
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
@ -55,11 +62,46 @@ class BaseSnowflakeConfig(BaseTimeWindowConfig):
|
||||
|
||||
username: Optional[str] = None
|
||||
password: Optional[pydantic.SecretStr] = pydantic.Field(default=None, exclude=True)
|
||||
private_key_path: Optional[str]
|
||||
private_key_password: Optional[pydantic.SecretStr] = pydantic.Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
authentication_type: Optional[str] = "DEFAULT_AUTHENTICATOR"
|
||||
host_port: str
|
||||
warehouse: Optional[str]
|
||||
role: Optional[str]
|
||||
include_table_lineage: Optional[bool] = True
|
||||
|
||||
connect_args: Optional[dict]
|
||||
|
||||
@pydantic.validator("authentication_type")
|
||||
def authenticator_type_is_valid(cls, v, values, **kwargs):
|
||||
valid_auth_types = {
|
||||
"DEFAULT_AUTHENTICATOR": DEFAULT_AUTHENTICATOR,
|
||||
"EXTERNAL_BROWSER_AUTHENTICATOR": EXTERNAL_BROWSER_AUTHENTICATOR,
|
||||
"KEY_PAIR_AUTHENTICATOR": KEY_PAIR_AUTHENTICATOR,
|
||||
}
|
||||
if v not in valid_auth_types.keys():
|
||||
raise ValueError(
|
||||
f"unsupported authenticator type '{v}' was provided,"
|
||||
f" use one of {list(valid_auth_types.keys())}"
|
||||
)
|
||||
else:
|
||||
if v == "KEY_PAIR_AUTHENTICATOR":
|
||||
# If we are using key pair auth, we need the private key path and password to be set
|
||||
if values.get("private_key_path") is None:
|
||||
raise ValueError(
|
||||
f"'private_key_path' was none "
|
||||
f"but should be set when using {v} authentication"
|
||||
)
|
||||
if values.get("private_key_password") is None:
|
||||
raise ValueError(
|
||||
f"'private_key_password' was none "
|
||||
f"but should be set when using {v} authentication"
|
||||
)
|
||||
logger.info(f"using authenticator type '{v}'")
|
||||
return valid_auth_types.get(v)
|
||||
|
||||
def get_sql_alchemy_url(self, database=None):
|
||||
return make_sqlalchemy_uri(
|
||||
self.scheme,
|
||||
@ -71,6 +113,7 @@ class BaseSnowflakeConfig(BaseTimeWindowConfig):
|
||||
# Drop the options if value is None.
|
||||
key: value
|
||||
for (key, value) in {
|
||||
"authenticator": self.authentication_type,
|
||||
"warehouse": self.warehouse,
|
||||
"role": self.role,
|
||||
"application": APPLICATION_NAME,
|
||||
@ -79,6 +122,29 @@ class BaseSnowflakeConfig(BaseTimeWindowConfig):
|
||||
},
|
||||
)
|
||||
|
||||
def get_sql_alchemy_connect_args(self) -> dict:
|
||||
if self.authentication_type != KEY_PAIR_AUTHENTICATOR:
|
||||
return {}
|
||||
if self.connect_args is None:
|
||||
if self.private_key_path is None:
|
||||
raise ValueError("missing required private key path to read key from")
|
||||
if self.private_key_password is None:
|
||||
raise ValueError("missing required private key password")
|
||||
with open(self.private_key_path, "rb") as key:
|
||||
p_key = serialization.load_pem_private_key(
|
||||
key.read(),
|
||||
password=self.private_key_password.get_secret_value().encode(),
|
||||
backend=default_backend(),
|
||||
)
|
||||
|
||||
pkb = p_key.private_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
self.connect_args = {"private_key": pkb}
|
||||
return self.connect_args
|
||||
|
||||
|
||||
class SnowflakeConfig(BaseSnowflakeConfig, SQLAlchemyConfig):
|
||||
database_pattern: AllowDenyPattern = AllowDenyPattern(
|
||||
@ -98,6 +164,9 @@ class SnowflakeConfig(BaseSnowflakeConfig, SQLAlchemyConfig):
|
||||
def get_sql_alchemy_url(self, database: str = None) -> str:
|
||||
return super().get_sql_alchemy_url(database=database)
|
||||
|
||||
def get_sql_alchemy_connect_args(self) -> dict:
|
||||
return super().get_sql_alchemy_connect_args()
|
||||
|
||||
|
||||
class SnowflakeSource(SQLAlchemySource):
|
||||
config: SnowflakeConfig
|
||||
@ -116,7 +185,12 @@ class SnowflakeSource(SQLAlchemySource):
|
||||
def get_inspectors(self) -> Iterable[Inspector]:
|
||||
url = self.config.get_sql_alchemy_url(database=None)
|
||||
logger.debug(f"sql_alchemy_url={url}")
|
||||
db_listing_engine = create_engine(url, **self.config.options)
|
||||
|
||||
db_listing_engine = create_engine(
|
||||
url,
|
||||
connect_args=self.config.get_sql_alchemy_connect_args(),
|
||||
**self.config.options,
|
||||
)
|
||||
|
||||
for db_row in db_listing_engine.execute(text("SHOW DATABASES")):
|
||||
db = db_row.name
|
||||
@ -125,7 +199,9 @@ class SnowflakeSource(SQLAlchemySource):
|
||||
# they are isolated from each other.
|
||||
self.current_database = db
|
||||
engine = create_engine(
|
||||
self.config.get_sql_alchemy_url(database=db), **self.config.options
|
||||
self.config.get_sql_alchemy_url(database=db),
|
||||
connect_args=self.config.get_sql_alchemy_connect_args(),
|
||||
**self.config.options,
|
||||
)
|
||||
|
||||
with engine.connect() as conn:
|
||||
@ -182,7 +258,11 @@ QUALIFY ROW_NUMBER() OVER (PARTITION BY downstream_table_name ORDER BY query_sta
|
||||
def _populate_lineage(self) -> None:
|
||||
url = self.config.get_sql_alchemy_url()
|
||||
logger.debug(f"sql_alchemy_url={url}")
|
||||
engine = create_engine(url, **self.config.options)
|
||||
engine = create_engine(
|
||||
url,
|
||||
connect_args=self.config.get_sql_alchemy_connect_args(),
|
||||
**self.config.options,
|
||||
)
|
||||
query: str = """
|
||||
WITH table_lineage_history AS (
|
||||
SELECT
|
||||
|
||||
@ -287,7 +287,11 @@ class SnowflakeUsageSource(StatefulIngestionSourceBase):
|
||||
def _make_sql_engine(self) -> Engine:
|
||||
url = self.config.get_sql_alchemy_url()
|
||||
logger.debug(f"sql_alchemy_url={url}")
|
||||
engine = create_engine(url, **self.config.options)
|
||||
engine = create_engine(
|
||||
url,
|
||||
connect_args=self.config.get_sql_alchemy_connect_args(),
|
||||
**self.config.options,
|
||||
)
|
||||
return engine
|
||||
|
||||
def _get_snowflake_history(self) -> Iterable[SnowflakeJoinedAccessEvent]:
|
||||
|
||||
@ -2,7 +2,7 @@ import pytest
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_snowflake_uri():
|
||||
def test_snowflake_uri_default_authentication():
|
||||
from datahub.ingestion.source.sql.snowflake import SnowflakeConfig
|
||||
|
||||
config = SnowflakeConfig.parse_obj(
|
||||
@ -15,7 +15,55 @@ def test_snowflake_uri():
|
||||
"role": "sysadmin",
|
||||
}
|
||||
)
|
||||
|
||||
assert (
|
||||
config.get_sql_alchemy_url()
|
||||
== "snowflake://user:password@acctname/?warehouse=COMPUTE_WH&role=sysadmin&application=acryl_datahub"
|
||||
== "snowflake://user:password@acctname/?authenticator=DEFAULT_AUTHENTICATOR&warehouse=COMPUTE_WH&role"
|
||||
"=sysadmin&application=acryl_datahub"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_snowflake_uri_external_browser_authentication():
|
||||
from datahub.ingestion.source.sql.snowflake import SnowflakeConfig
|
||||
|
||||
config = SnowflakeConfig.parse_obj(
|
||||
{
|
||||
"username": "user",
|
||||
"host_port": "acctname",
|
||||
"database": "demo",
|
||||
"warehouse": "COMPUTE_WH",
|
||||
"role": "sysadmin",
|
||||
"authentication_type": "EXTERNAL_BROWSER_AUTHENTICATOR",
|
||||
}
|
||||
)
|
||||
|
||||
assert (
|
||||
config.get_sql_alchemy_url()
|
||||
== "snowflake://user@acctname/?authenticator=EXTERNALBROWSER&warehouse=COMPUTE_WH&role"
|
||||
"=sysadmin&application=acryl_datahub"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_snowflake_uri_key_pair_authentication():
|
||||
from datahub.ingestion.source.sql.snowflake import SnowflakeConfig
|
||||
|
||||
config = SnowflakeConfig.parse_obj(
|
||||
{
|
||||
"username": "user",
|
||||
"host_port": "acctname",
|
||||
"database": "demo",
|
||||
"warehouse": "COMPUTE_WH",
|
||||
"role": "sysadmin",
|
||||
"authentication_type": "KEY_PAIR_AUTHENTICATOR",
|
||||
"private_key_path": "/a/random/path",
|
||||
"private_key_password": "a_random_password",
|
||||
}
|
||||
)
|
||||
|
||||
assert (
|
||||
config.get_sql_alchemy_url()
|
||||
== "snowflake://user@acctname/?authenticator=SNOWFLAKE_JWT&warehouse=COMPUTE_WH&role"
|
||||
"=sysadmin&application=acryl_datahub"
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user