From c7d35ffd6609d0ae79a2b1151a2221086ed4d8c5 Mon Sep 17 00:00:00 2001 From: Andrew Sikowitz Date: Mon, 27 Mar 2023 17:20:34 -0400 Subject: [PATCH] perf(ingest): Improve FileBackedDict iteration performance; minor refactoring (#7689) - Adds dirty bit to cache, only writes data if dirty - Refactors __iter__ - Adds sql_query_iterator - Adds items_snapshot, more performant `items()` that allows for filtering - Renames connection -> shared_connection - Removes unnecessary flush during close if connection is not shared - Adds Closeable mixin --- .../utilities/file_backed_collections.py | 107 +++++++++++++----- .../utilities/test_file_backed_collections.py | 55 +++++++-- 2 files changed, 121 insertions(+), 41 deletions(-) diff --git a/metadata-ingestion/src/datahub/utilities/file_backed_collections.py b/metadata-ingestion/src/datahub/utilities/file_backed_collections.py index 89fb7d10df..462595e4d1 100644 --- a/metadata-ingestion/src/datahub/utilities/file_backed_collections.py +++ b/metadata-ingestion/src/datahub/utilities/file_backed_collections.py @@ -5,6 +5,7 @@ import pickle import sqlite3 import tempfile from dataclasses import dataclass, field +from datetime import datetime from types import TracebackType from typing import ( Any, @@ -23,6 +24,8 @@ from typing import ( Union, ) +from datahub.ingestion.api.closeable import Closeable + logger: logging.Logger = logging.getLogger(__name__) _DEFAULT_FILE_NAME = "sqlite.db" @@ -31,7 +34,8 @@ _DEFAULT_MEMORY_CACHE_MAX_SIZE = 2000 _DEFAULT_MEMORY_CACHE_EVICTION_BATCH_SIZE = 200 # https://docs.python.org/3/library/sqlite3.html#sqlite-and-python-types -SqliteValue = Union[int, float, str, bytes, None] +# Datetimes get converted to strings +SqliteValue = Union[int, float, str, bytes, datetime, None] _VT = TypeVar("_VT") @@ -130,7 +134,7 @@ def _default_deserializer(value: Any) -> Any: @dataclass(eq=False) -class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]): +class FileBackedDict(MutableMapping[str, _VT], Generic[_VT], Closeable): """ A dict-like object that stores its data in a temporary SQLite database. This is useful for storing large amounts of data that don't fit in memory. @@ -139,7 +143,7 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]): """ # Use a predefined connection, able to be shared across multiple FileBacked* objects - connection: Optional[ConnectionWrapper] = None + shared_connection: Optional[ConnectionWrapper] = None tablename: str = _DEFAULT_TABLE_NAME serializer: Callable[[_VT], SqliteValue] = _default_serializer @@ -152,7 +156,10 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]): _conn: ConnectionWrapper = field(init=False, repr=False) # To improve performance, we maintain an in-memory LRU cache using an OrderedDict. - _active_object_cache: OrderedDict[str, _VT] = field(init=False, repr=False) + # Maintains a dirty bit marking whether the value has been modified since it was persisted. + _active_object_cache: OrderedDict[str, Tuple[_VT, bool]] = field( + init=False, repr=False + ) def __post_init__(self) -> None: assert ( @@ -162,8 +169,8 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]): assert "key" not in self.extra_columns, '"key" is a reserved column name' assert "value" not in self.extra_columns, '"value" is a reserved column name' - if self.connection: - self._conn = self.connection + if self.shared_connection: + self._conn = self.shared_connection else: self._conn = ConnectionWrapper() @@ -189,8 +196,8 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]): f"CREATE INDEX {self.tablename}_{column_name} ON {self.tablename} ({column_name})" ) - def _add_to_cache(self, key: str, value: _VT) -> None: - self._active_object_cache[key] = value + def _add_to_cache(self, key: str, value: _VT, dirty: bool) -> None: + self._active_object_cache[key] = value, dirty if len(self._active_object_cache) > self.cache_max_size: # Try to prune in batches rather than one at a time. @@ -202,12 +209,12 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]): def _prune_cache(self, num_items_to_prune: int) -> None: items_to_write: List[Tuple[SqliteValue, ...]] = [] for _ in range(num_items_to_prune): - key, value = self._active_object_cache.popitem(last=False) - - values = [key, self.serializer(value)] - for column_serializer in self.extra_columns.values(): - values.append(column_serializer(value)) - items_to_write.append(tuple(values)) + key, (value, dirty) = self._active_object_cache.popitem(last=False) + if dirty: + values = [key, self.serializer(value)] + for column_serializer in self.extra_columns.values(): + values.append(column_serializer(value)) + items_to_write.append(tuple(values)) if items_to_write: self._conn.executemany( @@ -226,7 +233,7 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]): def __getitem__(self, key: str) -> _VT: if key in self._active_object_cache: self._active_object_cache.move_to_end(key) - return self._active_object_cache[key] + return self._active_object_cache[key][0] cursor = self._conn.execute( f"SELECT value FROM {self.tablename} WHERE key = ?", (key,) @@ -236,11 +243,11 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]): raise KeyError(key) deserialized_result = self.deserializer(result[0]) - self._add_to_cache(key, deserialized_result) + self._add_to_cache(key, deserialized_result, False) return deserialized_result def __setitem__(self, key: str, value: _VT) -> None: - self._add_to_cache(key, value) + self._add_to_cache(key, value, True) def __delitem__(self, key: str) -> None: in_cache = False @@ -254,17 +261,43 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]): if not in_cache and not n_deleted: raise KeyError(key) + def mark_dirty(self, key: str) -> None: + if key in self._active_object_cache and not self._active_object_cache[key][1]: + self._active_object_cache[key] = self._active_object_cache[key][0], True + def __iter__(self) -> Iterator[str]: + # Cache should be small, so safe list cast to avoid mutation during iteration + cache_keys = set(self._active_object_cache.keys()) + yield from cache_keys + cursor = self._conn.execute(f"SELECT key FROM {self.tablename}") for row in cursor: - if row[0] in self._active_object_cache: - # If the key is in the active object cache, then SQL isn't the source of truth. - continue + if row[0] not in cache_keys: + yield row[0] - yield row[0] + def items_snapshot( + self, cond_sql: Optional[str] = None + ) -> Iterator[Tuple[str, _VT]]: + """ + Return a fixed snapshot, rather than a view, of the dictionary's items. - for key in self._active_object_cache: - yield key + Flushes the cache and provides the option to filter the results. + Provides better performance over standard `items()` method. + + Args: + cond_sql: Conditional expression for WHERE statement, e.g. `x = 0 AND y = "value"` + + Returns: + Iterator of filtered (key, value) pairs. + """ + self.flush() + sql = f"SELECT key, value FROM {self.tablename}" + if cond_sql: + sql += f" WHERE {cond_sql}" + + cursor = self._conn.execute(sql) + for row in cursor: + yield row[0], self.deserializer(row[1]) def __len__(self) -> int: cursor = self._conn.execute( @@ -282,6 +315,22 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]): params: Tuple[Any, ...] = (), refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None, ) -> List[Tuple[Any, ...]]: + return self._sql_query(query, params, refs).fetchall() + + def sql_query_iterator( + self, + query: str, + params: Tuple[Any, ...] = (), + refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None, + ) -> Iterator[Tuple[Any, ...]]: + return self._sql_query(query, params, refs) + + def _sql_query( + self, + query: str, + params: Tuple[Any, ...] = (), + refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None, + ) -> sqlite3.Cursor: # We need to flush object and any objects the query references to ensure # that we don't miss objects that have been modified but not yet flushed. self.flush() @@ -289,15 +338,13 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]): for referenced_table in refs: referenced_table.flush() - cursor = self._conn.execute(query, params) - return cursor.fetchall() + return self._conn.execute(query, params) def close(self) -> None: if self._conn: - # Ensure everything is written out. - self.flush() - - if not self.connection: # Connection created inside this class + if self.shared_connection: # Connection not owned by this object + self.flush() # Ensure everything is written out + else: self._conn.close() # This forces all writes to go directly to the DB so they fail immediately. @@ -330,7 +377,7 @@ class FileBackedList(Generic[_VT]): ) -> None: self._len = 0 self._dict = FileBackedDict( - connection=connection, + shared_connection=connection, serializer=serializer, deserializer=deserializer, tablename=tablename, diff --git a/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py b/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py index 826dc9e653..fc58954f04 100644 --- a/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py +++ b/metadata-ingestion/tests/unit/utilities/test_file_backed_collections.py @@ -1,6 +1,7 @@ import dataclasses import json import pathlib +import sqlite3 from dataclasses import dataclass from typing import Counter, Dict @@ -25,6 +26,8 @@ def test_file_dict() -> None: assert len(cache) == 100 assert sorted(cache) == sorted([f"key-{i}" for i in range(100)]) + assert sorted(cache.items()) == sorted([(f"key-{i}", i) for i in range(100)]) + assert sorted(cache.values()) == sorted([i for i in range(100)]) # Force eviction of everything. cache.flush() @@ -140,7 +143,7 @@ def test_custom_serde() -> None: assert cache["second"] == second assert cache["first"] == first - assert serializer_calls == 4 # Items written to cache on every access + assert serializer_calls == 2 assert deserializer_calls == 2 @@ -161,6 +164,7 @@ def test_file_dict_stores_counter() -> None: cache[str(i)][str(j)] += 100 in_memory_counters[i][str(j)] += 100 cache[str(i)][str(j)] += j + cache.mark_dirty(str(i)) in_memory_counters[i][str(j)] += j for i in range(n): @@ -174,43 +178,64 @@ class Pair: y: str -def test_custom_column() -> None: +@pytest.mark.parametrize("cache_max_size", [0, 1, 10]) +def test_custom_column(cache_max_size: int) -> None: cache = FileBackedDict[Pair]( extra_columns={ "x": lambda m: m.x, }, + cache_max_size=cache_max_size, ) cache["first"] = Pair(3, "a") cache["second"] = Pair(100, "b") + cache["third"] = Pair(27, "c") + cache["fourth"] = Pair(100, "d") # Verify that the extra column is present and has the correct values. - assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 103 + assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 230 # Verify that the extra column is updated when the value is updated. - cache["first"] = Pair(4, "c") - assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 104 + cache["first"] = Pair(4, "e") + assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 231 # Test param binding. assert ( cache.sql_query( f"SELECT sum(x) FROM {cache.tablename} WHERE x < ?", params=(50,) )[0][0] - == 4 + == 31 ) + assert sorted(list(cache.items())) == [ + ("first", Pair(4, "e")), + ("fourth", Pair(100, "d")), + ("second", Pair(100, "b")), + ("third", Pair(27, "c")), + ] + assert sorted(list(cache.items_snapshot())) == [ + ("first", Pair(4, "e")), + ("fourth", Pair(100, "d")), + ("second", Pair(100, "b")), + ("third", Pair(27, "c")), + ] + assert sorted(list(cache.items_snapshot("x < 50"))) == [ + ("first", Pair(4, "e")), + ("third", Pair(27, "c")), + ] + def test_shared_connection() -> None: with ConnectionWrapper() as connection: cache1 = FileBackedDict[int]( - connection=connection, + shared_connection=connection, tablename="cache1", extra_columns={ "v": lambda v: v, }, ) cache2 = FileBackedDict[Pair]( - connection=connection, + shared_connection=connection, tablename="cache2", extra_columns={ "x": lambda m: m.x, @@ -227,10 +252,12 @@ def test_shared_connection() -> None: assert len(cache1) == 2 assert len(cache2) == 3 - # Test advanced SQL queries. - assert cache2.sql_query( + # Test advanced SQL queries and sql_query_iterator. + iterator = cache2.sql_query_iterator( f"SELECT y, sum(x) FROM {cache2.tablename} GROUP BY y ORDER BY y" - ) == [("a", 15), ("b", 11)] + ) + assert type(iterator) == sqlite3.Cursor + assert list(iterator) == [("a", 15), ("b", 11)] # Test joining between the two tables. assert ( @@ -245,6 +272,12 @@ def test_shared_connection() -> None: ) == [("a", 45), ("b", 55)] ) + + assert list(cache2.items_snapshot('y = "a"')) == [ + ("ref-a-1", Pair(7, "a")), + ("ref-a-2", Pair(8, "a")), + ] + cache2.close() # Check can still use cache1