Add a primitive helper script for listing worker endpoints. (#15243)

Co-authored-by: Patrick Cloke <patrickc@matrix.org>
This commit is contained in:
reivilibre 2023-03-23 12:11:14 +00:00 committed by GitHub
parent 3b0083c92a
commit 98fd558382
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 424 additions and 12 deletions

View file

@ -43,19 +43,22 @@ def client_patterns(
Returns:
An iterable of patterns.
"""
patterns = []
versions = []
if unstable:
unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex))
if v1:
v1_prefix = CLIENT_API_PREFIX + "/api/v1"
patterns.append(re.compile("^" + v1_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_API_PREFIX + f"/{release}"
patterns.append(re.compile("^" + new_prefix + path_regex))
versions.append("api/v1")
versions.extend(releases)
if unstable:
versions.append("unstable")
return patterns
if len(versions) == 1:
versions_str = versions[0]
elif len(versions) > 1:
versions_str = "(" + "|".join(versions) + ")"
else:
raise RuntimeError("Must have at least one version for a URL")
return [re.compile("^" + CLIENT_API_PREFIX + "/" + versions_str + path_regex)]
def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) -> None:

View file

@ -576,6 +576,9 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
class ThreepidRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid$")
# This is used as a proxy for all the 3pid endpoints.
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -834,6 +837,7 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None:
class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$")
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -38,6 +38,7 @@ class AccountDataServlet(RestServlet):
PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
)
CATEGORY = "Account data requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -136,6 +137,7 @@ class RoomAccountDataServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)"
"/account_data/(?P<account_data_type>[^/]*)"
)
CATEGORY = "Account data requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -40,6 +40,7 @@ logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet):
PATTERNS = client_patterns("/devices$")
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -123,6 +124,7 @@ class DeleteDevicesRestServlet(RestServlet):
class DeviceRestServlet(RestServlet):
PATTERNS = client_patterns("/devices/(?P<device_id>[^/]*)$")
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
class EventStreamRestServlet(RestServlet):
PATTERNS = client_patterns("/events$", v1=True)
CATEGORY = "Sync requests"
DEFAULT_LONGPOLL_TIME_MS = 30000
@ -76,6 +77,7 @@ class EventStreamRestServlet(RestServlet):
class EventRestServlet(RestServlet):
PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -69,6 +70,7 @@ class GetFilterRestServlet(RestServlet):
class CreateFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -28,6 +28,7 @@ if TYPE_CHECKING:
# TODO: Needs unit testing
class InitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/initialSync$", v1=True)
CATEGORY = "Sync requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -89,6 +89,7 @@ class KeyUploadServlet(RestServlet):
"""
PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -182,6 +183,7 @@ class KeyQueryServlet(RestServlet):
"""
PATTERNS = client_patterns("/keys/query$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -225,6 +227,7 @@ class KeyChangesServlet(RestServlet):
"""
PATTERNS = client_patterns("/keys/changes$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -274,6 +277,7 @@ class OneTimeKeyServlet(RestServlet):
"""
PATTERNS = client_patterns("/keys/claim$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -40,6 +40,7 @@ class KnockRoomAliasServlet(RestServlet):
"""
PATTERNS = client_patterns("/knock/(?P<room_identifier>[^/]*)")
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -72,6 +72,8 @@ class LoginResponse(TypedDict, total=False):
class LoginRestServlet(RestServlet):
PATTERNS = client_patterns("/login$", v1=True)
CATEGORY = "Registration/login requests"
CAS_TYPE = "m.login.cas"
SSO_TYPE = "m.login.sso"
TOKEN_TYPE = "m.login.token"
@ -537,6 +539,7 @@ def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
class RefreshTokenServlet(RestServlet):
PATTERNS = client_patterns("/refresh$")
CATEGORY = "Registration/login requests"
def __init__(self, hs: "HomeServer"):
self._auth_handler = hs.get_auth_handler()
@ -590,6 +593,7 @@ class SsoRedirectServlet(RestServlet):
+ "/(r0|v3)/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
)
]
CATEGORY = "SSO requests needed for all SSO providers"
def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they

View file

@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(RestServlet):
PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True)
CATEGORY = "Presence requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -29,6 +29,7 @@ if TYPE_CHECKING:
class ProfileDisplaynameRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -86,6 +87,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -142,6 +144,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -44,6 +44,9 @@ class PushRuleRestServlet(RestServlet):
"Unrecognised request: You probably wanted a trailing slash"
)
WORKERS_DENIED_METHODS = ["PUT", "DELETE"]
CATEGORY = "Push rule requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()

View file

@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
class ReadMarkerRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
CATEGORY = "Receipts requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -36,6 +36,7 @@ class ReceiptRestServlet(RestServlet):
"/receipt/(?P<receipt_type>[^/]*)"
"/(?P<event_id>[^/]*)$"
)
CATEGORY = "Receipts requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -367,6 +367,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
f"/register/{LoginType.REGISTRATION_TOKEN}/validity",
releases=("v1",),
)
CATEGORY = "Registration/login requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -395,6 +396,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
class RegisterRestServlet(RestServlet):
PATTERNS = client_patterns("/register$")
CATEGORY = "Registration/login requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -42,6 +42,7 @@ class RelationPaginationServlet(RestServlet):
"(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
releases=("v1",),
)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -84,6 +85,7 @@ class RelationPaginationServlet(RestServlet):
class ThreadsServlet(RestServlet):
PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/threads"),)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -140,7 +140,7 @@ class TransactionRestServlet(RestServlet):
class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
@ -180,6 +180,8 @@ class RoomCreateRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(RestServlet):
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.event_creation_handler = hs.get_event_creation_handler()
@ -323,6 +325,8 @@ class RoomStateEventRestServlet(RestServlet):
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(TransactionRestServlet):
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
@ -398,6 +402,8 @@ class RoomSendEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up
@ -460,6 +466,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
# TODO: Needs unit testing
class PublicRoomListRestServlet(RestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -578,6 +585,7 @@ class PublicRoomListRestServlet(RestServlet):
# TODO: Needs unit testing
class RoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -633,6 +641,7 @@ class RoomMemberListRestServlet(RestServlet):
# except it does custom AS logic and has a simpler return format
class JoinedRoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -654,6 +663,10 @@ class JoinedRoomMemberListRestServlet(RestServlet):
# TODO: Needs better unit testing
class RoomMessageListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
# TODO The routing information should be exposed programatically.
# I want to do this but for now I felt bad about leaving this without
# at least a visible warning on it.
CATEGORY = "Client API requests (ALL FOR SAME ROOM MUST GO TO SAME WORKER)"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -720,6 +733,7 @@ class RoomMessageListRestServlet(RestServlet):
# TODO: Needs unit testing
class RoomStateRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -742,6 +756,7 @@ class RoomStateRestServlet(RestServlet):
# TODO: Needs unit testing
class RoomInitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
CATEGORY = "Sync requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -766,6 +781,7 @@ class RoomEventServlet(RestServlet):
PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True
)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -858,6 +874,7 @@ class RoomEventContextServlet(RestServlet):
PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True
)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -958,6 +975,8 @@ class RoomForgetRestServlet(TransactionRestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(TransactionRestServlet):
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
@ -1071,6 +1090,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
class RoomRedactEventRestServlet(TransactionRestServlet):
CATEGORY = "Event sending requests"
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
@ -1164,6 +1185,7 @@ class RoomTypingRestServlet(RestServlet):
PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$", v1=True
)
CATEGORY = "The typing stream"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -1195,7 +1217,7 @@ class RoomTypingRestServlet(RestServlet):
# Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000)
# Defer getting the typing handler since it will raise on workers.
# Defer getting the typing handler since it will raise on WORKER_PATTERNS.
typing_handler = self.hs.get_typing_writer_handler()
try:
@ -1224,6 +1246,7 @@ class RoomAliasListServlet(RestServlet):
r"/rooms/(?P<room_id>[^/]*)/aliases"
),
] + list(client_patterns("/rooms/(?P<room_id>[^/]*)/aliases$", unstable=False))
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -1244,6 +1267,7 @@ class RoomAliasListServlet(RestServlet):
class SearchRestServlet(RestServlet):
PATTERNS = client_patterns("/search$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -1263,6 +1287,7 @@ class SearchRestServlet(RestServlet):
class JoinedRoomsRestServlet(RestServlet):
PATTERNS = client_patterns("/joined_rooms$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -1334,6 +1359,7 @@ class TimestampLookupRestServlet(RestServlet):
PATTERNS = (
re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/timestamp_to_event$"),
)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -1365,6 +1391,8 @@ class TimestampLookupRestServlet(RestServlet):
class RoomHierarchyRestServlet(RestServlet):
PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/hierarchy$"),)
WORKERS = PATTERNS
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -1405,6 +1433,7 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet):
"/rooms/(?P<room_identifier>[^/]*)/summary$"
),
)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

View file

@ -69,6 +69,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/batch_send$"
),
)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -37,6 +37,7 @@ class RoomKeysServlet(RestServlet):
PATTERNS = client_patterns(
"/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$"
)
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -253,6 +254,7 @@ class RoomKeysServlet(RestServlet):
class RoomKeysNewVersionServlet(RestServlet):
PATTERNS = client_patterns("/room_keys/version$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -328,6 +330,7 @@ class RoomKeysNewVersionServlet(RestServlet):
class RoomKeysVersionServlet(RestServlet):
PATTERNS = client_patterns("/room_keys/version/(?P<version>[^/]+)$")
CATEGORY = "Encryption requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -35,6 +35,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$"
)
CATEGORY = "The to_device stream"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -87,6 +87,7 @@ class SyncRestServlet(RestServlet):
PATTERNS = client_patterns("/sync$")
ALLOWED_PRESENCE = {"online", "offline", "unavailable"}
CATEGORY = "Sync requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -37,6 +37,7 @@ class TagListServlet(RestServlet):
PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags$"
)
CATEGORY = "Account data requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
@ -64,6 +65,7 @@ class TagServlet(RestServlet):
PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
)
CATEGORY = "Account data requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -31,6 +31,7 @@ logger = logging.getLogger(__name__)
class UserDirectorySearchRestServlet(RestServlet):
PATTERNS = client_patterns("/user_directory/search$")
CATEGORY = "User directory search requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -34,6 +34,7 @@ logger = logging.getLogger(__name__)
class VersionsRestServlet(RestServlet):
PATTERNS = [re.compile("^/_matrix/client/versions$")]
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -29,6 +29,7 @@ if TYPE_CHECKING:
class VoipRestServlet(RestServlet):
PATTERNS = client_patterns("/voip/turnServer$", v1=True)
CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()

View file

@ -93,6 +93,8 @@ class RemoteKey(RestServlet):
}
"""
CATEGORY = "Federation requests"
def __init__(self, hs: "HomeServer"):
self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastores().main