Separate out requestTokens to separate handlers

This commit is contained in:
David Baker 2016-07-11 09:57:07 +01:00
parent 9c491366c5
commit a5db0026ed
2 changed files with 93 additions and 65 deletions

View File

@ -28,24 +28,54 @@ import logging
logger = logging.getLogger(__name__)
class PasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
def __init__(self, hs):
super(PasswordRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is None:
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
def on_OPTIONS(self, _):
return 200, {}
class PasswordRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password")
PATTERNS = client_v2_patterns("/account/password$")
def __init__(self, hs):
super(PasswordRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
if '/account/password/email/requestToken' in request.path:
ret = yield self.onPasswordEmailTokenRequest(request)
defer.returnValue(ret)
body = parse_json_object_from_request(request)
authed, result, params, _ = yield self.auth_handler.check_auth([
@ -90,8 +120,20 @@ class PasswordRestServlet(RestServlet):
defer.returnValue((200, {}))
def on_OPTIONS(self, _):
return 200, {}
class ThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs):
self.hs = hs
super(ThreepidRequestTokenRestServlet, self).__init__()
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def onPasswordEmailTokenRequest(self, request):
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
@ -107,8 +149,10 @@ class PasswordRestServlet(RestServlet):
'email', body['email']
)
if existingUid is None:
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
logger.error("existing %r", existingUid)
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
@ -118,7 +162,7 @@ class PasswordRestServlet(RestServlet):
class ThreepidRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid")
PATTERNS = client_v2_patterns("/account/3pid$")
def __init__(self, hs):
super(ThreepidRestServlet, self).__init__()
@ -143,10 +187,6 @@ class ThreepidRestServlet(RestServlet):
def on_POST(self, request):
yield run_on_reactor()
if '/account/3pid/email/requestToken' in request.path:
ret = yield self.onThreepidEmailTokenRequest(request)
defer.returnValue(ret)
body = parse_json_object_from_request(request)
threePidCreds = body.get('threePidCreds')
@ -187,30 +227,9 @@ class ThreepidRestServlet(RestServlet):
defer.returnValue((200, {}))
@defer.inlineCallbacks
def onThreepidEmailTokenRequest(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
def register_servlets(hs, http_server):
PasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
ThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)

View File

@ -41,8 +41,43 @@ else:
logger = logging.getLogger(__name__)
class RegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/email/requestToken$")
def __init__(self, hs):
super(RegisterRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
def on_OPTIONS(self, _):
return 200, {}
class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register")
PATTERNS = client_v2_patterns("/register$")
def __init__(self, hs):
super(RegisterRestServlet, self).__init__()
@ -70,10 +105,6 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind,)
)
if '/register/email/requestToken' in request.path:
ret = yield self.onEmailTokenRequest(request)
defer.returnValue(ret)
body = parse_json_object_from_request(request)
# we do basic sanity checks here because the auth layer will store these
@ -305,29 +336,6 @@ class RegisterRestServlet(RestServlet):
"refresh_token": refresh_token,
})
@defer.inlineCallbacks
def onEmailTokenRequest(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
@defer.inlineCallbacks
def _do_guest_registration(self):
if not self.hs.config.allow_guest_access:
@ -345,4 +353,5 @@ class RegisterRestServlet(RestServlet):
def register_servlets(hs, http_server):
RegisterRequestTokenRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server)