diff --git a/metadata-ingestion/src/datahub/utilities/file_backed_collections.py b/metadata-ingestion/src/datahub/utilities/file_backed_collections.py index 462595e4d1..fc00125661 100644 --- a/metadata-ingestion/src/datahub/utilities/file_backed_collections.py +++ b/metadata-ingestion/src/datahub/utilities/file_backed_collections.py @@ -64,6 +64,7 @@ class ConnectionWrapper: filename = pathlib.Path(self._directory.name) / _DEFAULT_FILE_NAME self.conn = sqlite3.connect(filename, isolation_level=None) + self.conn.row_factory = sqlite3.Row self.filename = filename # These settings are optimized for performance. @@ -314,7 +315,7 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT], Closeable): query: str, params: Tuple[Any, ...] = (), refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None, - ) -> List[Tuple[Any, ...]]: + ) -> List[sqlite3.Row]: return self._sql_query(query, params, refs).fetchall() def sql_query_iterator( @@ -322,7 +323,7 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT], Closeable): query: str, params: Tuple[Any, ...] = (), refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None, - ) -> Iterator[Tuple[Any, ...]]: + ) -> Iterator[sqlite3.Row]: return self._sql_query(query, params, refs) def _sql_query( @@ -422,7 +423,7 @@ class FileBackedList(Generic[_VT]): query: str, params: Tuple[Any, ...] = (), refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None, - ) -> List[Tuple[Any, ...]]: + ) -> List[sqlite3.Row]: return self._dict.sql_query(query, params, refs=refs) def close(self) -> None: diff --git a/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py b/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py index fc58954f04..d771feae7b 100644 --- a/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py +++ b/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py @@ -257,21 +257,19 @@ def test_shared_connection() -> None: f"SELECT y, sum(x) FROM {cache2.tablename} GROUP BY y ORDER BY y" ) assert type(iterator) == sqlite3.Cursor - assert list(iterator) == [("a", 15), ("b", 11)] + assert [tuple(r) for r in iterator] == [("a", 15), ("b", 11)] # Test joining between the two tables. - assert ( - cache2.sql_query( - f""" - SELECT cache2.y, sum(cache2.x * cache1.v) FROM {cache2.tablename} cache2 - LEFT JOIN {cache1.tablename} cache1 ON cache1.key = cache2.y - GROUP BY cache2.y - ORDER BY cache2.y - """, - refs=[cache1], - ) - == [("a", 45), ("b", 55)] + rows = cache2.sql_query( + f""" + SELECT cache2.y, sum(cache2.x * cache1.v) FROM {cache2.tablename} cache2 + LEFT JOIN {cache1.tablename} cache1 ON cache1.key = cache2.y + GROUP BY cache2.y + ORDER BY cache2.y + """, + refs=[cache1], ) + assert [tuple(row) for row in rows] == [("a", 45), ("b", 55)] assert list(cache2.items_snapshot('y = "a"')) == [ ("ref-a-1", Pair(7, "a")),