Fix Clickhouse Types (#10295)

This commit is contained in:
Mayur Singal 2023-02-23 14:36:15 +05:30 committed by GitHub
parent ef30577ace
commit 50af4990e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 24 deletions

View File

@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Clickhouse source module"""
import enum
from clickhouse_sqlalchemy.drivers.base import ClickHouseDialect, ischema_names
from clickhouse_sqlalchemy.drivers.http.transport import RequestsTransport, _get_type
@ -50,16 +49,42 @@ class AggregateFunction(String):
__visit_name__ = "AggregateFunction"
class Map(sqltypes.UserDefinedType): # pylint: disable=abstract-method
__visit_name__ = "Map"
class Array(sqltypes.UserDefinedType): # pylint: disable=abstract-method
__visit_name__ = "Array"
class Tuple(sqltypes.UserDefinedType): # pylint: disable=abstract-method
__visit_name__ = "Tuple"
class Enum(sqltypes.UserDefinedType): # pylint: disable=abstract-method
__visit_name__ = "Enum"
@reflection.cache
def _get_column_type(
self, name, spec
): # pylint: disable=protected-access,too-many-branches,too-many-return-statements
ischema_names.update({"AggregateFunction": AggregateFunction})
ischema_names.update(
{
"AggregateFunction": AggregateFunction,
"Map": Map,
"Array": Array,
"Tuple": Tuple,
"Enum": Enum,
}
)
ClickHouseDialect.ischema_names = ischema_names
if spec.startswith("Array"):
inner = spec[6:-1]
coltype = self.ischema_names["_array"]
return coltype(self._get_column_type(name, inner))
return self.ischema_names["Array"]
if spec.startswith("FixedString"):
return self.ischema_names["FixedString"]
@ -75,29 +100,13 @@ def _get_column_type(
return coltype(self._get_column_type(name, inner))
if spec.startswith("Tuple"):
inner = spec[6:-1]
coltype = self.ischema_names["_tuple"]
inner_types = [self._get_column_type(name, t.strip()) for t in inner.split(",")]
return coltype(*inner_types)
return self.ischema_names["Tuple"]
if spec.startswith("Map"):
inner = spec[4:-1]
coltype = self.ischema_names["_map"]
inner_types = [self._get_column_type(name, t.strip()) for t in inner.split(",")]
return coltype(*inner_types)
return self.ischema_names["Map"]
if spec.startswith("Enum"):
pos = spec.find("(")
coltype = self.ischema_names[spec[:pos]]
options = {}
if pos >= 0:
options = self._parse_options(spec[pos + 1 : spec.rfind(")")])
if not options:
return sqltypes.NullType
type_enum = enum.Enum(f"{name}_enum", options)
return lambda: coltype(type_enum)
return self.ischema_names["Enum"]
if spec.startswith("DateTime64"):
return self.ischema_names["DateTime64"]
@ -189,6 +198,32 @@ def get_table_comment(
)
def _get_column_info(
self, name, format_type, default_type, default_expression, comment
):
col_type = self._get_column_type( # pylint: disable=protected-access
name, format_type
)
col_default = self._get_column_default( # pylint: disable=protected-access
default_type, default_expression
)
result = {
"name": name,
"type": col_type,
"nullable": format_type.startswith("Nullable("),
"default": col_default,
"comment": comment or None,
}
raw_type = format_type.lower().replace("(", "<").replace(")", ">")
if col_type in [Map, Array, Tuple, Enum]:
result["display_type"] = raw_type
if col_type == Array:
result["raw_data_type"] = raw_type
return result
ClickHouseDialect.get_unique_constraints = get_unique_constraints
ClickHouseDialect.get_pk_constraint = get_pk_constraint
ClickHouseDialect._get_column_type = ( # pylint: disable=protected-access
@ -199,6 +234,9 @@ ClickHouseDialect.get_view_definition = get_view_definition
ClickHouseDialect.get_table_comment = get_table_comment
ClickHouseDialect.get_all_view_definitions = get_all_view_definitions
ClickHouseDialect.get_all_table_comments = get_all_table_comments
ClickHouseDialect._get_column_info = ( # pylint: disable=protected-access
_get_column_info
)
class ClickhouseSource(CommonDbSourceService):

View File

@ -94,6 +94,7 @@ class SqlColumnHandlerMixin:
data_type_display = column["type"]
if col_type == DataType.ARRAY.value and not arr_data_type:
arr_data_type = DataType.VARCHAR.value
data_type_display = data_type_display or column.get("display_type")
return data_type_display, arr_data_type, parsed_string
@staticmethod