mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-11 22:44:19 -05:00
Add missing type hints to synapse.replication.http. (#11856)
This commit is contained in:
parent
8b309adb43
commit
63d90f10ec
1
changelog.d/11856.misc
Normal file
1
changelog.d/11856.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add missing type hints to replication code.
|
@ -40,7 +40,7 @@ class ReplicationRestResource(JsonResource):
|
|||||||
super().__init__(hs, canonical_json=False, extract_context=True)
|
super().__init__(hs, canonical_json=False, extract_context=True)
|
||||||
self.register_servlets(hs)
|
self.register_servlets(hs)
|
||||||
|
|
||||||
def register_servlets(self, hs: "HomeServer"):
|
def register_servlets(self, hs: "HomeServer") -> None:
|
||||||
send_event.register_servlets(hs, self)
|
send_event.register_servlets(hs, self)
|
||||||
federation.register_servlets(hs, self)
|
federation.register_servlets(hs, self)
|
||||||
presence.register_servlets(hs, self)
|
presence.register_servlets(hs, self)
|
||||||
|
@ -15,16 +15,20 @@
|
|||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import urllib
|
import urllib.parse
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
|
||||||
|
|
||||||
from prometheus_client import Counter, Gauge
|
from prometheus_client import Counter, Gauge
|
||||||
|
|
||||||
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse.api.errors import HttpResponseException, SynapseError
|
from synapse.api.errors import HttpResponseException, SynapseError
|
||||||
from synapse.http import RequestTimedOutError
|
from synapse.http import RequestTimedOutError
|
||||||
|
from synapse.http.server import HttpServer
|
||||||
from synapse.logging import opentracing
|
from synapse.logging import opentracing
|
||||||
from synapse.logging.opentracing import trace
|
from synapse.logging.opentracing import trace
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
@ -113,10 +117,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|||||||
if hs.config.worker.worker_replication_secret:
|
if hs.config.worker.worker_replication_secret:
|
||||||
self._replication_secret = hs.config.worker.worker_replication_secret
|
self._replication_secret = hs.config.worker.worker_replication_secret
|
||||||
|
|
||||||
def _check_auth(self, request) -> None:
|
def _check_auth(self, request: Request) -> None:
|
||||||
# Get the authorization header.
|
# Get the authorization header.
|
||||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||||
|
|
||||||
|
if not auth_headers:
|
||||||
|
raise RuntimeError("Missing Authorization header.")
|
||||||
if len(auth_headers) > 1:
|
if len(auth_headers) > 1:
|
||||||
raise RuntimeError("Too many Authorization headers.")
|
raise RuntimeError("Too many Authorization headers.")
|
||||||
parts = auth_headers[0].split(b" ")
|
parts = auth_headers[0].split(b" ")
|
||||||
@ -129,7 +135,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|||||||
raise RuntimeError("Invalid Authorization header.")
|
raise RuntimeError("Invalid Authorization header.")
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def _serialize_payload(**kwargs):
|
async def _serialize_payload(**kwargs) -> JsonDict:
|
||||||
"""Static method that is called when creating a request.
|
"""Static method that is called when creating a request.
|
||||||
|
|
||||||
Concrete implementations should have explicit parameters (rather than
|
Concrete implementations should have explicit parameters (rather than
|
||||||
@ -144,19 +150,20 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def _handle_request(self, request, **kwargs):
|
async def _handle_request(
|
||||||
|
self, request: Request, **kwargs: Any
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
"""Handle incoming request.
|
"""Handle incoming request.
|
||||||
|
|
||||||
This is called with the request object and PATH_ARGS.
|
This is called with the request object and PATH_ARGS.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[int, dict]: HTTP status code and a JSON serialisable dict
|
HTTP status code and a JSON serialisable dict to be used as response
|
||||||
to be used as response body of request.
|
body of request.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_client(cls, hs: "HomeServer"):
|
def make_client(cls, hs: "HomeServer") -> Callable:
|
||||||
"""Create a client that makes requests.
|
"""Create a client that makes requests.
|
||||||
|
|
||||||
Returns a callable that accepts the same parameters as
|
Returns a callable that accepts the same parameters as
|
||||||
@ -182,7 +189,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@trace(opname="outgoing_replication_request")
|
@trace(opname="outgoing_replication_request")
|
||||||
async def send_request(*, instance_name="master", **kwargs):
|
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
|
||||||
with outgoing_gauge.track_inprogress():
|
with outgoing_gauge.track_inprogress():
|
||||||
if instance_name == local_instance_name:
|
if instance_name == local_instance_name:
|
||||||
raise Exception("Trying to send HTTP request to self")
|
raise Exception("Trying to send HTTP request to self")
|
||||||
@ -268,7 +275,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|||||||
|
|
||||||
return send_request
|
return send_request
|
||||||
|
|
||||||
def register(self, http_server):
|
def register(self, http_server: HttpServer) -> None:
|
||||||
"""Called by the server to register this as a handler to the
|
"""Called by the server to register this as a handler to the
|
||||||
appropriate path.
|
appropriate path.
|
||||||
"""
|
"""
|
||||||
@ -289,7 +296,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|||||||
self.__class__.__name__,
|
self.__class__.__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _check_auth_and_handle(self, request, **kwargs):
|
async def _check_auth_and_handle(
|
||||||
|
self, request: Request, **kwargs: Any
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
"""Called on new incoming requests when caching is enabled. Checks
|
"""Called on new incoming requests when caching is enabled. Checks
|
||||||
if there is a cached response for the request and returns that,
|
if there is a cached response for the request and returns that,
|
||||||
otherwise calls `_handle_request` and caches its response.
|
otherwise calls `_handle_request` and caches its response.
|
||||||
|
@ -13,10 +13,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
|
from twisted.web.server import Request
|
||||||
|
|
||||||
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
@ -48,14 +52,18 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
|
|||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(user_id, account_data_type, content):
|
async def _serialize_payload( # type: ignore[override]
|
||||||
|
user_id: str, account_data_type: str, content: JsonDict
|
||||||
|
) -> JsonDict:
|
||||||
payload = {
|
payload = {
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id, account_data_type):
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, user_id: str, account_data_type: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
max_stream_id = await self.handler.add_account_data_for_user(
|
max_stream_id = await self.handler.add_account_data_for_user(
|
||||||
@ -89,14 +97,18 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
|
|||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(user_id, room_id, account_data_type, content):
|
async def _serialize_payload( # type: ignore[override]
|
||||||
|
user_id: str, room_id: str, account_data_type: str, content: JsonDict
|
||||||
|
) -> JsonDict:
|
||||||
payload = {
|
payload = {
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id, room_id, account_data_type):
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, user_id: str, room_id: str, account_data_type: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
max_stream_id = await self.handler.add_account_data_to_room(
|
max_stream_id = await self.handler.add_account_data_to_room(
|
||||||
@ -130,14 +142,18 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint):
|
|||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(user_id, room_id, tag, content):
|
async def _serialize_payload( # type: ignore[override]
|
||||||
|
user_id: str, room_id: str, tag: str, content: JsonDict
|
||||||
|
) -> JsonDict:
|
||||||
payload = {
|
payload = {
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id, room_id, tag):
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, user_id: str, room_id: str, tag: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
max_stream_id = await self.handler.add_tag_to_room(
|
max_stream_id = await self.handler.add_tag_to_room(
|
||||||
@ -173,11 +189,13 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
|
|||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(user_id, room_id, tag):
|
async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override]
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id, room_id, tag):
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, user_id: str, room_id: str, tag: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
max_stream_id = await self.handler.remove_tag_from_room(
|
max_stream_id = await self.handler.remove_tag_from_room(
|
||||||
user_id,
|
user_id,
|
||||||
room_id,
|
room_id,
|
||||||
@ -187,7 +205,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
|
|||||||
return 200, {"max_stream_id": max_stream_id}
|
return 200, {"max_stream_id": max_stream_id}
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs: "HomeServer", http_server):
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
ReplicationUserAccountDataRestServlet(hs).register(http_server)
|
ReplicationUserAccountDataRestServlet(hs).register(http_server)
|
||||||
ReplicationRoomAccountDataRestServlet(hs).register(http_server)
|
ReplicationRoomAccountDataRestServlet(hs).register(http_server)
|
||||||
ReplicationAddTagRestServlet(hs).register(http_server)
|
ReplicationAddTagRestServlet(hs).register(http_server)
|
||||||
|
@ -13,9 +13,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
|
from twisted.web.server import Request
|
||||||
|
|
||||||
|
from synapse.http.server import HttpServer
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
@ -63,14 +67,16 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
|
|||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(user_id):
|
async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override]
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id):
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
user_devices = await self.device_list_updater.user_device_resync(user_id)
|
user_devices = await self.device_list_updater.user_device_resync(user_id)
|
||||||
|
|
||||||
return 200, user_devices
|
return 200, user_devices
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs: "HomeServer", http_server):
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
|
ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
|
||||||
|
@ -13,17 +13,22 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Tuple
|
||||||
|
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from twisted.web.server import Request
|
||||||
from synapse.events import make_event_from_dict
|
|
||||||
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||||
|
from synapse.events import EventBase, make_event_from_dict
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.databases.main import DataStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -69,14 +74,18 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||||||
self.federation_event_handler = hs.get_federation_event_handler()
|
self.federation_event_handler = hs.get_federation_event_handler()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
|
async def _serialize_payload( # type: ignore[override]
|
||||||
|
store: "DataStore",
|
||||||
|
room_id: str,
|
||||||
|
event_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||||
|
backfilled: bool,
|
||||||
|
) -> JsonDict:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
store
|
store
|
||||||
room_id (str)
|
room_id
|
||||||
event_and_contexts (list[tuple[FrozenEvent, EventContext]])
|
event_and_contexts
|
||||||
backfilled (bool): Whether or not the events are the result of
|
backfilled: Whether or not the events are the result of backfilling
|
||||||
backfilling
|
|
||||||
"""
|
"""
|
||||||
event_payloads = []
|
event_payloads = []
|
||||||
for event, context in event_and_contexts:
|
for event, context in event_and_contexts:
|
||||||
@ -102,7 +111,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def _handle_request(self, request):
|
async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # type: ignore[override]
|
||||||
with Measure(self.clock, "repl_fed_send_events_parse"):
|
with Measure(self.clock, "repl_fed_send_events_parse"):
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
@ -163,10 +172,14 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
|
|||||||
self.registry = hs.get_federation_registry()
|
self.registry = hs.get_federation_registry()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(edu_type, origin, content):
|
async def _serialize_payload( # type: ignore[override]
|
||||||
|
edu_type: str, origin: str, content: JsonDict
|
||||||
|
) -> JsonDict:
|
||||||
return {"origin": origin, "content": content}
|
return {"origin": origin, "content": content}
|
||||||
|
|
||||||
async def _handle_request(self, request, edu_type):
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, edu_type: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
with Measure(self.clock, "repl_fed_send_edu_parse"):
|
with Measure(self.clock, "repl_fed_send_edu_parse"):
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
@ -175,9 +188,9 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
|
|||||||
|
|
||||||
logger.info("Got %r edu from %s", edu_type, origin)
|
logger.info("Got %r edu from %s", edu_type, origin)
|
||||||
|
|
||||||
result = await self.registry.on_edu(edu_type, origin, edu_content)
|
await self.registry.on_edu(edu_type, origin, edu_content)
|
||||||
|
|
||||||
return 200, result
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
class ReplicationGetQueryRestServlet(ReplicationEndpoint):
|
class ReplicationGetQueryRestServlet(ReplicationEndpoint):
|
||||||
@ -206,15 +219,17 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
|
|||||||
self.registry = hs.get_federation_registry()
|
self.registry = hs.get_federation_registry()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(query_type, args):
|
async def _serialize_payload(query_type: str, args: JsonDict) -> JsonDict: # type: ignore[override]
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
query_type (str)
|
query_type
|
||||||
args (dict): The arguments received for the given query type
|
args: The arguments received for the given query type
|
||||||
"""
|
"""
|
||||||
return {"args": args}
|
return {"args": args}
|
||||||
|
|
||||||
async def _handle_request(self, request, query_type):
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, query_type: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
with Measure(self.clock, "repl_fed_query_parse"):
|
with Measure(self.clock, "repl_fed_query_parse"):
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
@ -248,14 +263,16 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
|
|||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(room_id, args):
|
async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override]
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
room_id (str)
|
room_id
|
||||||
"""
|
"""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request(self, request, room_id):
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, room_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
await self.store.clean_room_for_join(room_id)
|
await self.store.clean_room_for_join(room_id)
|
||||||
|
|
||||||
return 200, {}
|
return 200, {}
|
||||||
@ -283,17 +300,19 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
|
|||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(room_id, room_version):
|
async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDict: # type: ignore[override]
|
||||||
return {"room_version": room_version.identifier}
|
return {"room_version": room_version.identifier}
|
||||||
|
|
||||||
async def _handle_request(self, request, room_id):
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, room_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
|
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
|
||||||
await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
|
await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs: "HomeServer", http_server):
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
ReplicationFederationSendEventsRestServlet(hs).register(http_server)
|
ReplicationFederationSendEventsRestServlet(hs).register(http_server)
|
||||||
ReplicationFederationSendEduRestServlet(hs).register(http_server)
|
ReplicationFederationSendEduRestServlet(hs).register(http_server)
|
||||||
ReplicationGetQueryRestServlet(hs).register(http_server)
|
ReplicationGetQueryRestServlet(hs).register(http_server)
|
||||||
|
@ -13,10 +13,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Optional, Tuple, cast
|
||||||
|
|
||||||
|
from twisted.web.server import Request
|
||||||
|
|
||||||
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
@ -39,25 +43,24 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
|||||||
self.registration_handler = hs.get_registration_handler()
|
self.registration_handler = hs.get_registration_handler()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(
|
async def _serialize_payload( # type: ignore[override]
|
||||||
user_id,
|
user_id: str,
|
||||||
device_id,
|
device_id: Optional[str],
|
||||||
initial_display_name,
|
initial_display_name: Optional[str],
|
||||||
is_guest,
|
is_guest: bool,
|
||||||
is_appservice_ghost,
|
is_appservice_ghost: bool,
|
||||||
should_issue_refresh_token,
|
should_issue_refresh_token: bool,
|
||||||
auth_provider_id,
|
auth_provider_id: Optional[str],
|
||||||
auth_provider_session_id,
|
auth_provider_session_id: Optional[str],
|
||||||
):
|
) -> JsonDict:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
user_id (int)
|
user_id
|
||||||
device_id (str|None): Device ID to use, if None a new one is
|
device_id: Device ID to use, if None a new one is generated.
|
||||||
generated.
|
initial_display_name
|
||||||
initial_display_name (str|None)
|
is_guest
|
||||||
is_guest (bool)
|
is_appservice_ghost
|
||||||
is_appservice_ghost (bool)
|
should_issue_refresh_token
|
||||||
should_issue_refresh_token (bool)
|
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
@ -69,7 +72,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
|||||||
"auth_provider_session_id": auth_provider_session_id,
|
"auth_provider_session_id": auth_provider_session_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id):
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
device_id = content["device_id"]
|
device_id = content["device_id"]
|
||||||
@ -91,8 +96,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
|||||||
auth_provider_session_id=auth_provider_session_id,
|
auth_provider_session_id=auth_provider_session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, res
|
return 200, cast(JsonDict, res)
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs: "HomeServer", http_server):
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
RegisterDeviceReplicationServlet(hs).register(http_server)
|
RegisterDeviceReplicationServlet(hs).register(http_server)
|
||||||
|
@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
|
|||||||
|
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
|
||||||
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
@ -53,7 +54,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
|
|||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload( # type: ignore
|
async def _serialize_payload( # type: ignore[override]
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
@ -77,7 +78,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
|
|||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: SynapseRequest, room_id: str, user_id: str
|
self, request: SynapseRequest, room_id: str, user_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
@ -122,13 +123,13 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
|
|||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload( # type: ignore
|
async def _serialize_payload( # type: ignore[override]
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
remote_room_hosts: List[str],
|
remote_room_hosts: List[str],
|
||||||
content: JsonDict,
|
content: JsonDict,
|
||||||
):
|
) -> JsonDict:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
requester: The user making the request, according to the access token.
|
requester: The user making the request, according to the access token.
|
||||||
@ -143,12 +144,12 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
|
|||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore
|
async def _handle_request( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
):
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
remote_room_hosts = content["remote_room_hosts"]
|
remote_room_hosts = content["remote_room_hosts"]
|
||||||
@ -192,7 +193,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
|||||||
self.member_handler = hs.get_room_member_handler()
|
self.member_handler = hs.get_room_member_handler()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload( # type: ignore
|
async def _serialize_payload( # type: ignore[override]
|
||||||
invite_event_id: str,
|
invite_event_id: str,
|
||||||
txn_id: Optional[str],
|
txn_id: Optional[str],
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
@ -215,7 +216,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
|||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: SynapseRequest, invite_event_id: str
|
self, request: SynapseRequest, invite_event_id: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
@ -262,12 +263,12 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
|
|||||||
self.member_handler = hs.get_room_member_handler()
|
self.member_handler = hs.get_room_member_handler()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload( # type: ignore
|
async def _serialize_payload( # type: ignore[override]
|
||||||
knock_event_id: str,
|
knock_event_id: str,
|
||||||
txn_id: Optional[str],
|
txn_id: Optional[str],
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
content: JsonDict,
|
content: JsonDict,
|
||||||
):
|
) -> JsonDict:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
knock_event_id: The ID of the knock to be rescinded.
|
knock_event_id: The ID of the knock to be rescinded.
|
||||||
@ -281,11 +282,11 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
|
|||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore
|
async def _handle_request( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
request: SynapseRequest,
|
request: SynapseRequest,
|
||||||
knock_event_id: str,
|
knock_event_id: str,
|
||||||
):
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
txn_id = content["txn_id"]
|
txn_id = content["txn_id"]
|
||||||
@ -329,7 +330,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
|
|||||||
self.distributor = hs.get_distributor()
|
self.distributor = hs.get_distributor()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload( # type: ignore
|
async def _serialize_payload( # type: ignore[override]
|
||||||
room_id: str, user_id: str, change: str
|
room_id: str, user_id: str, change: str
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
"""
|
"""
|
||||||
@ -345,7 +346,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
|
|||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request( # type: ignore
|
async def _handle_request( # type: ignore[override]
|
||||||
self, request: Request, room_id: str, user_id: str, change: str
|
self, request: Request, room_id: str, user_id: str, change: str
|
||||||
) -> Tuple[int, JsonDict]:
|
) -> Tuple[int, JsonDict]:
|
||||||
logger.info("user membership change: %s in %s", user_id, room_id)
|
logger.info("user membership change: %s in %s", user_id, room_id)
|
||||||
@ -360,7 +361,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
|
|||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs: "HomeServer", http_server):
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
ReplicationRemoteJoinRestServlet(hs).register(http_server)
|
ReplicationRemoteJoinRestServlet(hs).register(http_server)
|
||||||
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
|
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
|
||||||
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
|
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
|
||||||
|
@ -13,11 +13,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
|
from twisted.web.server import Request
|
||||||
|
|
||||||
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
from synapse.types import UserID
|
from synapse.types import JsonDict, UserID
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
@ -49,18 +52,17 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
|
|||||||
self._presence_handler = hs.get_presence_handler()
|
self._presence_handler = hs.get_presence_handler()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(user_id):
|
async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override]
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id):
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
await self._presence_handler.bump_presence_active_time(
|
await self._presence_handler.bump_presence_active_time(
|
||||||
UserID.from_string(user_id)
|
UserID.from_string(user_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (200, {})
|
||||||
200,
|
|
||||||
{},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReplicationPresenceSetState(ReplicationEndpoint):
|
class ReplicationPresenceSetState(ReplicationEndpoint):
|
||||||
@ -92,16 +94,21 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
|
|||||||
self._presence_handler = hs.get_presence_handler()
|
self._presence_handler = hs.get_presence_handler()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(
|
async def _serialize_payload( # type: ignore[override]
|
||||||
user_id, state, ignore_status_msg=False, force_notify=False
|
user_id: str,
|
||||||
):
|
state: JsonDict,
|
||||||
|
ignore_status_msg: bool = False,
|
||||||
|
force_notify: bool = False,
|
||||||
|
) -> JsonDict:
|
||||||
return {
|
return {
|
||||||
"state": state,
|
"state": state,
|
||||||
"ignore_status_msg": ignore_status_msg,
|
"ignore_status_msg": ignore_status_msg,
|
||||||
"force_notify": force_notify,
|
"force_notify": force_notify,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id):
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
await self._presence_handler.set_state(
|
await self._presence_handler.set_state(
|
||||||
@ -111,12 +118,9 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
|
|||||||
content["force_notify"],
|
content["force_notify"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (200, {})
|
||||||
200,
|
|
||||||
{},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs: "HomeServer", http_server):
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
ReplicationBumpPresenceActiveTime(hs).register(http_server)
|
ReplicationBumpPresenceActiveTime(hs).register(http_server)
|
||||||
ReplicationPresenceSetState(hs).register(http_server)
|
ReplicationPresenceSetState(hs).register(http_server)
|
||||||
|
@ -13,10 +13,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
|
from twisted.web.server import Request
|
||||||
|
|
||||||
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
@ -48,7 +52,7 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
|
|||||||
self.pusher_pool = hs.get_pusherpool()
|
self.pusher_pool = hs.get_pusherpool()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(app_id, pushkey, user_id):
|
async def _serialize_payload(app_id: str, pushkey: str, user_id: str) -> JsonDict: # type: ignore[override]
|
||||||
payload = {
|
payload = {
|
||||||
"app_id": app_id,
|
"app_id": app_id,
|
||||||
"pushkey": pushkey,
|
"pushkey": pushkey,
|
||||||
@ -56,7 +60,9 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
|
|||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id):
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
app_id = content["app_id"]
|
app_id = content["app_id"]
|
||||||
@ -67,5 +73,5 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
|
|||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs: "HomeServer", http_server):
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
ReplicationRemovePusherRestServlet(hs).register(http_server)
|
ReplicationRemovePusherRestServlet(hs).register(http_server)
|
||||||
|
@ -13,10 +13,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Optional, Tuple
|
||||||
|
|
||||||
|
from twisted.web.server import Request
|
||||||
|
|
||||||
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
@ -36,34 +40,34 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||||||
self.registration_handler = hs.get_registration_handler()
|
self.registration_handler = hs.get_registration_handler()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(
|
async def _serialize_payload( # type: ignore[override]
|
||||||
user_id,
|
user_id: str,
|
||||||
password_hash,
|
password_hash: Optional[str],
|
||||||
was_guest,
|
was_guest: bool,
|
||||||
make_guest,
|
make_guest: bool,
|
||||||
appservice_id,
|
appservice_id: Optional[str],
|
||||||
create_profile_with_displayname,
|
create_profile_with_displayname: Optional[str],
|
||||||
admin,
|
admin: bool,
|
||||||
user_type,
|
user_type: Optional[str],
|
||||||
address,
|
address: Optional[str],
|
||||||
shadow_banned,
|
shadow_banned: bool,
|
||||||
):
|
) -> JsonDict:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The desired user ID to register.
|
user_id: The desired user ID to register.
|
||||||
password_hash (str|None): Optional. The password hash for this user.
|
password_hash: Optional. The password hash for this user.
|
||||||
was_guest (bool): Optional. Whether this is a guest account being
|
was_guest: Optional. Whether this is a guest account being upgraded
|
||||||
upgraded to a non-guest account.
|
to a non-guest account.
|
||||||
make_guest (boolean): True if the the new user should be guest,
|
make_guest: True if the the new user should be guest, false to add a
|
||||||
false to add a regular user account.
|
regular user account.
|
||||||
appservice_id (str|None): The ID of the appservice registering the user.
|
appservice_id: The ID of the appservice registering the user.
|
||||||
create_profile_with_displayname (unicode|None): Optionally create a
|
create_profile_with_displayname: Optionally create a profile for the
|
||||||
profile for the user, setting their displayname to the given value
|
user, setting their displayname to the given value
|
||||||
admin (boolean): is an admin user?
|
admin: is an admin user?
|
||||||
user_type (str|None): type of user. One of the values from
|
user_type: type of user. One of the values from api.constants.UserTypes,
|
||||||
api.constants.UserTypes, or None for a normal user.
|
or None for a normal user.
|
||||||
address (str|None): the IP address used to perform the regitration.
|
address: the IP address used to perform the regitration.
|
||||||
shadow_banned (bool): Whether to shadow-ban the user
|
shadow_banned: Whether to shadow-ban the user
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"password_hash": password_hash,
|
"password_hash": password_hash,
|
||||||
@ -77,7 +81,9 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||||||
"shadow_banned": shadow_banned,
|
"shadow_banned": shadow_banned,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id):
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
await self.registration_handler.check_registration_ratelimit(content["address"])
|
await self.registration_handler.check_registration_ratelimit(content["address"])
|
||||||
@ -110,18 +116,21 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
|
|||||||
self.registration_handler = hs.get_registration_handler()
|
self.registration_handler = hs.get_registration_handler()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(user_id, auth_result, access_token):
|
async def _serialize_payload( # type: ignore[override]
|
||||||
|
user_id: str, auth_result: JsonDict, access_token: Optional[str]
|
||||||
|
) -> JsonDict:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The user ID that consented
|
user_id: The user ID that consented
|
||||||
auth_result (dict): The authenticated credentials of the newly
|
auth_result: The authenticated credentials of the newly registered user.
|
||||||
registered user.
|
access_token: The access token of the newly logged in
|
||||||
access_token (str|None): The access token of the newly logged in
|
|
||||||
device, or None if `inhibit_login` enabled.
|
device, or None if `inhibit_login` enabled.
|
||||||
"""
|
"""
|
||||||
return {"auth_result": auth_result, "access_token": access_token}
|
return {"auth_result": auth_result, "access_token": access_token}
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id):
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, user_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
auth_result = content["auth_result"]
|
auth_result = content["auth_result"]
|
||||||
@ -134,6 +143,6 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
|
|||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs: "HomeServer", http_server):
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
ReplicationRegisterServlet(hs).register(http_server)
|
ReplicationRegisterServlet(hs).register(http_server)
|
||||||
ReplicationPostRegisterActionsServlet(hs).register(http_server)
|
ReplicationPostRegisterActionsServlet(hs).register(http_server)
|
||||||
|
@ -13,18 +13,22 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Tuple
|
||||||
|
|
||||||
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
from synapse.events import make_event_from_dict
|
from synapse.events import EventBase, make_event_from_dict
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_json_object_from_request
|
from synapse.http.servlet import parse_json_object_from_request
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
from synapse.types import Requester, UserID
|
from synapse.types import JsonDict, Requester, UserID
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.databases.main import DataStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -70,18 +74,24 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
|||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(
|
async def _serialize_payload( # type: ignore[override]
|
||||||
event_id, store, event, context, requester, ratelimit, extra_users
|
event_id: str,
|
||||||
):
|
store: "DataStore",
|
||||||
|
event: EventBase,
|
||||||
|
context: EventContext,
|
||||||
|
requester: Requester,
|
||||||
|
ratelimit: bool,
|
||||||
|
extra_users: List[UserID],
|
||||||
|
) -> JsonDict:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
event_id (str)
|
event_id
|
||||||
store (DataStore)
|
store
|
||||||
requester (Requester)
|
requester
|
||||||
event (FrozenEvent)
|
event
|
||||||
context (EventContext)
|
context
|
||||||
ratelimit (bool)
|
ratelimit
|
||||||
extra_users (list(UserID)): Any extra users to notify about event
|
extra_users: Any extra users to notify about event
|
||||||
"""
|
"""
|
||||||
serialized_context = await context.serialize(event, store)
|
serialized_context = await context.serialize(event, store)
|
||||||
|
|
||||||
@ -100,7 +110,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
|||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
async def _handle_request(self, request, event_id):
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, event_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
with Measure(self.clock, "repl_send_event_parse"):
|
with Measure(self.clock, "repl_send_event_parse"):
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
@ -120,8 +132,6 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
|||||||
ratelimit = content["ratelimit"]
|
ratelimit = content["ratelimit"]
|
||||||
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
|
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
|
||||||
|
|
||||||
request.requester = requester
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
|
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
|
||||||
)
|
)
|
||||||
@ -139,5 +149,5 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs: "HomeServer", http_server):
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
ReplicationSendEventRestServlet(hs).register(http_server)
|
ReplicationSendEventRestServlet(hs).register(http_server)
|
||||||
|
@ -13,11 +13,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
|
from twisted.web.server import Request
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.http.server import HttpServer
|
||||||
from synapse.http.servlet import parse_integer
|
from synapse.http.servlet import parse_integer
|
||||||
from synapse.replication.http._base import ReplicationEndpoint
|
from synapse.replication.http._base import ReplicationEndpoint
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
@ -57,10 +61,14 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
|||||||
self.streams = hs.get_replication_streams()
|
self.streams = hs.get_replication_streams()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _serialize_payload(stream_name, from_token, upto_token):
|
async def _serialize_payload( # type: ignore[override]
|
||||||
|
stream_name: str, from_token: int, upto_token: int
|
||||||
|
) -> JsonDict:
|
||||||
return {"from_token": from_token, "upto_token": upto_token}
|
return {"from_token": from_token, "upto_token": upto_token}
|
||||||
|
|
||||||
async def _handle_request(self, request, stream_name):
|
async def _handle_request( # type: ignore[override]
|
||||||
|
self, request: Request, stream_name: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
stream = self.streams.get(stream_name)
|
stream = self.streams.get(stream_name)
|
||||||
if stream is None:
|
if stream is None:
|
||||||
raise SynapseError(400, "Unknown stream")
|
raise SynapseError(400, "Unknown stream")
|
||||||
@ -78,5 +86,5 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs: "HomeServer", http_server):
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
ReplicationGetStreamUpdates(hs).register(http_server)
|
ReplicationGetStreamUpdates(hs).register(http_server)
|
||||||
|
Loading…
Reference in New Issue
Block a user