From 12e8a1fcf61ad0cd0392a33a37acbf5dd2788287 Mon Sep 17 00:00:00 2001 From: Abhishek Pandey Date: Mon, 30 May 2022 18:27:26 +0530 Subject: [PATCH] required-fields-updated-in-snowflake-and-athena (#5143) required-fields-updated-in-snowflake-and-athena (#5143) --- .../connections/database/athenaConnection.json | 8 ++------ .../connections/database/snowflakeConnection.json | 2 +- .../ingestion/source/database/sql_column_handler.py | 2 +- .../ingestion/source/database/sqlalchemy_source.py | 13 +++++++++---- ingestion/src/metadata/utils/source_connections.py | 9 +++++---- ingestion/tests/unit/test_source_connection.py | 11 +++++------ 6 files changed, 23 insertions(+), 22 deletions(-) diff --git a/catalog-rest-service/src/main/resources/json/schema/entity/services/connections/database/athenaConnection.json b/catalog-rest-service/src/main/resources/json/schema/entity/services/connections/database/athenaConnection.json index 369a2ed11ba..8234ec980d7 100644 --- a/catalog-rest-service/src/main/resources/json/schema/entity/services/connections/database/athenaConnection.json +++ b/catalog-rest-service/src/main/resources/json/schema/entity/services/connections/database/athenaConnection.json @@ -41,11 +41,6 @@ "description": "Host and port of the Athena service.", "type": "string" }, - "database": { - "title": "Database", - "description": "Database of the data source. This is optional parameter, if you would like to restrict the metadata reading to a single database. When left blank, OpenMetadata Ingestion attempts to scan all the databases.", - "type": "string" - }, "s3StagingDir": { "title": "S3 Staging Directory", "description": "S3 Staging Directory.", @@ -73,5 +68,6 @@ "$ref": "../connectionBasicType.json#/definitions/supportsProfiler" } }, - "additionalProperties": false + "additionalProperties": false, + "required": ["s3StagingDir", "awsConfig", "workgroup"] } diff --git a/catalog-rest-service/src/main/resources/json/schema/entity/services/connections/database/snowflakeConnection.json b/catalog-rest-service/src/main/resources/json/schema/entity/services/connections/database/snowflakeConnection.json index 379ff893fe5..eeccd3aa8d0 100644 --- a/catalog-rest-service/src/main/resources/json/schema/entity/services/connections/database/snowflakeConnection.json +++ b/catalog-rest-service/src/main/resources/json/schema/entity/services/connections/database/snowflakeConnection.json @@ -96,5 +96,5 @@ } }, "additionalProperties": false, - "required": ["username", "account"] + "required": ["username", "account", "password", "warehouse"] } diff --git a/ingestion/src/metadata/ingestion/source/database/sql_column_handler.py b/ingestion/src/metadata/ingestion/source/database/sql_column_handler.py index 1a28484fcdd..3b83ddae751 100644 --- a/ingestion/src/metadata/ingestion/source/database/sql_column_handler.py +++ b/ingestion/src/metadata/ingestion/source/database/sql_column_handler.py @@ -137,7 +137,7 @@ class SqlColumnHandler: ) table_columns = [] columns = inspector.get_columns( - table, schema, db_name=self.service_connection.database + table, schema, db_name=self._get_database_name() ) for column in columns: try: diff --git a/ingestion/src/metadata/ingestion/source/database/sqlalchemy_source.py b/ingestion/src/metadata/ingestion/source/database/sqlalchemy_source.py index 7d3b16e4c8d..08fb9eb41bf 100644 --- a/ingestion/src/metadata/ingestion/source/database/sqlalchemy_source.py +++ b/ingestion/src/metadata/ingestion/source/database/sqlalchemy_source.py @@ -127,12 +127,17 @@ class SqlAlchemySource(Source, ABC): Method to fetch tags associated with table """ - def get_database_entity(self, database_name: Optional[str]) -> Database: + def _get_database_name(self) -> str: + if hasattr(self.service_connection, "database"): + return self.service_connection.database or "default" + return "default" + + def get_database_entity(self) -> Database: """ Method to get database enetity from db name """ return Database( - name=database_name if database_name else "default", + name=self._get_database_name(), service=EntityReference( id=self.service.id, type=self.service_connection.type.value ), @@ -173,7 +178,7 @@ class SqlAlchemySource(Source, ABC): self.metadata, entity_type=DatabaseSchema, service_name=self.config.serviceName, - database_name=self.service_connection.database, + database_name=self._get_database_name(), schema_name=schema, ) yield from self.delete_tables(schema_fqn) @@ -253,7 +258,7 @@ class SqlAlchemySource(Source, ABC): schema, table_name, table_type, inspector ) - database = self.get_database_entity(self.service_connection.database) + database = self.get_database_entity() # check if we have any model to associate with table_entity.dataModel = self.get_data_model( database.name.__root__, schema, table_name diff --git a/ingestion/src/metadata/utils/source_connections.py b/ingestion/src/metadata/utils/source_connections.py index de8220eec86..25490a587d4 100644 --- a/ingestion/src/metadata/utils/source_connections.py +++ b/ingestion/src/metadata/utils/source_connections.py @@ -97,7 +97,8 @@ def get_connection_url_common(connection): url += "@" url += connection.hostPort - url += f"/{connection.database}" if connection.database else "" + if hasattr(connection, "database"): + url += f"/{connection.database}" if connection.database else "" options = ( connection.connectionOptions.dict() @@ -347,10 +348,10 @@ def _(connection: AthenaConnection): else: url += ":" url += f"@athena.{connection.awsConfig.awsRegion}.amazonaws.com:443" - if connection.database: - url += f"/{connection.database}" + url += f"?s3_staging_dir={quote_plus(connection.s3StagingDir)}" - url += f"&work_group={connection.workgroup}" + if connection.workgroup: + url += f"&work_group={connection.workgroup}" return url diff --git a/ingestion/tests/unit/test_source_connection.py b/ingestion/tests/unit/test_source_connection.py index f7f80ad4ff5..2cb8ff5ac8b 100644 --- a/ingestion/tests/unit/test_source_connection.py +++ b/ingestion/tests/unit/test_source_connection.py @@ -625,9 +625,9 @@ class SouceConnectionTest(TestCase): expected_args = {} snowflake_conn_obj = SnowflakeConnection( username="user", - password=None, + password="test-pwd", database="tiny", - connectionArguments=None, + warehouse="COMPUTE_WH", scheme=SnowflakeScheme.snowflake, account="account.region_name.cloud_service", ) @@ -637,8 +637,9 @@ class SouceConnectionTest(TestCase): expected_args = {"user": "user-to-be-impersonated"} snowflake_conn_obj = SnowflakeConnection( username="user", - password=None, + password="test-pwd", database="tiny", + warehouse="COMPUTE_WH", connectionArguments={"user": "user-to-be-impersonated"}, scheme=SnowflakeScheme.snowflake, account="account.region_name.cloud_service", @@ -657,18 +658,16 @@ class SouceConnectionTest(TestCase): s3StagingDir="s3athena-postgres", workgroup="primary", scheme=AthenaScheme.awsathena_rest, - database=None, ) assert expected_url == get_connection_url(athena_conn_obj) # connection arguments witho db - expected_url = "awsathena+rest://key:secret_key@athena.us-east-2.amazonaws.com:443/test?s3_staging_dir=s3athena-postgres&work_group=primary" + expected_url = "awsathena+rest://key:secret_key@athena.us-east-2.amazonaws.com:443?s3_staging_dir=s3athena-postgres&work_group=primary" athena_conn_obj = AthenaConnection( awsConfig=awsCreds, s3StagingDir="s3athena-postgres", workgroup="primary", scheme=AthenaScheme.awsathena_rest, - database="test", ) assert expected_url == get_connection_url(athena_conn_obj)