From 75ca0a6168f92dab3255839cf85fb0df3a0076c3 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 27 Oct 2021 17:27:23 +0100 Subject: [PATCH] Annotate `log_function` decorator (#10943) Co-authored-by: Patrick Cloke --- changelog.d/10943.misc | 1 + synapse/federation/federation_client.py | 17 ++++++++++++-- synapse/federation/federation_server.py | 10 +++++---- .../federation/sender/transaction_manager.py | 1 - synapse/federation/transport/client.py | 22 +++++++++++++++---- synapse/handlers/directory.py | 2 +- synapse/handlers/federation_event.py | 2 +- synapse/handlers/presence.py | 2 ++ synapse/handlers/profile.py | 4 ++++ synapse/logging/utils.py | 8 +++++-- synapse/state/__init__.py | 5 +++-- synapse/storage/databases/main/profile.py | 2 +- 12 files changed, 58 insertions(+), 18 deletions(-) create mode 100644 changelog.d/10943.misc diff --git a/changelog.d/10943.misc b/changelog.d/10943.misc new file mode 100644 index 000000000..3ce28d1a6 --- /dev/null +++ b/changelog.d/10943.misc @@ -0,0 +1 @@ +Add type annotations for the `log_function` decorator. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 2ab4dec88..670186f54 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -227,7 +227,7 @@ class FederationClient(FederationBase): ) async def backfill( - self, dest: str, room_id: str, limit: int, extremities: Iterable[str] + self, dest: str, room_id: str, limit: int, extremities: Collection[str] ) -> Optional[List[EventBase]]: """Requests some more historic PDUs for the given room from the given destination server. @@ -237,6 +237,8 @@ class FederationClient(FederationBase): room_id: The room_id to backfill. limit: The maximum number of events to return. extremities: our current backwards extremities, to backfill from + Must be a Collection that is falsy when empty. + (Iterable is not enough here!) """ logger.debug("backfill extrem=%s", extremities) @@ -250,11 +252,22 @@ class FederationClient(FederationBase): logger.debug("backfill transaction_data=%r", transaction_data) + if not isinstance(transaction_data, dict): + # TODO we probably want an exception type specific to federation + # client validation. + raise TypeError("Backfill transaction_data is not a dict.") + + transaction_data_pdus = transaction_data.get("pdus") + if not isinstance(transaction_data_pdus, list): + # TODO we probably want an exception type specific to federation + # client validation. + raise TypeError("transaction_data.pdus is not a list.") + room_version = await self.store.get_room_version(room_id) pdus = [ event_from_pdu_json(p, room_version, outlier=False) - for p in transaction_data["pdus"] + for p in transaction_data_pdus ] # Check signatures and hash of pdus, removing any from the list that fail checks diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 0d66034f4..32a75993d 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -295,14 +295,16 @@ class FederationServer(FederationBase): Returns: HTTP response code and body """ - response = await self.transaction_actions.have_responded(origin, transaction) + existing_response = await self.transaction_actions.have_responded( + origin, transaction + ) - if response: + if existing_response: logger.debug( "[%s] We've already responded to this request", transaction.transaction_id, ) - return response + return existing_response logger.debug("[%s] Transaction is new", transaction.transaction_id) @@ -632,7 +634,7 @@ class FederationServer(FederationBase): async def on_make_knock_request( self, origin: str, room_id: str, user_id: str, supported_versions: List[str] - ) -> Dict[str, Union[EventBase, str]]: + ) -> JsonDict: """We've received a /make_knock/ request, so we create a partial knock event for the room and hand that back, along with the room version, to the knocking homeserver. We do *not* persist or process this event until the other server has diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index dc555cca0..ab935e5a7 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -149,7 +149,6 @@ class TransactionManager: ) except HttpResponseException as e: code = e.code - response = e.response set_tag(tags.ERROR, True) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 8b247fe20..d96317883 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -15,7 +15,19 @@ import logging import urllib -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union +from typing import ( + Any, + Awaitable, + Callable, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + Union, +) import attr import ijson @@ -100,7 +112,7 @@ class TransportLayerClient: @log_function async def backfill( - self, destination: str, room_id: str, event_tuples: Iterable[str], limit: int + self, destination: str, room_id: str, event_tuples: Collection[str], limit: int ) -> Optional[JsonDict]: """Requests `limit` previous PDUs in a given context before list of PDUs. @@ -108,7 +120,9 @@ class TransportLayerClient: Args: destination room_id - event_tuples + event_tuples: + Must be a Collection that is falsy when empty. + (Iterable is not enough here!) limit Returns: @@ -786,7 +800,7 @@ class TransportLayerClient: @log_function def join_group( self, destination: str, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: + ) -> Awaitable[JsonDict]: """Attempts to join a group""" path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 8567cb0e0..8ca5f60b1 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -245,7 +245,7 @@ class DirectoryHandler: servers = result.servers else: try: - fed_result = await self.federation.make_query( + fed_result: Optional[JsonDict] = await self.federation.make_query( destination=room_alias.domain, query_type="directory", args={"room_alias": room_alias.to_string()}, diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index bd1fa08ce..e617db4c0 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -477,7 +477,7 @@ class FederationEventHandler: @log_function async def backfill( - self, dest: str, room_id: str, limit: int, extremities: Iterable[str] + self, dest: str, room_id: str, limit: int, extremities: Collection[str] ) -> None: """Trigger a backfill request to `dest` for the given `room_id` diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index fdab50da3..3df872c57 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -52,6 +52,7 @@ import synapse.metrics from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState +from synapse.appservice import ApplicationService from synapse.events.presence_router import PresenceRouter from synapse.logging.context import run_in_background from synapse.logging.utils import log_function @@ -1551,6 +1552,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): is_guest: bool = False, explicit_room_id: Optional[str] = None, include_offline: bool = True, + service: Optional[ApplicationService] = None, ) -> Tuple[List[UserPresenceState], int]: # The process for getting presence events are: # 1. Get the rooms the user is in. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index e6c3cf585..6b5a6ded8 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -456,7 +456,11 @@ class ProfileHandler: continue new_name = profile.get("displayname") + if not isinstance(new_name, str): + new_name = None new_avatar = profile.get("avatar_url") + if not isinstance(new_avatar, str): + new_avatar = None # We always hit update to update the last_check timestamp await self.store.update_remote_profile_cache(user_id, new_name, new_avatar) diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py index 08895e72e..4a01b902c 100644 --- a/synapse/logging/utils.py +++ b/synapse/logging/utils.py @@ -16,6 +16,7 @@ import logging from functools import wraps from inspect import getcallargs +from typing import Callable, TypeVar, cast _TIME_FUNC_ID = 0 @@ -41,7 +42,10 @@ def _log_debug_as_f(f, msg, msg_args): logger.handle(record) -def log_function(f): +F = TypeVar("F", bound=Callable) + + +def log_function(f: F) -> F: """Function decorator that logs every call to that function.""" func_name = f.__name__ @@ -69,4 +73,4 @@ def log_function(f): return f(*args, **kwargs) wrapped.__name__ = func_name - return wrapped + return cast(F, wrapped) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 5cf2e1257..98a023975 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -26,6 +26,7 @@ from typing import ( FrozenSet, Iterable, List, + Mapping, Optional, Sequence, Set, @@ -519,7 +520,7 @@ class StateResolutionHandler: self, room_id: str, room_version: str, - state_groups_ids: Dict[int, StateMap[str]], + state_groups_ids: Mapping[int, StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "StateResolutionStore", ) -> _StateCacheEntry: @@ -703,7 +704,7 @@ class StateResolutionHandler: def _make_state_cache_entry( - new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]] + new_state: StateMap[str], state_groups_ids: Mapping[int, StateMap[str]] ) -> _StateCacheEntry: """Given a resolved state, and a set of input state groups, pick one to base a new state group on (if any), and return an appropriately-constructed diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index ba7075caa..dd8e27e22 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -91,7 +91,7 @@ class ProfileWorkerStore(SQLBaseStore): ) async def update_remote_profile_cache( - self, user_id: str, displayname: str, avatar_url: str + self, user_id: str, displayname: Optional[str], avatar_url: Optional[str] ) -> int: return await self.db_pool.simple_update( table="remote_profile_cache",