diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py index 421991a096..95ca10045f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py @@ -34,7 +34,7 @@ class AwsAssumeRoleConfig(PermissiveConfigModel): def assume_role( role: AwsAssumeRoleConfig, - aws_region: str, + aws_region: Optional[str], credentials: Optional[dict] = None, ) -> dict: credentials = credentials or {} @@ -93,7 +93,7 @@ class AwsConnectionConfig(ConfigModel): default=None, description="Named AWS profile to use. Only used if access key / secret are unset. If not set the default will be used", ) - aws_region: str = Field(description="AWS region code.") + aws_region: Optional[str] = Field(None, description="AWS region code.") aws_endpoint_url: Optional[str] = Field( default=None, diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py index 6f6e8bbc05..e335174eeb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py @@ -82,7 +82,7 @@ class SagemakerSource(Source): env=self.env, report=self.report, job_type_filter=self.source_config.extract_jobs, - aws_region=self.source_config.aws_region, + aws_region=self.sagemaker_client.meta.region_name, ) yield from job_processor.get_workunits() @@ -98,7 +98,7 @@ class SagemakerSource(Source): model_image_to_jobs=model_image_to_jobs, model_name_to_jobs=model_name_to_jobs, lineage=lineage, - aws_region=self.source_config.aws_region, + aws_region=self.sagemaker_client.meta.region_name, ) yield from model_processor.get_workunits() diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py index 6fd3c5ba30..a2f96264b7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py @@ -81,7 +81,7 @@ class DBTCoreConfig(DBTCommonConfig): if (values.get(f) or "").startswith("s3://") ] - if uri_containing_fields and not aws_connection: + if uri_containing_fields and aws_connection is None: raise ValueError( f"Please provide aws_connection configuration, since s3 uris have been provided in fields {uri_containing_fields}" ) diff --git a/metadata-ingestion/tests/unit/test_dbt_source.py b/metadata-ingestion/tests/unit/test_dbt_source.py index 0fbe9ecbcc..737cf6aca3 100644 --- a/metadata-ingestion/tests/unit/test_dbt_source.py +++ b/metadata-ingestion/tests/unit/test_dbt_source.py @@ -1,6 +1,7 @@ from typing import Dict, List, Union from unittest import mock +import pytest from pydantic import ValidationError from datahub.emitter import mce_builder @@ -180,14 +181,12 @@ def test_dbt_entity_emission_configuration(): "target_platform": "dummy_platform", "entities_enabled": {"models": "Only", "seeds": "Only"}, } - try: + with pytest.raises( + ValidationError, + match="Cannot have more than 1 type of entity emission set to ONLY", + ): DBTCoreConfig.parse_obj(config_dict) - except ValidationError as ve: - assert len(ve.errors()) == 1 - assert ( - "Cannot have more than 1 type of entity emission set to ONLY" - in ve.errors()[0]["msg"] - ) + # valid config config_dict = { "manifest_path": "dummy_path", @@ -198,6 +197,26 @@ def test_dbt_entity_emission_configuration(): DBTCoreConfig.parse_obj(config_dict) +def test_dbt_s3_config(): + # test missing aws config + config_dict: dict = { + "manifest_path": "s3://dummy_path", + "catalog_path": "s3://dummy_path", + "target_platform": "dummy_platform", + } + with pytest.raises(ValidationError, match="provide aws_connection"): + DBTCoreConfig.parse_obj(config_dict) + + # valid config + config_dict = { + "manifest_path": "s3://dummy_path", + "catalog_path": "s3://dummy_path", + "target_platform": "dummy_platform", + "aws_connection": {}, + } + DBTCoreConfig.parse_obj(config_dict) + + def test_default_convert_column_urns_to_lowercase(): config_dict = { "manifest_path": "dummy_path",