Add missing type hints to synapse.replication. (#11938)

This commit is contained in:
Patrick Cloke 2022-02-08 11:03:08 -05:00 committed by GitHub
parent 8c94b3abe9
commit d0e78af35e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 209 additions and 147 deletions

View file

@ -40,7 +40,7 @@ class SlavedIdTracker(AbstractStreamIdTracker):
for table, column in extra_tables:
self.advance(None, _load_current_id(db_conn, table, column))
def advance(self, instance_name: Optional[str], new_id: int):
def advance(self, instance_name: Optional[str], new_id: int) -> None:
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self) -> int:

View file

@ -37,7 +37,9 @@ class SlavedClientIpStore(BaseSlavedStore):
cache_name="client_ip_last_seen", max_size=50000
)
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
async def insert_client_ip(
self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str
) -> None:
now = int(self._clock.time_msec())
key = (user_id, access_token, ip)

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Iterable
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@ -60,7 +60,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows)
@ -70,7 +72,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
def _invalidate_caches_for_devices(self, token, rows):
def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None:
for row in rows:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Iterable
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@ -44,10 +44,12 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
self._group_updates_id_gen.get_current_token(),
)
def get_group_stream_token(self):
def get_group_stream_token(self) -> int:
return self._group_updates_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == GroupServerStream.NAME:
self._group_updates_id_gen.advance(instance_name, token)
for row in rows:

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
@ -20,10 +21,12 @@ from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self):
def get_max_push_rules_stream_id(self) -> int:
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Iterable
from synapse.replication.tcp.streams import PushersStream
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
@ -41,8 +41,8 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
return self._pushers_id_gen.get_current_token()
def process_replication_rows(
self, stream_name: str, instance_name: str, token, rows
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(instance_name, token) # type: ignore
self._pushers_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)