Add missing type hints to synapse.replication.http. (#11856)

This commit is contained in:
Patrick Cloke 2022-02-08 07:44:39 -05:00 committed by GitHub
parent 8b309adb43
commit 63d90f10ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 258 additions and 162 deletions

1
changelog.d/11856.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to replication code.

View File

@ -40,7 +40,7 @@ class ReplicationRestResource(JsonResource):
super().__init__(hs, canonical_json=False, extract_context=True) super().__init__(hs, canonical_json=False, extract_context=True)
self.register_servlets(hs) self.register_servlets(hs)
def register_servlets(self, hs: "HomeServer"): def register_servlets(self, hs: "HomeServer") -> None:
send_event.register_servlets(hs, self) send_event.register_servlets(hs, self)
federation.register_servlets(hs, self) federation.register_servlets(hs, self)
presence.register_servlets(hs, self) presence.register_servlets(hs, self)

View File

@ -15,16 +15,20 @@
import abc import abc
import logging import logging
import re import re
import urllib import urllib.parse
from inspect import signature from inspect import signature
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
from prometheus_client import Counter, Gauge from prometheus_client import Counter, Gauge
from twisted.web.server import Request
from synapse.api.errors import HttpResponseException, SynapseError from synapse.api.errors import HttpResponseException, SynapseError
from synapse.http import RequestTimedOutError from synapse.http import RequestTimedOutError
from synapse.http.server import HttpServer
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.logging.opentracing import trace from synapse.logging.opentracing import trace
from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -113,10 +117,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
if hs.config.worker.worker_replication_secret: if hs.config.worker.worker_replication_secret:
self._replication_secret = hs.config.worker.worker_replication_secret self._replication_secret = hs.config.worker.worker_replication_secret
def _check_auth(self, request) -> None: def _check_auth(self, request: Request) -> None:
# Get the authorization header. # Get the authorization header.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
raise RuntimeError("Missing Authorization header.")
if len(auth_headers) > 1: if len(auth_headers) > 1:
raise RuntimeError("Too many Authorization headers.") raise RuntimeError("Too many Authorization headers.")
parts = auth_headers[0].split(b" ") parts = auth_headers[0].split(b" ")
@ -129,7 +135,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
raise RuntimeError("Invalid Authorization header.") raise RuntimeError("Invalid Authorization header.")
@abc.abstractmethod @abc.abstractmethod
async def _serialize_payload(**kwargs): async def _serialize_payload(**kwargs) -> JsonDict:
"""Static method that is called when creating a request. """Static method that is called when creating a request.
Concrete implementations should have explicit parameters (rather than Concrete implementations should have explicit parameters (rather than
@ -144,19 +150,20 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
return {} return {}
@abc.abstractmethod @abc.abstractmethod
async def _handle_request(self, request, **kwargs): async def _handle_request(
self, request: Request, **kwargs: Any
) -> Tuple[int, JsonDict]:
"""Handle incoming request. """Handle incoming request.
This is called with the request object and PATH_ARGS. This is called with the request object and PATH_ARGS.
Returns: Returns:
tuple[int, dict]: HTTP status code and a JSON serialisable dict HTTP status code and a JSON serialisable dict to be used as response
to be used as response body of request. body of request.
""" """
pass
@classmethod @classmethod
def make_client(cls, hs: "HomeServer"): def make_client(cls, hs: "HomeServer") -> Callable:
"""Create a client that makes requests. """Create a client that makes requests.
Returns a callable that accepts the same parameters as Returns a callable that accepts the same parameters as
@ -182,7 +189,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
) )
@trace(opname="outgoing_replication_request") @trace(opname="outgoing_replication_request")
async def send_request(*, instance_name="master", **kwargs): async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
with outgoing_gauge.track_inprogress(): with outgoing_gauge.track_inprogress():
if instance_name == local_instance_name: if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self") raise Exception("Trying to send HTTP request to self")
@ -268,7 +275,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
return send_request return send_request
def register(self, http_server): def register(self, http_server: HttpServer) -> None:
"""Called by the server to register this as a handler to the """Called by the server to register this as a handler to the
appropriate path. appropriate path.
""" """
@ -289,7 +296,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
self.__class__.__name__, self.__class__.__name__,
) )
async def _check_auth_and_handle(self, request, **kwargs): async def _check_auth_and_handle(
self, request: Request, **kwargs: Any
) -> Tuple[int, JsonDict]:
"""Called on new incoming requests when caching is enabled. Checks """Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that, if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response. otherwise calls `_handle_request` and caches its response.

View File

@ -13,10 +13,14 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -48,14 +52,18 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod @staticmethod
async def _serialize_payload(user_id, account_data_type, content): async def _serialize_payload( # type: ignore[override]
user_id: str, account_data_type: str, content: JsonDict
) -> JsonDict:
payload = { payload = {
"content": content, "content": content,
} }
return payload return payload
async def _handle_request(self, request, user_id, account_data_type): async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_account_data_for_user( max_stream_id = await self.handler.add_account_data_for_user(
@ -89,14 +97,18 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod @staticmethod
async def _serialize_payload(user_id, room_id, account_data_type, content): async def _serialize_payload( # type: ignore[override]
user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> JsonDict:
payload = { payload = {
"content": content, "content": content,
} }
return payload return payload
async def _handle_request(self, request, user_id, room_id, account_data_type): async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_account_data_to_room( max_stream_id = await self.handler.add_account_data_to_room(
@ -130,14 +142,18 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod @staticmethod
async def _serialize_payload(user_id, room_id, tag, content): async def _serialize_payload( # type: ignore[override]
user_id: str, room_id: str, tag: str, content: JsonDict
) -> JsonDict:
payload = { payload = {
"content": content, "content": content,
} }
return payload return payload
async def _handle_request(self, request, user_id, room_id, tag): async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_tag_to_room( max_stream_id = await self.handler.add_tag_to_room(
@ -173,11 +189,13 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod @staticmethod
async def _serialize_payload(user_id, room_id, tag): async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override]
return {} return {}
async def _handle_request(self, request, user_id, room_id, tag): async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str, room_id: str, tag: str
) -> Tuple[int, JsonDict]:
max_stream_id = await self.handler.remove_tag_from_room( max_stream_id = await self.handler.remove_tag_from_room(
user_id, user_id,
room_id, room_id,
@ -187,7 +205,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
return 200, {"max_stream_id": max_stream_id} return 200, {"max_stream_id": max_stream_id}
def register_servlets(hs: "HomeServer", http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationUserAccountDataRestServlet(hs).register(http_server) ReplicationUserAccountDataRestServlet(hs).register(http_server)
ReplicationRoomAccountDataRestServlet(hs).register(http_server) ReplicationRoomAccountDataRestServlet(hs).register(http_server)
ReplicationAddTagRestServlet(hs).register(http_server) ReplicationAddTagRestServlet(hs).register(http_server)

View File

@ -13,9 +13,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -63,14 +67,16 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod @staticmethod
async def _serialize_payload(user_id): async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override]
return {} return {}
async def _handle_request(self, request, user_id): async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
user_devices = await self.device_list_updater.user_device_resync(user_id) user_devices = await self.device_list_updater.user_device_resync(user_id)
return 200, user_devices return 200, user_devices
def register_servlets(hs: "HomeServer", http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationUserDevicesResyncRestServlet(hs).register(http_server) ReplicationUserDevicesResyncRestServlet(hs).register(http_server)

View File

@ -13,17 +13,22 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, List, Tuple
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from twisted.web.server import Request
from synapse.events import make_event_from_dict
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -69,14 +74,18 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.federation_event_handler = hs.get_federation_event_handler() self.federation_event_handler = hs.get_federation_event_handler()
@staticmethod @staticmethod
async def _serialize_payload(store, room_id, event_and_contexts, backfilled): async def _serialize_payload( # type: ignore[override]
store: "DataStore",
room_id: str,
event_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
) -> JsonDict:
""" """
Args: Args:
store store
room_id (str) room_id
event_and_contexts (list[tuple[FrozenEvent, EventContext]]) event_and_contexts
backfilled (bool): Whether or not the events are the result of backfilled: Whether or not the events are the result of backfilling
backfilling
""" """
event_payloads = [] event_payloads = []
for event, context in event_and_contexts: for event, context in event_and_contexts:
@ -102,7 +111,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
return payload return payload
async def _handle_request(self, request): async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # type: ignore[override]
with Measure(self.clock, "repl_fed_send_events_parse"): with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -163,10 +172,14 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
self.registry = hs.get_federation_registry() self.registry = hs.get_federation_registry()
@staticmethod @staticmethod
async def _serialize_payload(edu_type, origin, content): async def _serialize_payload( # type: ignore[override]
edu_type: str, origin: str, content: JsonDict
) -> JsonDict:
return {"origin": origin, "content": content} return {"origin": origin, "content": content}
async def _handle_request(self, request, edu_type): async def _handle_request( # type: ignore[override]
self, request: Request, edu_type: str
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_fed_send_edu_parse"): with Measure(self.clock, "repl_fed_send_edu_parse"):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -175,9 +188,9 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
logger.info("Got %r edu from %s", edu_type, origin) logger.info("Got %r edu from %s", edu_type, origin)
result = await self.registry.on_edu(edu_type, origin, edu_content) await self.registry.on_edu(edu_type, origin, edu_content)
return 200, result return 200, {}
class ReplicationGetQueryRestServlet(ReplicationEndpoint): class ReplicationGetQueryRestServlet(ReplicationEndpoint):
@ -206,15 +219,17 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
self.registry = hs.get_federation_registry() self.registry = hs.get_federation_registry()
@staticmethod @staticmethod
async def _serialize_payload(query_type, args): async def _serialize_payload(query_type: str, args: JsonDict) -> JsonDict: # type: ignore[override]
""" """
Args: Args:
query_type (str) query_type
args (dict): The arguments received for the given query type args: The arguments received for the given query type
""" """
return {"args": args} return {"args": args}
async def _handle_request(self, request, query_type): async def _handle_request( # type: ignore[override]
self, request: Request, query_type: str
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_fed_query_parse"): with Measure(self.clock, "repl_fed_query_parse"):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -248,14 +263,16 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@staticmethod @staticmethod
async def _serialize_payload(room_id, args): async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override]
""" """
Args: Args:
room_id (str) room_id
""" """
return {} return {}
async def _handle_request(self, request, room_id): async def _handle_request( # type: ignore[override]
self, request: Request, room_id: str
) -> Tuple[int, JsonDict]:
await self.store.clean_room_for_join(room_id) await self.store.clean_room_for_join(room_id)
return 200, {} return 200, {}
@ -283,17 +300,19 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@staticmethod @staticmethod
async def _serialize_payload(room_id, room_version): async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDict: # type: ignore[override]
return {"room_version": room_version.identifier} return {"room_version": room_version.identifier}
async def _handle_request(self, request, room_id): async def _handle_request( # type: ignore[override]
self, request: Request, room_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]] room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
await self.store.maybe_store_room_on_outlier_membership(room_id, room_version) await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
return 200, {} return 200, {}
def register_servlets(hs: "HomeServer", http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationFederationSendEventsRestServlet(hs).register(http_server) ReplicationFederationSendEventsRestServlet(hs).register(http_server)
ReplicationFederationSendEduRestServlet(hs).register(http_server) ReplicationFederationSendEduRestServlet(hs).register(http_server)
ReplicationGetQueryRestServlet(hs).register(http_server) ReplicationGetQueryRestServlet(hs).register(http_server)

View File

@ -13,10 +13,14 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional, Tuple, cast
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -39,25 +43,24 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
@staticmethod @staticmethod
async def _serialize_payload( async def _serialize_payload( # type: ignore[override]
user_id, user_id: str,
device_id, device_id: Optional[str],
initial_display_name, initial_display_name: Optional[str],
is_guest, is_guest: bool,
is_appservice_ghost, is_appservice_ghost: bool,
should_issue_refresh_token, should_issue_refresh_token: bool,
auth_provider_id, auth_provider_id: Optional[str],
auth_provider_session_id, auth_provider_session_id: Optional[str],
): ) -> JsonDict:
""" """
Args: Args:
user_id (int) user_id
device_id (str|None): Device ID to use, if None a new one is device_id: Device ID to use, if None a new one is generated.
generated. initial_display_name
initial_display_name (str|None) is_guest
is_guest (bool) is_appservice_ghost
is_appservice_ghost (bool) should_issue_refresh_token
should_issue_refresh_token (bool)
""" """
return { return {
"device_id": device_id, "device_id": device_id,
@ -69,7 +72,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"auth_provider_session_id": auth_provider_session_id, "auth_provider_session_id": auth_provider_session_id,
} }
async def _handle_request(self, request, user_id): async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
device_id = content["device_id"] device_id = content["device_id"]
@ -91,8 +96,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
auth_provider_session_id=auth_provider_session_id, auth_provider_session_id=auth_provider_session_id,
) )
return 200, res return 200, cast(JsonDict, res)
def register_servlets(hs: "HomeServer", http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RegisterDeviceReplicationServlet(hs).register(http_server) RegisterDeviceReplicationServlet(hs).register(http_server)

View File

@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.web.server import Request from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
@ -53,7 +54,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod @staticmethod
async def _serialize_payload( # type: ignore async def _serialize_payload( # type: ignore[override]
requester: Requester, requester: Requester,
room_id: str, room_id: str,
user_id: str, user_id: str,
@ -77,7 +78,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
"content": content, "content": content,
} }
async def _handle_request( # type: ignore async def _handle_request( # type: ignore[override]
self, request: SynapseRequest, room_id: str, user_id: str self, request: SynapseRequest, room_id: str, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -122,13 +123,13 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod @staticmethod
async def _serialize_payload( # type: ignore async def _serialize_payload( # type: ignore[override]
requester: Requester, requester: Requester,
room_id: str, room_id: str,
user_id: str, user_id: str,
remote_room_hosts: List[str], remote_room_hosts: List[str],
content: JsonDict, content: JsonDict,
): ) -> JsonDict:
""" """
Args: Args:
requester: The user making the request, according to the access token. requester: The user making the request, according to the access token.
@ -143,12 +144,12 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
"content": content, "content": content,
} }
async def _handle_request( # type: ignore async def _handle_request( # type: ignore[override]
self, self,
request: SynapseRequest, request: SynapseRequest,
room_id: str, room_id: str,
user_id: str, user_id: str,
): ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"] remote_room_hosts = content["remote_room_hosts"]
@ -192,7 +193,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
self.member_handler = hs.get_room_member_handler() self.member_handler = hs.get_room_member_handler()
@staticmethod @staticmethod
async def _serialize_payload( # type: ignore async def _serialize_payload( # type: ignore[override]
invite_event_id: str, invite_event_id: str,
txn_id: Optional[str], txn_id: Optional[str],
requester: Requester, requester: Requester,
@ -215,7 +216,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"content": content, "content": content,
} }
async def _handle_request( # type: ignore async def _handle_request( # type: ignore[override]
self, request: SynapseRequest, invite_event_id: str self, request: SynapseRequest, invite_event_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -262,12 +263,12 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
self.member_handler = hs.get_room_member_handler() self.member_handler = hs.get_room_member_handler()
@staticmethod @staticmethod
async def _serialize_payload( # type: ignore async def _serialize_payload( # type: ignore[override]
knock_event_id: str, knock_event_id: str,
txn_id: Optional[str], txn_id: Optional[str],
requester: Requester, requester: Requester,
content: JsonDict, content: JsonDict,
): ) -> JsonDict:
""" """
Args: Args:
knock_event_id: The ID of the knock to be rescinded. knock_event_id: The ID of the knock to be rescinded.
@ -281,11 +282,11 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
"content": content, "content": content,
} }
async def _handle_request( # type: ignore async def _handle_request( # type: ignore[override]
self, self,
request: SynapseRequest, request: SynapseRequest,
knock_event_id: str, knock_event_id: str,
): ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
txn_id = content["txn_id"] txn_id = content["txn_id"]
@ -329,7 +330,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
@staticmethod @staticmethod
async def _serialize_payload( # type: ignore async def _serialize_payload( # type: ignore[override]
room_id: str, user_id: str, change: str room_id: str, user_id: str, change: str
) -> JsonDict: ) -> JsonDict:
""" """
@ -345,7 +346,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
return {} return {}
async def _handle_request( # type: ignore async def _handle_request( # type: ignore[override]
self, request: Request, room_id: str, user_id: str, change: str self, request: Request, room_id: str, user_id: str, change: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
logger.info("user membership change: %s in %s", user_id, room_id) logger.info("user membership change: %s in %s", user_id, room_id)
@ -360,7 +361,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
return 200, {} return 200, {}
def register_servlets(hs: "HomeServer", http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationRemoteJoinRestServlet(hs).register(http_server) ReplicationRemoteJoinRestServlet(hs).register(http_server)
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)

View File

@ -13,11 +13,14 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import UserID from synapse.types import JsonDict, UserID
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -49,18 +52,17 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
self._presence_handler = hs.get_presence_handler() self._presence_handler = hs.get_presence_handler()
@staticmethod @staticmethod
async def _serialize_payload(user_id): async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override]
return {} return {}
async def _handle_request(self, request, user_id): async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
await self._presence_handler.bump_presence_active_time( await self._presence_handler.bump_presence_active_time(
UserID.from_string(user_id) UserID.from_string(user_id)
) )
return ( return (200, {})
200,
{},
)
class ReplicationPresenceSetState(ReplicationEndpoint): class ReplicationPresenceSetState(ReplicationEndpoint):
@ -92,16 +94,21 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
self._presence_handler = hs.get_presence_handler() self._presence_handler = hs.get_presence_handler()
@staticmethod @staticmethod
async def _serialize_payload( async def _serialize_payload( # type: ignore[override]
user_id, state, ignore_status_msg=False, force_notify=False user_id: str,
): state: JsonDict,
ignore_status_msg: bool = False,
force_notify: bool = False,
) -> JsonDict:
return { return {
"state": state, "state": state,
"ignore_status_msg": ignore_status_msg, "ignore_status_msg": ignore_status_msg,
"force_notify": force_notify, "force_notify": force_notify,
} }
async def _handle_request(self, request, user_id): async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
await self._presence_handler.set_state( await self._presence_handler.set_state(
@ -111,12 +118,9 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
content["force_notify"], content["force_notify"],
) )
return ( return (200, {})
200,
{},
)
def register_servlets(hs: "HomeServer", http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationBumpPresenceActiveTime(hs).register(http_server) ReplicationBumpPresenceActiveTime(hs).register(http_server)
ReplicationPresenceSetState(hs).register(http_server) ReplicationPresenceSetState(hs).register(http_server)

View File

@ -13,10 +13,14 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -48,7 +52,7 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
self.pusher_pool = hs.get_pusherpool() self.pusher_pool = hs.get_pusherpool()
@staticmethod @staticmethod
async def _serialize_payload(app_id, pushkey, user_id): async def _serialize_payload(app_id: str, pushkey: str, user_id: str) -> JsonDict: # type: ignore[override]
payload = { payload = {
"app_id": app_id, "app_id": app_id,
"pushkey": pushkey, "pushkey": pushkey,
@ -56,7 +60,9 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
return payload return payload
async def _handle_request(self, request, user_id): async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
app_id = content["app_id"] app_id = content["app_id"]
@ -67,5 +73,5 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
return 200, {} return 200, {}
def register_servlets(hs: "HomeServer", http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationRemovePusherRestServlet(hs).register(http_server) ReplicationRemovePusherRestServlet(hs).register(http_server)

View File

@ -13,10 +13,14 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional, Tuple
from twisted.web.server import Request
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -36,34 +40,34 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
@staticmethod @staticmethod
async def _serialize_payload( async def _serialize_payload( # type: ignore[override]
user_id, user_id: str,
password_hash, password_hash: Optional[str],
was_guest, was_guest: bool,
make_guest, make_guest: bool,
appservice_id, appservice_id: Optional[str],
create_profile_with_displayname, create_profile_with_displayname: Optional[str],
admin, admin: bool,
user_type, user_type: Optional[str],
address, address: Optional[str],
shadow_banned, shadow_banned: bool,
): ) -> JsonDict:
""" """
Args: Args:
user_id (str): The desired user ID to register. user_id: The desired user ID to register.
password_hash (str|None): Optional. The password hash for this user. password_hash: Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being was_guest: Optional. Whether this is a guest account being upgraded
upgraded to a non-guest account. to a non-guest account.
make_guest (boolean): True if the the new user should be guest, make_guest: True if the the new user should be guest, false to add a
false to add a regular user account. regular user account.
appservice_id (str|None): The ID of the appservice registering the user. appservice_id: The ID of the appservice registering the user.
create_profile_with_displayname (unicode|None): Optionally create a create_profile_with_displayname: Optionally create a profile for the
profile for the user, setting their displayname to the given value user, setting their displayname to the given value
admin (boolean): is an admin user? admin: is an admin user?
user_type (str|None): type of user. One of the values from user_type: type of user. One of the values from api.constants.UserTypes,
api.constants.UserTypes, or None for a normal user. or None for a normal user.
address (str|None): the IP address used to perform the regitration. address: the IP address used to perform the regitration.
shadow_banned (bool): Whether to shadow-ban the user shadow_banned: Whether to shadow-ban the user
""" """
return { return {
"password_hash": password_hash, "password_hash": password_hash,
@ -77,7 +81,9 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"shadow_banned": shadow_banned, "shadow_banned": shadow_banned,
} }
async def _handle_request(self, request, user_id): async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
await self.registration_handler.check_registration_ratelimit(content["address"]) await self.registration_handler.check_registration_ratelimit(content["address"])
@ -110,18 +116,21 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
@staticmethod @staticmethod
async def _serialize_payload(user_id, auth_result, access_token): async def _serialize_payload( # type: ignore[override]
user_id: str, auth_result: JsonDict, access_token: Optional[str]
) -> JsonDict:
""" """
Args: Args:
user_id (str): The user ID that consented user_id: The user ID that consented
auth_result (dict): The authenticated credentials of the newly auth_result: The authenticated credentials of the newly registered user.
registered user. access_token: The access token of the newly logged in
access_token (str|None): The access token of the newly logged in
device, or None if `inhibit_login` enabled. device, or None if `inhibit_login` enabled.
""" """
return {"auth_result": auth_result, "access_token": access_token} return {"auth_result": auth_result, "access_token": access_token}
async def _handle_request(self, request, user_id): async def _handle_request( # type: ignore[override]
self, request: Request, user_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
auth_result = content["auth_result"] auth_result = content["auth_result"]
@ -134,6 +143,6 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
return 200, {} return 200, {}
def register_servlets(hs: "HomeServer", http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationRegisterServlet(hs).register(http_server) ReplicationRegisterServlet(hs).register(http_server)
ReplicationPostRegisterActionsServlet(hs).register(http_server) ReplicationPostRegisterActionsServlet(hs).register(http_server)

View File

@ -13,18 +13,22 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, List, Tuple
from twisted.web.server import Request
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID from synapse.types import JsonDict, Requester, UserID
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -70,18 +74,24 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock() self.clock = hs.get_clock()
@staticmethod @staticmethod
async def _serialize_payload( async def _serialize_payload( # type: ignore[override]
event_id, store, event, context, requester, ratelimit, extra_users event_id: str,
): store: "DataStore",
event: EventBase,
context: EventContext,
requester: Requester,
ratelimit: bool,
extra_users: List[UserID],
) -> JsonDict:
""" """
Args: Args:
event_id (str) event_id
store (DataStore) store
requester (Requester) requester
event (FrozenEvent) event
context (EventContext) context
ratelimit (bool) ratelimit
extra_users (list(UserID)): Any extra users to notify about event extra_users: Any extra users to notify about event
""" """
serialized_context = await context.serialize(event, store) serialized_context = await context.serialize(event, store)
@ -100,7 +110,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
return payload return payload
async def _handle_request(self, request, event_id): async def _handle_request( # type: ignore[override]
self, request: Request, event_id: str
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_send_event_parse"): with Measure(self.clock, "repl_send_event_parse"):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -120,8 +132,6 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
ratelimit = content["ratelimit"] ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]] extra_users = [UserID.from_string(u) for u in content["extra_users"]]
request.requester = requester
logger.info( logger.info(
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id "Got event to send with ID: %s into room: %s", event.event_id, event.room_id
) )
@ -139,5 +149,5 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
) )
def register_servlets(hs: "HomeServer", http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationSendEventRestServlet(hs).register(http_server) ReplicationSendEventRestServlet(hs).register(http_server)

View File

@ -13,11 +13,15 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_integer from synapse.http.servlet import parse_integer
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -57,10 +61,14 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
self.streams = hs.get_replication_streams() self.streams = hs.get_replication_streams()
@staticmethod @staticmethod
async def _serialize_payload(stream_name, from_token, upto_token): async def _serialize_payload( # type: ignore[override]
stream_name: str, from_token: int, upto_token: int
) -> JsonDict:
return {"from_token": from_token, "upto_token": upto_token} return {"from_token": from_token, "upto_token": upto_token}
async def _handle_request(self, request, stream_name): async def _handle_request( # type: ignore[override]
self, request: Request, stream_name: str
) -> Tuple[int, JsonDict]:
stream = self.streams.get(stream_name) stream = self.streams.get(stream_name)
if stream is None: if stream is None:
raise SynapseError(400, "Unknown stream") raise SynapseError(400, "Unknown stream")
@ -78,5 +86,5 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
) )
def register_servlets(hs: "HomeServer", http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationGetStreamUpdates(hs).register(http_server) ReplicationGetStreamUpdates(hs).register(http_server)