Merge pull request #3534 from krombel/use_parse_and_asserts_from_servlet

Use parse and asserts from http.servlet
This commit is contained in:
Amber Brown 2018-07-14 09:09:19 +10:00 committed by GitHub
commit 8532953c04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 101 additions and 164 deletions

1
changelog.d/3534.misc Normal file
View File

@ -0,0 +1 @@
refactor: use parse_{string,integer} and assert's from http.servlet for deduplication

View File

@ -23,7 +23,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_request from synapse.http.servlet import assert_params_in_dict
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -199,7 +199,7 @@ def event_from_pdu_json(pdu_json, outlier=False):
""" """
# we could probably enforce a bunch of other fields here (room_id, sender, # we could probably enforce a bunch of other fields here (room_id, sender,
# origin, etc etc) # origin, etc etc)
assert_params_in_request(pdu_json, ('event_id', 'type', 'depth')) assert_params_in_dict(pdu_json, ('event_id', 'type', 'depth'))
depth = pdu_json['depth'] depth = pdu_json['depth']
if not isinstance(depth, six.integer_types): if not isinstance(depth, six.integer_types):

View File

@ -206,7 +206,7 @@ def parse_json_object_from_request(request, allow_empty_body=False):
return content return content
def assert_params_in_request(body, required): def assert_params_in_dict(body, required):
absent = [] absent = []
for k in required: for k in required:
if k not in body: if k not in body:

View File

@ -22,7 +22,12 @@ from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
parse_integer,
parse_string
)
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -98,16 +103,8 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
if not is_admin: if not is_admin:
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
before_ts = request.args.get("before_ts", None) before_ts = parse_integer(request, "before_ts", required=True)
if not before_ts: logger.info("before_ts: %r", before_ts)
raise SynapseError(400, "Missing 'before_ts' arg")
logger.info("before_ts: %r", before_ts[0])
try:
before_ts = int(before_ts[0])
except Exception:
raise SynapseError(400, "Invalid 'before_ts' arg")
ret = yield self.media_repository.delete_old_remote_media(before_ts) ret = yield self.media_repository.delete_old_remote_media(before_ts)
@ -300,10 +297,8 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert_params_in_dict(content, ["new_room_user_id"])
new_room_user_id = content.get("new_room_user_id") new_room_user_id = content["new_room_user_id"]
if not new_room_user_id:
raise SynapseError(400, "Please provide field `new_room_user_id`")
room_creator_requester = create_requester(new_room_user_id) room_creator_requester = create_requester(new_room_user_id)
@ -464,9 +459,8 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
params = parse_json_object_from_request(request) params = parse_json_object_from_request(request)
assert_params_in_dict(params, ["new_password"])
new_password = params['new_password'] new_password = params['new_password']
if not new_password:
raise SynapseError(400, "Missing 'new_password' arg")
logger.info("new_password: %r", new_password) logger.info("new_password: %r", new_password)
@ -514,12 +508,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Can only users a local user") raise SynapseError(400, "Can only users a local user")
order = "name" # order by name in user table order = "name" # order by name in user table
start = request.args.get("start")[0] start = parse_integer(request, "start", required=True)
limit = request.args.get("limit")[0] limit = parse_integer(request, "limit", required=True)
if not limit:
raise SynapseError(400, "Missing 'limit' arg")
if not start:
raise SynapseError(400, "Missing 'start' arg")
logger.info("limit: %s, start: %s", limit, start) logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate( ret = yield self.handlers.admin_handler.get_users_paginate(
@ -551,12 +542,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
order = "name" # order by name in user table order = "name" # order by name in user table
params = parse_json_object_from_request(request) params = parse_json_object_from_request(request)
assert_params_in_dict(params, ["limit", "start"])
limit = params['limit'] limit = params['limit']
start = params['start'] start = params['start']
if not limit:
raise SynapseError(400, "Missing 'limit' arg")
if not start:
raise SynapseError(400, "Missing 'start' arg")
logger.info("limit: %s, start: %s", limit, start) logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate( ret = yield self.handlers.admin_handler.get_users_paginate(
@ -604,10 +592,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user") raise SynapseError(400, "Can only users a local user")
term = request.args.get("term")[0] term = parse_string(request, "term", required=True)
if not term:
raise SynapseError(400, "Missing 'term' arg")
logger.info("term: %s ", term) logger.info("term: %s ", term)
ret = yield self.handlers.admin_handler.search_users( ret = yield self.handlers.admin_handler.search_users(

View File

@ -52,15 +52,14 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_alias): def on_PUT(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
if "room_id" not in content: if "room_id" not in content:
raise SynapseError(400, "Missing room_id key", raise SynapseError(400, 'Missing params: ["room_id"]',
errcode=Codes.BAD_JSON) errcode=Codes.BAD_JSON)
logger.debug("Got content: %s", content) logger.debug("Got content: %s", content)
room_alias = RoomAlias.from_string(room_alias)
logger.debug("Got room name: %s", room_alias.to_string()) logger.debug("Got room name: %s", room_alias.to_string())
room_id = content["room_id"] room_id = content["room_id"]

View File

@ -16,6 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.http.servlet import parse_boolean
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -33,7 +34,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
include_archived = request.args.get("archived", None) == ["true"] include_archived = parse_boolean(request, "archived", default=False)
content = yield self.initial_sync_handler.snapshot_all_rooms( content = yield self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,

View File

@ -21,7 +21,7 @@ from synapse.api.errors import (
SynapseError, SynapseError,
UnrecognizedRequestError, UnrecognizedRequestError,
) )
from synapse.http.servlet import parse_json_value_from_request from synapse.http.servlet import parse_json_value_from_request, parse_string
from synapse.push.baserules import BASE_RULE_IDS from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.push.rulekinds import PRIORITY_CLASS_MAP
@ -75,11 +75,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
before = request.args.get("before", None) before = parse_string(request, "before")
if before: if before:
before = _namespaced_rule_id(spec, before[0]) before = _namespaced_rule_id(spec, before[0])
after = request.args.get("after", None) after = parse_string(request, "after")
if after: if after:
after = _namespaced_rule_id(spec, after[0]) after = _namespaced_rule_id(spec, after[0])

View File

@ -21,6 +21,7 @@ from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
@ -91,15 +92,11 @@ class PushersSetRestServlet(ClientV1RestServlet):
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
reqd = ['kind', 'app_id', 'app_display_name', assert_params_in_dict(
'device_display_name', 'pushkey', 'lang', 'data'] content,
missing = [] ['kind', 'app_id', 'app_display_name',
for i in reqd: 'device_display_name', 'pushkey', 'lang', 'data']
if i not in content: )
missing.append(i)
if len(missing):
raise SynapseError(400, "Missing parameters: " + ','.join(missing),
errcode=Codes.MISSING_PARAM)
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind']) logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
logger.debug("Got pushers request with body: %r", content) logger.debug("Got pushers request with body: %r", content)
@ -148,7 +145,7 @@ class PushersRemoveRestServlet(RestServlet):
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>" SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs): def __init__(self, hs):
super(RestServlet, self).__init__() super(PushersRemoveRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.auth = hs.get_auth() self.auth = hs.get_auth()

View File

@ -18,14 +18,12 @@ import hmac
import logging import logging
from hashlib import sha1 from hashlib import sha1
from six import string_types
from twisted.internet import defer from twisted.internet import defer
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.types import create_requester from synapse.types import create_requester
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -124,8 +122,7 @@ class RegisterRestServlet(ClientV1RestServlet):
session = (register_json["session"] session = (register_json["session"]
if "session" in register_json else None) if "session" in register_json else None)
login_type = None login_type = None
if "type" not in register_json: assert_params_in_dict(register_json, ["type"])
raise SynapseError(400, "Missing 'type' key.")
try: try:
login_type = register_json["type"] login_type = register_json["type"]
@ -312,9 +309,7 @@ class RegisterRestServlet(ClientV1RestServlet):
def _do_app_service(self, request, register_json, session): def _do_app_service(self, request, register_json, session):
as_token = self.auth.get_access_token_from_request(request) as_token = self.auth.get_access_token_from_request(request)
if "user" not in register_json: assert_params_in_dict(register_json, ["user"])
raise SynapseError(400, "Expected 'user' key.")
user_localpart = register_json["user"].encode("utf-8") user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
@ -331,12 +326,7 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_shared_secret(self, request, register_json, session): def _do_shared_secret(self, request, register_json, session):
if not isinstance(register_json.get("mac", None), string_types): assert_params_in_dict(register_json, ["mac", "user", "password"])
raise SynapseError(400, "Expected mac.")
if not isinstance(register_json.get("user", None), string_types):
raise SynapseError(400, "Expected 'user' key.")
if not isinstance(register_json.get("password", None), string_types):
raise SynapseError(400, "Expected 'password' key.")
if not self.hs.config.registration_shared_secret: if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled") raise SynapseError(400, "Shared secret registration is not enabled")
@ -419,11 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_create(self, requester, user_json): def _do_create(self, requester, user_json):
if "localpart" not in user_json: assert_params_in_dict(user_json, ["localpart", "displayname"])
raise SynapseError(400, "Expected 'localpart' key.")
if "displayname" not in user_json:
raise SynapseError(400, "Expected 'displayname' key.")
localpart = user_json["localpart"].encode("utf-8") localpart = user_json["localpart"].encode("utf-8")
displayname = user_json["displayname"].encode("utf-8") displayname = user_json["displayname"].encode("utf-8")

View File

@ -28,6 +28,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2, serialize_event from synapse.events.utils import format_event_for_client_v2, serialize_event
from synapse.http.servlet import ( from synapse.http.servlet import (
assert_params_in_dict,
parse_integer, parse_integer,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
@ -435,9 +436,9 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
request, default_limit=10, request, default_limit=10,
) )
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
filter_bytes = request.args.get("filter", None) filter_bytes = parse_string(request, "filter")
if filter_bytes: if filter_bytes:
filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8") filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
event_filter = Filter(json.loads(filter_json)) event_filter = Filter(json.loads(filter_json))
else: else:
event_filter = None event_filter = None
@ -530,7 +531,7 @@ class RoomEventContextServlet(ClientV1RestServlet):
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
limit = int(request.args.get("limit", [10])[0]) limit = parse_integer(request, "limit", default=10)
results = yield self.handlers.room_context_handler.get_event_context( results = yield self.handlers.room_context_handler.get_event_context(
requester.user, requester.user,
@ -636,8 +637,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
target = requester.user target = requester.user
if membership_action in ["invite", "ban", "unban", "kick"]: if membership_action in ["invite", "ban", "unban", "kick"]:
if "user_id" not in content: assert_params_in_dict(content, ["user_id"])
raise SynapseError(400, "Missing user_id key.")
target = UserID.from_string(content["user_id"]) target = UserID.from_string(content["user_id"])
event_content = None event_content = None
@ -764,7 +764,7 @@ class SearchRestServlet(ClientV1RestServlet):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
batch = request.args.get("next_batch", [None])[0] batch = parse_string(request, "next_batch")
results = yield self.handlers.search_handler.search( results = yield self.handlers.search_handler.search(
requester.user, requester.user,
content, content,

View File

@ -24,7 +24,7 @@ from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_request, assert_params_in_dict,
parse_json_object_from_request, parse_json_object_from_request,
) )
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
@ -47,7 +47,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_request(body, [ assert_params_in_dict(body, [
'id_server', 'client_secret', 'email', 'send_attempt' 'id_server', 'client_secret', 'email', 'send_attempt'
]) ])
@ -80,7 +80,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_request(body, [ assert_params_in_dict(body, [
'id_server', 'client_secret', 'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt', 'country', 'phone_number', 'send_attempt',
]) ])
@ -159,11 +159,10 @@ class PasswordRestServlet(RestServlet):
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
user_id = threepid_user_id user_id = threepid_user_id
else: else:
logger.error("Auth succeeded but no known type!", result.keys()) logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN) raise SynapseError(500, "", Codes.UNKNOWN)
if 'new_password' not in params: assert_params_in_dict(params, ["new_password"])
raise SynapseError(400, "", Codes.MISSING_PARAM)
new_password = params['new_password'] new_password = params['new_password']
yield self._set_password_handler.set_password( yield self._set_password_handler.set_password(
@ -228,15 +227,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(
required = ['id_server', 'client_secret', 'email', 'send_attempt'] body,
absent = [] ['id_server', 'client_secret', 'email', 'send_attempt'],
for k in required: )
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
if not check_3pid_allowed(self.hs, "email", body['email']): if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError( raise SynapseError(
@ -266,18 +260,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(body, [
required = [
'id_server', 'client_secret', 'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt', 'country', 'phone_number', 'send_attempt',
] ])
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
@ -372,15 +358,7 @@ class ThreepidDeleteRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(body, ['medium', 'address'])
required = ['medium', 'address']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()

View File

@ -18,14 +18,18 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api import errors from synapse.api import errors
from synapse.http import servlet from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
RestServlet
)
from ._base import client_v2_patterns, interactive_auth_handler from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DevicesRestServlet(servlet.RestServlet): class DevicesRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices$", v2_alpha=False) PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
def __init__(self, hs): def __init__(self, hs):
@ -47,7 +51,7 @@ class DevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {"devices": devices})) defer.returnValue((200, {"devices": devices}))
class DeleteDevicesRestServlet(servlet.RestServlet): class DeleteDevicesRestServlet(RestServlet):
""" """
API for bulk deletion of devices. Accepts a JSON object with a devices API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth. key which lists the device_ids to delete. Requires user interactive auth.
@ -67,19 +71,17 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
try: try:
body = servlet.parse_json_object_from_request(request) body = parse_json_object_from_request(request)
except errors.SynapseError as e: except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON: if e.errcode == errors.Codes.NOT_JSON:
# deal with older clients which didn't pass a J*DELETESON dict # DELETE
# deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict # the same as those that pass an empty dict
body = {} body = {}
else: else:
raise e raise e
if 'devices' not in body: assert_params_in_dict(body, ["devices"])
raise errors.SynapseError(
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
)
yield self.auth_handler.validate_user_via_ui_auth( yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request), requester, body, self.hs.get_ip_from_request(request),
@ -92,7 +94,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class DeviceRestServlet(servlet.RestServlet): class DeviceRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False) PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
def __init__(self, hs): def __init__(self, hs):
@ -121,7 +123,7 @@ class DeviceRestServlet(servlet.RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
try: try:
body = servlet.parse_json_object_from_request(request) body = parse_json_object_from_request(request)
except errors.SynapseError as e: except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON: if e.errcode == errors.Codes.NOT_JSON:
@ -144,7 +146,7 @@ class DeviceRestServlet(servlet.RestServlet):
def on_PUT(self, request, device_id): def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
body = servlet.parse_json_object_from_request(request) body = parse_json_object_from_request(request)
yield self.device_handler.update_device( yield self.device_handler.update_device(
requester.user.to_string(), requester.user.to_string(),
device_id, device_id,

View File

@ -28,7 +28,7 @@ from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_request, assert_params_in_dict,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
@ -68,7 +68,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_request(body, [ assert_params_in_dict(body, [
'id_server', 'client_secret', 'email', 'send_attempt' 'id_server', 'client_secret', 'email', 'send_attempt'
]) ])
@ -104,7 +104,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_request(body, [ assert_params_in_dict(body, [
'id_server', 'client_secret', 'id_server', 'client_secret',
'country', 'phone_number', 'country', 'phone_number',
'send_attempt', 'send_attempt',
@ -386,9 +386,7 @@ class RegisterRestServlet(RestServlet):
add_msisdn = False add_msisdn = False
else: else:
# NB: This may be from the auth handler and NOT from the POST # NB: This may be from the auth handler and NOT from the POST
if 'password' not in params: assert_params_in_dict(params, ["password"])
raise SynapseError(400, "Missing password.",
Codes.MISSING_PARAM)
desired_username = params.get("username", None) desired_username = params.get("username", None)
new_password = params.get("password", None) new_password = params.get("password", None)
@ -565,11 +563,14 @@ class RegisterRestServlet(RestServlet):
Returns: Returns:
defer.Deferred: defer.Deferred:
""" """
reqd = ('medium', 'address', 'validated_at') try:
if any(x not in threepid for x in reqd): assert_params_in_dict(threepid, ['medium', 'address', 'validated_at'])
# This will only happen if the ID server returns a malformed response except SynapseError as ex:
logger.info("Can't add incomplete 3pid") if ex.errcode == Codes.MISSING_PARAM:
defer.returnValue() # This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
defer.returnValue(None)
raise
yield self.auth_handler.add_threepid( yield self.auth_handler.add_threepid(
user_id, user_id,

View File

@ -23,7 +23,7 @@ from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_request, assert_params_in_dict,
parse_json_object_from_request, parse_json_object_from_request,
) )
@ -50,7 +50,7 @@ class ReportEventRestServlet(RestServlet):
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_request(body, ("reason", "score")) assert_params_in_dict(body, ("reason", "score"))
if not isinstance(body["reason"], string_types): if not isinstance(body["reason"], string_types):
raise SynapseError( raise SynapseError(

View File

@ -14,6 +14,8 @@
from pydenticon import Generator from pydenticon import Generator
from synapse.http.servlet import parse_integer
from twisted.web.resource import Resource from twisted.web.resource import Resource
FOREGROUND = [ FOREGROUND = [
@ -56,8 +58,8 @@ class IdenticonResource(Resource):
def render_GET(self, request): def render_GET(self, request):
name = "/".join(request.postpath) name = "/".join(request.postpath)
width = int(request.args.get("width", [96])[0]) width = parse_integer(request, "width", default=96)
height = int(request.args.get("height", [96])[0]) height = parse_integer(request, "height", default=96)
identicon_bytes = self.generate_identicon(name, width, height) identicon_bytes = self.generate_identicon(name, width, height)
request.setHeader(b"Content-Type", b"image/png") request.setHeader(b"Content-Type", b"image/png")
request.setHeader( request.setHeader(

View File

@ -40,6 +40,7 @@ from synapse.http.server import (
respond_with_json_bytes, respond_with_json_bytes,
wrap_json_request_handler, wrap_json_request_handler,
) )
from synapse.http.servlet import parse_integer, parse_string
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background
@ -96,9 +97,9 @@ class PreviewUrlResource(Resource):
# XXX: if get_user_by_req fails, what should we do in an async render? # XXX: if get_user_by_req fails, what should we do in an async render?
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
url = request.args.get("url")[0] url = parse_string(request, "url")
if "ts" in request.args: if "ts" in request.args:
ts = int(request.args.get("ts")[0]) ts = parse_integer(request, "ts")
else: else:
ts = self.clock.time_msec() ts = self.clock.time_msec()

View File

@ -21,6 +21,7 @@ from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.server import respond_with_json, wrap_json_request_handler from synapse.http.server import respond_with_json, wrap_json_request_handler
from synapse.http.servlet import parse_string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -65,10 +66,10 @@ class UploadResource(Resource):
code=413, code=413,
) )
upload_name = request.args.get("filename", None) upload_name = parse_string(request, "filename")
if upload_name: if upload_name:
try: try:
upload_name = upload_name[0].decode('UTF-8') upload_name = upload_name.decode('UTF-8')
except UnicodeDecodeError: except UnicodeDecodeError:
raise SynapseError( raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name), msg="Invalid UTF-8 filename parameter: %r" % (upload_name),

View File

@ -16,6 +16,7 @@
import logging import logging
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string
from synapse.types import StreamToken from synapse.types import StreamToken
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,23 +57,10 @@ class PaginationConfig(object):
@classmethod @classmethod
def from_request(cls, request, raise_invalid_params=True, def from_request(cls, request, raise_invalid_params=True,
default_limit=None): default_limit=None):
def get_param(name, default=None): direction = parse_string(request, "dir", default='f', allowed_values=['f', 'b'])
lst = request.args.get(name, [])
if len(lst) > 1:
raise SynapseError(
400, "%s must be specified only once" % (name,)
)
elif len(lst) == 1:
return lst[0]
else:
return default
direction = get_param("dir", 'f') from_tok = parse_string(request, "from")
if direction not in ['f', 'b']: to_tok = parse_string(request, "to")
raise SynapseError(400, "'dir' parameter is invalid.")
from_tok = get_param("from")
to_tok = get_param("to")
try: try:
if from_tok == "END": if from_tok == "END":
@ -88,12 +76,7 @@ class PaginationConfig(object):
except Exception: except Exception:
raise SynapseError(400, "'to' paramater is invalid") raise SynapseError(400, "'to' paramater is invalid")
limit = get_param("limit", None) limit = parse_integer(request, "limit", default=default_limit)
if limit is not None and not limit.isdigit():
raise SynapseError(400, "'limit' parameter must be an integer.")
if limit is None:
limit = default_limit
try: try:
return PaginationConfig(from_tok, to_tok, direction, limit) return PaginationConfig(from_tok, to_tok, direction, limit)