Add type hints to synapse/storage/databases/main/events_worker.py (#11411)

Also refactor the stream ID trackers/generators a bit and try to
document them better.
This commit is contained in:
Sean Quah 2021-11-26 18:41:31 +00:00 committed by GitHub
parent 1d8b80b334
commit ffd858aa68
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 255 additions and 171 deletions

View file

@ -14,10 +14,18 @@
from typing import List, Optional, Tuple
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.util.id_generators import _load_current_id
from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id
class SlavedIdTracker:
class SlavedIdTracker(AbstractStreamIdTracker):
"""Tracks the "current" stream ID of a stream with a single writer.
See `AbstractStreamIdTracker` for more details.
Note that this class does not work correctly when there are multiple
writers.
"""
def __init__(
self,
db_conn: LoggingDatabaseConnection,
@ -36,17 +44,7 @@ class SlavedIdTracker:
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self) -> int:
"""
Returns:
int
"""
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
For streams with single writers this is equivalent to
`get_current_token`.
"""
return self.get_current_token()

View file

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
@ -25,9 +24,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
# We assert this for the benefit of mypy
assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows: