Use parse_{int,str} and assert from http.servlet

parse_integer and parse_string can take a request and raise errors
in case we have wrong or missing params.
This PR tries to use them more to deduplicate some code and make it
better readable
This commit is contained in:
Krombel 2018-07-13 21:40:14 +02:00
parent 2aba1f549c
commit 32fd6910d0
14 changed files with 90 additions and 154 deletions

View File

@ -22,7 +22,12 @@ from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.servlet import (
assert_params_in_request,
parse_json_object_from_request,
parse_integer,
parse_string
)
from synapse.types import UserID, create_requester
from .base import ClientV1RestServlet, client_path_patterns
@ -98,16 +103,8 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
before_ts = request.args.get("before_ts", None)
if not before_ts:
raise SynapseError(400, "Missing 'before_ts' arg")
logger.info("before_ts: %r", before_ts[0])
try:
before_ts = int(before_ts[0])
except Exception:
raise SynapseError(400, "Invalid 'before_ts' arg")
before_ts = parse_integer(request, "before_ts", required=True)
logger.info("before_ts: %r", before_ts)
ret = yield self.media_repository.delete_old_remote_media(before_ts)
@ -300,10 +297,8 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin")
content = parse_json_object_from_request(request)
new_room_user_id = content.get("new_room_user_id")
if not new_room_user_id:
raise SynapseError(400, "Please provide field `new_room_user_id`")
assert_params_in_request(content, ["new_room_user_id"])
new_room_user_id = content["new_room_user_id"]
room_creator_requester = create_requester(new_room_user_id)
@ -464,9 +459,8 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin")
params = parse_json_object_from_request(request)
assert_params_in_request(params, ["new_password"])
new_password = params['new_password']
if not new_password:
raise SynapseError(400, "Missing 'new_password' arg")
logger.info("new_password: %r", new_password)
@ -514,12 +508,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Can only users a local user")
order = "name" # order by name in user table
start = request.args.get("start")[0]
limit = request.args.get("limit")[0]
if not limit:
raise SynapseError(400, "Missing 'limit' arg")
if not start:
raise SynapseError(400, "Missing 'start' arg")
start = parse_integer(request, "start", required=True)
limit = parse_integer(request, "limit", required=True)
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(
@ -551,12 +542,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
order = "name" # order by name in user table
params = parse_json_object_from_request(request)
assert_params_in_request(params, ["limit", "start"])
limit = params['limit']
start = params['start']
if not limit:
raise SynapseError(400, "Missing 'limit' arg")
if not start:
raise SynapseError(400, "Missing 'start' arg")
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(
@ -604,10 +592,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user")
term = request.args.get("term")[0]
if not term:
raise SynapseError(400, "Missing 'term' arg")
term = parse_string(request, "term", required=True)
logger.info("term: %s ", term)
ret = yield self.handlers.admin_handler.search_users(

View File

@ -18,8 +18,8 @@ import logging
from twisted.internet import defer
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.http.servlet import parse_json_object_from_request
from synapse.api.errors import AuthError, SynapseError
from synapse.http.servlet import assert_params_in_request, parse_json_object_from_request
from synapse.types import RoomAlias
from .base import ClientV1RestServlet, client_path_patterns
@ -52,15 +52,14 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
content = parse_json_object_from_request(request)
if "room_id" not in content:
raise SynapseError(400, "Missing room_id key",
raise SynapseError(400, 'Missing params: ["room_id"]',
errcode=Codes.BAD_JSON)
logger.debug("Got content: %s", content)
room_alias = RoomAlias.from_string(room_alias)
logger.debug("Got room name: %s", room_alias.to_string())
room_id = content["room_id"]

View File

@ -16,6 +16,7 @@
from twisted.internet import defer
from synapse.streams.config import PaginationConfig
from synapse.http.servlet import parse_boolean
from .base import ClientV1RestServlet, client_path_patterns
@ -33,7 +34,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
include_archived = request.args.get("archived", None) == ["true"]
include_archived = parse_boolean(request, "archived", default=False)
content = yield self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),
pagin_config=pagination_config,

View File

@ -21,7 +21,7 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
from synapse.http.servlet import parse_json_value_from_request
from synapse.http.servlet import parse_json_value_from_request, parse_string
from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
@ -75,11 +75,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e:
raise SynapseError(400, e.message)
before = request.args.get("before", None)
before = parse_string(request, "before")
if before:
before = _namespaced_rule_id(spec, before[0])
after = request.args.get("after", None)
after = parse_string(request, "after")
if after:
after = _namespaced_rule_id(spec, after[0])

View File

@ -21,6 +21,7 @@ from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
assert_params_in_request,
parse_json_object_from_request,
parse_string,
)
@ -91,15 +92,11 @@ class PushersSetRestServlet(ClientV1RestServlet):
)
defer.returnValue((200, {}))
reqd = ['kind', 'app_id', 'app_display_name',
'device_display_name', 'pushkey', 'lang', 'data']
missing = []
for i in reqd:
if i not in content:
missing.append(i)
if len(missing):
raise SynapseError(400, "Missing parameters: " + ','.join(missing),
errcode=Codes.MISSING_PARAM)
assert_params_in_request(
content,
['kind', 'app_id', 'app_display_name',
'device_display_name', 'pushkey', 'lang', 'data']
)
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
logger.debug("Got pushers request with body: %r", content)
@ -148,7 +145,7 @@ class PushersRemoveRestServlet(RestServlet):
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
super(RestServlet, self).__init__()
super(PushersRemoveRestServlet, self).__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_auth()

View File

@ -18,15 +18,13 @@ import hmac
import logging
from hashlib import sha1
from six import string_types
from twisted.internet import defer
import synapse.util.stringutils as stringutils
from synapse.api.auth import get_access_token_from_request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.servlet import assert_params_in_request, parse_json_object_from_request
from synapse.types import create_requester
from .base import ClientV1RestServlet, client_path_patterns
@ -124,8 +122,7 @@ class RegisterRestServlet(ClientV1RestServlet):
session = (register_json["session"]
if "session" in register_json else None)
login_type = None
if "type" not in register_json:
raise SynapseError(400, "Missing 'type' key.")
assert_params_in_request(register_json, ["type"])
try:
login_type = register_json["type"]
@ -312,9 +309,7 @@ class RegisterRestServlet(ClientV1RestServlet):
def _do_app_service(self, request, register_json, session):
as_token = get_access_token_from_request(request)
if "user" not in register_json:
raise SynapseError(400, "Expected 'user' key.")
assert_params_in_request(register_json, ["user"])
user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler
@ -331,12 +326,7 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_shared_secret(self, request, register_json, session):
if not isinstance(register_json.get("mac", None), string_types):
raise SynapseError(400, "Expected mac.")
if not isinstance(register_json.get("user", None), string_types):
raise SynapseError(400, "Expected 'user' key.")
if not isinstance(register_json.get("password", None), string_types):
raise SynapseError(400, "Expected 'password' key.")
assert_params_in_request(register_json, ["mac", "user", "password"])
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
@ -419,11 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_create(self, requester, user_json):
if "localpart" not in user_json:
raise SynapseError(400, "Expected 'localpart' key.")
if "displayname" not in user_json:
raise SynapseError(400, "Expected 'displayname' key.")
assert_params_in_request(user_json, ["localpart", "displayname"])
localpart = user_json["localpart"].encode("utf-8")
displayname = user_json["displayname"].encode("utf-8")

View File

@ -28,6 +28,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2, serialize_event
from synapse.http.servlet import (
assert_params_in_request,
parse_integer,
parse_json_object_from_request,
parse_string,
@ -435,7 +436,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
request, default_limit=10,
)
as_client_event = "raw" not in request.args
filter_bytes = request.args.get("filter", None)
filter_bytes = parse_string(request, "filter")
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8")
event_filter = Filter(json.loads(filter_json))
@ -530,7 +531,7 @@ class RoomEventContextServlet(ClientV1RestServlet):
def on_GET(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
limit = int(request.args.get("limit", [10])[0])
limit = parse_integer(request, "limit", default=10)
results = yield self.handlers.room_context_handler.get_event_context(
requester.user,
@ -636,8 +637,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
target = requester.user
if membership_action in ["invite", "ban", "unban", "kick"]:
if "user_id" not in content:
raise SynapseError(400, "Missing user_id key.")
assert_params_in_request(content, ["user_id"])
target = UserID.from_string(content["user_id"])
event_content = None
@ -764,7 +764,7 @@ class SearchRestServlet(ClientV1RestServlet):
content = parse_json_object_from_request(request)
batch = request.args.get("next_batch", [None])[0]
batch = parse_string(request, "next_batch")
results = yield self.handlers.search_handler.search(
requester.user,
content,

View File

@ -160,11 +160,10 @@ class PasswordRestServlet(RestServlet):
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
user_id = threepid_user_id
else:
logger.error("Auth succeeded but no known type!", result.keys())
logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
if 'new_password' not in params:
raise SynapseError(400, "", Codes.MISSING_PARAM)
assert_params_in_request(params, ["new_password"])
new_password = params['new_password']
yield self._set_password_handler.set_password(
@ -229,15 +228,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
assert_params_in_request(
body,
['id_server', 'client_secret', 'email', 'send_attempt'],
)
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
@ -267,18 +261,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = [
assert_params_in_request(body, [
'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt',
]
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
])
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
@ -373,15 +359,7 @@ class ThreepidDeleteRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['medium', 'address']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
assert_params_in_request(body, ['medium', 'address'])
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()

View File

@ -18,14 +18,18 @@ import logging
from twisted.internet import defer
from synapse.api import errors
from synapse.http import servlet
from synapse.http.servlet import (
assert_params_in_request,
parse_json_object_from_request,
RestServlet
)
from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
class DevicesRestServlet(servlet.RestServlet):
class DevicesRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
def __init__(self, hs):
@ -47,7 +51,7 @@ class DevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {"devices": devices}))
class DeleteDevicesRestServlet(servlet.RestServlet):
class DeleteDevicesRestServlet(RestServlet):
"""
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
@ -67,19 +71,17 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
requester = yield self.auth.get_user_by_req(request)
try:
body = servlet.parse_json_object_from_request(request)
body = parse_json_object_from_request(request)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
# deal with older clients which didn't pass a J*DELETESON dict
# DELETE
# deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict
body = {}
else:
raise e
if 'devices' not in body:
raise errors.SynapseError(
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
)
assert_params_in_request(body, ["devices"])
yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request),
@ -92,7 +94,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {}))
class DeviceRestServlet(servlet.RestServlet):
class DeviceRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
def __init__(self, hs):
@ -121,7 +123,7 @@ class DeviceRestServlet(servlet.RestServlet):
requester = yield self.auth.get_user_by_req(request)
try:
body = servlet.parse_json_object_from_request(request)
body = parse_json_object_from_request(request)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
@ -144,7 +146,7 @@ class DeviceRestServlet(servlet.RestServlet):
def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
body = servlet.parse_json_object_from_request(request)
body = parse_json_object_from_request(request)
yield self.device_handler.update_device(
requester.user.to_string(),
device_id,

View File

@ -387,9 +387,7 @@ class RegisterRestServlet(RestServlet):
add_msisdn = False
else:
# NB: This may be from the auth handler and NOT from the POST
if 'password' not in params:
raise SynapseError(400, "Missing password.",
Codes.MISSING_PARAM)
assert_params_in_request(params, ["password"])
desired_username = params.get("username", None)
new_password = params.get("password", None)
@ -566,11 +564,14 @@ class RegisterRestServlet(RestServlet):
Returns:
defer.Deferred:
"""
reqd = ('medium', 'address', 'validated_at')
if any(x not in threepid for x in reqd):
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
defer.returnValue()
try:
assert_params_in_request(threepid, ['medium', 'address', 'validated_at'])
except SynapseError as ex:
if ex.errcode == Codes.MISSING_PARAM:
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
defer.returnValue(None)
raise
yield self.auth_handler.add_threepid(
user_id,

View File

@ -14,6 +14,8 @@
from pydenticon import Generator
from synapse.http.servlet import parse_integer
from twisted.web.resource import Resource
FOREGROUND = [
@ -56,8 +58,8 @@ class IdenticonResource(Resource):
def render_GET(self, request):
name = "/".join(request.postpath)
width = int(request.args.get("width", [96])[0])
height = int(request.args.get("height", [96])[0])
width = parse_integer(request, "width", default=96)
height = parse_integer(request, "height", default=96)
identicon_bytes = self.generate_identicon(name, width, height)
request.setHeader(b"Content-Type", b"image/png")
request.setHeader(

View File

@ -40,6 +40,7 @@ from synapse.http.server import (
respond_with_json_bytes,
wrap_json_request_handler,
)
from synapse.http.servlet import parse_integer, parse_string
from synapse.util.async import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
@ -96,9 +97,9 @@ class PreviewUrlResource(Resource):
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = yield self.auth.get_user_by_req(request)
url = request.args.get("url")[0]
url = parse_string(request, "url")
if "ts" in request.args:
ts = int(request.args.get("ts")[0])
ts = parse_integer(request, "ts")
else:
ts = self.clock.time_msec()

View File

@ -21,6 +21,7 @@ from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import SynapseError
from synapse.http.server import respond_with_json, wrap_json_request_handler
from synapse.http.servlet import parse_string
logger = logging.getLogger(__name__)
@ -65,10 +66,10 @@ class UploadResource(Resource):
code=413,
)
upload_name = request.args.get("filename", None)
upload_name = parse_string(request, "filename")
if upload_name:
try:
upload_name = upload_name[0].decode('UTF-8')
upload_name = upload_name.decode('UTF-8')
except UnicodeDecodeError:
raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name),

View File

@ -16,6 +16,7 @@
import logging
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string
from synapse.types import StreamToken
logger = logging.getLogger(__name__)
@ -56,23 +57,10 @@ class PaginationConfig(object):
@classmethod
def from_request(cls, request, raise_invalid_params=True,
default_limit=None):
def get_param(name, default=None):
lst = request.args.get(name, [])
if len(lst) > 1:
raise SynapseError(
400, "%s must be specified only once" % (name,)
)
elif len(lst) == 1:
return lst[0]
else:
return default
direction = parse_string(request, "dir", default='f', allowed_values=['f', 'b'])
direction = get_param("dir", 'f')
if direction not in ['f', 'b']:
raise SynapseError(400, "'dir' parameter is invalid.")
from_tok = get_param("from")
to_tok = get_param("to")
from_tok = parse_string(request, "from")
to_tok = parse_string(request, "to")
try:
if from_tok == "END":
@ -88,12 +76,7 @@ class PaginationConfig(object):
except Exception:
raise SynapseError(400, "'to' paramater is invalid")
limit = get_param("limit", None)
if limit is not None and not limit.isdigit():
raise SynapseError(400, "'limit' parameter must be an integer.")
if limit is None:
limit = default_limit
limit = parse_integer(request, "limit", default=default_limit)
try:
return PaginationConfig(from_tok, to_tok, direction, limit)