Separate get_current_token into two. (#8113)

The function is used for two purposes: 1) for subscribers of streams to
get a token they can use to get further updates with, and 2) for
replication to track position of the writers of the stream.

For streams with a single writer the two scenarios produce the same
result, however the situation becomes complicated for streams with
multiple writers. The current `MultiWriterIdGenerator` does not
correctly handle the first case (which is not an issue as its only used
for the `caches` stream which nothing subscribes to outside of
replication).
This commit is contained in:
Erik Johnston 2020-08-19 10:39:31 +01:00 committed by GitHub
parent f40645e60b
commit 76d21d14a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 47 additions and 20 deletions

1
changelog.d/8113.misc Normal file
View File

@ -0,0 +1 @@
Separate `get_current_token` into two since there are two different use cases for it.

View File

@ -33,3 +33,11 @@ class SlavedIdTracker(object):
int int
""" """
return self._current return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
For streams with single writers this is equivalent to
`get_current_token`.
"""
return self.get_current_token()

View File

@ -405,7 +405,7 @@ class CachesStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
store.get_cache_stream_token, store.get_cache_stream_token_for_writer,
store.get_all_updated_caches, store.get_all_updated_caches,
) )

View File

@ -299,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
}, },
) )
def get_cache_stream_token(self, instance_name): def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
if self._cache_id_gen: if self._cache_id_gen:
return self._cache_id_gen.get_current_token(instance_name) return self._cache_id_gen.get_current_token_for_writer(instance_name)
else: else:
return 0 return 0

View File

@ -158,6 +158,14 @@ class StreamIdGenerator(object):
return self._current return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
For streams with single writers this is equivalent to
`get_current_token`.
"""
return self.get_current_token()
class ChainedIdGenerator(object): class ChainedIdGenerator(object):
"""Used to generate new stream ids where the stream must be kept in sync """Used to generate new stream ids where the stream must be kept in sync
@ -216,6 +224,14 @@ class ChainedIdGenerator(object):
"Attempted to advance token on source for table %r", self._table "Attempted to advance token on source for table %r", self._table
) )
def get_current_token_for_writer(self, instance_name: str) -> Tuple[int, int]:
"""Returns the position of the given writer.
For streams with single writers this is equivalent to
`get_current_token`.
"""
return self.get_current_token()
class MultiWriterIdGenerator: class MultiWriterIdGenerator:
"""An ID generator that tracks a stream that can have multiple writers. """An ID generator that tracks a stream that can have multiple writers.
@ -298,7 +314,7 @@ class MultiWriterIdGenerator:
# Assert the fetched ID is actually greater than what we currently # Assert the fetched ID is actually greater than what we currently
# believe the ID to be. If not, then the sequence and table have got # believe the ID to be. If not, then the sequence and table have got
# out of sync somehow. # out of sync somehow.
assert self.get_current_token() < next_id assert self.get_current_token_for_writer(self._instance_name) < next_id
with self._lock: with self._lock:
self._unfinished_ids.add(next_id) self._unfinished_ids.add(next_id)
@ -344,16 +360,18 @@ class MultiWriterIdGenerator:
curr = self._current_positions.get(self._instance_name, 0) curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, next_id) self._current_positions[self._instance_name] = max(curr, next_id)
def get_current_token(self, instance_name: str = None) -> int: def get_current_token(self) -> int:
"""Gets the current position of a named writer (defaults to current """Returns the maximum stream id such that all stream ids less than or
instance). equal to it have been successfully persisted.
Returns 0 if we don't have a position for the named writer (likely due
to it being a new writer).
""" """
if instance_name is None: # Currently we don't support this operation, as it's not obvious how to
instance_name = self._instance_name # condense the stream positions of multiple writers into a single int.
raise NotImplementedError()
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
"""
with self._lock: with self._lock:
return self._current_positions.get(instance_name, 0) return self._current_positions.get(instance_name, 0)

View File

@ -88,7 +88,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator() id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token("master"), 7) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position # Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager. # advanced after we leave the context manager.
@ -98,12 +98,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8) self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token("master"), 7) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(_get_next_async()) self.get_success(_get_next_async())
self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token("master"), 8) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_multi_instance(self): def test_multi_instance(self):
"""Test that reads and writes from multiple processes are handled """Test that reads and writes from multiple processes are handled
@ -116,8 +116,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen = self._create_id_generator("second") second_id_gen = self._create_id_generator("second")
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token("first"), 3) self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
self.assertEqual(first_id_gen.get_current_token("second"), 7) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
# Try allocating a new ID gen and check that we only see position # Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager. # advanced after we leave the context manager.
@ -166,7 +166,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator() id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token("master"), 7) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Try allocating a new ID gen and check that we only see position # Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager. # advanced after we leave the context manager.
@ -176,9 +176,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8) self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token("master"), 7) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.get_success(self.db_pool.runInteraction("test", _get_next_txn)) self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token("master"), 8) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)