perf(ingest): Improve FileBackedDict iteration performance; minor refactoring (#7689)

- Adds dirty bit to cache, only writes data if dirty
- Refactors __iter__
- Adds sql_query_iterator
- Adds items_snapshot, more performant `items()` that allows for filtering
- Renames connection -> shared_connection
- Removes unnecessary flush during close if connection is not shared
- Adds Closeable mixin
This commit is contained in:
Andrew Sikowitz 2023-03-27 17:20:34 -04:00 committed by GitHub
parent 04f1b86d54
commit c7d35ffd66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 121 additions and 41 deletions

View File

@ -5,6 +5,7 @@ import pickle
import sqlite3 import sqlite3
import tempfile import tempfile
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime
from types import TracebackType from types import TracebackType
from typing import ( from typing import (
Any, Any,
@ -23,6 +24,8 @@ from typing import (
Union, Union,
) )
from datahub.ingestion.api.closeable import Closeable
logger: logging.Logger = logging.getLogger(__name__) logger: logging.Logger = logging.getLogger(__name__)
_DEFAULT_FILE_NAME = "sqlite.db" _DEFAULT_FILE_NAME = "sqlite.db"
@ -31,7 +34,8 @@ _DEFAULT_MEMORY_CACHE_MAX_SIZE = 2000
_DEFAULT_MEMORY_CACHE_EVICTION_BATCH_SIZE = 200 _DEFAULT_MEMORY_CACHE_EVICTION_BATCH_SIZE = 200
# https://docs.python.org/3/library/sqlite3.html#sqlite-and-python-types # https://docs.python.org/3/library/sqlite3.html#sqlite-and-python-types
SqliteValue = Union[int, float, str, bytes, None] # Datetimes get converted to strings
SqliteValue = Union[int, float, str, bytes, datetime, None]
_VT = TypeVar("_VT") _VT = TypeVar("_VT")
@ -130,7 +134,7 @@ def _default_deserializer(value: Any) -> Any:
@dataclass(eq=False) @dataclass(eq=False)
class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]): class FileBackedDict(MutableMapping[str, _VT], Generic[_VT], Closeable):
""" """
A dict-like object that stores its data in a temporary SQLite database. A dict-like object that stores its data in a temporary SQLite database.
This is useful for storing large amounts of data that don't fit in memory. This is useful for storing large amounts of data that don't fit in memory.
@ -139,7 +143,7 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
""" """
# Use a predefined connection, able to be shared across multiple FileBacked* objects # Use a predefined connection, able to be shared across multiple FileBacked* objects
connection: Optional[ConnectionWrapper] = None shared_connection: Optional[ConnectionWrapper] = None
tablename: str = _DEFAULT_TABLE_NAME tablename: str = _DEFAULT_TABLE_NAME
serializer: Callable[[_VT], SqliteValue] = _default_serializer serializer: Callable[[_VT], SqliteValue] = _default_serializer
@ -152,7 +156,10 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
_conn: ConnectionWrapper = field(init=False, repr=False) _conn: ConnectionWrapper = field(init=False, repr=False)
# To improve performance, we maintain an in-memory LRU cache using an OrderedDict. # To improve performance, we maintain an in-memory LRU cache using an OrderedDict.
_active_object_cache: OrderedDict[str, _VT] = field(init=False, repr=False) # Maintains a dirty bit marking whether the value has been modified since it was persisted.
_active_object_cache: OrderedDict[str, Tuple[_VT, bool]] = field(
init=False, repr=False
)
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert ( assert (
@ -162,8 +169,8 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
assert "key" not in self.extra_columns, '"key" is a reserved column name' assert "key" not in self.extra_columns, '"key" is a reserved column name'
assert "value" not in self.extra_columns, '"value" is a reserved column name' assert "value" not in self.extra_columns, '"value" is a reserved column name'
if self.connection: if self.shared_connection:
self._conn = self.connection self._conn = self.shared_connection
else: else:
self._conn = ConnectionWrapper() self._conn = ConnectionWrapper()
@ -189,8 +196,8 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
f"CREATE INDEX {self.tablename}_{column_name} ON {self.tablename} ({column_name})" f"CREATE INDEX {self.tablename}_{column_name} ON {self.tablename} ({column_name})"
) )
def _add_to_cache(self, key: str, value: _VT) -> None: def _add_to_cache(self, key: str, value: _VT, dirty: bool) -> None:
self._active_object_cache[key] = value self._active_object_cache[key] = value, dirty
if len(self._active_object_cache) > self.cache_max_size: if len(self._active_object_cache) > self.cache_max_size:
# Try to prune in batches rather than one at a time. # Try to prune in batches rather than one at a time.
@ -202,8 +209,8 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
def _prune_cache(self, num_items_to_prune: int) -> None: def _prune_cache(self, num_items_to_prune: int) -> None:
items_to_write: List[Tuple[SqliteValue, ...]] = [] items_to_write: List[Tuple[SqliteValue, ...]] = []
for _ in range(num_items_to_prune): for _ in range(num_items_to_prune):
key, value = self._active_object_cache.popitem(last=False) key, (value, dirty) = self._active_object_cache.popitem(last=False)
if dirty:
values = [key, self.serializer(value)] values = [key, self.serializer(value)]
for column_serializer in self.extra_columns.values(): for column_serializer in self.extra_columns.values():
values.append(column_serializer(value)) values.append(column_serializer(value))
@ -226,7 +233,7 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
def __getitem__(self, key: str) -> _VT: def __getitem__(self, key: str) -> _VT:
if key in self._active_object_cache: if key in self._active_object_cache:
self._active_object_cache.move_to_end(key) self._active_object_cache.move_to_end(key)
return self._active_object_cache[key] return self._active_object_cache[key][0]
cursor = self._conn.execute( cursor = self._conn.execute(
f"SELECT value FROM {self.tablename} WHERE key = ?", (key,) f"SELECT value FROM {self.tablename} WHERE key = ?", (key,)
@ -236,11 +243,11 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
raise KeyError(key) raise KeyError(key)
deserialized_result = self.deserializer(result[0]) deserialized_result = self.deserializer(result[0])
self._add_to_cache(key, deserialized_result) self._add_to_cache(key, deserialized_result, False)
return deserialized_result return deserialized_result
def __setitem__(self, key: str, value: _VT) -> None: def __setitem__(self, key: str, value: _VT) -> None:
self._add_to_cache(key, value) self._add_to_cache(key, value, True)
def __delitem__(self, key: str) -> None: def __delitem__(self, key: str) -> None:
in_cache = False in_cache = False
@ -254,17 +261,43 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
if not in_cache and not n_deleted: if not in_cache and not n_deleted:
raise KeyError(key) raise KeyError(key)
def mark_dirty(self, key: str) -> None:
if key in self._active_object_cache and not self._active_object_cache[key][1]:
self._active_object_cache[key] = self._active_object_cache[key][0], True
def __iter__(self) -> Iterator[str]: def __iter__(self) -> Iterator[str]:
# Cache should be small, so safe list cast to avoid mutation during iteration
cache_keys = set(self._active_object_cache.keys())
yield from cache_keys
cursor = self._conn.execute(f"SELECT key FROM {self.tablename}") cursor = self._conn.execute(f"SELECT key FROM {self.tablename}")
for row in cursor: for row in cursor:
if row[0] in self._active_object_cache: if row[0] not in cache_keys:
# If the key is in the active object cache, then SQL isn't the source of truth.
continue
yield row[0] yield row[0]
for key in self._active_object_cache: def items_snapshot(
yield key self, cond_sql: Optional[str] = None
) -> Iterator[Tuple[str, _VT]]:
"""
Return a fixed snapshot, rather than a view, of the dictionary's items.
Flushes the cache and provides the option to filter the results.
Provides better performance over standard `items()` method.
Args:
cond_sql: Conditional expression for WHERE statement, e.g. `x = 0 AND y = "value"`
Returns:
Iterator of filtered (key, value) pairs.
"""
self.flush()
sql = f"SELECT key, value FROM {self.tablename}"
if cond_sql:
sql += f" WHERE {cond_sql}"
cursor = self._conn.execute(sql)
for row in cursor:
yield row[0], self.deserializer(row[1])
def __len__(self) -> int: def __len__(self) -> int:
cursor = self._conn.execute( cursor = self._conn.execute(
@ -282,6 +315,22 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
params: Tuple[Any, ...] = (), params: Tuple[Any, ...] = (),
refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None, refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None,
) -> List[Tuple[Any, ...]]: ) -> List[Tuple[Any, ...]]:
return self._sql_query(query, params, refs).fetchall()
def sql_query_iterator(
self,
query: str,
params: Tuple[Any, ...] = (),
refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None,
) -> Iterator[Tuple[Any, ...]]:
return self._sql_query(query, params, refs)
def _sql_query(
self,
query: str,
params: Tuple[Any, ...] = (),
refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None,
) -> sqlite3.Cursor:
# We need to flush object and any objects the query references to ensure # We need to flush object and any objects the query references to ensure
# that we don't miss objects that have been modified but not yet flushed. # that we don't miss objects that have been modified but not yet flushed.
self.flush() self.flush()
@ -289,15 +338,13 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
for referenced_table in refs: for referenced_table in refs:
referenced_table.flush() referenced_table.flush()
cursor = self._conn.execute(query, params) return self._conn.execute(query, params)
return cursor.fetchall()
def close(self) -> None: def close(self) -> None:
if self._conn: if self._conn:
# Ensure everything is written out. if self.shared_connection: # Connection not owned by this object
self.flush() self.flush() # Ensure everything is written out
else:
if not self.connection: # Connection created inside this class
self._conn.close() self._conn.close()
# This forces all writes to go directly to the DB so they fail immediately. # This forces all writes to go directly to the DB so they fail immediately.
@ -330,7 +377,7 @@ class FileBackedList(Generic[_VT]):
) -> None: ) -> None:
self._len = 0 self._len = 0
self._dict = FileBackedDict( self._dict = FileBackedDict(
connection=connection, shared_connection=connection,
serializer=serializer, serializer=serializer,
deserializer=deserializer, deserializer=deserializer,
tablename=tablename, tablename=tablename,

View File

@ -1,6 +1,7 @@
import dataclasses import dataclasses
import json import json
import pathlib import pathlib
import sqlite3
from dataclasses import dataclass from dataclasses import dataclass
from typing import Counter, Dict from typing import Counter, Dict
@ -25,6 +26,8 @@ def test_file_dict() -> None:
assert len(cache) == 100 assert len(cache) == 100
assert sorted(cache) == sorted([f"key-{i}" for i in range(100)]) assert sorted(cache) == sorted([f"key-{i}" for i in range(100)])
assert sorted(cache.items()) == sorted([(f"key-{i}", i) for i in range(100)])
assert sorted(cache.values()) == sorted([i for i in range(100)])
# Force eviction of everything. # Force eviction of everything.
cache.flush() cache.flush()
@ -140,7 +143,7 @@ def test_custom_serde() -> None:
assert cache["second"] == second assert cache["second"] == second
assert cache["first"] == first assert cache["first"] == first
assert serializer_calls == 4 # Items written to cache on every access assert serializer_calls == 2
assert deserializer_calls == 2 assert deserializer_calls == 2
@ -161,6 +164,7 @@ def test_file_dict_stores_counter() -> None:
cache[str(i)][str(j)] += 100 cache[str(i)][str(j)] += 100
in_memory_counters[i][str(j)] += 100 in_memory_counters[i][str(j)] += 100
cache[str(i)][str(j)] += j cache[str(i)][str(j)] += j
cache.mark_dirty(str(i))
in_memory_counters[i][str(j)] += j in_memory_counters[i][str(j)] += j
for i in range(n): for i in range(n):
@ -174,43 +178,64 @@ class Pair:
y: str y: str
def test_custom_column() -> None: @pytest.mark.parametrize("cache_max_size", [0, 1, 10])
def test_custom_column(cache_max_size: int) -> None:
cache = FileBackedDict[Pair]( cache = FileBackedDict[Pair](
extra_columns={ extra_columns={
"x": lambda m: m.x, "x": lambda m: m.x,
}, },
cache_max_size=cache_max_size,
) )
cache["first"] = Pair(3, "a") cache["first"] = Pair(3, "a")
cache["second"] = Pair(100, "b") cache["second"] = Pair(100, "b")
cache["third"] = Pair(27, "c")
cache["fourth"] = Pair(100, "d")
# Verify that the extra column is present and has the correct values. # Verify that the extra column is present and has the correct values.
assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 103 assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 230
# Verify that the extra column is updated when the value is updated. # Verify that the extra column is updated when the value is updated.
cache["first"] = Pair(4, "c") cache["first"] = Pair(4, "e")
assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 104 assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 231
# Test param binding. # Test param binding.
assert ( assert (
cache.sql_query( cache.sql_query(
f"SELECT sum(x) FROM {cache.tablename} WHERE x < ?", params=(50,) f"SELECT sum(x) FROM {cache.tablename} WHERE x < ?", params=(50,)
)[0][0] )[0][0]
== 4 == 31
) )
assert sorted(list(cache.items())) == [
("first", Pair(4, "e")),
("fourth", Pair(100, "d")),
("second", Pair(100, "b")),
("third", Pair(27, "c")),
]
assert sorted(list(cache.items_snapshot())) == [
("first", Pair(4, "e")),
("fourth", Pair(100, "d")),
("second", Pair(100, "b")),
("third", Pair(27, "c")),
]
assert sorted(list(cache.items_snapshot("x < 50"))) == [
("first", Pair(4, "e")),
("third", Pair(27, "c")),
]
def test_shared_connection() -> None: def test_shared_connection() -> None:
with ConnectionWrapper() as connection: with ConnectionWrapper() as connection:
cache1 = FileBackedDict[int]( cache1 = FileBackedDict[int](
connection=connection, shared_connection=connection,
tablename="cache1", tablename="cache1",
extra_columns={ extra_columns={
"v": lambda v: v, "v": lambda v: v,
}, },
) )
cache2 = FileBackedDict[Pair]( cache2 = FileBackedDict[Pair](
connection=connection, shared_connection=connection,
tablename="cache2", tablename="cache2",
extra_columns={ extra_columns={
"x": lambda m: m.x, "x": lambda m: m.x,
@ -227,10 +252,12 @@ def test_shared_connection() -> None:
assert len(cache1) == 2 assert len(cache1) == 2
assert len(cache2) == 3 assert len(cache2) == 3
# Test advanced SQL queries. # Test advanced SQL queries and sql_query_iterator.
assert cache2.sql_query( iterator = cache2.sql_query_iterator(
f"SELECT y, sum(x) FROM {cache2.tablename} GROUP BY y ORDER BY y" f"SELECT y, sum(x) FROM {cache2.tablename} GROUP BY y ORDER BY y"
) == [("a", 15), ("b", 11)] )
assert type(iterator) == sqlite3.Cursor
assert list(iterator) == [("a", 15), ("b", 11)]
# Test joining between the two tables. # Test joining between the two tables.
assert ( assert (
@ -245,6 +272,12 @@ def test_shared_connection() -> None:
) )
== [("a", 45), ("b", 55)] == [("a", 45), ("b", 55)]
) )
assert list(cache2.items_snapshot('y = "a"')) == [
("ref-a-1", Pair(7, "a")),
("ref-a-2", Pair(8, "a")),
]
cache2.close() cache2.close()
# Check can still use cache1 # Check can still use cache1