Merge branch 'release-v0.7.1' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2015-02-19 10:38:48 +00:00
commit 8321e8a2e0
55 changed files with 2544 additions and 304 deletions

View File

@ -1,3 +1,14 @@
Changes in synapse v0.7.1 (2015-02-19)
======================================
* Add cache when fetching events from remote servers to stop repeatedly
fetching events with bad signatures.
* Respect the per remote server retry scheme when fetching both events and
server keys to reduce the number of times we send requests to dead servers.
* Inform remote servers when the local server fails to handle a received event.
* Turn off python bytecode generation due to problems experienced when
upgrading from previous versions.
Changes in synapse v0.7.0 (2015-02-12) Changes in synapse v0.7.0 (2015-02-12)
====================================== ======================================

View File

@ -21,6 +21,7 @@ import datetime
import argparse import argparse
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.util.frozenutils import unfreeze
def make_graph(db_name, room_id, file_prefix, limit): def make_graph(db_name, room_id, file_prefix, limit):
@ -70,7 +71,7 @@ def make_graph(db_name, room_id, file_prefix, limit):
float(event.origin_server_ts) / 1000 float(event.origin_server_ts) / 1000
).strftime('%Y-%m-%d %H:%M:%S,%f') ).strftime('%Y-%m-%d %H:%M:%S,%f')
content = json.dumps(event.get_dict()["content"]) content = json.dumps(unfreeze(event.get_dict()["content"]))
label = ( label = (
"<" "<"

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.7.0f" __version__ = "0.7.1"

View File

@ -306,6 +306,34 @@ class Auth(object):
# Can optionally look elsewhere in the request (e.g. headers) # Can optionally look elsewhere in the request (e.g. headers)
try: try:
access_token = request.args["access_token"][0] access_token = request.args["access_token"][0]
# Check for application service tokens with a user_id override
try:
app_service = yield self.store.get_app_service_by_token(
access_token
)
if not app_service:
raise KeyError
user_id = app_service.sender
if "user_id" in request.args:
user_id = request.args["user_id"][0]
if not app_service.is_interested_in_user(user_id):
raise AuthError(
403,
"Application service cannot masquerade as this user."
)
if not user_id:
raise KeyError
defer.returnValue(
(UserID.from_string(user_id), ClientInfo("", ""))
)
return
except KeyError:
pass # normal users won't have this query parameter set
user_info = yield self.get_user_by_token(access_token) user_info = yield self.get_user_by_token(access_token)
user = user_info["user"] user = user_info["user"]
device_id = user_info["device_id"] device_id = user_info["device_id"]
@ -344,8 +372,7 @@ class Auth(object):
try: try:
ret = yield self.store.get_user_by_token(token=token) ret = yield self.store.get_user_by_token(token=token)
if not ret: if not ret:
raise StoreError() raise StoreError(400, "Unknown token")
user_info = { user_info = {
"admin": bool(ret.get("admin", False)), "admin": bool(ret.get("admin", False)),
"device_id": ret.get("device_id"), "device_id": ret.get("device_id"),
@ -358,6 +385,18 @@ class Auth(object):
raise AuthError(403, "Unrecognised access token.", raise AuthError(403, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN) errcode=Codes.UNKNOWN_TOKEN)
@defer.inlineCallbacks
def get_appservice_by_req(self, request):
try:
token = request.args["access_token"][0]
service = yield self.store.get_app_service_by_token(token)
if not service:
raise AuthError(403, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN)
defer.returnValue(service)
except KeyError:
raise AuthError(403, "Missing access token.")
def is_server_admin(self, user): def is_server_admin(self, user):
return self.store.is_server_admin(user) return self.store.is_server_admin(user)

View File

@ -59,6 +59,7 @@ class LoginType(object):
EMAIL_URL = u"m.login.email.url" EMAIL_URL = u"m.login.email.url"
EMAIL_IDENTITY = u"m.login.email.identity" EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha" RECAPTCHA = u"m.login.recaptcha"
APPLICATION_SERVICE = u"m.login.application_service"
class EventTypes(object): class EventTypes(object):

View File

@ -36,7 +36,8 @@ class Codes(object):
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
CAPTCHA_INVALID = "M_CAPTCHA_INVALID" CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
MISSING_PARAM = "M_MISSING_PARAM", MISSING_PARAM = "M_MISSING_PARAM",
TOO_LARGE = "M_TOO_LARGE" TOO_LARGE = "M_TOO_LARGE",
EXCLUSIVE = "M_EXCLUSIVE"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):

View File

@ -22,3 +22,4 @@ WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content" CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_PREFIX = "/_matrix/key/v1" SERVER_KEY_PREFIX = "/_matrix/key/v1"
MEDIA_PREFIX = "/_matrix/media/v1" MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"

View File

@ -14,6 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
sys.dont_write_bytecode = True
from synapse.storage import prepare_database, UpgradeDatabaseException from synapse.storage import prepare_database, UpgradeDatabaseException
from synapse.server import HomeServer from synapse.server import HomeServer
@ -26,13 +29,14 @@ from twisted.web.resource import Resource
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import Site from twisted.web.server import Site
from synapse.http.server import JsonResource, RootRedirect from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.appservice.v1 import AppServiceRestResource
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.http.server_key_resource import LocalKey from synapse.http.server_key_resource import LocalKey
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import ( from synapse.api.urls import (
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, APP_SERVICE_PREFIX
) )
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
@ -48,7 +52,7 @@ import synapse
import logging import logging
import os import os
import re import re
import sys import subprocess
import sqlite3 import sqlite3
import syweb import syweb
@ -69,6 +73,9 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_federation(self): def build_resource_for_federation(self):
return JsonResource(self) return JsonResource(self)
def build_resource_for_app_services(self):
return AppServiceRestResource(self)
def build_resource_for_web_client(self): def build_resource_for_web_client(self):
syweb_path = os.path.dirname(syweb.__file__) syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient") webclient_path = os.path.join(syweb_path, "webclient")
@ -90,7 +97,9 @@ class SynapseHomeServer(HomeServer):
"sqlite3", self.get_db_name(), "sqlite3", self.get_db_name(),
check_same_thread=False, check_same_thread=False,
cp_min=1, cp_min=1,
cp_max=1 cp_max=1,
cp_openfun=prepare_database, # Prepare the database for each conn
# so that :memory: sqlite works
) )
def create_resource_tree(self, web_client, redirect_root_to_web_client): def create_resource_tree(self, web_client, redirect_root_to_web_client):
@ -114,6 +123,7 @@ class SynapseHomeServer(HomeServer):
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()), (CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()), (SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
(MEDIA_PREFIX, self.get_resource_for_media_repository()), (MEDIA_PREFIX, self.get_resource_for_media_repository()),
(APP_SERVICE_PREFIX, self.get_resource_for_app_services()),
] ]
if web_client: if web_client:
logger.info("Adding the web client.") logger.info("Adding the web client.")
@ -199,6 +209,66 @@ class SynapseHomeServer(HomeServer):
logger.info("Synapse now listening on port %d", unsecure_port) logger.info("Synapse now listening on port %d", unsecure_port)
def get_version_string():
null = open(os.devnull, 'w')
cwd = os.path.dirname(os.path.abspath(__file__))
try:
git_branch = subprocess.check_output(
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
stderr=null,
cwd=cwd,
).strip()
git_branch = "b=" + git_branch
except subprocess.CalledProcessError:
git_branch = ""
try:
git_tag = subprocess.check_output(
['git', 'describe', '--exact-match'],
stderr=null,
cwd=cwd,
).strip()
git_tag = "t=" + git_tag
except subprocess.CalledProcessError:
git_tag = ""
try:
git_commit = subprocess.check_output(
['git', 'rev-parse', '--short', 'HEAD'],
stderr=null,
cwd=cwd,
).strip()
except subprocess.CalledProcessError:
git_commit = ""
try:
dirty_string = "-this_is_a_dirty_checkout"
is_dirty = subprocess.check_output(
['git', 'describe', '--dirty=' + dirty_string],
stderr=null,
cwd=cwd,
).strip().endswith(dirty_string)
git_dirty = "dirty" if is_dirty else ""
except subprocess.CalledProcessError:
git_dirty = ""
if git_branch or git_tag or git_commit or git_dirty:
git_version = ",".join(
s for s in
(git_branch, git_tag, git_commit, git_dirty,)
if s
)
return (
"Synapse/%s (%s)" % (
synapse.__version__, git_version,
)
).encode("ascii")
return ("Synapse/%s" % (synapse.__version__,)).encode("ascii")
def setup(): def setup():
config = HomeServerConfig.load_config( config = HomeServerConfig.load_config(
"Synapse Homeserver", "Synapse Homeserver",
@ -210,8 +280,10 @@ def setup():
check_requirements() check_requirements()
version_string = get_version_string()
logger.info("Server hostname: %s", config.server_name) logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", synapse.__version__) logger.info("Server version: %s", version_string)
if re.search(":[0-9]+$", config.server_name): if re.search(":[0-9]+$", config.server_name):
domain_with_port = config.server_name domain_with_port = config.server_name
@ -228,6 +300,7 @@ def setup():
tls_context_factory=tls_context_factory, tls_context_factory=tls_context_factory,
config=config, config=config,
content_addr=config.content_addr, content_addr=config.content_addr,
version_string=version_string,
) )
hs.create_resource_tree( hs.create_resource_tree(
@ -252,14 +325,6 @@ def setup():
logger.info("Database prepared in %s.", db_name) logger.info("Database prepared in %s.", db_name)
db_pool = hs.get_db_pool()
if db_name == ":memory:":
# Memory databases will need to be setup each time they are opened.
reactor.callWhenRunning(
db_pool.runWithConnection, prepare_database
)
if config.manhole: if config.manhole:
f = twisted.manhole.telnet.ShellFactory() f = twisted.manhole.telnet.ShellFactory()
f.username = "matrix" f.username = "matrix"
@ -270,12 +335,13 @@ def setup():
bind_port = config.bind_port bind_port = config.bind_port
if config.no_tls: if config.no_tls:
bind_port = None bind_port = None
hs.start_listening(bind_port, config.unsecure_port) hs.start_listening(bind_port, config.unsecure_port)
hs.get_pusherpool().start() hs.get_pusherpool().start()
hs.get_state_handler().start_caching() hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling() hs.get_datastore().start_profiling()
hs.get_replication_layer().start_get_pdu_cache()
if config.daemonize: if config.daemonize:
print config.pid_file print config.pid_file

View File

@ -19,7 +19,7 @@ import os
import subprocess import subprocess
import signal import signal
SYNAPSE = ["python", "-m", "synapse.app.homeserver"] SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
CONFIGFILE = "homeserver.yaml" CONFIGFILE = "homeserver.yaml"
PIDFILE = "homeserver.pid" PIDFILE = "homeserver.pid"

View File

@ -0,0 +1,147 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.constants import EventTypes
import logging
import re
logger = logging.getLogger(__name__)
class ApplicationService(object):
"""Defines an application service. This definition is mostly what is
provided to the /register AS API.
Provides methods to check if this service is "interested" in events.
"""
NS_USERS = "users"
NS_ALIASES = "aliases"
NS_ROOMS = "rooms"
# The ordering here is important as it is used to map database values (which
# are stored as ints representing the position in this list) to namespace
# values.
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None,
sender=None, txn_id=None):
self.token = token
self.url = url
self.hs_token = hs_token
self.sender = sender
self.namespaces = self._check_namespaces(namespaces)
self.txn_id = txn_id
def _check_namespaces(self, namespaces):
# Sanity check that it is of the form:
# {
# users: ["regex",...],
# aliases: ["regex",...],
# rooms: ["regex",...],
# }
if not namespaces:
return None
for ns in ApplicationService.NS_LIST:
if type(namespaces[ns]) != list:
raise ValueError("Bad namespace value for '%s'", ns)
for regex in namespaces[ns]:
if not isinstance(regex, basestring):
raise ValueError("Expected string regex for ns '%s'", ns)
return namespaces
def _matches_regex(self, test_string, namespace_key):
if not isinstance(test_string, basestring):
logger.error(
"Expected a string to test regex against, but got %s",
test_string
)
return False
for regex in self.namespaces[namespace_key]:
if re.match(regex, test_string):
return True
return False
def _matches_user(self, event, member_list):
if (hasattr(event, "sender") and
self.is_interested_in_user(event.sender)):
return True
# also check m.room.member state key
if (hasattr(event, "type") and event.type == EventTypes.Member
and hasattr(event, "state_key")
and self.is_interested_in_user(event.state_key)):
return True
# check joined member events
for member in member_list:
if self.is_interested_in_user(member.state_key):
return True
return False
def _matches_room_id(self, event):
if hasattr(event, "room_id"):
return self.is_interested_in_room(event.room_id)
return False
def _matches_aliases(self, event, alias_list):
for alias in alias_list:
if self.is_interested_in_alias(alias):
return True
return False
def is_interested(self, event, restrict_to=None, aliases_for_event=None,
member_list=None):
"""Check if this service is interested in this event.
Args:
event(Event): The event to check.
restrict_to(str): The namespace to restrict regex tests to.
aliases_for_event(list): A list of all the known room aliases for
this event.
member_list(list): A list of all joined room members in this room.
Returns:
bool: True if this service would like to know about this event.
"""
if aliases_for_event is None:
aliases_for_event = []
if member_list is None:
member_list = []
if restrict_to and restrict_to not in ApplicationService.NS_LIST:
# this is a programming error, so fail early and raise a general
# exception
raise Exception("Unexpected restrict_to value: %s". restrict_to)
if not restrict_to:
return (self._matches_user(event, member_list)
or self._matches_aliases(event, aliases_for_event)
or self._matches_room_id(event))
elif restrict_to == ApplicationService.NS_ALIASES:
return self._matches_aliases(event, aliases_for_event)
elif restrict_to == ApplicationService.NS_ROOMS:
return self._matches_room_id(event)
elif restrict_to == ApplicationService.NS_USERS:
return self._matches_user(event, member_list)
def is_interested_in_user(self, user_id):
return self._matches_regex(user_id, ApplicationService.NS_USERS)
def is_interested_in_alias(self, alias):
return self._matches_regex(alias, ApplicationService.NS_ALIASES)
def is_interested_in_room(self, room_id):
return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
def __str__(self):
return "ApplicationService: %s" % (self.__dict__,)

108
synapse/appservice/api.py Normal file
View File

@ -0,0 +1,108 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event
import logging
import urllib
logger = logging.getLogger(__name__)
class ApplicationServiceApi(SimpleHttpClient):
"""This class manages HS -> AS communications, including querying and
pushing.
"""
def __init__(self, hs):
super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock()
@defer.inlineCallbacks
def query_user(self, service, user_id):
uri = service.url + ("/users/%s" % urllib.quote(user_id))
response = None
try:
response = yield self.get_json(uri, {
"access_token": service.hs_token
})
if response is not None: # just an empty json object
defer.returnValue(True)
except CodeMessageException as e:
if e.code == 404:
defer.returnValue(False)
return
logger.warning("query_user to %s received %s", uri, e.code)
except Exception as ex:
logger.warning("query_user to %s threw exception %s", uri, ex)
defer.returnValue(False)
@defer.inlineCallbacks
def query_alias(self, service, alias):
uri = service.url + ("/rooms/%s" % urllib.quote(alias))
response = None
try:
response = yield self.get_json(uri, {
"access_token": service.hs_token
})
if response is not None: # just an empty json object
defer.returnValue(True)
except CodeMessageException as e:
logger.warning("query_alias to %s received %s", uri, e.code)
if e.code == 404:
defer.returnValue(False)
return
except Exception as ex:
logger.warning("query_alias to %s threw exception %s", uri, ex)
defer.returnValue(False)
@defer.inlineCallbacks
def push_bulk(self, service, events):
events = self._serialize(events)
uri = service.url + ("/transactions/%s" %
urllib.quote(str(0))) # TODO txn_ids
response = None
try:
response = yield self.put_json(
uri=uri,
json_body={
"events": events
},
args={
"access_token": service.hs_token
})
if response: # just an empty json object
# TODO: Mark txn as sent successfully
defer.returnValue(True)
except CodeMessageException as e:
logger.warning("push_bulk to %s received %s", uri, e.code)
except Exception as ex:
logger.warning("push_bulk to %s threw exception %s", uri, ex)
defer.returnValue(False)
@defer.inlineCallbacks
def push(self, service, event):
response = yield self.push_bulk(service, [event])
defer.returnValue(response)
def _serialize(self, events):
time_now = self.clock.time_msec()
return [
serialize_event(e, time_now, as_client_event=True) for e in events
]

View File

@ -22,6 +22,8 @@ from syutil.crypto.signing_key import (
from syutil.base64util import decode_base64, encode_base64 from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from OpenSSL import crypto from OpenSSL import crypto
import logging import logging
@ -87,8 +89,14 @@ class Keyring(object):
return return
# Try to fetch the key from the remote server. # Try to fetch the key from the remote server.
# TODO(markjh): Ratelimit requests to a given server.
limiter = yield get_retry_limiter(
server_name,
self.clock,
self.store,
)
with limiter:
(response, tls_certificate) = yield fetch_server_key( (response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_context_factory server_name, self.hs.tls_context_factory
) )

View File

@ -19,10 +19,13 @@ from twisted.internet import defer
from .federation_base import FederationBase from .federation_base import FederationBase
from .units import Edu from .units import Edu
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException, SynapseError
from synapse.util.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
import logging import logging
@ -30,6 +33,20 @@ logger = logging.getLogger(__name__)
class FederationClient(FederationBase): class FederationClient(FederationBase):
def __init__(self):
self._get_pdu_cache = None
def start_get_pdu_cache(self):
self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache",
clock=self._clock,
max_len=1000,
expiry_ms=120*1000,
reset_expiry_on_get=False,
)
self._get_pdu_cache.start()
@log_function @log_function
def send_pdu(self, pdu, destinations): def send_pdu(self, pdu, destinations):
"""Informs the replication layer about a new PDU generated within the """Informs the replication layer about a new PDU generated within the
@ -160,9 +177,21 @@ class FederationClient(FederationBase):
# TODO: Rate limit the number of times we try and get the same event. # TODO: Rate limit the number of times we try and get the same event.
if self._get_pdu_cache:
e = self._get_pdu_cache.get(event_id)
if e:
defer.returnValue(e)
pdu = None pdu = None
for destination in destinations: for destination in destinations:
try: try:
limiter = yield get_retry_limiter(
destination,
self._clock,
self.store,
)
with limiter:
transaction_data = yield self.transport_layer.get_event( transaction_data = yield self.transport_layer.get_event(
destination, event_id destination, event_id
) )
@ -181,8 +210,25 @@ class FederationClient(FederationBase):
pdu = yield self._check_sigs_and_hash(pdu) pdu = yield self._check_sigs_and_hash(pdu)
break break
except CodeMessageException:
except SynapseError:
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise raise
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
except NotRetryingDestination as e:
logger.info(e.message)
continue
except Exception as e: except Exception as e:
logger.info( logger.info(
"Failed to get PDU %s from %s because %s", "Failed to get PDU %s from %s because %s",
@ -190,6 +236,9 @@ class FederationClient(FederationBase):
) )
continue continue
if self._get_pdu_cache is not None:
self._get_pdu_cache[event_id] = pdu
defer.returnValue(pdu) defer.returnValue(pdu)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -114,7 +114,15 @@ class FederationServer(FederationBase):
with PreserveLoggingContext(): with PreserveLoggingContext():
dl = [] dl = []
for pdu in pdu_list: for pdu in pdu_list:
dl.append(self._handle_new_pdu(transaction.origin, pdu)) d = self._handle_new_pdu(transaction.origin, pdu)
def handle_failure(failure):
failure.trap(FederationError)
self.send_failure(failure.value, transaction.origin)
d.addErrback(handle_failure)
dl.append(d)
if hasattr(transaction, "edus"): if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]: for edu in [Edu(**x) for x in transaction.edus]:
@ -124,7 +132,10 @@ class FederationServer(FederationBase):
edu.content edu.content
) )
results = yield defer.DeferredList(dl) for failure in getattr(transaction, "pdu_failures", []):
logger.info("Got failure %r", failure)
results = yield defer.DeferredList(dl, consumeErrors=True)
ret = [] ret = []
for r in results: for r in results:
@ -132,10 +143,16 @@ class FederationServer(FederationBase):
ret.append({}) ret.append({})
else: else:
logger.exception(r[1]) logger.exception(r[1])
ret.append({"error": str(r[1])}) ret.append({"error": str(r[1].value)})
logger.debug("Returning: %s", str(ret)) logger.debug("Returning: %s", str(ret))
response = {
"pdus": dict(zip(
(p.event_id for p in pdu_list), ret
)),
}
yield self.transaction_actions.set_response( yield self.transaction_actions.set_response(
transaction, transaction,
200, response 200, response
@ -331,7 +348,6 @@ class FederationServer(FederationBase):
) )
if already_seen: if already_seen:
logger.debug("Already seen pdu %s", pdu.event_id) logger.debug("Already seen pdu %s", pdu.event_id)
defer.returnValue({})
return return
# Check signature. # Check signature.
@ -367,7 +383,13 @@ class FederationServer(FederationBase):
pdu.room_id, min_depth pdu.room_id, min_depth
) )
if min_depth and pdu.depth > min_depth and max_recursion > 0: if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
# message, to work around the fact that some events will
# reference really really old events we really don't want to
# send to the clients.
pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth and max_recursion > 0:
for event_id, hashes in pdu.prev_events: for event_id, hashes in pdu.prev_events:
if event_id not in have_seen: if event_id not in have_seen:
logger.debug( logger.debug(
@ -418,7 +440,7 @@ class FederationServer(FederationBase):
except: except:
logger.warn("Failed to get state for event: %s", pdu.event_id) logger.warn("Failed to get state for event: %s", pdu.event_id)
ret = yield self.handler.on_receive_pdu( yield self.handler.on_receive_pdu(
origin, origin,
pdu, pdu,
backfilled=False, backfilled=False,
@ -426,8 +448,6 @@ class FederationServer(FederationBase):
auth_chain=auth_chain, auth_chain=auth_chain,
) )
defer.returnValue(ret)
def __str__(self): def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name return "<ReplicationLayer(%s)>" % self.server_name

View File

@ -22,6 +22,9 @@ from .units import Transaction
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination,
)
import logging import logging
@ -63,6 +66,26 @@ class TransactionQueue(object):
# HACK to get unique tx id # HACK to get unique tx id
self._next_txn_id = int(self._clock.time_msec()) self._next_txn_id = int(self._clock.time_msec())
def can_send_to(self, destination):
"""Can we send messages to the given server?
We can't send messages to ourselves. If we are running on localhost
then we can only federation with other servers running on localhost.
Otherwise we only federate with servers on a public domain.
Args:
destination(str): The server we are possibly trying to send to.
Returns:
bool: True if we can send to the server.
"""
if destination == self.server_name:
return False
if self.server_name.startswith("localhost"):
return destination.startswith("localhost")
else:
return not destination.startswith("localhost")
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def enqueue_pdu(self, pdu, destinations, order): def enqueue_pdu(self, pdu, destinations, order):
@ -71,8 +94,9 @@ class TransactionQueue(object):
# table and we'll get back to it later. # table and we'll get back to it later.
destinations = set(destinations) destinations = set(destinations)
destinations.discard(self.server_name) destinations = set(
destinations.discard("localhost") dest for dest in destinations if self.can_send_to(dest)
)
logger.debug("Sending to: %s", str(destinations)) logger.debug("Sending to: %s", str(destinations))
@ -87,24 +111,27 @@ class TransactionQueue(object):
(pdu, deferred, order) (pdu, deferred, order)
) )
def eb(failure): def chain(failure):
if not deferred.called: if not deferred.called:
deferred.errback(failure) deferred.errback(failure)
else:
logger.warn("Failed to send pdu", failure) def log_failure(failure):
logger.warn("Failed to send pdu", failure.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext(): with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(eb) self._attempt_new_transaction(destination).addErrback(chain)
deferreds.append(deferred) deferreds.append(deferred)
yield defer.DeferredList(deferreds) yield defer.DeferredList(deferreds, consumeErrors=True)
# NO inlineCallbacks # NO inlineCallbacks
def enqueue_edu(self, edu): def enqueue_edu(self, edu):
destination = edu.destination destination = edu.destination
if destination == self.server_name: if not self.can_send_to(destination):
return return
deferred = defer.Deferred() deferred = defer.Deferred()
@ -112,51 +139,53 @@ class TransactionQueue(object):
(edu, deferred) (edu, deferred)
) )
def eb(failure): def chain(failure):
if not deferred.called: if not deferred.called:
deferred.errback(failure) deferred.errback(failure)
else:
logger.warn("Failed to send edu", failure) def log_failure(failure):
logger.warn("Failed to send pdu", failure.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext(): with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(eb) self._attempt_new_transaction(destination).addErrback(chain)
return deferred return deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def enqueue_failure(self, failure, destination): def enqueue_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost":
return
deferred = defer.Deferred() deferred = defer.Deferred()
if not self.can_send_to(destination):
return
self.pending_failures_by_dest.setdefault( self.pending_failures_by_dest.setdefault(
destination, [] destination, []
).append( ).append(
(failure, deferred) (failure, deferred)
) )
def chain(f):
if not deferred.called:
deferred.errback(f)
def log_failure(f):
logger.warn("Failed to send pdu", f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
yield deferred yield deferred
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _attempt_new_transaction(self, destination): def _attempt_new_transaction(self, destination):
(retry_last_ts, retry_interval) = (0, 0)
retry_timings = yield self.store.get_destination_retry_timings(
destination
)
if retry_timings:
(retry_last_ts, retry_interval) = (
retry_timings.retry_last_ts, retry_timings.retry_interval
)
if retry_last_ts + retry_interval > int(self._clock.time_msec()):
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
return
else:
logger.info("TX [%s] is ready for retry", destination)
if destination in self.pending_transactions: if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending # XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing. # request at which point pending_pdus_by_dest just keeps growing.
@ -183,15 +212,6 @@ class TransactionQueue(object):
logger.info("TX [%s] Nothing to send", destination) logger.info("TX [%s] Nothing to send", destination)
return return
logger.debug(
"TX [%s] Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
)
# Sort based on the order field # Sort based on the order field
pending_pdus.sort(key=lambda t: t[2]) pending_pdus.sort(key=lambda t: t[2])
@ -204,6 +224,21 @@ class TransactionQueue(object):
] ]
try: try:
limiter = yield get_retry_limiter(
destination,
self._clock,
self.store,
)
logger.debug(
"TX [%s] Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
)
self.pending_transactions[destination] = 1 self.pending_transactions[destination] = 1
logger.debug("TX [%s] Persisting transaction...", destination) logger.debug("TX [%s] Persisting transaction...", destination)
@ -229,6 +264,7 @@ class TransactionQueue(object):
transaction.transaction_id, transaction.transaction_id,
) )
with limiter:
# Actually send the transaction # Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age # FIXME (erikj): This is a bit of a hack to make the Pdu age
@ -249,6 +285,14 @@ class TransactionQueue(object):
transaction, json_data_cb transaction, json_data_cb
) )
code = 200 code = 200
if response:
for e_id, r in getattr(response, "pdus", {}).items():
if "error" in r:
logger.warn(
"Transaction returned error for %s: %s",
e_id, r,
)
except HttpResponseException as e: except HttpResponseException as e:
code = e.code code = e.code
response = e.response response = e.response
@ -263,18 +307,13 @@ class TransactionQueue(object):
) )
logger.debug("TX [%s] Marked as delivered", destination) logger.debug("TX [%s] Marked as delivered", destination)
logger.debug("TX [%s] Yielding to callbacks...", destination) logger.debug("TX [%s] Yielding to callbacks...", destination)
for deferred in deferreds: for deferred in deferreds:
if code == 200: if code == 200:
if retry_last_ts:
# this host is alive! reset retry schedule
yield self.store.set_destination_retry_timings(
destination, 0, 0
)
deferred.callback(None) deferred.callback(None)
else: else:
self.set_retrying(destination, retry_interval)
deferred.errback(RuntimeError("Got status %d" % code)) deferred.errback(RuntimeError("Got status %d" % code))
# Ensures we don't continue until all callbacks on that # Ensures we don't continue until all callbacks on that
@ -285,6 +324,12 @@ class TransactionQueue(object):
pass pass
logger.debug("TX [%s] Yielded to callbacks", destination) logger.debug("TX [%s] Yielded to callbacks", destination)
except NotRetryingDestination:
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
except RuntimeError as e: except RuntimeError as e:
# We capture this here as there as nothing actually listens # We capture this here as there as nothing actually listens
# for this finishing functions deferred. # for this finishing functions deferred.
@ -296,14 +341,12 @@ class TransactionQueue(object):
except Exception as e: except Exception as e:
# We capture this here as there as nothing actually listens # We capture this here as there as nothing actually listens
# for this finishing functions deferred. # for this finishing functions deferred.
logger.exception( logger.warn(
"TX [%s] Problem in _attempt_transaction: %s", "TX [%s] Problem in _attempt_transaction: %s",
destination, destination,
e, e,
) )
self.set_retrying(destination, retry_interval)
for deferred in deferreds: for deferred in deferreds:
if not deferred.called: if not deferred.called:
deferred.errback(e) deferred.errback(e)
@ -314,22 +357,3 @@ class TransactionQueue(object):
# Check to see if there is anything else to send. # Check to see if there is anything else to send.
self._attempt_new_transaction(destination) self._attempt_new_transaction(destination)
@defer.inlineCallbacks
def set_retrying(self, destination, retry_interval):
# track that this destination is having problems and we should
# give it a chance to recover before trying it again
if retry_interval:
retry_interval *= 2
# plateau at hourly retries for now
if retry_interval >= 60 * 60 * 1000:
retry_interval = 60 * 60 * 1000
else:
retry_interval = 2000 # try again at first after 2 seconds
yield self.store.set_destination_retry_timings(
destination,
int(self._clock.time_msec()),
retry_interval
)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.appservice.api import ApplicationServiceApi
from .register import RegistrationHandler from .register import RegistrationHandler
from .room import ( from .room import (
RoomCreationHandler, RoomMemberHandler, RoomListHandler RoomCreationHandler, RoomMemberHandler, RoomListHandler
@ -26,6 +27,7 @@ from .presence import PresenceHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
from .typing import TypingNotificationHandler from .typing import TypingNotificationHandler
from .admin import AdminHandler from .admin import AdminHandler
from .appservice import ApplicationServicesHandler
from .sync import SyncHandler from .sync import SyncHandler
@ -52,4 +54,7 @@ class Handlers(object):
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs) self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)
self.appservice_handler = ApplicationServicesHandler(
hs, ApplicationServiceApi(hs)
)
self.sync_handler = SyncHandler(hs) self.sync_handler = SyncHandler(hs)

View File

@ -0,0 +1,211 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.appservice import ApplicationService
from synapse.types import UserID
import synapse.util.stringutils as stringutils
import logging
logger = logging.getLogger(__name__)
# NB: Purposefully not inheriting BaseHandler since that contains way too much
# setup code which this handler does not need or use. This makes testing a lot
# easier.
class ApplicationServicesHandler(object):
def __init__(self, hs, appservice_api):
self.store = hs.get_datastore()
self.hs = hs
self.appservice_api = appservice_api
@defer.inlineCallbacks
def register(self, app_service):
logger.info("Register -> %s", app_service)
# check the token is recognised
try:
stored_service = yield self.store.get_app_service_by_token(
app_service.token
)
if not stored_service:
raise StoreError(404, "Application service not found")
except StoreError:
raise SynapseError(
403, "Unrecognised application services token. "
"Consult the home server admin.",
errcode=Codes.FORBIDDEN
)
app_service.hs_token = self._generate_hs_token()
# create a sender for this application service which is used when
# creating rooms, etc..
account = yield self.hs.get_handlers().registration_handler.register()
app_service.sender = account[0]
yield self.store.update_app_service(app_service)
defer.returnValue(app_service)
@defer.inlineCallbacks
def unregister(self, token):
logger.info("Unregister as_token=%s", token)
yield self.store.unregister_app_service(token)
@defer.inlineCallbacks
def notify_interested_services(self, event):
"""Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any
prolonged length of time.
Args:
event(Event): The event to push out to interested services.
"""
# Gather interested services
services = yield self._get_services_for_event(event)
if len(services) == 0:
return # no services need notifying
# Do we know this user exists? If not, poke the user query API for
# all services which match that user regex. This needs to block as these
# user queries need to be made BEFORE pushing the event.
yield self._check_user_exists(event.sender)
if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key)
# Fork off pushes to these services - XXX First cut, best effort
for service in services:
self.appservice_api.push(service, event)
@defer.inlineCallbacks
def query_user_exists(self, user_id):
"""Check if any application service knows this user_id exists.
Args:
user_id(str): The user to query if they exist on any AS.
Returns:
True if this user exists on at least one application service.
"""
user_query_services = yield self._get_services_for_user(
user_id=user_id
)
for user_service in user_query_services:
is_known_user = yield self.appservice_api.query_user(
user_service, user_id
)
if is_known_user:
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks
def query_room_alias_exists(self, room_alias):
"""Check if an application service knows this room alias exists.
Args:
room_alias(RoomAlias): The room alias to query.
Returns:
namedtuple: with keys "room_id" and "servers" or None if no
association can be found.
"""
room_alias_str = room_alias.to_string()
alias_query_services = yield self._get_services_for_event(
event=None,
restrict_to=ApplicationService.NS_ALIASES,
alias_list=[room_alias_str]
)
for alias_service in alias_query_services:
is_known_alias = yield self.appservice_api.query_alias(
alias_service, room_alias_str
)
if is_known_alias:
# the alias exists now so don't query more ASes.
result = yield self.store.get_association_from_room_alias(
room_alias
)
defer.returnValue(result)
@defer.inlineCallbacks
def _get_services_for_event(self, event, restrict_to="", alias_list=None):
"""Retrieve a list of application services interested in this event.
Args:
event(Event): The event to check. Can be None if alias_list is not.
restrict_to(str): The namespace to restrict regex tests to.
alias_list: A list of aliases to get services for. If None, this
list is obtained from the database.
Returns:
list<ApplicationService>: A list of services interested in this
event based on the service regex.
"""
member_list = None
if hasattr(event, "room_id"):
# We need to know the aliases associated with this event.room_id,
# if any.
if not alias_list:
alias_list = yield self.store.get_aliases_for_room(
event.room_id
)
# We need to know the members associated with this event.room_id,
# if any.
member_list = yield self.store.get_room_members(
room_id=event.room_id,
membership=Membership.JOIN
)
services = yield self.store.get_app_services()
interested_list = [
s for s in services if (
s.is_interested(event, restrict_to, alias_list, member_list)
)
]
defer.returnValue(interested_list)
@defer.inlineCallbacks
def _get_services_for_user(self, user_id):
services = yield self.store.get_app_services()
interested_list = [
s for s in services if (
s.is_interested_in_user(user_id)
)
]
defer.returnValue(interested_list)
@defer.inlineCallbacks
def _is_unknown_user(self, user_id):
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
# we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes.
defer.returnValue(False)
return
user_info = yield self.store.get_user_by_id(user_id)
defer.returnValue(len(user_info) == 0)
@defer.inlineCallbacks
def _check_user_exists(self, user_id):
unknown_user = yield self._is_unknown_user(user_id)
if unknown_user:
exists = yield self.query_user_exists(user_id)
defer.returnValue(exists)
defer.returnValue(True)
def _generate_hs_token(self):
return stringutils.random_string(24)

View File

@ -37,18 +37,15 @@ class DirectoryHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def create_association(self, user_id, room_alias, room_id, servers=None): def _create_association(self, room_alias, room_id, servers=None):
# general association creation for both human users and app services
# TODO(erikj): Do auth.
if not self.hs.is_mine(room_alias): if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local") raise SynapseError(400, "Room alias must be local")
# TODO(erikj): Change this. # TODO(erikj): Change this.
# TODO(erikj): Add transactions. # TODO(erikj): Add transactions.
# TODO(erikj): Check if there is a current association. # TODO(erikj): Check if there is a current association.
if not servers: if not servers:
servers = yield self.store.get_joined_hosts_for_room(room_id) servers = yield self.store.get_joined_hosts_for_room(room_id)
@ -61,23 +58,78 @@ class DirectoryHandler(BaseHandler):
servers servers
) )
@defer.inlineCallbacks
def create_association(self, user_id, room_alias, room_id, servers=None):
# association creation for human users
# TODO(erikj): Do user auth.
can_create = yield self.can_modify_alias(
room_alias,
user_id=user_id
)
if not can_create:
raise SynapseError(
400, "This alias is reserved by an application service.",
errcode=Codes.EXCLUSIVE
)
yield self._create_association(room_alias, room_id, servers)
@defer.inlineCallbacks
def create_appservice_association(self, service, room_alias, room_id,
servers=None):
if not service.is_interested_in_alias(room_alias.to_string()):
raise SynapseError(
400, "This application service has not reserved"
" this kind of alias.", errcode=Codes.EXCLUSIVE
)
# association creation for app services
yield self._create_association(room_alias, room_id, servers)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_association(self, user_id, room_alias): def delete_association(self, user_id, room_alias):
# association deletion for human users
# TODO Check if server admin # TODO Check if server admin
can_delete = yield self.can_modify_alias(
room_alias,
user_id=user_id
)
if not can_delete:
raise SynapseError(
400, "This alias is reserved by an application service.",
errcode=Codes.EXCLUSIVE
)
yield self._delete_association(room_alias)
@defer.inlineCallbacks
def delete_appservice_association(self, service, room_alias):
if not service.is_interested_in_alias(room_alias.to_string()):
raise SynapseError(
400,
"This application service has not reserved this kind of alias",
errcode=Codes.EXCLUSIVE
)
yield self._delete_association(room_alias)
@defer.inlineCallbacks
def _delete_association(self, room_alias):
if not self.hs.is_mine(room_alias): if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local") raise SynapseError(400, "Room alias must be local")
room_id = yield self.store.delete_room_alias(room_alias) yield self.store.delete_room_alias(room_alias)
if room_id: # TODO - Looks like _update_room_alias_event has never been implemented
yield self._update_room_alias_events(user_id, room_id) # if room_id:
# yield self._update_room_alias_events(user_id, room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_association(self, room_alias): def get_association(self, room_alias):
room_id = None room_id = None
if self.hs.is_mine(room_alias): if self.hs.is_mine(room_alias):
result = yield self.store.get_association_from_room_alias( result = yield self.get_association_from_room_alias(
room_alias room_alias
) )
@ -138,7 +190,7 @@ class DirectoryHandler(BaseHandler):
400, "Room Alias is not hosted on this Home Server" 400, "Room Alias is not hosted on this Home Server"
) )
result = yield self.store.get_association_from_room_alias( result = yield self.get_association_from_room_alias(
room_alias room_alias
) )
@ -166,3 +218,27 @@ class DirectoryHandler(BaseHandler):
"sender": user_id, "sender": user_id,
"content": {"aliases": aliases}, "content": {"aliases": aliases},
}, ratelimit=False) }, ratelimit=False)
@defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias):
result = yield self.store.get_association_from_room_alias(
room_alias
)
if not result:
# Query AS to see if it exists
as_handler = self.hs.get_handlers().appservice_handler
result = yield as_handler.query_room_alias_exists(room_alias)
defer.returnValue(result)
@defer.inlineCallbacks
def can_modify_alias(self, alias, user_id=None):
services = yield self.store.get_app_services()
interested_services = [
s for s in services if s.is_interested_in_alias(alias.to_string())
]
for service in interested_services:
if user_id == service.sender:
# this user IS the app service
defer.returnValue(True)
return
defer.returnValue(len(interested_services) == 0)

View File

@ -802,7 +802,7 @@ class FederationHandler(BaseHandler):
missing_auth = event_auth_events - seen_events missing_auth = event_auth_events - seen_events
if missing_auth: if missing_auth:
logger.debug("Missing auth: %s", missing_auth) logger.info("Missing auth: %s", missing_auth)
# If we don't have all the auth events, we need to get them. # If we don't have all the auth events, we need to get them.
try: try:
remote_auth_chain = yield self.replication_layer.get_event_auth( remote_auth_chain = yield self.replication_layer.get_event_auth(
@ -856,7 +856,7 @@ class FederationHandler(BaseHandler):
if different_auth and not event.internal_metadata.is_outlier(): if different_auth and not event.internal_metadata.is_outlier():
# Do auth conflict res. # Do auth conflict res.
logger.debug("Different auth: %s", different_auth) logger.info("Different auth: %s", different_auth)
different_events = yield defer.gatherResults( different_events = yield defer.gatherResults(
[ [
@ -892,6 +892,8 @@ class FederationHandler(BaseHandler):
context.state_group = None context.state_group = None
if different_auth and not event.internal_metadata.is_outlier(): if different_auth and not event.internal_metadata.is_outlier():
logger.info("Different auth after resolution: %s", different_auth)
# Only do auth resolution if we have something new to say. # Only do auth resolution if we have something new to say.
# We can't rove an auth failure. # We can't rove an auth failure.
do_resolution = False do_resolution = False

View File

@ -16,12 +16,13 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import LoginError, Codes from synapse.api.errors import LoginError, Codes, CodeMessageException
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.util.emailutils import EmailException from synapse.util.emailutils import EmailException
import synapse.util.emailutils as emailutils import synapse.util.emailutils as emailutils
import bcrypt import bcrypt
import json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -96,8 +97,9 @@ class LoginHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _query_email(self, email): def _query_email(self, email):
httpCli = SimpleHttpClient(self.hs) http_client = SimpleHttpClient(self.hs)
data = yield httpCli.get_json( try:
data = yield http_client.get_json(
# TODO FIXME This should be configurable. # TODO FIXME This should be configurable.
# XXX: ID servers need to use HTTPS # XXX: ID servers need to use HTTPS
"http://%s%s" % ( "http://%s%s" % (
@ -109,3 +111,6 @@ class LoginHandler(BaseHandler):
} }
) )
defer.returnValue(data) defer.returnValue(data)
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data)

View File

@ -492,7 +492,7 @@ class PresenceHandler(BaseHandler):
user, domain, remoteusers user, domain, remoteusers
)) ))
yield defer.DeferredList(deferreds) yield defer.DeferredList(deferreds, consumeErrors=True)
def _start_polling_local(self, user, target_user): def _start_polling_local(self, user, target_user):
target_localpart = target_user.localpart target_localpart = target_user.localpart
@ -548,7 +548,7 @@ class PresenceHandler(BaseHandler):
self._stop_polling_remote(user, domain, remoteusers) self._stop_polling_remote(user, domain, remoteusers)
) )
return defer.DeferredList(deferreds) return defer.DeferredList(deferreds, consumeErrors=True)
def _stop_polling_local(self, user, target_user): def _stop_polling_local(self, user, target_user):
for localpart in self._local_pushmap.keys(): for localpart in self._local_pushmap.keys():
@ -729,7 +729,7 @@ class PresenceHandler(BaseHandler):
del self._remote_sendmap[user] del self._remote_sendmap[user]
with PreserveLoggingContext(): with PreserveLoggingContext():
yield defer.DeferredList(deferreds) yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_update_to_local_and_remote(self, observed_user, statuscache, def push_update_to_local_and_remote(self, observed_user, statuscache,
@ -768,7 +768,7 @@ class PresenceHandler(BaseHandler):
) )
) )
yield defer.DeferredList(deferreds) yield defer.DeferredList(deferreds, consumeErrors=True)
defer.returnValue((localusers, remote_domains)) defer.returnValue((localusers, remote_domains))

View File

@ -18,7 +18,8 @@ from twisted.internet import defer
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import ( from synapse.api.errors import (
SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError,
CodeMessageException
) )
from ._base import BaseHandler from ._base import BaseHandler
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
@ -28,6 +29,7 @@ from synapse.http.client import CaptchaServerHttpClient
import base64 import base64
import bcrypt import bcrypt
import json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,6 +66,8 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
token = self._generate_token(user_id) token = self._generate_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
@ -82,6 +86,7 @@ class RegistrationHandler(BaseHandler):
localpart = self._generate_user_id() localpart = self._generate_user_id()
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
token = self._generate_token(user_id) token = self._generate_token(user_id)
yield self.store.register( yield self.store.register(
@ -121,6 +126,27 @@ class RegistrationHandler(BaseHandler):
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))
@defer.inlineCallbacks
def appservice_register(self, user_localpart, as_token):
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = yield self.store.get_app_service_by_token(as_token)
if not service:
raise AuthError(403, "Invalid application service token.")
if not service.is_interested_in_user(user_id):
raise SynapseError(
400, "Invalid user localpart for this application service.",
errcode=Codes.EXCLUSIVE
)
token = self._generate_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
password_hash=""
)
self.distributor.fire("registered_user", user)
defer.returnValue((user_id, token))
@defer.inlineCallbacks @defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response): def check_recaptcha(self, ip, private_key, challenge, response):
"""Checks a recaptcha is correct.""" """Checks a recaptcha is correct."""
@ -167,6 +193,20 @@ class RegistrationHandler(BaseHandler):
# XXX: This should be a deferred list, shouldn't it? # XXX: This should be a deferred list, shouldn't it?
yield self._bind_threepid(c, user_id) yield self._bind_threepid(c, user_id)
@defer.inlineCallbacks
def check_user_id_is_valid(self, user_id):
# valid user IDs must not clash with any user ID namespaces claimed by
# application services.
services = yield self.store.get_app_services()
interested_services = [
s for s in services if s.is_interested_in_user(user_id)
]
if len(interested_services) > 0:
raise SynapseError(
400, "This user ID is reserved by an application service.",
errcode=Codes.EXCLUSIVE
)
def _generate_token(self, user_id): def _generate_token(self, user_id):
# urlsafe variant uses _ and - so use . as the separator and replace # urlsafe variant uses _ and - so use . as the separator and replace
# all =s with .s so http clients don't quote =s when it is used as # all =s with .s so http clients don't quote =s when it is used as
@ -181,14 +221,17 @@ class RegistrationHandler(BaseHandler):
def _threepid_from_creds(self, creds): def _threepid_from_creds(self, creds):
# TODO: get this from the homeserver rather than creating a new one for # TODO: get this from the homeserver rather than creating a new one for
# each request # each request
httpCli = SimpleHttpClient(self.hs) http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable! # XXX: make this configurable!
trustedIdServers = ['matrix.org:8090', 'matrix.org'] trustedIdServers = ['matrix.org:8090', 'matrix.org']
if not creds['idServer'] in trustedIdServers: if not creds['idServer'] in trustedIdServers:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' + logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', creds['idServer']) 'credentials', creds['idServer'])
defer.returnValue(None) defer.returnValue(None)
data = yield httpCli.get_json(
data = {}
try:
data = yield http_client.get_json(
# XXX: This should be HTTPS # XXX: This should be HTTPS
"http://%s%s" % ( "http://%s%s" % (
creds['idServer'], creds['idServer'],
@ -196,6 +239,8 @@ class RegistrationHandler(BaseHandler):
), ),
{'sid': creds['sid'], 'clientSecret': creds['clientSecret']} {'sid': creds['sid'], 'clientSecret': creds['clientSecret']}
) )
except CodeMessageException as e:
data = json.loads(e.msg)
if 'medium' in data: if 'medium' in data:
defer.returnValue(data) defer.returnValue(data)
@ -205,8 +250,10 @@ class RegistrationHandler(BaseHandler):
def _bind_threepid(self, creds, mxid): def _bind_threepid(self, creds, mxid):
yield yield
logger.debug("binding threepid") logger.debug("binding threepid")
httpCli = SimpleHttpClient(self.hs) http_client = SimpleHttpClient(self.hs)
data = yield httpCli.post_urlencoded_get_json( data = None
try:
data = yield http_client.post_urlencoded_get_json(
# XXX: Change when ID servers are all HTTPS # XXX: Change when ID servers are all HTTPS
"http://%s%s" % ( "http://%s%s" % (
creds['idServer'], "/_matrix/identity/api/v1/3pid/bind" creds['idServer'], "/_matrix/identity/api/v1/3pid/bind"
@ -218,6 +265,8 @@ class RegistrationHandler(BaseHandler):
} }
) )
logger.debug("bound threepid") logger.debug("bound threepid")
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -181,7 +181,7 @@ class TypingNotificationHandler(BaseHandler):
}, },
)) ))
yield defer.DeferredList(deferreds, consumeErrors=False) yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def _recv_edu(self, origin, content): def _recv_edu(self, origin, content):

View File

@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.api.errors import CodeMessageException
from synapse.http.agent_name import AGENT_NAME
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
@ -44,6 +43,7 @@ class SimpleHttpClient(object):
# BrowserLikePolicyForHTTPS which will do regular cert validation # BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser' # 'like a browser'
self.agent = Agent(reactor) self.agent = Agent(reactor)
self.version_string = hs.version_string
@defer.inlineCallbacks @defer.inlineCallbacks
def post_urlencoded_get_json(self, uri, args={}): def post_urlencoded_get_json(self, uri, args={}):
@ -55,7 +55,7 @@ class SimpleHttpClient(object):
uri.encode("ascii"), uri.encode("ascii"),
headers=Headers({ headers=Headers({
b"Content-Type": [b"application/x-www-form-urlencoded"], b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [AGENT_NAME], b"User-Agent": [self.version_string],
}), }),
bodyProducer=FileBodyProducer(StringIO(query_bytes)) bodyProducer=FileBodyProducer(StringIO(query_bytes))
) )
@ -85,7 +85,7 @@ class SimpleHttpClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, uri, args={}): def get_json(self, uri, args={}):
""" Get's some json from the given host and path """ Gets some json from the given URI.
Args: Args:
uri (str): The URI to request, not including query parameters uri (str): The URI to request, not including query parameters
@ -93,15 +93,13 @@ class SimpleHttpClient(object):
None. None.
**Note**: The value of each key is assumed to be an iterable **Note**: The value of each key is assumed to be an iterable
and *not* a string. and *not* a string.
Returns: Returns:
Deferred: Succeeds when we get *any* HTTP response. Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
The result of the deferred is a tuple of `(code, response)`, Raises:
where `response` is a dict representing the decoded JSON body. On a non-2xx HTTP response. The response body will be used as the
error message.
""" """
yield
if len(args): if len(args):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes) uri = "%s?%s" % (uri, query_bytes)
@ -110,13 +108,62 @@ class SimpleHttpClient(object):
"GET", "GET",
uri.encode("ascii"), uri.encode("ascii"),
headers=Headers({ headers=Headers({
b"User-Agent": [AGENT_NAME], b"User-Agent": [self.version_string],
}) })
) )
body = yield readBody(response) body = yield readBody(response)
if 200 <= response.code < 300:
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
else:
# NB: This is explicitly not json.loads(body)'d because the contract
# of CodeMessageException is a *string* message. Callers can always
# load it into JSON if they want.
raise CodeMessageException(response.code, body)
@defer.inlineCallbacks
def put_json(self, uri, json_body, args={}):
""" Puts some json to the given URI.
Args:
uri (str): The URI to request, not including query parameters
json_body (dict): The JSON to put in the HTTP body,
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
Raises:
On a non-2xx HTTP response.
"""
if len(args):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
json_str = json.dumps(json_body)
response = yield self.agent.request(
"PUT",
uri.encode("ascii"),
headers=Headers({
b"User-Agent": [self.version_string],
"Content-Type": ["application/json"]
}),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
body = yield readBody(response)
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
else:
# NB: This is explicitly not json.loads(body)'d because the contract
# of CodeMessageException is a *string* message. Callers can always
# load it into JSON if they want.
raise CodeMessageException(response.code, body)
class CaptchaServerHttpClient(SimpleHttpClient): class CaptchaServerHttpClient(SimpleHttpClient):
@ -135,7 +182,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
bodyProducer=FileBodyProducer(StringIO(query_bytes)), bodyProducer=FileBodyProducer(StringIO(query_bytes)),
headers=Headers({ headers=Headers({
b"Content-Type": [b"application/x-www-form-urlencoded"], b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [AGENT_NAME], b"User-Agent": [self.version_string],
}) })
) )

View File

@ -20,7 +20,6 @@ from twisted.web.client import readBody, _AgentBase, _URI
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
from synapse.http.agent_name import AGENT_NAME
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
@ -80,6 +79,7 @@ class MatrixFederationHttpClient(object):
self.server_name = hs.hostname self.server_name = hs.hostname
self.agent = MatrixFederationHttpAgent(reactor) self.agent = MatrixFederationHttpAgent(reactor)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.version_string = hs.version_string
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes, def _create_request(self, destination, method, path_bytes,
@ -87,7 +87,7 @@ class MatrixFederationHttpClient(object):
query_bytes=b"", retry_on_dns_fail=True): query_bytes=b"", retry_on_dns_fail=True):
""" Creates and sends a request to the given url """ Creates and sends a request to the given url
""" """
headers_dict[b"User-Agent"] = [AGENT_NAME] headers_dict[b"User-Agent"] = [self.version_string]
headers_dict[b"Host"] = [destination] headers_dict[b"Host"] = [destination]
url_bytes = urlparse.urlunparse( url_bytes = urlparse.urlunparse(
@ -144,16 +144,16 @@ class MatrixFederationHttpClient(object):
destination, destination,
e e
) )
raise SynapseError(400, "Domain specified not found.") raise
logger.warn( logger.warn(
"Sending request failed to %s: %s %s : %s", "Sending request failed to %s: %s %s: %s - %s",
destination, destination,
method, method,
url_bytes, url_bytes,
e type(e).__name__,
_flatten_response_never_received(e),
) )
_print_ex(e)
if retries_left: if retries_left:
yield sleep(2 ** (5 - retries_left)) yield sleep(2 ** (5 - retries_left))
@ -447,14 +447,6 @@ def _readBodyToFile(response, stream, max_size):
return d return d
def _print_ex(e):
if hasattr(e, "reasons") and e.reasons:
for ex in e.reasons:
_print_ex(ex)
else:
logger.warn(e)
class _JsonProducer(object): class _JsonProducer(object):
""" Used by the twisted http client to create the HTTP body from json """ Used by the twisted http client to create the HTTP body from json
""" """
@ -474,3 +466,13 @@ class _JsonProducer(object):
def stopProducing(self): def stopProducing(self):
pass pass
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
return ", ".join(
_flatten_response_never_received(f.value)
for f in e.reasons
)
else:
return "%s: %s" % (type(e).__name__, e.message,)

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
from synapse.http.agent_name import AGENT_NAME
from synapse.api.errors import ( from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
) )
@ -74,6 +73,7 @@ class JsonResource(HttpServer, resource.Resource):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.path_regexs = {} self.path_regexs = {}
self.version_string = hs.version_string
def register_path(self, method, path_pattern, callback): def register_path(self, method, path_pattern, callback):
self.path_regexs.setdefault(method, []).append( self.path_regexs.setdefault(method, []).append(
@ -189,9 +189,13 @@ class JsonResource(HttpServer, resource.Resource):
return return
# TODO: Only enable CORS for the requests that need it. # TODO: Only enable CORS for the requests that need it.
respond_with_json(request, code, response_json_object, send_cors=True, respond_with_json(
request, code, response_json_object,
send_cors=True,
response_code_message=response_code_message, response_code_message=response_code_message,
pretty_print=self._request_user_agent_is_curl) pretty_print=self._request_user_agent_is_curl,
version_string=self.version_string,
)
@staticmethod @staticmethod
def _request_user_agent_is_curl(request): def _request_user_agent_is_curl(request):
@ -221,18 +225,23 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False, def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False): response_code_message=None, pretty_print=False,
version_string=""):
if not pretty_print: if not pretty_print:
json_bytes = encode_pretty_printed_json(json_object) json_bytes = encode_pretty_printed_json(json_object)
else: else:
json_bytes = encode_canonical_json(json_object) json_bytes = encode_canonical_json(json_object)
return respond_with_json_bytes(request, code, json_bytes, send_cors, return respond_with_json_bytes(
response_code_message=response_code_message) request, code, json_bytes,
send_cors=send_cors,
response_code_message=response_code_message,
version_string=version_string
)
def respond_with_json_bytes(request, code, json_bytes, send_cors=False, def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
response_code_message=None): version_string="", response_code_message=None):
"""Sends encoded JSON in response to the given request. """Sends encoded JSON in response to the given request.
Args: Args:
@ -246,7 +255,7 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
request.setResponseCode(code, message=response_code_message) request.setResponseCode(code, message=response_code_message)
request.setHeader(b"Content-Type", b"application/json") request.setHeader(b"Content-Type", b"application/json")
request.setHeader(b"Server", AGENT_NAME) request.setHeader(b"Server", version_string)
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),)) request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
if send_cors: if send_cors:

View File

@ -50,6 +50,7 @@ class LocalKey(Resource):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.version_string = hs.version_string
self.response_body = encode_canonical_json( self.response_body = encode_canonical_json(
self.response_json_object(hs.config) self.response_json_object(hs.config)
) )
@ -82,7 +83,10 @@ class LocalKey(Resource):
return json_object return json_object
def render_GET(self, request): def render_GET(self, request):
return respond_with_json_bytes(request, 200, self.response_body) return respond_with_json_bytes(
request, 200, self.response_body,
version_string=self.version_string
)
def getChild(self, name, request): def getChild(self, name, request):
if name == '': if name == '':

View File

@ -99,6 +99,12 @@ class Notifier(object):
`extra_users` param. `extra_users` param.
""" """
yield run_on_reactor() yield run_on_reactor()
# poke any interested application service.
self.hs.get_handlers().appservice_handler.notify_interested_services(
event
)
room_id = event.room_id room_id = event.room_id
room_source = self.event_sources.sources["room"] room_source = self.event_sources.sources["room"]
@ -135,7 +141,8 @@ class Notifier(object):
with PreserveLoggingContext(): with PreserveLoggingContext():
yield defer.DeferredList( yield defer.DeferredList(
[notify(l).addErrback(eb) for l in listeners] [notify(l).addErrback(eb) for l in listeners],
consumeErrors=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -203,7 +210,8 @@ class Notifier(object):
with PreserveLoggingContext(): with PreserveLoggingContext():
yield defer.DeferredList( yield defer.DeferredList(
[notify(l).addErrback(eb) for l in listeners] [notify(l).addErrback(eb) for l in listeners],
consumeErrors=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -66,6 +66,7 @@ class HttpPusher(Pusher):
d = { d = {
'notification': { 'notification': {
'id': event['event_id'], 'id': event['event_id'],
'room_id': event['room_id'],
'type': event['type'], 'type': event['type'],
'sender': event['user_id'], 'sender': event['user_id'],
'counts': { # -- we don't mark messages as read yet so 'counts': { # -- we don't mark messages as read yet so

View File

@ -5,7 +5,7 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"syutil>=0.0.3": ["syutil"], "syutil>=0.0.3": ["syutil"],
"matrix_angular_sdk>=0.6.2": ["syweb>=0.6.2"], "matrix_angular_sdk>=0.6.3": ["syweb>=0.6.3"],
"Twisted==14.0.2": ["twisted==14.0.2"], "Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
@ -24,6 +24,11 @@ def github_link(project, version, egg):
return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg) return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg)
DEPENDENCY_LINKS = [ DEPENDENCY_LINKS = [
github_link(
project="pyca/pynacl",
version="d4d3175589b892f6ea7c22f466e0e223853516fa",
egg="pynacl-0.3.0",
),
github_link( github_link(
project="matrix-org/syutil", project="matrix-org/syutil",
version="v0.0.3", version="v0.0.3",
@ -31,14 +36,9 @@ DEPENDENCY_LINKS = [
), ),
github_link( github_link(
project="matrix-org/matrix-angular-sdk", project="matrix-org/matrix-angular-sdk",
version="v0.6.2", version="v0.6.3",
egg="matrix_angular_sdk-0.6.2", egg="matrix_angular_sdk-0.6.3",
), ),
github_link(
project="pyca/pynacl",
version="d4d3175589b892f6ea7c22f466e0e223853516fa",
egg="pynacl-0.3.0",
)
] ]

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2015 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,7 +12,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse import __version__
AGENT_NAME = ("Synapse/%s" % (__version__,)).encode("ascii")

View File

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import register
from synapse.http.server import JsonResource
class AppServiceRestResource(JsonResource):
"""A resource for version 1 of the matrix application service API."""
def __init__(self, hs):
JsonResource.__init__(self, hs)
self.register_servlets(self, hs)
@staticmethod
def register_servlets(appservice_resource, hs):
register.register_servlets(hs, appservice_resource)

View File

@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains base REST classes for constructing client v1 servlets.
"""
from synapse.http.servlet import RestServlet
from synapse.api.urls import APP_SERVICE_PREFIX
import re
import logging
logger = logging.getLogger(__name__)
def as_path_pattern(path_regex):
"""Creates a regex compiled appservice path with the correct path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
SRE_Pattern
"""
return re.compile("^" + APP_SERVICE_PREFIX + path_regex)
class AppServiceRestServlet(RestServlet):
"""A base Synapse REST Servlet for the application services version 1 API.
"""
def __init__(self, hs):
self.hs = hs
self.handler = hs.get_handlers().appservice_handler

View File

@ -0,0 +1,121 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains REST servlets to do with registration: /register"""
from twisted.internet import defer
from base import AppServiceRestServlet, as_path_pattern
from synapse.api.errors import CodeMessageException, SynapseError
from synapse.storage.appservice import ApplicationService
import json
import logging
logger = logging.getLogger(__name__)
class RegisterRestServlet(AppServiceRestServlet):
"""Handles AS registration with the home server.
"""
PATTERN = as_path_pattern("/register$")
@defer.inlineCallbacks
def on_POST(self, request):
params = _parse_json(request)
# sanity check required params
try:
as_token = params["as_token"]
as_url = params["url"]
if (not isinstance(as_token, basestring) or
not isinstance(as_url, basestring)):
raise ValueError
except (KeyError, ValueError):
raise SynapseError(
400, "Missed required keys: as_token(str) / url(str)."
)
namespaces = {
"users": [],
"rooms": [],
"aliases": []
}
if "namespaces" in params:
self._parse_namespace(namespaces, params["namespaces"], "users")
self._parse_namespace(namespaces, params["namespaces"], "rooms")
self._parse_namespace(namespaces, params["namespaces"], "aliases")
app_service = ApplicationService(as_token, as_url, namespaces)
app_service = yield self.handler.register(app_service)
hs_token = app_service.hs_token
defer.returnValue((200, {
"hs_token": hs_token
}))
def _parse_namespace(self, target_ns, origin_ns, ns):
if ns not in target_ns or ns not in origin_ns:
return # nothing to parse / map through to.
possible_regex_list = origin_ns[ns]
if not type(possible_regex_list) == list:
raise SynapseError(400, "Namespace %s isn't an array." % ns)
for regex in possible_regex_list:
if not isinstance(regex, basestring):
raise SynapseError(
400, "Regex '%s' isn't a string in namespace %s" %
(regex, ns)
)
target_ns[ns] = origin_ns[ns]
class UnregisterRestServlet(AppServiceRestServlet):
"""Handles AS registration with the home server.
"""
PATTERN = as_path_pattern("/unregister$")
def on_POST(self, request):
params = _parse_json(request)
try:
as_token = params["as_token"]
if not isinstance(as_token, basestring):
raise ValueError
except (KeyError, ValueError):
raise SynapseError(400, "Missing required key: as_token(str)")
yield self.handler.unregister(as_token)
raise CodeMessageException(500, "Not implemented")
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except ValueError:
raise SynapseError(400, "Content not JSON.")
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)
UnregisterRestServlet(hs).register(http_server)

View File

@ -45,8 +45,6 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_alias): def on_PUT(self, request, room_alias):
user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
if "room_id" not in content: if "room_id" not in content:
raise SynapseError(400, "Missing room_id key", raise SynapseError(400, "Missing room_id key",
@ -69,6 +67,9 @@ class ClientDirectoryServer(ClientV1RestServlet):
dir_handler = self.handlers.directory_handler dir_handler = self.handlers.directory_handler
try:
# try to auth as a user
user, client = yield self.auth.get_user_by_req(request)
try: try:
user_id = user.to_string() user_id = user.to_string()
yield dir_handler.create_association( yield dir_handler.create_association(
@ -80,24 +81,57 @@ class ClientDirectoryServer(ClientV1RestServlet):
except: except:
logger.exception("Failed to create association") logger.exception("Failed to create association")
raise raise
except AuthError:
# try to auth as an application service
service = yield self.auth.get_appservice_by_req(request)
yield dir_handler.create_appservice_association(
service, room_alias, room_id, servers
)
logger.info(
"Application service at %s created alias %s pointing to %s",
service.url,
room_alias.to_string(),
room_id
)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request, room_alias): def on_DELETE(self, request, room_alias):
dir_handler = self.handlers.directory_handler
try:
service = yield self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias)
yield dir_handler.delete_appservice_association(
service, room_alias
)
logger.info(
"Application service at %s deleted alias %s",
service.url,
room_alias.to_string()
)
defer.returnValue((200, {}))
except AuthError:
# fallback to default user behaviour if they aren't an AS
pass
user, client = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(user) is_admin = yield self.auth.is_server_admin(user)
if not is_admin: if not is_admin:
raise AuthError(403, "You need to be a server admin") raise AuthError(403, "You need to be a server admin")
dir_handler = self.handlers.directory_handler
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)
yield dir_handler.delete_association( yield dir_handler.delete_association(
user.to_string(), room_alias user.to_string(), room_alias
) )
logger.info(
"User %s deleted alias %s",
user.to_string(),
room_alias.to_string()
)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -110,7 +110,8 @@ class RegisterRestServlet(ClientV1RestServlet):
stages = { stages = {
LoginType.RECAPTCHA: self._do_recaptcha, LoginType.RECAPTCHA: self._do_recaptcha,
LoginType.PASSWORD: self._do_password, LoginType.PASSWORD: self._do_password,
LoginType.EMAIL_IDENTITY: self._do_email_identity LoginType.EMAIL_IDENTITY: self._do_email_identity,
LoginType.APPLICATION_SERVICE: self._do_app_service
} }
session_info = self._get_session_info(request, session) session_info = self._get_session_info(request, session)
@ -276,6 +277,27 @@ class RegisterRestServlet(ClientV1RestServlet):
self._remove_session(session) self._remove_session(session)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def _do_app_service(self, request, register_json, session):
if "access_token" not in request.args:
raise SynapseError(400, "Expected application service token.")
if "user" not in register_json:
raise SynapseError(400, "Expected 'user' key.")
as_token = request.args["access_token"][0]
user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler
(user_id, token) = yield handler.appservice_register(
user_localpart, as_token
)
self._remove_session(session)
defer.returnValue({
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
})
def _parse_json(request): def _parse_json(request):
try: try:

View File

@ -54,7 +54,7 @@ class BaseMediaResource(Resource):
try: try:
yield request_handler(self, request) yield request_handler(self, request)
except CodeMessageException as e: except CodeMessageException as e:
logger.exception(e) logger.info("Responding with error: %r", e)
respond_with_json( respond_with_json(
request, e.code, cs_exception(e), send_cors=True request, e.code, cs_exception(e), send_cors=True
) )

View File

@ -77,6 +77,7 @@ class BaseHomeServer(object):
'resource_for_content_repo', 'resource_for_content_repo',
'resource_for_server_key', 'resource_for_server_key',
'resource_for_media_repository', 'resource_for_media_repository',
'resource_for_app_services',
'event_sources', 'event_sources',
'ratelimiter', 'ratelimiter',
'keyring', 'keyring',

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.expiringcache import ExpiringCache
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
@ -51,7 +52,6 @@ class _StateCacheEntry(object):
def __init__(self, state, state_group, ts): def __init__(self, state, state_group, ts):
self.state = state self.state = state
self.state_group = state_group self.state_group = state_group
self.ts = ts
class StateHandler(object): class StateHandler(object):
@ -69,12 +69,15 @@ class StateHandler(object):
def start_caching(self): def start_caching(self):
logger.debug("start_caching") logger.debug("start_caching")
self._state_cache = {} self._state_cache = ExpiringCache(
cache_name="state_cache",
clock=self.clock,
max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS*1000,
reset_expiry_on_get=True,
)
def f(): self._state_cache.start()
self._prune_cache()
self.clock.looping_call(f, 5*1000)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""): def get_current_state(self, room_id, event_type=None, state_key=""):
@ -409,34 +412,3 @@ class StateHandler(object):
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
return sorted(events, key=key_func) return sorted(events, key=key_func)
def _prune_cache(self):
logger.debug(
"_prune_cache. before len: %d",
len(self._state_cache.keys())
)
now = self.clock.time_msec()
if len(self._state_cache.keys()) > SIZE_OF_CACHE:
sorted_entries = sorted(
self._state_cache.items(),
key=lambda k, v: v.ts,
)
for k, _ in sorted_entries[SIZE_OF_CACHE:]:
self._state_cache.pop(k)
keys_to_delete = set()
for key, cache_entry in self._state_cache.items():
if now - cache_entry.ts > EVICTION_TIMEOUT_SECONDS*1000:
keys_to_delete.add(key)
for k in keys_to_delete:
self._state_cache.pop(k)
logger.debug(
"_prune_cache. after len: %d",
len(self._state_cache.keys())
)

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from .appservice import ApplicationServiceStore
from .directory import DirectoryStore from .directory import DirectoryStore
from .feedback import FeedbackStore from .feedback import FeedbackStore
from .presence import PresenceStore from .presence import PresenceStore
@ -65,6 +66,7 @@ SCHEMAS = [
"event_signatures", "event_signatures",
"pusher", "pusher",
"media_repository", "media_repository",
"application_services",
"filtering", "filtering",
"rejections", "rejections",
] ]
@ -72,7 +74,9 @@ SCHEMAS = [
# Remember to update this number every time an incompatible change is made to # Remember to update this number every time an incompatible change is made to
# database schema files, so the users will be informed on server restarts. # database schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 12 SCHEMA_VERSION = 13
dir_path = os.path.abspath(os.path.dirname(__file__))
class _RollbackButIsFineException(Exception): class _RollbackButIsFineException(Exception):
@ -86,6 +90,7 @@ class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, FeedbackStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
PresenceStore, TransactionStore, PresenceStore, TransactionStore,
DirectoryStore, KeyStore, StateStore, SignatureStore, DirectoryStore, KeyStore, StateStore, SignatureStore,
ApplicationServiceStore,
EventFederationStore, EventFederationStore,
MediaRepositoryStore, MediaRepositoryStore,
RejectionsStore, RejectionsStore,
@ -580,7 +585,6 @@ def schema_path(schema):
A filesystem path pointing at a ".sql" file. A filesystem path pointing at a ".sql" file.
""" """
dir_path = os.path.dirname(__file__)
schemaPath = os.path.join(dir_path, "schema", schema + ".sql") schemaPath = os.path.join(dir_path, "schema", schema + ".sql")
return schemaPath return schemaPath
@ -637,10 +641,13 @@ def prepare_database(db_conn):
c.executescript(sql_script) c.executescript(sql_script)
db_conn.commit() db_conn.commit()
else:
logger.info("Database is at version %r", user_version)
else: else:
sql_script = "BEGIN TRANSACTION;\n" sql_script = "BEGIN TRANSACTION;\n"
for sql_loc in SCHEMAS: for sql_loc in SCHEMAS:
logger.debug("Applying schema %r", sql_loc)
sql_script += read_schema(sql_loc) sql_script += read_schema(sql_loc)
sql_script += "\n" sql_script += "\n"
sql_script += "COMMIT TRANSACTION;" sql_script += "COMMIT TRANSACTION;"

View File

@ -0,0 +1,244 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.appservice import ApplicationService
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
class ApplicationServiceCache(object):
"""Caches ApplicationServices and provides utility functions on top.
This class is designed to be invoked on incoming events in order to avoid
hammering the database every time to extract a list of application service
regexes.
"""
def __init__(self):
self.services = []
class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs):
super(ApplicationServiceStore, self).__init__(hs)
self.cache = ApplicationServiceCache()
self.cache_defer = self._populate_cache()
@defer.inlineCallbacks
def unregister_app_service(self, token):
"""Unregisters this service.
This removes all AS specific regex and the base URL. The token is the
only thing preserved for future registration attempts.
"""
yield self.cache_defer # make sure the cache is ready
yield self.runInteraction(
"unregister_app_service",
self._unregister_app_service_txn,
token,
)
# update cache TODO: Should this be in the txn?
for service in self.cache.services:
if service.token == token:
service.url = None
service.namespaces = None
service.hs_token = None
def _unregister_app_service_txn(self, txn, token):
# kill the url to prevent pushes
txn.execute(
"UPDATE application_services SET url=NULL WHERE token=?",
(token,)
)
# cleanup regex
as_id = self._get_as_id_txn(txn, token)
if not as_id:
logger.warning(
"unregister_app_service_txn: Failed to find as_id for token=",
token
)
return False
txn.execute(
"DELETE FROM application_services_regex WHERE as_id=?",
(as_id,)
)
return True
@defer.inlineCallbacks
def update_app_service(self, service):
"""Update an application service, clobbering what was previously there.
Args:
service(ApplicationService): The updated service.
"""
yield self.cache_defer # make sure the cache is ready
# NB: There is no "insert" since we provide no public-facing API to
# allocate new ASes. It relies on the server admin inserting the AS
# token into the database manually.
if not service.token or not service.url:
raise StoreError(400, "Token and url must be specified.")
if not service.hs_token:
raise StoreError(500, "No HS token")
yield self.runInteraction(
"update_app_service",
self._update_app_service_txn,
service
)
# update cache TODO: Should this be in the txn?
for (index, cache_service) in enumerate(self.cache.services):
if service.token == cache_service.token:
self.cache.services[index] = service
logger.info("Updated: %s", service)
return
# new entry
self.cache.services.append(service)
logger.info("Updated(new): %s", service)
def _update_app_service_txn(self, txn, service):
as_id = self._get_as_id_txn(txn, service.token)
if not as_id:
logger.warning(
"update_app_service_txn: Failed to find as_id for token=",
service.token
)
return False
txn.execute(
"UPDATE application_services SET url=?, hs_token=?, sender=? "
"WHERE id=?",
(service.url, service.hs_token, service.sender, as_id,)
)
# cleanup regex
txn.execute(
"DELETE FROM application_services_regex WHERE as_id=?",
(as_id,)
)
for (ns_int, ns_str) in enumerate(ApplicationService.NS_LIST):
if ns_str in service.namespaces:
for regex in service.namespaces[ns_str]:
txn.execute(
"INSERT INTO application_services_regex("
"as_id, namespace, regex) values(?,?,?)",
(as_id, ns_int, regex)
)
return True
def _get_as_id_txn(self, txn, token):
cursor = txn.execute(
"SELECT id FROM application_services WHERE token=?",
(token,)
)
res = cursor.fetchone()
if res:
return res[0]
@defer.inlineCallbacks
def get_app_services(self):
yield self.cache_defer # make sure the cache is ready
defer.returnValue(self.cache.services)
@defer.inlineCallbacks
def get_app_service_by_token(self, token, from_cache=True):
"""Get the application service with the given token.
Args:
token (str): The application service token.
from_cache (bool): True to get this service from the cache, False to
check the database.
Raises:
StoreError if there was a problem retrieving this service.
"""
yield self.cache_defer # make sure the cache is ready
if from_cache:
for service in self.cache.services:
if service.token == token:
defer.returnValue(service)
return
defer.returnValue(None)
# TODO: The from_cache=False impl
# TODO: This should be JOINed with the application_services_regex table.
@defer.inlineCallbacks
def _populate_cache(self):
"""Populates the ApplicationServiceCache from the database."""
sql = ("SELECT * FROM application_services LEFT JOIN "
"application_services_regex ON application_services.id = "
"application_services_regex.as_id")
# SQL results in the form:
# [
# {
# 'regex': "something",
# 'url': "something",
# 'namespace': enum,
# 'as_id': 0,
# 'token': "something",
# 'hs_token': "otherthing",
# 'id': 0
# }
# ]
services = {}
results = yield self._execute_and_decode(sql)
for res in results:
as_token = res["token"]
if as_token not in services:
# add the service
services[as_token] = {
"url": res["url"],
"token": as_token,
"hs_token": res["hs_token"],
"sender": res["sender"],
"namespaces": {
ApplicationService.NS_USERS: [],
ApplicationService.NS_ALIASES: [],
ApplicationService.NS_ROOMS: []
}
}
# add the namespace regex if one exists
ns_int = res["namespace"]
if ns_int is None:
continue
try:
services[as_token]["namespaces"][
ApplicationService.NS_LIST[ns_int]].append(
res["regex"]
)
except IndexError:
logger.error("Bad namespace enum '%s'. %s", ns_int, res)
# TODO get last successful txn id f.e. service
for service in services.values():
logger.info("Found application service: %s", service)
self.cache.services.append(ApplicationService(
token=service["token"],
url=service["url"],
namespaces=service["namespaces"],
hs_token=service["hs_token"],
sender=service["sender"]
))

View File

@ -288,7 +288,7 @@ class RoomMemberStore(SQLBaseStore):
deferreds = [self.get_rooms_for_user(u) for u in user_id_list] deferreds = [self.get_rooms_for_user(u) for u in user_id_list]
results = yield defer.DeferredList(deferreds) results = yield defer.DeferredList(deferreds, consumeErrors=True)
# A list of sets of strings giving room IDs for each user # A list of sets of strings giving room IDs for each user
room_id_lists = [set([r.room_id for r in result[1]]) for result in results] room_id_lists = [set([r.room_id for r in result[1]]) for result in results]

View File

@ -0,0 +1,34 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS application_services(
id INTEGER PRIMARY KEY AUTOINCREMENT,
url TEXT,
token TEXT,
hs_token TEXT,
sender TEXT,
UNIQUE(token) ON CONFLICT ROLLBACK
);
CREATE TABLE IF NOT EXISTS application_services_regex(
id INTEGER PRIMARY KEY AUTOINCREMENT,
as_id INTEGER NOT NULL,
namespace INTEGER, /* enum[room_id|room_alias|user_id] */
regex TEXT,
FOREIGN KEY(as_id) REFERENCES application_services(id)
);

View File

@ -0,0 +1,34 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS application_services(
id INTEGER PRIMARY KEY AUTOINCREMENT,
url TEXT,
token TEXT,
hs_token TEXT,
sender TEXT,
UNIQUE(token) ON CONFLICT ROLLBACK
);
CREATE TABLE IF NOT EXISTS application_services_regex(
id INTEGER PRIMARY KEY AUTOINCREMENT,
as_id INTEGER NOT NULL,
namespace INTEGER, /* enum[room_id|room_alias|user_id] */
regex TEXT,
FOREIGN KEY(as_id) REFERENCES application_services(id)
);

View File

@ -99,8 +99,6 @@ class Clock(object):
except: except:
pass pass
return res
given_deferred.addCallbacks(callback=sucess, errback=err) given_deferred.addCallbacks(callback=sucess, errback=err)
timer = self.call_later(time_out, timed_out_fn) timer = self.call_later(time_out, timed_out_fn)

View File

@ -0,0 +1,115 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
logger = logging.getLogger(__name__)
class ExpiringCache(object):
def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
reset_expiry_on_get=False):
"""
Args:
cache_name (str): Name of this cache, used for logging.
clock (Clock)
max_len (int): Max size of dict. If the dict grows larger than this
then the oldest items get automatically evicted. Default is 0,
which indicates there is no max limit.
expiry_ms (int): How long before an item is evicted from the cache
in milliseconds. Default is 0, indicating items never get
evicted based on time.
reset_expiry_on_get (bool): If true, will reset the expiry time for
an item on access. Defaults to False.
"""
self._cache_name = cache_name
self._clock = clock
self._max_len = max_len
self._expiry_ms = expiry_ms
self._reset_expiry_on_get = reset_expiry_on_get
self._cache = {}
def start(self):
if not self._expiry_ms:
# Don't bother starting the loop if things never expire
return
def f():
self._prune_cache()
self._clock.looping_call(f, self._expiry_ms/2)
def __setitem__(self, key, value):
now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value)
# Evict if there are now too many items
if self._max_len and len(self._cache.keys()) > self._max_len:
sorted_entries = sorted(
self._cache.items(),
key=lambda k, v: v.time,
)
for k, _ in sorted_entries[self._max_len:]:
self._cache.pop(k)
def __getitem__(self, key):
entry = self._cache[key]
if self._reset_expiry_on_get:
entry.time = self._clock.time_msec()
return entry.value
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def _prune_cache(self):
if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called
# since we have this check in start too.
return
begin_length = len(self._cache)
now = self._clock.time_msec()
keys_to_delete = set()
for key, cache_entry in self._cache.items():
if now - cache_entry.time > self._expiry_ms:
keys_to_delete.add(key)
for k in keys_to_delete:
self._cache.pop(k)
logger.debug(
"[%s] _prune_cache before: %d, after len: %d",
self._cache_name, begin_length, len(self._cache.keys())
)
class _CacheEntry(object):
def __init__(self, time, value):
self.time = time
self.value = value

153
synapse/util/retryutils.py Normal file
View File

@ -0,0 +1,153 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import CodeMessageException
import logging
logger = logging.getLogger(__name__)
class NotRetryingDestination(Exception):
def __init__(self, retry_last_ts, retry_interval, destination):
msg = "Not retrying server %s." % (destination,)
super(NotRetryingDestination, self).__init__(msg)
self.retry_last_ts = retry_last_ts
self.retry_interval = retry_interval
self.destination = destination
@defer.inlineCallbacks
def get_retry_limiter(destination, clock, store, **kwargs):
"""For a given destination check if we have previously failed to
send a request there and are waiting before retrying the destination.
If we are not ready to retry the destination, this will raise a
NotRetryingDestination exception. Otherwise, will return a Context Manager
that will mark the destination as down if an exception is thrown (excluding
CodeMessageException with code < 500)
Example usage:
try:
limiter = yield get_retry_limiter(destination, clock, store)
with limiter:
response = yield do_request()
except NotRetryingDestination:
# We aren't ready to retry that destination.
raise
"""
retry_last_ts, retry_interval = (0, 0)
retry_timings = yield store.get_destination_retry_timings(
destination
)
if retry_timings:
retry_last_ts, retry_interval = (
retry_timings.retry_last_ts, retry_timings.retry_interval
)
now = int(clock.time_msec())
if retry_last_ts + retry_interval > now:
raise NotRetryingDestination(
retry_last_ts=retry_last_ts,
retry_interval=retry_interval,
destination=destination,
)
defer.returnValue(
RetryDestinationLimiter(
destination,
clock,
store,
retry_interval,
**kwargs
)
)
class RetryDestinationLimiter(object):
def __init__(self, destination, clock, store, retry_interval,
min_retry_interval=5000, max_retry_interval=60 * 60 * 1000,
multiplier_retry_interval=2,):
"""Marks the destination as "down" if an exception is thrown in the
context, except for CodeMessageException with code < 500.
If no exception is raised, marks the destination as "up".
Args:
destination (str)
clock (Clock)
store (DataStore)
retry_interval (int): The next retry interval taken from the
database in milliseconds, or zero if the last request was
successful.
min_retry_interval (int): The minimum retry interval to use after
a failed request, in milliseconds.
max_retry_interval (int): The maximum retry interval to use after
a failed request, in milliseconds.
multiplier_retry_interval (int): The multiplier to use to increase
the retry interval after a failed request.
"""
self.clock = clock
self.store = store
self.destination = destination
self.retry_interval = retry_interval
self.min_retry_interval = min_retry_interval
self.max_retry_interval = max_retry_interval
self.multiplier_retry_interval = multiplier_retry_interval
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
def err(failure):
logger.exception(
"Failed to store set_destination_retry_timings",
failure.value
)
valid_err_code = False
if exc_type is CodeMessageException:
valid_err_code = 0 <= exc_val.code < 500
if exc_type is None or valid_err_code:
# We connected successfully.
if not self.retry_interval:
return
retry_last_ts = 0
self.retry_interval = 0
else:
# We couldn't connect.
if self.retry_interval:
self.retry_interval *= self.multiplier_retry_interval
if self.retry_interval >= self.max_retry_interval:
self.retry_interval = self.max_retry_interval
else:
self.retry_interval = self.min_retry_interval
retry_last_ts = int(self.clock.time_msec())
self.store.set_destination_retry_timings(
self.destination, retry_last_ts, self.retry_interval
).addErrback(err)

139
tests/api/test_auth.py Normal file
View File

@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from mock import Mock
from synapse.api.auth import Auth
from synapse.api.errors import AuthError
class AuthTestCase(unittest.TestCase):
def setUp(self):
self.state_handler = Mock()
self.store = Mock()
self.hs = Mock()
self.hs.get_datastore = Mock(return_value=self.store)
self.hs.get_state_handler = Mock(return_value=self.state_handler)
self.auth = Auth(self.hs)
self.test_user = "@foo:bar"
self.test_token = "_test_token_"
@defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
user_info = {
"name": self.test_user,
"device_id": "nothing",
"token_id": "ditto",
"admin": False
}
self.store.get_user_by_token = Mock(return_value=user_info)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, info) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_token = Mock(return_value=None)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
def test_get_user_by_req_user_missing_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
user_info = {
"name": self.test_user,
"device_id": "nothing",
"token_id": "ditto",
"admin": False
}
self.store.get_user_by_token = Mock(return_value=user_info)
request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
@defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_token = Mock(return_value=None)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, info) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), self.test_user)
def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_token = Mock(return_value=None)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
def test_get_user_by_req_appservice_missing_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_token = Mock(return_value=None)
request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
@defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
masquerading_user_id = "@doppelganger:matrix.org"
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_token = Mock(return_value=None)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, info) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), masquerading_user_id)
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
masquerading_user_id = "@doppelganger:matrix.org"
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_token = Mock(return_value=None)
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)

View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,170 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.appservice import ApplicationService
from mock import Mock, PropertyMock
from tests import unittest
class ApplicationServiceTestCase(unittest.TestCase):
def setUp(self):
self.service = ApplicationService(
url="some_url",
token="some_token",
namespaces={
ApplicationService.NS_USERS: [],
ApplicationService.NS_ROOMS: [],
ApplicationService.NS_ALIASES: []
}
)
self.event = Mock(
type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
)
def test_regex_user_id_prefix_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*"
)
self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested(self.event))
def test_regex_user_id_prefix_no_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*"
)
self.event.sender = "@someone_else:matrix.org"
self.assertFalse(self.service.is_interested(self.event))
def test_regex_room_member_is_checked(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*"
)
self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested(self.event))
def test_regex_room_id_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
"!some_prefix.*some_suffix:matrix.org"
)
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
self.assertTrue(self.service.is_interested(self.event))
def test_regex_room_id_no_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
"!some_prefix.*some_suffix:matrix.org"
)
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
self.assertFalse(self.service.is_interested(self.event))
def test_regex_alias_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
"#irc_.*:matrix.org"
)
self.assertTrue(self.service.is_interested(
self.event,
aliases_for_event=["#irc_foobar:matrix.org", "#athing:matrix.org"]
))
def test_regex_alias_no_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
"#irc_.*:matrix.org"
)
self.assertFalse(self.service.is_interested(
self.event,
aliases_for_event=["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
))
def test_regex_multiple_matches(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
"#irc_.*:matrix.org"
)
self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*"
)
self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested(
self.event,
aliases_for_event=["#irc_barfoo:matrix.org"]
))
def test_restrict_to_rooms(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
"!flibble_.*:matrix.org"
)
self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*"
)
self.event.sender = "@irc_foobar:matrix.org"
self.event.room_id = "!wibblewoo:matrix.org"
self.assertFalse(self.service.is_interested(
self.event,
restrict_to=ApplicationService.NS_ROOMS
))
def test_restrict_to_aliases(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
"#xmpp_.*:matrix.org"
)
self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*"
)
self.event.sender = "@irc_foobar:matrix.org"
self.assertFalse(self.service.is_interested(
self.event,
restrict_to=ApplicationService.NS_ALIASES,
aliases_for_event=["#irc_barfoo:matrix.org"]
))
def test_restrict_to_senders(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
"#xmpp_.*:matrix.org"
)
self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*"
)
self.event.sender = "@xmpp_foobar:matrix.org"
self.assertFalse(self.service.is_interested(
self.event,
restrict_to=ApplicationService.NS_USERS,
aliases_for_event=["#xmpp_barfoo:matrix.org"]
))
def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*"
)
join_list = [
Mock(
type="m.room.member", room_id="!foo:bar", sender="@alice:here",
state_key="@alice:here"
),
Mock(
type="m.room.member", room_id="!foo:bar", sender="@irc_fo:here",
state_key="@irc_fo:here" # AS user
),
Mock(
type="m.room.member", room_id="!foo:bar", sender="@bob:here",
state_key="@bob:here"
)
]
self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue(self.service.is_interested(
event=self.event,
member_list=join_list
))

View File

@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .. import unittest
from synapse.handlers.appservice import ApplicationServicesHandler
from mock import Mock
class AppServiceHandlerTestCase(unittest.TestCase):
""" Tests the ApplicationServicesHandler. """
def setUp(self):
self.mock_store = Mock()
self.mock_as_api = Mock()
hs = Mock()
hs.get_datastore = Mock(return_value=self.mock_store)
self.handler = ApplicationServicesHandler(
hs, self.mock_as_api
)
@defer.inlineCallbacks
def test_notify_interested_services(self):
interested_service = self._mkservice(is_interested=True)
services = [
self._mkservice(is_interested=False),
interested_service,
self._mkservice(is_interested=False)
]
self.mock_store.get_app_services = Mock(return_value=services)
self.mock_store.get_user_by_id = Mock(return_value=[])
event = Mock(
sender="@someone:anywhere",
type="m.room.message",
room_id="!foo:bar"
)
self.mock_as_api.push = Mock()
yield self.handler.notify_interested_services(event)
self.mock_as_api.push.assert_called_once_with(interested_service, event)
@defer.inlineCallbacks
def test_query_room_alias_exists(self):
room_alias_str = "#foo:bar"
room_alias = Mock()
room_alias.to_string = Mock(return_value=room_alias_str)
room_id = "!alpha:bet"
servers = ["aperture"]
interested_service = self._mkservice(is_interested=True)
services = [
self._mkservice(is_interested=False),
interested_service,
self._mkservice(is_interested=False)
]
self.mock_store.get_app_services = Mock(return_value=services)
self.mock_store.get_association_from_room_alias = Mock(
return_value=Mock(room_id=room_id, servers=servers)
)
result = yield self.handler.query_room_alias_exists(room_alias)
self.mock_as_api.query_alias.assert_called_once_with(
interested_service,
room_alias_str
)
self.assertEquals(result.room_id, room_id)
self.assertEquals(result.servers, servers)
def _mkservice(self, is_interested):
service = Mock()
service.is_interested = Mock(return_value=is_interested)
service.token = "mock_service_token"
service.url = "mock_service_url"
return service

View File

@ -61,6 +61,7 @@ class PresenceStateTestCase(unittest.TestCase):
hs.handlers = JustPresenceHandlers(hs) hs.handlers = JustPresenceHandlers(hs)
self.datastore = hs.get_datastore() self.datastore = hs.get_datastore()
self.datastore.get_app_service_by_token = Mock(return_value=None)
def get_presence_list(*a, **kw): def get_presence_list(*a, **kw):
return defer.succeed([]) return defer.succeed([])
@ -147,6 +148,7 @@ class PresenceListTestCase(unittest.TestCase):
hs.handlers = JustPresenceHandlers(hs) hs.handlers = JustPresenceHandlers(hs)
self.datastore = hs.get_datastore() self.datastore = hs.get_datastore()
self.datastore.get_app_service_by_token = Mock(return_value=None)
def has_presence_state(user_localpart): def has_presence_state(user_localpart):
return defer.succeed( return defer.succeed(
@ -292,6 +294,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
hs.handlers.room_member_handler.get_rooms_for_user = get_rooms_for_user hs.handlers.room_member_handler.get_rooms_for_user = get_rooms_for_user
self.mock_datastore = hs.get_datastore() self.mock_datastore = hs.get_datastore()
self.mock_datastore.get_app_service_by_token = Mock(return_value=None)
def get_profile_displayname(user_id): def get_profile_displayname(user_id):
return defer.succeed("Frank") return defer.succeed("Frank")

View File

@ -56,6 +56,8 @@ class V2AlphaRestTestCase(unittest.TestCase):
r.register_servlets(hs, self.mock_resource) r.register_servlets(hs, self.mock_resource)
def make_datastore_mock(self): def make_datastore_mock(self):
return Mock(spec=[ store = Mock(spec=[
"insert_client_ip", "insert_client_ip",
]) ])
store.get_app_service_by_token = Mock(return_value=None)
return store

View File

@ -0,0 +1,110 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.appservice import ApplicationService
from synapse.server import HomeServer
from synapse.storage.appservice import ApplicationServiceStore
from mock import Mock
from tests.utils import SQLiteMemoryDbPool, MockClock
class ApplicationServiceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer(
"test", db_pool=db_pool, clock=MockClock(), config=Mock()
)
self.as_token = "token1"
db_pool.runQuery(
"INSERT INTO application_services(token) VALUES(?)",
(self.as_token,)
)
db_pool.runQuery(
"INSERT INTO application_services(token) VALUES(?)", ("token2",)
)
db_pool.runQuery(
"INSERT INTO application_services(token) VALUES(?)", ("token3",)
)
# must be done after inserts
self.store = ApplicationServiceStore(hs)
@defer.inlineCallbacks
def test_update_and_retrieval_of_service(self):
url = "https://matrix.org/appservices/foobar"
hs_token = "hstok"
user_regex = ["@foobar_.*:matrix.org"]
alias_regex = ["#foobar_.*:matrix.org"]
room_regex = []
service = ApplicationService(
url=url, hs_token=hs_token, token=self.as_token, namespaces={
ApplicationService.NS_USERS: user_regex,
ApplicationService.NS_ALIASES: alias_regex,
ApplicationService.NS_ROOMS: room_regex
})
yield self.store.update_app_service(service)
stored_service = yield self.store.get_app_service_by_token(
self.as_token
)
self.assertEquals(stored_service.token, self.as_token)
self.assertEquals(stored_service.url, url)
self.assertEquals(
stored_service.namespaces[ApplicationService.NS_ALIASES],
alias_regex
)
self.assertEquals(
stored_service.namespaces[ApplicationService.NS_ROOMS],
room_regex
)
self.assertEquals(
stored_service.namespaces[ApplicationService.NS_USERS],
user_regex
)
@defer.inlineCallbacks
def test_retrieve_unknown_service_token(self):
service = yield self.store.get_app_service_by_token("invalid_token")
self.assertEquals(service, None)
@defer.inlineCallbacks
def test_retrieval_of_service(self):
stored_service = yield self.store.get_app_service_by_token(
self.as_token
)
self.assertEquals(stored_service.token, self.as_token)
self.assertEquals(stored_service.url, None)
self.assertEquals(
stored_service.namespaces[ApplicationService.NS_ALIASES],
[]
)
self.assertEquals(
stored_service.namespaces[ApplicationService.NS_ROOMS],
[]
)
self.assertEquals(
stored_service.namespaces[ApplicationService.NS_USERS],
[]
)
@defer.inlineCallbacks
def test_retrieval_of_all_services(self):
services = yield self.store.get_app_services()
self.assertEquals(len(services), 3)

View File

@ -46,10 +46,16 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
if datastore is None: if datastore is None:
db_pool = SQLiteMemoryDbPool() db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare() yield db_pool.prepare()
hs = HomeServer(name, db_pool=db_pool, config=config, **kargs) hs = HomeServer(
name, db_pool=db_pool, config=config,
version_string="Synapse/tests",
**kargs
)
else: else:
hs = HomeServer( hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config, **kargs name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests",
**kargs
) )
defer.returnValue(hs) defer.returnValue(hs)