Add type hints to user admin API. (#9521)

This commit is contained in:
Dirk Klimpel 2021-03-03 14:09:39 +01:00 committed by GitHub
parent 0c330423bc
commit d790d0d314
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 63 additions and 35 deletions

1
changelog.d/9521.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to user admin API.

View File

@ -16,7 +16,7 @@ import hashlib
import hmac import hmac
import logging import logging
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
@ -47,13 +47,15 @@ logger = logging.getLogger(__name__)
class UsersRestServlet(RestServlet): class UsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, List[JsonDict]]:
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
@ -153,7 +155,7 @@ class UserRestServletV2(RestServlet):
otherwise an error. otherwise an error.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
@ -165,7 +167,9 @@ class UserRestServletV2(RestServlet):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
self.pusher_pool = hs.get_pusherpool() self.pusher_pool = hs.get_pusherpool()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -179,7 +183,9 @@ class UserRestServletV2(RestServlet):
return 200, ret return 200, ret
async def on_PUT(self, request, user_id): async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
@ -273,6 +279,8 @@ class UserRestServletV2(RestServlet):
) )
user = await self.admin_handler.get_user(target_user) user = await self.admin_handler.get_user(target_user)
assert user is not None
return 200, user return 200, user
else: # create user else: # create user
@ -330,9 +338,10 @@ class UserRestServletV2(RestServlet):
target_user, requester, body["avatar_url"], True target_user, requester, body["avatar_url"], True
) )
ret = await self.admin_handler.get_user(target_user) user = await self.admin_handler.get_user(target_user)
assert user is not None
return 201, ret return 201, user
class UserRegisterServlet(RestServlet): class UserRegisterServlet(RestServlet):
@ -346,10 +355,10 @@ class UserRegisterServlet(RestServlet):
PATTERNS = admin_patterns("/register") PATTERNS = admin_patterns("/register")
NONCE_TIMEOUT = 60 NONCE_TIMEOUT = 60
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.reactor = hs.get_reactor() self.reactor = hs.get_reactor()
self.nonces = {} self.nonces = {} # type: Dict[str, int]
self.hs = hs self.hs = hs
def _clear_old_nonces(self): def _clear_old_nonces(self):
@ -362,7 +371,7 @@ class UserRegisterServlet(RestServlet):
if now - v > self.NONCE_TIMEOUT: if now - v > self.NONCE_TIMEOUT:
del self.nonces[k] del self.nonces[k]
def on_GET(self, request): def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
""" """
Generate a new nonce. Generate a new nonce.
""" """
@ -372,7 +381,7 @@ class UserRegisterServlet(RestServlet):
self.nonces[nonce] = int(self.reactor.seconds()) self.nonces[nonce] = int(self.reactor.seconds())
return 200, {"nonce": nonce} return 200, {"nonce": nonce}
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
self._clear_old_nonces() self._clear_old_nonces()
if not self.hs.config.registration_shared_secret: if not self.hs.config.registration_shared_secret:
@ -478,12 +487,14 @@ class WhoisRestServlet(RestServlet):
client_patterns("/admin" + path_regex, v1=True) client_patterns("/admin" + path_regex, v1=True)
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
auth_user = requester.user auth_user = requester.user
@ -508,7 +519,9 @@ class DeactivateAccountRestServlet(RestServlet):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]: async def on_POST(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
@ -550,7 +563,7 @@ class AccountValidityRenewServlet(RestServlet):
self.account_activity_handler = hs.get_account_validity_handler() self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -584,14 +597,16 @@ class ResetPasswordRestServlet(RestServlet):
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)") PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self._set_password_handler = hs.get_set_password_handler() self._set_password_handler = hs.get_set_password_handler()
async def on_POST(self, request, target_user_id): async def on_POST(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, JsonDict]:
"""Post request to allow an administrator reset password for a user. """Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse. This needs user to have administrator access in Synapse.
""" """
@ -626,12 +641,14 @@ class SearchUsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)") PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, target_user_id): async def on_GET(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, Optional[List[JsonDict]]]:
"""Get request to search user table for specific users according to """Get request to search user table for specific users according to
search term. search term.
This needs user to have a administrator access in Synapse. This needs user to have a administrator access in Synapse.
@ -682,12 +699,14 @@ class UserAdminServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -699,7 +718,9 @@ class UserAdminServlet(RestServlet):
return 200, {"admin": is_admin} return 200, {"admin": is_admin}
async def on_PUT(self, request, user_id): async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user auth_user = requester.user
@ -730,12 +751,14 @@ class UserMembershipRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
room_ids = await self.store.get_rooms_for_user(user_id) room_ids = await self.store.get_rooms_for_user(user_id)
@ -758,7 +781,7 @@ class PushersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -799,7 +822,7 @@ class UserMediaRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -891,7 +914,9 @@ class UserTokenRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
async def on_POST(self, request, user_id): async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user auth_user = requester.user
@ -943,7 +968,9 @@ class ShadowBanRestServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request, user_id): async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(user_id):

View File

@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, List, Optional, Tuple from typing import List, Optional, Tuple
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
@ -27,7 +27,7 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator, MultiWriterIdGenerator,
StreamIdGenerator, StreamIdGenerator,
) )
from synapse.types import get_domain_from_id from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore from .account_data import AccountDataStore
@ -264,7 +264,7 @@ class DataStore(
return [UserPresenceState(**row) for row in rows] return [UserPresenceState(**row) for row in rows]
async def get_users(self) -> List[Dict[str, Any]]: async def get_users(self) -> List[JsonDict]:
"""Function to retrieve a list of users in users table. """Function to retrieve a list of users in users table.
Returns: Returns:
@ -292,7 +292,7 @@ class DataStore(
name: Optional[str] = None, name: Optional[str] = None,
guests: bool = True, guests: bool = True,
deactivated: bool = False, deactivated: bool = False,
) -> Tuple[List[Dict[str, Any]], int]: ) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from """Function to retrieve a paginated list of users from
users list. This will return a json list of users and the users list. This will return a json list of users and the
total number of users matching the filter criteria. total number of users matching the filter criteria.
@ -353,7 +353,7 @@ class DataStore(
"get_users_paginate_txn", get_users_paginate_txn "get_users_paginate_txn", get_users_paginate_txn
) )
async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]: async def search_users(self, term: str) -> Optional[List[JsonDict]]:
"""Function to search users list for one or more users with """Function to search users list for one or more users with
the matched term. the matched term.

View File

@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
start: int, start: int,
limit: int, limit: int,
user_id: str, user_id: str,
order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value, order_by: str = MediaSortOrder.CREATED_TS.value,
direction: str = "f", direction: str = "f",
) -> Tuple[List[Dict[str, Any]], int]: ) -> Tuple[List[Dict[str, Any]], int]:
"""Get a paginated list of metadata for a local piece of media """Get a paginated list of metadata for a local piece of media