Refactor code for calculating registration flows (#6106)

because, frankly, it looked like it was written by an axe-murderer.

This should be a non-functional change, except that where `m.login.dummy` was
previously advertised *before* `m.login.terms`, it will now be advertised
afterwards. AFAICT that should have no effect, and will be more consistent with
the flows that involve passing a 3pid.
This commit is contained in:
Richard van der Hoff 2019-09-25 11:32:05 +01:00 committed by GitHub
parent f99a9c9cb0
commit 8004d6ca2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 145 additions and 83 deletions

1
changelog.d/6106.misc Normal file
View File

@ -0,0 +1 @@
Refactor code for calculating registration flows.

View File

@ -16,6 +16,7 @@
import hmac import hmac
import logging import logging
from typing import List, Union
from six import string_types from six import string_types
@ -31,8 +32,11 @@ from synapse.api.errors import (
ThreepidValidationError, ThreepidValidationError,
UnrecognizedRequestError, UnrecognizedRequestError,
) )
from synapse.config.captcha import CaptchaConfig
from synapse.config.consent_config import ConsentConfig
from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved from synapse.config.server import is_threepid_reserved
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -371,6 +375,8 @@ class RegisterRestServlet(RestServlet):
self.ratelimiter = hs.get_registration_ratelimiter() self.ratelimiter = hs.get_registration_ratelimiter()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._registration_flows = _calculate_registration_flows(hs.config)
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -491,69 +497,8 @@ class RegisterRestServlet(RestServlet):
assigned_user_id=registered_user_id, assigned_user_id=registered_user_id,
) )
# FIXME: need a better error than "no auth flow found" for scenarios
# where we required 3PID for registration but the user didn't give one
require_email = "email" in self.hs.config.registrations_require_3pid
require_msisdn = "msisdn" in self.hs.config.registrations_require_3pid
show_msisdn = True
if self.hs.config.disable_msisdn_registration:
show_msisdn = False
require_msisdn = False
flows = []
if self.hs.config.enable_registration_captcha:
# only support 3PIDless registration if no 3PIDs are required
if not require_email and not require_msisdn:
# Also add a dummy flow here, otherwise if a client completes
# recaptcha first we'll assume they were going for this flow
# and complete the request, when they could have been trying to
# complete one of the flows with email/msisdn auth.
flows.extend([[LoginType.RECAPTCHA, LoginType.DUMMY]])
# only support the email-only flow if we don't require MSISDN 3PIDs
if not require_msisdn:
flows.extend([[LoginType.RECAPTCHA, LoginType.EMAIL_IDENTITY]])
if show_msisdn:
# only support the MSISDN-only flow if we don't require email 3PIDs
if not require_email:
flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]])
# always let users provide both MSISDN & email
flows.extend(
[[LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY]]
)
else:
# only support 3PIDless registration if no 3PIDs are required
if not require_email and not require_msisdn:
flows.extend([[LoginType.DUMMY]])
# only support the email-only flow if we don't require MSISDN 3PIDs
if not require_msisdn:
flows.extend([[LoginType.EMAIL_IDENTITY]])
if show_msisdn:
# only support the MSISDN-only flow if we don't require email 3PIDs
if not require_email or require_msisdn:
flows.extend([[LoginType.MSISDN]])
# always let users provide both MSISDN & email
flows.extend([[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]])
# Append m.login.terms to all flows if we're requiring consent
if self.hs.config.user_consent_at_registration:
new_flows = []
for flow in flows:
inserted = False
# m.login.terms should go near the end but before msisdn or email auth
for i, stage in enumerate(flow):
if stage == LoginType.EMAIL_IDENTITY or stage == LoginType.MSISDN:
flow.insert(i, LoginType.TERMS)
inserted = True
break
if not inserted:
flow.append(LoginType.TERMS)
flows.extend(new_flows)
auth_result, params, session_id = yield self.auth_handler.check_auth( auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request) self._registration_flows, body, self.hs.get_ip_from_request(request)
) )
# Check that we're not trying to register a denied 3pid. # Check that we're not trying to register a denied 3pid.
@ -716,6 +661,61 @@ class RegisterRestServlet(RestServlet):
) )
def _calculate_registration_flows(
# technically `config` has to provide *all* of these interfaces, not just one
config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
) -> List[List[str]]:
"""Get a suitable flows list for registration
Args:
config: server configuration
Returns: a list of supported flows
"""
# FIXME: need a better error than "no auth flow found" for scenarios
# where we required 3PID for registration but the user didn't give one
require_email = "email" in config.registrations_require_3pid
require_msisdn = "msisdn" in config.registrations_require_3pid
show_msisdn = True
if config.disable_msisdn_registration:
show_msisdn = False
require_msisdn = False
flows = []
# only support 3PIDless registration if no 3PIDs are required
if not require_email and not require_msisdn:
# Add a dummy step here, otherwise if a client completes
# recaptcha first we'll assume they were going for this flow
# and complete the request, when they could have been trying to
# complete one of the flows with email/msisdn auth.
flows.append([LoginType.DUMMY])
# only support the email-only flow if we don't require MSISDN 3PIDs
if not require_msisdn:
flows.append([LoginType.EMAIL_IDENTITY])
# only support the MSISDN-only flow if we don't require email 3PIDs
if show_msisdn and not require_email:
flows.append([LoginType.MSISDN])
if show_msisdn:
flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY])
# Prepend m.login.terms to all flows if we're requiring consent
if config.user_consent_at_registration:
for flow in flows:
flow.insert(0, LoginType.TERMS)
# Prepend recaptcha to all flows if we're requiring captcha
if config.enable_registration_captcha:
for flow in flows:
flow.insert(0, LoginType.RECAPTCHA)
return flows
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
EmailRegisterRequestTokenRestServlet(hs).register(http_server) EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)

View File

@ -34,19 +34,12 @@ from tests import unittest
class RegisterRestServletTestCase(unittest.HomeserverTestCase): class RegisterRestServletTestCase(unittest.HomeserverTestCase):
servlets = [register.register_servlets] servlets = [register.register_servlets]
url = b"/_matrix/client/r0/register"
def make_homeserver(self, reactor, clock): def default_config(self, name="test"):
config = super().default_config(name)
self.url = b"/_matrix/client/r0/register" config["allow_guest_access"] = True
return config
self.hs = self.setup_test_homeserver()
self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = []
self.hs.config.enable_registration_captcha = False
self.hs.config.allow_guest_access = True
return self.hs
def test_POST_appservice_registration_valid(self): def test_POST_appservice_registration_valid(self):
user_id = "@as_user_kermit:test" user_id = "@as_user_kermit:test"
@ -199,6 +192,68 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
def test_advertised_flows(self):
request, channel = self.make_request(b"POST", self.url, b"{}")
self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
# with the stock config, we expect all four combinations of 3pid
self.assertCountEqual(
[
["m.login.dummy"],
["m.login.email.identity"],
["m.login.msisdn"],
["m.login.msisdn", "m.login.email.identity"],
],
(f["stages"] for f in flows),
)
@unittest.override_config(
{
"enable_registration_captcha": True,
"user_consent": {
"version": "1",
"template_dir": "/",
"require_at_registration": True,
},
}
)
def test_advertised_flows_captcha_and_terms(self):
request, channel = self.make_request(b"POST", self.url, b"{}")
self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
self.assertCountEqual(
[
["m.login.recaptcha", "m.login.terms", "m.login.dummy"],
["m.login.recaptcha", "m.login.terms", "m.login.email.identity"],
["m.login.recaptcha", "m.login.terms", "m.login.msisdn"],
[
"m.login.recaptcha",
"m.login.terms",
"m.login.msisdn",
"m.login.email.identity",
],
],
(f["stages"] for f in flows),
)
@unittest.override_config(
{"registrations_require_3pid": ["email"], "disable_msisdn_registration": True}
)
def test_advertised_flows_no_msisdn_email_required(self):
request, channel = self.make_request(b"POST", self.url, b"{}")
self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
# with the stock config, we expect all four combinations of 3pid
self.assertCountEqual(
[["m.login.email.identity"]], (f["stages"] for f in flows)
)
class AccountValidityTestCase(unittest.HomeserverTestCase): class AccountValidityTestCase(unittest.HomeserverTestCase):

View File

@ -28,6 +28,21 @@ from tests import unittest
class TermsTestCase(unittest.HomeserverTestCase): class TermsTestCase(unittest.HomeserverTestCase):
servlets = [register_servlets] servlets = [register_servlets]
def default_config(self, name="test"):
config = super().default_config(name)
config.update(
{
"public_baseurl": "https://example.org/",
"user_consent": {
"version": "1.0",
"policy_name": "My Cool Privacy Policy",
"template_dir": "/",
"require_at_registration": True,
},
}
)
return config
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.clock = MemoryReactorClock() self.clock = MemoryReactorClock()
self.hs_clock = Clock(self.clock) self.hs_clock = Clock(self.clock)
@ -35,17 +50,8 @@ class TermsTestCase(unittest.HomeserverTestCase):
self.registration_handler = Mock() self.registration_handler = Mock()
self.auth_handler = Mock() self.auth_handler = Mock()
self.device_handler = Mock() self.device_handler = Mock()
hs.config.enable_registration = True
hs.config.registrations_require_3pid = []
hs.config.auto_join_rooms = []
hs.config.enable_registration_captcha = False
def test_ui_auth(self): def test_ui_auth(self):
self.hs.config.user_consent_at_registration = True
self.hs.config.user_consent_policy_name = "My Cool Privacy Policy"
self.hs.config.public_baseurl = "https://example.org/"
self.hs.config.user_consent_version = "1.0"
# Do a UI auth request # Do a UI auth request
request, channel = self.make_request(b"POST", self.url, b"{}") request, channel = self.make_request(b"POST", self.url, b"{}")
self.render(request) self.render(request)