Merge branch 'release-v0.5.1' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2014-11-26 12:06:36 +00:00
commit 48ee9ddb22
57 changed files with 1040 additions and 735 deletions

View File

@ -1,3 +1,11 @@
Changes in synapse 0.5.1 (2014-11-26)
=====================================
See UPGRADES.rst for specific instructions on how to upgrade.
* Fix bug where we served up an Event that did not match its signatures.
* Fix regression where we no longer correctly handled the case where a
homeserver receives an event for a room it doesn't recognise (but is in.)
Changes in synapse 0.5.0 (2014-11-19) Changes in synapse 0.5.0 (2014-11-19)
===================================== =====================================
This release includes changes to the federation protocol and client-server API This release includes changes to the federation protocol and client-server API

View File

@ -69,8 +69,8 @@ command line utility which lets you easily see what the JSON APIs are up to).
Meanwhile, iOS and Android SDKs and clients are currently in development and available from: Meanwhile, iOS and Android SDKs and clients are currently in development and available from:
* https://github.com/matrix-org/matrix-ios-sdk - https://github.com/matrix-org/matrix-ios-sdk
* https://github.com/matrix-org/matrix-android-sdk - https://github.com/matrix-org/matrix-android-sdk
We'd like to invite you to join #matrix:matrix.org (via http://matrix.org/alpha), run a homeserver, take a look at the Matrix spec at We'd like to invite you to join #matrix:matrix.org (via http://matrix.org/alpha), run a homeserver, take a look at the Matrix spec at
http://matrix.org/docs/spec, experiment with the APIs and the demo http://matrix.org/docs/spec, experiment with the APIs and the demo
@ -94,7 +94,7 @@ header files for python C extensions.
Installing prerequisites on Ubuntu or Debian:: Installing prerequisites on Ubuntu or Debian::
$ sudo apt-get install build-essential python2.7-dev libffi-dev \ $ sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools python-pip python-setuptools sqlite3
Installing prerequisites on Mac OS X:: Installing prerequisites on Mac OS X::
@ -125,7 +125,7 @@ created. To reset the installation::
pip seems to leak *lots* of memory during installation. For instance, a Linux pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are happens, you will have to individually install the dependencies which are
failing, e.g.: failing, e.g.::
$ pip install --user twisted $ pip install --user twisted
@ -148,7 +148,7 @@ Troubleshooting Running
----------------------- -----------------------
If ``synctl`` fails with ``pkg_resources.DistributionNotFound`` errors you may If ``synctl`` fails with ``pkg_resources.DistributionNotFound`` errors you may
need a newer version of setuptools than that provided by your OS. need a newer version of setuptools than that provided by your OS.::
$ sudo pip install setuptools --upgrade $ sudo pip install setuptools --upgrade
@ -172,7 +172,7 @@ Homeserver Development
====================== ======================
To check out a homeserver for development, clone the git repo into a working To check out a homeserver for development, clone the git repo into a working
directory of your choice: directory of your choice::
$ git clone https://github.com/matrix-org/synapse.git $ git clone https://github.com/matrix-org/synapse.git
$ cd synapse $ cd synapse

View File

@ -1,3 +1,12 @@
Upgrading to v0.5.1
===================
Depending on precisely when you installed v0.5.0 you may have ended up with
a stale release of the reference matrix webclient installed as a python module.
To uninstall it and ensure you are depending on the latest module, please run::
$ pip uninstall syweb
Upgrading to v0.5.0 Upgrading to v0.5.0
=================== ===================

View File

@ -1 +1 @@
0.5.0 0.5.1

View File

@ -23,7 +23,7 @@ def get_targets(server_name):
for srv in answers: for srv in answers:
yield (srv.target, srv.port) yield (srv.target, srv.port)
except dns.resolver.NXDOMAIN: except dns.resolver.NXDOMAIN:
yield (server_name, 8480) yield (server_name, 8448)
def get_server_keys(server_name, target, port): def get_server_keys(server_name, target, port):
url = "https://%s:%i/_matrix/key/v1" % (target, port) url = "https://%s:%i/_matrix/key/v1" % (target, port)

View File

@ -32,7 +32,7 @@ setup(
description="Reference Synapse Home Server", description="Reference Synapse Home Server",
install_requires=[ install_requires=[
"syutil==0.0.2", "syutil==0.0.2",
"matrix_angular_sdk==0.5.0", "matrix_angular_sdk==0.5.1",
"Twisted>=14.0.0", "Twisted>=14.0.0",
"service_identity>=1.0.0", "service_identity>=1.0.0",
"pyopenssl>=0.14", "pyopenssl>=0.14",
@ -45,7 +45,7 @@ setup(
dependency_links=[ dependency_links=[
"https://github.com/matrix-org/syutil/tarball/v0.0.2#egg=syutil-0.0.2", "https://github.com/matrix-org/syutil/tarball/v0.0.2#egg=syutil-0.0.2",
"https://github.com/pyca/pynacl/tarball/52dbe2dc33f1#egg=pynacl-0.3.0", "https://github.com/pyca/pynacl/tarball/52dbe2dc33f1#egg=pynacl-0.3.0",
"https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.5.0/#egg=matrix_angular_sdk-0.5.0", "https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.5.1/#egg=matrix_angular_sdk-0.5.1",
], ],
setup_requires=[ setup_requires=[
"setuptools_trial", "setuptools_trial",

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a synapse home server. """ This is a reference implementation of a synapse home server.
""" """
__version__ = "0.5.0" __version__ = "0.5.1"

View File

@ -38,79 +38,66 @@ class Auth(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
def check(self, event, raises=False): def check(self, event, auth_events):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Returns: Returns:
True if the auth checks pass. True if the auth checks pass.
Raises:
AuthError if there was a problem authorising this event. This will
be raised only if raises=True.
""" """
try: try:
if hasattr(event, "room_id"): if not hasattr(event, "room_id"):
if event.old_state_events is None: raise AuthError(500, "Event has no room_id: %s" % event)
# Oh, we don't know what the state of the room was, so we if auth_events is None:
# are trusting that this is allowed (at least for now) # Oh, we don't know what the state of the room was, so we
logger.warn("Trusting event: %s", event.event_id) # are trusting that this is allowed (at least for now)
return True logger.warn("Trusting event: %s", event.event_id)
if hasattr(event, "outlier") and event.outlier is True:
# TODO (erikj): Auth for outliers is done differently.
return True
if event.type == RoomCreateEvent.TYPE:
# FIXME
return True
# FIXME: Temp hack
if event.type == RoomAliasesEvent.TYPE:
return True
if event.type == RoomMemberEvent.TYPE:
allowed = self.is_membership_change_allowed(event)
if allowed:
logger.debug("Allowing! %s", event)
else:
logger.debug("Denying! %s", event)
return allowed
self.check_event_sender_in_room(event)
self._can_send_event(event)
if event.type == RoomPowerLevelsEvent.TYPE:
self._check_power_levels(event)
if event.type == RoomRedactionEvent.TYPE:
self._check_redaction(event)
logger.debug("Allowing! %s", event)
return True return True
else:
raise AuthError(500, "Unknown event: %s" % event) if event.type == RoomCreateEvent.TYPE:
# FIXME
return True
# FIXME: Temp hack
if event.type == RoomAliasesEvent.TYPE:
return True
if event.type == RoomMemberEvent.TYPE:
allowed = self.is_membership_change_allowed(
event, auth_events
)
if allowed:
logger.debug("Allowing! %s", event)
else:
logger.debug("Denying! %s", event)
return allowed
self.check_event_sender_in_room(event, auth_events)
self._can_send_event(event, auth_events)
if event.type == RoomPowerLevelsEvent.TYPE:
self._check_power_levels(event, auth_events)
if event.type == RoomRedactionEvent.TYPE:
self._check_redaction(event, auth_events)
logger.debug("Allowing! %s", event)
except AuthError as e: except AuthError as e:
logger.info( logger.info(
"Event auth check failed on event %s with msg: %s", "Event auth check failed on event %s with msg: %s",
event, e.msg event, e.msg
) )
logger.info("Denying! %s", event) logger.info("Denying! %s", event)
if raises: raise
raise
return False
@defer.inlineCallbacks @defer.inlineCallbacks
def check_joined_room(self, room_id, user_id): def check_joined_room(self, room_id, user_id):
try: member = yield self.state.get_current_state(
member = yield self.store.get_room_member( room_id=room_id,
room_id=room_id, event_type=RoomMemberEvent.TYPE,
user_id=user_id state_key=user_id
) )
self._check_joined_room(member, user_id, room_id) self._check_joined_room(member, user_id, room_id)
defer.returnValue(member) defer.returnValue(member)
except AttributeError:
pass
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_host_in_room(self, room_id, host): def check_host_in_room(self, room_id, host):
@ -130,9 +117,9 @@ class Auth(object):
defer.returnValue(False) defer.returnValue(False)
def check_event_sender_in_room(self, event): def check_event_sender_in_room(self, event, auth_events):
key = (RoomMemberEvent.TYPE, event.user_id, ) key = (RoomMemberEvent.TYPE, event.user_id, )
member_event = event.state_events.get(key) member_event = auth_events.get(key)
return self._check_joined_room( return self._check_joined_room(
member_event, member_event,
@ -147,15 +134,15 @@ class Auth(object):
)) ))
@log_function @log_function
def is_membership_change_allowed(self, event): def is_membership_change_allowed(self, event, auth_events):
membership = event.content["membership"] membership = event.content["membership"]
# Check if this is the room creator joining: # Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership: if len(event.prev_events) == 1 and Membership.JOIN == membership:
# Get room creation event: # Get room creation event:
key = (RoomCreateEvent.TYPE, "", ) key = (RoomCreateEvent.TYPE, "", )
create = event.old_state_events.get(key) create = auth_events.get(key)
if event.prev_events[0][0] == create.event_id: if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key: if create.content["creator"] == event.state_key:
return True return True
@ -163,19 +150,19 @@ class Auth(object):
# get info about the caller # get info about the caller
key = (RoomMemberEvent.TYPE, event.user_id, ) key = (RoomMemberEvent.TYPE, event.user_id, )
caller = event.old_state_events.get(key) caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target # get info about the target
key = (RoomMemberEvent.TYPE, target_user_id, ) key = (RoomMemberEvent.TYPE, target_user_id, )
target = event.old_state_events.get(key) target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN target_in_room = target and target.membership == Membership.JOIN
key = (RoomJoinRulesEvent.TYPE, "", ) key = (RoomJoinRulesEvent.TYPE, "", )
join_rule_event = event.old_state_events.get(key) join_rule_event = auth_events.get(key)
if join_rule_event: if join_rule_event:
join_rule = join_rule_event.content.get( join_rule = join_rule_event.content.get(
"join_rule", JoinRules.INVITE "join_rule", JoinRules.INVITE
@ -186,11 +173,13 @@ class Auth(object):
user_level = self._get_power_level_from_event_state( user_level = self._get_power_level_from_event_state(
event, event,
event.user_id, event.user_id,
auth_events,
) )
ban_level, kick_level, redact_level = ( ban_level, kick_level, redact_level = (
self._get_ops_level_from_event_state( self._get_ops_level_from_event_state(
event event,
auth_events,
) )
) )
@ -260,9 +249,9 @@ class Auth(object):
return True return True
def _get_power_level_from_event_state(self, event, user_id): def _get_power_level_from_event_state(self, event, user_id, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", ) key = (RoomPowerLevelsEvent.TYPE, "", )
power_level_event = event.old_state_events.get(key) power_level_event = auth_events.get(key)
level = None level = None
if power_level_event: if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id) level = power_level_event.content.get("users", {}).get(user_id)
@ -270,16 +259,16 @@ class Auth(object):
level = power_level_event.content.get("users_default", 0) level = power_level_event.content.get("users_default", 0)
else: else:
key = (RoomCreateEvent.TYPE, "", ) key = (RoomCreateEvent.TYPE, "", )
create_event = event.old_state_events.get(key) create_event = auth_events.get(key)
if (create_event is not None and if (create_event is not None and
create_event.content["creator"] == user_id): create_event.content["creator"] == user_id):
return 100 return 100
return level return level
def _get_ops_level_from_event_state(self, event): def _get_ops_level_from_event_state(self, event, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", ) key = (RoomPowerLevelsEvent.TYPE, "", )
power_level_event = event.old_state_events.get(key) power_level_event = auth_events.get(key)
if power_level_event: if power_level_event:
return ( return (
@ -375,6 +364,11 @@ class Auth(object):
key = (RoomMemberEvent.TYPE, event.user_id, ) key = (RoomMemberEvent.TYPE, event.user_id, )
member_event = event.old_state_events.get(key) member_event = event.old_state_events.get(key)
key = (RoomCreateEvent.TYPE, "", )
create_event = event.old_state_events.get(key)
if create_event:
auth_events.append(create_event.event_id)
if join_rule_event: if join_rule_event:
join_rule = join_rule_event.content.get("join_rule") join_rule = join_rule_event.content.get("join_rule")
is_public = join_rule == JoinRules.PUBLIC if join_rule else False is_public = join_rule == JoinRules.PUBLIC if join_rule else False
@ -406,9 +400,9 @@ class Auth(object):
event.auth_events = zip(auth_events, hashes) event.auth_events = zip(auth_events, hashes)
@log_function @log_function
def _can_send_event(self, event): def _can_send_event(self, event, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", ) key = (RoomPowerLevelsEvent.TYPE, "", )
send_level_event = event.old_state_events.get(key) send_level_event = auth_events.get(key)
send_level = None send_level = None
if send_level_event: if send_level_event:
send_level = send_level_event.content.get("events", {}).get( send_level = send_level_event.content.get("events", {}).get(
@ -432,6 +426,7 @@ class Auth(object):
user_level = self._get_power_level_from_event_state( user_level = self._get_power_level_from_event_state(
event, event,
event.user_id, event.user_id,
auth_events,
) )
if user_level: if user_level:
@ -468,14 +463,16 @@ class Auth(object):
return True return True
def _check_redaction(self, event): def _check_redaction(self, event, auth_events):
user_level = self._get_power_level_from_event_state( user_level = self._get_power_level_from_event_state(
event, event,
event.user_id, event.user_id,
auth_events,
) )
_, _, redact_level = self._get_ops_level_from_event_state( _, _, redact_level = self._get_ops_level_from_event_state(
event event,
auth_events,
) )
if user_level < redact_level: if user_level < redact_level:
@ -484,7 +481,7 @@ class Auth(object):
"You don't have permission to redact events" "You don't have permission to redact events"
) )
def _check_power_levels(self, event): def _check_power_levels(self, event, auth_events):
user_list = event.content.get("users", {}) user_list = event.content.get("users", {})
# Validate users # Validate users
for k, v in user_list.items(): for k, v in user_list.items():
@ -499,7 +496,7 @@ class Auth(object):
raise SynapseError(400, "Not a valid power level: %s" % (v,)) raise SynapseError(400, "Not a valid power level: %s" % (v,))
key = (event.type, event.state_key, ) key = (event.type, event.state_key, )
current_state = event.old_state_events.get(key) current_state = auth_events.get(key)
if not current_state: if not current_state:
return return
@ -507,6 +504,7 @@ class Auth(object):
user_level = self._get_power_level_from_event_state( user_level = self._get_power_level_from_event_state(
event, event,
event.user_id, event.user_id,
auth_events,
) )
# Check other levels: # Check other levels:

View File

@ -17,6 +17,8 @@
import logging import logging
logger = logging.getLogger(__name__)
class Codes(object): class Codes(object):
UNAUTHORIZED = "M_UNAUTHORIZED" UNAUTHORIZED = "M_UNAUTHORIZED"
@ -38,7 +40,7 @@ class CodeMessageException(Exception):
"""An exception with integer code and message string attributes.""" """An exception with integer code and message string attributes."""
def __init__(self, code, msg): def __init__(self, code, msg):
logging.error("%s: %s, %s", type(self).__name__, code, msg) logger.info("%s: %s, %s", type(self).__name__, code, msg)
super(CodeMessageException, self).__init__("%d: %s" % (code, msg)) super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
self.code = code self.code = code
self.msg = msg self.msg = msg
@ -140,7 +142,8 @@ def cs_exception(exception):
if isinstance(exception, CodeMessageException): if isinstance(exception, CodeMessageException):
return exception.error_dict() return exception.error_dict()
else: else:
logging.error("Unknown exception type: %s", type(exception)) logger.error("Unknown exception type: %s", type(exception))
return {}
def cs_error(msg, code=Codes.UNKNOWN, **kwargs): def cs_error(msg, code=Codes.UNKNOWN, **kwargs):

View File

@ -83,6 +83,8 @@ class SynapseEvent(JsonEncodedObject):
"content", "content",
] ]
outlier = False
def __init__(self, raises=True, **kwargs): def __init__(self, raises=True, **kwargs):
super(SynapseEvent, self).__init__(**kwargs) super(SynapseEvent, self).__init__(**kwargs)
# if "content" in kwargs: # if "content" in kwargs:
@ -123,6 +125,7 @@ class SynapseEvent(JsonEncodedObject):
pdu_json.pop("outlier", None) pdu_json.pop("outlier", None)
pdu_json.pop("replaces_state", None) pdu_json.pop("replaces_state", None)
pdu_json.pop("redacted", None) pdu_json.pop("redacted", None)
pdu_json.pop("prev_content", None)
state_hash = pdu_json.pop("state_hash", None) state_hash = pdu_json.pop("state_hash", None)
if state_hash is not None: if state_hash is not None:
pdu_json.setdefault("unsigned", {})["state_hash"] = state_hash pdu_json.setdefault("unsigned", {})["state_hash"] = state_hash

View File

@ -26,7 +26,7 @@ from twisted.web.server import Site
from synapse.http.server import JsonResource, RootRedirect from synapse.http.server import JsonResource, RootRedirect
from synapse.http.content_repository import ContentRepoResource from synapse.http.content_repository import ContentRepoResource
from synapse.http.server_key_resource import LocalKey from synapse.http.server_key_resource import LocalKey
from synapse.http.client import MatrixHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import ( from synapse.api.urls import (
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, SERVER_KEY_PREFIX,
@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
class SynapseHomeServer(HomeServer): class SynapseHomeServer(HomeServer):
def build_http_client(self): def build_http_client(self):
return MatrixHttpClient(self) return MatrixFederationHttpClient(self)
def build_resource_for_client(self): def build_resource_for_client(self):
return JsonResource() return JsonResource()
@ -116,7 +116,7 @@ class SynapseHomeServer(HomeServer):
# extra resources to existing nodes. See self._resource_id for the key. # extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {} resource_mappings = {}
for (full_path, resource) in desired_tree: for (full_path, resource) in desired_tree:
logging.info("Attaching %s to path %s", resource, full_path) logger.info("Attaching %s to path %s", resource, full_path)
last_resource = self.root_resource last_resource = self.root_resource
for path_seg in full_path.split('/')[1:-1]: for path_seg in full_path.split('/')[1:-1]:
if not path_seg in last_resource.listNames(): if not path_seg in last_resource.listNames():
@ -221,12 +221,12 @@ def setup():
db_name = hs.get_db_name() db_name = hs.get_db_name()
logging.info("Preparing database: %s...", db_name) logger.info("Preparing database: %s...", db_name)
with sqlite3.connect(db_name) as db_conn: with sqlite3.connect(db_name) as db_conn:
prepare_database(db_conn) prepare_database(db_conn)
logging.info("Database prepared in %s.", db_name) logger.info("Database prepared in %s.", db_name)
hs.get_db_pool() hs.get_db_pool()
@ -257,13 +257,16 @@ def setup():
else: else:
reactor.run() reactor.run()
def run(): def run():
with LoggingContext("run"): with LoggingContext("run"):
reactor.run() reactor.run()
def main(): def main():
with LoggingContext("main"): with LoggingContext("main"):
setup() setup()
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -21,11 +21,12 @@ import signal
SYNAPSE = ["python", "-m", "synapse.app.homeserver"] SYNAPSE = ["python", "-m", "synapse.app.homeserver"]
CONFIGFILE="homeserver.yaml" CONFIGFILE = "homeserver.yaml"
PIDFILE="homeserver.pid" PIDFILE = "homeserver.pid"
GREEN = "\x1b[1;32m"
NORMAL = "\x1b[m"
GREEN="\x1b[1;32m"
NORMAL="\x1b[m"
def start(): def start():
if not os.path.exists(CONFIGFILE): if not os.path.exists(CONFIGFILE):
@ -43,12 +44,14 @@ def start():
subprocess.check_call(args) subprocess.check_call(args)
print GREEN + "started" + NORMAL print GREEN + "started" + NORMAL
def stop(): def stop():
if os.path.exists(PIDFILE): if os.path.exists(PIDFILE):
pid = int(open(PIDFILE).read()) pid = int(open(PIDFILE).read())
os.kill(pid, signal.SIGTERM) os.kill(pid, signal.SIGTERM)
print GREEN + "stopped" + NORMAL print GREEN + "stopped" + NORMAL
def main(): def main():
action = sys.argv[1] if sys.argv[1:] else "usage" action = sys.argv[1] if sys.argv[1:] else "usage"
if action == "start": if action == "start":
@ -62,5 +65,6 @@ def main():
sys.stderr.write("Usage: %s [start|stop|restart]\n" % (sys.argv[0],)) sys.stderr.write("Usage: %s [start|stop|restart]\n" % (sys.argv[0],))
sys.exit(1) sys.exit(1)
if __name__=='__main__':
if __name__ == "__main__":
main() main()

View File

@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
def check_event_content_hash(event, hash_algorithm=hashlib.sha256): def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"""Check whether the hash for this PDU matches the contents""" """Check whether the hash for this PDU matches the contents"""
computed_hash = _compute_content_hash(event, hash_algorithm) computed_hash = _compute_content_hash(event, hash_algorithm)
logging.debug("Expecting hash: %s", encode_base64(computed_hash.digest())) logger.debug("Expecting hash: %s", encode_base64(computed_hash.digest()))
if computed_hash.name not in event.hashes: if computed_hash.name not in event.hashes:
raise SynapseError( raise SynapseError(
400, 400,

View File

@ -17,7 +17,7 @@
from twisted.web.http import HTTPClient from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from synapse.http.endpoint import matrix_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
import json import json
import logging import logging
@ -31,7 +31,7 @@ def fetch_server_key(server_name, ssl_context_factory):
"""Fetch the keys for a remote server.""" """Fetch the keys for a remote server."""
factory = SynapseKeyClientFactory() factory = SynapseKeyClientFactory()
endpoint = matrix_endpoint( endpoint = matrix_federation_endpoint(
reactor, server_name, ssl_context_factory, timeout=30 reactor, server_name, ssl_context_factory, timeout=30
) )
@ -48,7 +48,7 @@ def fetch_server_key(server_name, ssl_context_factory):
class SynapseKeyClientError(Exception): class SynapseKeyClientError(Exception):
"""The key wasn't retireved from the remote server.""" """The key wasn't retrieved from the remote server."""
pass pass

View File

@ -135,7 +135,7 @@ class Keyring(object):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
self.store.store_server_certificate( yield self.store.store_server_certificate(
server_name, server_name,
server_name, server_name,
time_now_ms, time_now_ms,
@ -143,7 +143,7 @@ class Keyring(object):
) )
for key_id, key in verify_keys.items(): for key_id, key in verify_keys.items():
self.store.store_server_verify_key( yield self.store.store_server_verify_key(
server_name, server_name, time_now_ms, key server_name, server_name, time_now_ms, key
) )

View File

@ -24,6 +24,7 @@ from .units import Transaction, Edu
from .persistence import TransactionActions from .persistence import TransactionActions
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
import logging import logging
@ -319,19 +320,20 @@ class ReplicationLayer(object):
logger.debug("[%s] Transacition is new", transaction.transaction_id) logger.debug("[%s] Transacition is new", transaction.transaction_id)
dl = [] with PreserveLoggingContext():
for pdu in pdu_list: dl = []
dl.append(self._handle_new_pdu(transaction.origin, pdu)) for pdu in pdu_list:
dl.append(self._handle_new_pdu(transaction.origin, pdu))
if hasattr(transaction, "edus"): if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]: for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu( self.received_edu(
transaction.origin, transaction.origin,
edu.edu_type, edu.edu_type,
edu.content edu.content
) )
results = yield defer.DeferredList(dl) results = yield defer.DeferredList(dl)
ret = [] ret = []
for r in results: for r in results:
@ -425,7 +427,9 @@ class ReplicationLayer(object):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
defer.returnValue((200, { defer.returnValue((200, {
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
"auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]], "auth_chain": [
p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
],
})) }))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -436,7 +440,9 @@ class ReplicationLayer(object):
( (
200, 200,
{ {
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], "auth_chain": [
a.get_pdu_json(time_now) for a in auth_pdus
],
} }
) )
) )
@ -457,7 +463,7 @@ class ReplicationLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def send_join(self, destination, pdu): def send_join(self, destination, pdu):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join( _, content = yield self.transport_layer.send_join(
destination, destination,
pdu.room_id, pdu.room_id,
@ -475,11 +481,17 @@ class ReplicationLayer(object):
# FIXME: We probably want to do something with the auth_chain given # FIXME: We probably want to do something with the auth_chain given
# to us # to us
# auth_chain = [ auth_chain = [
# Pdu(outlier=True, **p) for p in content.get("auth_chain", []) self.event_from_pdu_json(p, outlier=True)
# ] for p in content.get("auth_chain", [])
]
defer.returnValue(state) auth_chain.sort(key=lambda e: e.depth)
defer.returnValue({
"state": state,
"auth_chain": auth_chain,
})
@defer.inlineCallbacks @defer.inlineCallbacks
def send_invite(self, destination, context, event_id, pdu): def send_invite(self, destination, context, event_id, pdu):
@ -498,13 +510,15 @@ class ReplicationLayer(object):
defer.returnValue(self.event_from_pdu_json(pdu_dict)) defer.returnValue(self.event_from_pdu_json(pdu_dict))
@log_function @log_function
def _get_persisted_pdu(self, origin, event_id): def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id. """ Get a PDU from the database with given origin and id.
Returns: Returns:
Deferred: Results in a `Pdu`. Deferred: Results in a `Pdu`.
""" """
return self.handler.get_persisted_pdu(origin, event_id) return self.handler.get_persisted_pdu(
origin, event_id, do_auth=do_auth
)
def _transaction_from_pdus(self, pdu_list): def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for """Returns a new Transaction containing the given PDUs suitable for
@ -523,7 +537,9 @@ class ReplicationLayer(object):
@log_function @log_function
def _handle_new_pdu(self, origin, pdu, backfilled=False): def _handle_new_pdu(self, origin, pdu, backfilled=False):
# We reprocess pdus when we have seen them only as outliers # We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu(origin, pdu.event_id) existing = yield self._get_persisted_pdu(
origin, pdu.event_id, do_auth=False
)
if existing and (not existing.outlier or pdu.outlier): if existing and (not existing.outlier or pdu.outlier):
logger.debug("Already seen pdu %s", pdu.event_id) logger.debug("Already seen pdu %s", pdu.event_id)
@ -532,6 +548,36 @@ class ReplicationLayer(object):
state = None state = None
# We need to make sure we have all the auth events.
for e_id, _ in pdu.auth_events:
exists = yield self._get_persisted_pdu(
origin,
e_id,
do_auth=False
)
if not exists:
try:
logger.debug(
"_handle_new_pdu fetch missing auth event %s from %s",
e_id,
origin,
)
yield self.get_pdu(
origin,
event_id=e_id,
outlier=True,
)
logger.debug("Processed pdu %s", e_id)
except:
logger.warn(
"Failed to get auth event %s from %s",
e_id,
origin
)
# Get missing pdus if necessary. # Get missing pdus if necessary.
if not pdu.outlier: if not pdu.outlier:
# We only backfill backwards to the min depth. # We only backfill backwards to the min depth.
@ -539,16 +585,28 @@ class ReplicationLayer(object):
pdu.room_id pdu.room_id
) )
logger.debug(
"_handle_new_pdu min_depth for %s: %d",
pdu.room_id, min_depth
)
if min_depth and pdu.depth > min_depth: if min_depth and pdu.depth > min_depth:
for event_id, hashes in pdu.prev_events: for event_id, hashes in pdu.prev_events:
exists = yield self._get_persisted_pdu(origin, event_id) exists = yield self._get_persisted_pdu(
origin,
event_id,
do_auth=False
)
if not exists: if not exists:
logger.debug("Requesting pdu %s", event_id) logger.debug(
"_handle_new_pdu requesting pdu %s",
event_id
)
try: try:
yield self.get_pdu( yield self.get_pdu(
pdu.origin, origin,
event_id=event_id, event_id=event_id,
) )
logger.debug("Processed pdu %s", event_id) logger.debug("Processed pdu %s", event_id)
@ -558,6 +616,10 @@ class ReplicationLayer(object):
else: else:
# We need to get the state at this event, since we have reached # We need to get the state at this event, since we have reached
# a backward extremity edge. # a backward extremity edge.
logger.debug(
"_handle_new_pdu getting state for %s",
pdu.room_id
)
state = yield self.get_state_for_context( state = yield self.get_state_for_context(
origin, pdu.room_id, pdu.event_id, origin, pdu.room_id, pdu.event_id,
) )
@ -649,7 +711,8 @@ class _TransactionQueue(object):
(pdu, deferred, order) (pdu, deferred, order)
) )
self._attempt_new_transaction(destination) with PreserveLoggingContext():
self._attempt_new_transaction(destination)
deferreds.append(deferred) deferreds.append(deferred)
@ -669,7 +732,9 @@ class _TransactionQueue(object):
deferred.errback(failure) deferred.errback(failure)
else: else:
logger.exception("Failed to send edu", failure) logger.exception("Failed to send edu", failure)
self._attempt_new_transaction(destination).addErrback(eb)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(eb)
return deferred return deferred

View File

@ -25,7 +25,6 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Edu(JsonEncodedObject): class Edu(JsonEncodedObject):
""" An Edu represents a piece of data sent from one homeserver to another. """ An Edu represents a piece of data sent from one homeserver to another.

View File

@ -78,7 +78,7 @@ class BaseHandler(object):
if not suppress_auth: if not suppress_auth:
logger.debug("Authing...") logger.debug("Authing...")
self.auth.check(event, raises=True) self.auth.check(event, auth_events=event.old_state_events)
logger.debug("Authed") logger.debug("Authed")
else: else:
logger.debug("Suppressed auth.") logger.debug("Suppressed auth.")
@ -112,7 +112,7 @@ class BaseHandler(object):
event.destinations = list(destinations) event.destinations = list(destinations)
self.notifier.on_new_room_event(event, extra_users=extra_users) yield self.notifier.on_new_room_event(event, extra_users=extra_users)
federation_handler = self.hs.get_handlers().federation_handler federation_handler = self.hs.get_handlers().federation_handler
yield federation_handler.handle_new_event(event, snapshot) yield federation_handler.handle_new_event(event, snapshot)

View File

@ -17,7 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, Codes, CodeMessageException
from synapse.api.events.room import RoomAliasesEvent from synapse.api.events.room import RoomAliasesEvent
import logging import logging
@ -84,22 +84,32 @@ class DirectoryHandler(BaseHandler):
room_id = result.room_id room_id = result.room_id
servers = result.servers servers = result.servers
else: else:
result = yield self.federation.make_query( try:
destination=room_alias.domain, result = yield self.federation.make_query(
query_type="directory", destination=room_alias.domain,
args={ query_type="directory",
"room_alias": room_alias.to_string(), args={
}, "room_alias": room_alias.to_string(),
retry_on_dns_fail=False, },
) retry_on_dns_fail=False,
)
except CodeMessageException as e:
logging.warn("Error retrieving alias")
if e.code == 404:
result = None
else:
raise
if result and "room_id" in result and "servers" in result: if result and "room_id" in result and "servers" in result:
room_id = result["room_id"] room_id = result["room_id"]
servers = result["servers"] servers = result["servers"]
if not room_id: if not room_id:
defer.returnValue({}) raise SynapseError(
return 404,
"Room alias %r not found" % (room_alias.to_string(),),
Codes.NOT_FOUND
)
extra_servers = yield self.store.get_joined_hosts_for_room(room_id) extra_servers = yield self.store.get_joined_hosts_for_room(room_id)
servers = list(set(extra_servers) | set(servers)) servers = list(set(extra_servers) | set(servers))
@ -128,8 +138,11 @@ class DirectoryHandler(BaseHandler):
"servers": result.servers, "servers": result.servers,
}) })
else: else:
raise SynapseError(404, "Room alias \"%s\" not found", room_alias) raise SynapseError(
404,
"Room alias %r not found" % (room_alias.to_string(),),
Codes.NOT_FOUND
)
@defer.inlineCallbacks @defer.inlineCallbacks
def send_room_alias_update_event(self, user_id, room_id): def send_room_alias_update_event(self, user_id, room_id):

View File

@ -56,7 +56,7 @@ class EventStreamHandler(BaseHandler):
self.clock.cancel_call_later( self.clock.cancel_call_later(
self._stop_timer_per_user.pop(auth_user)) self._stop_timer_per_user.pop(auth_user))
else: else:
self.distributor.fire( yield self.distributor.fire(
"started_user_eventstream", auth_user "started_user_eventstream", auth_user
) )
self._streams_per_user[auth_user] += 1 self._streams_per_user[auth_user] += 1
@ -65,8 +65,10 @@ class EventStreamHandler(BaseHandler):
pagin_config.from_token = None pagin_config.from_token = None
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler
logger.debug("BETA")
room_ids = yield rm_handler.get_rooms_for_user(auth_user) room_ids = yield rm_handler.get_rooms_for_user(auth_user)
logger.debug("ALPHA")
with PreserveLoggingContext(): with PreserveLoggingContext():
events, tokens = yield self.notifier.get_events_for( events, tokens = yield self.notifier.get_events_for(
auth_user, room_ids, pagin_config, timeout auth_user, room_ids, pagin_config, timeout
@ -93,7 +95,7 @@ class EventStreamHandler(BaseHandler):
logger.debug( logger.debug(
"_later stopped_user_eventstream %s", auth_user "_later stopped_user_eventstream %s", auth_user
) )
self.distributor.fire( yield self.distributor.fire(
"stopped_user_eventstream", auth_user "stopped_user_eventstream", auth_user
) )
del self._stop_timer_per_user[auth_user] del self._stop_timer_per_user[auth_user]

View File

@ -24,7 +24,8 @@ from synapse.api.constants import Membership
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import ( from synapse.crypto.event_signing import (
compute_event_signature, check_event_content_hash compute_event_signature, check_event_content_hash,
add_hashes_and_signatures,
) )
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
@ -122,7 +123,8 @@ class FederationHandler(BaseHandler):
event.origin, redacted_pdu_json event.origin, redacted_pdu_json
) )
except SynapseError as e: except SynapseError as e:
logger.warn("Signature check failed for %s redacted to %s", logger.warn(
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()), encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json), encode_canonical_json(redacted_pdu_json),
) )
@ -140,15 +142,27 @@ class FederationHandler(BaseHandler):
) )
event = redacted_event event = redacted_event
is_new_state = yield self.state_handler.annotate_event_with_state(
event,
old_state=state
)
logger.debug("Event: %s", event) logger.debug("Event: %s", event)
# FIXME (erikj): Awful hack to make the case where we are not currently
# in the room work
current_state = None
if state:
is_in_room = yield self.auth.check_host_in_room(
event.room_id,
self.server_name
)
if not is_in_room:
logger.debug("Got event for room we're not in.")
current_state = state
try: try:
self.auth.check(event, raises=True) yield self._handle_new_event(
event,
state=state,
backfilled=backfilled,
current_state=current_state,
)
except AuthError as e: except AuthError as e:
raise FederationError( raise FederationError(
"ERROR", "ERROR",
@ -157,43 +171,14 @@ class FederationHandler(BaseHandler):
affected=event.event_id, affected=event.event_id,
) )
is_new_state = is_new_state and not backfilled
# TODO: Implement something in federation that allows us to
# respond to PDU.
yield self.store.persist_event(
event,
backfilled,
is_new_state=is_new_state
)
room = yield self.store.get_room(event.room_id) room = yield self.store.get_room(event.room_id)
if not room: if not room:
# Huh, let's try and get the current state yield self.store.store_room(
try: room_id=event.room_id,
yield self.replication_layer.get_state_for_context( room_creator_user_id="",
event.origin, event.room_id, event.event_id, is_public=False,
) )
hosts = yield self.store.get_joined_hosts_for_room(
event.room_id
)
if self.hs.hostname in hosts:
try:
yield self.store.store_room(
room_id=event.room_id,
room_creator_user_id="",
is_public=False,
)
except:
pass
except:
logger.exception(
"Failed to get current state for room %s",
event.room_id
)
if not backfilled: if not backfilled:
extra_users = [] extra_users = []
@ -209,7 +194,7 @@ class FederationHandler(BaseHandler):
if event.type == RoomMemberEvent.TYPE: if event.type == RoomMemberEvent.TYPE:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
user = self.hs.parse_userid(event.state_key) user = self.hs.parse_userid(event.state_key)
self.distributor.fire( yield self.distributor.fire(
"user_joined_room", user=user, room_id=event.room_id "user_joined_room", user=user, room_id=event.room_id
) )
@ -254,6 +239,8 @@ class FederationHandler(BaseHandler):
pdu=event pdu=event
) )
defer.returnValue(pdu) defer.returnValue(pdu)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -275,6 +262,8 @@ class FederationHandler(BaseHandler):
We suspend processing of any received events from this room until we We suspend processing of any received events from this room until we
have finished processing the join. have finished processing the join.
""" """
logger.debug("Joining %s to %s", joinee, room_id)
pdu = yield self.replication_layer.make_join( pdu = yield self.replication_layer.make_join(
target_host, target_host,
room_id, room_id,
@ -297,19 +286,28 @@ class FederationHandler(BaseHandler):
try: try:
event.event_id = self.event_factory.create_event_id() event.event_id = self.event_factory.create_event_id()
event.origin = self.hs.hostname
event.content = content event.content = content
state = yield self.replication_layer.send_join( if not hasattr(event, "signatures"):
event.signatures = {}
add_hashes_and_signatures(
event,
self.hs.hostname,
self.hs.config.signing_key[0],
)
ret = yield self.replication_layer.send_join(
target_host, target_host,
event event
) )
logger.debug("do_invite_join state: %s", state) state = ret["state"]
auth_chain = ret["auth_chain"]
yield self.state_handler.annotate_event_with_state( logger.debug("do_invite_join auth_chain: %s", auth_chain)
event, logger.debug("do_invite_join state: %s", state)
old_state=state
)
logger.debug("do_invite_join event: %s", event) logger.debug("do_invite_join event: %s", event)
@ -323,34 +321,41 @@ class FederationHandler(BaseHandler):
# FIXME # FIXME
pass pass
for e in auth_chain:
e.outlier = True
yield self._handle_new_event(e)
yield self.notifier.on_new_room_event(
e, extra_users=[joinee]
)
for e in state: for e in state:
# FIXME: Auth these. # FIXME: Auth these.
e.outlier = True e.outlier = True
yield self._handle_new_event(e)
yield self.state_handler.annotate_event_with_state( yield self.notifier.on_new_room_event(
e, e, extra_users=[joinee]
) )
yield self.store.persist_event( yield self._handle_new_event(
e,
backfilled=False,
is_new_state=True
)
yield self.store.persist_event(
event, event,
backfilled=False, state=state,
is_new_state=True current_state=state
) )
yield self.notifier.on_new_room_event(
event, extra_users=[joinee]
)
logger.debug("Finished joining %s to %s", joinee, room_id)
finally: finally:
room_queue = self.room_queues[room_id] room_queue = self.room_queues[room_id]
del self.room_queues[room_id] del self.room_queues[room_id]
for p in room_queue: for p in room_queue:
try: try:
yield self.on_receive_pdu(p, backfilled=False) self.on_receive_pdu(p, backfilled=False)
except: except:
pass logger.exception("Couldn't handle pdu")
defer.returnValue(True) defer.returnValue(True)
@ -374,7 +379,7 @@ class FederationHandler(BaseHandler):
yield self.state_handler.annotate_event_with_state(event) yield self.state_handler.annotate_event_with_state(event)
yield self.auth.add_auth_events(event) yield self.auth.add_auth_events(event)
self.auth.check(event, raises=True) self.auth.check(event, auth_events=event.old_state_events)
pdu = event pdu = event
@ -390,16 +395,7 @@ class FederationHandler(BaseHandler):
event.outlier = False event.outlier = False
is_new_state = yield self.state_handler.annotate_event_with_state(event) yield self._handle_new_event(event)
self.auth.check(event, raises=True)
# FIXME (erikj): All this is duplicated above :(
yield self.store.persist_event(
event,
backfilled=False,
is_new_state=is_new_state
)
extra_users = [] extra_users = []
if event.type == RoomMemberEvent.TYPE: if event.type == RoomMemberEvent.TYPE:
@ -412,9 +408,9 @@ class FederationHandler(BaseHandler):
) )
if event.type == RoomMemberEvent.TYPE: if event.type == RoomMemberEvent.TYPE:
if event.membership == Membership.JOIN: if event.content["membership"] == Membership.JOIN:
user = self.hs.parse_userid(event.state_key) user = self.hs.parse_userid(event.state_key)
self.distributor.fire( yield self.distributor.fire(
"user_joined_room", user=user, room_id=event.room_id "user_joined_room", user=user, room_id=event.room_id
) )
@ -527,7 +523,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_persisted_pdu(self, origin, event_id): def get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id. """ Get a PDU from the database with given origin and id.
Returns: Returns:
@ -539,12 +535,13 @@ class FederationHandler(BaseHandler):
) )
if event: if event:
in_room = yield self.auth.check_host_in_room( if do_auth:
event.room_id, in_room = yield self.auth.check_host_in_room(
origin event.room_id,
) origin
if not in_room: )
raise AuthError(403, "Host not in room.") if not in_room:
raise AuthError(403, "Host not in room.")
defer.returnValue(event) defer.returnValue(event)
else: else:
@ -562,3 +559,65 @@ class FederationHandler(BaseHandler):
) )
while waiters: while waiters:
waiters.pop().callback(None) waiters.pop().callback(None)
@defer.inlineCallbacks
def _handle_new_event(self, event, state=None, backfilled=False,
current_state=None):
if state:
for s in state:
yield self._handle_new_event(s)
is_new_state = yield self.state_handler.annotate_event_with_state(
event,
old_state=state
)
if event.old_state_events:
known_ids = set(
[s.event_id for s in event.old_state_events.values()]
)
for e_id, _ in event.auth_events:
if e_id not in known_ids:
e = yield self.store.get_event(
e_id,
allow_none=True,
)
if not e:
# TODO: Do some conflict res to make sure that we're
# not the ones who are wrong.
logger.info(
"Rejecting %s as %s not in %s",
event.event_id, e_id, known_ids,
)
raise AuthError(403, "Auth events are stale")
auth_events = event.old_state_events
else:
# We need to get the auth events from somewhere.
# TODO: Don't just hit the DBs?
auth_events = {}
for e_id, _ in event.auth_events:
e = yield self.store.get_event(
e_id,
allow_none=True,
)
if not e:
raise AuthError(
403,
"Can't find auth event %s." % (e_id, )
)
auth_events[(e.type, e.state_key)] = e
self.auth.check(event, auth_events=auth_events)
yield self.store.persist_event(
event,
backfilled=backfilled,
is_new_state=(is_new_state and not backfilled),
current_state=current_state,
)

View File

@ -17,13 +17,12 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import LoginError, Codes from synapse.api.errors import LoginError, Codes
from synapse.http.client import IdentityServerHttpClient from synapse.http.client import SimpleHttpClient
from synapse.util.emailutils import EmailException from synapse.util.emailutils import EmailException
import synapse.util.emailutils as emailutils import synapse.util.emailutils as emailutils
import bcrypt import bcrypt
import logging import logging
import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -97,10 +96,16 @@ class LoginHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _query_email(self, email): def _query_email(self, email):
httpCli = IdentityServerHttpClient(self.hs) httpCli = SimpleHttpClient(self.hs)
data = yield httpCli.get_json( data = yield httpCli.get_json(
'matrix.org:8090', # TODO FIXME This should be configurable. # TODO FIXME This should be configurable.
"/_matrix/identity/api/v1/lookup?medium=email&address=" + # XXX: ID servers need to use HTTPS
"%s" % urllib.quote(email) "http://%s%s" % (
"matrix.org:8090", "/_matrix/identity/api/v1/lookup"
),
{
'medium': 'email',
'address': email
}
) )
defer.returnValue(data) defer.returnValue(data)

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.errors import RoomError from synapse.api.errors import RoomError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.util.logcontext import PreserveLoggingContext
from ._base import BaseHandler from ._base import BaseHandler
import logging import logging
@ -86,9 +87,10 @@ class MessageHandler(BaseHandler):
event, snapshot, suppress_auth=suppress_auth event, snapshot, suppress_auth=suppress_auth
) )
self.hs.get_handlers().presence_handler.bump_presence_active_time( with PreserveLoggingContext():
user self.hs.get_handlers().presence_handler.bump_presence_active_time(
) user
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, user_id=None, room_id=None, pagin_config=None, def get_messages(self, user_id=None, room_id=None, pagin_config=None,
@ -241,7 +243,7 @@ class MessageHandler(BaseHandler):
public_room_ids = [r["room_id"] for r in public_rooms] public_room_ids = [r["room_id"] for r in public_rooms]
limit = pagin_config.limit limit = pagin_config.limit
if not limit: if limit is None:
limit = 10 limit = 10
for event in room_list: for event in room_list:
@ -296,7 +298,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def room_initial_sync(self, user_id, room_id, pagin_config=None, def room_initial_sync(self, user_id, room_id, pagin_config=None,
feedback=False): feedback=False):
yield self.auth.check_joined_room(room_id, user_id) yield self.auth.check_joined_room(room_id, user_id)
# TODO(paul): I wish I was called with user objects not user_id # TODO(paul): I wish I was called with user objects not user_id
@ -304,7 +306,7 @@ class MessageHandler(BaseHandler):
auth_user = self.hs.parse_userid(user_id) auth_user = self.hs.parse_userid(user_id)
# TODO: These concurrently # TODO: These concurrently
state_tuples = yield self.store.get_current_state(room_id) state_tuples = yield self.state_handler.get_current_state(room_id)
state = [self.hs.serialize_event(x) for x in state_tuples] state = [self.hs.serialize_event(x) for x in state_tuples]
member_event = (yield self.store.get_room_member( member_event = (yield self.store.get_room_member(
@ -340,8 +342,8 @@ class MessageHandler(BaseHandler):
) )
presence.append(member_presence) presence.append(member_presence)
except Exception: except Exception:
logger.exception("Failed to get member presence of %r", logger.exception(
m.user_id "Failed to get member presence of %r", m.user_id
) )
defer.returnValue({ defer.returnValue({

View File

@ -19,6 +19,7 @@ from synapse.api.errors import SynapseError, AuthError
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from ._base import BaseHandler from ._base import BaseHandler
@ -142,7 +143,7 @@ class PresenceHandler(BaseHandler):
return UserPresenceCache() return UserPresenceCache()
def registered_user(self, user): def registered_user(self, user):
self.store.create_presence(user.localpart) return self.store.create_presence(user.localpart)
@defer.inlineCallbacks @defer.inlineCallbacks
def is_presence_visible(self, observer_user, observed_user): def is_presence_visible(self, observer_user, observed_user):
@ -241,14 +242,12 @@ class PresenceHandler(BaseHandler):
was_level = self.STATE_LEVELS[statuscache.get_state()["presence"]] was_level = self.STATE_LEVELS[statuscache.get_state()["presence"]]
now_level = self.STATE_LEVELS[state["presence"]] now_level = self.STATE_LEVELS[state["presence"]]
yield defer.DeferredList([ yield self.store.set_presence_state(
self.store.set_presence_state( target_user.localpart, state_to_store
target_user.localpart, state_to_store )
), yield self.distributor.fire(
self.distributor.fire( "collect_presencelike_data", target_user, state
"collect_presencelike_data", target_user, state )
),
])
if now_level > was_level: if now_level > was_level:
state["last_active"] = self.clock.time_msec() state["last_active"] = self.clock.time_msec()
@ -256,14 +255,15 @@ class PresenceHandler(BaseHandler):
now_online = state["presence"] != PresenceState.OFFLINE now_online = state["presence"] != PresenceState.OFFLINE
was_polling = target_user in self._user_cachemap was_polling = target_user in self._user_cachemap
if now_online and not was_polling: with PreserveLoggingContext():
self.start_polling_presence(target_user, state=state) if now_online and not was_polling:
elif not now_online and was_polling: self.start_polling_presence(target_user, state=state)
self.stop_polling_presence(target_user) elif not now_online and was_polling:
self.stop_polling_presence(target_user)
# TODO(paul): perform a presence push as part of start/stop poll so # TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time # we don't have to do this all the time
self.changed_presencelike_data(target_user, state) self.changed_presencelike_data(target_user, state)
def bump_presence_active_time(self, user, now=None): def bump_presence_active_time(self, user, now=None):
if now is None: if now is None:
@ -277,7 +277,7 @@ class PresenceHandler(BaseHandler):
self._user_cachemap_latest_serial += 1 self._user_cachemap_latest_serial += 1
statuscache.update(state, serial=self._user_cachemap_latest_serial) statuscache.update(state, serial=self._user_cachemap_latest_serial)
self.push_presence(user, statuscache=statuscache) return self.push_presence(user, statuscache=statuscache)
@log_function @log_function
def started_user_eventstream(self, user): def started_user_eventstream(self, user):
@ -381,8 +381,10 @@ class PresenceHandler(BaseHandler):
yield self.store.set_presence_list_accepted( yield self.store.set_presence_list_accepted(
observer_user.localpart, observed_user.to_string() observer_user.localpart, observed_user.to_string()
) )
with PreserveLoggingContext():
self.start_polling_presence(observer_user, target_user=observed_user) self.start_polling_presence(
observer_user, target_user=observed_user
)
@defer.inlineCallbacks @defer.inlineCallbacks
def deny_presence(self, observed_user, observer_user): def deny_presence(self, observed_user, observer_user):
@ -401,7 +403,10 @@ class PresenceHandler(BaseHandler):
observer_user.localpart, observed_user.to_string() observer_user.localpart, observed_user.to_string()
) )
self.stop_polling_presence(observer_user, target_user=observed_user) with PreserveLoggingContext():
self.stop_polling_presence(
observer_user, target_user=observed_user
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None): def get_presence_list(self, observer_user, accepted=None):
@ -710,7 +715,8 @@ class PresenceHandler(BaseHandler):
if not self._remote_sendmap[user]: if not self._remote_sendmap[user]:
del self._remote_sendmap[user] del self._remote_sendmap[user]
yield defer.DeferredList(deferreds) with PreserveLoggingContext():
yield defer.DeferredList(deferreds)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_update_to_local_and_remote(self, observed_user, statuscache, def push_update_to_local_and_remote(self, observed_user, statuscache,

View File

@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.util.logcontext import PreserveLoggingContext
from ._base import BaseHandler from ._base import BaseHandler
@ -46,7 +47,7 @@ class ProfileHandler(BaseHandler):
) )
def registered_user(self, user): def registered_user(self, user):
self.store.create_profile(user.localpart) return self.store.create_profile(user.localpart)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_displayname(self, target_user): def get_displayname(self, target_user):
@ -152,13 +153,14 @@ class ProfileHandler(BaseHandler):
if not user.is_mine: if not user.is_mine:
defer.returnValue(None) defer.returnValue(None)
(displayname, avatar_url) = yield defer.gatherResults( with PreserveLoggingContext():
[ (displayname, avatar_url) = yield defer.gatherResults(
self.store.get_profile_displayname(user.localpart), [
self.store.get_profile_avatar_url(user.localpart), self.store.get_profile_displayname(user.localpart),
], self.store.get_profile_avatar_url(user.localpart),
consumeErrors=True ],
) consumeErrors=True
)
state["displayname"] = displayname state["displayname"] = displayname
state["avatar_url"] = avatar_url state["avatar_url"] = avatar_url

View File

@ -22,7 +22,7 @@ from synapse.api.errors import (
) )
from ._base import BaseHandler from ._base import BaseHandler
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.http.client import IdentityServerHttpClient from synapse.http.client import SimpleHttpClient
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
import base64 import base64
@ -69,7 +69,7 @@ class RegistrationHandler(BaseHandler):
password_hash=password_hash password_hash=password_hash
) )
self.distributor.fire("registered_user", user) yield self.distributor.fire("registered_user", user)
else: else:
# autogen a random user ID # autogen a random user ID
attempts = 0 attempts = 0
@ -133,7 +133,7 @@ class RegistrationHandler(BaseHandler):
if not threepid: if not threepid:
raise RegistrationError(400, "Couldn't validate 3pid") raise RegistrationError(400, "Couldn't validate 3pid")
logger.info("got threepid medium %s address %s", logger.info("got threepid with medium '%s' and address '%s'",
threepid['medium'], threepid['address']) threepid['medium'], threepid['address'])
@defer.inlineCallbacks @defer.inlineCallbacks
@ -159,7 +159,7 @@ class RegistrationHandler(BaseHandler):
def _threepid_from_creds(self, creds): def _threepid_from_creds(self, creds):
# TODO: get this from the homeserver rather than creating a new one for # TODO: get this from the homeserver rather than creating a new one for
# each request # each request
httpCli = IdentityServerHttpClient(self.hs) httpCli = SimpleHttpClient(self.hs)
# XXX: make this configurable! # XXX: make this configurable!
trustedIdServers = ['matrix.org:8090'] trustedIdServers = ['matrix.org:8090']
if not creds['idServer'] in trustedIdServers: if not creds['idServer'] in trustedIdServers:
@ -167,8 +167,11 @@ class RegistrationHandler(BaseHandler):
'credentials', creds['idServer']) 'credentials', creds['idServer'])
defer.returnValue(None) defer.returnValue(None)
data = yield httpCli.get_json( data = yield httpCli.get_json(
creds['idServer'], # XXX: This should be HTTPS
"/_matrix/identity/api/v1/3pid/getValidated3pid", "http://%s%s" % (
creds['idServer'],
"/_matrix/identity/api/v1/3pid/getValidated3pid"
),
{'sid': creds['sid'], 'clientSecret': creds['clientSecret']} {'sid': creds['sid'], 'clientSecret': creds['clientSecret']}
) )
@ -178,16 +181,21 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _bind_threepid(self, creds, mxid): def _bind_threepid(self, creds, mxid):
httpCli = IdentityServerHttpClient(self.hs) yield
logger.debug("binding threepid")
httpCli = SimpleHttpClient(self.hs)
data = yield httpCli.post_urlencoded_get_json( data = yield httpCli.post_urlencoded_get_json(
creds['idServer'], # XXX: Change when ID servers are all HTTPS
"/_matrix/identity/api/v1/3pid/bind", "http://%s%s" % (
creds['idServer'], "/_matrix/identity/api/v1/3pid/bind"
),
{ {
'sid': creds['sid'], 'sid': creds['sid'],
'clientSecret': creds['clientSecret'], 'clientSecret': creds['clientSecret'],
'mxid': mxid, 'mxid': mxid,
} }
) )
logger.debug("bound threepid")
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -215,10 +223,7 @@ class RegistrationHandler(BaseHandler):
# each request # each request
client = CaptchaServerHttpClient(self.hs) client = CaptchaServerHttpClient(self.hs)
data = yield client.post_urlencoded_get_raw( data = yield client.post_urlencoded_get_raw(
"www.google.com:80", "http://www.google.com:80/recaptcha/api/verify",
"/recaptcha/api/verify",
# twisted dislikes google's response, no content length.
accept_partial=True,
args={ args={
'privatekey': private_key, 'privatekey': private_key,
'remoteip': ip_addr, 'remoteip': ip_addr,

View File

@ -178,7 +178,9 @@ class RoomCreationHandler(BaseHandler):
if room_alias: if room_alias:
result["room_alias"] = room_alias.to_string() result["room_alias"] = room_alias.to_string()
directory_handler.send_room_alias_update_event(user_id, room_id) yield directory_handler.send_room_alias_update_event(
user_id, room_id
)
defer.returnValue(result) defer.returnValue(result)
@ -211,7 +213,6 @@ class RoomCreationHandler(BaseHandler):
**event_keys **event_keys
) )
power_levels_event = self.event_factory.create_event( power_levels_event = self.event_factory.create_event(
etype=RoomPowerLevelsEvent.TYPE, etype=RoomPowerLevelsEvent.TYPE,
content={ content={
@ -480,7 +481,7 @@ class RoomMemberHandler(BaseHandler):
) )
user = self.hs.parse_userid(event.user_id) user = self.hs.parse_userid(event.user_id)
self.distributor.fire( yield self.distributor.fire(
"user_joined_room", user=user, room_id=room_id "user_joined_room", user=user, room_id=room_id
) )

View File

@ -15,308 +15,45 @@
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import DNSLookupError
from twisted.web.client import ( from twisted.web.client import (
_AgentBase, _URI, readBody, FileBodyProducer, PartialDownloadError Agent, readBody, FileBodyProducer, PartialDownloadError
) )
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from synapse.http.endpoint import matrix_endpoint
from synapse.util.async import sleep
from synapse.util.logcontext import PreserveLoggingContext
from syutil.jsonutil import encode_canonical_json
from synapse.api.errors import CodeMessageException, SynapseError
from syutil.crypto.jsonsign import sign_json
from StringIO import StringIO from StringIO import StringIO
import json import json
import logging import logging
import urllib import urllib
import urlparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MatrixHttpAgent(_AgentBase): class SimpleHttpClient(object):
"""
def __init__(self, reactor, pool=None): A simple, no-frills HTTP client with methods that wrap up common ways of
_AgentBase.__init__(self, reactor, pool) using HTTP in Matrix
def request(self, destination, endpoint, method, path, params, query,
headers, body_producer):
host = b""
port = 0
fragment = b""
parsed_URI = _URI(b"http", destination, host, port, path, params,
query, fragment)
# Set the connection pool key to be the destination.
key = destination
return self._requestWithEndpoint(key, endpoint, method, parsed_URI,
headers, body_producer,
parsed_URI.originForm)
class BaseHttpClient(object):
"""Base class for HTTP clients using twisted.
""" """
def __init__(self, hs): def __init__(self, hs):
self.agent = MatrixHttpAgent(reactor)
self.hs = hs self.hs = hs
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
self.agent = Agent(reactor)
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes, def post_urlencoded_get_json(self, uri, args={}):
body_callback, headers_dict={}, param_bytes=b"",
query_bytes=b"", retry_on_dns_fail=True):
""" Creates and sends a request to the given url
"""
headers_dict[b"User-Agent"] = [b"Synapse"]
headers_dict[b"Host"] = [destination]
url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "",)
)
logger.debug("Sending request to %s: %s %s",
destination, method, url_bytes)
logger.debug(
"Types: %s",
[
type(destination), type(method), type(path_bytes),
type(param_bytes),
type(query_bytes)
]
)
retries_left = 5
endpoint = self._getEndpoint(reactor, destination)
while True:
producer = None
if body_callback:
producer = body_callback(method, url_bytes, headers_dict)
try:
with PreserveLoggingContext():
response = yield self.agent.request(
destination,
endpoint,
method,
path_bytes,
param_bytes,
query_bytes,
Headers(headers_dict),
producer
)
logger.debug("Got response to %s", method)
break
except Exception as e:
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
logger.warn("DNS Lookup failed to %s with %s", destination,
e)
raise SynapseError(400, "Domain specified not found.")
logger.exception("Got error in _create_request")
_print_ex(e)
if retries_left:
yield sleep(2 ** (5 - retries_left))
retries_left -= 1
else:
raise
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
pass
else:
# :'(
# Update transactions table?
logger.error(
"Got response %d %s", response.code, response.phrase
)
raise CodeMessageException(
response.code, response.phrase
)
defer.returnValue(response)
class MatrixHttpClient(BaseHttpClient):
""" Wrapper around the twisted HTTP client api. Implements
Attributes:
agent (twisted.web.client.Agent): The twisted Agent used to send the
requests.
"""
RETRY_DNS_LOOKUP_FAILURES = "__retry_dns"
def __init__(self, hs):
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
BaseHttpClient.__init__(self, hs)
def sign_request(self, destination, method, url_bytes, headers_dict,
content=None):
request = {
"method": method,
"uri": url_bytes,
"origin": self.server_name,
"destination": destination,
}
if content is not None:
request["content"] = content
request = sign_json(request, self.server_name, self.signing_key)
auth_headers = []
for key, sig in request["signatures"][self.server_name].items():
auth_headers.append(bytes(
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
self.server_name, key, sig,
)
))
headers_dict[b"Authorization"] = auth_headers
@defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None):
""" Sends the specifed json data using PUT
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to
use as the request body.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. On a 4xx or 5xx error response a
CodeMessageException is raised.
"""
if not json_data_callback:
def json_data_callback():
return data
def body_callback(method, url_bytes, headers_dict):
json_data = json_data_callback()
self.sign_request(
destination, method, url_bytes, headers_dict, json_data
)
producer = _JsonProducer(json_data)
return producer
response = yield self._create_request(
destination.encode("ascii"),
"PUT",
path.encode("ascii"),
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
)
logger.debug("Getting resp body")
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue((response.code, body))
@defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
""" Get's some json from the given host homeserver and path
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
Returns:
Deferred: Succeeds when we get *any* HTTP response.
The result of the deferred is a tuple of `(code, response)`,
where `response` is a dict representing the decoded JSON body.
"""
logger.debug("get_json args: %s", args)
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, basestring):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True)
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
response = yield self._create_request(
destination.encode("ascii"),
"GET",
path.encode("ascii"),
query_bytes=query_bytes,
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail
)
body = yield readBody(response)
defer.returnValue(json.loads(body))
def _getEndpoint(self, reactor, destination):
return matrix_endpoint(
reactor, destination, timeout=10,
ssl_context_factory=self.hs.tls_context_factory
)
class IdentityServerHttpClient(BaseHttpClient):
"""Separate HTTP client for talking to the Identity servers since they
don't use SRV records and talk x-www-form-urlencoded rather than JSON.
"""
def _getEndpoint(self, reactor, destination):
#TODO: This should be talking TLS
return matrix_endpoint(reactor, destination, timeout=10)
@defer.inlineCallbacks
def post_urlencoded_get_json(self, destination, path, args={}):
logger.debug("post_urlencoded_get_json args: %s", args) logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(args, True)
def body_callback(method, url_bytes, headers_dict): response = yield self.agent.request(
return FileBodyProducer(StringIO(query_bytes))
response = yield self._create_request(
destination.encode("ascii"),
"POST", "POST",
path.encode("ascii"), uri.encode("ascii"),
body_callback=body_callback, headers=Headers({
headers_dict={
"Content-Type": ["application/x-www-form-urlencoded"] "Content-Type": ["application/x-www-form-urlencoded"]
} }),
bodyProducer=FileBodyProducer(StringIO(query_bytes))
) )
body = yield readBody(response) body = yield readBody(response)
@ -324,13 +61,11 @@ class IdentityServerHttpClient(BaseHttpClient):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True): def get_json(self, uri, args={}):
""" Get's some json from the given host homeserver and path """ Get's some json from the given host and path
Args: Args:
destination (str): The remote server to send the HTTP request uri (str): The URI to request, not including query parameters
to.
path (str): The HTTP path.
args (dict): A dictionary used to create query strings, defaults to args (dict): A dictionary used to create query strings, defaults to
None. None.
**Note**: The value of each key is assumed to be an iterable **Note**: The value of each key is assumed to be an iterable
@ -342,18 +77,15 @@ class IdentityServerHttpClient(BaseHttpClient):
The result of the deferred is a tuple of `(code, response)`, The result of the deferred is a tuple of `(code, response)`,
where `response` is a dict representing the decoded JSON body. where `response` is a dict representing the decoded JSON body.
""" """
logger.debug("get_json args: %s", args)
query_bytes = urllib.urlencode(args, True) yield
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail) if len(args):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
response = yield self._create_request( response = yield self.agent.request(
destination.encode("ascii"),
"GET", "GET",
path.encode("ascii"), uri.encode("ascii"),
query_bytes=query_bytes,
retry_on_dns_fail=retry_on_dns_fail,
body_callback=None
) )
body = yield readBody(response) body = yield readBody(response)
@ -361,38 +93,31 @@ class IdentityServerHttpClient(BaseHttpClient):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
class CaptchaServerHttpClient(MatrixHttpClient): class CaptchaServerHttpClient(SimpleHttpClient):
"""Separate HTTP client for talking to google's captcha servers""" """
Separate HTTP client for talking to google's captcha servers
def _getEndpoint(self, reactor, destination): Only slightly special because accepts partial download responses
return matrix_endpoint(reactor, destination, timeout=10) """
@defer.inlineCallbacks @defer.inlineCallbacks
def post_urlencoded_get_raw(self, destination, path, accept_partial=False, def post_urlencoded_get_raw(self, url, args={}):
args={}):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(args, True)
def body_callback(method, url_bytes, headers_dict): response = yield self.agent.request(
return FileBodyProducer(StringIO(query_bytes))
response = yield self._create_request(
destination.encode("ascii"),
"POST", "POST",
path.encode("ascii"), url.encode("ascii"),
body_callback=body_callback, bodyProducer=FileBodyProducer(StringIO(query_bytes)),
headers_dict={ headers=Headers({
"Content-Type": ["application/x-www-form-urlencoded"] "Content-Type": ["application/x-www-form-urlencoded"]
} })
) )
try: try:
body = yield readBody(response) body = yield readBody(response)
defer.returnValue(body) defer.returnValue(body)
except PartialDownloadError as e: except PartialDownloadError as e:
if accept_partial: # twisted dislikes google's response, no content length.
defer.returnValue(e.response) defer.returnValue(e.response)
else:
raise e
def _print_ex(e): def _print_ex(e):
@ -401,24 +126,3 @@ def _print_ex(e):
_print_ex(ex) _print_ex(ex)
else: else:
logger.exception(e) logger.exception(e)
class _JsonProducer(object):
""" Used by the twisted http client to create the HTTP body from json
"""
def __init__(self, jsn):
self.reset(jsn)
def reset(self, jsn):
self.body = encode_canonical_json(jsn)
self.length = len(self.body)
def startProducing(self, consumer):
consumer.write(self.body)
return defer.succeed(None)
def pauseProducing(self):
pass
def stopProducing(self):
pass

View File

@ -131,11 +131,13 @@ class ContentRepoResource(resource.Resource):
request.setHeader('Content-Type', content_type) request.setHeader('Content-Type', content_type)
# cache for at least a day. # cache for at least a day.
# XXX: we might want to turn this off for data we don't want to recommend # XXX: we might want to turn this off for data we don't want to
# caching as it's sensitive or private - or at least select private. # recommend caching as it's sensitive or private - or at least
# don't bother setting Expires as all our matrix clients are smart enough to # select private. don't bother setting Expires as all our matrix
# be happy with Cache-Control (right?) # clients are smart enough to be happy with Cache-Control (right?)
request.setHeader('Cache-Control', 'public,max-age=86400,s-maxage=86400') request.setHeader(
"Cache-Control", "public,max-age=86400,s-maxage=86400"
)
d = FileSender().beginFileTransfer(f, request) d = FileSender().beginFileTransfer(f, request)
@ -179,7 +181,7 @@ class ContentRepoResource(resource.Resource):
fname = yield self.map_request_to_name(request) fname = yield self.map_request_to_name(request)
# TODO I have a suspcious feeling this is just going to block # TODO I have a suspicious feeling this is just going to block
with open(fname, "wb") as f: with open(fname, "wb") as f:
f.write(request.content.read()) f.write(request.content.read())
@ -188,7 +190,7 @@ class ContentRepoResource(resource.Resource):
# FIXME: we can't assume what the repo's public mounted path is # FIXME: we can't assume what the repo's public mounted path is
# ...plus self-signed SSL won't work to remote clients anyway # ...plus self-signed SSL won't work to remote clients anyway
# ...and we can't assume that it's SSL anyway, as we might want to # ...and we can't assume that it's SSL anyway, as we might want to
# server it via the non-SSL listener... # serve it via the non-SSL listener...
url = "%s/_matrix/content/%s" % ( url = "%s/_matrix/content/%s" % (
self.external_addr, file_name self.external_addr, file_name
) )

View File

@ -27,8 +27,8 @@ import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def matrix_endpoint(reactor, destination, ssl_context_factory=None, def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout=None): timeout=None):
"""Construct an endpoint for the given matrix destination. """Construct an endpoint for the given matrix destination.
Args: Args:

View File

@ -0,0 +1,308 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, reactor
from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, _AgentBase, _URI
from twisted.web.http_headers import Headers
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep
from synapse.util.logcontext import PreserveLoggingContext
from syutil.jsonutil import encode_canonical_json
from synapse.api.errors import CodeMessageException, SynapseError
from syutil.crypto.jsonsign import sign_json
import json
import logging
import urllib
import urlparse
logger = logging.getLogger(__name__)
class MatrixFederationHttpAgent(_AgentBase):
def __init__(self, reactor, pool=None):
_AgentBase.__init__(self, reactor, pool)
def request(self, destination, endpoint, method, path, params, query,
headers, body_producer):
host = b""
port = 0
fragment = b""
parsed_URI = _URI(b"http", destination, host, port, path, params,
query, fragment)
# Set the connection pool key to be the destination.
key = destination
return self._requestWithEndpoint(key, endpoint, method, parsed_URI,
headers, body_producer,
parsed_URI.originForm)
class MatrixFederationHttpClient(object):
"""HTTP client used to talk to other homeservers over the federation
protocol. Send client certificates and signs requests.
Attributes:
agent (twisted.web.client.Agent): The twisted Agent used to send the
requests.
"""
def __init__(self, hs):
self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
self.agent = MatrixFederationHttpAgent(reactor)
@defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes,
body_callback, headers_dict={}, param_bytes=b"",
query_bytes=b"", retry_on_dns_fail=True):
""" Creates and sends a request to the given url
"""
headers_dict[b"User-Agent"] = [b"Synapse"]
headers_dict[b"Host"] = [destination]
url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "",)
)
logger.debug("Sending request to %s: %s %s",
destination, method, url_bytes)
logger.debug(
"Types: %s",
[
type(destination), type(method), type(path_bytes),
type(param_bytes),
type(query_bytes)
]
)
retries_left = 5
endpoint = self._getEndpoint(reactor, destination)
while True:
producer = None
if body_callback:
producer = body_callback(method, url_bytes, headers_dict)
try:
with PreserveLoggingContext():
response = yield self.agent.request(
destination,
endpoint,
method,
path_bytes,
param_bytes,
query_bytes,
Headers(headers_dict),
producer
)
logger.debug("Got response to %s", method)
break
except Exception as e:
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
logger.warn("DNS Lookup failed to %s with %s", destination,
e)
raise SynapseError(400, "Domain specified not found.")
logger.exception("Got error in _create_request")
_print_ex(e)
if retries_left:
yield sleep(2 ** (5 - retries_left))
retries_left -= 1
else:
raise
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
pass
else:
# :'(
# Update transactions table?
logger.error(
"Got response %d %s", response.code, response.phrase
)
raise CodeMessageException(
response.code, response.phrase
)
defer.returnValue(response)
def sign_request(self, destination, method, url_bytes, headers_dict,
content=None):
request = {
"method": method,
"uri": url_bytes,
"origin": self.server_name,
"destination": destination,
}
if content is not None:
request["content"] = content
request = sign_json(request, self.server_name, self.signing_key)
auth_headers = []
for key, sig in request["signatures"][self.server_name].items():
auth_headers.append(bytes(
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
self.server_name, key, sig,
)
))
headers_dict[b"Authorization"] = auth_headers
@defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None):
""" Sends the specifed json data using PUT
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to
use as the request body.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. On a 4xx or 5xx error response a
CodeMessageException is raised.
"""
if not json_data_callback:
def json_data_callback():
return data
def body_callback(method, url_bytes, headers_dict):
json_data = json_data_callback()
self.sign_request(
destination, method, url_bytes, headers_dict, json_data
)
producer = _JsonProducer(json_data)
return producer
response = yield self._create_request(
destination.encode("ascii"),
"PUT",
path.encode("ascii"),
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
)
logger.debug("Getting resp body")
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue((response.code, body))
@defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
""" Get's some json from the given host homeserver and path
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
Returns:
Deferred: Succeeds when we get *any* HTTP response.
The result of the deferred is a tuple of `(code, response)`,
where `response` is a dict representing the decoded JSON body.
"""
logger.debug("get_json args: %s", args)
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, basestring):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True)
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
response = yield self._create_request(
destination.encode("ascii"),
"GET",
path.encode("ascii"),
query_bytes=query_bytes,
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail
)
body = yield readBody(response)
defer.returnValue(json.loads(body))
def _getEndpoint(self, reactor, destination):
return matrix_federation_endpoint(
reactor, destination, timeout=10,
ssl_context_factory=self.hs.tls_context_factory
)
def _print_ex(e):
if hasattr(e, "reasons") and e.reasons:
for ex in e.reasons:
_print_ex(ex)
else:
logger.exception(e)
class _JsonProducer(object):
""" Used by the twisted http client to create the HTTP body from json
"""
def __init__(self, jsn):
self.reset(jsn)
def reset(self, jsn):
self.body = encode_canonical_json(jsn)
self.length = len(self.body)
def startProducing(self, consumer):
consumer.write(self.body)
return defer.succeed(None)
def pauseProducing(self):
pass
def stopProducing(self):
pass

View File

@ -138,8 +138,7 @@ class JsonResource(HttpServer, resource.Resource):
) )
except CodeMessageException as e: except CodeMessageException as e:
if isinstance(e, SynapseError): if isinstance(e, SynapseError):
logger.error("%s SynapseError: %s - %s", request, e.code, logger.info("%s SynapseError: %s - %s", request, e.code, e.msg)
e.msg)
else: else:
logger.exception(e) logger.exception(e)
self._send_response( self._send_response(

View File

@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.async import run_on_reactor
import logging import logging
@ -96,6 +97,7 @@ class Notifier(object):
listening to the room, and any listeners for the users in the listening to the room, and any listeners for the users in the
`extra_users` param. `extra_users` param.
""" """
yield run_on_reactor()
room_id = event.room_id room_id = event.room_id
room_source = self.event_sources.sources["room"] room_source = self.event_sources.sources["room"]
@ -143,6 +145,7 @@ class Notifier(object):
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
yield run_on_reactor()
presence_source = self.event_sources.sources["presence"] presence_source = self.event_sources.sources["presence"]
listeners = set() listeners = set()
@ -211,6 +214,7 @@ class Notifier(object):
timeout, timeout,
deferred, deferred,
) )
def _timeout_listener(): def _timeout_listener():
# TODO (erikj): We should probably set to_token to the current # TODO (erikj): We should probably set to_token to the current
# max rather than reusing from_token. # max rather than reusing from_token.

View File

@ -26,7 +26,6 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EventStreamRestServlet(RestServlet): class EventStreamRestServlet(RestServlet):
PATTERN = client_path_pattern("/events$") PATTERN = client_path_pattern("/events$")

View File

@ -117,8 +117,6 @@ class PresenceListRestServlet(RestServlet):
logger.exception("JSON parse error") logger.exception("JSON parse error")
raise SynapseError(400, "Unable to parse content") raise SynapseError(400, "Unable to parse content")
deferreds = []
if "invite" in content: if "invite" in content:
for u in content["invite"]: for u in content["invite"]:
if not isinstance(u, basestring): if not isinstance(u, basestring):
@ -126,8 +124,9 @@ class PresenceListRestServlet(RestServlet):
if len(u) == 0: if len(u) == 0:
continue continue
invited_user = self.hs.parse_userid(u) invited_user = self.hs.parse_userid(u)
deferreds.append(self.handlers.presence_handler.send_invite( yield self.handlers.presence_handler.send_invite(
observer_user=user, observed_user=invited_user)) observer_user=user, observed_user=invited_user
)
if "drop" in content: if "drop" in content:
for u in content["drop"]: for u in content["drop"]:
@ -136,10 +135,9 @@ class PresenceListRestServlet(RestServlet):
if len(u) == 0: if len(u) == 0:
continue continue
dropped_user = self.hs.parse_userid(u) dropped_user = self.hs.parse_userid(u)
deferreds.append(self.handlers.presence_handler.drop( yield self.handlers.presence_handler.drop(
observer_user=user, observed_user=dropped_user)) observer_user=user, observed_user=dropped_user
)
yield defer.DeferredList(deferreds)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -222,6 +222,7 @@ class RegisterRestServlet(RestServlet):
threepidCreds = register_json['threepidCreds'] threepidCreds = register_json['threepidCreds']
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
logger.debug("Registering email. threepidcreds: %s" % (threepidCreds))
yield handler.register_email(threepidCreds) yield handler.register_email(threepidCreds)
session["threepidCreds"] = threepidCreds # store creds for next stage session["threepidCreds"] = threepidCreds # store creds for next stage
session[LoginType.EMAIL_IDENTITY] = True # mark email as done session[LoginType.EMAIL_IDENTITY] = True # mark email as done
@ -232,6 +233,7 @@ class RegisterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_password(self, request, register_json, session): def _do_password(self, request, register_json, session):
yield
if (self.hs.config.enable_registration_captcha and if (self.hs.config.enable_registration_captcha and
not session[LoginType.RECAPTCHA]): not session[LoginType.RECAPTCHA]):
# captcha should've been done by this stage! # captcha should've been done by this stage!
@ -259,6 +261,9 @@ class RegisterRestServlet(RestServlet):
) )
if session[LoginType.EMAIL_IDENTITY]: if session[LoginType.EMAIL_IDENTITY]:
logger.debug("Binding emails %s to %s" % (
session["threepidCreds"], user_id)
)
yield handler.bind_emails(user_id, session["threepidCreds"]) yield handler.bind_emails(user_id, session["threepidCreds"])
result = { result = {

View File

@ -148,7 +148,7 @@ class RoomStateEventRestServlet(RestServlet):
content = _parse_json(request) content = _parse_json(request)
event = self.event_factory.create_event( event = self.event_factory.create_event(
etype=urllib.unquote(event_type), etype=event_type, # already urldecoded
content=content, content=content,
room_id=urllib.unquote(room_id), room_id=urllib.unquote(room_id),
user_id=user.to_string(), user_id=user.to_string(),

View File

@ -82,7 +82,7 @@ class StateHandler(object):
if hasattr(event, "outlier") and event.outlier: if hasattr(event, "outlier") and event.outlier:
event.state_group = None event.state_group = None
event.old_state_events = None event.old_state_events = None
event.state_events = {} event.state_events = None
defer.returnValue(False) defer.returnValue(False)
return return

View File

@ -67,7 +67,7 @@ SCHEMAS = [
# Remember to update this number every time an incompatible change is made to # Remember to update this number every time an incompatible change is made to
# database schema files, so the users will be informed on server restarts. # database schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 7 SCHEMA_VERSION = 8
class _RollbackButIsFineException(Exception): class _RollbackButIsFineException(Exception):
@ -93,7 +93,8 @@ class DataStore(RoomMemberStore, RoomStore,
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, backfilled=False, is_new_state=True): def persist_event(self, event, backfilled=False, is_new_state=True,
current_state=None):
stream_ordering = None stream_ordering = None
if backfilled: if backfilled:
if not self.min_token_deferred.called: if not self.min_token_deferred.called:
@ -109,6 +110,7 @@ class DataStore(RoomMemberStore, RoomStore,
backfilled=backfilled, backfilled=backfilled,
stream_ordering=stream_ordering, stream_ordering=stream_ordering,
is_new_state=is_new_state, is_new_state=is_new_state,
current_state=current_state,
) )
except _RollbackButIsFineException: except _RollbackButIsFineException:
pass pass
@ -137,7 +139,7 @@ class DataStore(RoomMemberStore, RoomStore,
@log_function @log_function
def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None, def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None,
is_new_state=True): is_new_state=True, current_state=None):
if event.type == RoomMemberEvent.TYPE: if event.type == RoomMemberEvent.TYPE:
self._store_room_member_txn(txn, event) self._store_room_member_txn(txn, event)
elif event.type == FeedbackEvent.TYPE: elif event.type == FeedbackEvent.TYPE:
@ -206,8 +208,24 @@ class DataStore(RoomMemberStore, RoomStore,
self._store_state_groups_txn(txn, event) self._store_state_groups_txn(txn, event)
if current_state:
txn.execute("DELETE FROM current_state_events")
for s in current_state:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": s.event_id,
"room_id": s.room_id,
"type": s.type,
"state_key": s.state_key,
},
or_replace=True,
)
is_state = hasattr(event, "state_key") and event.state_key is not None is_state = hasattr(event, "state_key") and event.state_key is not None
if is_new_state and is_state: if is_state:
vals = { vals = {
"event_id": event.event_id, "event_id": event.event_id,
"room_id": event.room_id, "room_id": event.room_id,
@ -225,17 +243,18 @@ class DataStore(RoomMemberStore, RoomStore,
or_replace=True, or_replace=True,
) )
self._simple_insert_txn( if is_new_state:
txn, self._simple_insert_txn(
"current_state_events", txn,
{ "current_state_events",
"event_id": event.event_id, {
"room_id": event.room_id, "event_id": event.event_id,
"type": event.type, "room_id": event.room_id,
"state_key": event.state_key, "type": event.type,
}, "state_key": event.state_key,
or_replace=True, },
) or_replace=True,
)
for e_id, h in event.prev_state: for e_id, h in event.prev_state:
self._simple_insert_txn( self._simple_insert_txn(
@ -312,7 +331,12 @@ class DataStore(RoomMemberStore, RoomStore,
txn, event.event_id, ref_alg, ref_hash_bytes txn, event.event_id, ref_alg, ref_hash_bytes
) )
self._update_min_depth_for_room_txn(txn, event.room_id, event.depth) if not outlier:
self._update_min_depth_for_room_txn(
txn,
event.room_id,
event.depth
)
def _store_redaction(self, txn, event): def _store_redaction(self, txn, event):
txn.execute( txn.execute(
@ -508,7 +532,7 @@ def prepare_database(db_conn):
"new for the server to understand" "new for the server to understand"
) )
elif user_version < SCHEMA_VERSION: elif user_version < SCHEMA_VERSION:
logging.info( logger.info(
"Upgrading database from version %d", "Upgrading database from version %d",
user_version user_version
) )

View File

@ -57,7 +57,7 @@ class LoggingTransaction(object):
if args and args[0]: if args and args[0]:
values = args[0] values = args[0]
sql_logger.debug( sql_logger.debug(
"[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)), "[SQL values] {%s} " + ", ".join(("<%r>",) * len(values)),
self.name, self.name,
*values *values
) )
@ -91,6 +91,7 @@ class SQLBaseStore(object):
def runInteraction(self, desc, func, *args, **kwargs): def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool.""" """Wraps the .runInteraction() method on the underlying db_pool."""
current_context = LoggingContext.current_context() current_context = LoggingContext.current_context()
def inner_func(txn, *args, **kwargs): def inner_func(txn, *args, **kwargs):
with LoggingContext("runInteraction") as context: with LoggingContext("runInteraction") as context:
current_context.copy_to(context) current_context.copy_to(context)
@ -115,7 +116,6 @@ class SQLBaseStore(object):
"[TXN END] {%s} %f", "[TXN END] {%s} %f",
name, end - start name, end - start
) )
with PreserveLoggingContext(): with PreserveLoggingContext():
result = yield self._db_pool.runInteraction( result = yield self._db_pool.runInteraction(
inner_func, *args, **kwargs inner_func, *args, **kwargs
@ -246,7 +246,10 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found") raise StoreError(404, "No row found")
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol): def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % { sql = (
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s "
"ORDER BY rowid asc"
) % {
"retcol": retcol, "retcol": retcol,
"table": table, "table": table,
"where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
@ -299,7 +302,7 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the rows with keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return retcols : list of strings giving the names of the columns to return
""" """
sql = "SELECT %s FROM %s WHERE %s" % ( sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
", ".join(retcols), ", ".join(retcols),
table, table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues) " AND ".join("%s = ?" % (k, ) for k in keyvalues)
@ -334,7 +337,7 @@ class SQLBaseStore(object):
retcols=None, allow_none=False): retcols=None, allow_none=False):
""" Combined SELECT then UPDATE.""" """ Combined SELECT then UPDATE."""
if retcols: if retcols:
select_sql = "SELECT %s FROM %s WHERE %s" % ( select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
", ".join(retcols), ", ".join(retcols),
table, table,
" AND ".join("%s = ?" % (k) for k in keyvalues) " AND ".join("%s = ?" % (k) for k in keyvalues)
@ -461,7 +464,7 @@ class SQLBaseStore(object):
def _get_events_txn(self, txn, event_ids): def _get_events_txn(self, txn, event_ids):
# FIXME (erikj): This should be batched? # FIXME (erikj): This should be batched?
sql = "SELECT * FROM events WHERE event_id = ?" sql = "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
event_rows = [] event_rows = []
for e_id in event_ids: for e_id in event_ids:
@ -478,7 +481,9 @@ class SQLBaseStore(object):
def _parse_events_txn(self, txn, rows): def _parse_events_txn(self, txn, rows):
events = [self._parse_event_from_row(r) for r in rows] events = [self._parse_event_from_row(r) for r in rows]
select_event_sql = "SELECT * FROM events WHERE event_id = ?" select_event_sql = (
"SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
)
for i, ev in enumerate(events): for i, ev in enumerate(events):
signatures = self._get_event_signatures_txn( signatures = self._get_event_signatures_txn(

View File

@ -75,7 +75,9 @@ class RegistrationStore(SQLBaseStore):
"VALUES (?,?,?)", "VALUES (?,?,?)",
[user_id, password_hash, now]) [user_id, password_hash, now])
except IntegrityError: except IntegrityError:
raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
# it's possible for this to get a conflict, but only for a single user # it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID # since tokens are namespaced based on their user ID
@ -83,8 +85,8 @@ class RegistrationStore(SQLBaseStore):
"VALUES (?,?)", [txn.lastrowid, token]) "VALUES (?,?)", [txn.lastrowid, token])
def get_user_by_id(self, user_id): def get_user_by_id(self, user_id):
query = ("SELECT users.name, users.password_hash FROM users " query = ("SELECT users.name, users.password_hash FROM users"
"WHERE users.name = ?") " WHERE users.name = ?")
return self._execute( return self._execute(
self.cursor_to_dict, self.cursor_to_dict,
query, user_id query, user_id
@ -120,10 +122,10 @@ class RegistrationStore(SQLBaseStore):
def _query_for_auth(self, txn, token): def _query_for_auth(self, txn, token):
sql = ( sql = (
"SELECT users.name, users.admin, access_tokens.device_id " "SELECT users.name, users.admin, access_tokens.device_id"
"FROM users " " FROM users"
"INNER JOIN access_tokens on users.id = access_tokens.user_id " " INNER JOIN access_tokens on users.id = access_tokens.user_id"
"WHERE token = ?" " WHERE token = ?"
) )
cursor = txn.execute(sql, (token,)) cursor = txn.execute(sql, (token,))

View File

@ -27,7 +27,9 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
OpsLevel = collections.namedtuple("OpsLevel", ("ban_level", "kick_level", "redact_level")) OpsLevel = collections.namedtuple("OpsLevel", (
"ban_level", "kick_level", "redact_level")
)
class RoomStore(SQLBaseStore): class RoomStore(SQLBaseStore):

View File

@ -177,8 +177,8 @@ class RoomMemberStore(SQLBaseStore):
return self._get_members_query(clause, vals) return self._get_members_query(clause, vals)
def _get_members_query(self, where_clause, where_values): def _get_members_query(self, where_clause, where_values):
return self._db_pool.runInteraction( return self.runInteraction(
self._get_members_query_txn, "get_members_query", self._get_members_query_txn,
where_clause, where_values where_clause, where_values
) )

View File

@ -0,0 +1,34 @@
/* Copyright 2014 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS event_signatures_2 (
event_id TEXT,
signature_name TEXT,
key_id TEXT,
signature BLOB,
CONSTRAINT uniqueness UNIQUE (event_id, signature_name, key_id)
);
INSERT INTO event_signatures_2 (event_id, signature_name, key_id, signature)
SELECT event_id, signature_name, key_id, signature FROM event_signatures;
DROP TABLE event_signatures;
ALTER TABLE event_signatures_2 RENAME TO event_signatures;
CREATE INDEX IF NOT EXISTS event_signatures_id ON event_signatures (
event_id
);
PRAGMA user_version = 8;

View File

@ -42,7 +42,7 @@ CREATE TABLE IF NOT EXISTS event_signatures (
signature_name TEXT, signature_name TEXT,
key_id TEXT, key_id TEXT,
signature BLOB, signature BLOB,
CONSTRAINT uniqueness UNIQUE (event_id, key_id) CONSTRAINT uniqueness UNIQUE (event_id, signature_name, key_id)
); );
CREATE INDEX IF NOT EXISTS event_signatures_id ON event_signatures ( CREATE INDEX IF NOT EXISTS event_signatures_id ON event_signatures (

View File

@ -36,7 +36,7 @@ class SignatureStore(SQLBaseStore):
return dict(txn.fetchall()) return dict(txn.fetchall())
def _store_event_content_hash_txn(self, txn, event_id, algorithm, def _store_event_content_hash_txn(self, txn, event_id, algorithm,
hash_bytes): hash_bytes):
"""Store a hash for a Event """Store a hash for a Event
Args: Args:
txn (cursor): txn (cursor):
@ -84,7 +84,7 @@ class SignatureStore(SQLBaseStore):
return dict(txn.fetchall()) return dict(txn.fetchall())
def _store_event_reference_hash_txn(self, txn, event_id, algorithm, def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
hash_bytes): hash_bytes):
"""Store a hash for a PDU """Store a hash for a PDU
Args: Args:
txn (cursor): txn (cursor):
@ -127,7 +127,7 @@ class SignatureStore(SQLBaseStore):
return res return res
def _store_event_signature_txn(self, txn, event_id, signature_name, key_id, def _store_event_signature_txn(self, txn, event_id, signature_name, key_id,
signature_bytes): signature_bytes):
"""Store a signature from the origin server for a PDU. """Store a signature from the origin server for a PDU.
Args: Args:
txn (cursor): txn (cursor):
@ -169,7 +169,7 @@ class SignatureStore(SQLBaseStore):
return results return results
def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id, def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id,
algorithm, hash_bytes): algorithm, hash_bytes):
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
"event_edge_hashes", "event_edge_hashes",

View File

@ -87,7 +87,7 @@ class StateStore(SQLBaseStore):
) )
def _store_state_groups_txn(self, txn, event): def _store_state_groups_txn(self, txn, event):
if not event.state_events: if event.state_events is None:
return return
state_group = event.state_group state_group = event.state_group

View File

@ -213,8 +213,8 @@ class StreamStore(SQLBaseStore):
# Tokens really represent positions between elements, but we use # Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence # the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities. # we have a bit of asymmetry when it comes to equalities.
from_comp = '<=' if direction =='b' else '>' from_comp = '<=' if direction == 'b' else '>'
to_comp = '>' if direction =='b' else '<=' to_comp = '>' if direction == 'b' else '<='
order = "DESC" if direction == 'b' else "ASC" order = "DESC" if direction == 'b' else "ASC"
args = [room_id] args = [room_id]
@ -235,9 +235,10 @@ class StreamStore(SQLBaseStore):
) )
sql = ( sql = (
"SELECT *, (%(redacted)s) AS redacted FROM events " "SELECT *, (%(redacted)s) AS redacted FROM events"
"WHERE outlier = 0 AND room_id = ? AND %(bounds)s " " WHERE outlier = 0 AND room_id = ? AND %(bounds)s"
"ORDER BY topological_ordering %(order)s, stream_ordering %(order)s %(limit)s " " ORDER BY topological_ordering %(order)s,"
" stream_ordering %(order)s %(limit)s"
) % { ) % {
"redacted": del_sql, "redacted": del_sql,
"bounds": bounds, "bounds": bounds,

View File

@ -28,11 +28,11 @@ class SourcePaginationConfig(object):
specific event source.""" specific event source."""
def __init__(self, from_key=None, to_key=None, direction='f', def __init__(self, from_key=None, to_key=None, direction='f',
limit=0): limit=None):
self.from_key = from_key self.from_key = from_key
self.to_key = to_key self.to_key = to_key
self.direction = 'f' if direction == 'f' else 'b' self.direction = 'f' if direction == 'f' else 'b'
self.limit = int(limit) self.limit = int(limit) if limit is not None else None
class PaginationConfig(object): class PaginationConfig(object):
@ -40,11 +40,11 @@ class PaginationConfig(object):
"""A configuration object which stores pagination parameters.""" """A configuration object which stores pagination parameters."""
def __init__(self, from_token=None, to_token=None, direction='f', def __init__(self, from_token=None, to_token=None, direction='f',
limit=0): limit=None):
self.from_token = from_token self.from_token = from_token
self.to_token = to_token self.to_token = to_token
self.direction = 'f' if direction == 'f' else 'b' self.direction = 'f' if direction == 'f' else 'b'
self.limit = int(limit) self.limit = int(limit) if limit is not None else None
@classmethod @classmethod
def from_request(cls, request, raise_invalid_params=True): def from_request(cls, request, raise_invalid_params=True):
@ -80,8 +80,8 @@ class PaginationConfig(object):
except: except:
raise SynapseError(400, "'to' paramater is invalid") raise SynapseError(400, "'to' paramater is invalid")
limit = get_param("limit", "0") limit = get_param("limit", None)
if not limit.isdigit(): if limit is not None and not limit.isdigit():
raise SynapseError(400, "'limit' parameter must be an integer.") raise SynapseError(400, "'limit' parameter must be an integer.")
try: try:

View File

@ -37,6 +37,7 @@ class Clock(object):
def call_later(self, delay, callback): def call_later(self, delay, callback):
current_context = LoggingContext.current_context() current_context = LoggingContext.current_context()
def wrapped_callback(): def wrapped_callback():
LoggingContext.thread_local.current_context = current_context LoggingContext.thread_local.current_context = current_context
callback() callback()

View File

@ -18,6 +18,7 @@ from twisted.internet import defer, reactor
from .logcontext import PreserveLoggingContext from .logcontext import PreserveLoggingContext
@defer.inlineCallbacks @defer.inlineCallbacks
def sleep(seconds): def sleep(seconds):
d = defer.Deferred() d = defer.Deferred()
@ -25,6 +26,7 @@ def sleep(seconds):
with PreserveLoggingContext(): with PreserveLoggingContext():
yield d yield d
def run_on_reactor(): def run_on_reactor():
""" This will cause the rest of the function to be invoked upon the next """ This will cause the rest of the function to be invoked upon the next
iteration of the main loop iteration of the main loop

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer from twisted.internet import defer
import logging import logging
@ -91,6 +93,7 @@ class Signal(object):
Each observer callable may return a Deferred.""" Each observer callable may return a Deferred."""
self.observers.append(observer) self.observers.append(observer)
@defer.inlineCallbacks
def fire(self, *args, **kwargs): def fire(self, *args, **kwargs):
"""Invokes every callable in the observer list, passing in the args and """Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is kwargs. Exceptions thrown by observers are logged but ignored. It is
@ -98,22 +101,24 @@ class Signal(object):
Returns a Deferred that will complete when all the observers have Returns a Deferred that will complete when all the observers have
completed.""" completed."""
deferreds = [] with PreserveLoggingContext():
for observer in self.observers: deferreds = []
d = defer.maybeDeferred(observer, *args, **kwargs) for observer in self.observers:
d = defer.maybeDeferred(observer, *args, **kwargs)
def eb(failure): def eb(failure):
logger.warning( logger.warning(
"%s signal observer %s failed: %r", "%s signal observer %s failed: %r",
self.name, observer, failure, self.name, observer, failure,
exc_info=( exc_info=(
failure.type, failure.type,
failure.value, failure.value,
failure.getTracebackObject())) failure.getTracebackObject()))
if not self.suppress_failures: if not self.suppress_failures:
raise failure raise failure
deferreds.append(d.addErrback(eb)) deferreds.append(d.addErrback(eb))
return defer.DeferredList( result = yield defer.DeferredList(
deferreds, fireOnOneErrback=not self.suppress_failures deferreds, fireOnOneErrback=not self.suppress_failures
) )
defer.returnValue(result)

View File

@ -1,6 +1,8 @@
import threading import threading
import logging import logging
logger = logging.getLogger(__name__)
class LoggingContext(object): class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a """Additional context for log formatting. Contexts are scoped within a
@ -53,11 +55,14 @@ class LoggingContext(object):
None to avoid suppressing any exeptions that were thrown. None to avoid suppressing any exeptions that were thrown.
""" """
if self.thread_local.current_context is not self: if self.thread_local.current_context is not self:
logging.error( if self.thread_local.current_context is self.sentinel:
"Current logging context %s is not the expected context %s", logger.debug("Expected logging context %s has been lost", self)
self.thread_local.current_context, else:
self logger.warn(
) "Current logging context %s is not expected context %s",
self.thread_local.current_context,
self
)
self.thread_local.current_context = self.parent_context self.thread_local.current_context = self.parent_context
self.parent_context = None self.parent_context = None

View File

@ -83,20 +83,22 @@ class FederationTestCase(unittest.TestCase):
event_id="$a:b", event_id="$a:b",
user_id="@a:b", user_id="@a:b",
origin="b", origin="b",
auth_events=[],
hashes={"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"}, hashes={"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
) )
self.datastore.persist_event.return_value = defer.succeed(None) self.datastore.persist_event.return_value = defer.succeed(None)
self.datastore.get_room.return_value = defer.succeed(True) self.datastore.get_room.return_value = defer.succeed(True)
self.state_handler.annotate_event_with_state.return_value = ( def annotate(ev, old_state=None):
defer.succeed(False) ev.old_state_events = []
) return defer.succeed(False)
self.state_handler.annotate_event_with_state.side_effect = annotate
yield self.handlers.federation_handler.on_receive_pdu(pdu, False) yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
self.datastore.persist_event.assert_called_once_with( self.datastore.persist_event.assert_called_once_with(
ANY, False, is_new_state=False ANY, is_new_state=False, backfilled=False, current_state=None
) )
self.state_handler.annotate_event_with_state.assert_called_once_with( self.state_handler.annotate_event_with_state.assert_called_once_with(
@ -104,7 +106,7 @@ class FederationTestCase(unittest.TestCase):
old_state=None, old_state=None,
) )
self.auth.check.assert_called_once_with(ANY, raises=True) self.auth.check.assert_called_once_with(ANY, auth_events={})
self.notifier.on_new_room_event.assert_called_once_with( self.notifier.on_new_room_event.assert_called_once_with(
ANY, ANY,

View File

@ -120,7 +120,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.datastore.get_room_member.return_value = defer.succeed(None) self.datastore.get_room_member.return_value = defer.succeed(None)
event.state_events = { event.old_state_events = {
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member( (RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
user_id="@alice:green", user_id="@alice:green",
room_id=room_id, room_id=room_id,
@ -129,9 +129,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
user_id="@bob:red", user_id="@bob:red",
room_id=room_id, room_id=room_id,
), ),
(RoomMemberEvent.TYPE, target_user_id): event,
} }
event.state_events = event.old_state_events
event.state_events[(RoomMemberEvent.TYPE, target_user_id)] = event
# Actual invocation # Actual invocation
yield self.room_member_handler.change_membership(event) yield self.room_member_handler.change_membership(event)
@ -187,6 +189,16 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
(RoomMemberEvent.TYPE, user_id): event, (RoomMemberEvent.TYPE, user_id): event,
} }
event.old_state_events = {
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
user_id="@alice:green",
room_id=room_id,
),
}
event.state_events = event.old_state_events
event.state_events[(RoomMemberEvent.TYPE, user_id)] = event
# Actual invocation # Actual invocation
yield self.room_member_handler.change_membership(event) yield self.room_member_handler.change_membership(event)

View File

@ -84,7 +84,8 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.assertEquals("Value", value) self.assertEquals("Value", value)
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
"SELECT retcol FROM tablename WHERE keycol = ?", "SELECT retcol FROM tablename WHERE keycol = ? "
"ORDER BY rowid asc",
["TheKey"] ["TheKey"]
) )
@ -101,7 +102,8 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
"SELECT colA, colB, colC FROM tablename WHERE keycol = ?", "SELECT colA, colB, colC FROM tablename WHERE keycol = ? "
"ORDER BY rowid asc",
["TheKey"] ["TheKey"]
) )
@ -135,7 +137,8 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
"SELECT colA FROM tablename WHERE keycol = ?", "SELECT colA FROM tablename WHERE keycol = ? "
"ORDER BY rowid asc",
["A set"] ["A set"]
) )
@ -184,7 +187,8 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.assertEquals({"columname": "Old Value"}, ret) self.assertEquals({"columname": "Old Value"}, ret)
self.mock_txn.execute.assert_has_calls([ self.mock_txn.execute.assert_has_calls([
call('SELECT columname FROM tablename WHERE keycol = ?', call('SELECT columname FROM tablename WHERE keycol = ? '
'ORDER BY rowid asc',
['TheKey']), ['TheKey']),
call("UPDATE tablename SET columname = ? WHERE keycol = ?", call("UPDATE tablename SET columname = ? WHERE keycol = ?",
["New Value", "TheKey"]) ["New Value", "TheKey"])