mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
184 lines
6.3 KiB
Python
184 lines
6.3 KiB
Python
#
|
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
|
#
|
|
# Copyright (C) 2023 New Vector, Ltd
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as
|
|
# published by the Free Software Foundation, either version 3 of the
|
|
# License, or (at your option) any later version.
|
|
#
|
|
# See the GNU Affero General Public License for more details:
|
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
|
#
|
|
# Originally licensed under the Apache License, Version 2.0:
|
|
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
|
#
|
|
# [This file includes modifications made by New Vector Limited]
|
|
#
|
|
#
|
|
|
|
import logging
|
|
from typing import (
|
|
Awaitable,
|
|
Callable,
|
|
Dict,
|
|
Generic,
|
|
Hashable,
|
|
List,
|
|
Set,
|
|
Tuple,
|
|
TypeVar,
|
|
)
|
|
|
|
from prometheus_client import Gauge
|
|
|
|
from twisted.internet import defer
|
|
|
|
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
|
from synapse.util import Clock
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
V = TypeVar("V")
|
|
R = TypeVar("R")
|
|
|
|
number_queued = Gauge(
|
|
"synapse_util_batching_queue_number_queued",
|
|
"The number of items waiting in the queue across all keys",
|
|
labelnames=("name",),
|
|
)
|
|
|
|
number_in_flight = Gauge(
|
|
"synapse_util_batching_queue_number_pending",
|
|
"The number of items across all keys either being processed or waiting in a queue",
|
|
labelnames=("name",),
|
|
)
|
|
|
|
number_of_keys = Gauge(
|
|
"synapse_util_batching_queue_number_of_keys",
|
|
"The number of distinct keys that have items queued",
|
|
labelnames=("name",),
|
|
)
|
|
|
|
|
|
class BatchingQueue(Generic[V, R]):
|
|
"""A queue that batches up work, calling the provided processing function
|
|
with all pending work (for a given key).
|
|
|
|
The provided processing function will only be called once at a time for each
|
|
key. It will be called the next reactor tick after `add_to_queue` has been
|
|
called, and will keep being called until the queue has been drained (for the
|
|
given key).
|
|
|
|
If the processing function raises an exception then the exception is proxied
|
|
through to the callers waiting on that batch of work.
|
|
|
|
Note that the return value of `add_to_queue` will be the return value of the
|
|
processing function that processed the given item. This means that the
|
|
returned value will likely include data for other items that were in the
|
|
batch.
|
|
|
|
Args:
|
|
name: A name for the queue, used for logging contexts and metrics.
|
|
This must be unique, otherwise the metrics will be wrong.
|
|
clock: The clock to use to schedule work.
|
|
process_batch_callback: The callback to to be run to process a batch of
|
|
work.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
clock: Clock,
|
|
process_batch_callback: Callable[[List[V]], Awaitable[R]],
|
|
):
|
|
self._name = name
|
|
self._clock = clock
|
|
|
|
# The set of keys currently being processed.
|
|
self._processing_keys: Set[Hashable] = set()
|
|
|
|
# The currently pending batch of values by key, with a Deferred to call
|
|
# with the result of the corresponding `_process_batch_callback` call.
|
|
self._next_values: Dict[Hashable, List[Tuple[V, defer.Deferred]]] = {}
|
|
|
|
# The function to call with batches of values.
|
|
self._process_batch_callback = process_batch_callback
|
|
|
|
number_queued.labels(self._name).set_function(
|
|
lambda: sum(len(q) for q in self._next_values.values())
|
|
)
|
|
|
|
number_of_keys.labels(self._name).set_function(lambda: len(self._next_values))
|
|
|
|
self._number_in_flight_metric: Gauge = number_in_flight.labels(self._name)
|
|
|
|
async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
|
|
"""Adds the value to the queue with the given key, returning the result
|
|
of the processing function for the batch that included the given value.
|
|
|
|
The optional `key` argument allows sharding the queue by some key. The
|
|
queues will then be processed in parallel, i.e. the process batch
|
|
function will be called in parallel with batched values from a single
|
|
key.
|
|
"""
|
|
|
|
# First we create a defer and add it and the value to the list of
|
|
# pending items.
|
|
d: defer.Deferred[R] = defer.Deferred()
|
|
self._next_values.setdefault(key, []).append((value, d))
|
|
|
|
# If we're not currently processing the key fire off a background
|
|
# process to start processing.
|
|
if key not in self._processing_keys:
|
|
run_as_background_process(self._name, self._process_queue, key)
|
|
|
|
with self._number_in_flight_metric.track_inprogress():
|
|
return await make_deferred_yieldable(d)
|
|
|
|
async def _process_queue(self, key: Hashable) -> None:
|
|
"""A background task to repeatedly pull things off the queue for the
|
|
given key and call the `self._process_batch_callback` with the values.
|
|
"""
|
|
|
|
if key in self._processing_keys:
|
|
return
|
|
|
|
try:
|
|
self._processing_keys.add(key)
|
|
|
|
while True:
|
|
# We purposefully wait a reactor tick to allow us to batch
|
|
# together requests that we're about to receive. A common
|
|
# pattern is to call `add_to_queue` multiple times at once, and
|
|
# deferring to the next reactor tick allows us to batch all of
|
|
# those up.
|
|
await self._clock.sleep(0)
|
|
|
|
next_values = self._next_values.pop(key, [])
|
|
if not next_values:
|
|
# We've exhausted the queue.
|
|
break
|
|
|
|
try:
|
|
values = [value for value, _ in next_values]
|
|
results = await self._process_batch_callback(values)
|
|
|
|
with PreserveLoggingContext():
|
|
for _, deferred in next_values:
|
|
deferred.callback(results)
|
|
|
|
except Exception as e:
|
|
with PreserveLoggingContext():
|
|
for _, deferred in next_values:
|
|
if deferred.called:
|
|
continue
|
|
|
|
deferred.errback(e)
|
|
|
|
finally:
|
|
self._processing_keys.discard(key)
|