Source code for zict.async_buffer

from __future__ import annotations

import asyncio
import contextvars
from collections.abc import Callable, Collection
from concurrent.futures import Executor, ThreadPoolExecutor
from functools import wraps
from itertools import chain
from typing import Any, Literal

from zict.buffer import Buffer
from zict.common import KT, VT, T, locked


[docs] class AsyncBuffer(Buffer[KT, VT]): """Extension of :class:`~zict.Buffer` that allows offloading all reads and writes from/to slow to a separate worker thread. This requires ``fast`` to be fully thread-safe (e.g. a plain dict). ``slow.__setitem__`` and ``slow.__getitem__`` will be called from the offloaded thread, while all of its other methods (including, notably for the purpose of thread-safety consideration, ``__contains__`` and ``__delitem__``) will be called from the main thread. See Also -------- Buffer Parameters ---------- Same as in Buffer, plus: executor: concurrent.futures.Executor, optional An Executor instance to use for offloading. It must not pickle/unpickle. Defaults to an internal ThreadPoolExecutor. nthreads: int, optional Number of offloaded threads to run in parallel. Defaults to 1. Mutually exclusive with executor parameter. """ executor: Executor | None nthreads: int | None futures: set[asyncio.Future] evicting: dict[asyncio.Future, float] @wraps(Buffer.__init__) def __init__( self, *args: Any, executor: Executor | None = None, nthreads: int = 1, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) self.executor = executor self.nthreads = None if executor else nthreads self._internal_executor = executor is None self.futures = set() self.evicting = {}
[docs] def close(self) -> None: # Call LRU.close(), which stops LRU.evict_until_below_target() halfway through super().close() for future in self.futures: future.cancel() if self.executor is not None and self.nthreads is not None: self.executor.shutdown(wait=True) self.executor = None
def _offload(self, func: Callable[..., T], *args: Any) -> asyncio.Future[T]: if self.executor is None: assert self.nthreads self.executor = ThreadPoolExecutor( self.nthreads, thread_name_prefix="zict.AsyncBuffer offloader" ) loop = asyncio.get_running_loop() context = contextvars.copy_context() future = loop.run_in_executor(self.executor, context.run, func, *args) self.futures.add(future) future.add_done_callback(self.futures.remove) return future # type: ignore[return-value] # Return an asyncio.Future, instead of just writing it as an async function, to make # it easier for overriders to tell apart the use case when all keys were already # in fast
[docs] @locked def async_get( self, keys: Collection[KT], missing: Literal["raise", "omit"] = "raise" ) -> asyncio.Future[dict[KT, VT]]: """Fetch one or more key/value pairs. If not all keys are available in fast, offload to a worker thread moving keys from slow to fast, as well as possibly moving older keys from fast to slow. Parameters ---------- keys: collection of zero or more keys to get missing: raise or omit, optional raise (default) If any key is missing, raise KeyError. omit If a key is missing, return a dict with less keys than those requested. Notes ----- All keys may be present when you call ``async_get``, but ``__delitem__`` may be called on one of them before the actual data is fetched. ``__setitem__`` also internally calls ``__delitem__`` in a non-atomic way, so you may get ``KeyError`` when updating a value too. """ # This block avoids spawning a thread if keys are missing from both fast and # slow. It is otherwise just a performance optimization. if missing == "omit": keys = [key for key in keys if key in self] elif missing == "raise": for key in keys: if key not in self: raise KeyError(key) else: raise ValueError(f"missing: expected raise or omit; got {missing}") # End performance optimization try: # Do not pull keys towards the top of the LRU unless they are all available. # This matters when there is a very long queue of async_get futures. d = self.fast.get_all_or_nothing(keys) except KeyError: pass else: f: asyncio.Future[dict[KT, VT]] = asyncio.Future() f.set_result(d) return f def _async_get() -> dict[KT, VT]: d = {} for k in keys: if self.fast.closed: raise asyncio.CancelledError() try: # This can cause keys to be restored and older keys to be evicted d[k] = self[k] except KeyError: # Race condition: key was there when async_get was called, but got # deleted afterwards. if missing == "raise": raise return d return self._offload(_async_get)
def __setitem__(self, key: KT, value: VT) -> None: """Immediately set a key in fast. If this causes the total weight to exceed n, asynchronously start moving keys from fast to slow in a worker thread. """ self.set_noevict(key, value) self.async_evict_until_below_target()
[docs] @locked def async_evict_until_below_target(self, n: float | None = None) -> None: """If the total weight exceeds n, asynchronously start moving keys from fast to slow in a worker thread. """ if n is None: n = self.n n = max(0.0, n) weight = min(chain([self.fast.total_weight], self.evicting.values())) if weight <= n: return # Note: this can get cancelled by LRU.close(), which in turn is # triggered by Buffer.close() future = self._offload(self.evict_until_below_target, n) self.evicting[future] = n future.add_done_callback(self.evicting.__delitem__)