Type hints and validation improvements. (#9321)

* Adds type hints to the groups servlet and stringutils code.
* Assert the maximum length of some input values for spec compliance.
This commit is contained in:
Patrick Cloke 2021-02-08 13:59:54 -05:00 committed by GitHub
parent 0963d39ea6
commit 3f58fc848d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 177 additions and 79 deletions

1
changelog.d/9321.bugfix Normal file
View File

@ -0,0 +1 @@
Assert a maximum length for the `client_secret` parameter for spec compliance.

View File

@ -18,6 +18,7 @@
import logging import logging
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
@ -32,6 +33,11 @@ logger = logging.getLogger(__name__)
# TODO: Flairs # TODO: Flairs
# Note that the maximum lengths are somewhat arbitrary.
MAX_SHORT_DESC_LEN = 1000
MAX_LONG_DESC_LEN = 10000
class GroupsServerWorkerHandler: class GroupsServerWorkerHandler:
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
@ -508,11 +514,26 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
) )
profile = {} profile = {}
for keyname in ("name", "avatar_url", "short_description", "long_description"): for keyname, max_length in (
("name", MAX_DISPLAYNAME_LEN),
("avatar_url", MAX_AVATAR_URL_LEN),
("short_description", MAX_SHORT_DESC_LEN),
("long_description", MAX_LONG_DESC_LEN),
):
if keyname in content: if keyname in content:
value = content[keyname] value = content[keyname]
if not isinstance(value, str): if not isinstance(value, str):
raise SynapseError(400, "%r value is not a string" % (keyname,)) raise SynapseError(
400,
"%r value is not a string" % (keyname,),
errcode=Codes.INVALID_PARAM,
)
if len(value) > max_length:
raise SynapseError(
400,
"Invalid %s parameter" % (keyname,),
errcode=Codes.INVALID_PARAM,
)
profile[keyname] = value profile[keyname] = value
await self.store.update_group_profile(group_id, profile) await self.store.update_group_profile(group_id, profile)

View File

@ -16,13 +16,24 @@
import logging import logging
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Tuple
from twisted.web.http import Request
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.types import GroupID from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.types import GroupID, JsonDict
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,7 +44,7 @@ def _validate_group_id(f):
""" """
@wraps(f) @wraps(f)
def wrapper(self, request, group_id, *args, **kwargs): def wrapper(self, request: Request, group_id: str, *args, **kwargs):
if not GroupID.is_valid(group_id): if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
@ -48,14 +59,14 @@ class GroupServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_GET(self, request, group_id): async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -66,11 +77,15 @@ class GroupServlet(RestServlet):
return 200, group_description return 200, group_description
@_validate_group_id @_validate_group_id
async def on_POST(self, request, group_id): async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert_params_in_dict(
content, ("name", "avatar_url", "short_description", "long_description")
)
assert isinstance(self.groups_handler, GroupsLocalHandler)
await self.groups_handler.update_group_profile( await self.groups_handler.update_group_profile(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -84,14 +99,14 @@ class GroupSummaryServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_GET(self, request, group_id): async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -116,18 +131,21 @@ class GroupSummaryRoomsCatServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)$"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_PUT(self, request, group_id, category_id, room_id): async def on_PUT(
self, request: Request, group_id: str, category_id: str, room_id: str
):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.update_group_summary_room( resp = await self.groups_handler.update_group_summary_room(
group_id, group_id,
requester_user_id, requester_user_id,
@ -139,10 +157,13 @@ class GroupSummaryRoomsCatServlet(RestServlet):
return 200, resp return 200, resp
@_validate_group_id @_validate_group_id
async def on_DELETE(self, request, group_id, category_id, room_id): async def on_DELETE(
self, request: Request, group_id: str, category_id: str, room_id: str
):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.delete_group_summary_room( resp = await self.groups_handler.delete_group_summary_room(
group_id, requester_user_id, room_id=room_id, category_id=category_id group_id, requester_user_id, room_id=room_id, category_id=category_id
) )
@ -158,14 +179,16 @@ class GroupCategoryServlet(RestServlet):
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$" "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_GET(self, request, group_id, category_id): async def on_GET(
self, request: Request, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -176,11 +199,14 @@ class GroupCategoryServlet(RestServlet):
return 200, category return 200, category
@_validate_group_id @_validate_group_id
async def on_PUT(self, request, group_id, category_id): async def on_PUT(
self, request: Request, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.update_group_category( resp = await self.groups_handler.update_group_category(
group_id, requester_user_id, category_id=category_id, content=content group_id, requester_user_id, category_id=category_id, content=content
) )
@ -188,10 +214,13 @@ class GroupCategoryServlet(RestServlet):
return 200, resp return 200, resp
@_validate_group_id @_validate_group_id
async def on_DELETE(self, request, group_id, category_id): async def on_DELETE(
self, request: Request, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.delete_group_category( resp = await self.groups_handler.delete_group_category(
group_id, requester_user_id, category_id=category_id group_id, requester_user_id, category_id=category_id
) )
@ -205,14 +234,14 @@ class GroupCategoriesServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_GET(self, request, group_id): async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -229,14 +258,16 @@ class GroupRoleServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_GET(self, request, group_id, role_id): async def on_GET(
self, request: Request, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -247,11 +278,14 @@ class GroupRoleServlet(RestServlet):
return 200, category return 200, category
@_validate_group_id @_validate_group_id
async def on_PUT(self, request, group_id, role_id): async def on_PUT(
self, request: Request, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.update_group_role( resp = await self.groups_handler.update_group_role(
group_id, requester_user_id, role_id=role_id, content=content group_id, requester_user_id, role_id=role_id, content=content
) )
@ -259,10 +293,13 @@ class GroupRoleServlet(RestServlet):
return 200, resp return 200, resp
@_validate_group_id @_validate_group_id
async def on_DELETE(self, request, group_id, role_id): async def on_DELETE(
self, request: Request, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.delete_group_role( resp = await self.groups_handler.delete_group_role(
group_id, requester_user_id, role_id=role_id group_id, requester_user_id, role_id=role_id
) )
@ -276,14 +313,14 @@ class GroupRolesServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_GET(self, request, group_id): async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -308,18 +345,21 @@ class GroupSummaryUsersRoleServlet(RestServlet):
"/users/(?P<user_id>[^/]*)$" "/users/(?P<user_id>[^/]*)$"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_PUT(self, request, group_id, role_id, user_id): async def on_PUT(
self, request: Request, group_id: str, role_id: str, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.update_group_summary_user( resp = await self.groups_handler.update_group_summary_user(
group_id, group_id,
requester_user_id, requester_user_id,
@ -331,10 +371,13 @@ class GroupSummaryUsersRoleServlet(RestServlet):
return 200, resp return 200, resp
@_validate_group_id @_validate_group_id
async def on_DELETE(self, request, group_id, role_id, user_id): async def on_DELETE(
self, request: Request, group_id: str, role_id: str, user_id: str
):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.delete_group_summary_user( resp = await self.groups_handler.delete_group_summary_user(
group_id, requester_user_id, user_id=user_id, role_id=role_id group_id, requester_user_id, user_id=user_id, role_id=role_id
) )
@ -348,14 +391,14 @@ class GroupRoomServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_GET(self, request, group_id): async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -372,14 +415,14 @@ class GroupUsersServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_GET(self, request, group_id): async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -396,14 +439,14 @@ class GroupInvitedUsersServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_GET(self, request, group_id): async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -420,18 +463,19 @@ class GroupSettingJoinPolicyServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_PUT(self, request, group_id): async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.set_group_join_policy( result = await self.groups_handler.set_group_join_policy(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -445,14 +489,14 @@ class GroupCreateServlet(RestServlet):
PATTERNS = client_patterns("/create_group$") PATTERNS = client_patterns("/create_group$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
async def on_POST(self, request): async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -461,6 +505,7 @@ class GroupCreateServlet(RestServlet):
localpart = content.pop("localpart") localpart = content.pop("localpart")
group_id = GroupID(localpart, self.server_name).to_string() group_id = GroupID(localpart, self.server_name).to_string()
assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.create_group( result = await self.groups_handler.create_group(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -476,18 +521,21 @@ class GroupAdminRoomsServlet(RestServlet):
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$" "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_PUT(self, request, group_id, room_id): async def on_PUT(
self, request: Request, group_id: str, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.add_room_to_group( result = await self.groups_handler.add_room_to_group(
group_id, requester_user_id, room_id, content group_id, requester_user_id, room_id, content
) )
@ -495,10 +543,13 @@ class GroupAdminRoomsServlet(RestServlet):
return 200, result return 200, result
@_validate_group_id @_validate_group_id
async def on_DELETE(self, request, group_id, room_id): async def on_DELETE(
self, request: Request, group_id: str, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.remove_room_from_group( result = await self.groups_handler.remove_room_from_group(
group_id, requester_user_id, room_id group_id, requester_user_id, room_id
) )
@ -515,18 +566,21 @@ class GroupAdminRoomsConfigServlet(RestServlet):
"/config/(?P<config_key>[^/]*)$" "/config/(?P<config_key>[^/]*)$"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_PUT(self, request, group_id, room_id, config_key): async def on_PUT(
self, request: Request, group_id: str, room_id: str, config_key: str
):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.update_room_in_group( result = await self.groups_handler.update_room_in_group(
group_id, requester_user_id, room_id, config_key, content group_id, requester_user_id, room_id, config_key, content
) )
@ -542,7 +596,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$" "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -551,12 +605,13 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
@_validate_group_id @_validate_group_id
async def on_PUT(self, request, group_id, user_id): async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
config = content.get("config", {}) config = content.get("config", {})
assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.invite( result = await self.groups_handler.invite(
group_id, user_id, requester_user_id, config group_id, user_id, requester_user_id, config
) )
@ -572,18 +627,19 @@ class GroupAdminUsersKickServlet(RestServlet):
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$" "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_PUT(self, request, group_id, user_id): async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.remove_user_from_group( result = await self.groups_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content group_id, user_id, requester_user_id, content
) )
@ -597,18 +653,19 @@ class GroupSelfLeaveServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_PUT(self, request, group_id): async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.remove_user_from_group( result = await self.groups_handler.remove_user_from_group(
group_id, requester_user_id, requester_user_id, content group_id, requester_user_id, requester_user_id, content
) )
@ -622,18 +679,19 @@ class GroupSelfJoinServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_PUT(self, request, group_id): async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.join_group( result = await self.groups_handler.join_group(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -647,18 +705,19 @@ class GroupSelfAcceptInviteServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id @_validate_group_id
async def on_PUT(self, request, group_id): async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.accept_invite( result = await self.groups_handler.accept_invite(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -672,14 +731,14 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@_validate_group_id @_validate_group_id
async def on_PUT(self, request, group_id): async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -696,14 +755,14 @@ class PublicisedGroupsForUserServlet(RestServlet):
PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$") PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
async def on_GET(self, request, user_id): async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.get_user_by_req(request, allow_guest=True)
result = await self.groups_handler.get_publicised_groups_for_user(user_id) result = await self.groups_handler.get_publicised_groups_for_user(user_id)
@ -717,14 +776,14 @@ class PublicisedGroupsForUsersServlet(RestServlet):
PATTERNS = client_patterns("/publicised_groups$") PATTERNS = client_patterns("/publicised_groups$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
async def on_POST(self, request): async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -741,13 +800,13 @@ class GroupsForUserServlet(RestServlet):
PATTERNS = client_patterns("/joined_groups$") PATTERNS = client_patterns("/joined_groups$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
async def on_GET(self, request): async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -756,7 +815,7 @@ class GroupsForUserServlet(RestServlet):
return 200, result return 200, result
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server):
GroupServlet(hs).register(http_server) GroupServlet(hs).register(http_server)
GroupSummaryServlet(hs).register(http_server) GroupSummaryServlet(hs).register(http_server)
GroupInvitedUsersServlet(hs).register(http_server) GroupInvitedUsersServlet(hs).register(http_server)

View File

@ -193,6 +193,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
body, ["client_secret", "country", "phone_number", "send_attempt"] body, ["client_secret", "country", "phone_number", "send_attempt"]
) )
client_secret = body["client_secret"] client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
country = body["country"] country = body["country"]
phone_number = body["phone_number"] phone_number = body["phone_number"]
send_attempt = body["send_attempt"] send_attempt = body["send_attempt"]
@ -293,6 +294,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
sid = parse_string(request, "sid", required=True) sid = parse_string(request, "sid", required=True)
client_secret = parse_string(request, "client_secret", required=True) client_secret = parse_string(request, "client_secret", required=True)
assert_valid_client_secret(client_secret)
token = parse_string(request, "token", required=True) token = parse_string(request, "token", required=True)
# Attempt to validate a 3PID session # Attempt to validate a 3PID session

View File

@ -25,7 +25,17 @@ import abc
import functools import functools
import logging import logging
import os import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
TypeVar,
Union,
cast,
)
import twisted.internet.base import twisted.internet.base
import twisted.internet.tcp import twisted.internet.tcp
@ -588,7 +598,9 @@ class HomeServer(metaclass=abc.ABCMeta):
return UserDirectoryHandler(self) return UserDirectoryHandler(self)
@cache_in_self @cache_in_self
def get_groups_local_handler(self): def get_groups_local_handler(
self,
) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]:
if self.config.worker_app: if self.config.worker_app:
return GroupsLocalWorkerHandler(self) return GroupsLocalWorkerHandler(self)
else: else:

View File

@ -25,7 +25,7 @@ from synapse.api.errors import Codes, SynapseError
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris, # https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically # together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
@ -42,16 +42,15 @@ MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
rand = random.SystemRandom() rand = random.SystemRandom()
def random_string(length): def random_string(length: int) -> str:
return "".join(rand.choice(string.ascii_letters) for _ in range(length)) return "".join(rand.choice(string.ascii_letters) for _ in range(length))
def random_string_with_symbols(length): def random_string_with_symbols(length: int) -> str:
return "".join(rand.choice(_string_with_symbols) for _ in range(length)) return "".join(rand.choice(_string_with_symbols) for _ in range(length))
def is_ascii(s): def is_ascii(s: bytes) -> bool:
if isinstance(s, bytes):
try: try:
s.decode("ascii").encode("ascii") s.decode("ascii").encode("ascii")
except UnicodeDecodeError: except UnicodeDecodeError:
@ -61,9 +60,13 @@ def is_ascii(s):
return True return True
def assert_valid_client_secret(client_secret): def assert_valid_client_secret(client_secret: str) -> None:
"""Validate that a given string matches the client_secret regex defined by the spec""" """Validate that a given string matches the client_secret defined by the spec"""
if client_secret_regex.match(client_secret) is None: if (
len(client_secret) <= 0
or len(client_secret) > 255
or CLIENT_SECRET_REGEX.match(client_secret) is None
):
raise SynapseError( raise SynapseError(
400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
) )