Merge branch 'develop' into markjh/v2_sync_api

This commit is contained in:
Mark Haines 2015-10-13 10:33:00 +01:00
commit f96b480670
11 changed files with 245 additions and 37 deletions

View File

@ -44,4 +44,7 @@ Eric Myhre <hash at exultant.us>
repository API. repository API.
Muthu Subramanian <muthu.subramanian.karunanidhi at ericsson.com> Muthu Subramanian <muthu.subramanian.karunanidhi at ericsson.com>
* Add SAML2 support for registration and logins. * Add SAML2 support for registration and login.
Steven Hammerton <steven.hammerton at openmarket.com>
* Add CAS support for registration and login.

View File

@ -47,7 +47,6 @@ class CodeMessageException(RuntimeError):
"""An exception with integer code and message string attributes.""" """An exception with integer code and message string attributes."""
def __init__(self, code, msg): def __init__(self, code, msg):
logger.info("%s: %s, %s", type(self).__name__, code, msg)
super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
self.code = code self.code = code
self.msg = msg self.msg = msg

43
synapse/config/cas.py Normal file
View File

@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class CasConfig(Config):
"""Cas Configuration
cas_server_url: URL of CAS server
"""
def read_config(self, config):
cas_config = config.get("cas_config", None)
if cas_config:
self.cas_enabled = True
self.cas_server_url = cas_config["server_url"]
self.cas_required_attributes = cas_config.get("required_attributes", {})
else:
self.cas_enabled = False
self.cas_server_url = None
self.cas_required_attributes = {}
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable CAS for registration and login.
#cas_config:
# server_url: "https://cas-server.com"
# #required_attributes:
# # name: value
"""

View File

@ -26,12 +26,13 @@ from .metrics import MetricsConfig
from .appservice import AppServiceConfig from .appservice import AppServiceConfig
from .key import KeyConfig from .key import KeyConfig
from .saml2 import SAML2Config from .saml2 import SAML2Config
from .cas import CasConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, VoipConfig, RegistrationConfig, MetricsConfig,
AppServiceConfig, KeyConfig, SAML2Config, ): AppServiceConfig, KeyConfig, SAML2Config, CasConfig):
pass pass

View File

@ -26,6 +26,7 @@ class ServerConfig(Config):
self.soft_file_limit = config["soft_file_limit"] self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize") self.daemonize = config.get("daemonize")
self.print_pidfile = config.get("print_pidfile") self.print_pidfile = config.get("print_pidfile")
self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", True) self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.listeners = config.get("listeners", []) self.listeners = config.get("listeners", [])

View File

@ -295,6 +295,38 @@ class AuthHandler(BaseHandler):
refresh_token = yield self.issue_refresh_token(user_id) refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token)) defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
def login_with_cas_user_id(self, user_id):
"""
Authenticates the user with the given user ID,
intended to have been captured from a CAS response
Args:
user_id (str): User ID
Returns:
A tuple of:
The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
def does_user_exist(self, user_id):
try:
yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(True)
except LoginError:
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id): def _find_user_id_and_pwd_hash(self, user_id):
"""Checks to see if a user with the given id exists. Will check case """Checks to see if a user with the given id exists. Will check case

View File

@ -324,7 +324,8 @@ class MessageHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def snapshot_all_rooms(self, user_id=None, pagin_config=None, as_client_event=True): def snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
"""Retrieve a snapshot of all rooms the user is invited or has joined. """Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is This snapshot may include messages for all rooms where the user is
@ -335,17 +336,19 @@ class MessageHandler(BaseHandler):
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config (synapse.api.streams.PaginationConfig): The pagination
config used to determine how many messages *PER ROOM* to return. config used to determine how many messages *PER ROOM* to return.
as_client_event (bool): True to get events in client-server format. as_client_event (bool): True to get events in client-server format.
include_archived (bool): True to get rooms that the user has left
Returns: Returns:
A list of dicts with "room_id" and "membership" keys for all rooms A list of dicts with "room_id" and "membership" keys for all rooms
the user is currently invited or joined in on. Rooms where the user the user is currently invited or joined in on. Rooms where the user
is joined on, may return a "messages" key with messages, depending is joined on, may return a "messages" key with messages, depending
on the specified PaginationConfig. on the specified PaginationConfig.
""" """
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
memberships.append(Membership.LEAVE)
room_list = yield self.store.get_rooms_for_user_where_membership_is( room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, user_id=user_id, membership_list=memberships
membership_list=[
Membership.INVITE, Membership.JOIN, Membership.LEAVE
]
) )
user = UserID.from_string(user_id) user = UserID.from_string(user_id)

View File

@ -67,7 +67,9 @@ class SimpleHttpClient(object):
connectTimeout=15, connectTimeout=15,
contextFactory=hs.get_http_client_context_factory() contextFactory=hs.get_http_client_context_factory()
) )
self.version_string = hs.version_string self.user_agent = hs.version_string
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,)
def request(self, method, uri, *args, **kwargs): def request(self, method, uri, *args, **kwargs):
# A small wrapper around self.agent.request() so we can easily attach # A small wrapper around self.agent.request() so we can easily attach
@ -112,7 +114,7 @@ class SimpleHttpClient(object):
uri.encode("ascii"), uri.encode("ascii"),
headers=Headers({ headers=Headers({
b"Content-Type": [b"application/x-www-form-urlencoded"], b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.version_string], b"User-Agent": [self.user_agent],
}), }),
bodyProducer=FileBodyProducer(StringIO(query_bytes)) bodyProducer=FileBodyProducer(StringIO(query_bytes))
) )
@ -131,7 +133,8 @@ class SimpleHttpClient(object):
"POST", "POST",
uri.encode("ascii"), uri.encode("ascii"),
headers=Headers({ headers=Headers({
"Content-Type": ["application/json"] b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
}), }),
bodyProducer=FileBodyProducer(StringIO(json_str)) bodyProducer=FileBodyProducer(StringIO(json_str))
) )
@ -157,27 +160,8 @@ class SimpleHttpClient(object):
On a non-2xx HTTP response. The response body will be used as the On a non-2xx HTTP response. The response body will be used as the
error message. error message.
""" """
if len(args): body = yield self.get_raw(uri, args)
query_bytes = urllib.urlencode(args, True) defer.returnValue(json.loads(body))
uri = "%s?%s" % (uri, query_bytes)
response = yield self.request(
"GET",
uri.encode("ascii"),
headers=Headers({
b"User-Agent": [self.version_string],
})
)
body = yield preserve_context_over_fn(readBody, response)
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
else:
# NB: This is explicitly not json.loads(body)'d because the contract
# of CodeMessageException is a *string* message. Callers can always
# load it into JSON if they want.
raise CodeMessageException(response.code, body)
@defer.inlineCallbacks @defer.inlineCallbacks
def put_json(self, uri, json_body, args={}): def put_json(self, uri, json_body, args={}):
@ -206,7 +190,7 @@ class SimpleHttpClient(object):
"PUT", "PUT",
uri.encode("ascii"), uri.encode("ascii"),
headers=Headers({ headers=Headers({
b"User-Agent": [self.version_string], b"User-Agent": [self.user_agent],
"Content-Type": ["application/json"] "Content-Type": ["application/json"]
}), }),
bodyProducer=FileBodyProducer(StringIO(json_str)) bodyProducer=FileBodyProducer(StringIO(json_str))
@ -222,6 +206,42 @@ class SimpleHttpClient(object):
# load it into JSON if they want. # load it into JSON if they want.
raise CodeMessageException(response.code, body) raise CodeMessageException(response.code, body)
@defer.inlineCallbacks
def get_raw(self, uri, args={}):
""" Gets raw text from the given URI.
Args:
uri (str): The URI to request, not including query parameters
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body at text.
Raises:
On a non-2xx HTTP response. The response body will be used as the
error message.
"""
if len(args):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
response = yield self.request(
"GET",
uri.encode("ascii"),
headers=Headers({
b"User-Agent": [self.user_agent],
})
)
body = yield preserve_context_over_fn(readBody, response)
if 200 <= response.code < 300:
defer.returnValue(body)
else:
raise CodeMessageException(response.code, body)
class CaptchaServerHttpClient(SimpleHttpClient): class CaptchaServerHttpClient(SimpleHttpClient):
""" """
@ -241,7 +261,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
bodyProducer=FileBodyProducer(StringIO(query_bytes)), bodyProducer=FileBodyProducer(StringIO(query_bytes)),
headers=Headers({ headers=Headers({
b"Content-Type": [b"application/x-www-form-urlencoded"], b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.version_string], b"User-Agent": [self.user_agent],
}) })
) )

View File

@ -186,7 +186,7 @@ class Pusher(object):
if not display_name: if not display_name:
return False return False
return re.search( return re.search(
"\b%s\b" % re.escape(display_name), ev['content']['body'], r"\b%s\b" % re.escape(display_name), ev['content']['body'],
flags=re.IGNORECASE flags=re.IGNORECASE
) is not None ) is not None

View File

@ -29,10 +29,12 @@ class InitialSyncRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
handler = self.handlers.message_handler handler = self.handlers.message_handler
include_archived = request.args.get("archived", None) == ["true"]
content = yield handler.snapshot_all_rooms( content = yield handler.snapshot_all_rooms(
user_id=user.to_string(), user_id=user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,
as_client_event=as_client_event as_client_event=as_client_event,
include_archived=include_archived,
) )
defer.returnValue((200, content)) defer.returnValue((200, content))

View File

@ -15,7 +15,8 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, LoginError, Codes
from synapse.http.client import SimpleHttpClient
from synapse.types import UserID from synapse.types import UserID
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_pattern
@ -27,6 +28,8 @@ from saml2 import BINDING_HTTP_POST
from saml2 import config from saml2 import config
from saml2.client import Saml2Client from saml2.client import Saml2Client
import xml.etree.ElementTree as ET
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,16 +38,23 @@ class LoginRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login$") PATTERN = client_path_pattern("/login$")
PASS_TYPE = "m.login.password" PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2" SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas"
def __init__(self, hs): def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs) super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.saml2_enabled = hs.config.saml2_enabled self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.cas_server_url = hs.config.cas_server_url
self.cas_required_attributes = hs.config.cas_required_attributes
self.servername = hs.config.server_name
def on_GET(self, request): def on_GET(self, request):
flows = [{"type": LoginRestServlet.PASS_TYPE}] flows = [{"type": LoginRestServlet.PASS_TYPE}]
if self.saml2_enabled: if self.saml2_enabled:
flows.append({"type": LoginRestServlet.SAML2_TYPE}) flows.append({"type": LoginRestServlet.SAML2_TYPE})
if self.cas_enabled:
flows.append({"type": LoginRestServlet.CAS_TYPE})
return (200, {"flows": flows}) return (200, {"flows": flows})
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
@ -67,6 +77,19 @@ class LoginRestServlet(ClientV1RestServlet):
"uri": "%s%s" % (self.idp_redirect_url, relay_state) "uri": "%s%s" % (self.idp_redirect_url, relay_state)
} }
defer.returnValue((200, result)) defer.returnValue((200, result))
elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE):
# TODO: get this from the homeserver rather than creating a new one for
# each request
http_client = SimpleHttpClient(self.hs)
uri = "%s/proxyValidate" % (self.cas_server_url,)
args = {
"ticket": login_submission["ticket"],
"service": login_submission["service"]
}
body = yield http_client.get_raw(uri, args)
result = yield self.do_cas_login(body)
defer.returnValue(result)
else: else:
raise SynapseError(400, "Bad login type.") raise SynapseError(400, "Bad login type.")
except KeyError: except KeyError:
@ -100,6 +123,74 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
@defer.inlineCallbacks
def do_cas_login(self, cas_response_body):
user, attributes = self.parse_cas_response(cas_response_body)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
yield auth_handler.login_with_cas_user_id(user_id)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in attribute tags
# to the full URL of the namespace.
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
# We don't care about namespace here and it will always be encased in
# curly braces, so we remove them.
if "}" in attribute.tag:
attributes[attribute.tag.split("}")[1]] = attribute.text
else:
attributes[attribute.tag] = attribute.text
if user is None or attributes is None:
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
return (user, attributes)
class LoginFallbackRestServlet(ClientV1RestServlet): class LoginFallbackRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/fallback$") PATTERN = client_path_pattern("/login/fallback$")
@ -174,6 +265,17 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue((200, {"status": "not_authenticated"})) defer.returnValue((200, {"status": "not_authenticated"}))
class CasRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas")
def __init__(self, hs):
super(CasRestServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
def on_GET(self, request):
return (200, {"serverUrl": self.cas_server_url})
def _parse_json(request): def _parse_json(request):
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
@ -188,4 +290,6 @@ def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server) LoginRestServlet(hs).register(http_server)
if hs.config.saml2_enabled: if hs.config.saml2_enabled:
SAML2RestServlet(hs).register(http_server) SAML2RestServlet(hs).register(http_server)
if hs.config.cas_enabled:
CasRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server) # TODO PasswordResetRestServlet(hs).register(http_server)