perf(ingest): Improve FileBackedDict iteration performance; minor refactoring (#7689)

- Adds dirty bit to cache, only writes data if dirty
- Refactors __iter__
- Adds sql_query_iterator
- Adds items_snapshot, more performant `items()` that allows for filtering
- Renames connection -> shared_connection
- Removes unnecessary flush during close if connection is not shared
- Adds Closeable mixin
This commit is contained in:
Andrew Sikowitz 2023-03-27 17:20:34 -04:00 committed by GitHub
parent 04f1b86d54
commit c7d35ffd66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 121 additions and 41 deletions

View File

@ -5,6 +5,7 @@ import pickle
import sqlite3
import tempfile
from dataclasses import dataclass, field
from datetime import datetime
from types import TracebackType
from typing import (
Any,
@ -23,6 +24,8 @@ from typing import (
Union,
)
from datahub.ingestion.api.closeable import Closeable
logger: logging.Logger = logging.getLogger(__name__)
_DEFAULT_FILE_NAME = "sqlite.db"
@ -31,7 +34,8 @@ _DEFAULT_MEMORY_CACHE_MAX_SIZE = 2000
_DEFAULT_MEMORY_CACHE_EVICTION_BATCH_SIZE = 200
# https://docs.python.org/3/library/sqlite3.html#sqlite-and-python-types
SqliteValue = Union[int, float, str, bytes, None]
# Datetimes get converted to strings
SqliteValue = Union[int, float, str, bytes, datetime, None]
_VT = TypeVar("_VT")
@ -130,7 +134,7 @@ def _default_deserializer(value: Any) -> Any:
@dataclass(eq=False)
class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
class FileBackedDict(MutableMapping[str, _VT], Generic[_VT], Closeable):
"""
A dict-like object that stores its data in a temporary SQLite database.
This is useful for storing large amounts of data that don't fit in memory.
@ -139,7 +143,7 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
"""
# Use a predefined connection, able to be shared across multiple FileBacked* objects
connection: Optional[ConnectionWrapper] = None
shared_connection: Optional[ConnectionWrapper] = None
tablename: str = _DEFAULT_TABLE_NAME
serializer: Callable[[_VT], SqliteValue] = _default_serializer
@ -152,7 +156,10 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
_conn: ConnectionWrapper = field(init=False, repr=False)
# To improve performance, we maintain an in-memory LRU cache using an OrderedDict.
_active_object_cache: OrderedDict[str, _VT] = field(init=False, repr=False)
# Maintains a dirty bit marking whether the value has been modified since it was persisted.
_active_object_cache: OrderedDict[str, Tuple[_VT, bool]] = field(
init=False, repr=False
)
def __post_init__(self) -> None:
assert (
@ -162,8 +169,8 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
assert "key" not in self.extra_columns, '"key" is a reserved column name'
assert "value" not in self.extra_columns, '"value" is a reserved column name'
if self.connection:
self._conn = self.connection
if self.shared_connection:
self._conn = self.shared_connection
else:
self._conn = ConnectionWrapper()
@ -189,8 +196,8 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
f"CREATE INDEX {self.tablename}_{column_name} ON {self.tablename} ({column_name})"
)
def _add_to_cache(self, key: str, value: _VT) -> None:
self._active_object_cache[key] = value
def _add_to_cache(self, key: str, value: _VT, dirty: bool) -> None:
self._active_object_cache[key] = value, dirty
if len(self._active_object_cache) > self.cache_max_size:
# Try to prune in batches rather than one at a time.
@ -202,8 +209,8 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
def _prune_cache(self, num_items_to_prune: int) -> None:
items_to_write: List[Tuple[SqliteValue, ...]] = []
for _ in range(num_items_to_prune):
key, value = self._active_object_cache.popitem(last=False)
key, (value, dirty) = self._active_object_cache.popitem(last=False)
if dirty:
values = [key, self.serializer(value)]
for column_serializer in self.extra_columns.values():
values.append(column_serializer(value))
@ -226,7 +233,7 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
def __getitem__(self, key: str) -> _VT:
if key in self._active_object_cache:
self._active_object_cache.move_to_end(key)
return self._active_object_cache[key]
return self._active_object_cache[key][0]
cursor = self._conn.execute(
f"SELECT value FROM {self.tablename} WHERE key = ?", (key,)
@ -236,11 +243,11 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
raise KeyError(key)
deserialized_result = self.deserializer(result[0])
self._add_to_cache(key, deserialized_result)
self._add_to_cache(key, deserialized_result, False)
return deserialized_result
def __setitem__(self, key: str, value: _VT) -> None:
self._add_to_cache(key, value)
self._add_to_cache(key, value, True)
def __delitem__(self, key: str) -> None:
in_cache = False
@ -254,17 +261,43 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
if not in_cache and not n_deleted:
raise KeyError(key)
def mark_dirty(self, key: str) -> None:
if key in self._active_object_cache and not self._active_object_cache[key][1]:
self._active_object_cache[key] = self._active_object_cache[key][0], True
def __iter__(self) -> Iterator[str]:
# Cache should be small, so safe list cast to avoid mutation during iteration
cache_keys = set(self._active_object_cache.keys())
yield from cache_keys
cursor = self._conn.execute(f"SELECT key FROM {self.tablename}")
for row in cursor:
if row[0] in self._active_object_cache:
# If the key is in the active object cache, then SQL isn't the source of truth.
continue
if row[0] not in cache_keys:
yield row[0]
for key in self._active_object_cache:
yield key
def items_snapshot(
self, cond_sql: Optional[str] = None
) -> Iterator[Tuple[str, _VT]]:
"""
Return a fixed snapshot, rather than a view, of the dictionary's items.
Flushes the cache and provides the option to filter the results.
Provides better performance over standard `items()` method.
Args:
cond_sql: Conditional expression for WHERE statement, e.g. `x = 0 AND y = "value"`
Returns:
Iterator of filtered (key, value) pairs.
"""
self.flush()
sql = f"SELECT key, value FROM {self.tablename}"
if cond_sql:
sql += f" WHERE {cond_sql}"
cursor = self._conn.execute(sql)
for row in cursor:
yield row[0], self.deserializer(row[1])
def __len__(self) -> int:
cursor = self._conn.execute(
@ -282,6 +315,22 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
params: Tuple[Any, ...] = (),
refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None,
) -> List[Tuple[Any, ...]]:
return self._sql_query(query, params, refs).fetchall()
def sql_query_iterator(
self,
query: str,
params: Tuple[Any, ...] = (),
refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None,
) -> Iterator[Tuple[Any, ...]]:
return self._sql_query(query, params, refs)
def _sql_query(
self,
query: str,
params: Tuple[Any, ...] = (),
refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None,
) -> sqlite3.Cursor:
# We need to flush object and any objects the query references to ensure
# that we don't miss objects that have been modified but not yet flushed.
self.flush()
@ -289,15 +338,13 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
for referenced_table in refs:
referenced_table.flush()
cursor = self._conn.execute(query, params)
return cursor.fetchall()
return self._conn.execute(query, params)
def close(self) -> None:
if self._conn:
# Ensure everything is written out.
self.flush()
if not self.connection: # Connection created inside this class
if self.shared_connection: # Connection not owned by this object
self.flush() # Ensure everything is written out
else:
self._conn.close()
# This forces all writes to go directly to the DB so they fail immediately.
@ -330,7 +377,7 @@ class FileBackedList(Generic[_VT]):
) -> None:
self._len = 0
self._dict = FileBackedDict(
connection=connection,
shared_connection=connection,
serializer=serializer,
deserializer=deserializer,
tablename=tablename,

View File

@ -1,6 +1,7 @@
import dataclasses
import json
import pathlib
import sqlite3
from dataclasses import dataclass
from typing import Counter, Dict
@ -25,6 +26,8 @@ def test_file_dict() -> None:
assert len(cache) == 100
assert sorted(cache) == sorted([f"key-{i}" for i in range(100)])
assert sorted(cache.items()) == sorted([(f"key-{i}", i) for i in range(100)])
assert sorted(cache.values()) == sorted([i for i in range(100)])
# Force eviction of everything.
cache.flush()
@ -140,7 +143,7 @@ def test_custom_serde() -> None:
assert cache["second"] == second
assert cache["first"] == first
assert serializer_calls == 4 # Items written to cache on every access
assert serializer_calls == 2
assert deserializer_calls == 2
@ -161,6 +164,7 @@ def test_file_dict_stores_counter() -> None:
cache[str(i)][str(j)] += 100
in_memory_counters[i][str(j)] += 100
cache[str(i)][str(j)] += j
cache.mark_dirty(str(i))
in_memory_counters[i][str(j)] += j
for i in range(n):
@ -174,43 +178,64 @@ class Pair:
y: str
def test_custom_column() -> None:
@pytest.mark.parametrize("cache_max_size", [0, 1, 10])
def test_custom_column(cache_max_size: int) -> None:
cache = FileBackedDict[Pair](
extra_columns={
"x": lambda m: m.x,
},
cache_max_size=cache_max_size,
)
cache["first"] = Pair(3, "a")
cache["second"] = Pair(100, "b")
cache["third"] = Pair(27, "c")
cache["fourth"] = Pair(100, "d")
# Verify that the extra column is present and has the correct values.
assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 103
assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 230
# Verify that the extra column is updated when the value is updated.
cache["first"] = Pair(4, "c")
assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 104
cache["first"] = Pair(4, "e")
assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 231
# Test param binding.
assert (
cache.sql_query(
f"SELECT sum(x) FROM {cache.tablename} WHERE x < ?", params=(50,)
)[0][0]
== 4
== 31
)
assert sorted(list(cache.items())) == [
("first", Pair(4, "e")),
("fourth", Pair(100, "d")),
("second", Pair(100, "b")),
("third", Pair(27, "c")),
]
assert sorted(list(cache.items_snapshot())) == [
("first", Pair(4, "e")),
("fourth", Pair(100, "d")),
("second", Pair(100, "b")),
("third", Pair(27, "c")),
]
assert sorted(list(cache.items_snapshot("x < 50"))) == [
("first", Pair(4, "e")),
("third", Pair(27, "c")),
]
def test_shared_connection() -> None:
with ConnectionWrapper() as connection:
cache1 = FileBackedDict[int](
connection=connection,
shared_connection=connection,
tablename="cache1",
extra_columns={
"v": lambda v: v,
},
)
cache2 = FileBackedDict[Pair](
connection=connection,
shared_connection=connection,
tablename="cache2",
extra_columns={
"x": lambda m: m.x,
@ -227,10 +252,12 @@ def test_shared_connection() -> None:
assert len(cache1) == 2
assert len(cache2) == 3
# Test advanced SQL queries.
assert cache2.sql_query(
# Test advanced SQL queries and sql_query_iterator.
iterator = cache2.sql_query_iterator(
f"SELECT y, sum(x) FROM {cache2.tablename} GROUP BY y ORDER BY y"
) == [("a", 15), ("b", 11)]
)
assert type(iterator) == sqlite3.Cursor
assert list(iterator) == [("a", 15), ("b", 11)]
# Test joining between the two tables.
assert (
@ -245,6 +272,12 @@ def test_shared_connection() -> None:
)
== [("a", 45), ("b", 55)]
)
assert list(cache2.items_snapshot('y = "a"')) == [
("ref-a-1", Pair(7, "a")),
("ref-a-2", Pair(8, "a")),
]
cache2.close()
# Check can still use cache1