mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Make MultiWriterIDGenerator work for streams that use negative stream IDs (#8203)
This is so that we can use it for the backfill events stream.
This commit is contained in:
parent
318245eaa6
commit
bbb3c8641c
1
changelog.d/8203.misc
Normal file
1
changelog.d/8203.misc
Normal file
@ -0,0 +1 @@
|
||||
Make `MultiWriterIDGenerator` work for streams that use negative values.
|
@ -185,6 +185,8 @@ class MultiWriterIdGenerator:
|
||||
id_column: Column that stores the stream ID.
|
||||
sequence_name: The name of the postgres sequence used to generate new
|
||||
IDs.
|
||||
positive: Whether the IDs are positive (true) or negative (false).
|
||||
When using negative IDs we go backwards from -1 to -2, -3, etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -196,13 +198,19 @@ class MultiWriterIdGenerator:
|
||||
instance_column: str,
|
||||
id_column: str,
|
||||
sequence_name: str,
|
||||
positive: bool = True,
|
||||
):
|
||||
self._db = db
|
||||
self._instance_name = instance_name
|
||||
self._positive = positive
|
||||
self._return_factor = 1 if positive else -1
|
||||
|
||||
# We lock as some functions may be called from DB threads.
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Note: If we are a negative stream then we still store all the IDs as
|
||||
# positive to make life easier for us, and simply negate the IDs when we
|
||||
# return them.
|
||||
self._current_positions = self._load_current_ids(
|
||||
db_conn, table, instance_column, id_column
|
||||
)
|
||||
@ -233,13 +241,16 @@ class MultiWriterIdGenerator:
|
||||
def _load_current_ids(
|
||||
self, db_conn, table: str, instance_column: str, id_column: str
|
||||
) -> Dict[str, int]:
|
||||
# If positive stream aggregate via MAX. For negative stream use MIN
|
||||
# *and* negate the result to get a positive number.
|
||||
sql = """
|
||||
SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
|
||||
SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
|
||||
GROUP BY %(instance)s
|
||||
""" % {
|
||||
"instance": instance_column,
|
||||
"id": id_column,
|
||||
"table": table,
|
||||
"agg": "MAX" if self._positive else "-MIN",
|
||||
}
|
||||
|
||||
cur = db_conn.cursor()
|
||||
@ -269,15 +280,16 @@ class MultiWriterIdGenerator:
|
||||
# Assert the fetched ID is actually greater than what we currently
|
||||
# believe the ID to be. If not, then the sequence and table have got
|
||||
# out of sync somehow.
|
||||
assert self.get_current_token_for_writer(self._instance_name) < next_id
|
||||
|
||||
with self._lock:
|
||||
assert self._current_positions.get(self._instance_name, 0) < next_id
|
||||
|
||||
self._unfinished_ids.add(next_id)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_id
|
||||
# Multiply by the return factor so that the ID has correct sign.
|
||||
yield self._return_factor * next_id
|
||||
finally:
|
||||
self._mark_id_as_finished(next_id)
|
||||
|
||||
@ -296,15 +308,15 @@ class MultiWriterIdGenerator:
|
||||
# Assert the fetched ID is actually greater than any ID we've already
|
||||
# seen. If not, then the sequence and table have got out of sync
|
||||
# somehow.
|
||||
assert max(self.get_positions().values(), default=0) < min(next_ids)
|
||||
|
||||
with self._lock:
|
||||
assert max(self._current_positions.values(), default=0) < min(next_ids)
|
||||
|
||||
self._unfinished_ids.update(next_ids)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_ids
|
||||
yield [self._return_factor * i for i in next_ids]
|
||||
finally:
|
||||
for i in next_ids:
|
||||
self._mark_id_as_finished(i)
|
||||
@ -327,7 +339,7 @@ class MultiWriterIdGenerator:
|
||||
txn.call_after(self._mark_id_as_finished, next_id)
|
||||
txn.call_on_exception(self._mark_id_as_finished, next_id)
|
||||
|
||||
return next_id
|
||||
return self._return_factor * next_id
|
||||
|
||||
def _mark_id_as_finished(self, next_id: int):
|
||||
"""The ID has finished being processed so we should advance the
|
||||
@ -359,20 +371,25 @@ class MultiWriterIdGenerator:
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
return self._current_positions.get(instance_name, 0)
|
||||
return self._return_factor * self._current_positions.get(instance_name, 0)
|
||||
|
||||
def get_positions(self) -> Dict[str, int]:
|
||||
"""Get a copy of the current positon map.
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
return dict(self._current_positions)
|
||||
return {
|
||||
name: self._return_factor * i
|
||||
for name, i in self._current_positions.items()
|
||||
}
|
||||
|
||||
def advance(self, instance_name: str, new_id: int):
|
||||
"""Advance the postion of the named writer to the given ID, if greater
|
||||
than existing entry.
|
||||
"""
|
||||
|
||||
new_id *= self._return_factor
|
||||
|
||||
with self._lock:
|
||||
self._current_positions[instance_name] = max(
|
||||
new_id, self._current_positions.get(instance_name, 0)
|
||||
@ -390,7 +407,7 @@ class MultiWriterIdGenerator:
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
return self._persisted_upto_position
|
||||
return self._return_factor * self._persisted_upto_position
|
||||
|
||||
def _add_persisted_position(self, new_id: int):
|
||||
"""Record that we have persisted a position.
|
||||
|
@ -264,3 +264,108 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||
# We assume that so long as `get_next` does correctly advance the
|
||||
# `persisted_upto_position` in this case, then it will be correct in the
|
||||
# other cases that are tested above (since they'll hit the same code).
|
||||
|
||||
|
||||
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.
|
||||
"""
|
||||
|
||||
if not USE_POSTGRES_FOR_TESTS:
|
||||
skip = "Requires Postgres"
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.db_pool = self.store.db_pool # type: DatabasePool
|
||||
|
||||
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
||||
|
||||
def _setup_db(self, txn):
|
||||
txn.execute("CREATE SEQUENCE foobar_seq")
|
||||
txn.execute(
|
||||
"""
|
||||
CREATE TABLE foobar (
|
||||
stream_id BIGINT NOT NULL,
|
||||
instance_name TEXT NOT NULL,
|
||||
data TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
|
||||
def _create(conn):
|
||||
return MultiWriterIdGenerator(
|
||||
conn,
|
||||
self.db_pool,
|
||||
instance_name=instance_name,
|
||||
table="foobar",
|
||||
instance_column="instance_name",
|
||||
id_column="stream_id",
|
||||
sequence_name="foobar_seq",
|
||||
positive=False,
|
||||
)
|
||||
|
||||
return self.get_success(self.db_pool.runWithConnection(_create))
|
||||
|
||||
def _insert_row(self, instance_name: str, stream_id: int):
|
||||
"""Insert one row as the given instance with given stream_id.
|
||||
"""
|
||||
|
||||
def _insert(txn):
|
||||
txn.execute(
|
||||
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
|
||||
)
|
||||
|
||||
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
|
||||
|
||||
def test_single_instance(self):
|
||||
"""Test that reads and writes from a single process are handled
|
||||
correctly.
|
||||
"""
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
with self.get_success(id_gen.get_next()) as stream_id:
|
||||
self._insert_row("master", stream_id)
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": -1})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
|
||||
|
||||
with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
|
||||
for stream_id in stream_ids:
|
||||
self._insert_row("master", stream_id)
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": -4})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), -4)
|
||||
|
||||
# Test loading from DB by creating a second ID gen
|
||||
second_id_gen = self._create_id_generator()
|
||||
|
||||
self.assertEqual(second_id_gen.get_positions(), {"master": -4})
|
||||
self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
|
||||
self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
|
||||
|
||||
def test_multiple_instance(self):
|
||||
"""Tests that having multiple instances that get advanced over
|
||||
federation works corretly.
|
||||
"""
|
||||
id_gen_1 = self._create_id_generator("first")
|
||||
id_gen_2 = self._create_id_generator("second")
|
||||
|
||||
with self.get_success(id_gen_1.get_next()) as stream_id:
|
||||
self._insert_row("first", stream_id)
|
||||
id_gen_2.advance("first", stream_id)
|
||||
|
||||
self.assertEqual(id_gen_1.get_positions(), {"first": -1})
|
||||
self.assertEqual(id_gen_2.get_positions(), {"first": -1})
|
||||
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
|
||||
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
|
||||
|
||||
with self.get_success(id_gen_2.get_next()) as stream_id:
|
||||
self._insert_row("second", stream_id)
|
||||
id_gen_1.advance("second", stream_id)
|
||||
|
||||
self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
|
||||
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
|
||||
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
|
||||
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
|
||||
|
Loading…
Reference in New Issue
Block a user