FIX #19798 - Shortening SQA __tablename__ to avoid hitting errors in … (#19809)

* FIX #19798 - Shortening SQA __tablename__ to avoid hitting errors in postgres

* fix tests

---------

Co-authored-by: Sriharsha Chintalapani <harshach@users.noreply.github.com>
This commit is contained in:
Pere Miquel Brull 2025-02-17 09:37:06 +01:00 committed by GitHub
parent 3ae2ff1c1c
commit 91b62fdc32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 41 additions and 17 deletions

View File

@ -12,6 +12,7 @@
Helper module to handle data sampling Helper module to handle data sampling
for the profiler for the profiler
""" """
import hashlib
import traceback import traceback
from typing import List, Optional, Union, cast from typing import List, Optional, Union, cast
@ -32,6 +33,7 @@ from metadata.profiler.orm.functions.modulo import ModuloFn
from metadata.profiler.orm.functions.random_num import RandomNumFn from metadata.profiler.orm.functions.random_num import RandomNumFn
from metadata.profiler.processor.handle_partition import build_partition_predicate from metadata.profiler.processor.handle_partition import build_partition_predicate
from metadata.sampler.sampler_interface import SamplerInterface from metadata.sampler.sampler_interface import SamplerInterface
from metadata.utils.constants import UTF_8
from metadata.utils.helpers import is_safe_sql_query from metadata.utils.helpers import is_safe_sql_query
from metadata.utils.logger import profiler_interface_registry_logger from metadata.utils.logger import profiler_interface_registry_logger
@ -109,17 +111,28 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
query = self.get_partitioned_query(query) query = self.get_partitioned_query(query)
return query return query
def get_sampler_table_name(self) -> str:
"""Get the base name of the SQA table for sampling.
We use MD5 as a hashing algorithm to generate a unique name for the table
keeping its length controlled. Otherwise, we ended up having issues
with names getting truncated when we add the suffixes to the identifiers
such as _sample, or _rnd.
"""
encoded_name = self.raw_dataset.__tablename__.encode(UTF_8)
hash_object = hashlib.md5(encoded_name)
return hash_object.hexdigest()
def get_sample_query(self, *, column=None) -> Query: def get_sample_query(self, *, column=None) -> Query:
"""get query for sample data""" """get query for sample data"""
if self.sample_config.profileSampleType == ProfileSampleType.PERCENTAGE: if self.sample_config.profileSampleType == ProfileSampleType.PERCENTAGE:
rnd = self._base_sample_query( rnd = self._base_sample_query(
column, column,
(ModuloFn(RandomNumFn(), 100)).label(RANDOM_LABEL), (ModuloFn(RandomNumFn(), 100)).label(RANDOM_LABEL),
).cte(f"{self.raw_dataset.__tablename__}_rnd") ).cte(f"{self.get_sampler_table_name()}_rnd")
session_query = self.client.query(rnd) session_query = self.client.query(rnd)
return session_query.where( return session_query.where(
rnd.c.random <= self.sample_config.profileSample rnd.c.random <= self.sample_config.profileSample
).cte(f"{self.raw_dataset.__tablename__}_sample") ).cte(f"{self.get_sampler_table_name()}_sample")
table_query = self.client.query(self.raw_dataset) table_query = self.client.query(self.raw_dataset)
session_query = self._base_sample_query( session_query = self._base_sample_query(
@ -129,7 +142,7 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
return ( return (
session_query.order_by(RANDOM_LABEL) session_query.order_by(RANDOM_LABEL)
.limit(self.sample_config.profileSample) .limit(self.sample_config.profileSample)
.cte(f"{self.raw_dataset.__tablename__}_rnd") .cte(f"{self.get_sampler_table_name()}_rnd")
) )
def get_dataset(self, column=None, **__) -> Union[DeclarativeMeta, AliasedClass]: def get_dataset(self, column=None, **__) -> Union[DeclarativeMeta, AliasedClass]:
@ -143,7 +156,7 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
if not self.sample_config.profileSample: if not self.sample_config.profileSample:
if self.partition_details: if self.partition_details:
partitioned = self._partitioned_table() partitioned = self._partitioned_table()
return partitioned.cte(f"{self.raw_dataset.__tablename__}_partitioned") return partitioned.cte(f"{self.get_sampler_table_name()}_partitioned")
return self.raw_dataset return self.raw_dataset
@ -224,7 +237,7 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
stmt = stmt.columns(*list(inspect(self.raw_dataset).c)) stmt = stmt.columns(*list(inspect(self.raw_dataset).c))
return self.client.query(stmt.subquery()).cte( return self.client.query(stmt.subquery()).cte(
f"{self.raw_dataset.__tablename__}_user_sampled" f"{self.get_sampler_table_name()}_user_sampled"
) )
def _partitioned_table(self) -> Query: def _partitioned_table(self) -> Query:

View File

@ -162,7 +162,7 @@ PROFILER_INGESTION_CONFIG_TEMPLATE = dedent(
"serviceConnection": {{ "serviceConnection": {{
"config": {service_config} "config": {service_config}
}}, }},
"sourceConfig": {{"config": {{"type":"Profiler"}}}} "sourceConfig": {{"config": {{"type":"Profiler", "profileSample": 100}}}}
}}, }},
"processor": {{"type": "orm-profiler", "config": {{}}}}, "processor": {{"type": "orm-profiler", "config": {{}}}},
"sink": {{"type": "metadata-rest", "config": {{}}}}, "sink": {{"type": "metadata-rest", "config": {{}}}},

View File

@ -149,9 +149,9 @@ class SampleTest(TestCase):
) )
query: CTE = sampler.get_sample_query() query: CTE = sampler.get_sample_query()
expected_query = ( expected_query = (
"WITH users_rnd AS \n(SELECT users.id AS id, ABS(RANDOM()) * 100 %% 100 AS random \n" 'WITH "9bc65c2abec141778ffaa729489f3e87_rnd" AS \n(SELECT users.id AS id, ABS(RANDOM()) * 100 %% 100 AS random \n'
"FROM users)\n SELECT users_rnd.id, users_rnd.random \n" 'FROM users)\n SELECT "9bc65c2abec141778ffaa729489f3e87_rnd".id, "9bc65c2abec141778ffaa729489f3e87_rnd".random \n'
"FROM users_rnd \nWHERE users_rnd.random <= 50.0" 'FROM "9bc65c2abec141778ffaa729489f3e87_rnd" \nWHERE "9bc65c2abec141778ffaa729489f3e87_rnd".random <= 50.0'
) )
assert ( assert (
expected_query.casefold() expected_query.casefold()
@ -191,9 +191,9 @@ class SampleTest(TestCase):
) )
query: CTE = sampler.get_sample_query() query: CTE = sampler.get_sample_query()
expected_query = ( expected_query = (
"WITH users_rnd AS \n(SELECT users.id AS id, ABS(RANDOM()) * 100 %% 100 AS random \n" 'WITH "9bc65c2abec141778ffaa729489f3e87_rnd" AS \n(SELECT users.id AS id, ABS(RANDOM()) * 100 %% 100 AS random \n'
"FROM users \nWHERE id in ('1', '2'))\n SELECT users_rnd.id, users_rnd.random \n" "FROM users \nWHERE id in ('1', '2'))\n SELECT \"9bc65c2abec141778ffaa729489f3e87_rnd\".id, \"9bc65c2abec141778ffaa729489f3e87_rnd\".random \n"
"FROM users_rnd \nWHERE users_rnd.random <= 50.0" 'FROM "9bc65c2abec141778ffaa729489f3e87_rnd" \nWHERE "9bc65c2abec141778ffaa729489f3e87_rnd".random <= 50.0'
) )
assert ( assert (
expected_query.casefold() expected_query.casefold()

View File

@ -18,6 +18,15 @@ class User(Base):
age = Column(Integer) age = Column(Integer)
class UserWithLongName(Base):
__tablename__ = "u" * 63 # Keep a max length name of 63 chars (max for Postgres)
id = Column(Integer, primary_key=True)
name = Column(String(256))
fullname = Column(String(256))
nickname = Column(String(256))
age = Column(Integer)
class SQATestUtils: class SQATestUtils:
def __init__(self, connection_url: str): def __init__(self, connection_url: str):
self.connection_url = connection_url self.connection_url = connection_url
@ -34,14 +43,16 @@ class SQATestUtils:
self.session.commit() self.session.commit()
def load_user_data(self): def load_user_data(self):
for clz in (User, UserWithLongName):
data = [ data = [
User(name="John", fullname="John Doe", nickname="johnny b goode", age=30), # type: ignore clz(name="John", fullname="John Doe", nickname="johnny b goode", age=30), # type: ignore
User(name="Jane", fullname="Jone Doe", nickname=None, age=31), # type: ignore clz(name="Jane", fullname="Jone Doe", nickname=None, age=31), # type: ignore
] * 20 ] * 20
self.load_data(data) self.load_data(data)
def create_user_table(self): def create_user_table(self):
User.__table__.create(bind=self.session.get_bind()) User.__table__.create(bind=self.session.get_bind())
UserWithLongName.__table__.create(bind=self.session.get_bind())
def close(self): def close(self):
self.session.close() self.session.close()