mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-04 09:14:55 -04:00
Support trying multiple localparts for OpenID Connect. (#8801)
Abstracts the SAML and OpenID Connect code which attempts to regenerate the localpart of a matrix ID if it is already in use.
This commit is contained in:
parent
f38676d161
commit
4fd222ad70
6 changed files with 331 additions and 137 deletions
|
@ -12,6 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
|
||||
from urllib.parse import urlencode
|
||||
|
@ -35,15 +36,10 @@ from twisted.web.client import readBody
|
|||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.handlers._base import BaseHandler
|
||||
from synapse.handlers.sso import MappingException
|
||||
from synapse.handlers.sso import MappingException, UserAttributes
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
UserID,
|
||||
contains_invalid_mxid_characters,
|
||||
map_username_to_mxid_localpart,
|
||||
)
|
||||
from synapse.types import JsonDict, map_username_to_mxid_localpart
|
||||
from synapse.util import json_decoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -869,73 +865,51 @@ class OidcHandler(BaseHandler):
|
|||
# to be strings.
|
||||
remote_user_id = str(remote_user_id)
|
||||
|
||||
# first of all, check if we already have a mapping for this user
|
||||
previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
|
||||
self._auth_provider_id, remote_user_id,
|
||||
# Older mapping providers don't accept the `failures` argument, so we
|
||||
# try and detect support.
|
||||
mapper_signature = inspect.signature(
|
||||
self._user_mapping_provider.map_user_attributes
|
||||
)
|
||||
if previously_registered_user_id:
|
||||
return previously_registered_user_id
|
||||
supports_failures = "failures" in mapper_signature.parameters
|
||||
|
||||
# Otherwise, generate a new user.
|
||||
try:
|
||||
attributes = await self._user_mapping_provider.map_user_attributes(
|
||||
userinfo, token
|
||||
)
|
||||
except Exception as e:
|
||||
raise MappingException(
|
||||
"Could not extract user attributes from OIDC response: " + str(e)
|
||||
)
|
||||
async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
|
||||
"""
|
||||
Call the mapping provider to map the OIDC userinfo and token to user attributes.
|
||||
|
||||
logger.debug(
|
||||
"Retrieved user attributes from user mapping provider: %r", attributes
|
||||
)
|
||||
|
||||
localpart = attributes["localpart"]
|
||||
if not localpart:
|
||||
raise MappingException(
|
||||
"Error parsing OIDC response: OIDC mapping provider plugin "
|
||||
"did not return a localpart value"
|
||||
)
|
||||
|
||||
user_id = UserID(localpart, self.server_name).to_string()
|
||||
users = await self.store.get_users_by_id_case_insensitive(user_id)
|
||||
if users:
|
||||
if self._allow_existing_users:
|
||||
if len(users) == 1:
|
||||
registered_user_id = next(iter(users))
|
||||
elif user_id in users:
|
||||
registered_user_id = user_id
|
||||
else:
|
||||
raise MappingException(
|
||||
"Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
|
||||
user_id, list(users.keys())
|
||||
)
|
||||
)
|
||||
This is backwards compatibility for abstraction for the SSO handler.
|
||||
"""
|
||||
if supports_failures:
|
||||
attributes = await self._user_mapping_provider.map_user_attributes(
|
||||
userinfo, token, failures
|
||||
)
|
||||
else:
|
||||
# This mxid is taken
|
||||
raise MappingException("mxid '{}' is already taken".format(user_id))
|
||||
else:
|
||||
# Since the localpart is provided via a potentially untrusted module,
|
||||
# ensure the MXID is valid before registering.
|
||||
if contains_invalid_mxid_characters(localpart):
|
||||
raise MappingException("localpart is invalid: %s" % (localpart,))
|
||||
# If the mapping provider does not support processing failures,
|
||||
# do not continually generate the same Matrix ID since it will
|
||||
# continue to already be in use. Note that the error raised is
|
||||
# arbitrary and will get turned into a MappingException.
|
||||
if failures:
|
||||
raise RuntimeError(
|
||||
"Mapping provider does not support de-duplicating Matrix IDs"
|
||||
)
|
||||
|
||||
# It's the first time this user is logging in and the mapped mxid was
|
||||
# not taken, register the user
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart,
|
||||
default_display_name=attributes["display_name"],
|
||||
user_agent_ips=[(user_agent, ip_address)],
|
||||
)
|
||||
attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
|
||||
userinfo, token
|
||||
)
|
||||
|
||||
await self.store.record_user_external_id(
|
||||
self._auth_provider_id, remote_user_id, registered_user_id,
|
||||
return UserAttributes(**attributes)
|
||||
|
||||
return await self._sso_handler.get_mxid_from_sso(
|
||||
self._auth_provider_id,
|
||||
remote_user_id,
|
||||
user_agent,
|
||||
ip_address,
|
||||
oidc_response_to_user_attributes,
|
||||
self._allow_existing_users,
|
||||
)
|
||||
return registered_user_id
|
||||
|
||||
|
||||
UserAttribute = TypedDict(
|
||||
"UserAttribute", {"localpart": str, "display_name": Optional[str]}
|
||||
UserAttributeDict = TypedDict(
|
||||
"UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
|
||||
)
|
||||
C = TypeVar("C")
|
||||
|
||||
|
@ -978,13 +952,15 @@ class OidcMappingProvider(Generic[C]):
|
|||
raise NotImplementedError()
|
||||
|
||||
async def map_user_attributes(
|
||||
self, userinfo: UserInfo, token: Token
|
||||
) -> UserAttribute:
|
||||
self, userinfo: UserInfo, token: Token, failures: int
|
||||
) -> UserAttributeDict:
|
||||
"""Map a `UserInfo` object into user attributes.
|
||||
|
||||
Args:
|
||||
userinfo: An object representing the user given by the OIDC provider
|
||||
token: A dict with the tokens returned by the provider
|
||||
failures: How many times a call to this function with this
|
||||
UserInfo has resulted in a failure.
|
||||
|
||||
Returns:
|
||||
A dict containing the ``localpart`` and (optionally) the ``display_name``
|
||||
|
@ -1084,13 +1060,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
|||
return userinfo[self._config.subject_claim]
|
||||
|
||||
async def map_user_attributes(
|
||||
self, userinfo: UserInfo, token: Token
|
||||
) -> UserAttribute:
|
||||
self, userinfo: UserInfo, token: Token, failures: int
|
||||
) -> UserAttributeDict:
|
||||
localpart = self._config.localpart_template.render(user=userinfo).strip()
|
||||
|
||||
# Ensure only valid characters are included in the MXID.
|
||||
localpart = map_username_to_mxid_localpart(localpart)
|
||||
|
||||
# Append suffix integer if last call to this function failed to produce
|
||||
# a usable mxid.
|
||||
localpart += str(failures) if failures else ""
|
||||
|
||||
display_name = None # type: Optional[str]
|
||||
if self._config.display_name_template is not None:
|
||||
display_name = self._config.display_name_template.render(
|
||||
|
@ -1100,7 +1080,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
|||
if display_name == "":
|
||||
display_name = None
|
||||
|
||||
return UserAttribute(localpart=localpart, display_name=display_name)
|
||||
return UserAttributeDict(localpart=localpart, display_name=display_name)
|
||||
|
||||
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
|
||||
extras = {} # type: Dict[str, str]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue