Fix up BatchingQueue (#10078)

Fixes #10068
This commit is contained in:
Erik Johnston 2021-05-27 14:32:31 +01:00 committed by GitHub
parent d9f44fd0b9
commit 78b5102ae7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 125 additions and 24 deletions

1
changelog.d/10078.misc Normal file
View File

@ -0,0 +1 @@
Fix up `BatchingQueue` implementation.

View File

@ -25,10 +25,11 @@ from typing import (
TypeVar, TypeVar,
) )
from prometheus_client import Gauge
from twisted.internet import defer from twisted.internet import defer
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock from synapse.util import Clock
@ -38,6 +39,24 @@ logger = logging.getLogger(__name__)
V = TypeVar("V") V = TypeVar("V")
R = TypeVar("R") R = TypeVar("R")
number_queued = Gauge(
"synapse_util_batching_queue_number_queued",
"The number of items waiting in the queue across all keys",
labelnames=("name",),
)
number_in_flight = Gauge(
"synapse_util_batching_queue_number_pending",
"The number of items across all keys either being processed or waiting in a queue",
labelnames=("name",),
)
number_of_keys = Gauge(
"synapse_util_batching_queue_number_of_keys",
"The number of distinct keys that have items queued",
labelnames=("name",),
)
class BatchingQueue(Generic[V, R]): class BatchingQueue(Generic[V, R]):
"""A queue that batches up work, calling the provided processing function """A queue that batches up work, calling the provided processing function
@ -48,10 +67,20 @@ class BatchingQueue(Generic[V, R]):
called, and will keep being called until the queue has been drained (for the called, and will keep being called until the queue has been drained (for the
given key). given key).
If the processing function raises an exception then the exception is proxied
through to the callers waiting on that batch of work.
Note that the return value of `add_to_queue` will be the return value of the Note that the return value of `add_to_queue` will be the return value of the
processing function that processed the given item. This means that the processing function that processed the given item. This means that the
returned value will likely include data for other items that were in the returned value will likely include data for other items that were in the
batch. batch.
Args:
name: A name for the queue, used for logging contexts and metrics.
This must be unique, otherwise the metrics will be wrong.
clock: The clock to use to schedule work.
process_batch_callback: The callback to to be run to process a batch of
work.
""" """
def __init__( def __init__(
@ -73,19 +102,15 @@ class BatchingQueue(Generic[V, R]):
# The function to call with batches of values. # The function to call with batches of values.
self._process_batch_callback = process_batch_callback self._process_batch_callback = process_batch_callback
LaterGauge( number_queued.labels(self._name).set_function(
"synapse_util_batching_queue_number_queued", lambda: sum(len(q) for q in self._next_values.values())
"The number of items waiting in the queue across all keys",
labels=("name",),
caller=lambda: sum(len(v) for v in self._next_values.values()),
) )
LaterGauge( number_of_keys.labels(self._name).set_function(lambda: len(self._next_values))
"synapse_util_batching_queue_number_of_keys",
"The number of distinct keys that have items queued", self._number_in_flight_metric = number_in_flight.labels(
labels=("name",), self._name
caller=lambda: len(self._next_values), ) # type: Gauge
)
async def add_to_queue(self, value: V, key: Hashable = ()) -> R: async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
"""Adds the value to the queue with the given key, returning the result """Adds the value to the queue with the given key, returning the result
@ -107,6 +132,7 @@ class BatchingQueue(Generic[V, R]):
if key not in self._processing_keys: if key not in self._processing_keys:
run_as_background_process(self._name, self._process_queue, key) run_as_background_process(self._name, self._process_queue, key)
with self._number_in_flight_metric.track_inprogress():
return await make_deferred_yieldable(d) return await make_deferred_yieldable(d)
async def _process_queue(self, key: Hashable) -> None: async def _process_queue(self, key: Hashable) -> None:
@ -114,10 +140,10 @@ class BatchingQueue(Generic[V, R]):
given key and call the `self._process_batch_callback` with the values. given key and call the `self._process_batch_callback` with the values.
""" """
try:
if key in self._processing_keys: if key in self._processing_keys:
return return
try:
self._processing_keys.add(key) self._processing_keys.add(key)
while True: while True:
@ -137,16 +163,16 @@ class BatchingQueue(Generic[V, R]):
values = [value for value, _ in next_values] values = [value for value, _ in next_values]
results = await self._process_batch_callback(values) results = await self._process_batch_callback(values)
for _, deferred in next_values:
with PreserveLoggingContext(): with PreserveLoggingContext():
for _, deferred in next_values:
deferred.callback(results) deferred.callback(results)
except Exception as e: except Exception as e:
with PreserveLoggingContext():
for _, deferred in next_values: for _, deferred in next_values:
if deferred.called: if deferred.called:
continue continue
with PreserveLoggingContext():
deferred.errback(e) deferred.errback(e)
finally: finally:

View File

@ -14,7 +14,12 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.util.batching_queue import BatchingQueue from synapse.util.batching_queue import (
BatchingQueue,
number_in_flight,
number_of_keys,
number_queued,
)
from tests.server import get_clock from tests.server import get_clock
from tests.unittest import TestCase from tests.unittest import TestCase
@ -24,6 +29,14 @@ class BatchingQueueTestCase(TestCase):
def setUp(self): def setUp(self):
self.clock, hs_clock = get_clock() self.clock, hs_clock = get_clock()
# We ensure that we remove any existing metrics for "test_queue".
try:
number_queued.remove("test_queue")
number_of_keys.remove("test_queue")
number_in_flight.remove("test_queue")
except KeyError:
pass
self._pending_calls = [] self._pending_calls = []
self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue) self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)
@ -32,6 +45,41 @@ class BatchingQueueTestCase(TestCase):
self._pending_calls.append((values, d)) self._pending_calls.append((values, d))
return await make_deferred_yieldable(d) return await make_deferred_yieldable(d)
def _assert_metrics(self, queued, keys, in_flight):
"""Assert that the metrics are correct"""
self.assertEqual(len(number_queued.collect()), 1)
self.assertEqual(len(number_queued.collect()[0].samples), 1)
self.assertEqual(
number_queued.collect()[0].samples[0].labels,
{"name": self.queue._name},
)
self.assertEqual(
number_queued.collect()[0].samples[0].value,
queued,
"number_queued",
)
self.assertEqual(len(number_of_keys.collect()), 1)
self.assertEqual(len(number_of_keys.collect()[0].samples), 1)
self.assertEqual(
number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
)
self.assertEqual(
number_of_keys.collect()[0].samples[0].value, keys, "number_of_keys"
)
self.assertEqual(len(number_in_flight.collect()), 1)
self.assertEqual(len(number_in_flight.collect()[0].samples), 1)
self.assertEqual(
number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
)
self.assertEqual(
number_in_flight.collect()[0].samples[0].value,
in_flight,
"number_in_flight",
)
def test_simple(self): def test_simple(self):
"""Tests the basic case of calling `add_to_queue` once and having """Tests the basic case of calling `add_to_queue` once and having
`_process_queue` return. `_process_queue` return.
@ -41,6 +89,8 @@ class BatchingQueueTestCase(TestCase):
queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo")) queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo"))
self._assert_metrics(queued=1, keys=1, in_flight=1)
# The queue should wait a reactor tick before calling the processing # The queue should wait a reactor tick before calling the processing
# function. # function.
self.assertFalse(self._pending_calls) self.assertFalse(self._pending_calls)
@ -52,12 +102,15 @@ class BatchingQueueTestCase(TestCase):
self.assertEqual(len(self._pending_calls), 1) self.assertEqual(len(self._pending_calls), 1)
self.assertEqual(self._pending_calls[0][0], ["foo"]) self.assertEqual(self._pending_calls[0][0], ["foo"])
self.assertFalse(queue_d.called) self.assertFalse(queue_d.called)
self._assert_metrics(queued=0, keys=0, in_flight=1)
# Return value of the `_process_queue` should be propagated back. # Return value of the `_process_queue` should be propagated back.
self._pending_calls.pop()[1].callback("bar") self._pending_calls.pop()[1].callback("bar")
self.assertEqual(self.successResultOf(queue_d), "bar") self.assertEqual(self.successResultOf(queue_d), "bar")
self._assert_metrics(queued=0, keys=0, in_flight=0)
def test_batching(self): def test_batching(self):
"""Test that multiple calls at the same time get batched up into one """Test that multiple calls at the same time get batched up into one
call to `_process_queue`. call to `_process_queue`.
@ -68,6 +121,8 @@ class BatchingQueueTestCase(TestCase):
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1")) queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2")) queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
self._assert_metrics(queued=2, keys=1, in_flight=2)
self.clock.pump([0]) self.clock.pump([0])
# We should see only *one* call to `_process_queue` # We should see only *one* call to `_process_queue`
@ -75,12 +130,14 @@ class BatchingQueueTestCase(TestCase):
self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"]) self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"])
self.assertFalse(queue_d1.called) self.assertFalse(queue_d1.called)
self.assertFalse(queue_d2.called) self.assertFalse(queue_d2.called)
self._assert_metrics(queued=0, keys=0, in_flight=2)
# Return value of the `_process_queue` should be propagated back to both. # Return value of the `_process_queue` should be propagated back to both.
self._pending_calls.pop()[1].callback("bar") self._pending_calls.pop()[1].callback("bar")
self.assertEqual(self.successResultOf(queue_d1), "bar") self.assertEqual(self.successResultOf(queue_d1), "bar")
self.assertEqual(self.successResultOf(queue_d2), "bar") self.assertEqual(self.successResultOf(queue_d2), "bar")
self._assert_metrics(queued=0, keys=0, in_flight=0)
def test_queuing(self): def test_queuing(self):
"""Test that we queue up requests while a `_process_queue` is being """Test that we queue up requests while a `_process_queue` is being
@ -92,13 +149,20 @@ class BatchingQueueTestCase(TestCase):
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1")) queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
self.clock.pump([0]) self.clock.pump([0])
self.assertEqual(len(self._pending_calls), 1)
# We queue up work after the process function has been called, testing
# that they get correctly queued up.
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2")) queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
queue_d3 = defer.ensureDeferred(self.queue.add_to_queue("foo3"))
# We should see only *one* call to `_process_queue` # We should see only *one* call to `_process_queue`
self.assertEqual(len(self._pending_calls), 1) self.assertEqual(len(self._pending_calls), 1)
self.assertEqual(self._pending_calls[0][0], ["foo1"]) self.assertEqual(self._pending_calls[0][0], ["foo1"])
self.assertFalse(queue_d1.called) self.assertFalse(queue_d1.called)
self.assertFalse(queue_d2.called) self.assertFalse(queue_d2.called)
self.assertFalse(queue_d3.called)
self._assert_metrics(queued=2, keys=1, in_flight=3)
# Return value of the `_process_queue` should be propagated back to the # Return value of the `_process_queue` should be propagated back to the
# first. # first.
@ -106,18 +170,24 @@ class BatchingQueueTestCase(TestCase):
self.assertEqual(self.successResultOf(queue_d1), "bar1") self.assertEqual(self.successResultOf(queue_d1), "bar1")
self.assertFalse(queue_d2.called) self.assertFalse(queue_d2.called)
self.assertFalse(queue_d3.called)
self._assert_metrics(queued=2, keys=1, in_flight=2)
# We should now see a second call to `_process_queue` # We should now see a second call to `_process_queue`
self.clock.pump([0]) self.clock.pump([0])
self.assertEqual(len(self._pending_calls), 1) self.assertEqual(len(self._pending_calls), 1)
self.assertEqual(self._pending_calls[0][0], ["foo2"]) self.assertEqual(self._pending_calls[0][0], ["foo2", "foo3"])
self.assertFalse(queue_d2.called) self.assertFalse(queue_d2.called)
self.assertFalse(queue_d3.called)
self._assert_metrics(queued=0, keys=0, in_flight=2)
# Return value of the `_process_queue` should be propagated back to the # Return value of the `_process_queue` should be propagated back to the
# second. # second.
self._pending_calls.pop()[1].callback("bar2") self._pending_calls.pop()[1].callback("bar2")
self.assertEqual(self.successResultOf(queue_d2), "bar2") self.assertEqual(self.successResultOf(queue_d2), "bar2")
self.assertEqual(self.successResultOf(queue_d3), "bar2")
self._assert_metrics(queued=0, keys=0, in_flight=0)
def test_different_keys(self): def test_different_keys(self):
"""Test that calls to different keys get processed in parallel.""" """Test that calls to different keys get processed in parallel."""
@ -140,6 +210,7 @@ class BatchingQueueTestCase(TestCase):
self.assertFalse(queue_d1.called) self.assertFalse(queue_d1.called)
self.assertFalse(queue_d2.called) self.assertFalse(queue_d2.called)
self.assertFalse(queue_d3.called) self.assertFalse(queue_d3.called)
self._assert_metrics(queued=1, keys=1, in_flight=3)
# Return value of the `_process_queue` should be propagated back to the # Return value of the `_process_queue` should be propagated back to the
# first. # first.
@ -148,6 +219,7 @@ class BatchingQueueTestCase(TestCase):
self.assertEqual(self.successResultOf(queue_d1), "bar1") self.assertEqual(self.successResultOf(queue_d1), "bar1")
self.assertFalse(queue_d2.called) self.assertFalse(queue_d2.called)
self.assertFalse(queue_d3.called) self.assertFalse(queue_d3.called)
self._assert_metrics(queued=1, keys=1, in_flight=2)
# Return value of the `_process_queue` should be propagated back to the # Return value of the `_process_queue` should be propagated back to the
# second. # second.
@ -161,9 +233,11 @@ class BatchingQueueTestCase(TestCase):
self.assertEqual(len(self._pending_calls), 1) self.assertEqual(len(self._pending_calls), 1)
self.assertEqual(self._pending_calls[0][0], ["foo3"]) self.assertEqual(self._pending_calls[0][0], ["foo3"])
self.assertFalse(queue_d3.called) self.assertFalse(queue_d3.called)
self._assert_metrics(queued=0, keys=0, in_flight=1)
# Return value of the `_process_queue` should be propagated back to the # Return value of the `_process_queue` should be propagated back to the
# third deferred. # third deferred.
self._pending_calls.pop()[1].callback("bar4") self._pending_calls.pop()[1].callback("bar4")
self.assertEqual(self.successResultOf(queue_d3), "bar4") self.assertEqual(self.successResultOf(queue_d3), "bar4")
self._assert_metrics(queued=0, keys=0, in_flight=0)