diff --git a/ingestion/examples/workflows/hive.json b/ingestion/examples/workflows/hive.json index 62fbd6e8af2..744a74a5c6c 100644 --- a/ingestion/examples/workflows/hive.json +++ b/ingestion/examples/workflows/hive.json @@ -1,20 +1,23 @@ { "source": { "type": "hive", - "config": { - "service_name": "local_hive", - "host_port": "localhost:10000" - } + "serviceName": "local_hive", + "serviceConnection": { + "config": { + "type":"Hive", + "hostPort": "localhost:10000" + } + }, + "sourceConfig": {"config": {"enableDataProfiler": false}} }, "sink": { "type": "metadata-rest", "config": {} }, - "metadata_server": { - "type": "metadata-server", - "config": { - "api_endpoint": "http://localhost:8585/api", - "auth_provider_type": "no-auth" + "workflowConfig": { + "openMetadataServerConfig": { + "hostPort": "http://localhost:8585/api", + "authProvider": "no-auth" } } } diff --git a/ingestion/src/metadata/ingestion/source/hive.py b/ingestion/src/metadata/ingestion/source/hive.py index 56d672374a9..93abd686831 100644 --- a/ingestion/src/metadata/ingestion/source/hive.py +++ b/ingestion/src/metadata/ingestion/source/hive.py @@ -15,7 +15,6 @@ from pyhive.sqlalchemy_hive import HiveDialect, _type_map from sqlalchemy import types, util from metadata.ingestion.source.sql_source import SQLSource -from metadata.ingestion.source.sql_source_common import SQLConnectionConfig complex_data_types = ["struct", "map", "array", "union"] @@ -58,21 +57,26 @@ HiveDialect.get_columns = get_columns from metadata.generated.schema.entity.services.connections.database.hiveConnection import ( HiveSQLConnection, ) - - -class HiveConfig(HiveSQLConnection, SQLConnectionConfig): - def get_connection_url(self): - url = super().get_connection_url() - if self.authOptions: - return f"{url};{self.authOptions}" - return url +from metadata.generated.schema.metadataIngestion.workflow import ( + OpenMetadataServerConfig, +) +from metadata.generated.schema.metadataIngestion.workflow import ( + Source as WorkflowSource, +) +from metadata.ingestion.api.source import InvalidSourceException class HiveSource(SQLSource): - def __init__(self, config, metadata_config): - super().__init__(config, metadata_config) + def prepare(self): + self.service_connection.database = "default" + return super().prepare() @classmethod def create(cls, config_dict, metadata_config: OpenMetadataServerConfig): - config = HiveConfig.parse_obj(config_dict) + config: HiveSQLConnection = WorkflowSource.parse_obj(config_dict) + connection: HiveSQLConnection = config.serviceConnection.__root__.config + if not isinstance(connection, HiveSQLConnection): + raise InvalidSourceException( + f"Expected HiveSQLConnection, but got {connection}" + ) return cls(config, metadata_config) diff --git a/ingestion/src/metadata/utils/source_connections.py b/ingestion/src/metadata/utils/source_connections.py index 87ec7cd3e63..046eb1f9e5f 100644 --- a/ingestion/src/metadata/utils/source_connections.py +++ b/ingestion/src/metadata/utils/source_connections.py @@ -22,6 +22,9 @@ from metadata.generated.schema.entity.services.connections.database.clickhouseCo from metadata.generated.schema.entity.services.connections.database.databricksConnection import ( DatabricksConnection, ) +from metadata.generated.schema.entity.services.connections.database.hiveConnection import ( + HiveSQLConnection, +) from metadata.generated.schema.entity.services.connections.database.mssqlConnection import ( MssqlConnection, ) @@ -192,4 +195,9 @@ def _(connection: SnowflakeConnection): ) url = f"{url}?{params}" +@get_connection_url.register +def _(connection: HiveSQLConnection): + url = get_connection_url_common(connection) + if connection.authOptions: + return f"{url};{connection.authOptions}" return url diff --git a/ingestion/tests/unit/source_connection/test_databricks_connection.py b/ingestion/tests/unit/source_connection/test_databricks_connection.py deleted file mode 100644 index fe6ee0cb90a..00000000000 --- a/ingestion/tests/unit/source_connection/test_databricks_connection.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2021 Collate -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from unittest import TestCase - -from metadata.generated.schema.entity.services.connections.database.databricksConnection import ( - DatabricksConnection, - DatabricksScheme, -) -from metadata.utils.source_connections import get_connection_url - - -class DatabricksConnectionTest(TestCase): - def test_connection_url_without_db(self): - expected_result = ( - "databricks+connector://token:KlivDTACWXKmZVfN1qIM@1.1.1.1:443" - ) - databricks_conn_obj = DatabricksConnection( - scheme=DatabricksScheme.databricks_connector, - hostPort="1.1.1.1:443", - token="KlivDTACWXKmZVfN1qIM", - ) - assert expected_result == get_connection_url(databricks_conn_obj) - - def test_connection_url_with_db(self): - expected_result = ( - "databricks+connector://token:KlivDTACWXKmZVfN1qIM@1.1.1.1:443/default" - ) - databricks_conn_obj = DatabricksConnection( - scheme=DatabricksScheme.databricks_connector, - hostPort="1.1.1.1:443", - token="KlivDTACWXKmZVfN1qIM", - database="default", - ) - assert expected_result == get_connection_url(databricks_conn_obj) diff --git a/ingestion/tests/unit/source_connection/test_trino_connection.py b/ingestion/tests/unit/test_source_connection.py similarity index 50% rename from ingestion/tests/unit/source_connection/test_trino_connection.py rename to ingestion/tests/unit/test_source_connection.py index c59c0cf8890..b1a8c5d3a14 100644 --- a/ingestion/tests/unit/source_connection/test_trino_connection.py +++ b/ingestion/tests/unit/test_source_connection.py @@ -12,6 +12,14 @@ from unittest import TestCase +from metadata.generated.schema.entity.services.connections.database.databricksConnection import ( + DatabricksConnection, + DatabricksScheme, +) +from metadata.generated.schema.entity.services.connections.database.hiveConnection import ( + HiveScheme, + HiveSQLConnection, +) from metadata.generated.schema.entity.services.connections.database.trinoConnection import ( TrinoConnection, TrinoScheme, @@ -19,8 +27,48 @@ from metadata.generated.schema.entity.services.connections.database.trinoConnect from metadata.utils.source_connections import get_connection_args, get_connection_url -class TrinoConnectionTest(TestCase): - def test_connection_url_without_params(self): +class SouceConnectionTest(TestCase): + def test_databricks_url_without_db(self): + expected_result = ( + "databricks+connector://token:KlivDTACWXKmZVfN1qIM@1.1.1.1:443" + ) + databricks_conn_obj = DatabricksConnection( + scheme=DatabricksScheme.databricks_connector, + hostPort="1.1.1.1:443", + token="KlivDTACWXKmZVfN1qIM", + ) + assert expected_result == get_connection_url(databricks_conn_obj) + + def test_databricks_url_with_db(self): + expected_result = ( + "databricks+connector://token:KlivDTACWXKmZVfN1qIM@1.1.1.1:443/default" + ) + databricks_conn_obj = DatabricksConnection( + scheme=DatabricksScheme.databricks_connector, + hostPort="1.1.1.1:443", + token="KlivDTACWXKmZVfN1qIM", + database="default", + ) + assert expected_result == get_connection_url(databricks_conn_obj) + + def test_hive_url(self): + expected_result = "hive://localhost:10000/default" + databricks_conn_obj = HiveSQLConnection( + scheme=HiveScheme.hive, hostPort="localhost:10000", database="default" + ) + assert expected_result == get_connection_url(databricks_conn_obj) + + def test_hive_url_auth(self): + expected_result = "hive://localhost:10000/default;auth=CUSTOM" + databricks_conn_obj = HiveSQLConnection( + scheme=HiveScheme.hive, + hostPort="localhost:10000", + database="default", + authOptions="auth=CUSTOM", + ) + assert expected_result == get_connection_url(databricks_conn_obj) + + def test_trino_url_without_params(self): expected_url = "trino://username:pass@localhost:443/catalog" trino_conn_obj = TrinoConnection( scheme=TrinoScheme.trino, @@ -31,7 +79,7 @@ class TrinoConnectionTest(TestCase): ) assert expected_url == get_connection_url(trino_conn_obj) - def test_connection_url_with_params(self): + def test_trino_url_with_params(self): expected_url = "trino://username:pass@localhost:443/catalog?param=value" trino_conn_obj = TrinoConnection( scheme=TrinoScheme.trino, @@ -43,7 +91,7 @@ class TrinoConnectionTest(TestCase): ) assert expected_url == get_connection_url(trino_conn_obj) - def test_connection_with_proxies(self): + def test_trino_with_proxies(self): test_proxies = {"http": "http_proxy", "https": "https_proxy"} trino_conn_obj = TrinoConnection( scheme=TrinoScheme.trino,