mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-05 15:30:49 -05:00
Factor SSO success handling out of CAS login (#4264)
This is mostly factoring out the post-CAS-login code to somewhere we can reuse it for other SSO flows, but it also fixes the userid mapping while we're at it.
This commit is contained in:
parent
b0c24a66ec
commit
c588b9b9e4
1
changelog.d/4264.bugfix
Normal file
1
changelog.d/4264.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fix CAS login when username is not valid in an MXID
|
@ -563,10 +563,10 @@ class AuthHandler(BaseHandler):
|
|||||||
insensitively, but return None if there are multiple inexact matches.
|
insensitively, but return None if there are multiple inexact matches.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
(str) user_id: complete @user:id
|
(unicode|bytes) user_id: complete @user:id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: (str) canonical_user_id, or None if zero or
|
defer.Deferred: (unicode) canonical_user_id, or None if zero or
|
||||||
multiple matches
|
multiple matches
|
||||||
"""
|
"""
|
||||||
res = yield self._find_user_id_and_pwd_hash(user_id)
|
res = yield self._find_user_id_and_pwd_hash(user_id)
|
||||||
@ -954,6 +954,15 @@ class MacaroonGenerator(object):
|
|||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
||||||
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
|
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (unicode):
|
||||||
|
duration_in_ms (int):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
unicode
|
||||||
|
"""
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
macaroon.add_first_party_caveat("type = login")
|
macaroon.add_first_party_caveat("type = login")
|
||||||
now = self.hs.get_clock().time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
|
@ -23,8 +23,12 @@ from twisted.web.client import PartialDownloadError
|
|||||||
|
|
||||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||||
from synapse.http.server import finish_request
|
from synapse.http.server import finish_request
|
||||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
from synapse.http.servlet import (
|
||||||
from synapse.types import UserID
|
RestServlet,
|
||||||
|
parse_json_object_from_request,
|
||||||
|
parse_string,
|
||||||
|
)
|
||||||
|
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
|
|
||||||
from .base import ClientV1RestServlet, client_path_patterns
|
from .base import ClientV1RestServlet, client_path_patterns
|
||||||
@ -358,17 +362,15 @@ class CasTicketServlet(ClientV1RestServlet):
|
|||||||
self.cas_server_url = hs.config.cas_server_url
|
self.cas_server_url = hs.config.cas_server_url
|
||||||
self.cas_service_url = hs.config.cas_service_url
|
self.cas_service_url = hs.config.cas_service_url
|
||||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self._sso_auth_handler = SSOAuthHandler(hs)
|
||||||
self.handlers = hs.get_handlers()
|
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
client_redirect_url = request.args[b"redirectUrl"][0]
|
client_redirect_url = parse_string(request, "redirectUrl", required=True)
|
||||||
http_client = self.hs.get_simple_http_client()
|
http_client = self.hs.get_simple_http_client()
|
||||||
uri = self.cas_server_url + "/proxyValidate"
|
uri = self.cas_server_url + "/proxyValidate"
|
||||||
args = {
|
args = {
|
||||||
"ticket": request.args[b"ticket"][0].decode('ascii'),
|
"ticket": parse_string(request, "ticket", required=True),
|
||||||
"service": self.cas_service_url
|
"service": self.cas_service_url
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
@ -380,7 +382,6 @@ class CasTicketServlet(ClientV1RestServlet):
|
|||||||
result = yield self.handle_cas_response(request, body, client_redirect_url)
|
result = yield self.handle_cas_response(request, body, client_redirect_url)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
|
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
|
||||||
user, attributes = self.parse_cas_response(cas_response_body)
|
user, attributes = self.parse_cas_response(cas_response_body)
|
||||||
|
|
||||||
@ -396,29 +397,10 @@ class CasTicketServlet(ClientV1RestServlet):
|
|||||||
if required_value != actual_value:
|
if required_value != actual_value:
|
||||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
user_id = UserID(user, self.hs.hostname).to_string()
|
return self._sso_auth_handler.on_successful_auth(
|
||||||
auth_handler = self.auth_handler
|
user, request, client_redirect_url,
|
||||||
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
|
||||||
if not registered_user_id:
|
|
||||||
registered_user_id, _ = (
|
|
||||||
yield self.handlers.registration_handler.register(localpart=user)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
login_token = self.macaroon_gen.generate_short_term_login_token(
|
|
||||||
registered_user_id
|
|
||||||
)
|
|
||||||
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
|
|
||||||
login_token)
|
|
||||||
request.redirect(redirect_url)
|
|
||||||
finish_request(request)
|
|
||||||
|
|
||||||
def add_login_token_to_redirect_url(self, url, token):
|
|
||||||
url_parts = list(urllib.parse.urlparse(url))
|
|
||||||
query = dict(urllib.parse.parse_qsl(url_parts[4]))
|
|
||||||
query.update({"loginToken": token})
|
|
||||||
url_parts[4] = urllib.parse.urlencode(query).encode('ascii')
|
|
||||||
return urllib.parse.urlunparse(url_parts)
|
|
||||||
|
|
||||||
def parse_cas_response(self, cas_response_body):
|
def parse_cas_response(self, cas_response_body):
|
||||||
user = None
|
user = None
|
||||||
attributes = {}
|
attributes = {}
|
||||||
@ -452,6 +434,71 @@ class CasTicketServlet(ClientV1RestServlet):
|
|||||||
return user, attributes
|
return user, attributes
|
||||||
|
|
||||||
|
|
||||||
|
class SSOAuthHandler(object):
|
||||||
|
"""
|
||||||
|
Utility class for Resources and Servlets which handle the response from a SSO
|
||||||
|
service
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hs (synapse.server.HomeServer)
|
||||||
|
"""
|
||||||
|
def __init__(self, hs):
|
||||||
|
self._hostname = hs.hostname
|
||||||
|
self._auth_handler = hs.get_auth_handler()
|
||||||
|
self._registration_handler = hs.get_handlers().registration_handler
|
||||||
|
self._macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_successful_auth(
|
||||||
|
self, username, request, client_redirect_url,
|
||||||
|
):
|
||||||
|
"""Called once the user has successfully authenticated with the SSO.
|
||||||
|
|
||||||
|
Registers the user if necessary, and then returns a redirect (with
|
||||||
|
a login token) to the client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username (unicode|bytes): the remote user id. We'll map this onto
|
||||||
|
something sane for a MXID localpath.
|
||||||
|
|
||||||
|
request (SynapseRequest): the incoming request from the browser. We'll
|
||||||
|
respond to it with a redirect.
|
||||||
|
|
||||||
|
client_redirect_url (unicode): the redirect_url the client gave us when
|
||||||
|
it first started the process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[none]: Completes once we have handled the request.
|
||||||
|
"""
|
||||||
|
localpart = map_username_to_mxid_localpart(username)
|
||||||
|
user_id = UserID(localpart, self._hostname).to_string()
|
||||||
|
registered_user_id = yield self._auth_handler.check_user_exists(user_id)
|
||||||
|
if not registered_user_id:
|
||||||
|
registered_user_id, _ = (
|
||||||
|
yield self._registration_handler.register(
|
||||||
|
localpart=localpart,
|
||||||
|
generate_token=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
login_token = self._macaroon_gen.generate_short_term_login_token(
|
||||||
|
registered_user_id
|
||||||
|
)
|
||||||
|
redirect_url = self._add_login_token_to_redirect_url(
|
||||||
|
client_redirect_url, login_token
|
||||||
|
)
|
||||||
|
request.redirect(redirect_url)
|
||||||
|
finish_request(request)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _add_login_token_to_redirect_url(url, token):
|
||||||
|
url_parts = list(urllib.parse.urlparse(url))
|
||||||
|
query = dict(urllib.parse.parse_qsl(url_parts[4]))
|
||||||
|
query.update({"loginToken": token})
|
||||||
|
url_parts[4] = urllib.parse.urlencode(query)
|
||||||
|
return urllib.parse.urlunparse(url_parts)
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
LoginRestServlet(hs).register(http_server)
|
LoginRestServlet(hs).register(http_server)
|
||||||
if hs.config.cas_enabled:
|
if hs.config.cas_enabled:
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import re
|
||||||
import string
|
import string
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
@ -228,6 +229,71 @@ def contains_invalid_mxid_characters(localpart):
|
|||||||
return any(c not in mxid_localpart_allowed_characters for c in localpart)
|
return any(c not in mxid_localpart_allowed_characters for c in localpart)
|
||||||
|
|
||||||
|
|
||||||
|
UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
|
||||||
|
|
||||||
|
# the following is a pattern which matches '=', and bytes which are not allowed in a mxid
|
||||||
|
# localpart.
|
||||||
|
#
|
||||||
|
# It works by:
|
||||||
|
# * building a string containing the allowed characters (excluding '=')
|
||||||
|
# * escaping every special character with a backslash (to stop '-' being interpreted as a
|
||||||
|
# range operator)
|
||||||
|
# * wrapping it in a '[^...]' regex
|
||||||
|
# * converting the whole lot to a 'bytes' sequence, so that we can use it to match
|
||||||
|
# bytes rather than strings
|
||||||
|
#
|
||||||
|
NON_MXID_CHARACTER_PATTERN = re.compile(
|
||||||
|
("[^%s]" % (
|
||||||
|
re.escape("".join(mxid_localpart_allowed_characters - {"="}),),
|
||||||
|
)).encode("ascii"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def map_username_to_mxid_localpart(username, case_sensitive=False):
|
||||||
|
"""Map a username onto a string suitable for a MXID
|
||||||
|
|
||||||
|
This follows the algorithm laid out at
|
||||||
|
https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username (unicode|bytes): username to be mapped
|
||||||
|
case_sensitive (bool): true if TEST and test should be mapped
|
||||||
|
onto different mxids
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
unicode: string suitable for a mxid localpart
|
||||||
|
"""
|
||||||
|
if not isinstance(username, bytes):
|
||||||
|
username = username.encode('utf-8')
|
||||||
|
|
||||||
|
# first we sort out upper-case characters
|
||||||
|
if case_sensitive:
|
||||||
|
def f1(m):
|
||||||
|
return b"_" + m.group().lower()
|
||||||
|
|
||||||
|
username = UPPER_CASE_PATTERN.sub(f1, username)
|
||||||
|
else:
|
||||||
|
username = username.lower()
|
||||||
|
|
||||||
|
# then we sort out non-ascii characters
|
||||||
|
def f2(m):
|
||||||
|
g = m.group()[0]
|
||||||
|
if isinstance(g, str):
|
||||||
|
# on python 2, we need to do a ord(). On python 3, the
|
||||||
|
# byte itself will do.
|
||||||
|
g = ord(g)
|
||||||
|
return b"=%02x" % (g,)
|
||||||
|
|
||||||
|
username = NON_MXID_CHARACTER_PATTERN.sub(f2, username)
|
||||||
|
|
||||||
|
# we also do the =-escaping to mxids starting with an underscore.
|
||||||
|
username = re.sub(b'^_', b'=5f', username)
|
||||||
|
|
||||||
|
# we should now only have ascii bytes left, so can decode back to a
|
||||||
|
# unicode.
|
||||||
|
return username.decode('ascii')
|
||||||
|
|
||||||
|
|
||||||
class StreamToken(
|
class StreamToken(
|
||||||
namedtuple("Token", (
|
namedtuple("Token", (
|
||||||
"room_key",
|
"room_key",
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.types import GroupID, RoomAlias, UserID
|
from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.utils import TestHomeServer
|
from tests.utils import TestHomeServer
|
||||||
@ -79,3 +79,32 @@ class GroupIDTestCase(unittest.TestCase):
|
|||||||
except SynapseError as exc:
|
except SynapseError as exc:
|
||||||
self.assertEqual(400, exc.code)
|
self.assertEqual(400, exc.code)
|
||||||
self.assertEqual("M_UNKNOWN", exc.errcode)
|
self.assertEqual("M_UNKNOWN", exc.errcode)
|
||||||
|
|
||||||
|
|
||||||
|
class MapUsernameTestCase(unittest.TestCase):
|
||||||
|
def testPassThrough(self):
|
||||||
|
self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")
|
||||||
|
|
||||||
|
def testUpperCase(self):
|
||||||
|
self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234")
|
||||||
|
self.assertEqual(
|
||||||
|
map_username_to_mxid_localpart("tEST_1234", case_sensitive=True),
|
||||||
|
"t_e_s_t__1234",
|
||||||
|
)
|
||||||
|
|
||||||
|
def testSymbols(self):
|
||||||
|
self.assertEqual(
|
||||||
|
map_username_to_mxid_localpart("test=$?_1234"),
|
||||||
|
"test=3d=24=3f_1234",
|
||||||
|
)
|
||||||
|
|
||||||
|
def testLeadingUnderscore(self):
|
||||||
|
self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234")
|
||||||
|
|
||||||
|
def testNonAscii(self):
|
||||||
|
# this should work with either a unicode or a bytes
|
||||||
|
self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast")
|
||||||
|
self.assertEqual(
|
||||||
|
map_username_to_mxid_localpart(u'têst'.encode('utf-8')),
|
||||||
|
"t=c3=aast",
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user