Port rest/ to Python 3 (#3823)

This commit is contained in:
Amber Brown 2018-09-12 20:41:31 +10:00 committed by GitHub
parent 8fd93b5eea
commit 02aa41809b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 113 additions and 100 deletions

View file

@ -101,7 +101,7 @@ class UserRegisterServlet(ClientV1RestServlet):
nonce = self.hs.get_secrets().token_hex(64)
self.nonces[nonce] = int(self.reactor.seconds())
return (200, {"nonce": nonce.encode('ascii')})
return (200, {"nonce": nonce})
@defer.inlineCallbacks
def on_POST(self, request):
@ -164,7 +164,7 @@ class UserRegisterServlet(ClientV1RestServlet):
key=self.hs.config.registration_shared_secret.encode(),
digestmod=hashlib.sha1,
)
want_mac.update(nonce)
want_mac.update(nonce.encode('utf8'))
want_mac.update(b"\x00")
want_mac.update(username)
want_mac.update(b"\x00")
@ -173,7 +173,10 @@ class UserRegisterServlet(ClientV1RestServlet):
want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest()
if not hmac.compare_digest(want_mac, got_mac.encode('ascii')):
if not hmac.compare_digest(
want_mac.encode('ascii'),
got_mac.encode('ascii')
):
raise SynapseError(403, "HMAC incorrect")
# Reuse the parts of RegisterRestServlet to reduce code duplication

View file

@ -45,20 +45,20 @@ class EventStreamRestServlet(ClientV1RestServlet):
is_guest = requester.is_guest
room_id = None
if is_guest:
if "room_id" not in request.args:
if b"room_id" not in request.args:
raise SynapseError(400, "Guest users must specify room_id param")
if "room_id" in request.args:
room_id = request.args["room_id"][0]
if b"room_id" in request.args:
room_id = request.args[b"room_id"][0].decode('ascii')
pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args:
if b"timeout" in request.args:
try:
timeout = int(request.args["timeout"][0])
timeout = int(request.args[b"timeout"][0])
except ValueError:
raise SynapseError(400, "timeout must be in milliseconds.")
as_client_event = "raw" not in request.args
as_client_event = b"raw" not in request.args
chunk = yield self.event_stream_handler.get_stream(
requester.user.to_string(),

View file

@ -32,7 +32,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args
as_client_event = b"raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
include_archived = parse_boolean(request, "archived", default=False)
content = yield self.initial_sync_handler.snapshot_all_rooms(

View file

@ -14,10 +14,9 @@
# limitations under the License.
import logging
import urllib
import xml.etree.ElementTree as ET
from six.moves.urllib import parse as urlparse
from six.moves import urllib
from canonicaljson import json
from saml2 import BINDING_HTTP_POST, config
@ -134,7 +133,7 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.SAML2_TYPE):
relay_state = ""
if "relay_state" in login_submission:
relay_state = "&RelayState=" + urllib.quote(
relay_state = "&RelayState=" + urllib.parse.quote(
login_submission["relay_state"])
result = {
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
@ -366,7 +365,7 @@ class SAML2RestServlet(ClientV1RestServlet):
(user_id, token) = yield handler.register_saml2(username)
# Forward to the RelayState callback along with ava
if 'RelayState' in request.args:
request.redirect(urllib.unquote(
request.redirect(urllib.parse.unquote(
request.args['RelayState'][0]) +
'?status=authenticated&access_token=' +
token + '&user_id=' + user_id + '&ava=' +
@ -377,7 +376,7 @@ class SAML2RestServlet(ClientV1RestServlet):
"user_id": user_id, "token": token,
"ava": saml2_auth.ava}))
elif 'RelayState' in request.args:
request.redirect(urllib.unquote(
request.redirect(urllib.parse.unquote(
request.args['RelayState'][0]) +
'?status=not_authenticated')
finish_request(request)
@ -390,21 +389,22 @@ class CasRedirectServlet(ClientV1RestServlet):
def __init__(self, hs):
super(CasRedirectServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_server_url = hs.config.cas_server_url.encode('ascii')
self.cas_service_url = hs.config.cas_service_url.encode('ascii')
def on_GET(self, request):
args = request.args
if "redirectUrl" not in args:
if b"redirectUrl" not in args:
return (400, "Redirect URL not specified for CAS auth")
client_redirect_url_param = urllib.urlencode({
"redirectUrl": args["redirectUrl"][0]
})
hs_redirect_url = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket"
service_param = urllib.urlencode({
"service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
})
request.redirect("%s/login?%s" % (self.cas_server_url, service_param))
client_redirect_url_param = urllib.parse.urlencode({
b"redirectUrl": args[b"redirectUrl"][0]
}).encode('ascii')
hs_redirect_url = (self.cas_service_url +
b"/_matrix/client/api/v1/login/cas/ticket")
service_param = urllib.parse.urlencode({
b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
}).encode('ascii')
request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param))
finish_request(request)
@ -422,11 +422,11 @@ class CasTicketServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
client_redirect_url = request.args["redirectUrl"][0]
client_redirect_url = request.args[b"redirectUrl"][0]
http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": request.args["ticket"],
"ticket": request.args[b"ticket"][0].decode('ascii'),
"service": self.cas_service_url
}
try:
@ -471,11 +471,11 @@ class CasTicketServlet(ClientV1RestServlet):
finish_request(request)
def add_login_token_to_redirect_url(self, url, token):
url_parts = list(urlparse.urlparse(url))
query = dict(urlparse.parse_qsl(url_parts[4]))
url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({"loginToken": token})
url_parts[4] = urllib.urlencode(query)
return urlparse.urlunparse(url_parts)
url_parts[4] = urllib.parse.urlencode(query).encode('ascii')
return urllib.parse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body):
user = None

View file

@ -46,7 +46,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
try:
priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e:
raise SynapseError(400, e.message)
raise SynapseError(400, str(e))
requester = yield self.auth.get_user_by_req(request)
@ -73,7 +73,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
content,
)
except InvalidRuleException as e:
raise SynapseError(400, e.message)
raise SynapseError(400, str(e))
before = parse_string(request, "before")
if before:
@ -95,9 +95,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
)
self.notify_user(user_id)
except InconsistentRuleException as e:
raise SynapseError(400, e.message)
raise SynapseError(400, str(e))
except RuleNotFoundException as e:
raise SynapseError(400, e.message)
raise SynapseError(400, str(e))
defer.returnValue((200, {}))
@ -142,10 +142,10 @@ class PushRuleRestServlet(ClientV1RestServlet):
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
if path[0] == b'':
defer.returnValue((200, rules))
elif path[0] == 'global':
path = path[1:]
elif path[0] == b'global':
path = [x.decode('ascii') for x in path[1:]]
result = _filter_ruleset_with_path(rules['global'], path)
defer.returnValue((200, result))
else:
@ -192,10 +192,10 @@ class PushRuleRestServlet(ClientV1RestServlet):
def _rule_spec_from_path(path):
if len(path) < 2:
raise UnrecognizedRequestError()
if path[0] != 'pushrules':
if path[0] != b'pushrules':
raise UnrecognizedRequestError()
scope = path[1]
scope = path[1].decode('ascii')
path = path[2:]
if scope != 'global':
raise UnrecognizedRequestError()
@ -203,13 +203,13 @@ def _rule_spec_from_path(path):
if len(path) == 0:
raise UnrecognizedRequestError()
template = path[0]
template = path[0].decode('ascii')
path = path[1:]
if len(path) == 0 or len(path[0]) == 0:
raise UnrecognizedRequestError()
rule_id = path[0]
rule_id = path[0].decode('ascii')
spec = {
'scope': scope,
@ -220,7 +220,7 @@ def _rule_spec_from_path(path):
path = path[1:]
if len(path) > 0 and len(path[0]) > 0:
spec['attr'] = path[0]
spec['attr'] = path[0].decode('ascii')
return spec

View file

@ -59,7 +59,7 @@ class PushersRestServlet(ClientV1RestServlet):
]
for p in pushers:
for k, v in p.items():
for k, v in list(p.items()):
if k not in allowed_keys:
del p[k]
@ -126,7 +126,7 @@ class PushersSetRestServlet(ClientV1RestServlet):
profile_tag=content.get('profile_tag', ""),
)
except PusherConfigException as pce:
raise SynapseError(400, "Config Error: " + pce.message,
raise SynapseError(400, "Config Error: " + str(pce),
errcode=Codes.MISSING_PARAM)
self.notifier.on_new_replication_data()

View file

@ -207,7 +207,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
"sender": requester.user.to_string(),
}
if 'ts' in request.args and requester.app_service:
if b'ts' in request.args and requester.app_service:
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
event = yield self.event_creation_hander.create_and_send_nonmember_event(
@ -255,7 +255,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
if RoomID.is_valid(room_identifier):
room_id = room_identifier
try:
remote_room_hosts = request.args["server_name"]
remote_room_hosts = [
x.decode('ascii') for x in request.args[b"server_name"]
]
except Exception:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
@ -461,10 +463,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
pagination_config = PaginationConfig.from_request(
request, default_limit=10,
)
as_client_event = "raw" not in request.args
filter_bytes = parse_string(request, "filter")
as_client_event = b"raw" not in request.args
filter_bytes = parse_string(request, b"filter", encoding=None)
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
@ -560,7 +562,7 @@ class RoomEventContextServlet(ClientV1RestServlet):
# picking the API shape for symmetry with /messages
filter_bytes = parse_string(request, "filter")
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
filter_json = urlparse.unquote(filter_bytes)
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None

View file

@ -89,7 +89,7 @@ class SyncRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
if "from" in request.args:
if b"from" in request.args:
# /events used to use 'from', but /sync uses 'since'.
# Lets be helpful and whine if we see a 'from'.
raise SynapseError(

View file

@ -79,7 +79,7 @@ class ThirdPartyUserServlet(RestServlet):
yield self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args
fields.pop("access_token", None)
fields.pop(b"access_token", None)
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields
@ -102,7 +102,7 @@ class ThirdPartyLocationServlet(RestServlet):
yield self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args
fields.pop("access_token", None)
fields.pop(b"access_token", None)
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields