mirror of
https://github.com/datahub-project/datahub.git
synced 2025-08-23 16:48:12 +00:00
perf(ingest): Improve FileBackedDict iteration performance; minor refactoring (#7689)
- Adds dirty bit to cache, only writes data if dirty - Refactors __iter__ - Adds sql_query_iterator - Adds items_snapshot, more performant `items()` that allows for filtering - Renames connection -> shared_connection - Removes unnecessary flush during close if connection is not shared - Adds Closeable mixin
This commit is contained in:
parent
04f1b86d54
commit
c7d35ffd66
@ -5,6 +5,7 @@ import pickle
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
@ -23,6 +24,8 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from datahub.ingestion.api.closeable import Closeable
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_FILE_NAME = "sqlite.db"
|
||||
@ -31,7 +34,8 @@ _DEFAULT_MEMORY_CACHE_MAX_SIZE = 2000
|
||||
_DEFAULT_MEMORY_CACHE_EVICTION_BATCH_SIZE = 200
|
||||
|
||||
# https://docs.python.org/3/library/sqlite3.html#sqlite-and-python-types
|
||||
SqliteValue = Union[int, float, str, bytes, None]
|
||||
# Datetimes get converted to strings
|
||||
SqliteValue = Union[int, float, str, bytes, datetime, None]
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
@ -130,7 +134,7 @@ def _default_deserializer(value: Any) -> Any:
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
||||
class FileBackedDict(MutableMapping[str, _VT], Generic[_VT], Closeable):
|
||||
"""
|
||||
A dict-like object that stores its data in a temporary SQLite database.
|
||||
This is useful for storing large amounts of data that don't fit in memory.
|
||||
@ -139,7 +143,7 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
||||
"""
|
||||
|
||||
# Use a predefined connection, able to be shared across multiple FileBacked* objects
|
||||
connection: Optional[ConnectionWrapper] = None
|
||||
shared_connection: Optional[ConnectionWrapper] = None
|
||||
tablename: str = _DEFAULT_TABLE_NAME
|
||||
|
||||
serializer: Callable[[_VT], SqliteValue] = _default_serializer
|
||||
@ -152,7 +156,10 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
||||
_conn: ConnectionWrapper = field(init=False, repr=False)
|
||||
|
||||
# To improve performance, we maintain an in-memory LRU cache using an OrderedDict.
|
||||
_active_object_cache: OrderedDict[str, _VT] = field(init=False, repr=False)
|
||||
# Maintains a dirty bit marking whether the value has been modified since it was persisted.
|
||||
_active_object_cache: OrderedDict[str, Tuple[_VT, bool]] = field(
|
||||
init=False, repr=False
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert (
|
||||
@ -162,8 +169,8 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
||||
assert "key" not in self.extra_columns, '"key" is a reserved column name'
|
||||
assert "value" not in self.extra_columns, '"value" is a reserved column name'
|
||||
|
||||
if self.connection:
|
||||
self._conn = self.connection
|
||||
if self.shared_connection:
|
||||
self._conn = self.shared_connection
|
||||
else:
|
||||
self._conn = ConnectionWrapper()
|
||||
|
||||
@ -189,8 +196,8 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
||||
f"CREATE INDEX {self.tablename}_{column_name} ON {self.tablename} ({column_name})"
|
||||
)
|
||||
|
||||
def _add_to_cache(self, key: str, value: _VT) -> None:
|
||||
self._active_object_cache[key] = value
|
||||
def _add_to_cache(self, key: str, value: _VT, dirty: bool) -> None:
|
||||
self._active_object_cache[key] = value, dirty
|
||||
|
||||
if len(self._active_object_cache) > self.cache_max_size:
|
||||
# Try to prune in batches rather than one at a time.
|
||||
@ -202,8 +209,8 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
||||
def _prune_cache(self, num_items_to_prune: int) -> None:
|
||||
items_to_write: List[Tuple[SqliteValue, ...]] = []
|
||||
for _ in range(num_items_to_prune):
|
||||
key, value = self._active_object_cache.popitem(last=False)
|
||||
|
||||
key, (value, dirty) = self._active_object_cache.popitem(last=False)
|
||||
if dirty:
|
||||
values = [key, self.serializer(value)]
|
||||
for column_serializer in self.extra_columns.values():
|
||||
values.append(column_serializer(value))
|
||||
@ -226,7 +233,7 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
||||
def __getitem__(self, key: str) -> _VT:
|
||||
if key in self._active_object_cache:
|
||||
self._active_object_cache.move_to_end(key)
|
||||
return self._active_object_cache[key]
|
||||
return self._active_object_cache[key][0]
|
||||
|
||||
cursor = self._conn.execute(
|
||||
f"SELECT value FROM {self.tablename} WHERE key = ?", (key,)
|
||||
@ -236,11 +243,11 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
||||
raise KeyError(key)
|
||||
|
||||
deserialized_result = self.deserializer(result[0])
|
||||
self._add_to_cache(key, deserialized_result)
|
||||
self._add_to_cache(key, deserialized_result, False)
|
||||
return deserialized_result
|
||||
|
||||
def __setitem__(self, key: str, value: _VT) -> None:
|
||||
self._add_to_cache(key, value)
|
||||
self._add_to_cache(key, value, True)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
in_cache = False
|
||||
@ -254,17 +261,43 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
||||
if not in_cache and not n_deleted:
|
||||
raise KeyError(key)
|
||||
|
||||
def mark_dirty(self, key: str) -> None:
|
||||
if key in self._active_object_cache and not self._active_object_cache[key][1]:
|
||||
self._active_object_cache[key] = self._active_object_cache[key][0], True
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
# Cache should be small, so safe list cast to avoid mutation during iteration
|
||||
cache_keys = set(self._active_object_cache.keys())
|
||||
yield from cache_keys
|
||||
|
||||
cursor = self._conn.execute(f"SELECT key FROM {self.tablename}")
|
||||
for row in cursor:
|
||||
if row[0] in self._active_object_cache:
|
||||
# If the key is in the active object cache, then SQL isn't the source of truth.
|
||||
continue
|
||||
|
||||
if row[0] not in cache_keys:
|
||||
yield row[0]
|
||||
|
||||
for key in self._active_object_cache:
|
||||
yield key
|
||||
def items_snapshot(
|
||||
self, cond_sql: Optional[str] = None
|
||||
) -> Iterator[Tuple[str, _VT]]:
|
||||
"""
|
||||
Return a fixed snapshot, rather than a view, of the dictionary's items.
|
||||
|
||||
Flushes the cache and provides the option to filter the results.
|
||||
Provides better performance over standard `items()` method.
|
||||
|
||||
Args:
|
||||
cond_sql: Conditional expression for WHERE statement, e.g. `x = 0 AND y = "value"`
|
||||
|
||||
Returns:
|
||||
Iterator of filtered (key, value) pairs.
|
||||
"""
|
||||
self.flush()
|
||||
sql = f"SELECT key, value FROM {self.tablename}"
|
||||
if cond_sql:
|
||||
sql += f" WHERE {cond_sql}"
|
||||
|
||||
cursor = self._conn.execute(sql)
|
||||
for row in cursor:
|
||||
yield row[0], self.deserializer(row[1])
|
||||
|
||||
def __len__(self) -> int:
|
||||
cursor = self._conn.execute(
|
||||
@ -282,6 +315,22 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
||||
params: Tuple[Any, ...] = (),
|
||||
refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None,
|
||||
) -> List[Tuple[Any, ...]]:
|
||||
return self._sql_query(query, params, refs).fetchall()
|
||||
|
||||
def sql_query_iterator(
|
||||
self,
|
||||
query: str,
|
||||
params: Tuple[Any, ...] = (),
|
||||
refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None,
|
||||
) -> Iterator[Tuple[Any, ...]]:
|
||||
return self._sql_query(query, params, refs)
|
||||
|
||||
def _sql_query(
|
||||
self,
|
||||
query: str,
|
||||
params: Tuple[Any, ...] = (),
|
||||
refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None,
|
||||
) -> sqlite3.Cursor:
|
||||
# We need to flush object and any objects the query references to ensure
|
||||
# that we don't miss objects that have been modified but not yet flushed.
|
||||
self.flush()
|
||||
@ -289,15 +338,13 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
||||
for referenced_table in refs:
|
||||
referenced_table.flush()
|
||||
|
||||
cursor = self._conn.execute(query, params)
|
||||
return cursor.fetchall()
|
||||
return self._conn.execute(query, params)
|
||||
|
||||
def close(self) -> None:
|
||||
if self._conn:
|
||||
# Ensure everything is written out.
|
||||
self.flush()
|
||||
|
||||
if not self.connection: # Connection created inside this class
|
||||
if self.shared_connection: # Connection not owned by this object
|
||||
self.flush() # Ensure everything is written out
|
||||
else:
|
||||
self._conn.close()
|
||||
|
||||
# This forces all writes to go directly to the DB so they fail immediately.
|
||||
@ -330,7 +377,7 @@ class FileBackedList(Generic[_VT]):
|
||||
) -> None:
|
||||
self._len = 0
|
||||
self._dict = FileBackedDict(
|
||||
connection=connection,
|
||||
shared_connection=connection,
|
||||
serializer=serializer,
|
||||
deserializer=deserializer,
|
||||
tablename=tablename,
|
||||
|
@ -1,6 +1,7 @@
|
||||
import dataclasses
|
||||
import json
|
||||
import pathlib
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from typing import Counter, Dict
|
||||
|
||||
@ -25,6 +26,8 @@ def test_file_dict() -> None:
|
||||
|
||||
assert len(cache) == 100
|
||||
assert sorted(cache) == sorted([f"key-{i}" for i in range(100)])
|
||||
assert sorted(cache.items()) == sorted([(f"key-{i}", i) for i in range(100)])
|
||||
assert sorted(cache.values()) == sorted([i for i in range(100)])
|
||||
|
||||
# Force eviction of everything.
|
||||
cache.flush()
|
||||
@ -140,7 +143,7 @@ def test_custom_serde() -> None:
|
||||
|
||||
assert cache["second"] == second
|
||||
assert cache["first"] == first
|
||||
assert serializer_calls == 4 # Items written to cache on every access
|
||||
assert serializer_calls == 2
|
||||
assert deserializer_calls == 2
|
||||
|
||||
|
||||
@ -161,6 +164,7 @@ def test_file_dict_stores_counter() -> None:
|
||||
cache[str(i)][str(j)] += 100
|
||||
in_memory_counters[i][str(j)] += 100
|
||||
cache[str(i)][str(j)] += j
|
||||
cache.mark_dirty(str(i))
|
||||
in_memory_counters[i][str(j)] += j
|
||||
|
||||
for i in range(n):
|
||||
@ -174,43 +178,64 @@ class Pair:
|
||||
y: str
|
||||
|
||||
|
||||
def test_custom_column() -> None:
|
||||
@pytest.mark.parametrize("cache_max_size", [0, 1, 10])
|
||||
def test_custom_column(cache_max_size: int) -> None:
|
||||
cache = FileBackedDict[Pair](
|
||||
extra_columns={
|
||||
"x": lambda m: m.x,
|
||||
},
|
||||
cache_max_size=cache_max_size,
|
||||
)
|
||||
|
||||
cache["first"] = Pair(3, "a")
|
||||
cache["second"] = Pair(100, "b")
|
||||
cache["third"] = Pair(27, "c")
|
||||
cache["fourth"] = Pair(100, "d")
|
||||
|
||||
# Verify that the extra column is present and has the correct values.
|
||||
assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 103
|
||||
assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 230
|
||||
|
||||
# Verify that the extra column is updated when the value is updated.
|
||||
cache["first"] = Pair(4, "c")
|
||||
assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 104
|
||||
cache["first"] = Pair(4, "e")
|
||||
assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 231
|
||||
|
||||
# Test param binding.
|
||||
assert (
|
||||
cache.sql_query(
|
||||
f"SELECT sum(x) FROM {cache.tablename} WHERE x < ?", params=(50,)
|
||||
)[0][0]
|
||||
== 4
|
||||
== 31
|
||||
)
|
||||
|
||||
assert sorted(list(cache.items())) == [
|
||||
("first", Pair(4, "e")),
|
||||
("fourth", Pair(100, "d")),
|
||||
("second", Pair(100, "b")),
|
||||
("third", Pair(27, "c")),
|
||||
]
|
||||
assert sorted(list(cache.items_snapshot())) == [
|
||||
("first", Pair(4, "e")),
|
||||
("fourth", Pair(100, "d")),
|
||||
("second", Pair(100, "b")),
|
||||
("third", Pair(27, "c")),
|
||||
]
|
||||
assert sorted(list(cache.items_snapshot("x < 50"))) == [
|
||||
("first", Pair(4, "e")),
|
||||
("third", Pair(27, "c")),
|
||||
]
|
||||
|
||||
|
||||
def test_shared_connection() -> None:
|
||||
with ConnectionWrapper() as connection:
|
||||
cache1 = FileBackedDict[int](
|
||||
connection=connection,
|
||||
shared_connection=connection,
|
||||
tablename="cache1",
|
||||
extra_columns={
|
||||
"v": lambda v: v,
|
||||
},
|
||||
)
|
||||
cache2 = FileBackedDict[Pair](
|
||||
connection=connection,
|
||||
shared_connection=connection,
|
||||
tablename="cache2",
|
||||
extra_columns={
|
||||
"x": lambda m: m.x,
|
||||
@ -227,10 +252,12 @@ def test_shared_connection() -> None:
|
||||
assert len(cache1) == 2
|
||||
assert len(cache2) == 3
|
||||
|
||||
# Test advanced SQL queries.
|
||||
assert cache2.sql_query(
|
||||
# Test advanced SQL queries and sql_query_iterator.
|
||||
iterator = cache2.sql_query_iterator(
|
||||
f"SELECT y, sum(x) FROM {cache2.tablename} GROUP BY y ORDER BY y"
|
||||
) == [("a", 15), ("b", 11)]
|
||||
)
|
||||
assert type(iterator) == sqlite3.Cursor
|
||||
assert list(iterator) == [("a", 15), ("b", 11)]
|
||||
|
||||
# Test joining between the two tables.
|
||||
assert (
|
||||
@ -245,6 +272,12 @@ def test_shared_connection() -> None:
|
||||
)
|
||||
== [("a", 45), ("b", 55)]
|
||||
)
|
||||
|
||||
assert list(cache2.items_snapshot('y = "a"')) == [
|
||||
("ref-a-1", Pair(7, "a")),
|
||||
("ref-a-2", Pair(8, "a")),
|
||||
]
|
||||
|
||||
cache2.close()
|
||||
|
||||
# Check can still use cache1
|
||||
|
Loading…
x
Reference in New Issue
Block a user