Merge branch 'develop' into jaywink/admin-forward-extremities

This commit is contained in:
Jason Robinson 2021-01-09 22:00:04 +02:00
commit 2eb421b606
23 changed files with 385 additions and 168 deletions

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
logger = logging.getLogger("create_postgres_db") logger = logging.getLogger("create_postgres_db")

1
changelog.d/9036.feature Normal file
View File

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

1
changelog.d/9038.misc Normal file
View File

@ -0,0 +1 @@
Configure the linters to run on a consistent set of files.

1
changelog.d/9039.removal Normal file
View File

@ -0,0 +1 @@
Remove broken and unmaintained `demo/webserver.py` script.

1
changelog.d/9051.bugfix Normal file
View File

@ -0,0 +1 @@
Fix error handling during insertion of client IPs into the database.

1
changelog.d/9053.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug where we didn't correctly record CPU time spent in 'on_new_event' block.

1
changelog.d/9054.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a minor bug which could cause confusing error messages from invalid configurations.

1
changelog.d/9057.doc Normal file
View File

@ -0,0 +1 @@
Add missing user_mapping_provider configuration to the Keycloak OIDC example. Contributed by @chris-ruecker.

View File

@ -1,59 +0,0 @@
import argparse
import BaseHTTPServer
import os
import SimpleHTTPServer
import cgi, logging
from daemonize import Daemonize
class SimpleHTTPRequestHandlerWithPOST(SimpleHTTPServer.SimpleHTTPRequestHandler):
UPLOAD_PATH = "upload"
"""
Accept all post request as file upload
"""
def do_POST(self):
path = os.path.join(self.UPLOAD_PATH, os.path.basename(self.path))
length = self.headers["content-length"]
data = self.rfile.read(int(length))
with open(path, "wb") as fh:
fh.write(data)
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
# Return the absolute path of the uploaded file
self.wfile.write('{"url":"/%s"}' % path)
def setup():
parser = argparse.ArgumentParser()
parser.add_argument("directory")
parser.add_argument("-p", "--port", dest="port", type=int, default=8080)
parser.add_argument("-P", "--pid-file", dest="pid", default="web.pid")
args = parser.parse_args()
# Get absolute path to directory to serve, as daemonize changes to '/'
os.chdir(args.directory)
dr = os.getcwd()
httpd = BaseHTTPServer.HTTPServer(("", args.port), SimpleHTTPRequestHandlerWithPOST)
def run():
os.chdir(dr)
httpd.serve_forever()
daemon = Daemonize(
app="synapse-webclient", pid=args.pid, action=run, auto_close_fds=False
)
daemon.start()
if __name__ == "__main__":
setup()

View File

@ -158,6 +158,10 @@ oidc_config:
client_id: "synapse" client_id: "synapse"
client_secret: "copy secret generated from above" client_secret: "copy secret generated from above"
scopes: ["openid", "profile"] scopes: ["openid", "profile"]
user_mapping_provider:
config:
localpart_template: "{{ user.preferred_username }}"
display_name_template: "{{ user.name }}"
``` ```
### [Auth0][auth0] ### [Auth0][auth0]

View File

@ -103,6 +103,7 @@ files =
tests/replication, tests/replication,
tests/test_utils, tests/test_utils,
tests/handlers/test_password_providers.py, tests/handlers/test_password_providers.py,
tests/rest/client/v1/test_login.py,
tests/rest/client/v2_alpha/test_auth.py, tests/rest/client/v2_alpha/test_auth.py,
tests/util/test_stream_change_cache.py tests/util/test_stream_change_cache.py

View File

@ -15,16 +15,7 @@
# Stub for frozendict. # Stub for frozendict.
from typing import ( from typing import Any, Hashable, Iterable, Iterator, Mapping, Tuple, TypeVar, overload
Any,
Hashable,
Iterable,
Iterator,
Mapping,
overload,
Tuple,
TypeVar,
)
_KT = TypeVar("_KT", bound=Hashable) # Key type. _KT = TypeVar("_KT", bound=Hashable) # Key type.
_VT = TypeVar("_VT") # Value type. _VT = TypeVar("_VT") # Value type.

View File

@ -7,17 +7,17 @@ from typing import (
Callable, Callable,
Dict, Dict,
Hashable, Hashable,
Iterator,
Iterable,
ItemsView, ItemsView,
Iterable,
Iterator,
KeysView, KeysView,
List, List,
Mapping, Mapping,
Optional, Optional,
Sequence, Sequence,
Tuple,
Type, Type,
TypeVar, TypeVar,
Tuple,
Union, Union,
ValuesView, ValuesView,
overload, overload,

View File

@ -16,7 +16,7 @@
"""Contains *incomplete* type hints for txredisapi. """Contains *incomplete* type hints for txredisapi.
""" """
from typing import List, Optional, Union, Type from typing import List, Optional, Type, Union
class RedisProtocol: class RedisProtocol:
def publish(self, channel: str, message: bytes): ... def publish(self, channel: str, message: bytes): ...

View File

@ -56,7 +56,7 @@ def json_error_to_config_error(
""" """
# copy `config_path` before modifying it. # copy `config_path` before modifying it.
path = list(config_path) path = list(config_path)
for p in list(e.path): for p in list(e.absolute_path):
if isinstance(p, int): if isinstance(p, int):
path.append("<item %i>" % p) path.append("<item %i>" % p)
else: else:

View File

@ -396,31 +396,30 @@ class Notifier:
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
with PreserveLoggingContext(): with Measure(self.clock, "on_new_event"):
with Measure(self.clock, "on_new_event"): user_streams = set()
user_streams = set()
for user in users: for user in users:
user_stream = self.user_to_user_stream.get(str(user)) user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None: if user_stream is not None:
user_streams.add(user_stream) user_streams.add(user_stream)
for room in rooms: for room in rooms:
user_streams |= self.room_to_user_streams.get(room, set()) user_streams |= self.room_to_user_streams.get(room, set())
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
for user_stream in user_streams: for user_stream in user_streams:
try: try:
user_stream.notify(stream_key, new_token, time_now_ms) user_stream.notify(stream_key, new_token, time_now_ms)
except Exception: except Exception:
logger.exception("Failed to notify listener") logger.exception("Failed to notify listener")
self.notify_replication() self.notify_replication()
# Notify appservices # Notify appservices
self._notify_app_services_ephemeral( self._notify_app_services_ephemeral(
stream_key, new_token, users, stream_key, new_token, users,
) )
def on_new_replication_data(self) -> None: def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happened """Used to inform replication listeners that something has happened

View File

@ -319,9 +319,9 @@ class SsoRedirectServlet(RestServlet):
# register themselves with the main SSOHandler. # register themselves with the main SSOHandler.
if hs.config.cas_enabled: if hs.config.cas_enabled:
hs.get_cas_handler() hs.get_cas_handler()
elif hs.config.saml2_enabled: if hs.config.saml2_enabled:
hs.get_saml_handler() hs.get_saml_handler()
elif hs.config.oidc_enabled: if hs.config.oidc_enabled:
hs.get_oidc_handler() hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()

View File

@ -470,43 +470,35 @@ class ClientIpStore(ClientIpWorkerStore):
for entry in to_update.items(): for entry in to_update.items():
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
try: self.db_pool.simple_upsert_txn(
self.db_pool.simple_upsert_txn( txn,
table="user_ips",
keyvalues={"user_id": user_id, "access_token": access_token, "ip": ip},
values={
"user_agent": user_agent,
"device_id": device_id,
"last_seen": last_seen,
},
lock=False,
)
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
# this is always an update rather than an upsert: the row should
# already exist, and if it doesn't, that may be because it has been
# deleted, and we don't want to re-create it.
self.db_pool.simple_update_txn(
txn, txn,
table="user_ips", table="devices",
keyvalues={ keyvalues={"user_id": user_id, "device_id": device_id},
"user_id": user_id, updatevalues={
"access_token": access_token, "user_agent": user_agent,
"last_seen": last_seen,
"ip": ip, "ip": ip,
}, },
values={
"user_agent": user_agent,
"device_id": device_id,
"last_seen": last_seen,
},
lock=False,
) )
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
# this is always an update rather than an upsert: the row should
# already exist, and if it doesn't, that may be because it has been
# deleted, and we don't want to re-create it.
self.db_pool.simple_update_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues={
"user_agent": user_agent,
"last_seen": last_seen,
"ip": ip,
},
)
except Exception as e:
# Failed to upsert, log and continue
logger.error("Failed to insert client IP %r: %r", entry, e)
async def get_last_client_ip_by_device( async def get_last_client_ip_by_device(
self, user_id: str, device_id: Optional[str] self, user_id: str, device_id: Optional[str]
) -> Dict[Tuple[str, str], dict]: ) -> Dict[Tuple[str, str], dict]:

View File

@ -111,7 +111,8 @@ class Measure:
curr_context = current_context() curr_context = current_context()
if not curr_context: if not curr_context:
logger.warning( logger.warning(
"Starting metrics collection from sentinel context: metrics will be lost" "Starting metrics collection %r from sentinel context: metrics will be lost",
name,
) )
parent_context = None parent_context = None
else: else:

53
tests/config/test_util.py Normal file
View File

@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.config import ConfigError
from synapse.config._util import validate_config
from tests.unittest import TestCase
class ValidateConfigTestCase(TestCase):
"""Test cases for synapse.config._util.validate_config"""
def test_bad_object_in_array(self):
"""malformed objects within an array should be validated correctly"""
# consider a structure:
#
# array_of_objs:
# - r: 1
# foo: 2
#
# - r: 2
# bar: 3
#
# ... where each entry must contain an "r": check that the path
# to the required item is correclty reported.
schema = {
"type": "object",
"properties": {
"array_of_objs": {
"type": "array",
"items": {"type": "object", "required": ["r"]},
},
},
}
with self.assertRaises(ConfigError) as c:
validate_config(schema, {"array_of_objs": [{}]}, ("base",))
self.assertEqual(c.exception.path, ["base", "array_of_objs", "<item 0>"])

View File

@ -1,22 +1,67 @@
import json # -*- coding: utf-8 -*-
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time import time
import urllib.parse import urllib.parse
from html.parser import HTMLParser
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from mock import Mock from mock import Mock
try: import pymacaroons
import jwt
except ImportError: from twisted.web.resource import Resource
jwt = None
import synapse.rest.admin import synapse.rest.admin
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import login, logout from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha import devices, register
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
from synapse.rest.synapse.client.pick_idp import PickIdpResource
from tests import unittest from tests import unittest
from tests.unittest import override_config from tests.handlers.test_oidc import HAS_OIDC
from tests.handlers.test_saml import has_saml2
from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
from tests.unittest import override_config, skip_unless
try:
import jwt
HAS_JWT = True
except ImportError:
HAS_JWT = False
# public_base_url used in some tests
BASE_URL = "https://synapse/"
# CAS server used in some tests
CAS_SERVER = "https://fake.test"
# just enough to tell pysaml2 where to redirect to
SAML_SERVER = "https://test.saml.server/idp/sso"
TEST_SAML_METADATA = """
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
<md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="%(SAML_SERVER)s"/>
</md:IDPSSODescriptor>
</md:EntityDescriptor>
""" % {
"SAML_SERVER": SAML_SERVER,
}
LOGIN_URL = b"/_matrix/client/r0/login" LOGIN_URL = b"/_matrix/client/r0/login"
TEST_URL = b"/_matrix/client/r0/account/whoami" TEST_URL = b"/_matrix/client/r0/account/whoami"
@ -314,6 +359,184 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase):
"""Tests for homeservers with multiple SSO providers enabled"""
servlets = [
login.register_servlets,
]
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
config["cas_config"] = {
"enabled": True,
"server_url": CAS_SERVER,
"service_url": "https://matrix.goodserver.com:8448",
}
config["saml2_config"] = {
"sp_config": {
"metadata": {"inline": [TEST_SAML_METADATA]},
# use the XMLSecurity backend to avoid relying on xmlsec1
"crypto_backend": "XMLSecurity",
},
}
config["oidc_config"] = TEST_OIDC_CONFIG
return config
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
return d
def test_multi_sso_redirect(self):
"""/login/sso/redirect should redirect to an identity picker"""
client_redirect_url = "https://x?<abc>"
# first hit the redirect url, which should redirect to our idp picker
channel = self.make_request(
"GET",
"/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url,
)
self.assertEqual(channel.code, 302, channel.result)
uri = channel.headers.getRawHeaders("Location")[0]
# hitting that picker should give us some HTML
channel = self.make_request("GET", uri)
self.assertEqual(channel.code, 200, channel.result)
# parse the form to check it has fields assumed elsewhere in this class
class FormPageParser(HTMLParser):
def __init__(self):
super().__init__()
# the values of the hidden inputs: map from name to value
self.hiddens = {} # type: Dict[str, Optional[str]]
# the values of the radio buttons
self.radios = [] # type: List[Optional[str]]
def handle_starttag(
self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
) -> None:
attr_dict = dict(attrs)
if tag == "input":
if attr_dict["type"] == "radio" and attr_dict["name"] == "idp":
self.radios.append(attr_dict["value"])
elif attr_dict["type"] == "hidden":
input_name = attr_dict["name"]
assert input_name
self.hiddens[input_name] = attr_dict["value"]
def error(_, message):
self.fail(message)
p = FormPageParser()
p.feed(channel.result["body"].decode("utf-8"))
p.close()
self.assertCountEqual(p.radios, ["cas", "oidc", "saml"])
self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url)
def test_multi_sso_redirect_to_cas(self):
"""If CAS is chosen, should redirect to the CAS server"""
client_redirect_url = "https://x?<abc>"
channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas",
shorthand=False,
)
self.assertEqual(channel.code, 302, channel.result)
cas_uri = channel.headers.getRawHeaders("Location")[0]
cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
# it should redirect us to the login page of the cas server
self.assertEqual(cas_uri_path, CAS_SERVER + "/login")
# check that the redirectUrl is correctly encoded in the service param - ie, the
# place that CAS will redirect to
cas_uri_params = urllib.parse.parse_qs(cas_uri_query)
service_uri = cas_uri_params["service"][0]
_, service_uri_query = service_uri.split("?", 1)
service_uri_params = urllib.parse.parse_qs(service_uri_query)
self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url)
def test_multi_sso_redirect_to_saml(self):
"""If SAML is chosen, should redirect to the SAML server"""
client_redirect_url = "https://x?<abc>"
channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl="
+ client_redirect_url
+ "&idp=saml",
)
self.assertEqual(channel.code, 302, channel.result)
saml_uri = channel.headers.getRawHeaders("Location")[0]
saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
# it should redirect us to the login page of the SAML server
self.assertEqual(saml_uri_path, SAML_SERVER)
# the RelayState is used to carry the client redirect url
saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
relay_state_param = saml_uri_params["RelayState"][0]
self.assertEqual(relay_state_param, client_redirect_url)
def test_multi_sso_redirect_to_oidc(self):
"""If OIDC is chosen, should redirect to the OIDC auth endpoint"""
client_redirect_url = "https://x?<abc>"
channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl="
+ client_redirect_url
+ "&idp=oidc",
)
self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
# ... and should have set a cookie including the redirect url
cookies = dict(
h.split(";")[0].split("=", maxsplit=1)
for h in channel.headers.getRawHeaders("Set-Cookie")
)
oidc_session_cookie = cookies["oidc_session"]
macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
self.assertEqual(
self._get_value_from_macaroon(macaroon, "client_redirect_url"),
client_redirect_url,
)
def test_multi_sso_redirect_to_unknown(self):
"""An unknown IdP should cause a 400"""
channel = self.make_request(
"GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
)
self.assertEqual(channel.code, 400, channel.result)
@staticmethod
def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
prefix = key + " = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise ValueError("No %s caveat in macaroon" % (key,))
class CASTestCase(unittest.HomeserverTestCase): class CASTestCase(unittest.HomeserverTestCase):
servlets = [ servlets = [
@ -327,7 +550,7 @@ class CASTestCase(unittest.HomeserverTestCase):
config = self.default_config() config = self.default_config()
config["cas_config"] = { config["cas_config"] = {
"enabled": True, "enabled": True,
"server_url": "https://fake.test", "server_url": CAS_SERVER,
"service_url": "https://matrix.goodserver.com:8448", "service_url": "https://matrix.goodserver.com:8448",
} }
@ -413,8 +636,7 @@ class CASTestCase(unittest.HomeserverTestCase):
} }
) )
def test_cas_redirect_whitelisted(self): def test_cas_redirect_whitelisted(self):
"""Tests that the SSO login flow serves a redirect to a whitelisted url """Tests that the SSO login flow serves a redirect to a whitelisted url"""
"""
self._test_redirect("https://legit-site.com/") self._test_redirect("https://legit-site.com/")
@override_config({"public_baseurl": "https://example.com"}) @override_config({"public_baseurl": "https://example.com"})
@ -462,10 +684,8 @@ class CASTestCase(unittest.HomeserverTestCase):
self.assertIn(b"SSO account deactivated", channel.result["body"]) self.assertIn(b"SSO account deactivated", channel.result["body"])
@skip_unless(HAS_JWT, "requires jwt")
class JWTTestCase(unittest.HomeserverTestCase): class JWTTestCase(unittest.HomeserverTestCase):
if not jwt:
skip = "requires jwt"
servlets = [ servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource, synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets, login.register_servlets,
@ -481,17 +701,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.hs.config.jwt_algorithm = self.jwt_algorithm self.hs.config.jwt_algorithm = self.jwt_algorithm
return self.hs return self.hs
def jwt_encode(self, token: str, secret: str = jwt_secret) -> str: def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
result = jwt.encode(token, secret, self.jwt_algorithm) result = jwt.encode(
payload, secret, self.jwt_algorithm
) # type: Union[str, bytes]
if isinstance(result, bytes): if isinstance(result, bytes):
return result.decode("ascii") return result.decode("ascii")
return result return result
def jwt_login(self, *args): def jwt_login(self, *args):
params = json.dumps( params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
{"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
)
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
return channel return channel
@ -623,7 +843,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
) )
def test_login_no_token(self): def test_login_no_token(self):
params = json.dumps({"type": "org.matrix.login.jwt"}) params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"403", channel.result) self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@ -633,10 +853,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
# RSS256, with a public key configured in synapse as "jwt_secret", and tokens # RSS256, with a public key configured in synapse as "jwt_secret", and tokens
# signed by the private key. # signed by the private key.
@skip_unless(HAS_JWT, "requires jwt")
class JWTPubKeyTestCase(unittest.HomeserverTestCase): class JWTPubKeyTestCase(unittest.HomeserverTestCase):
if not jwt:
skip = "requires jwt"
servlets = [ servlets = [
login.register_servlets, login.register_servlets,
] ]
@ -693,17 +911,15 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
self.hs.config.jwt_algorithm = "RS256" self.hs.config.jwt_algorithm = "RS256"
return self.hs return self.hs
def jwt_encode(self, token: str, secret: str = jwt_privatekey) -> str: def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
# PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
result = jwt.encode(token, secret, "RS256") result = jwt.encode(payload, secret, "RS256") # type: Union[bytes,str]
if isinstance(result, bytes): if isinstance(result, bytes):
return result.decode("ascii") return result.decode("ascii")
return result return result
def jwt_login(self, *args): def jwt_login(self, *args):
params = json.dumps( params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
{"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
)
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
return channel return channel
@ -773,8 +989,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
return self.hs return self.hs
def test_login_appservice_user(self): def test_login_appservice_user(self):
"""Test that an appservice user can use /login """Test that an appservice user can use /login"""
"""
self.register_as_user(AS_USER) self.register_as_user(AS_USER)
params = { params = {
@ -788,8 +1003,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
def test_login_appservice_user_bot(self): def test_login_appservice_user_bot(self):
"""Test that the appservice bot can use /login """Test that the appservice bot can use /login"""
"""
self.register_as_user(AS_USER) self.register_as_user(AS_USER)
params = { params = {
@ -803,8 +1017,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
def test_login_appservice_wrong_user(self): def test_login_appservice_wrong_user(self):
"""Test that non-as users cannot login with the as token """Test that non-as users cannot login with the as token"""
"""
self.register_as_user(AS_USER) self.register_as_user(AS_USER)
params = { params = {
@ -818,8 +1031,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.result["code"], b"403", channel.result)
def test_login_appservice_wrong_as(self): def test_login_appservice_wrong_as(self):
"""Test that as users cannot login with wrong as token """Test that as users cannot login with wrong as token"""
"""
self.register_as_user(AS_USER) self.register_as_user(AS_USER)
params = { params = {
@ -834,7 +1046,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
def test_login_appservice_no_token(self): def test_login_appservice_no_token(self):
"""Test that users must provide a token when using the appservice """Test that users must provide a token when using the appservice
login method login method
""" """
self.register_as_user(AS_USER) self.register_as_user(AS_USER)

View File

@ -444,6 +444,7 @@ class RestHelper:
# an 'oidc_config' suitable for login_via_oidc. # an 'oidc_config' suitable for login_via_oidc.
TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
TEST_OIDC_CONFIG = { TEST_OIDC_CONFIG = {
"enabled": True, "enabled": True,
"discover": False, "discover": False,
@ -451,7 +452,7 @@ TEST_OIDC_CONFIG = {
"client_id": "test-client-id", "client_id": "test-client-id",
"client_secret": "test-client-secret", "client_secret": "test-client-secret",
"scopes": ["profile"], "scopes": ["profile"],
"authorization_endpoint": "https://z", "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
"token_endpoint": "https://issuer.test/token", "token_endpoint": "https://issuer.test/token",
"userinfo_endpoint": "https://issuer.test/userinfo", "userinfo_endpoint": "https://issuer.test/userinfo",
"user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},

20
tox.ini
View File

@ -24,6 +24,20 @@ deps =
# install the "enum34" dependency of cryptography. # install the "enum34" dependency of cryptography.
pip>=10 pip>=10
# directories/files we run the linters on
lint_targets =
setup.py
synapse
tests
scripts
scripts-dev
stubs
contrib
synctl
synmark
.buildkite
docker
# default settings for all tox environments # default settings for all tox environments
[testenv] [testenv]
deps = deps =
@ -130,13 +144,13 @@ commands =
[testenv:check_codestyle] [testenv:check_codestyle]
extras = lint extras = lint
commands = commands =
python -m black --check --diff . python -m black --check --diff {[base]lint_targets}
/bin/sh -c "flake8 synapse tests scripts scripts-dev contrib synctl {env:PEP8SUFFIX:}" flake8 {[base]lint_targets} {env:PEP8SUFFIX:}
{toxinidir}/scripts-dev/config-lint.sh {toxinidir}/scripts-dev/config-lint.sh
[testenv:check_isort] [testenv:check_isort]
extras = lint extras = lint
commands = /bin/sh -c "isort -c --df --sp setup.cfg synapse tests scripts-dev scripts" commands = isort -c --df --sp setup.cfg {[base]lint_targets}
[testenv:check-newsfragment] [testenv:check-newsfragment]
skip_install = True skip_install = True