Preparatory refactors of OidcHandler (#9067)

Some light refactoring of OidcHandler, in preparation for bigger things:

  * remove inheritance from deprecated BaseHandler
  * add an object to hold the things that go into a session cookie
  * factor out a separate class for manipulating said cookies
This commit is contained in:
Richard van der Hoff 2021-01-13 10:26:12 +00:00 committed by GitHub
parent 7a2e9b549d
commit bc4bf7b384
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 201 additions and 165 deletions

1
changelog.d/9067.feature Normal file
View File

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

View File

@ -14,7 +14,7 @@
# limitations under the License.
import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
from urllib.parse import urlencode
import attr
@ -35,7 +35,6 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
@ -85,12 +84,15 @@ class OidcError(Exception):
return self.error
class OidcHandler(BaseHandler):
class OidcHandler:
"""Handles requests related to the OpenID Connect login flow.
"""
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._store = hs.get_datastore()
self._token_generator = OidcSessionTokenGenerator(hs)
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
@ -116,7 +118,6 @@ class OidcHandler(BaseHandler):
self._http_client = hs.get_proxied_http_client()
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
# identifier for the external_ids table
self.idp_id = "oidc"
@ -519,11 +520,13 @@ class OidcHandler(BaseHandler):
if not client_redirect_url:
client_redirect_url = b""
cookie = self._generate_oidc_session_token(
cookie = self._token_generator.generate_oidc_session_token(
state=state,
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
ui_auth_session_id=ui_auth_session_id,
session_data=OidcSessionData(
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
ui_auth_session_id=ui_auth_session_id,
),
)
request.addCookie(
SESSION_COOKIE_NAME,
@ -628,11 +631,9 @@ class OidcHandler(BaseHandler):
# Deserialize the session token and verify it.
try:
(
nonce,
client_redirect_url,
ui_auth_session_id,
) = self._verify_oidc_session_token(session, state)
session_data = self._token_generator.verify_oidc_session_token(
session, state
)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._sso_handler.render_error(request, "invalid_session", str(e))
@ -674,14 +675,14 @@ class OidcHandler(BaseHandler):
else:
logger.debug("Extracting userinfo from id_token")
try:
userinfo = await self._parse_id_token(token, nonce=nonce)
userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
except Exception as e:
logger.exception("Invalid id_token")
self._sso_handler.render_error(request, "invalid_token", str(e))
return
# first check if we're doing a UIA
if ui_auth_session_id:
if session_data.ui_auth_session_id:
try:
remote_user_id = self._remote_id_from_userinfo(userinfo)
except Exception as e:
@ -690,7 +691,7 @@ class OidcHandler(BaseHandler):
return
return await self._sso_handler.complete_sso_ui_auth_request(
self.idp_id, remote_user_id, ui_auth_session_id, request
self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
)
# otherwise, it's a login
@ -698,133 +699,12 @@ class OidcHandler(BaseHandler):
# Call the mapper to register/login the user
try:
await self._complete_oidc_login(
userinfo, token, request, client_redirect_url
userinfo, token, request, session_data.client_redirect_url
)
except MappingException as e:
logger.exception("Could not map user")
self._sso_handler.render_error(request, "mapping_error", str(e))
def _generate_oidc_session_token(
self,
state: str,
nonce: str,
client_redirect_url: str,
ui_auth_session_id: Optional[str],
duration_in_ms: int = (60 * 60 * 1000),
) -> str:
"""Generates a signed token storing data about an OIDC session.
When Synapse initiates an authorization flow, it creates a random state
and a random nonce. Those parameters are given to the provider and
should be verified when the client comes back from the provider.
It is also used to store the client_redirect_url, which is used to
complete the SSO login flow.
Args:
state: The ``state`` parameter passed to the OIDC provider.
nonce: The ``nonce`` parameter passed to the OIDC provider.
client_redirect_url: The URL the client gave when it initiated the
flow.
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
duration_in_ms: An optional duration for the token in milliseconds.
Defaults to an hour.
Returns:
A signed macaroon token with the session information.
"""
macaroon = pymacaroons.Macaroon(
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = session")
macaroon.add_first_party_caveat("state = %s" % (state,))
macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (client_redirect_url,)
)
if ui_auth_session_id:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def _verify_oidc_session_token(
self, session: bytes, state: str
) -> Tuple[str, str, Optional[str]]:
"""Verifies and extract an OIDC session token.
This verifies that a given session token was issued by this homeserver
and extract the nonce and client_redirect_url caveats.
Args:
session: The session token to verify
state: The state the OIDC provider gave back
Returns:
The nonce, client_redirect_url, and ui_auth_session_id for this session
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = session")
v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
# to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self._macaroon_secret_key)
# Extract the `nonce`, `client_redirect_url`, and maybe the
# `ui_auth_session_id` from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
try:
ui_auth_session_id = self._get_value_from_macaroon(
macaroon, "ui_auth_session_id"
) # type: Optional[str]
except ValueError:
ui_auth_session_id = None
return nonce, client_redirect_url, ui_auth_session_id
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
Args:
macaroon: the token
key: the key of the caveat to extract
Returns:
The extracted value
Raises:
Exception: if the caveat was not in the macaroon
"""
prefix = key + " = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise ValueError("No %s caveat in macaroon" % (key,))
def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self.clock.time_msec()
return now < expiry
async def _complete_oidc_login(
self,
userinfo: UserInfo,
@ -901,8 +781,8 @@ class OidcHandler(BaseHandler):
# and attempt to match it.
attributes = await oidc_response_to_user_attributes(failures=0)
user_id = UserID(attributes.localpart, self.server_name).to_string()
users = await self.store.get_users_by_id_case_insensitive(user_id)
user_id = UserID(attributes.localpart, self._server_name).to_string()
users = await self._store.get_users_by_id_case_insensitive(user_id)
if users:
# If an existing matrix ID is returned, then use it.
if len(users) == 1:
@ -954,6 +834,148 @@ class OidcHandler(BaseHandler):
return str(remote_user_id)
class OidcSessionTokenGenerator:
"""Methods for generating and checking OIDC Session cookies."""
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._server_name = hs.hostname
self._macaroon_secret_key = hs.config.key.macaroon_secret_key
def generate_oidc_session_token(
self,
state: str,
session_data: "OidcSessionData",
duration_in_ms: int = (60 * 60 * 1000),
) -> str:
"""Generates a signed token storing data about an OIDC session.
When Synapse initiates an authorization flow, it creates a random state
and a random nonce. Those parameters are given to the provider and
should be verified when the client comes back from the provider.
It is also used to store the client_redirect_url, which is used to
complete the SSO login flow.
Args:
state: The ``state`` parameter passed to the OIDC provider.
session_data: data to include in the session token.
duration_in_ms: An optional duration for the token in milliseconds.
Defaults to an hour.
Returns:
A signed macaroon token with the session information.
"""
macaroon = pymacaroons.Macaroon(
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = session")
macaroon.add_first_party_caveat("state = %s" % (state,))
macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (session_data.client_redirect_url,)
)
if session_data.ui_auth_session_id:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
)
now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def verify_oidc_session_token(
self, session: bytes, state: str
) -> "OidcSessionData":
"""Verifies and extract an OIDC session token.
This verifies that a given session token was issued by this homeserver
and extract the nonce and client_redirect_url caveats.
Args:
session: The session token to verify
state: The state the OIDC provider gave back
Returns:
The data extracted from the session cookie
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = session")
v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
# to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self._macaroon_secret_key)
# Extract the `nonce`, `client_redirect_url`, and maybe the
# `ui_auth_session_id` from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
try:
ui_auth_session_id = self._get_value_from_macaroon(
macaroon, "ui_auth_session_id"
) # type: Optional[str]
except ValueError:
ui_auth_session_id = None
return OidcSessionData(
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=ui_auth_session_id,
)
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
Args:
macaroon: the token
key: the key of the caveat to extract
Returns:
The extracted value
Raises:
Exception: if the caveat was not in the macaroon
"""
prefix = key + " = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise ValueError("No %s caveat in macaroon" % (key,))
def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self._clock.time_msec()
return now < expiry
@attr.s(frozen=True, slots=True)
class OidcSessionData:
"""The attributes which are stored in a OIDC session cookie"""
# The `nonce` parameter passed to the OIDC provider.
nonce = attr.ib(type=str)
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
client_redirect_url = attr.ib(type=str)
# The session ID of the ongoing UI Auth (None if this is a login)
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
UserAttributeDict = TypedDict(
"UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
)

View File

@ -14,7 +14,7 @@
# limitations under the License.
import json
import re
from typing import Dict
from typing import Dict, Optional
from urllib.parse import parse_qs, urlencode, urlparse
from mock import ANY, Mock, patch
@ -349,9 +349,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
cookie = args[1]
macaroon = pymacaroons.Macaroon.deserialize(cookie)
state = self.handler._get_value_from_macaroon(macaroon, "state")
nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
redirect = self.handler._get_value_from_macaroon(
state = self.handler._token_generator._get_value_from_macaroon(
macaroon, "state"
)
nonce = self.handler._token_generator._get_value_from_macaroon(
macaroon, "nonce"
)
redirect = self.handler._token_generator._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
@ -411,12 +415,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url = "http://client/redirect"
user_agent = "Browser"
ip_address = "10.0.0.1"
session = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
request = _build_callback_request(
code, state, session, user_agent=user_agent, ip_address=ip_address
)
@ -500,11 +499,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("invalid_session")
# Mismatching session
session = self.handler._generate_oidc_session_token(
state="state",
nonce="nonce",
client_redirect_url="http://client/redirect",
ui_auth_session_id=None,
session = self._generate_oidc_session_token(
state="state", nonce="nonce", client_redirect_url="http://client/redirect",
)
request.args = {}
request.args[b"state"] = [b"mismatching state"]
@ -623,11 +619,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
state = "state"
client_redirect_url = "http://client/redirect"
session = self.handler._generate_oidc_session_token(
state=state,
nonce="nonce",
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
session = self._generate_oidc_session_token(
state=state, nonce="nonce", client_redirect_url=client_redirect_url,
)
request = _build_callback_request("code", state, session)
@ -841,6 +834,24 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
self.assertRenderedError("mapping_error", "localpart is invalid: ")
def _generate_oidc_session_token(
self,
state: str,
nonce: str,
client_redirect_url: str,
ui_auth_session_id: Optional[str] = None,
) -> str:
from synapse.handlers.oidc_handler import OidcSessionData
return self.handler._token_generator.generate_oidc_session_token(
state=state,
session_data=OidcSessionData(
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=ui_auth_session_id,
),
)
class UsernamePickerTestCase(HomeserverTestCase):
if not HAS_OIDC:
@ -965,17 +976,19 @@ async def _make_callback_with_userinfo(
userinfo: the OIDC userinfo dict
client_redirect_url: the URL to redirect to on success.
"""
from synapse.handlers.oidc_handler import OidcSessionData
handler = hs.get_oidc_handler()
handler._exchange_code = simple_async_mock(return_value={})
handler._parse_id_token = simple_async_mock(return_value=userinfo)
handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
state = "state"
session = handler._generate_oidc_session_token(
session = handler._token_generator.generate_oidc_session_token(
state=state,
nonce="nonce",
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
session_data=OidcSessionData(
nonce="nonce", client_redirect_url=client_redirect_url,
),
)
request = _build_callback_request("code", state, session)