Fix missing _add_persisted_position (#8179)

This was forgotten in #8164.
This commit is contained in:
Erik Johnston 2020-08-27 13:20:34 +01:00 committed by GitHub
parent 30426c7063
commit 5649b7f3d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 52 additions and 3 deletions

1
changelog.d/8179.misc Normal file
View File

@ -0,0 +1 @@
Add functions to `MultiWriterIdGen` used by events stream.

View File

@ -343,6 +343,8 @@ class MultiWriterIdGenerator:
curr = self._current_positions.get(self._instance_name, 0) curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, next_id) self._current_positions[self._instance_name] = max(curr, next_id)
self._add_persisted_position(next_id)
def get_current_token(self) -> int: def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted. equal to it have been successfully persisted.

View File

@ -58,6 +58,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
return self.get_success(self.db_pool.runWithConnection(_create)) return self.get_success(self.db_pool.runWithConnection(_create))
def _insert_rows(self, instance_name: str, number: int): def _insert_rows(self, instance_name: str, number: int):
"""Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence.
"""
def _insert(txn): def _insert(txn):
for _ in range(number): for _ in range(number):
txn.execute( txn.execute(
@ -65,7 +69,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
(instance_name,), (instance_name,),
) )
self.get_success(self.db_pool.runInteraction("test_single_instance", _insert)) self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
def _insert_row_with_id(self, instance_name: str, stream_id: int):
"""Insert one row as the given instance with given stream_id, updating
the postgres sequence position to match.
"""
def _insert(txn):
txn.execute(
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
)
txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
def test_empty(self): def test_empty(self):
"""Test an ID generator against an empty database gives sensible """Test an ID generator against an empty database gives sensible
@ -188,11 +205,17 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
positions. positions.
""" """
self._insert_rows("first", 3) # The following tests are a bit cheeky in that we notify about new
self._insert_rows("second", 5) # positions via `advance` without *actually* advancing the postgres
# sequence.
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
id_gen = self._create_id_generator("first") id_gen = self._create_id_generator("first")
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
# Min is 3 and there is a gap between 5, so we expect it to be 3. # Min is 3 and there is a gap between 5, so we expect it to be 3.
self.assertEqual(id_gen.get_persisted_upto_position(), 3) self.assertEqual(id_gen.get_persisted_upto_position(), 3)
@ -218,3 +241,26 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen.advance("first", 11) id_gen.advance("first", 11)
id_gen.advance("second", 15) id_gen.advance("second", 15)
self.assertEqual(id_gen.get_persisted_upto_position(), 11) self.assertEqual(id_gen.get_persisted_upto_position(), 11)
def test_get_persisted_upto_position_get_next(self):
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions when `get_next` is called.
"""
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
id_gen = self._create_id_generator("first")
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
with self.get_success(id_gen.get_next()) as stream_id:
self.assertEqual(stream_id, 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
# We assume that so long as `get_next` does correctly advance the
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).