mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-12-27 01:29:24 -05:00
Add functions to MultiWriterIdGen
used by events stream (#8164)
This commit is contained in:
parent
5099bd68da
commit
eba98fb024
1
changelog.d/8164.misc
Normal file
1
changelog.d/8164.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add functions to `MultiWriterIdGen` used by events stream.
|
@ -14,9 +14,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import heapq
|
||||||
import threading
|
import threading
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Dict, Set
|
from typing import Dict, List, Set
|
||||||
|
|
||||||
from typing_extensions import Deque
|
from typing_extensions import Deque
|
||||||
|
|
||||||
@ -210,6 +211,23 @@ class MultiWriterIdGenerator:
|
|||||||
# should be less than the minimum of this set (if not empty).
|
# should be less than the minimum of this set (if not empty).
|
||||||
self._unfinished_ids = set() # type: Set[int]
|
self._unfinished_ids = set() # type: Set[int]
|
||||||
|
|
||||||
|
# We track the max position where we know everything before has been
|
||||||
|
# persisted. This is done by a) looking at the min across all instances
|
||||||
|
# and b) noting that if we have seen a run of persisted positions
|
||||||
|
# without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
|
||||||
|
#
|
||||||
|
# Note: There is no guarentee that the IDs generated by the sequence
|
||||||
|
# will be gapless; gaps can form when e.g. a transaction was rolled
|
||||||
|
# back. This means that sometimes we won't be able to skip forward the
|
||||||
|
# position even though everything has been persisted. However, since
|
||||||
|
# gaps should be relatively rare it's still worth doing the book keeping
|
||||||
|
# that allows us to skip forwards when there are gapless runs of
|
||||||
|
# positions.
|
||||||
|
self._persisted_upto_position = (
|
||||||
|
min(self._current_positions.values()) if self._current_positions else 0
|
||||||
|
)
|
||||||
|
self._known_persisted_positions = [] # type: List[int]
|
||||||
|
|
||||||
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
|
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
|
||||||
|
|
||||||
def _load_current_ids(
|
def _load_current_ids(
|
||||||
@ -234,9 +252,12 @@ class MultiWriterIdGenerator:
|
|||||||
|
|
||||||
return current_positions
|
return current_positions
|
||||||
|
|
||||||
def _load_next_id_txn(self, txn):
|
def _load_next_id_txn(self, txn) -> int:
|
||||||
return self._sequence_gen.get_next_id_txn(txn)
|
return self._sequence_gen.get_next_id_txn(txn)
|
||||||
|
|
||||||
|
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
|
||||||
|
return self._sequence_gen.get_next_mult_txn(txn, n)
|
||||||
|
|
||||||
async def get_next(self):
|
async def get_next(self):
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
@ -262,6 +283,34 @@ class MultiWriterIdGenerator:
|
|||||||
|
|
||||||
return manager()
|
return manager()
|
||||||
|
|
||||||
|
async def get_next_mult(self, n: int):
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
with await stream_id_gen.get_next_mult(5) as stream_ids:
|
||||||
|
# ... persist events ...
|
||||||
|
"""
|
||||||
|
next_ids = await self._db.runInteraction(
|
||||||
|
"_load_next_mult_id", self._load_next_mult_id_txn, n
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the fetched ID is actually greater than any ID we've already
|
||||||
|
# seen. If not, then the sequence and table have got out of sync
|
||||||
|
# somehow.
|
||||||
|
assert max(self.get_positions().values(), default=0) < min(next_ids)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self._unfinished_ids.update(next_ids)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def manager():
|
||||||
|
try:
|
||||||
|
yield next_ids
|
||||||
|
finally:
|
||||||
|
for i in next_ids:
|
||||||
|
self._mark_id_as_finished(i)
|
||||||
|
|
||||||
|
return manager()
|
||||||
|
|
||||||
def get_next_txn(self, txn: LoggingTransaction):
|
def get_next_txn(self, txn: LoggingTransaction):
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
@ -326,3 +375,53 @@ class MultiWriterIdGenerator:
|
|||||||
self._current_positions[instance_name] = max(
|
self._current_positions[instance_name] = max(
|
||||||
new_id, self._current_positions.get(instance_name, 0)
|
new_id, self._current_positions.get(instance_name, 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._add_persisted_position(new_id)
|
||||||
|
|
||||||
|
def get_persisted_upto_position(self) -> int:
|
||||||
|
"""Get the max position where all previous positions have been
|
||||||
|
persisted.
|
||||||
|
|
||||||
|
Note: In the worst case scenario this will be equal to the minimum
|
||||||
|
position across writers. This means that the returned position here can
|
||||||
|
lag if one writer doesn't write very often.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
return self._persisted_upto_position
|
||||||
|
|
||||||
|
def _add_persisted_position(self, new_id: int):
|
||||||
|
"""Record that we have persisted a position.
|
||||||
|
|
||||||
|
This is used to keep the `_current_positions` up to date.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We require that the lock is locked by caller
|
||||||
|
assert self._lock.locked()
|
||||||
|
|
||||||
|
heapq.heappush(self._known_persisted_positions, new_id)
|
||||||
|
|
||||||
|
# We move the current min position up if the minimum current positions
|
||||||
|
# of all instances is higher (since by definition all positions less
|
||||||
|
# that that have been persisted).
|
||||||
|
min_curr = min(self._current_positions.values())
|
||||||
|
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
|
||||||
|
|
||||||
|
# We now iterate through the seen positions, discarding those that are
|
||||||
|
# less than the current min positions, and incrementing the min position
|
||||||
|
# if its exactly one greater.
|
||||||
|
#
|
||||||
|
# This is also where we discard items from `_known_persisted_positions`
|
||||||
|
# (to ensure the list doesn't infinitely grow).
|
||||||
|
while self._known_persisted_positions:
|
||||||
|
if self._known_persisted_positions[0] <= self._persisted_upto_position:
|
||||||
|
heapq.heappop(self._known_persisted_positions)
|
||||||
|
elif (
|
||||||
|
self._known_persisted_positions[0] == self._persisted_upto_position + 1
|
||||||
|
):
|
||||||
|
heapq.heappop(self._known_persisted_positions)
|
||||||
|
self._persisted_upto_position += 1
|
||||||
|
else:
|
||||||
|
# There was a gap in seen positions, so there is nothing more to
|
||||||
|
# do.
|
||||||
|
break
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import abc
|
import abc
|
||||||
import threading
|
import threading
|
||||||
from typing import Callable, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
@ -39,6 +39,12 @@ class PostgresSequenceGenerator(SequenceGenerator):
|
|||||||
txn.execute("SELECT nextval(?)", (self._sequence_name,))
|
txn.execute("SELECT nextval(?)", (self._sequence_name,))
|
||||||
return txn.fetchone()[0]
|
return txn.fetchone()[0]
|
||||||
|
|
||||||
|
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
|
||||||
|
txn.execute(
|
||||||
|
"SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
|
||||||
|
)
|
||||||
|
return [i for (i,) in txn]
|
||||||
|
|
||||||
|
|
||||||
GetFirstCallbackType = Callable[[Cursor], int]
|
GetFirstCallbackType = Callable[[Cursor], int]
|
||||||
|
|
||||||
|
@ -182,3 +182,39 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
||||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
|
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
|
||||||
|
|
||||||
|
def test_get_persisted_upto_position(self):
|
||||||
|
"""Test that `get_persisted_upto_position` correctly tracks updates to
|
||||||
|
positions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._insert_rows("first", 3)
|
||||||
|
self._insert_rows("second", 5)
|
||||||
|
|
||||||
|
id_gen = self._create_id_generator("first")
|
||||||
|
|
||||||
|
# Min is 3 and there is a gap between 5, so we expect it to be 3.
|
||||||
|
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
||||||
|
|
||||||
|
# We advance "first" straight to 6. Min is now 5 but there is no gap so
|
||||||
|
# we expect it to be 6
|
||||||
|
id_gen.advance("first", 6)
|
||||||
|
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
|
||||||
|
|
||||||
|
# No gap, so we expect 7.
|
||||||
|
id_gen.advance("second", 7)
|
||||||
|
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
||||||
|
|
||||||
|
# We haven't seen 8 yet, so we expect 7 still.
|
||||||
|
id_gen.advance("second", 9)
|
||||||
|
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
||||||
|
|
||||||
|
# Now that we've seen 7, 8 and 9 we can got straight to 9.
|
||||||
|
id_gen.advance("first", 8)
|
||||||
|
self.assertEqual(id_gen.get_persisted_upto_position(), 9)
|
||||||
|
|
||||||
|
# Jump forward with gaps. The minimum is 11, even though we haven't seen
|
||||||
|
# 10 we know that everything before 11 must be persisted.
|
||||||
|
id_gen.advance("first", 11)
|
||||||
|
id_gen.advance("second", 15)
|
||||||
|
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
|
||||||
|
Loading…
Reference in New Issue
Block a user