feat(ingest/datahub): use stream_results with mysql (#12278)

This commit is contained in:
Harshal Sheth 2025-01-06 18:29:51 -05:00 committed by GitHub
parent 30a77c022a
commit a06a229499
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 19 deletions

View File

@ -461,7 +461,7 @@ plugins: Dict[str, Set[str]] = {
"mssql-odbc": sql_common | mssql_common | {"pyodbc"},
"mysql": mysql,
# mariadb should have same dependency as mysql
"mariadb": sql_common | {"pymysql>=1.0.2"},
"mariadb": sql_common | mysql,
"okta": {"okta~=1.7.0", "nest-asyncio"},
"oracle": sql_common | {"oracledb"},
"postgres": sql_common | postgres_common,

View File

@ -1,6 +1,7 @@
import os
from typing import Optional, Set
import pydantic
from pydantic import Field, root_validator
from datahub.configuration.common import AllowDenyPattern
@ -119,3 +120,12 @@ class DataHubSourceConfig(StatefulIngestionConfigBase):
" Please specify at least one of `database_connection` or `kafka_connection`, ideally both."
)
return values
@pydantic.validator("database_connection")
def validate_mysql_scheme(
cls, v: SQLAlchemyConnectionConfig
) -> SQLAlchemyConnectionConfig:
if "mysql" in v.scheme:
if v.scheme != "mysql+pymysql":
raise ValueError("For MySQL, the scheme must be mysql+pymysql.")
return v

View File

@ -151,8 +151,10 @@ class DataHubDatabaseReader:
self, query: str, params: Dict[str, Any]
) -> Iterable[Dict[str, Any]]:
with self.engine.connect() as conn:
if self.engine.dialect.name == "postgresql":
if self.engine.dialect.name in ["postgresql", "mysql", "mariadb"]:
with conn.begin(): # Transaction required for PostgreSQL server-side cursor
# Note that stream_results=True is mainly supported by PostgreSQL and MySQL-based dialects.
# https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.Connection.execution_options.params.stream_results
conn = conn.execution_options(
stream_results=True,
yield_per=self.config.database_query_batch_size,
@ -160,22 +162,6 @@ class DataHubDatabaseReader:
result = conn.execute(query, params)
for row in result:
yield dict(row)
elif self.engine.dialect.name == "mysql": # MySQL
import MySQLdb
with contextlib.closing(
conn.connection.cursor(MySQLdb.cursors.SSCursor)
) as cursor:
logger.debug(f"Using Cursor type: {cursor.__class__.__name__}")
cursor.execute(query, params)
columns = [desc[0] for desc in cursor.description]
while True:
rows = cursor.fetchmany(self.config.database_query_batch_size)
if not rows:
break # Use break instead of return in generator
for row in rows:
yield dict(zip(columns, row))
else:
raise ValueError(f"Unsupported dialect: {self.engine.dialect.name}")

View File

@ -130,7 +130,7 @@ class DataHubSource(StatefulIngestionSourceBase):
self._commit_progress(i)
def _get_kafka_workunits(
self, from_offsets: Dict[int, int], soft_deleted_urns: List[str] = []
self, from_offsets: Dict[int, int], soft_deleted_urns: List[str]
) -> Iterable[MetadataWorkUnit]:
if self.config.kafka_connection is None:
return