Add type hints to synapse/storage/databases/main/events_worker.py (#11411)

Also refactor the stream ID trackers/generators a bit and try to
document them better.
This commit is contained in:
Sean Quah 2021-11-26 18:41:31 +00:00 committed by GitHub
parent 1d8b80b334
commit ffd858aa68
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 255 additions and 171 deletions

View file

@ -89,31 +89,77 @@ def _load_current_id(
return (max if step > 0 else min)(current_id, step)
class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_next(self) -> AsyncContextManager[int]:
raise NotImplementedError()
class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
"""Tracks the "current" stream ID of a stream that may have multiple writers.
Stream IDs are monotonically increasing or decreasing integers representing write
transactions. The "current" stream ID is the stream ID such that all transactions
with equal or smaller stream IDs have completed. Since transactions may complete out
of order, this is not the same as the stream ID of the last completed transaction.
Completed transactions include both committed transactions and transactions that
have been rolled back.
"""
@abc.abstractmethod
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
def advance(self, instance_name: str, new_id: int) -> None:
"""Advance the position of the named writer to the given ID, if greater
than existing entry.
"""
raise NotImplementedError()
@abc.abstractmethod
def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
Returns:
The maximum stream id.
"""
raise NotImplementedError()
@abc.abstractmethod
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
For streams with single writers this is equivalent to `get_current_token`.
"""
raise NotImplementedError()
class AbstractStreamIdGenerator(AbstractStreamIdTracker):
"""Generates stream IDs for a stream that may have multiple writers.
Each stream ID represents a write transaction, whose completion is tracked
so that the "current" stream ID of the stream can be determined.
See `AbstractStreamIdTracker` for more details.
"""
@abc.abstractmethod
def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
raise NotImplementedError()
@abc.abstractmethod
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
"""
Usage:
async with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
raise NotImplementedError()
class StreamIdGenerator(AbstractStreamIdGenerator):
"""Used to generate new stream ids when persisting events while keeping
track of which transactions have been completed.
"""Generates and tracks stream IDs for a stream with a single writer.
This allows us to get the "current" stream id, i.e. the stream id such that
all ids less than or equal to it have completed. This handles the fact that
persistence of events can complete out of order.
This class must only be used when the current Synapse process is the sole
writer for a stream.
Args:
db_conn(connection): A database connection to use to fetch the
@ -157,12 +203,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
def advance(self, instance_name: str, new_id: int) -> None:
# `StreamIdGenerator` should only be used when there is a single writer,
# so replication should never happen.
raise Exception("Replication is not supported by StreamIdGenerator")
def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
self._current += self._step
next_id = self._current
@ -180,11 +226,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
"""
Usage:
async with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
next_ids = range(
self._current + self._step,
@ -208,12 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
Returns:
The maximum stream id.
"""
with self._lock:
if self._unfinished_ids:
return next(iter(self._unfinished_ids)) - self._step
@ -221,16 +256,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
For streams with single writers this is equivalent to
`get_current_token`.
"""
return self.get_current_token()
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
"""An ID generator that tracks a stream that can have multiple writers.
"""Generates and tracks stream IDs for a stream with multiple writers.
Uses a Postgres sequence to coordinate ID assignment, but positions of other
writers will only get updated when `advance` is called (by replication).
@ -475,12 +505,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return stream_ids
def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
@ -492,12 +516,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
"""
Usage:
async with stream_id_gen.get_next_mult(5) as stream_ids:
# ... persist events ...
"""
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
@ -597,15 +615,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._add_persisted_position(next_id)
def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer."""
# If we don't have an entry for the given instance name, we assume it's a
# new writer.
#
@ -631,10 +643,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
}
def advance(self, instance_name: str, new_id: int) -> None:
"""Advance the position of the named writer to the given ID, if greater
than existing entry.
"""
new_id *= self._return_factor
with self._lock: