improve mxid & displayname selection for register_mxid_from_3pid

* [x] strip invalid characters from generated mxid
* [x] append numbers to disambiguate clashing mxids
* [x] generate displayanames from 3pids using a dodgy heuristic
* [x] get rid of the create_profile_with_localpart and instead
      explicitly set displaynames so they propagate correctly
This commit is contained in:
Matthew Hodgson 2018-05-03 04:20:25 +01:00
parent 79b2583f1b
commit 32e4420a66
5 changed files with 90 additions and 33 deletions

View File

@ -208,7 +208,7 @@ class ProfileHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def set_displayname(self, target_user, requester, new_displayname, by_admin=False): def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
"""target_user is the user whose displayname is to be changed; """target_user is the user whose displayname is to be changed;
auth_user is the user attempting to make this change.""" requester is the user attempting to make this change."""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")

View File

@ -113,6 +113,7 @@ class RegistrationHandler(BaseHandler):
generate_token=True, generate_token=True,
guest_access_token=None, guest_access_token=None,
make_guest=False, make_guest=False,
display_name=None,
admin=False, admin=False,
): ):
"""Registers a new client on the server. """Registers a new client on the server.
@ -128,6 +129,7 @@ class RegistrationHandler(BaseHandler):
since it offers no means of associating a device_id with the since it offers no means of associating a device_id with the
access_token. Instead you should call auth_handler.issue_access_token access_token. Instead you should call auth_handler.issue_access_token
after registration. after registration.
display_name (str): The displayname to set for this user, if any
Returns: Returns:
A tuple of (user_id, access_token). A tuple of (user_id, access_token).
Raises: Raises:
@ -165,13 +167,20 @@ class RegistrationHandler(BaseHandler):
password_hash=password_hash, password_hash=password_hash,
was_guest=was_guest, was_guest=was_guest,
make_guest=make_guest, make_guest=make_guest,
create_profile_with_localpart=(
# If the user was a guest then they already have a profile
None if was_guest else user.localpart
),
admin=admin, admin=admin,
) )
if display_name is not None:
display_name = (
# If the user was a guest then they already have a profile
None if was_guest else user.localpart
)
if display_name:
yield self.profile_handler.set_displayname(
user_id, user_id, display_name, by_admin=True,
)
if self.hs.config.user_directory_search_all_users: if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(localpart) profile = yield self.store.get_profileinfo(localpart)
yield self.user_directory_handler.handle_local_profile_change( yield self.user_directory_handler.handle_local_profile_change(
@ -196,8 +205,12 @@ class RegistrationHandler(BaseHandler):
token=token, token=token,
password_hash=password_hash, password_hash=password_hash,
make_guest=make_guest, make_guest=make_guest,
create_profile_with_localpart=user.localpart,
) )
yield self.profile_handler.set_displayname(
user_id, user_id, user.localpart, by_admin=True,
)
except SynapseError: except SynapseError:
# if user id is taken, just generate another # if user id is taken, just generate another
user = None user = None
@ -241,8 +254,12 @@ class RegistrationHandler(BaseHandler):
user_id=user_id, user_id=user_id,
password_hash="", password_hash="",
appservice_id=service_id, appservice_id=service_id,
create_profile_with_localpart=user.localpart,
) )
yield self.profile_handler.set_displayname(
user_id, user_id, user.localpart, by_admin=True,
)
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -288,7 +305,10 @@ class RegistrationHandler(BaseHandler):
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=None, password_hash=None,
create_profile_with_localpart=user.localpart, )
yield self.profile_handler.set_displayname(
user_id, user_id, user.localpart, by_admin=True,
) )
except Exception as e: except Exception as e:
yield self.store.add_access_token_to_user(user_id, token) yield self.store.add_access_token_to_user(user_id, token)
@ -443,18 +463,15 @@ class RegistrationHandler(BaseHandler):
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash, password_hash=password_hash,
create_profile_with_localpart=user.localpart,
) )
if displayname is not None:
yield self.profile_handler.set_displayname(
user_id, user_id, displayname, by_admin=True,
)
else: else:
yield self._auth_handler.delete_access_tokens_for_user(user_id) yield self._auth_handler.delete_access_tokens_for_user(user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token) yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
yield self.profile_handler.set_displayname(
user, requester, displayname, by_admin=True,
)
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))
def auth_handler(self): def auth_handler(self):

View File

@ -17,7 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
import synapse import synapse
import synapse.types from synapse import types
from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
@ -31,6 +31,8 @@ from ._base import client_v2_patterns, interactive_auth_handler
import logging import logging
import hmac import hmac
import re
from string import capwords
from hashlib import sha1 from hashlib import sha1
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
@ -222,6 +224,8 @@ class RegisterRestServlet(RestServlet):
raise SynapseError(400, "Invalid username") raise SynapseError(400, "Invalid username")
desired_username = body['username'] desired_username = body['username']
desired_display_name = None
appservice = None appservice = None
if has_access_token(request): if has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request) appservice = yield self.auth.get_appservice_by_req(request)
@ -374,9 +378,10 @@ class RegisterRestServlet(RestServlet):
# reset it first to avoid folks picking their own username. # reset it first to avoid folks picking their own username.
desired_username = None desired_username = None
# we should always have an auth_result if we're going to progress # we should have an auth_result at this point if we're going to progress
# to register the user (i.e. we haven't picked up a registered_user_id) # to register the user (i.e. we haven't picked up a registered_user_id
# from our session store # from our session store), in which case get ready and gen the
# desired_username
if auth_result: if auth_result:
if ( if (
( (
@ -388,7 +393,41 @@ class RegisterRestServlet(RestServlet):
) )
): ):
address = auth_result[login_type]['address'] address = auth_result[login_type]['address']
desired_username = address.replace('@', '-').lower() desired_username = types.strip_invalid_mxid_characters(
address.replace('@', '-').lower()
)
# find a unique mxid for the account, suffixing numbers
# if needed
while True:
try:
yield self.registration_handler.check_username(
desired_username,
guest_access_token=guest_access_token,
assigned_user_id=registered_user_id,
)
# if we got this far we passed the check.
break
except SynapseError as e:
if e.errcode == Codes.USER_IN_USE:
m = re.match(r'^(.*)(\d+)$', desired_username)
if m:
desired_username = m.group(1) + str(
int(m.group(2)) + 1
)
else:
desired_username += "1"
else:
# something else went wrong.
break
# XXX: a nasty heuristic to turn an email address into
# a displayname, as part of register_mxid_from_3pid
parts = address.replace('.', ' ').split('@')
desired_display_name = (
capwords(parts[0]) +
" [" + capwords(parts[1].split(' ')[0]) + "]"
)
if desired_username is not None: if desired_username is not None:
yield self.registration_handler.check_username( yield self.registration_handler.check_username(
@ -431,6 +470,7 @@ class RegisterRestServlet(RestServlet):
password=new_password, password=new_password,
guest_access_token=guest_access_token, guest_access_token=guest_access_token,
generate_token=False, generate_token=False,
display_name=desired_display_name,
) )
# remember that we've now registered that user account, and with # remember that we've now registered that user account, and with

View File

@ -128,7 +128,7 @@ class RegistrationStore(RegistrationWorkerStore,
def register(self, user_id, token=None, password_hash=None, def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None, was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_localpart=None, admin=False): admin=False):
"""Attempts to register an account. """Attempts to register an account.
Args: Args:
@ -142,8 +142,6 @@ class RegistrationStore(RegistrationWorkerStore,
make_guest (boolean): True if the the new user should be guest, make_guest (boolean): True if the the new user should be guest,
false to add a regular user account. false to add a regular user account.
appservice_id (str): The ID of the appservice registering the user. appservice_id (str): The ID of the appservice registering the user.
create_profile_with_localpart (str): Optionally create a profile for
the given localpart.
Raises: Raises:
StoreError if the user_id could not be registered. StoreError if the user_id could not be registered.
""" """
@ -156,7 +154,6 @@ class RegistrationStore(RegistrationWorkerStore,
was_guest, was_guest,
make_guest, make_guest,
appservice_id, appservice_id,
create_profile_with_localpart,
admin admin
) )
@ -169,7 +166,6 @@ class RegistrationStore(RegistrationWorkerStore,
was_guest, was_guest,
make_guest, make_guest,
appservice_id, appservice_id,
create_profile_with_localpart,
admin, admin,
): ):
now = int(self.clock.time()) now = int(self.clock.time())
@ -234,14 +230,6 @@ class RegistrationStore(RegistrationWorkerStore,
(next_id, user_id, token,) (next_id, user_id, token,)
) )
if create_profile_with_localpart:
# set a default displayname serverside to avoid ugly race
# between auto-joins and clients trying to set displaynames
txn.execute(
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
(create_profile_with_localpart, create_profile_with_localpart)
)
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user_id,) txn, self.get_user_by_id, (user_id,)
) )

View File

@ -229,6 +229,18 @@ def contains_invalid_mxid_characters(localpart):
return any(c not in mxid_localpart_allowed_characters for c in localpart) return any(c not in mxid_localpart_allowed_characters for c in localpart)
def strip_invalid_mxid_characters(localpart):
"""Removes any invalid characters from an mxid
Args:
localpart (basestring): the localpart to be stripped
Returns:
localpart (basestring): the localpart having been stripped
"""
return filter(lambda c: c not in mxid_localpart_allowed_characters, localpart)
class StreamToken( class StreamToken(
namedtuple("Token", ( namedtuple("Token", (
"room_key", "room_key",