Merge branch 'develop' into csauth

This commit is contained in:
David Baker 2015-04-17 13:51:10 +01:00
commit cb03fafdf1
45 changed files with 1817 additions and 537 deletions

37
AUTHORS.rst Normal file
View File

@ -0,0 +1,37 @@
Erik Johnston <erik at matrix.org>
* HS core
* Federation API impl
Mark Haines <mark at matrix.org>
* HS core
* Crypto
* Content repository
* CS v2 API impl
Kegan Dougal <kegan at matrix.org>
* HS core
* CS v1 API impl
* AS API impl
Paul "LeoNerd" Evans <paul at matrix.org>
* HS core
* Presence
* Typing Notifications
* Performance metrics and caching layer
Dave Baker <dave at matrix.org>
* Push notifications
* Auth CS v2 impl
Matthew Hodgson <matthew at matrix.org>
* General doc & housekeeping
* Vertobot/vertobridge matrix<->verto PoC
Emmanuel Rohee <manu at matrix.org>
* Supporting iOS clients (testability and fallback registration)
Turned to Dust <dwinslow86 at gmail.com>
* ArchLinux installation instructions
Brabo <brabo at riseup.net>
* Installation instruction fixes

View File

@ -1,3 +1,10 @@
Changes in synapse vX
=====================
* Changed config option from ``disable_registration`` to
``enable_registration``. Old option will be ignored.
Changes in synapse v0.8.1 (2015-03-18) Changes in synapse v0.8.1 (2015-03-18)
====================================== ======================================

118
CONTRIBUTING.rst Normal file
View File

@ -0,0 +1,118 @@
Contributing code to Matrix
===========================
Everyone is welcome to contribute code to Matrix
(https://github.com/matrix-org), provided that they are willing to license
their contributions under the same license as the project itself. We follow a
simple 'inbound=outbound' model for contributions: the act of submitting an
'inbound' contribution means that the contributor agrees to license the code
under the same terms as the project's overall 'outbound' license - in our
case, this is almost always Apache Software License v2 (see LICENSE).
How to contribute
~~~~~~~~~~~~~~~~~
The preferred and easiest way to contribute changes to Matrix is to fork the
relevant project on github, and then create a pull request to ask us to pull
your changes into our repo
(https://help.github.com/articles/using-pull-requests/)
**The single biggest thing you need to know is: please base your changes on
the develop branch - /not/ master.**
We use the master branch to track the most recent release, so that folks who
blindly clone the repo and automatically check out master get something that
works. Develop is the unstable branch where all the development actually
happens: the workflow is that contributors should fork the develop branch to
make a 'feature' branch for a particular contribution, and then make a pull
request to merge this back into the matrix.org 'official' develop branch. We
use github's pull request workflow to review the contribution, and either ask
you to make any refinements needed or merge it and make them ourselves. The
changes will then land on master when we next do a release.
We use Jenkins for continuous integration (http://matrix.org/jenkins), and
typically all pull requests get automatically tested Jenkins: if your change breaks the build, Jenkins will yell about it in #matrix-dev:matrix.org so please lurk there and keep an eye open.
Code style
~~~~~~~~~~
All Matrix projects have a well-defined code-style - and sometimes we've even
got as far as documenting it... For instance, synapse's code style doc lives
at https://github.com/matrix-org/synapse/tree/master/docs/code_style.rst.
Please ensure your changes match the cosmetic style of the existing project,
and **never** mix cosmetic and functional changes in the same commit, as it
makes it horribly hard to review otherwise.
Attribution
~~~~~~~~~~~
Everyone who contributes anything to Matrix is welcome to be listed in the
AUTHORS.rst file for the project in question. Please feel free to include a
change to AUTHORS.rst in your pull request to list yourself and a short
description of the area(s) you've worked on. Also, we sometimes have swag to
give away to contributors - if you feel that Matrix-branded apparel is missing
from your life, please mail us your shipping address to matrix at matrix.org and we'll try to fix it :)
Sign off
~~~~~~~~
In order to have a concrete record that your contribution is intentional
and you agree to license it under the same terms as the project's license, we've adopted the
same lightweight approach that the Linux Kernel
(https://www.kernel.org/doc/Documentation/SubmittingPatches), Docker
(https://github.com/docker/docker/blob/master/CONTRIBUTING.md), and many other
projects use: the DCO (Developer Certificate of Origin:
http://developercertificate.org/). This is a simple declaration that you wrote
the contribution or otherwise have the right to contribute it to Matrix::
Developer Certificate of Origin
Version 1.1
Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
660 York Street, Suite 102,
San Francisco, CA 94110 USA
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.
Developer's Certificate of Origin 1.1
By making a contribution to this project, I certify that:
(a) The contribution was created in whole or in part by me and I
have the right to submit it under the open source license
indicated in the file; or
(b) The contribution is based upon previous work that, to the best
of my knowledge, is covered under an appropriate open source
license and I have the right under that license to submit that
work with modifications, whether created in whole or in part
by me, under the same open source license (unless I am
permitted to submit under a different license), as indicated
in the file; or
(c) The contribution was provided directly to me by some other
person who certified (a), (b) or (c) and I have not modified
it.
(d) I understand and agree that this project and the contribution
are public and that a record of the contribution (including all
personal information I submit with it, including my sign-off) is
maintained indefinitely and may be redistributed consistent with
this project or the open source license(s) involved.
If you agree to this for your contribution, then all that's needed is to
include the line in your commit or pull request comment::
Signed-off-by: Your Name <your@email.example.org>
...using your real name; unfortunately pseudonyms and anonymous contributions
can't be accepted. Git makes this trivial - just use the -s flag when you do
``git commit``, having first set ``user.name`` and ``user.email`` git configs
(which you should have done anyway :)
Conclusion
~~~~~~~~~~
That's it! Matrix is a very open and collaborative project as you might expect given our obsession with open communication. If we're going to successfully matrix together all the fragmented communication technologies out there we are reliant on contributions and collaboration from the community to do so. So please get involved - and we hope you have as much fun hacking on Matrix as we do!

View File

@ -129,7 +129,8 @@ To set up your homeserver, run (in your virtualenv, as before)::
Substituting your host and domain name as appropriate. Substituting your host and domain name as appropriate.
By default, registration of new users is disabled. You can either enable By default, registration of new users is disabled. You can either enable
registration in the config (it is then recommended to also set up CAPTCHA), or registration in the config by specifying ``enable_registration: true``
(it is then recommended to also set up CAPTCHA), or
you can use the command line to register new users:: you can use the command line to register new users::
$ source ~/.synapse/bin/activate $ source ~/.synapse/bin/activate
@ -348,7 +349,7 @@ and port where the server is running. (At the current time synapse does not
support clustering multiple servers into a single logical homeserver). The DNS support clustering multiple servers into a single logical homeserver). The DNS
record would then look something like:: record would then look something like::
$ dig -t srv _matrix._tcp.machine.my.domaine.name $ dig -t srv _matrix._tcp.machine.my.domain.name
_matrix._tcp IN SRV 10 0 8448 machine.my.domain.name. _matrix._tcp IN SRV 10 0 8448 machine.my.domain.name.

View File

@ -7,6 +7,9 @@ matrix:
matrix-bot: matrix-bot:
user_id: '@vertobot:matrix.org' user_id: '@vertobot:matrix.org'
password: '' password: ''
domain: 'matrix.org"
as_url: 'http://localhost:8009'
as_token: 'vertobot123'
verto-bot: verto-bot:
host: webrtc.freeswitch.org host: webrtc.freeswitch.org

View File

@ -33,7 +33,8 @@ for port in 8080 8081 8082; do
--manhole $((port + 1000)) \ --manhole $((port + 1000)) \
--tls-dh-params-path "demo/demo.tls.dh" \ --tls-dh-params-path "demo/demo.tls.dh" \
--media-store-path "demo/media_store.$port" \ --media-store-path "demo/media_store.$port" \
$PARAMS $SYNAPSE_PARAMS \ $PARAMS $SYNAPSE_PARAMS \
--enable-registration
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--config-path "demo/etc/$port.config" \ --config-path "demo/etc/$port.config" \

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.8.1-r2" __version__ = "0.8.1-r3"

View File

@ -216,17 +216,20 @@ class Auth(object):
else: else:
ban_level = 50 # FIXME (erikj): What should we do here? ban_level = 50 # FIXME (erikj): What should we do here?
if Membership.INVITE == membership: if Membership.JOIN != membership:
# TODO (erikj): We should probably handle this more intelligently # JOIN is the only action you can perform if you're not in the room
# PRIVATE join rules.
# Invites are valid iff caller is in the room and target isn't.
if not caller_in_room: # caller isn't joined if not caller_in_room: # caller isn't joined
raise AuthError( raise AuthError(
403, 403,
"%s not in room %s." % (event.user_id, event.room_id,) "%s not in room %s." % (event.user_id, event.room_id,)
) )
elif target_banned:
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
# Invites are valid iff caller is in the room and target isn't.
if target_banned:
raise AuthError( raise AuthError(
403, "%s is banned from the room" % (target_user_id,) 403, "%s is banned from the room" % (target_user_id,)
) )
@ -252,13 +255,7 @@ class Auth(object):
raise AuthError(403, "You are not allowed to join this room") raise AuthError(403, "You are not allowed to join this room")
elif Membership.LEAVE == membership: elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks. # TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
if not caller_in_room: # trying to leave a room you aren't joined
raise AuthError(
403,
"%s not in room %s." % (target_user_id, event.room_id,)
)
elif target_banned and user_level < ban_level:
raise AuthError( raise AuthError(
403, "You cannot unban user &s." % (target_user_id,) 403, "You cannot unban user &s." % (target_user_id,)
) )
@ -493,7 +490,7 @@ class Auth(object):
send_level = send_level_event.content.get("events", {}).get( send_level = send_level_event.content.get("events", {}).get(
event.type event.type
) )
if not send_level: if send_level is None:
if hasattr(event, "state_key"): if hasattr(event, "state_key"):
send_level = send_level_event.content.get( send_level = send_level_event.content.get(
"state_default", 50 "state_default", 50

View File

@ -32,15 +32,13 @@ from twisted.web.resource import Resource
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import Site from twisted.web.server import Site
from synapse.http.server import JsonResource, RootRedirect from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.appservice.v1 import AppServiceRestResource
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.http.server_key_resource import LocalKey from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import ( from synapse.api.urls import (
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, APP_SERVICE_PREFIX, SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, STATIC_PREFIX
STATIC_PREFIX
) )
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
@ -78,9 +76,6 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_federation(self): def build_resource_for_federation(self):
return JsonResource(self) return JsonResource(self)
def build_resource_for_app_services(self):
return AppServiceRestResource(self)
def build_resource_for_web_client(self): def build_resource_for_web_client(self):
import syweb import syweb
syweb_path = os.path.dirname(syweb.__file__) syweb_path = os.path.dirname(syweb.__file__)
@ -141,7 +136,6 @@ class SynapseHomeServer(HomeServer):
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()), (CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()), (SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
(MEDIA_PREFIX, self.get_resource_for_media_repository()), (MEDIA_PREFIX, self.get_resource_for_media_repository()),
(APP_SERVICE_PREFIX, self.get_resource_for_app_services()),
(STATIC_PREFIX, self.get_resource_for_static_content()), (STATIC_PREFIX, self.get_resource_for_static_content()),
] ]

View File

@ -20,6 +20,50 @@ import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ApplicationServiceState(object):
DOWN = "down"
UP = "up"
class AppServiceTransaction(object):
"""Represents an application service transaction."""
def __init__(self, service, id, events):
self.service = service
self.id = id
self.events = events
def send(self, as_api):
"""Sends this transaction using the provided AS API interface.
Args:
as_api(ApplicationServiceApi): The API to use to send.
Returns:
A Deferred which resolves to True if the transaction was sent.
"""
return as_api.push_bulk(
service=self.service,
events=self.events,
txn_id=self.id
)
def complete(self, store):
"""Completes this transaction as successful.
Marks this transaction ID on the application service and removes the
transaction contents from the database.
Args:
store: The database store to operate on.
Returns:
A Deferred which resolves to True if the transaction was completed.
"""
return store.complete_appservice_txn(
service=self.service,
txn_id=self.id
)
class ApplicationService(object): class ApplicationService(object):
"""Defines an application service. This definition is mostly what is """Defines an application service. This definition is mostly what is
provided to the /register AS API. provided to the /register AS API.
@ -35,13 +79,13 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None, def __init__(self, token, url=None, namespaces=None, hs_token=None,
sender=None, txn_id=None): sender=None, id=None):
self.token = token self.token = token
self.url = url self.url = url
self.hs_token = hs_token self.hs_token = hs_token
self.sender = sender self.sender = sender
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.txn_id = txn_id self.id = id
def _check_namespaces(self, namespaces): def _check_namespaces(self, namespaces):
# Sanity check that it is of the form: # Sanity check that it is of the form:
@ -51,7 +95,7 @@ class ApplicationService(object):
# rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...], # rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# } # }
if not namespaces: if not namespaces:
return None namespaces = {}
for ns in ApplicationService.NS_LIST: for ns in ApplicationService.NS_LIST:
if ns not in namespaces: if ns not in namespaces:
@ -155,7 +199,10 @@ class ApplicationService(object):
return self._matches_user(event, member_list) return self._matches_user(event, member_list)
def is_interested_in_user(self, user_id): def is_interested_in_user(self, user_id):
return self._matches_regex(user_id, ApplicationService.NS_USERS) return (
self._matches_regex(user_id, ApplicationService.NS_USERS)
or user_id == self.sender
)
def is_interested_in_alias(self, alias): def is_interested_in_alias(self, alias):
return self._matches_regex(alias, ApplicationService.NS_ALIASES) return self._matches_regex(alias, ApplicationService.NS_ALIASES)
@ -164,7 +211,10 @@ class ApplicationService(object):
return self._matches_regex(room_id, ApplicationService.NS_ROOMS) return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
def is_exclusive_user(self, user_id): def is_exclusive_user(self, user_id):
return self._is_exclusive(ApplicationService.NS_USERS, user_id) return (
self._is_exclusive(ApplicationService.NS_USERS, user_id)
or user_id == self.sender
)
def is_exclusive_alias(self, alias): def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias) return self._is_exclusive(ApplicationService.NS_ALIASES, alias)

View File

@ -72,14 +72,19 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_bulk(self, service, events): def push_bulk(self, service, events, txn_id=None):
events = self._serialize(events) events = self._serialize(events)
if txn_id is None:
logger.warning("push_bulk: Missing txn ID sending events to %s",
service.url)
txn_id = str(0)
txn_id = str(txn_id)
uri = service.url + ("/transactions/%s" % uri = service.url + ("/transactions/%s" %
urllib.quote(str(0))) # TODO txn_ids urllib.quote(txn_id))
response = None
try: try:
response = yield self.put_json( yield self.put_json(
uri=uri, uri=uri,
json_body={ json_body={
"events": events "events": events
@ -87,9 +92,8 @@ class ApplicationServiceApi(SimpleHttpClient):
args={ args={
"access_token": service.hs_token "access_token": service.hs_token
}) })
if response: # just an empty json object defer.returnValue(True)
# TODO: Mark txn as sent successfully return
defer.returnValue(True)
except CodeMessageException as e: except CodeMessageException as e:
logger.warning("push_bulk to %s received %s", uri, e.code) logger.warning("push_bulk to %s received %s", uri, e.code)
except Exception as ex: except Exception as ex:
@ -97,8 +101,8 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def push(self, service, event): def push(self, service, event, txn_id=None):
response = yield self.push_bulk(service, [event]) response = yield self.push_bulk(service, [event], txn_id)
defer.returnValue(response) defer.returnValue(response)
def _serialize(self, events): def _serialize(self, events):

View File

@ -0,0 +1,254 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module controls the reliability for application service transactions.
The nominal flow through this module looks like:
__________
1---ASa[e]-->| Service |--> Queue ASa[f]
2----ASb[e]->| Queuer |
3--ASa[f]--->|__________|-----------+ ASa[e], ASb[e]
V
-````````- +------------+
|````````|<--StoreTxn-|Transaction |
|Database| | Controller |---> SEND TO AS
`--------` +------------+
What happens on SEND TO AS depends on the state of the Application Service:
- If the AS is marked as DOWN, do nothing.
- If the AS is marked as UP, send the transaction.
* SUCCESS : Increment where the AS is up to txn-wise and nuke the txn
contents from the db.
* FAILURE : Marked AS as DOWN and start Recoverer.
Recoverer attempts to recover ASes who have died. The flow for this looks like:
,--------------------- backoff++ --------------.
V |
START ---> Wait exp ------> Get oldest txn ID from ----> FAILURE
backoff DB and try to send it
^ |___________
Mark AS as | V
UP & quit +---------- YES SUCCESS
| | |
NO <--- Have more txns? <------ Mark txn success & nuke <-+
from db; incr AS pos.
Reset backoff.
This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
from synapse.appservice import ApplicationServiceState
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
class AppServiceScheduler(object):
""" Public facing API for this module. Does the required DI to tie the
components together. This also serves as the "event_pool", which in this
case is a simple array.
"""
def __init__(self, clock, store, as_api):
self.clock = clock
self.store = store
self.as_api = as_api
def create_recoverer(service, callback):
return _Recoverer(clock, store, as_api, service, callback)
self.txn_ctrl = _TransactionController(
clock, store, as_api, create_recoverer
)
self.queuer = _ServiceQueuer(self.txn_ctrl)
@defer.inlineCallbacks
def start(self):
logger.info("Starting appservice scheduler")
# check for any DOWN ASes and start recoverers for them.
recoverers = yield _Recoverer.start(
self.clock, self.store, self.as_api, self.txn_ctrl.on_recovered
)
self.txn_ctrl.add_recoverers(recoverers)
def submit_event_for_as(self, service, event):
self.queuer.enqueue(service, event)
class _ServiceQueuer(object):
"""Queues events for the same application service together, sending
transactions as soon as possible. Once a transaction is sent successfully,
this schedules any other events in the queue to run.
"""
def __init__(self, txn_ctrl):
self.queued_events = {} # dict of {service_id: [events]}
self.pending_requests = {} # dict of {service_id: Deferred}
self.txn_ctrl = txn_ctrl
def enqueue(self, service, event):
# if this service isn't being sent something
if not self.pending_requests.get(service.id):
self._send_request(service, [event])
else:
# add to queue for this service
if service.id not in self.queued_events:
self.queued_events[service.id] = []
self.queued_events[service.id].append(event)
def _send_request(self, service, events):
# send request and add callbacks
d = self.txn_ctrl.send(service, events)
d.addBoth(self._on_request_finish)
d.addErrback(self._on_request_fail)
self.pending_requests[service.id] = d
def _on_request_finish(self, service):
self.pending_requests[service.id] = None
# if there are queued events, then send them.
if (service.id in self.queued_events
and len(self.queued_events[service.id]) > 0):
self._send_request(service, self.queued_events[service.id])
self.queued_events[service.id] = []
def _on_request_fail(self, err):
logger.error("AS request failed: %s", err)
class _TransactionController(object):
def __init__(self, clock, store, as_api, recoverer_fn):
self.clock = clock
self.store = store
self.as_api = as_api
self.recoverer_fn = recoverer_fn
# keep track of how many recoverers there are
self.recoverers = []
@defer.inlineCallbacks
def send(self, service, events):
try:
txn = yield self.store.create_appservice_txn(
service=service,
events=events
)
service_is_up = yield self._is_service_up(service)
if service_is_up:
sent = yield txn.send(self.as_api)
if sent:
txn.complete(self.store)
else:
self._start_recoverer(service)
except Exception as e:
logger.exception(e)
self._start_recoverer(service)
# request has finished
defer.returnValue(service)
@defer.inlineCallbacks
def on_recovered(self, recoverer):
self.recoverers.remove(recoverer)
logger.info("Successfully recovered application service AS ID %s",
recoverer.service.id)
logger.info("Remaining active recoverers: %s", len(self.recoverers))
yield self.store.set_appservice_state(
recoverer.service,
ApplicationServiceState.UP
)
def add_recoverers(self, recoverers):
for r in recoverers:
self.recoverers.append(r)
if len(recoverers) > 0:
logger.info("New active recoverers: %s", len(self.recoverers))
@defer.inlineCallbacks
def _start_recoverer(self, service):
yield self.store.set_appservice_state(
service,
ApplicationServiceState.DOWN
)
logger.info(
"Application service falling behind. Starting recoverer. AS ID %s",
service.id
)
recoverer = self.recoverer_fn(service, self.on_recovered)
self.add_recoverers([recoverer])
recoverer.recover()
@defer.inlineCallbacks
def _is_service_up(self, service):
state = yield self.store.get_appservice_state(service)
defer.returnValue(state == ApplicationServiceState.UP or state is None)
class _Recoverer(object):
@staticmethod
@defer.inlineCallbacks
def start(clock, store, as_api, callback):
services = yield store.get_appservices_by_state(
ApplicationServiceState.DOWN
)
recoverers = [
_Recoverer(clock, store, as_api, s, callback) for s in services
]
for r in recoverers:
logger.info("Starting recoverer for AS ID %s which was marked as "
"DOWN", r.service.id)
r.recover()
defer.returnValue(recoverers)
def __init__(self, clock, store, as_api, service, callback):
self.clock = clock
self.store = store
self.as_api = as_api
self.service = service
self.callback = callback
self.backoff_counter = 1
def recover(self):
self.clock.call_later((2 ** self.backoff_counter), self.retry)
def _backoff(self):
# cap the backoff to be around 18h => (2^16) = 65536 secs
if self.backoff_counter < 16:
self.backoff_counter += 1
self.recover()
@defer.inlineCallbacks
def retry(self):
try:
txn = yield self.store.get_oldest_unsent_txn(self.service)
if txn:
logger.info("Retrying transaction %s for AS ID %s",
txn.id, txn.service.id)
sent = yield txn.send(self.as_api)
if sent:
yield txn.complete(self.store)
# reset the backoff counter and retry immediately
self.backoff_counter = 1
yield self.retry()
else:
self._backoff()
else:
self._set_service_recovered()
except Exception as e:
logger.exception(e)
self._backoff()
def _set_service_recovered(self):
self.callback(self)

View File

@ -0,0 +1,31 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class AppServiceConfig(Config):
def __init__(self, args):
super(AppServiceConfig, self).__init__(args)
self.app_service_config_files = args.app_service_config_files
@classmethod
def add_arguments(cls, parser):
super(AppServiceConfig, cls).add_arguments(parser)
group = parser.add_argument_group("appservice")
group.add_argument(
"--app-service-config-files", type=str, nargs='+',
help="A list of application service config files to use."
)

View File

@ -24,12 +24,13 @@ from .email import EmailConfig
from .voip import VoipConfig from .voip import VoipConfig
from .registration import RegistrationConfig from .registration import RegistrationConfig
from .metrics import MetricsConfig from .metrics import MetricsConfig
from .appservice import AppServiceConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
EmailConfig, VoipConfig, RegistrationConfig, EmailConfig, VoipConfig, RegistrationConfig,
MetricsConfig,): MetricsConfig, AppServiceConfig,):
pass pass

View File

@ -25,11 +25,11 @@ class RegistrationConfig(Config):
def __init__(self, args): def __init__(self, args):
super(RegistrationConfig, self).__init__(args) super(RegistrationConfig, self).__init__(args)
# `args.disable_registration` may either be a bool or a string depending # `args.enable_registration` may either be a bool or a string depending
# on if the option was given a value (e.g. --disable-registration=false # on if the option was given a value (e.g. --enable-registration=true
# would set `args.disable_registration` to "false" not False.) # would set `args.enable_registration` to "true" not True.)
self.disable_registration = bool( self.disable_registration = not bool(
distutils.util.strtobool(str(args.disable_registration)) distutils.util.strtobool(str(args.enable_registration))
) )
self.registration_shared_secret = args.registration_shared_secret self.registration_shared_secret = args.registration_shared_secret
@ -39,11 +39,11 @@ class RegistrationConfig(Config):
reg_group = parser.add_argument_group("registration") reg_group = parser.add_argument_group("registration")
reg_group.add_argument( reg_group.add_argument(
"--disable-registration", "--enable-registration",
const=True, const=True,
default=True, default=False,
nargs='?', nargs='?',
help="Disable registration of new users.", help="Enable registration for new users.",
) )
reg_group.add_argument( reg_group.add_argument(
"--registration-shared-secret", type=str, "--registration-shared-secret", type=str,
@ -53,8 +53,8 @@ class RegistrationConfig(Config):
@classmethod @classmethod
def generate_config(cls, args, config_dir_path): def generate_config(cls, args, config_dir_path):
if args.disable_registration is None: if args.enable_registration is None:
args.disable_registration = True args.enable_registration = False
if args.registration_shared_secret is None: if args.registration_shared_secret is None:
args.registration_shared_secret = random_string_with_symbols(50) args.registration_shared_secret = random_string_with_symbols(50)

View File

@ -110,7 +110,7 @@ class ServerConfig(Config):
with open(args.signing_key_path, "w") as signing_key_file: with open(args.signing_key_path, "w") as signing_key_file:
syutil.crypto.signing_key.write_signing_keys( syutil.crypto.signing_key.write_signing_keys(
signing_key_file, signing_key_file,
(syutil.crypto.signing_key.generate_singing_key("auto"),), (syutil.crypto.signing_key.generate_signing_key("auto"),),
) )
else: else:
signing_keys = cls.read_file(args.signing_key_path, "signing_key") signing_keys = cls.read_file(args.signing_key_path, "signing_key")

View File

@ -46,9 +46,10 @@ def _event_dict_property(key):
class EventBase(object): class EventBase(object):
def __init__(self, event_dict, signatures={}, unsigned={}, def __init__(self, event_dict, signatures={}, unsigned={},
internal_metadata_dict={}): internal_metadata_dict={}, rejected_reason=None):
self.signatures = signatures self.signatures = signatures
self.unsigned = unsigned self.unsigned = unsigned
self.rejected_reason = rejected_reason
self._event_dict = event_dict self._event_dict = event_dict
@ -109,7 +110,7 @@ class EventBase(object):
class FrozenEvent(EventBase): class FrozenEvent(EventBase):
def __init__(self, event_dict, internal_metadata_dict={}): def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
event_dict = dict(event_dict) event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a # Signatures is a dict of dicts, and this is faster than doing a
@ -128,6 +129,7 @@ class FrozenEvent(EventBase):
signatures=signatures, signatures=signatures,
unsigned=unsigned, unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict, internal_metadata_dict=internal_metadata_dict,
rejected_reason=rejected_reason,
) )
@staticmethod @staticmethod

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.appservice.scheduler import AppServiceScheduler
from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.api import ApplicationServiceApi
from .register import RegistrationHandler from .register import RegistrationHandler
from .room import ( from .room import (
@ -56,8 +57,13 @@ class Handlers(object):
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs) self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)
asapi = ApplicationServiceApi(hs)
self.appservice_handler = ApplicationServicesHandler( self.appservice_handler = ApplicationServicesHandler(
hs, ApplicationServiceApi(hs) hs, asapi, AppServiceScheduler(
clock=hs.get_clock(),
store=hs.get_datastore(),
as_api=asapi
)
) )
self.sync_handler = SyncHandler(hs) self.sync_handler = SyncHandler(hs)
self.auth_handler = AuthHandler(hs) self.auth_handler = AuthHandler(hs)

View File

@ -16,57 +16,36 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.types import UserID from synapse.types import UserID
import synapse.util.stringutils as stringutils
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def log_failure(failure):
logger.error(
"Application Services Failure",
exc_info=(
failure.type,
failure.value,
failure.getTracebackObject()
)
)
# NB: Purposefully not inheriting BaseHandler since that contains way too much # NB: Purposefully not inheriting BaseHandler since that contains way too much
# setup code which this handler does not need or use. This makes testing a lot # setup code which this handler does not need or use. This makes testing a lot
# easier. # easier.
class ApplicationServicesHandler(object): class ApplicationServicesHandler(object):
def __init__(self, hs, appservice_api): def __init__(self, hs, appservice_api, appservice_scheduler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self.appservice_api = appservice_api self.appservice_api = appservice_api
self.scheduler = appservice_scheduler
@defer.inlineCallbacks self.started_scheduler = False
def register(self, app_service):
logger.info("Register -> %s", app_service)
# check the token is recognised
try:
stored_service = yield self.store.get_app_service_by_token(
app_service.token
)
if not stored_service:
raise StoreError(404, "Application service not found")
except StoreError:
raise SynapseError(
403, "Unrecognised application services token. "
"Consult the home server admin.",
errcode=Codes.FORBIDDEN
)
app_service.hs_token = self._generate_hs_token()
# create a sender for this application service which is used when
# creating rooms, etc..
account = yield self.hs.get_handlers().registration_handler.register()
app_service.sender = account[0]
yield self.store.update_app_service(app_service)
defer.returnValue(app_service)
@defer.inlineCallbacks
def unregister(self, token):
logger.info("Unregister as_token=%s", token)
yield self.store.unregister_app_service(token)
@defer.inlineCallbacks @defer.inlineCallbacks
def notify_interested_services(self, event): def notify_interested_services(self, event):
@ -90,9 +69,13 @@ class ApplicationServicesHandler(object):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key) yield self._check_user_exists(event.state_key)
# Fork off pushes to these services - XXX First cut, best effort if not self.started_scheduler:
self.scheduler.start().addErrback(log_failure)
self.started_scheduler = True
# Fork off pushes to these services
for service in services: for service in services:
self.appservice_api.push(service, event) self.scheduler.submit_event_for_as(service, event)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_user_exists(self, user_id): def query_user_exists(self, user_id):
@ -197,7 +180,14 @@ class ApplicationServicesHandler(object):
return return
user_info = yield self.store.get_user_by_id(user_id) user_info = yield self.store.get_user_by_id(user_id)
defer.returnValue(len(user_info) == 0) if len(user_info) > 0:
defer.returnValue(False)
return
# user not found; could be the AS though, so check.
services = yield self.store.get_app_services()
service_list = [s for s in services if s.sender == user_id]
defer.returnValue(len(service_list) == 0)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_user_exists(self, user_id): def _check_user_exists(self, user_id):
@ -206,6 +196,3 @@ class ApplicationServicesHandler(object):
exists = yield self.query_user_exists(user_id) exists = yield self.query_user_exists(user_id)
defer.returnValue(exists) defer.returnValue(exists)
defer.returnValue(True) defer.returnValue(True)
def _generate_hs_token(self):
return stringutils.random_string(24)

View File

@ -201,10 +201,18 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id) target_user = UserID.from_string(target_user_id)
extra_users.append(target_user) extra_users.append(target_user)
yield self.notifier.on_new_room_event( d = self.notifier.on_new_room_event(
event, extra_users=extra_users event, extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
@ -427,10 +435,18 @@ class FederationHandler(BaseHandler):
auth_events=auth_events, auth_events=auth_events,
) )
yield self.notifier.on_new_room_event( d = self.notifier.on_new_room_event(
new_event, extra_users=[joinee] new_event, extra_users=[joinee]
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
new_event.event_id, f.value
)
d.addErrback(log_failure)
logger.debug("Finished joining %s to %s", joinee, room_id) logger.debug("Finished joining %s to %s", joinee, room_id)
finally: finally:
room_queue = self.room_queues[room_id] room_queue = self.room_queues[room_id]
@ -500,10 +516,18 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id) target_user = UserID.from_string(target_user_id)
extra_users.append(target_user) extra_users.append(target_user)
yield self.notifier.on_new_room_event( d = self.notifier.on_new_room_event(
event, extra_users=extra_users event, extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN: if event.content["membership"] == Membership.JOIN:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
@ -574,10 +598,18 @@ class FederationHandler(BaseHandler):
) )
target_user = UserID.from_string(event.state_key) target_user = UserID.from_string(event.state_key)
yield self.notifier.on_new_room_event( d = self.notifier.on_new_room_event(
event, extra_users=[target_user], event, extra_users=[target_user],
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -33,6 +33,10 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__) metrics = synapse.metrics.get_metrics_for(__name__)
# Don't bother bumping "last active" time if it differs by less than 60 seconds
LAST_ACTIVE_GRANULARITY = 60*1000
# TODO(paul): Maybe there's one of these I can steal from somewhere # TODO(paul): Maybe there's one of these I can steal from somewhere
def partition(l, func): def partition(l, func):
"""Partition the list by the result of func applied to each element.""" """Partition the list by the result of func applied to each element."""
@ -282,6 +286,10 @@ class PresenceHandler(BaseHandler):
if now is None: if now is None:
now = self.clock.time_msec() now = self.clock.time_msec()
prev_state = self._get_or_make_usercache(user)
if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
return
self.changed_presencelike_data(user, {"last_active": now}) self.changed_presencelike_data(user, {"last_active": now})
def changed_presencelike_data(self, user, state): def changed_presencelike_data(self, user, state):

View File

@ -223,6 +223,7 @@ class TypingNotificationEventSource(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self._handler = None self._handler = None
self._room_member_handler = None
def handler(self): def handler(self):
# Avoid cyclic dependency in handler setup # Avoid cyclic dependency in handler setup
@ -230,6 +231,11 @@ class TypingNotificationEventSource(object):
self._handler = self.hs.get_handlers().typing_notification_handler self._handler = self.hs.get_handlers().typing_notification_handler
return self._handler return self._handler
def room_member_handler(self):
if not self._room_member_handler:
self._room_member_handler = self.hs.get_handlers().room_member_handler
return self._room_member_handler
def _make_event_for(self, room_id): def _make_event_for(self, room_id):
typing = self.handler()._room_typing[room_id] typing = self.handler()._room_typing[room_id]
return { return {
@ -240,19 +246,25 @@ class TypingNotificationEventSource(object):
}, },
} }
@defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit): def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key) from_key = int(from_key)
handler = self.handler() handler = self.handler()
joined_room_ids = (
yield self.room_member_handler().get_joined_rooms_for_user(user)
)
events = [] events = []
for room_id in handler._room_serials: for room_id in handler._room_serials:
if room_id not in joined_room_ids:
continue
if handler._room_serials[room_id] <= from_key: if handler._room_serials[room_id] <= from_key:
continue continue
# TODO: check if user is in room
events.append(self._make_event_for(room_id)) events.append(self._make_event_for(room_id))
return (events, handler._latest_room_serial) defer.returnValue((events, handler._latest_room_serial))
def get_current_key(self): def get_current_key(self):
return self.handler()._latest_room_serial return self.handler()._latest_room_serial

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
import logging import logging
from resource import getrusage, getpagesize, RUSAGE_SELF from resource import getrusage, getpagesize, RUSAGE_SELF
import os
import stat
from .metric import ( from .metric import (
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
@ -109,3 +111,36 @@ resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000)
# pages # pages
resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * PAGE_SIZE) resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * PAGE_SIZE)
TYPES = {
stat.S_IFSOCK: "SOCK",
stat.S_IFLNK: "LNK",
stat.S_IFREG: "REG",
stat.S_IFBLK: "BLK",
stat.S_IFDIR: "DIR",
stat.S_IFCHR: "CHR",
stat.S_IFIFO: "FIFO",
}
def _process_fds():
counts = {(k,): 0 for k in TYPES.values()}
counts[("other",)] = 0
for fd in os.listdir("/proc/self/fd"):
try:
s = os.stat("/proc/self/fd/%s" % (fd))
fmt = stat.S_IFMT(s.st_mode)
if fmt in TYPES:
t = TYPES[fmt]
else:
t = "other"
counts[(t,)] += 1
except OSError:
# the dirh itself used by listdir() is usually missing by now
pass
return counts
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])

View File

@ -59,10 +59,11 @@ class _NotificationListener(object):
self.limit = limit self.limit = limit
self.timeout = timeout self.timeout = timeout
self.deferred = deferred self.deferred = deferred
self.rooms = rooms self.rooms = rooms
self.timer = None
self.pending_notifications = [] def notified(self):
return self.deferred.called
def notify(self, notifier, events, start_token, end_token): def notify(self, notifier, events, start_token, end_token):
""" Inform whoever is listening about the new events. This will """ Inform whoever is listening about the new events. This will
@ -78,16 +79,27 @@ class _NotificationListener(object):
except defer.AlreadyCalledError: except defer.AlreadyCalledError:
pass pass
# Should the following be done be using intrusively linked lists?
# -- erikj
for room in self.rooms: for room in self.rooms:
lst = notifier.room_to_listeners.get(room, set()) lst = notifier.room_to_listeners.get(room, set())
lst.discard(self) lst.discard(self)
notifier.user_to_listeners.get(self.user, set()).discard(self) notifier.user_to_listeners.get(self.user, set()).discard(self)
if self.appservice: if self.appservice:
notifier.appservice_to_listeners.get( notifier.appservice_to_listeners.get(
self.appservice, set() self.appservice, set()
).discard(self) ).discard(self)
# Cancel the timeout for this notifer if one exists.
if self.timer is not None:
try:
notifier.clock.cancel_call_later(self.timer)
except:
logger.exception("Failed to cancel notifier timer")
class Notifier(object): class Notifier(object):
""" This class is responsible for notifying any listeners when there are """ This class is responsible for notifying any listeners when there are
@ -161,10 +173,18 @@ class Notifier(object):
room_source = self.event_sources.sources["room"] room_source = self.event_sources.sources["room"]
listeners = self.room_to_listeners.get(room_id, set()).copy() room_listeners = self.room_to_listeners.get(room_id, set())
_discard_if_notified(room_listeners)
listeners = room_listeners.copy()
for user in extra_users: for user in extra_users:
listeners |= self.user_to_listeners.get(user, set()).copy() user_listeners = self.user_to_listeners.get(user, set())
_discard_if_notified(user_listeners)
listeners |= user_listeners
for appservice in self.appservice_to_listeners: for appservice in self.appservice_to_listeners:
# TODO (kegan): Redundant appservice listener checks? # TODO (kegan): Redundant appservice listener checks?
@ -173,9 +193,13 @@ class Notifier(object):
# receive *invites* for users they are interested in. Does this # receive *invites* for users they are interested in. Does this
# make the room_to_listeners check somewhat obselete? # make the room_to_listeners check somewhat obselete?
if appservice.is_interested(event): if appservice.is_interested(event):
listeners |= self.appservice_to_listeners.get( app_listeners = self.appservice_to_listeners.get(
appservice, set() appservice, set()
).copy() )
_discard_if_notified(app_listeners)
listeners |= app_listeners
logger.debug("on_new_room_event listeners %s", listeners) logger.debug("on_new_room_event listeners %s", listeners)
@ -226,10 +250,18 @@ class Notifier(object):
listeners = set() listeners = set()
for user in users: for user in users:
listeners |= self.user_to_listeners.get(user, set()).copy() user_listeners = self.user_to_listeners.get(user, set())
_discard_if_notified(user_listeners)
listeners |= user_listeners
for room in rooms: for room in rooms:
listeners |= self.room_to_listeners.get(room, set()).copy() room_listeners = self.room_to_listeners.get(room, set())
_discard_if_notified(room_listeners)
listeners |= room_listeners
@defer.inlineCallbacks @defer.inlineCallbacks
def notify(listener): def notify(listener):
@ -300,14 +332,20 @@ class Notifier(object):
self._register_with_keys(listener[0]) self._register_with_keys(listener[0])
result = yield callback() result = yield callback()
timer = [None]
if timeout: if timeout:
timed_out = [False] timed_out = [False]
def _timeout_listener(): def _timeout_listener():
timed_out[0] = True timed_out[0] = True
timer[0] = None
listener[0].notify(self, [], from_token, from_token) listener[0].notify(self, [], from_token, from_token)
self.clock.call_later(timeout/1000., _timeout_listener) # We create multiple notification listeners so we have to manage
# canceling the timeout ourselves.
timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
while not result and not timed_out[0]: while not result and not timed_out[0]:
yield deferred yield deferred
deferred = defer.Deferred() deferred = defer.Deferred()
@ -322,6 +360,12 @@ class Notifier(object):
self._register_with_keys(listener[0]) self._register_with_keys(listener[0])
result = yield callback() result = yield callback()
if timer[0] is not None:
try:
self.clock.cancel_call_later(timer[0])
except:
logger.exception("Failed to cancel notifer timer")
defer.returnValue(result) defer.returnValue(result)
def get_events_for(self, user, rooms, pagination_config, timeout): def get_events_for(self, user, rooms, pagination_config, timeout):
@ -360,6 +404,8 @@ class Notifier(object):
def _timeout_listener(): def _timeout_listener():
# TODO (erikj): We should probably set to_token to the current # TODO (erikj): We should probably set to_token to the current
# max rather than reusing from_token. # max rather than reusing from_token.
# Remove the timer from the listener so we don't try to cancel it.
listener.timer = None
listener.notify( listener.notify(
self, self,
[], [],
@ -375,8 +421,11 @@ class Notifier(object):
if not timeout: if not timeout:
_timeout_listener() _timeout_listener()
else: else:
self.clock.call_later(timeout/1000.0, _timeout_listener) # Only add the timer if the listener hasn't been notified
if not listener.notified():
listener.timer = self.clock.call_later(
timeout/1000.0, _timeout_listener
)
return return
@log_function @log_function
@ -427,3 +476,17 @@ class Notifier(object):
listeners = self.room_to_listeners.setdefault(room_id, set()) listeners = self.room_to_listeners.setdefault(room_id, set())
listeners |= new_listeners listeners |= new_listeners
for l in new_listeners:
l.rooms.add(room_id)
def _discard_if_notified(listener_set):
"""Remove any 'stale' listeners from the given set.
"""
to_discard = set()
for l in listener_set:
if l.notified():
to_discard.add(l)
listener_set -= to_discard

View File

@ -4,7 +4,7 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"syutil>=0.0.3": ["syutil"], "syutil>=0.0.4": ["syutil"],
"Twisted==14.0.2": ["twisted==14.0.2"], "Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
@ -43,8 +43,8 @@ DEPENDENCY_LINKS = [
), ),
github_link( github_link(
project="matrix-org/syutil", project="matrix-org/syutil",
version="v0.0.3", version="v0.0.4",
egg="syutil-0.0.3", egg="syutil-0.0.4",
), ),
github_link( github_link(
project="matrix-org/matrix-angular-sdk", project="matrix-org/matrix-angular-sdk",

View File

@ -1,48 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains base REST classes for constructing client v1 servlets.
"""
from synapse.http.servlet import RestServlet
from synapse.api.urls import APP_SERVICE_PREFIX
import re
import logging
logger = logging.getLogger(__name__)
def as_path_pattern(path_regex):
"""Creates a regex compiled appservice path with the correct path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
SRE_Pattern
"""
return re.compile("^" + APP_SERVICE_PREFIX + path_regex)
class AppServiceRestServlet(RestServlet):
"""A base Synapse REST Servlet for the application services version 1 API.
"""
def __init__(self, hs):
self.hs = hs
self.handler = hs.get_handlers().appservice_handler

View File

@ -1,99 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains REST servlets to do with registration: /register"""
from twisted.internet import defer
from base import AppServiceRestServlet, as_path_pattern
from synapse.api.errors import CodeMessageException, SynapseError
from synapse.storage.appservice import ApplicationService
import json
import logging
logger = logging.getLogger(__name__)
class RegisterRestServlet(AppServiceRestServlet):
"""Handles AS registration with the home server.
"""
PATTERN = as_path_pattern("/register$")
@defer.inlineCallbacks
def on_POST(self, request):
params = _parse_json(request)
# sanity check required params
try:
as_token = params["as_token"]
as_url = params["url"]
if (not isinstance(as_token, basestring) or
not isinstance(as_url, basestring)):
raise ValueError
except (KeyError, ValueError):
raise SynapseError(
400, "Missed required keys: as_token(str) / url(str)."
)
try:
app_service = ApplicationService(
as_token, as_url, params["namespaces"]
)
except ValueError as e:
raise SynapseError(400, e.message)
app_service = yield self.handler.register(app_service)
hs_token = app_service.hs_token
defer.returnValue((200, {
"hs_token": hs_token
}))
class UnregisterRestServlet(AppServiceRestServlet):
"""Handles AS registration with the home server.
"""
PATTERN = as_path_pattern("/unregister$")
def on_POST(self, request):
params = _parse_json(request)
try:
as_token = params["as_token"]
if not isinstance(as_token, basestring):
raise ValueError
except (KeyError, ValueError):
raise SynapseError(400, "Missing required key: as_token(str)")
yield self.handler.unregister(as_token)
raise CodeMessageException(500, "Not implemented")
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except ValueError as e:
logger.warn(e)
raise SynapseError(400, "Content not JSON.")
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)
UnregisterRestServlet(hs).register(http_server)

View File

@ -12,18 +12,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import register
from synapse.http.server import JsonResource
class AppServiceRestResource(JsonResource):
"""A resource for version 1 of the matrix application service API."""
def __init__(self, hs):
JsonResource.__init__(self, hs)
self.register_servlets(self, hs)
@staticmethod
def register_servlets(appservice_resource, hs):
register.register_servlets(hs, appservice_resource)

View File

@ -80,7 +80,6 @@ class BaseHomeServer(object):
'resource_for_content_repo', 'resource_for_content_repo',
'resource_for_server_key', 'resource_for_server_key',
'resource_for_media_repository', 'resource_for_media_repository',
'resource_for_app_services',
'resource_for_metrics', 'resource_for_metrics',
'event_sources', 'event_sources',
'ratelimiter', 'ratelimiter',

View File

@ -14,9 +14,10 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from .appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore
)
from ._base import Cache from ._base import Cache
from .appservice import ApplicationServiceStore
from .directory import DirectoryStore from .directory import DirectoryStore
from .events import EventsStore from .events import EventsStore
from .presence import PresenceStore from .presence import PresenceStore
@ -50,7 +51,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 14 SCHEMA_VERSION = 15
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))
@ -71,6 +72,7 @@ class DataStore(RoomMemberStore, RoomStore,
FilteringStore, FilteringStore,
PusherStore, PusherStore,
PushRuleStore, PushRuleStore,
ApplicationServiceTransactionStore,
EventsStore, EventsStore,
): ):

View File

@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
sql_logger = logging.getLogger("synapse.storage.SQL") sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn") transaction_logger = logging.getLogger("synapse.storage.txn")
perf_logger = logging.getLogger("synapse.storage.TIME")
metrics = synapse.metrics.get_metrics_for("synapse.storage") metrics = synapse.metrics.get_metrics_for("synapse.storage")
@ -55,10 +56,14 @@ cache_counter = metrics.register_cache(
class Cache(object): class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1): def __init__(self, name, max_entries=1000, keylen=1, lru=False):
self.cache = OrderedDict() if lru:
self.cache = LruCache(max_size=max_entries)
self.max_entries = None
else:
self.cache = OrderedDict()
self.max_entries = max_entries
self.max_entries = max_entries
self.name = name self.name = name
self.keylen = keylen self.keylen = keylen
@ -82,8 +87,9 @@ class Cache(object):
if len(keyargs) != self.keylen: if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen) raise ValueError("Expected a key to have %d items", self.keylen)
while len(self.cache) > self.max_entries: if self.max_entries is not None:
self.cache.popitem(last=False) while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache[keyargs] = value self.cache[keyargs] = value
@ -94,9 +100,7 @@ class Cache(object):
self.cache.pop(keyargs, None) self.cache.pop(keyargs, None)
# TODO(paul): def cached(max_entries=1000, num_args=1, lru=False):
# * consider other eviction strategies - LRU?
def cached(max_entries=1000, num_args=1):
""" A method decorator that applies a memoizing cache around the function. """ A method decorator that applies a memoizing cache around the function.
The function is presumed to take zero or more arguments, which are used in The function is presumed to take zero or more arguments, which are used in
@ -115,6 +119,7 @@ def cached(max_entries=1000, num_args=1):
name=orig.__name__, name=orig.__name__,
max_entries=max_entries, max_entries=max_entries,
keylen=num_args, keylen=num_args,
lru=lru,
) )
@functools.wraps(orig) @functools.wraps(orig)
@ -237,10 +242,8 @@ class SQLBaseStore(object):
self._txn_perf_counters = PerformanceCounters() self._txn_perf_counters = PerformanceCounters()
self._get_event_counters = PerformanceCounters() self._get_event_counters = PerformanceCounters()
self._get_event_cache = LruCache(hs.config.event_cache_size) self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size)
# Pretend the getEventCache is just another named cache
caches_by_name["*getEvent*"] = self._get_event_cache
def start_profiling(self): def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec() self._previous_loop_ts = self._clock.time_msec()
@ -264,7 +267,7 @@ class SQLBaseStore(object):
time_now - time_then, limit=3 time_now - time_then, limit=3
) )
logger.info( perf_logger.info(
"Total database time: %.3f%% {%s} {%s}", "Total database time: %.3f%% {%s} {%s}",
ratio * 100, top_three_counters, top_3_event_counters ratio * 100, top_three_counters, top_3_event_counters
) )
@ -728,6 +731,12 @@ class SQLBaseStore(object):
return [e for e in events if e] return [e for e in events if e]
def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True):
for get_prev_content in (False, True):
self._get_event_cache.invalidate(event_id, check_redacted,
get_prev_content)
def _get_event_txn(self, txn, event_id, check_redacted=True, def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False): get_prev_content=False, allow_rejected=False):
@ -738,16 +747,14 @@ class SQLBaseStore(object):
sql_getevents_timer.inc_by(curr_time - last_time, desc) sql_getevents_timer.inc_by(curr_time - last_time, desc)
return curr_time return curr_time
cache = self._get_event_cache.setdefault(event_id, {})
try: try:
# Separate cache entries for each way to invoke _get_event_txn ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
ret = cache[(check_redacted, get_prev_content, allow_rejected)]
cache_counter.inc_hits("*getEvent*") if allow_rejected or not ret.rejected_reason:
return ret return ret
else:
return None
except KeyError: except KeyError:
cache_counter.inc_misses("*getEvent*")
pass pass
finally: finally:
start_time = update_counter("event_cache", start_time) start_time = update_counter("event_cache", start_time)
@ -772,19 +779,22 @@ class SQLBaseStore(object):
start_time = update_counter("select_event", start_time) start_time = update_counter("select_event", start_time)
result = self._get_event_from_row_txn(
txn, internal_metadata, js, redacted,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
rejected_reason=rejected_reason,
)
self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
if allow_rejected or not rejected_reason: if allow_rejected or not rejected_reason:
result = self._get_event_from_row_txn(
txn, internal_metadata, js, redacted,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
)
cache[(check_redacted, get_prev_content, allow_rejected)] = result
return result return result
else: else:
return None return None
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False): check_redacted=True, get_prev_content=False,
rejected_reason=None):
start_time = time.time() * 1000 start_time = time.time() * 1000
@ -799,7 +809,11 @@ class SQLBaseStore(object):
internal_metadata = json.loads(internal_metadata) internal_metadata = json.loads(internal_metadata)
start_time = update_counter("decode_internal", start_time) start_time = update_counter("decode_internal", start_time)
ev = FrozenEvent(d, internal_metadata_dict=internal_metadata) ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
start_time = update_counter("build_frozen_event", start_time) start_time = update_counter("build_frozen_event", start_time)
if check_redacted and redacted: if check_redacted and redacted:

View File

@ -13,154 +13,35 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import simplejson import urllib
import yaml
from simplejson import JSONDecodeError from simplejson import JSONDecodeError
import simplejson as json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.errors import StoreError from synapse.appservice import ApplicationService, AppServiceTransaction
from synapse.appservice import ApplicationService
from synapse.storage.roommember import RoomsForUser from synapse.storage.roommember import RoomsForUser
from synapse.types import UserID
from ._base import SQLBaseStore from ._base import SQLBaseStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def log_failure(failure):
logger.error("Failed to detect application services: %s", failure.value)
logger.error(failure.getTraceback())
class ApplicationServiceStore(SQLBaseStore): class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, hs):
super(ApplicationServiceStore, self).__init__(hs) super(ApplicationServiceStore, self).__init__(hs)
self.hostname = hs.hostname
self.services_cache = [] self.services_cache = []
self.cache_defer = self._populate_cache() self._populate_appservice_cache(
self.cache_defer.addErrback(log_failure) hs.config.app_service_config_files
@defer.inlineCallbacks
def unregister_app_service(self, token):
"""Unregisters this service.
This removes all AS specific regex and the base URL. The token is the
only thing preserved for future registration attempts.
"""
yield self.cache_defer # make sure the cache is ready
yield self.runInteraction(
"unregister_app_service",
self._unregister_app_service_txn,
token,
)
# update cache TODO: Should this be in the txn?
for service in self.services_cache:
if service.token == token:
service.url = None
service.namespaces = None
service.hs_token = None
def _unregister_app_service_txn(self, txn, token):
# kill the url to prevent pushes
txn.execute(
"UPDATE application_services SET url=NULL WHERE token=?",
(token,)
) )
# cleanup regex
as_id = self._get_as_id_txn(txn, token)
if not as_id:
logger.warning(
"unregister_app_service_txn: Failed to find as_id for token=",
token
)
return False
txn.execute(
"DELETE FROM application_services_regex WHERE as_id=?",
(as_id,)
)
return True
@defer.inlineCallbacks
def update_app_service(self, service):
"""Update an application service, clobbering what was previously there.
Args:
service(ApplicationService): The updated service.
"""
yield self.cache_defer # make sure the cache is ready
# NB: There is no "insert" since we provide no public-facing API to
# allocate new ASes. It relies on the server admin inserting the AS
# token into the database manually.
if not service.token or not service.url:
raise StoreError(400, "Token and url must be specified.")
if not service.hs_token:
raise StoreError(500, "No HS token")
yield self.runInteraction(
"update_app_service",
self._update_app_service_txn,
service
)
# update cache TODO: Should this be in the txn?
for (index, cache_service) in enumerate(self.services_cache):
if service.token == cache_service.token:
self.services_cache[index] = service
logger.info("Updated: %s", service)
return
# new entry
self.services_cache.append(service)
logger.info("Updated(new): %s", service)
def _update_app_service_txn(self, txn, service):
as_id = self._get_as_id_txn(txn, service.token)
if not as_id:
logger.warning(
"update_app_service_txn: Failed to find as_id for token=",
service.token
)
return False
txn.execute(
"UPDATE application_services SET url=?, hs_token=?, sender=? "
"WHERE id=?",
(service.url, service.hs_token, service.sender, as_id,)
)
# cleanup regex
txn.execute(
"DELETE FROM application_services_regex WHERE as_id=?",
(as_id,)
)
for (ns_int, ns_str) in enumerate(ApplicationService.NS_LIST):
if ns_str in service.namespaces:
for regex_obj in service.namespaces[ns_str]:
txn.execute(
"INSERT INTO application_services_regex("
"as_id, namespace, regex) values(?,?,?)",
(as_id, ns_int, simplejson.dumps(regex_obj))
)
return True
def _get_as_id_txn(self, txn, token):
cursor = txn.execute(
"SELECT id FROM application_services WHERE token=?",
(token,)
)
res = cursor.fetchone()
if res:
return res[0]
@defer.inlineCallbacks
def get_app_services(self): def get_app_services(self):
yield self.cache_defer # make sure the cache is ready return defer.succeed(self.services_cache)
defer.returnValue(self.services_cache)
@defer.inlineCallbacks
def get_app_service_by_user_id(self, user_id): def get_app_service_by_user_id(self, user_id):
"""Retrieve an application service from their user ID. """Retrieve an application service from their user ID.
@ -174,37 +55,23 @@ class ApplicationServiceStore(SQLBaseStore):
Returns: Returns:
synapse.appservice.ApplicationService or None. synapse.appservice.ApplicationService or None.
""" """
yield self.cache_defer # make sure the cache is ready
for service in self.services_cache: for service in self.services_cache:
if service.sender == user_id: if service.sender == user_id:
defer.returnValue(service) return defer.succeed(service)
return return defer.succeed(None)
defer.returnValue(None)
@defer.inlineCallbacks def get_app_service_by_token(self, token):
def get_app_service_by_token(self, token, from_cache=True):
"""Get the application service with the given appservice token. """Get the application service with the given appservice token.
Args: Args:
token (str): The application service token. token (str): The application service token.
from_cache (bool): True to get this service from the cache, False to Returns:
check the database. synapse.appservice.ApplicationService or None.
Raises:
StoreError if there was a problem retrieving this service.
""" """
yield self.cache_defer # make sure the cache is ready for service in self.services_cache:
if service.token == token:
if from_cache: return defer.succeed(service)
for service in self.services_cache: return defer.succeed(None)
if service.token == token:
defer.returnValue(service)
return
defer.returnValue(None)
# TODO: The from_cache=False impl
# TODO: This should be JOINed with the application_services_regex table.
def get_app_service_rooms(self, service): def get_app_service_rooms(self, service):
"""Get a list of RoomsForUser for this application service. """Get a list of RoomsForUser for this application service.
@ -277,12 +144,7 @@ class ApplicationServiceStore(SQLBaseStore):
return rooms_for_user_matching_user_id return rooms_for_user_matching_user_id
@defer.inlineCallbacks def _parse_services_dict(self, results):
def _populate_cache(self):
"""Populates the ApplicationServiceCache from the database."""
sql = ("SELECT * FROM application_services LEFT JOIN "
"application_services_regex ON application_services.id = "
"application_services_regex.as_id")
# SQL results in the form: # SQL results in the form:
# [ # [
# { # {
@ -296,12 +158,14 @@ class ApplicationServiceStore(SQLBaseStore):
# } # }
# ] # ]
services = {} services = {}
results = yield self._execute_and_decode("_populate_cache", sql)
for res in results: for res in results:
as_token = res["token"] as_token = res["token"]
if as_token is None:
continue
if as_token not in services: if as_token not in services:
# add the service # add the service
services[as_token] = { services[as_token] = {
"id": res["id"],
"url": res["url"], "url": res["url"],
"token": as_token, "token": as_token,
"hs_token": res["hs_token"], "hs_token": res["hs_token"],
@ -319,20 +183,287 @@ class ApplicationServiceStore(SQLBaseStore):
try: try:
services[as_token]["namespaces"][ services[as_token]["namespaces"][
ApplicationService.NS_LIST[ns_int]].append( ApplicationService.NS_LIST[ns_int]].append(
simplejson.loads(res["regex"]) json.loads(res["regex"])
) )
except IndexError: except IndexError:
logger.error("Bad namespace enum '%s'. %s", ns_int, res) logger.error("Bad namespace enum '%s'. %s", ns_int, res)
except JSONDecodeError: except JSONDecodeError:
logger.error("Bad regex object '%s'", res["regex"]) logger.error("Bad regex object '%s'", res["regex"])
# TODO get last successful txn id f.e. service service_list = []
for service in services.values(): for service in services.values():
logger.info("Found application service: %s", service) service_list.append(ApplicationService(
self.services_cache.append(ApplicationService(
token=service["token"], token=service["token"],
url=service["url"], url=service["url"],
namespaces=service["namespaces"], namespaces=service["namespaces"],
hs_token=service["hs_token"], hs_token=service["hs_token"],
sender=service["sender"] sender=service["sender"],
id=service["id"]
)) ))
return service_list
def _load_appservice(self, as_info):
required_string_fields = [
"url", "as_token", "hs_token", "sender_localpart"
]
for field in required_string_fields:
if not isinstance(as_info.get(field), basestring):
raise KeyError("Required string field: '%s'", field)
localpart = as_info["sender_localpart"]
if urllib.quote(localpart) != localpart:
raise ValueError(
"sender_localpart needs characters which are not URL encoded."
)
user = UserID(localpart, self.hostname)
user_id = user.to_string()
# namespace checks
if not isinstance(as_info.get("namespaces"), dict):
raise KeyError("Requires 'namespaces' object.")
for ns in ApplicationService.NS_LIST:
# specific namespaces are optional
if ns in as_info["namespaces"]:
# expect a list of dicts with exclusive and regex keys
for regex_obj in as_info["namespaces"][ns]:
if not isinstance(regex_obj, dict):
raise ValueError(
"Expected namespace entry in %s to be an object,"
" but got %s", ns, regex_obj
)
if not isinstance(regex_obj.get("regex"), basestring):
raise ValueError(
"Missing/bad type 'regex' key in %s", regex_obj
)
if not isinstance(regex_obj.get("exclusive"), bool):
raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj
)
return ApplicationService(
token=as_info["as_token"],
url=as_info["url"],
namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["as_token"] # the token is the only unique thing here
)
def _populate_appservice_cache(self, config_files):
"""Populates a cache of Application Services from the config files."""
if not isinstance(config_files, list):
logger.warning(
"Expected %s to be a list of AS config files.", config_files
)
return
for config_file in config_files:
try:
with open(config_file, 'r') as f:
appservice = self._load_appservice(yaml.load(f))
logger.info("Loaded application service: %s", appservice)
self.services_cache.append(appservice)
except Exception as e:
logger.error("Failed to load appservice from '%s'", config_file)
logger.exception(e)
class ApplicationServiceTransactionStore(SQLBaseStore):
def __init__(self, hs):
super(ApplicationServiceTransactionStore, self).__init__(hs)
@defer.inlineCallbacks
def get_appservices_by_state(self, state):
"""Get a list of application services based on their state.
Args:
state(ApplicationServiceState): The state to filter on.
Returns:
A Deferred which resolves to a list of ApplicationServices, which
may be empty.
"""
results = yield self._simple_select_list(
"application_services_state",
dict(state=state),
["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = yield self.get_app_services()
services = []
for res in results:
for service in as_list:
if service.id == res["as_id"]:
services.append(service)
defer.returnValue(services)
@defer.inlineCallbacks
def get_appservice_state(self, service):
"""Get the application service state.
Args:
service(ApplicationService): The service whose state to set.
Returns:
A Deferred which resolves to ApplicationServiceState.
"""
result = yield self._simple_select_one(
"application_services_state",
dict(as_id=service.id),
["state"],
allow_none=True
)
if result:
defer.returnValue(result.get("state"))
return
defer.returnValue(None)
def set_appservice_state(self, service, state):
"""Set the application service state.
Args:
service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply.
Returns:
A Deferred which resolves when the state was set successfully.
"""
return self._simple_upsert(
"application_services_state",
dict(as_id=service.id),
dict(state=state)
)
def create_appservice_txn(self, service, events):
"""Atomically creates a new transaction for this application service
with the given list of events.
Args:
service(ApplicationService): The service who the transaction is for.
events(list<Event>): A list of events to put in the transaction.
Returns:
AppServiceTransaction: A new transaction.
"""
return self.runInteraction(
"create_appservice_txn",
self._create_appservice_txn,
service, events
)
def _create_appservice_txn(self, txn, service, events):
# work out new txn id (highest txn id for this service += 1)
# The highest id may be the last one sent (in which case it is last_txn)
# or it may be the highest in the txns list (which are waiting to be/are
# being sent)
last_txn_id = self._get_last_txn(txn, service.id)
result = txn.execute(
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
(service.id,)
)
highest_txn_id = result.fetchone()[0]
if highest_txn_id is None:
highest_txn_id = 0
new_txn_id = max(highest_txn_id, last_txn_id) + 1
# Insert new txn into txn table
event_ids = [e.event_id for e in events]
txn.execute(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)",
(service.id, new_txn_id, json.dumps(event_ids))
)
return AppServiceTransaction(
service=service, id=new_txn_id, events=events
)
def complete_appservice_txn(self, txn_id, service):
"""Completes an application service transaction.
Args:
txn_id(str): The transaction ID being completed.
service(ApplicationService): The application service which was sent
this transaction.
Returns:
A Deferred which resolves if this transaction was stored
successfully.
"""
return self.runInteraction(
"complete_appservice_txn",
self._complete_appservice_txn,
txn_id, service
)
def _complete_appservice_txn(self, txn, txn_id, service):
txn_id = int(txn_id)
# Debugging query: Make sure the txn being completed is EXACTLY +1 from
# what was there before. If it isn't, we've got problems (e.g. the AS
# has probably missed some events), so whine loudly but still continue,
# since it shouldn't fail completion of the transaction.
last_txn_id = self._get_last_txn(txn, service.id)
if (last_txn_id + 1) != txn_id:
logger.error(
"appservice: Completing a transaction which has an ID > 1 from "
"the last ID sent to this AS. We've either dropped events or "
"sent it to the AS out of order. FIX ME. last_txn=%s "
"completing_txn=%s service_id=%s", last_txn_id, txn_id,
service.id
)
# Set current txn_id for AS to 'txn_id'
self._simple_upsert_txn(
txn, "application_services_state", dict(as_id=service.id),
dict(last_txn=txn_id)
)
# Delete txn
self._simple_delete_txn(
txn, "application_services_txns",
dict(txn_id=txn_id, as_id=service.id)
)
def get_oldest_unsent_txn(self, service):
"""Get the oldest transaction which has not been sent for this
service.
Args:
service(ApplicationService): The app service to get the oldest txn.
Returns:
A Deferred which resolves to an AppServiceTransaction or
None.
"""
return self.runInteraction(
"get_oldest_unsent_appservice_txn",
self._get_oldest_unsent_txn,
service
)
def _get_oldest_unsent_txn(self, txn, service):
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
result = txn.execute(
"SELECT MIN(txn_id), * FROM application_services_txns WHERE as_id=?",
(service.id,)
)
entry = self.cursor_to_dict(result)[0]
if not entry or entry["txn_id"] is None:
# the min(txn_id) part will force a row, so entry may not be None
return None
event_ids = json.loads(entry["event_ids"])
events = self._get_events_txn(txn, event_ids)
return AppServiceTransaction(
service=service, id=entry["txn_id"], events=events
)
def _get_last_txn(self, txn, service_id):
result = txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?",
(service_id,)
)
last_txn_id = result.fetchone()
if last_txn_id is None or last_txn_id[0] is None: # no row exists
return 0
else:
return int(last_txn_id[0]) # select 'last_txn' col

View File

@ -94,7 +94,7 @@ class EventsStore(SQLBaseStore):
current_state=None): current_state=None):
# Remove the any existing cache entries for the event_id # Remove the any existing cache entries for the event_id
self._get_event_cache.pop(event.event_id) self._invalidate_get_event_cache(event.event_id)
# We purposefully do this first since if we include a `current_state` # We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table # key, we *want* to update the `current_state_events` table
@ -356,7 +356,7 @@ class EventsStore(SQLBaseStore):
def _store_redaction(self, txn, event): def _store_redaction(self, txn, event):
# invalidate the cache for the redacted event # invalidate the cache for the redacted event
self._get_event_cache.pop(event.redacts) self._invalidate_get_event_cache(event.redacts)
txn.execute( txn.execute(
"INSERT INTO redactions (event_id, redacts) VALUES (?,?)", "INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
(event.event_id, event.redacts) (event.event_id, event.redacts)

View File

@ -0,0 +1,30 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS application_services_state(
as_id TEXT PRIMARY KEY,
state TEXT,
last_txn TEXT
);
CREATE TABLE IF NOT EXISTS application_services_txns(
as_id TEXT NOT NULL,
txn_id INTEGER NOT NULL,
event_ids TEXT NOT NULL,
UNIQUE(as_id, txn_id) ON CONFLICT ROLLBACK
);

View File

@ -65,7 +65,7 @@ class ExpiringCache(object):
if self._max_len and len(self._cache.keys()) > self._max_len: if self._max_len and len(self._cache.keys()) > self._max_len:
sorted_entries = sorted( sorted_entries = sorted(
self._cache.items(), self._cache.items(),
key=lambda k, v: v.time, key=lambda (k, v): v.time,
) )
for k, _ in sorted_entries[self._max_len:]: for k, _ in sorted_entries[self._max_len:]:

View File

@ -90,12 +90,16 @@ class LruCache(object):
def cache_len(): def cache_len():
return len(cache) return len(cache)
def cache_contains(key):
return key in cache
self.sentinel = object() self.sentinel = object()
self.get = cache_get self.get = cache_get
self.set = cache_set self.set = cache_set
self.setdefault = cache_set_default self.setdefault = cache_set_default
self.pop = cache_pop self.pop = cache_pop
self.len = cache_len self.len = cache_len
self.contains = cache_contains
def __getitem__(self, key): def __getitem__(self, key):
result = self.get(key, self.sentinel) result = self.get(key, self.sentinel)
@ -114,3 +118,6 @@ class LruCache(object):
def __len__(self): def __len__(self):
return self.len() return self.len()
def __contains__(self, key):
return self.contains(key)

View File

@ -199,6 +199,19 @@ class ApplicationServiceTestCase(unittest.TestCase):
aliases_for_event=["#xmpp_barfoo:matrix.org"] aliases_for_event=["#xmpp_barfoo:matrix.org"]
)) ))
def test_interested_in_self(self):
# make sure invites get through
self.service.sender = "@appservice:name"
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*")
)
self.event.type = "m.room.member"
self.event.content = {
"membership": "invite"
}
self.event.state_key = self.service.sender
self.assertTrue(self.service.is_interested(self.event))
def test_member_list_match(self): def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*") _regex("@irc_.*")

View File

@ -0,0 +1,252 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.appservice import ApplicationServiceState, AppServiceTransaction
from synapse.appservice.scheduler import (
_ServiceQueuer, _TransactionController, _Recoverer
)
from twisted.internet import defer
from ..utils import MockClock
from mock import Mock
from tests import unittest
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self):
self.clock = MockClock()
self.store = Mock()
self.as_api = Mock()
self.recoverer = Mock()
self.recoverer_fn = Mock(return_value=self.recoverer)
self.txnctrl = _TransactionController(
clock=self.clock, store=self.store, as_api=self.as_api,
recoverer_fn=self.recoverer_fn
)
def test_single_service_up_txn_sent(self):
# Test: The AS is up and the txn is successfully sent.
service = Mock()
events = [Mock(), Mock()]
txn_id = "foobar"
txn = Mock(id=txn_id, service=service, events=events)
# mock methods
self.store.get_appservice_state = Mock(
return_value=defer.succeed(ApplicationServiceState.UP)
)
txn.send = Mock(return_value=defer.succeed(True))
self.store.create_appservice_txn = Mock(
return_value=defer.succeed(txn)
)
# actual call
self.txnctrl.send(service, events)
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events # txn made and saved
)
self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made
txn.complete.assert_called_once_with(self.store) # txn completed
def test_single_service_down(self):
# Test: The AS is down so it shouldn't push; Recoverers will do it.
# It should still make a transaction though.
service = Mock()
events = [Mock(), Mock()]
txn = Mock(id="idhere", service=service, events=events)
self.store.get_appservice_state = Mock(
return_value=defer.succeed(ApplicationServiceState.DOWN)
)
self.store.create_appservice_txn = Mock(
return_value=defer.succeed(txn)
)
# actual call
self.txnctrl.send(service, events)
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events # txn made and saved
)
self.assertEquals(0, txn.send.call_count) # txn not sent though
self.assertEquals(0, txn.complete.call_count) # or completed
def test_single_service_up_txn_not_sent(self):
# Test: The AS is up and the txn is not sent. A Recoverer is made and
# started.
service = Mock()
events = [Mock(), Mock()]
txn_id = "foobar"
txn = Mock(id=txn_id, service=service, events=events)
# mock methods
self.store.get_appservice_state = Mock(
return_value=defer.succeed(ApplicationServiceState.UP)
)
self.store.set_appservice_state = Mock(return_value=defer.succeed(True))
txn.send = Mock(return_value=defer.succeed(False)) # fails to send
self.store.create_appservice_txn = Mock(
return_value=defer.succeed(txn)
)
# actual call
self.txnctrl.send(service, events)
self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events
)
self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made
self.assertEquals(1, self.recoverer.recover.call_count) # and invoked
self.assertEquals(1, len(self.txnctrl.recoverers)) # and stored
self.assertEquals(0, txn.complete.call_count) # txn not completed
self.store.set_appservice_state.assert_called_once_with(
service, ApplicationServiceState.DOWN # service marked as down
)
class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
def setUp(self):
self.clock = MockClock()
self.as_api = Mock()
self.store = Mock()
self.service = Mock()
self.callback = Mock()
self.recoverer = _Recoverer(
clock=self.clock,
as_api=self.as_api,
store=self.store,
service=self.service,
callback=self.callback,
)
def test_recover_single_txn(self):
txn = Mock()
# return one txn to send, then no more old txns
txns = [txn, None]
def take_txn(*args, **kwargs):
return defer.succeed(txns.pop(0))
self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn)
self.recoverer.recover()
# shouldn't have called anything prior to waiting for exp backoff
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
txn.send = Mock(return_value=True)
# wait for exp backoff
self.clock.advance_time(2)
self.assertEquals(1, txn.send.call_count)
self.assertEquals(1, txn.complete.call_count)
# 2 because it needs to get None to know there are no more txns
self.assertEquals(2, self.store.get_oldest_unsent_txn.call_count)
self.callback.assert_called_once_with(self.recoverer)
self.assertEquals(self.recoverer.service, self.service)
def test_recover_retry_txn(self):
txn = Mock()
txns = [txn, None]
pop_txn = False
def take_txn(*args, **kwargs):
if pop_txn:
return defer.succeed(txns.pop(0))
else:
return defer.succeed(txn)
self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn)
self.recoverer.recover()
self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count)
txn.send = Mock(return_value=False)
self.clock.advance_time(2)
self.assertEquals(1, txn.send.call_count)
self.assertEquals(0, txn.complete.call_count)
self.assertEquals(0, self.callback.call_count)
self.clock.advance_time(4)
self.assertEquals(2, txn.send.call_count)
self.assertEquals(0, txn.complete.call_count)
self.assertEquals(0, self.callback.call_count)
self.clock.advance_time(8)
self.assertEquals(3, txn.send.call_count)
self.assertEquals(0, txn.complete.call_count)
self.assertEquals(0, self.callback.call_count)
txn.send = Mock(return_value=True) # successfully send the txn
pop_txn = True # returns the txn the first time, then no more.
self.clock.advance_time(16)
self.assertEquals(1, txn.send.call_count) # new mock reset call count
self.assertEquals(1, txn.complete.call_count)
self.callback.assert_called_once_with(self.recoverer)
class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
def setUp(self):
self.txn_ctrl = Mock()
self.queuer = _ServiceQueuer(self.txn_ctrl)
def test_send_single_event_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4)
event = Mock()
self.queuer.enqueue(service, event)
self.txn_ctrl.send.assert_called_once_with(service, [event])
def test_send_single_event_with_queue(self):
d = defer.Deferred()
self.txn_ctrl.send = Mock(return_value=d)
service = Mock(id=4)
event = Mock(event_id="first")
event2 = Mock(event_id="second")
event3 = Mock(event_id="third")
# Send an event and don't resolve it just yet.
self.queuer.enqueue(service, event)
# Send more events: expect send() to NOT be called multiple times.
self.queuer.enqueue(service, event2)
self.queuer.enqueue(service, event3)
self.txn_ctrl.send.assert_called_with(service, [event])
self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve the send event: expect the queued events to be sent
d.callback(service)
self.txn_ctrl.send.assert_called_with(service, [event2, event3])
self.assertEquals(2, self.txn_ctrl.send.call_count)
def test_multiple_service_queues(self):
# Tests that each service has its own queue, and that they don't block
# on each other.
srv1 = Mock(id=4)
srv_1_defer = defer.Deferred()
srv_1_event = Mock(event_id="srv1a")
srv_1_event2 = Mock(event_id="srv1b")
srv2 = Mock(id=6)
srv_2_defer = defer.Deferred()
srv_2_event = Mock(event_id="srv2a")
srv_2_event2 = Mock(event_id="srv2b")
send_return_list = [srv_1_defer, srv_2_defer]
self.txn_ctrl.send = Mock(side_effect=lambda x,y: send_return_list.pop(0))
# send events for different ASes and make sure they are sent
self.queuer.enqueue(srv1, srv_1_event)
self.queuer.enqueue(srv1, srv_1_event2)
self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event])
self.queuer.enqueue(srv2, srv_2_event)
self.queuer.enqueue(srv2, srv_2_event2)
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event])
# make sure callbacks for a service only send queued events for THAT
# service
srv_2_defer.callback(srv2)
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2])
self.assertEquals(3, self.txn_ctrl.send.call_count)

View File

@ -27,10 +27,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.mock_store = Mock() self.mock_store = Mock()
self.mock_as_api = Mock() self.mock_as_api = Mock()
self.mock_scheduler = Mock()
hs = Mock() hs = Mock()
hs.get_datastore = Mock(return_value=self.mock_store) hs.get_datastore = Mock(return_value=self.mock_store)
self.handler = ApplicationServicesHandler( self.handler = ApplicationServicesHandler(
hs, self.mock_as_api hs, self.mock_as_api, self.mock_scheduler
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -52,7 +53,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
) )
self.mock_as_api.push = Mock() self.mock_as_api.push = Mock()
yield self.handler.notify_interested_services(event) yield self.handler.notify_interested_services(event)
self.mock_as_api.push.assert_called_once_with(interested_service, event) self.mock_scheduler.submit_event_for_as.assert_called_once_with(
interested_service, event
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_query_room_alias_exists(self): def test_query_room_alias_exists(self):

View File

@ -126,6 +126,13 @@ class TypingNotificationsTestCase(unittest.TestCase):
return defer.succeed([]) return defer.succeed([])
self.room_member_handler.get_room_members = get_room_members self.room_member_handler.get_room_members = get_room_members
def get_joined_rooms_for_user(user):
if user in self.room_members:
return defer.succeed([self.room_id])
else:
return defer.succeed([])
self.room_member_handler.get_joined_rooms_for_user = get_joined_rooms_for_user
@defer.inlineCallbacks @defer.inlineCallbacks
def fetch_room_distributions_into(room_id, localusers=None, def fetch_room_distributions_into(room_id, localusers=None,
remotedomains=None, ignore_user=None): remotedomains=None, ignore_user=None):
@ -175,8 +182,9 @@ class TypingNotificationsTestCase(unittest.TestCase):
]) ])
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
self.assertEquals( self.assertEquals(
self.event_source.get_new_events_for_user(self.u_apple, 0, None)[0], events[0],
[ [
{"type": "m.typing", {"type": "m.typing",
"room_id": self.room_id, "room_id": self.room_id,
@ -237,8 +245,9 @@ class TypingNotificationsTestCase(unittest.TestCase):
]) ])
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
self.assertEquals( self.assertEquals(
self.event_source.get_new_events_for_user(self.u_apple, 0, None)[0], events[0],
[ [
{"type": "m.typing", {"type": "m.typing",
"room_id": self.room_id, "room_id": self.room_id,
@ -292,8 +301,9 @@ class TypingNotificationsTestCase(unittest.TestCase):
yield put_json.await_calls() yield put_json.await_calls()
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
self.assertEquals( self.assertEquals(
self.event_source.get_new_events_for_user(self.u_apple, 0, None)[0], events[0],
[ [
{"type": "m.typing", {"type": "m.typing",
"room_id": self.room_id, "room_id": self.room_id,
@ -322,8 +332,9 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.on_new_user_event.reset_mock() self.on_new_user_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
self.assertEquals( self.assertEquals(
self.event_source.get_new_events_for_user(self.u_apple, 0, None)[0], events[0],
[ [
{"type": "m.typing", {"type": "m.typing",
"room_id": self.room_id, "room_id": self.room_id,
@ -340,8 +351,9 @@ class TypingNotificationsTestCase(unittest.TestCase):
]) ])
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 1, None)
self.assertEquals( self.assertEquals(
self.event_source.get_new_events_for_user(self.u_apple, 1, None)[0], events[0],
[ [
{"type": "m.typing", {"type": "m.typing",
"room_id": self.room_id, "room_id": self.room_id,
@ -366,8 +378,9 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.on_new_user_event.reset_mock() self.on_new_user_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 3) self.assertEquals(self.event_source.get_current_key(), 3)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
self.assertEquals( self.assertEquals(
self.event_source.get_new_events_for_user(self.u_apple, 0, None)[0], events[0],
[ [
{"type": "m.typing", {"type": "m.typing",
"room_id": self.room_id, "room_id": self.room_id,

View File

@ -34,6 +34,8 @@ class RoomTypingTestCase(RestTestCase):
""" Tests /rooms/$room_id/typing/$user_id REST API. """ """ Tests /rooms/$room_id/typing/$user_id REST API. """
user_id = "@sid:red" user_id = "@sid:red"
user = UserID.from_string(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.clock = MockClock() self.clock = MockClock()
@ -75,7 +77,7 @@ class RoomTypingTestCase(RestTestCase):
def get_room_members(room_id): def get_room_members(room_id):
if room_id == self.room_id: if room_id == self.room_id:
return defer.succeed([UserID.from_string(self.user_id)]) return defer.succeed([self.user])
else: else:
return defer.succeed([]) return defer.succeed([])
@ -115,8 +117,9 @@ class RoomTypingTestCase(RestTestCase):
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.user, 0, None)
self.assertEquals( self.assertEquals(
self.event_source.get_new_events_for_user(self.user_id, 0, None)[0], events[0],
[ [
{"type": "m.typing", {"type": "m.typing",
"room_id": self.room_id, "room_id": self.room_id,

View File

@ -51,6 +51,46 @@ class CacheTestCase(unittest.TestCase):
self.assertTrue(failed) self.assertTrue(failed)
def test_eviction(self):
cache = Cache("test", max_entries=2)
cache.prefill(1, "one")
cache.prefill(2, "two")
cache.prefill(3, "three") # 1 will be evicted
failed = False
try:
cache.get(1)
except KeyError:
failed = True
self.assertTrue(failed)
cache.get(2)
cache.get(3)
def test_eviction_lru(self):
cache = Cache("test", max_entries=2, lru=True)
cache.prefill(1, "one")
cache.prefill(2, "two")
# Now access 1 again, thus causing 2 to be least-recently used
cache.get(1)
cache.prefill(3, "three")
failed = False
try:
cache.get(2)
except KeyError:
failed = True
self.assertTrue(failed)
cache.get(1)
cache.get(3)
class CacheDecoratorTestCase(unittest.TestCase): class CacheDecoratorTestCase(unittest.TestCase):

View File

@ -15,10 +15,15 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.appservice import ApplicationServiceStore from synapse.storage.appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore
)
import json
import os
import yaml
from mock import Mock from mock import Mock
from tests.utils import SQLiteMemoryDbPool, MockClock from tests.utils import SQLiteMemoryDbPool, MockClock
@ -27,63 +32,39 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.as_yaml_files = []
db_pool = SQLiteMemoryDbPool() db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare() yield db_pool.prepare()
hs = HomeServer( hs = HomeServer(
"test", db_pool=db_pool, clock=MockClock(), config=Mock() "test", db_pool=db_pool, clock=MockClock(),
config=Mock(
app_service_config_files=self.as_yaml_files
)
) )
self.as_token = "token1" self.as_token = "token1"
db_pool.runQuery( self.as_url = "some_url"
"INSERT INTO application_services(token) VALUES(?)", self._add_appservice(self.as_token, self.as_url, "some_hs_token", "bob")
(self.as_token,) self._add_appservice("token2", "some_url", "some_hs_token", "bob")
) self._add_appservice("token3", "some_url", "some_hs_token", "bob")
db_pool.runQuery(
"INSERT INTO application_services(token) VALUES(?)", ("token2",)
)
db_pool.runQuery(
"INSERT INTO application_services(token) VALUES(?)", ("token3",)
)
# must be done after inserts # must be done after inserts
self.store = ApplicationServiceStore(hs) self.store = ApplicationServiceStore(hs)
@defer.inlineCallbacks def tearDown(self):
def test_update_and_retrieval_of_service(self): # TODO: suboptimal that we need to create files for tests!
url = "https://matrix.org/appservices/foobar" for f in self.as_yaml_files:
hs_token = "hstok" try:
user_regex = [ os.remove(f)
{"regex": "@foobar_.*:matrix.org", "exclusive": True} except:
] pass
alias_regex = [
{"regex": "#foobar_.*:matrix.org", "exclusive": False}
]
room_regex = [
] def _add_appservice(self, as_token, url, hs_token, sender):
service = ApplicationService( as_yaml = dict(url=url, as_token=as_token, hs_token=hs_token,
url=url, hs_token=hs_token, token=self.as_token, namespaces={ sender_localpart=sender, namespaces={})
ApplicationService.NS_USERS: user_regex, # use the token as the filename
ApplicationService.NS_ALIASES: alias_regex, with open(as_token, 'w') as outfile:
ApplicationService.NS_ROOMS: room_regex outfile.write(yaml.dump(as_yaml))
}) self.as_yaml_files.append(as_token)
yield self.store.update_app_service(service)
stored_service = yield self.store.get_app_service_by_token(
self.as_token
)
self.assertEquals(stored_service.token, self.as_token)
self.assertEquals(stored_service.url, url)
self.assertEquals(
stored_service.namespaces[ApplicationService.NS_ALIASES],
alias_regex
)
self.assertEquals(
stored_service.namespaces[ApplicationService.NS_ROOMS],
room_regex
)
self.assertEquals(
stored_service.namespaces[ApplicationService.NS_USERS],
user_regex
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_retrieve_unknown_service_token(self): def test_retrieve_unknown_service_token(self):
@ -96,7 +77,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self.as_token self.as_token
) )
self.assertEquals(stored_service.token, self.as_token) self.assertEquals(stored_service.token, self.as_token)
self.assertEquals(stored_service.url, None) self.assertEquals(stored_service.url, self.as_url)
self.assertEquals( self.assertEquals(
stored_service.namespaces[ApplicationService.NS_ALIASES], stored_service.namespaces[ApplicationService.NS_ALIASES],
[] []
@ -114,3 +95,314 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
def test_retrieval_of_all_services(self): def test_retrieval_of_all_services(self):
services = yield self.store.get_app_services() services = yield self.store.get_app_services()
self.assertEquals(len(services), 3) self.assertEquals(len(services), 3)
class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.as_yaml_files = []
self.db_pool = SQLiteMemoryDbPool()
yield self.db_pool.prepare()
self.as_list = [
{
"token": "token1",
"url": "https://matrix-as.org",
"id": "token1"
},
{
"token": "alpha_tok",
"url": "https://alpha.com",
"id": "alpha_tok"
},
{
"token": "beta_tok",
"url": "https://beta.com",
"id": "beta_tok"
},
{
"token": "delta_tok",
"url": "https://delta.com",
"id": "delta_tok"
},
]
for s in self.as_list:
yield self._add_service(s["url"], s["token"])
hs = HomeServer(
"test", db_pool=self.db_pool, clock=MockClock(), config=Mock(
app_service_config_files=self.as_yaml_files
)
)
self.store = TestTransactionStore(hs)
def _add_service(self, url, as_token):
as_yaml = dict(url=url, as_token=as_token, hs_token="something",
sender_localpart="a_sender", namespaces={})
# use the token as the filename
with open(as_token, 'w') as outfile:
outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token)
def _set_state(self, id, state, txn=None):
return self.db_pool.runQuery(
"INSERT INTO application_services_state(as_id, state, last_txn) "
"VALUES(?,?,?)",
(id, state, txn)
)
def _insert_txn(self, as_id, txn_id, events):
return self.db_pool.runQuery(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)",
(as_id, txn_id, json.dumps([e.event_id for e in events]))
)
def _set_last_txn(self, as_id, txn_id):
return self.db_pool.runQuery(
"INSERT INTO application_services_state(as_id, last_txn, state) "
"VALUES(?,?,?)",
(as_id, txn_id, ApplicationServiceState.UP)
)
@defer.inlineCallbacks
def test_get_appservice_state_none(self):
service = Mock(id=999)
state = yield self.store.get_appservice_state(service)
self.assertEquals(None, state)
@defer.inlineCallbacks
def test_get_appservice_state_up(self):
yield self._set_state(
self.as_list[0]["id"], ApplicationServiceState.UP
)
service = Mock(id=self.as_list[0]["id"])
state = yield self.store.get_appservice_state(service)
self.assertEquals(ApplicationServiceState.UP, state)
@defer.inlineCallbacks
def test_get_appservice_state_down(self):
yield self._set_state(
self.as_list[0]["id"], ApplicationServiceState.UP
)
yield self._set_state(
self.as_list[1]["id"], ApplicationServiceState.DOWN
)
yield self._set_state(
self.as_list[2]["id"], ApplicationServiceState.DOWN
)
service = Mock(id=self.as_list[1]["id"])
state = yield self.store.get_appservice_state(service)
self.assertEquals(ApplicationServiceState.DOWN, state)
@defer.inlineCallbacks
def test_get_appservices_by_state_none(self):
services = yield self.store.get_appservices_by_state(
ApplicationServiceState.DOWN
)
self.assertEquals(0, len(services))
@defer.inlineCallbacks
def test_set_appservices_state_down(self):
service = Mock(id=self.as_list[1]["id"])
yield self.store.set_appservice_state(
service,
ApplicationServiceState.DOWN
)
rows = yield self.db_pool.runQuery(
"SELECT as_id FROM application_services_state WHERE state=?",
(ApplicationServiceState.DOWN,)
)
self.assertEquals(service.id, rows[0][0])
@defer.inlineCallbacks
def test_set_appservices_state_multiple_up(self):
service = Mock(id=self.as_list[1]["id"])
yield self.store.set_appservice_state(
service,
ApplicationServiceState.UP
)
yield self.store.set_appservice_state(
service,
ApplicationServiceState.DOWN
)
yield self.store.set_appservice_state(
service,
ApplicationServiceState.UP
)
rows = yield self.db_pool.runQuery(
"SELECT as_id FROM application_services_state WHERE state=?",
(ApplicationServiceState.UP,)
)
self.assertEquals(service.id, rows[0][0])
@defer.inlineCallbacks
def test_create_appservice_txn_first(self):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
txn = yield self.store.create_appservice_txn(service, events)
self.assertEquals(txn.id, 1)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@defer.inlineCallbacks
def test_create_appservice_txn_older_last_txn(self):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
yield self._set_last_txn(service.id, 9643) # AS is falling behind
yield self._insert_txn(service.id, 9644, events)
yield self._insert_txn(service.id, 9645, events)
txn = yield self.store.create_appservice_txn(service, events)
self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@defer.inlineCallbacks
def test_create_appservice_txn_up_to_date_last_txn(self):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
yield self._set_last_txn(service.id, 9643)
txn = yield self.store.create_appservice_txn(service, events)
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@defer.inlineCallbacks
def test_create_appservice_txn_up_fuzzing(self):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
yield self._set_last_txn(service.id, 9643)
# dump in rows with higher IDs to make sure the queries aren't wrong.
yield self._set_last_txn(self.as_list[1]["id"], 119643)
yield self._set_last_txn(self.as_list[2]["id"], 9)
yield self._set_last_txn(self.as_list[3]["id"], 9643)
yield self._insert_txn(self.as_list[1]["id"], 119644, events)
yield self._insert_txn(self.as_list[1]["id"], 119645, events)
yield self._insert_txn(self.as_list[1]["id"], 119646, events)
yield self._insert_txn(self.as_list[2]["id"], 10, events)
yield self._insert_txn(self.as_list[3]["id"], 9643, events)
txn = yield self.store.create_appservice_txn(service, events)
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
self.assertEquals(txn.service, service)
@defer.inlineCallbacks
def test_complete_appservice_txn_first_txn(self):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
txn_id = 1
yield self._insert_txn(service.id, txn_id, events)
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
res = yield self.db_pool.runQuery(
"SELECT last_txn FROM application_services_state WHERE as_id=?",
(service.id,)
)
self.assertEquals(1, len(res))
self.assertEquals(str(txn_id), res[0][0])
res = yield self.db_pool.runQuery(
"SELECT * FROM application_services_txns WHERE txn_id=?",
(txn_id,)
)
self.assertEquals(0, len(res))
@defer.inlineCallbacks
def test_complete_appservice_txn_existing_in_state_table(self):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
txn_id = 5
yield self._set_last_txn(service.id, 4)
yield self._insert_txn(service.id, txn_id, events)
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
res = yield self.db_pool.runQuery(
"SELECT last_txn, state FROM application_services_state WHERE "
"as_id=?",
(service.id,)
)
self.assertEquals(1, len(res))
self.assertEquals(str(txn_id), res[0][0])
self.assertEquals(ApplicationServiceState.UP, res[0][1])
res = yield self.db_pool.runQuery(
"SELECT * FROM application_services_txns WHERE txn_id=?",
(txn_id,)
)
self.assertEquals(0, len(res))
@defer.inlineCallbacks
def test_get_oldest_unsent_txn_none(self):
service = Mock(id=self.as_list[0]["id"])
txn = yield self.store.get_oldest_unsent_txn(service)
self.assertEquals(None, txn)
@defer.inlineCallbacks
def test_get_oldest_unsent_txn(self):
service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")]
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out
self.store._get_events_txn = Mock(return_value=events)
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
yield self._insert_txn(service.id, 10, events)
yield self._insert_txn(service.id, 11, other_events)
yield self._insert_txn(service.id, 12, other_events)
txn = yield self.store.get_oldest_unsent_txn(service)
self.assertEquals(service, txn.service)
self.assertEquals(10, txn.id)
self.assertEquals(events, txn.events)
@defer.inlineCallbacks
def test_get_appservices_by_state_single(self):
yield self._set_state(
self.as_list[0]["id"], ApplicationServiceState.DOWN
)
yield self._set_state(
self.as_list[1]["id"], ApplicationServiceState.UP
)
services = yield self.store.get_appservices_by_state(
ApplicationServiceState.DOWN
)
self.assertEquals(1, len(services))
self.assertEquals(self.as_list[0]["id"], services[0].id)
@defer.inlineCallbacks
def test_get_appservices_by_state_multiple(self):
yield self._set_state(
self.as_list[0]["id"], ApplicationServiceState.DOWN
)
yield self._set_state(
self.as_list[1]["id"], ApplicationServiceState.UP
)
yield self._set_state(
self.as_list[2]["id"], ApplicationServiceState.DOWN
)
yield self._set_state(
self.as_list[3]["id"], ApplicationServiceState.UP
)
services = yield self.store.get_appservices_by_state(
ApplicationServiceState.DOWN
)
self.assertEquals(2, len(services))
self.assertEquals(
set([self.as_list[2]["id"], self.as_list[0]["id"]]),
set([services[0].id, services[1].id])
)
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore,
ApplicationServiceStore):
def __init__(self, hs):
super(TestTransactionStore, self).__init__(hs)