fix(ingest): use contextvar for cooperative timeout (#10021)

This commit is contained in:
Harshal Sheth 2024-03-11 14:14:39 -07:00 committed by GitHub
parent bcae7acc51
commit 92a3ac6f11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 57 additions and 11 deletions

View File

@ -1,9 +1,10 @@
import contextlib
import threading
import contextvars
import time
from typing import Iterator, Optional
_cooperation = threading.local()
# The deadline is an int from time.perf_counter_ns().
_cooperation_deadline = contextvars.ContextVar[int]("cooperation_deadline")
class CooperativeTimeoutError(TimeoutError):
@ -13,7 +14,7 @@ class CooperativeTimeoutError(TimeoutError):
def cooperate() -> None:
"""Method to be called periodically to cooperate with the timeout mechanism."""
deadline = getattr(_cooperation, "deadline", None)
deadline = _cooperation_deadline.get(None)
if deadline is not None and deadline < time.perf_counter_ns():
raise CooperativeTimeoutError("CooperativeTimeout deadline exceeded")
@ -50,15 +51,18 @@ def cooperative_timeout(timeout: Optional[float] = None) -> Iterator[None]:
# (unless you're willing to use some hacks https://stackoverflow.com/a/61528202).
# Attempting to forcibly terminate a thread can deadlock on the GIL.
if hasattr(_cooperation, "deadline"):
deadline = _cooperation_deadline.get(None)
if deadline is not None:
raise RuntimeError("cooperative timeout already active")
if timeout is not None:
_cooperation.deadline = time.perf_counter_ns() + timeout * 1_000_000_000
token = _cooperation_deadline.set(
time.perf_counter_ns() + int(timeout * 1_000_000_000)
)
try:
yield
finally:
del _cooperation.deadline
_cooperation_deadline.reset(token)
else:
# No-op.

View File

@ -1,5 +1,5 @@
import random
from typing import Dict, Iterator, List, Set, TypeVar, Union
from typing import Dict, Generic, Iterator, List, Set, TypeVar, Union
from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2
@ -8,7 +8,7 @@ _KT = TypeVar("_KT")
_VT = TypeVar("_VT")
class LossyList(List[T]):
class LossyList(List[T], Generic[T]):
"""A list that performs reservoir sampling of a much larger list"""
def __init__(self, max_elements: int = 10) -> None:
@ -60,7 +60,7 @@ class LossyList(List[T]):
return base_list
class LossySet(Set[T]):
class LossySet(Set[T], Generic[T]):
"""A set that only preserves a sample of elements in a set. Currently this is a very simple greedy sampling set"""
def __init__(self, max_elements: int = 10) -> None:
@ -101,7 +101,7 @@ class LossySet(Set[T]):
return base_list
class LossyDict(Dict[_KT, _VT]):
class LossyDict(Dict[_KT, _VT], Generic[_KT, _VT]):
"""A structure that only preserves a sample of elements in a dictionary using reservoir sampling."""
def __init__(self, max_elements: int = 10) -> None:

View File

@ -0,0 +1,42 @@
import time
import pytest
from datahub.utilities.cooperative_timeout import (
CooperativeTimeoutError,
cooperate,
cooperative_timeout,
)
def test_cooperate_no_timeout():
# Called outside of a timeout block, should not do anything.
cooperate()
def test_cooperate_with_timeout():
# Set a timeout of 0 seconds, should raise an error immediately
with pytest.raises(CooperativeTimeoutError):
with cooperative_timeout(0):
cooperate()
def test_cooperative_timeout_no_timeout():
# No timeout set, should not raise an error
with cooperative_timeout(timeout=None):
for _ in range(0, 15):
time.sleep(0.01)
cooperate()
def test_cooperative_timeout_with_timeout():
# Set a timeout, and should raise an error after the timeout is hit.
# It should, however, still run at least one iteration.
at_least_one_iteration = False
with pytest.raises(CooperativeTimeoutError):
with cooperative_timeout(0.5):
for _ in range(0, 51):
time.sleep(0.01)
cooperate()
at_least_one_iteration = True
assert at_least_one_iteration

View File

@ -55,7 +55,7 @@ def test_lossydict_sampling(length, sampling, sub_length):
for i in range(0, length):
list_length = random.choice(range(1, sub_length))
element_length_map[i] = 0
for num_elements in range(0, list_length):
for _num_elements in range(0, list_length):
if not l.get(i):
elements_added += 1
# reset to 0 until we get it back