FIX #19798 - Shortening SQA __tablename__ to avoid hitting errors in … (#19809)

* FIX #19798 - Shortening SQA __tablename__ to avoid hitting errors in postgres

* fix tests

---------

Co-authored-by: Sriharsha Chintalapani <harshach@users.noreply.github.com>
This commit is contained in:
Pere Miquel Brull 2025-02-17 09:37:06 +01:00 committed by GitHub
parent 3ae2ff1c1c
commit 91b62fdc32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 41 additions and 17 deletions

View File

@ -12,6 +12,7 @@
Helper module to handle data sampling
for the profiler
"""
import hashlib
import traceback
from typing import List, Optional, Union, cast
@ -32,6 +33,7 @@ from metadata.profiler.orm.functions.modulo import ModuloFn
from metadata.profiler.orm.functions.random_num import RandomNumFn
from metadata.profiler.processor.handle_partition import build_partition_predicate
from metadata.sampler.sampler_interface import SamplerInterface
from metadata.utils.constants import UTF_8
from metadata.utils.helpers import is_safe_sql_query
from metadata.utils.logger import profiler_interface_registry_logger
@ -109,17 +111,28 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
query = self.get_partitioned_query(query)
return query
def get_sampler_table_name(self) -> str:
"""Get the base name of the SQA table for sampling.
We use MD5 as a hashing algorithm to generate a unique name for the table
keeping its length controlled. Otherwise, we ended up having issues
with names getting truncated when we add the suffixes to the identifiers
such as _sample, or _rnd.
"""
encoded_name = self.raw_dataset.__tablename__.encode(UTF_8)
hash_object = hashlib.md5(encoded_name)
return hash_object.hexdigest()
def get_sample_query(self, *, column=None) -> Query:
"""get query for sample data"""
if self.sample_config.profileSampleType == ProfileSampleType.PERCENTAGE:
rnd = self._base_sample_query(
column,
(ModuloFn(RandomNumFn(), 100)).label(RANDOM_LABEL),
).cte(f"{self.raw_dataset.__tablename__}_rnd")
).cte(f"{self.get_sampler_table_name()}_rnd")
session_query = self.client.query(rnd)
return session_query.where(
rnd.c.random <= self.sample_config.profileSample
).cte(f"{self.raw_dataset.__tablename__}_sample")
).cte(f"{self.get_sampler_table_name()}_sample")
table_query = self.client.query(self.raw_dataset)
session_query = self._base_sample_query(
@ -129,7 +142,7 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
return (
session_query.order_by(RANDOM_LABEL)
.limit(self.sample_config.profileSample)
.cte(f"{self.raw_dataset.__tablename__}_rnd")
.cte(f"{self.get_sampler_table_name()}_rnd")
)
def get_dataset(self, column=None, **__) -> Union[DeclarativeMeta, AliasedClass]:
@ -143,7 +156,7 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
if not self.sample_config.profileSample:
if self.partition_details:
partitioned = self._partitioned_table()
return partitioned.cte(f"{self.raw_dataset.__tablename__}_partitioned")
return partitioned.cte(f"{self.get_sampler_table_name()}_partitioned")
return self.raw_dataset
@ -224,7 +237,7 @@ class SQASampler(SamplerInterface, SQAInterfaceMixin):
stmt = stmt.columns(*list(inspect(self.raw_dataset).c))
return self.client.query(stmt.subquery()).cte(
f"{self.raw_dataset.__tablename__}_user_sampled"
f"{self.get_sampler_table_name()}_user_sampled"
)
def _partitioned_table(self) -> Query:

View File

@ -162,7 +162,7 @@ PROFILER_INGESTION_CONFIG_TEMPLATE = dedent(
"serviceConnection": {{
"config": {service_config}
}},
"sourceConfig": {{"config": {{"type":"Profiler"}}}}
"sourceConfig": {{"config": {{"type":"Profiler", "profileSample": 100}}}}
}},
"processor": {{"type": "orm-profiler", "config": {{}}}},
"sink": {{"type": "metadata-rest", "config": {{}}}},

View File

@ -149,9 +149,9 @@ class SampleTest(TestCase):
)
query: CTE = sampler.get_sample_query()
expected_query = (
"WITH users_rnd AS \n(SELECT users.id AS id, ABS(RANDOM()) * 100 %% 100 AS random \n"
"FROM users)\n SELECT users_rnd.id, users_rnd.random \n"
"FROM users_rnd \nWHERE users_rnd.random <= 50.0"
'WITH "9bc65c2abec141778ffaa729489f3e87_rnd" AS \n(SELECT users.id AS id, ABS(RANDOM()) * 100 %% 100 AS random \n'
'FROM users)\n SELECT "9bc65c2abec141778ffaa729489f3e87_rnd".id, "9bc65c2abec141778ffaa729489f3e87_rnd".random \n'
'FROM "9bc65c2abec141778ffaa729489f3e87_rnd" \nWHERE "9bc65c2abec141778ffaa729489f3e87_rnd".random <= 50.0'
)
assert (
expected_query.casefold()
@ -191,9 +191,9 @@ class SampleTest(TestCase):
)
query: CTE = sampler.get_sample_query()
expected_query = (
"WITH users_rnd AS \n(SELECT users.id AS id, ABS(RANDOM()) * 100 %% 100 AS random \n"
"FROM users \nWHERE id in ('1', '2'))\n SELECT users_rnd.id, users_rnd.random \n"
"FROM users_rnd \nWHERE users_rnd.random <= 50.0"
'WITH "9bc65c2abec141778ffaa729489f3e87_rnd" AS \n(SELECT users.id AS id, ABS(RANDOM()) * 100 %% 100 AS random \n'
"FROM users \nWHERE id in ('1', '2'))\n SELECT \"9bc65c2abec141778ffaa729489f3e87_rnd\".id, \"9bc65c2abec141778ffaa729489f3e87_rnd\".random \n"
'FROM "9bc65c2abec141778ffaa729489f3e87_rnd" \nWHERE "9bc65c2abec141778ffaa729489f3e87_rnd".random <= 50.0'
)
assert (
expected_query.casefold()

View File

@ -18,6 +18,15 @@ class User(Base):
age = Column(Integer)
class UserWithLongName(Base):
__tablename__ = "u" * 63 # Keep a max length name of 63 chars (max for Postgres)
id = Column(Integer, primary_key=True)
name = Column(String(256))
fullname = Column(String(256))
nickname = Column(String(256))
age = Column(Integer)
class SQATestUtils:
def __init__(self, connection_url: str):
self.connection_url = connection_url
@ -34,14 +43,16 @@ class SQATestUtils:
self.session.commit()
def load_user_data(self):
for clz in (User, UserWithLongName):
data = [
User(name="John", fullname="John Doe", nickname="johnny b goode", age=30), # type: ignore
User(name="Jane", fullname="Jone Doe", nickname=None, age=31), # type: ignore
clz(name="John", fullname="John Doe", nickname="johnny b goode", age=30), # type: ignore
clz(name="Jane", fullname="Jone Doe", nickname=None, age=31), # type: ignore
] * 20
self.load_data(data)
def create_user_table(self):
User.__table__.create(bind=self.session.get_bind())
UserWithLongName.__table__.create(bind=self.session.get_bind())
def close(self):
self.session.close()