diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 76d2d2d64..90ea19bd4 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -23,11 +23,11 @@ from distutils.util import strtobool class RegistrationConfig(Config): def read_config(self, config): - self.disable_registration = not bool( + self.enable_registration = bool( strtobool(str(config["enable_registration"])) ) if "disable_registration" in config: - self.disable_registration = bool( + self.enable_registration = not bool( strtobool(str(config["disable_registration"])) ) @@ -78,6 +78,6 @@ class RegistrationConfig(Config): def read_arguments(self, args): if args.enable_registration is not None: - self.disable_registration = not bool( + self.enable_registration = bool( strtobool(str(args.enable_registration)) ) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index 2bfd4d96b..6d6d03c34 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -59,7 +59,7 @@ class RegisterRestServlet(ClientV1RestServlet): # } # TODO: persistent storage self.sessions = {} - self.disable_registration = hs.config.disable_registration + self.enable_registration = hs.config.enable_registration def on_GET(self, request): if self.hs.config.enable_registration_captcha: @@ -113,7 +113,7 @@ class RegisterRestServlet(ClientV1RestServlet): is_using_shared_secret = login_type == LoginType.SHARED_SECRET can_register = ( - not self.disable_registration + self.enable_registration or is_application_server or is_using_shared_secret ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 56a5bbec3..ec5c21fa1 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -117,7 +117,7 @@ class RegisterRestServlet(RestServlet): return # == Normal User Registration == (everyone else) - if self.hs.config.disable_registration: + if not self.hs.config.enable_registration: raise SynapseError(403, "Registration has been disabled") guest_access_token = body.get("guest_access_token", None) diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index b260e269a..e9698bfdc 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -122,7 +122,7 @@ class EventStreamPermissionsTestCase(RestTestCase): self.ratelimiter = hs.get_ratelimiter() self.ratelimiter.send_message.return_value = (True, 0) hs.config.enable_registration_captcha = False - hs.config.disable_registration = False + hs.config.enable_registration = True hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index f9a2b2248..df0841b0b 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -41,7 +41,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.hs.hostname = "superbig~testing~thing.com" self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_handlers = Mock(return_value=self.handlers) - self.hs.config.disable_registration = False + self.hs.config.enable_registration = True # init the thing we're testing self.servlet = RegisterRestServlet(self.hs) @@ -120,7 +120,7 @@ class RegisterRestServletTestCase(unittest.TestCase): })) def test_POST_disabled_registration(self): - self.hs.config.disable_registration = True + self.hs.config.enable_registration = False self.request_data = json.dumps({ "username": "kermit", "password": "monkey" diff --git a/tests/utils.py b/tests/utils.py index 431252a6f..3b1eb50d8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -46,7 +46,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config = Mock() config.signing_key = [MockKey()] config.event_cache_size = 1 - config.disable_registration = False + config.enable_registration = True config.macaroon_secret_key = "not even a little secret" config.server_name = "server.under.test" config.trusted_third_party_id_servers = []