mirror of
https://github.com/datahub-project/datahub.git
synced 2025-08-24 00:57:59 +00:00
perf(ingest): Improve FileBackedDict iteration performance; minor refactoring (#7689)
- Adds dirty bit to cache, only writes data if dirty - Refactors __iter__ - Adds sql_query_iterator - Adds items_snapshot, more performant `items()` that allows for filtering - Renames connection -> shared_connection - Removes unnecessary flush during close if connection is not shared - Adds Closeable mixin
This commit is contained in:
parent
04f1b86d54
commit
c7d35ffd66
@ -5,6 +5,7 @@ import pickle
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import tempfile
|
import tempfile
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -23,6 +24,8 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from datahub.ingestion.api.closeable import Closeable
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger(__name__)
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_DEFAULT_FILE_NAME = "sqlite.db"
|
_DEFAULT_FILE_NAME = "sqlite.db"
|
||||||
@ -31,7 +34,8 @@ _DEFAULT_MEMORY_CACHE_MAX_SIZE = 2000
|
|||||||
_DEFAULT_MEMORY_CACHE_EVICTION_BATCH_SIZE = 200
|
_DEFAULT_MEMORY_CACHE_EVICTION_BATCH_SIZE = 200
|
||||||
|
|
||||||
# https://docs.python.org/3/library/sqlite3.html#sqlite-and-python-types
|
# https://docs.python.org/3/library/sqlite3.html#sqlite-and-python-types
|
||||||
SqliteValue = Union[int, float, str, bytes, None]
|
# Datetimes get converted to strings
|
||||||
|
SqliteValue = Union[int, float, str, bytes, datetime, None]
|
||||||
|
|
||||||
_VT = TypeVar("_VT")
|
_VT = TypeVar("_VT")
|
||||||
|
|
||||||
@ -130,7 +134,7 @@ def _default_deserializer(value: Any) -> Any:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False)
|
@dataclass(eq=False)
|
||||||
class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
class FileBackedDict(MutableMapping[str, _VT], Generic[_VT], Closeable):
|
||||||
"""
|
"""
|
||||||
A dict-like object that stores its data in a temporary SQLite database.
|
A dict-like object that stores its data in a temporary SQLite database.
|
||||||
This is useful for storing large amounts of data that don't fit in memory.
|
This is useful for storing large amounts of data that don't fit in memory.
|
||||||
@ -139,7 +143,7 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Use a predefined connection, able to be shared across multiple FileBacked* objects
|
# Use a predefined connection, able to be shared across multiple FileBacked* objects
|
||||||
connection: Optional[ConnectionWrapper] = None
|
shared_connection: Optional[ConnectionWrapper] = None
|
||||||
tablename: str = _DEFAULT_TABLE_NAME
|
tablename: str = _DEFAULT_TABLE_NAME
|
||||||
|
|
||||||
serializer: Callable[[_VT], SqliteValue] = _default_serializer
|
serializer: Callable[[_VT], SqliteValue] = _default_serializer
|
||||||
@ -152,7 +156,10 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
|||||||
_conn: ConnectionWrapper = field(init=False, repr=False)
|
_conn: ConnectionWrapper = field(init=False, repr=False)
|
||||||
|
|
||||||
# To improve performance, we maintain an in-memory LRU cache using an OrderedDict.
|
# To improve performance, we maintain an in-memory LRU cache using an OrderedDict.
|
||||||
_active_object_cache: OrderedDict[str, _VT] = field(init=False, repr=False)
|
# Maintains a dirty bit marking whether the value has been modified since it was persisted.
|
||||||
|
_active_object_cache: OrderedDict[str, Tuple[_VT, bool]] = field(
|
||||||
|
init=False, repr=False
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
assert (
|
assert (
|
||||||
@ -162,8 +169,8 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
|||||||
assert "key" not in self.extra_columns, '"key" is a reserved column name'
|
assert "key" not in self.extra_columns, '"key" is a reserved column name'
|
||||||
assert "value" not in self.extra_columns, '"value" is a reserved column name'
|
assert "value" not in self.extra_columns, '"value" is a reserved column name'
|
||||||
|
|
||||||
if self.connection:
|
if self.shared_connection:
|
||||||
self._conn = self.connection
|
self._conn = self.shared_connection
|
||||||
else:
|
else:
|
||||||
self._conn = ConnectionWrapper()
|
self._conn = ConnectionWrapper()
|
||||||
|
|
||||||
@ -189,8 +196,8 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
|||||||
f"CREATE INDEX {self.tablename}_{column_name} ON {self.tablename} ({column_name})"
|
f"CREATE INDEX {self.tablename}_{column_name} ON {self.tablename} ({column_name})"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_to_cache(self, key: str, value: _VT) -> None:
|
def _add_to_cache(self, key: str, value: _VT, dirty: bool) -> None:
|
||||||
self._active_object_cache[key] = value
|
self._active_object_cache[key] = value, dirty
|
||||||
|
|
||||||
if len(self._active_object_cache) > self.cache_max_size:
|
if len(self._active_object_cache) > self.cache_max_size:
|
||||||
# Try to prune in batches rather than one at a time.
|
# Try to prune in batches rather than one at a time.
|
||||||
@ -202,8 +209,8 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
|||||||
def _prune_cache(self, num_items_to_prune: int) -> None:
|
def _prune_cache(self, num_items_to_prune: int) -> None:
|
||||||
items_to_write: List[Tuple[SqliteValue, ...]] = []
|
items_to_write: List[Tuple[SqliteValue, ...]] = []
|
||||||
for _ in range(num_items_to_prune):
|
for _ in range(num_items_to_prune):
|
||||||
key, value = self._active_object_cache.popitem(last=False)
|
key, (value, dirty) = self._active_object_cache.popitem(last=False)
|
||||||
|
if dirty:
|
||||||
values = [key, self.serializer(value)]
|
values = [key, self.serializer(value)]
|
||||||
for column_serializer in self.extra_columns.values():
|
for column_serializer in self.extra_columns.values():
|
||||||
values.append(column_serializer(value))
|
values.append(column_serializer(value))
|
||||||
@ -226,7 +233,7 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
|||||||
def __getitem__(self, key: str) -> _VT:
|
def __getitem__(self, key: str) -> _VT:
|
||||||
if key in self._active_object_cache:
|
if key in self._active_object_cache:
|
||||||
self._active_object_cache.move_to_end(key)
|
self._active_object_cache.move_to_end(key)
|
||||||
return self._active_object_cache[key]
|
return self._active_object_cache[key][0]
|
||||||
|
|
||||||
cursor = self._conn.execute(
|
cursor = self._conn.execute(
|
||||||
f"SELECT value FROM {self.tablename} WHERE key = ?", (key,)
|
f"SELECT value FROM {self.tablename} WHERE key = ?", (key,)
|
||||||
@ -236,11 +243,11 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
|||||||
raise KeyError(key)
|
raise KeyError(key)
|
||||||
|
|
||||||
deserialized_result = self.deserializer(result[0])
|
deserialized_result = self.deserializer(result[0])
|
||||||
self._add_to_cache(key, deserialized_result)
|
self._add_to_cache(key, deserialized_result, False)
|
||||||
return deserialized_result
|
return deserialized_result
|
||||||
|
|
||||||
def __setitem__(self, key: str, value: _VT) -> None:
|
def __setitem__(self, key: str, value: _VT) -> None:
|
||||||
self._add_to_cache(key, value)
|
self._add_to_cache(key, value, True)
|
||||||
|
|
||||||
def __delitem__(self, key: str) -> None:
|
def __delitem__(self, key: str) -> None:
|
||||||
in_cache = False
|
in_cache = False
|
||||||
@ -254,17 +261,43 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
|||||||
if not in_cache and not n_deleted:
|
if not in_cache and not n_deleted:
|
||||||
raise KeyError(key)
|
raise KeyError(key)
|
||||||
|
|
||||||
|
def mark_dirty(self, key: str) -> None:
|
||||||
|
if key in self._active_object_cache and not self._active_object_cache[key][1]:
|
||||||
|
self._active_object_cache[key] = self._active_object_cache[key][0], True
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[str]:
|
def __iter__(self) -> Iterator[str]:
|
||||||
|
# Cache should be small, so safe list cast to avoid mutation during iteration
|
||||||
|
cache_keys = set(self._active_object_cache.keys())
|
||||||
|
yield from cache_keys
|
||||||
|
|
||||||
cursor = self._conn.execute(f"SELECT key FROM {self.tablename}")
|
cursor = self._conn.execute(f"SELECT key FROM {self.tablename}")
|
||||||
for row in cursor:
|
for row in cursor:
|
||||||
if row[0] in self._active_object_cache:
|
if row[0] not in cache_keys:
|
||||||
# If the key is in the active object cache, then SQL isn't the source of truth.
|
|
||||||
continue
|
|
||||||
|
|
||||||
yield row[0]
|
yield row[0]
|
||||||
|
|
||||||
for key in self._active_object_cache:
|
def items_snapshot(
|
||||||
yield key
|
self, cond_sql: Optional[str] = None
|
||||||
|
) -> Iterator[Tuple[str, _VT]]:
|
||||||
|
"""
|
||||||
|
Return a fixed snapshot, rather than a view, of the dictionary's items.
|
||||||
|
|
||||||
|
Flushes the cache and provides the option to filter the results.
|
||||||
|
Provides better performance over standard `items()` method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cond_sql: Conditional expression for WHERE statement, e.g. `x = 0 AND y = "value"`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterator of filtered (key, value) pairs.
|
||||||
|
"""
|
||||||
|
self.flush()
|
||||||
|
sql = f"SELECT key, value FROM {self.tablename}"
|
||||||
|
if cond_sql:
|
||||||
|
sql += f" WHERE {cond_sql}"
|
||||||
|
|
||||||
|
cursor = self._conn.execute(sql)
|
||||||
|
for row in cursor:
|
||||||
|
yield row[0], self.deserializer(row[1])
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
cursor = self._conn.execute(
|
cursor = self._conn.execute(
|
||||||
@ -282,6 +315,22 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
|||||||
params: Tuple[Any, ...] = (),
|
params: Tuple[Any, ...] = (),
|
||||||
refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None,
|
refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None,
|
||||||
) -> List[Tuple[Any, ...]]:
|
) -> List[Tuple[Any, ...]]:
|
||||||
|
return self._sql_query(query, params, refs).fetchall()
|
||||||
|
|
||||||
|
def sql_query_iterator(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
params: Tuple[Any, ...] = (),
|
||||||
|
refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None,
|
||||||
|
) -> Iterator[Tuple[Any, ...]]:
|
||||||
|
return self._sql_query(query, params, refs)
|
||||||
|
|
||||||
|
def _sql_query(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
params: Tuple[Any, ...] = (),
|
||||||
|
refs: Optional[List[Union["FileBackedList", "FileBackedDict"]]] = None,
|
||||||
|
) -> sqlite3.Cursor:
|
||||||
# We need to flush object and any objects the query references to ensure
|
# We need to flush object and any objects the query references to ensure
|
||||||
# that we don't miss objects that have been modified but not yet flushed.
|
# that we don't miss objects that have been modified but not yet flushed.
|
||||||
self.flush()
|
self.flush()
|
||||||
@ -289,15 +338,13 @@ class FileBackedDict(MutableMapping[str, _VT], Generic[_VT]):
|
|||||||
for referenced_table in refs:
|
for referenced_table in refs:
|
||||||
referenced_table.flush()
|
referenced_table.flush()
|
||||||
|
|
||||||
cursor = self._conn.execute(query, params)
|
return self._conn.execute(query, params)
|
||||||
return cursor.fetchall()
|
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
if self._conn:
|
if self._conn:
|
||||||
# Ensure everything is written out.
|
if self.shared_connection: # Connection not owned by this object
|
||||||
self.flush()
|
self.flush() # Ensure everything is written out
|
||||||
|
else:
|
||||||
if not self.connection: # Connection created inside this class
|
|
||||||
self._conn.close()
|
self._conn.close()
|
||||||
|
|
||||||
# This forces all writes to go directly to the DB so they fail immediately.
|
# This forces all writes to go directly to the DB so they fail immediately.
|
||||||
@ -330,7 +377,7 @@ class FileBackedList(Generic[_VT]):
|
|||||||
) -> None:
|
) -> None:
|
||||||
self._len = 0
|
self._len = 0
|
||||||
self._dict = FileBackedDict(
|
self._dict = FileBackedDict(
|
||||||
connection=connection,
|
shared_connection=connection,
|
||||||
serializer=serializer,
|
serializer=serializer,
|
||||||
deserializer=deserializer,
|
deserializer=deserializer,
|
||||||
tablename=tablename,
|
tablename=tablename,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import sqlite3
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Counter, Dict
|
from typing import Counter, Dict
|
||||||
|
|
||||||
@ -25,6 +26,8 @@ def test_file_dict() -> None:
|
|||||||
|
|
||||||
assert len(cache) == 100
|
assert len(cache) == 100
|
||||||
assert sorted(cache) == sorted([f"key-{i}" for i in range(100)])
|
assert sorted(cache) == sorted([f"key-{i}" for i in range(100)])
|
||||||
|
assert sorted(cache.items()) == sorted([(f"key-{i}", i) for i in range(100)])
|
||||||
|
assert sorted(cache.values()) == sorted([i for i in range(100)])
|
||||||
|
|
||||||
# Force eviction of everything.
|
# Force eviction of everything.
|
||||||
cache.flush()
|
cache.flush()
|
||||||
@ -140,7 +143,7 @@ def test_custom_serde() -> None:
|
|||||||
|
|
||||||
assert cache["second"] == second
|
assert cache["second"] == second
|
||||||
assert cache["first"] == first
|
assert cache["first"] == first
|
||||||
assert serializer_calls == 4 # Items written to cache on every access
|
assert serializer_calls == 2
|
||||||
assert deserializer_calls == 2
|
assert deserializer_calls == 2
|
||||||
|
|
||||||
|
|
||||||
@ -161,6 +164,7 @@ def test_file_dict_stores_counter() -> None:
|
|||||||
cache[str(i)][str(j)] += 100
|
cache[str(i)][str(j)] += 100
|
||||||
in_memory_counters[i][str(j)] += 100
|
in_memory_counters[i][str(j)] += 100
|
||||||
cache[str(i)][str(j)] += j
|
cache[str(i)][str(j)] += j
|
||||||
|
cache.mark_dirty(str(i))
|
||||||
in_memory_counters[i][str(j)] += j
|
in_memory_counters[i][str(j)] += j
|
||||||
|
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
@ -174,43 +178,64 @@ class Pair:
|
|||||||
y: str
|
y: str
|
||||||
|
|
||||||
|
|
||||||
def test_custom_column() -> None:
|
@pytest.mark.parametrize("cache_max_size", [0, 1, 10])
|
||||||
|
def test_custom_column(cache_max_size: int) -> None:
|
||||||
cache = FileBackedDict[Pair](
|
cache = FileBackedDict[Pair](
|
||||||
extra_columns={
|
extra_columns={
|
||||||
"x": lambda m: m.x,
|
"x": lambda m: m.x,
|
||||||
},
|
},
|
||||||
|
cache_max_size=cache_max_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
cache["first"] = Pair(3, "a")
|
cache["first"] = Pair(3, "a")
|
||||||
cache["second"] = Pair(100, "b")
|
cache["second"] = Pair(100, "b")
|
||||||
|
cache["third"] = Pair(27, "c")
|
||||||
|
cache["fourth"] = Pair(100, "d")
|
||||||
|
|
||||||
# Verify that the extra column is present and has the correct values.
|
# Verify that the extra column is present and has the correct values.
|
||||||
assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 103
|
assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 230
|
||||||
|
|
||||||
# Verify that the extra column is updated when the value is updated.
|
# Verify that the extra column is updated when the value is updated.
|
||||||
cache["first"] = Pair(4, "c")
|
cache["first"] = Pair(4, "e")
|
||||||
assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 104
|
assert cache.sql_query(f"SELECT sum(x) FROM {cache.tablename}")[0][0] == 231
|
||||||
|
|
||||||
# Test param binding.
|
# Test param binding.
|
||||||
assert (
|
assert (
|
||||||
cache.sql_query(
|
cache.sql_query(
|
||||||
f"SELECT sum(x) FROM {cache.tablename} WHERE x < ?", params=(50,)
|
f"SELECT sum(x) FROM {cache.tablename} WHERE x < ?", params=(50,)
|
||||||
)[0][0]
|
)[0][0]
|
||||||
== 4
|
== 31
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert sorted(list(cache.items())) == [
|
||||||
|
("first", Pair(4, "e")),
|
||||||
|
("fourth", Pair(100, "d")),
|
||||||
|
("second", Pair(100, "b")),
|
||||||
|
("third", Pair(27, "c")),
|
||||||
|
]
|
||||||
|
assert sorted(list(cache.items_snapshot())) == [
|
||||||
|
("first", Pair(4, "e")),
|
||||||
|
("fourth", Pair(100, "d")),
|
||||||
|
("second", Pair(100, "b")),
|
||||||
|
("third", Pair(27, "c")),
|
||||||
|
]
|
||||||
|
assert sorted(list(cache.items_snapshot("x < 50"))) == [
|
||||||
|
("first", Pair(4, "e")),
|
||||||
|
("third", Pair(27, "c")),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_shared_connection() -> None:
|
def test_shared_connection() -> None:
|
||||||
with ConnectionWrapper() as connection:
|
with ConnectionWrapper() as connection:
|
||||||
cache1 = FileBackedDict[int](
|
cache1 = FileBackedDict[int](
|
||||||
connection=connection,
|
shared_connection=connection,
|
||||||
tablename="cache1",
|
tablename="cache1",
|
||||||
extra_columns={
|
extra_columns={
|
||||||
"v": lambda v: v,
|
"v": lambda v: v,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
cache2 = FileBackedDict[Pair](
|
cache2 = FileBackedDict[Pair](
|
||||||
connection=connection,
|
shared_connection=connection,
|
||||||
tablename="cache2",
|
tablename="cache2",
|
||||||
extra_columns={
|
extra_columns={
|
||||||
"x": lambda m: m.x,
|
"x": lambda m: m.x,
|
||||||
@ -227,10 +252,12 @@ def test_shared_connection() -> None:
|
|||||||
assert len(cache1) == 2
|
assert len(cache1) == 2
|
||||||
assert len(cache2) == 3
|
assert len(cache2) == 3
|
||||||
|
|
||||||
# Test advanced SQL queries.
|
# Test advanced SQL queries and sql_query_iterator.
|
||||||
assert cache2.sql_query(
|
iterator = cache2.sql_query_iterator(
|
||||||
f"SELECT y, sum(x) FROM {cache2.tablename} GROUP BY y ORDER BY y"
|
f"SELECT y, sum(x) FROM {cache2.tablename} GROUP BY y ORDER BY y"
|
||||||
) == [("a", 15), ("b", 11)]
|
)
|
||||||
|
assert type(iterator) == sqlite3.Cursor
|
||||||
|
assert list(iterator) == [("a", 15), ("b", 11)]
|
||||||
|
|
||||||
# Test joining between the two tables.
|
# Test joining between the two tables.
|
||||||
assert (
|
assert (
|
||||||
@ -245,6 +272,12 @@ def test_shared_connection() -> None:
|
|||||||
)
|
)
|
||||||
== [("a", 45), ("b", 55)]
|
== [("a", 45), ("b", 55)]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert list(cache2.items_snapshot('y = "a"')) == [
|
||||||
|
("ref-a-1", Pair(7, "a")),
|
||||||
|
("ref-a-2", Pair(8, "a")),
|
||||||
|
]
|
||||||
|
|
||||||
cache2.close()
|
cache2.close()
|
||||||
|
|
||||||
# Check can still use cache1
|
# Check can still use cache1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user