Initialise user displayname from SAML2 data (#4272)

When we register a new user from SAML2 data, initialise their displayname
correctly.
This commit is contained in:
Richard van der Hoff 2018-12-07 14:44:46 +01:00 committed by GitHub
parent 35e13477cf
commit 30da50a5b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 39 additions and 15 deletions

1
changelog.d/4272.feature Normal file
View File

@ -0,0 +1 @@
SAML2 authentication: Initialise user display name from SAML2 data

View File

@ -126,6 +126,7 @@ class RegistrationHandler(BaseHandler):
make_guest=False, make_guest=False,
admin=False, admin=False,
threepid=None, threepid=None,
default_display_name=None,
): ):
"""Registers a new client on the server. """Registers a new client on the server.
@ -140,6 +141,8 @@ class RegistrationHandler(BaseHandler):
since it offers no means of associating a device_id with the since it offers no means of associating a device_id with the
access_token. Instead you should call auth_handler.issue_access_token access_token. Instead you should call auth_handler.issue_access_token
after registration. after registration.
default_display_name (unicode|None): if set, the new user's displayname
will be set to this. Defaults to 'localpart'.
Returns: Returns:
A tuple of (user_id, access_token). A tuple of (user_id, access_token).
Raises: Raises:
@ -169,6 +172,13 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
if was_guest:
# If the user was a guest then they already have a profile
default_display_name = None
elif default_display_name is None:
default_display_name = localpart
token = None token = None
if generate_token: if generate_token:
token = self.macaroon_gen.generate_access_token(user_id) token = self.macaroon_gen.generate_access_token(user_id)
@ -178,10 +188,7 @@ class RegistrationHandler(BaseHandler):
password_hash=password_hash, password_hash=password_hash,
was_guest=was_guest, was_guest=was_guest,
make_guest=make_guest, make_guest=make_guest,
create_profile_with_localpart=( create_profile_with_displayname=default_display_name,
# If the user was a guest then they already have a profile
None if was_guest else user.localpart
),
admin=admin, admin=admin,
) )
@ -203,13 +210,15 @@ class RegistrationHandler(BaseHandler):
yield self.check_user_id_not_appservice_exclusive(user_id) yield self.check_user_id_not_appservice_exclusive(user_id)
if generate_token: if generate_token:
token = self.macaroon_gen.generate_access_token(user_id) token = self.macaroon_gen.generate_access_token(user_id)
if default_display_name is None:
default_display_name = localpart
try: try:
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash, password_hash=password_hash,
make_guest=make_guest, make_guest=make_guest,
create_profile_with_localpart=user.localpart, create_profile_with_displayname=default_display_name,
) )
except SynapseError: except SynapseError:
# if user id is taken, just generate another # if user id is taken, just generate another
@ -300,7 +309,7 @@ class RegistrationHandler(BaseHandler):
user_id=user_id, user_id=user_id,
password_hash="", password_hash="",
appservice_id=service_id, appservice_id=service_id,
create_profile_with_localpart=user.localpart, create_profile_with_displayname=user.localpart,
) )
defer.returnValue(user_id) defer.returnValue(user_id)
@ -478,7 +487,7 @@ class RegistrationHandler(BaseHandler):
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash, password_hash=password_hash,
create_profile_with_localpart=user.localpart, create_profile_with_displayname=user.localpart,
) )
else: else:
yield self._auth_handler.delete_access_tokens_for_user(user_id) yield self._auth_handler.delete_access_tokens_for_user(user_id)

View File

@ -451,6 +451,7 @@ class SSOAuthHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_successful_auth( def on_successful_auth(
self, username, request, client_redirect_url, self, username, request, client_redirect_url,
user_display_name=None,
): ):
"""Called once the user has successfully authenticated with the SSO. """Called once the user has successfully authenticated with the SSO.
@ -467,6 +468,9 @@ class SSOAuthHandler(object):
client_redirect_url (unicode): the redirect_url the client gave us when client_redirect_url (unicode): the redirect_url the client gave us when
it first started the process. it first started the process.
user_display_name (unicode|None): if set, and we have to register a new user,
we will set their displayname to this.
Returns: Returns:
Deferred[none]: Completes once we have handled the request. Deferred[none]: Completes once we have handled the request.
""" """
@ -478,6 +482,7 @@ class SSOAuthHandler(object):
yield self._registration_handler.register( yield self._registration_handler.register(
localpart=localpart, localpart=localpart,
generate_token=False, generate_token=False,
default_display_name=user_display_name,
) )
) )

View File

@ -66,6 +66,9 @@ class SAML2ResponseResource(Resource):
raise CodeMessageException(400, "uid not in SAML2 response") raise CodeMessageException(400, "uid not in SAML2 response")
username = saml2_auth.ava["uid"][0] username = saml2_auth.ava["uid"][0]
displayName = saml2_auth.ava.get("displayName", [None])[0]
return self._sso_auth_handler.on_successful_auth( return self._sso_auth_handler.on_successful_auth(
username, request, relay_state, username, request, relay_state,
user_display_name=displayName,
) )

View File

@ -22,6 +22,7 @@ from twisted.internet import defer
from synapse.api.errors import Codes, StoreError from synapse.api.errors import Codes, StoreError
from synapse.storage import background_updates from synapse.storage import background_updates
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@ -167,7 +168,7 @@ class RegistrationStore(RegistrationWorkerStore,
def register(self, user_id, token=None, password_hash=None, def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None, was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_localpart=None, admin=False): create_profile_with_displayname=None, admin=False):
"""Attempts to register an account. """Attempts to register an account.
Args: Args:
@ -181,8 +182,8 @@ class RegistrationStore(RegistrationWorkerStore,
make_guest (boolean): True if the the new user should be guest, make_guest (boolean): True if the the new user should be guest,
false to add a regular user account. false to add a regular user account.
appservice_id (str): The ID of the appservice registering the user. appservice_id (str): The ID of the appservice registering the user.
create_profile_with_localpart (str): Optionally create a profile for create_profile_with_displayname (unicode): Optionally create a profile for
the given localpart. the user, setting their displayname to the given value
Raises: Raises:
StoreError if the user_id could not be registered. StoreError if the user_id could not be registered.
""" """
@ -195,7 +196,7 @@ class RegistrationStore(RegistrationWorkerStore,
was_guest, was_guest,
make_guest, make_guest,
appservice_id, appservice_id,
create_profile_with_localpart, create_profile_with_displayname,
admin admin
) )
@ -208,9 +209,11 @@ class RegistrationStore(RegistrationWorkerStore,
was_guest, was_guest,
make_guest, make_guest,
appservice_id, appservice_id,
create_profile_with_localpart, create_profile_with_displayname,
admin, admin,
): ):
user_id_obj = UserID.from_string(user_id)
now = int(self.clock.time()) now = int(self.clock.time())
next_id = self._access_tokens_id_gen.get_next() next_id = self._access_tokens_id_gen.get_next()
@ -273,12 +276,15 @@ class RegistrationStore(RegistrationWorkerStore,
(next_id, user_id, token,) (next_id, user_id, token,)
) )
if create_profile_with_localpart: if create_profile_with_displayname:
# set a default displayname serverside to avoid ugly race # set a default displayname serverside to avoid ugly race
# between auto-joins and clients trying to set displaynames # between auto-joins and clients trying to set displaynames
#
# *obviously* the 'profiles' table uses localpart for user_id
# while everything else uses the full mxid.
txn.execute( txn.execute(
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)", "INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
(create_profile_with_localpart, create_profile_with_localpart) (user_id_obj.localpart, create_profile_with_displayname)
) )
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(

View File

@ -149,7 +149,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
def test_populate_monthly_users_is_guest(self): def test_populate_monthly_users_is_guest(self):
# Test that guest users are not added to mau list # Test that guest users are not added to mau list
user_id = "user_id" user_id = "@user_id:host"
self.store.register( self.store.register(
user_id=user_id, token="123", password_hash=None, make_guest=True user_id=user_id, token="123", password_hash=None, make_guest=True
) )