feat(ingest/dbt): support aws config without region (#9650)

Co-authored-by: Tamas Nemeth <treff7es@gmail.com>
This commit is contained in:
Harshal Sheth 2024-01-24 14:29:41 -08:00 committed by GitHub
parent c80383dd1a
commit 9b051e38d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 31 additions and 12 deletions

View File

@ -34,7 +34,7 @@ class AwsAssumeRoleConfig(PermissiveConfigModel):
def assume_role( def assume_role(
role: AwsAssumeRoleConfig, role: AwsAssumeRoleConfig,
aws_region: str, aws_region: Optional[str],
credentials: Optional[dict] = None, credentials: Optional[dict] = None,
) -> dict: ) -> dict:
credentials = credentials or {} credentials = credentials or {}
@ -93,7 +93,7 @@ class AwsConnectionConfig(ConfigModel):
default=None, default=None,
description="Named AWS profile to use. Only used if access key / secret are unset. If not set the default will be used", description="Named AWS profile to use. Only used if access key / secret are unset. If not set the default will be used",
) )
aws_region: str = Field(description="AWS region code.") aws_region: Optional[str] = Field(None, description="AWS region code.")
aws_endpoint_url: Optional[str] = Field( aws_endpoint_url: Optional[str] = Field(
default=None, default=None,

View File

@ -82,7 +82,7 @@ class SagemakerSource(Source):
env=self.env, env=self.env,
report=self.report, report=self.report,
job_type_filter=self.source_config.extract_jobs, job_type_filter=self.source_config.extract_jobs,
aws_region=self.source_config.aws_region, aws_region=self.sagemaker_client.meta.region_name,
) )
yield from job_processor.get_workunits() yield from job_processor.get_workunits()
@ -98,7 +98,7 @@ class SagemakerSource(Source):
model_image_to_jobs=model_image_to_jobs, model_image_to_jobs=model_image_to_jobs,
model_name_to_jobs=model_name_to_jobs, model_name_to_jobs=model_name_to_jobs,
lineage=lineage, lineage=lineage,
aws_region=self.source_config.aws_region, aws_region=self.sagemaker_client.meta.region_name,
) )
yield from model_processor.get_workunits() yield from model_processor.get_workunits()

View File

@ -81,7 +81,7 @@ class DBTCoreConfig(DBTCommonConfig):
if (values.get(f) or "").startswith("s3://") if (values.get(f) or "").startswith("s3://")
] ]
if uri_containing_fields and not aws_connection: if uri_containing_fields and aws_connection is None:
raise ValueError( raise ValueError(
f"Please provide aws_connection configuration, since s3 uris have been provided in fields {uri_containing_fields}" f"Please provide aws_connection configuration, since s3 uris have been provided in fields {uri_containing_fields}"
) )

View File

@ -1,6 +1,7 @@
from typing import Dict, List, Union from typing import Dict, List, Union
from unittest import mock from unittest import mock
import pytest
from pydantic import ValidationError from pydantic import ValidationError
from datahub.emitter import mce_builder from datahub.emitter import mce_builder
@ -180,14 +181,12 @@ def test_dbt_entity_emission_configuration():
"target_platform": "dummy_platform", "target_platform": "dummy_platform",
"entities_enabled": {"models": "Only", "seeds": "Only"}, "entities_enabled": {"models": "Only", "seeds": "Only"},
} }
try: with pytest.raises(
ValidationError,
match="Cannot have more than 1 type of entity emission set to ONLY",
):
DBTCoreConfig.parse_obj(config_dict) DBTCoreConfig.parse_obj(config_dict)
except ValidationError as ve:
assert len(ve.errors()) == 1
assert (
"Cannot have more than 1 type of entity emission set to ONLY"
in ve.errors()[0]["msg"]
)
# valid config # valid config
config_dict = { config_dict = {
"manifest_path": "dummy_path", "manifest_path": "dummy_path",
@ -198,6 +197,26 @@ def test_dbt_entity_emission_configuration():
DBTCoreConfig.parse_obj(config_dict) DBTCoreConfig.parse_obj(config_dict)
def test_dbt_s3_config():
# test missing aws config
config_dict: dict = {
"manifest_path": "s3://dummy_path",
"catalog_path": "s3://dummy_path",
"target_platform": "dummy_platform",
}
with pytest.raises(ValidationError, match="provide aws_connection"):
DBTCoreConfig.parse_obj(config_dict)
# valid config
config_dict = {
"manifest_path": "s3://dummy_path",
"catalog_path": "s3://dummy_path",
"target_platform": "dummy_platform",
"aws_connection": {},
}
DBTCoreConfig.parse_obj(config_dict)
def test_default_convert_column_urns_to_lowercase(): def test_default_convert_column_urns_to_lowercase():
config_dict = { config_dict = {
"manifest_path": "dummy_path", "manifest_path": "dummy_path",