diff --git a/ingestion/setup.py b/ingestion/setup.py index 8123fa59569..7c48bf2225c 100644 --- a/ingestion/setup.py +++ b/ingestion/setup.py @@ -85,7 +85,7 @@ plugins: Dict[str, Set[str]] = { "glue": {"boto3~=1.19.12"}, "dynamodb": {"boto3~=1.19.12"}, "hive": { - "pyhive~=0.6.3", + "pyhive~=0.6.5", "thrift~=0.13.0", "sasl==0.3.1", "thrift-sasl==0.4.3", diff --git a/ingestion/src/metadata/ingestion/source/database/hive.py b/ingestion/src/metadata/ingestion/source/database/hive.py index 0a106c4ad6e..21064e5e0ad 100644 --- a/ingestion/src/metadata/ingestion/source/database/hive.py +++ b/ingestion/src/metadata/ingestion/source/database/hive.py @@ -10,10 +10,12 @@ # limitations under the License. import re +from typing import Tuple from pyhive.sqlalchemy_hive import HiveDialect, _type_map from sqlalchemy import types, util +from metadata.generated.schema.entity.data.table import Table, TablePartition, TableType from metadata.generated.schema.entity.services.connections.database.hiveConnection import ( HiveConnection, ) @@ -83,6 +85,23 @@ def get_columns(self, connection, table_name, schema=None, **kw): return result +def get_table_names_older_versions(self, connection, schema=None, **kw): + query = "SHOW TABLES" + if schema: + query += " IN " + self.identifier_preparer.quote_identifier(schema) + tables_in_schema = connection.execute(query) + tables = [] + for row in tables_in_schema: + # check number of columns in result + # if it is > 1, we use spark thrift server with 3 columns in the result (schema, table, is_temporary) + # else it is hive with 1 column in the result + if len(row) > 1: + tables.append(row[1]) + else: + tables.append(row[0]) + return tables + + def get_table_names(self, connection, schema=None, **kw): query = "SHOW TABLES" if schema: @@ -97,7 +116,9 @@ def get_table_names(self, connection, schema=None, **kw): tables.append(row[1]) else: tables.append(row[0]) - views = get_view_names(self, connection, schema) + # "SHOW TABLES" command in hive also fetches view names + # Below code filters out view names from table names + views = self.get_view_names(connection, schema) return [table for table in tables if table not in views] @@ -118,9 +139,16 @@ def get_view_names(self, connection, schema=None, **kw): return views +def get_view_names_older_versions(self, connection, schema=None, **kw): + # Hive does not provide functionality to query tableType for older version + # This allows reflection to not crash at the cost of being inaccurate + return [] + + HiveDialect.get_columns = get_columns -HiveDialect.get_table_names = get_table_names -HiveDialect.get_view_names = get_view_names + + +HIVE_VERSION_WITH_VIEW_SUPPORT = "2.2.0" class HiveSource(CommonDbSourceService): @@ -133,3 +161,23 @@ class HiveSource(CommonDbSourceService): f"Expected HiveConnection, but got {connection}" ) return cls(config, metadata_config) + + def _parse_version(self, version: str) -> Tuple: + return tuple(map(int, (version.split(".")))) + + def prepare(self): + """ + Based on the version of hive update the get_table_names method + Fetching views in hive server with query "SHOW VIEWS" was possible + only after hive 2.2.0 version + """ + result = dict(self.engine.execute("SELECT VERSION()").fetchone()) + version = result.get("_c0", "").split() + if version and self._parse_version(version[0]) >= self._parse_version( + HIVE_VERSION_WITH_VIEW_SUPPORT + ): + HiveDialect.get_table_names = get_table_names + HiveDialect.get_view_names = get_view_names + else: + HiveDialect.get_table_names = get_table_names_older_versions + HiveDialect.get_view_names = get_view_names_older_versions