Merge branch 'release-v0.5.1' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2014-11-26 12:06:36 +00:00
commit 48ee9ddb22
57 changed files with 1040 additions and 735 deletions

View File

@ -1,3 +1,11 @@
Changes in synapse 0.5.1 (2014-11-26)
=====================================
See UPGRADES.rst for specific instructions on how to upgrade.
* Fix bug where we served up an Event that did not match its signatures.
* Fix regression where we no longer correctly handled the case where a
homeserver receives an event for a room it doesn't recognise (but is in.)
Changes in synapse 0.5.0 (2014-11-19)
=====================================
This release includes changes to the federation protocol and client-server API

View File

@ -69,8 +69,8 @@ command line utility which lets you easily see what the JSON APIs are up to).
Meanwhile, iOS and Android SDKs and clients are currently in development and available from:
* https://github.com/matrix-org/matrix-ios-sdk
* https://github.com/matrix-org/matrix-android-sdk
- https://github.com/matrix-org/matrix-ios-sdk
- https://github.com/matrix-org/matrix-android-sdk
We'd like to invite you to join #matrix:matrix.org (via http://matrix.org/alpha), run a homeserver, take a look at the Matrix spec at
http://matrix.org/docs/spec, experiment with the APIs and the demo
@ -94,7 +94,7 @@ header files for python C extensions.
Installing prerequisites on Ubuntu or Debian::
$ sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools
python-pip python-setuptools sqlite3
Installing prerequisites on Mac OS X::
@ -125,7 +125,7 @@ created. To reset the installation::
pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are
failing, e.g.:
failing, e.g.::
$ pip install --user twisted
@ -148,7 +148,7 @@ Troubleshooting Running
-----------------------
If ``synctl`` fails with ``pkg_resources.DistributionNotFound`` errors you may
need a newer version of setuptools than that provided by your OS.
need a newer version of setuptools than that provided by your OS.::
$ sudo pip install setuptools --upgrade
@ -172,7 +172,7 @@ Homeserver Development
======================
To check out a homeserver for development, clone the git repo into a working
directory of your choice:
directory of your choice::
$ git clone https://github.com/matrix-org/synapse.git
$ cd synapse

View File

@ -1,3 +1,12 @@
Upgrading to v0.5.1
===================
Depending on precisely when you installed v0.5.0 you may have ended up with
a stale release of the reference matrix webclient installed as a python module.
To uninstall it and ensure you are depending on the latest module, please run::
$ pip uninstall syweb
Upgrading to v0.5.0
===================

View File

@ -1 +1 @@
0.5.0
0.5.1

View File

@ -23,7 +23,7 @@ def get_targets(server_name):
for srv in answers:
yield (srv.target, srv.port)
except dns.resolver.NXDOMAIN:
yield (server_name, 8480)
yield (server_name, 8448)
def get_server_keys(server_name, target, port):
url = "https://%s:%i/_matrix/key/v1" % (target, port)

View File

@ -32,7 +32,7 @@ setup(
description="Reference Synapse Home Server",
install_requires=[
"syutil==0.0.2",
"matrix_angular_sdk==0.5.0",
"matrix_angular_sdk==0.5.1",
"Twisted>=14.0.0",
"service_identity>=1.0.0",
"pyopenssl>=0.14",
@ -45,7 +45,7 @@ setup(
dependency_links=[
"https://github.com/matrix-org/syutil/tarball/v0.0.2#egg=syutil-0.0.2",
"https://github.com/pyca/pynacl/tarball/52dbe2dc33f1#egg=pynacl-0.3.0",
"https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.5.0/#egg=matrix_angular_sdk-0.5.0",
"https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.5.1/#egg=matrix_angular_sdk-0.5.1",
],
setup_requires=[
"setuptools_trial",

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a synapse home server.
"""
__version__ = "0.5.0"
__version__ = "0.5.1"

View File

@ -38,27 +38,21 @@ class Auth(object):
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
def check(self, event, raises=False):
def check(self, event, auth_events):
""" Checks if this event is correctly authed.
Returns:
True if the auth checks pass.
Raises:
AuthError if there was a problem authorising this event. This will
be raised only if raises=True.
"""
try:
if hasattr(event, "room_id"):
if event.old_state_events is None:
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
logger.warn("Trusting event: %s", event.event_id)
return True
if hasattr(event, "outlier") and event.outlier is True:
# TODO (erikj): Auth for outliers is done differently.
return True
if event.type == RoomCreateEvent.TYPE:
# FIXME
return True
@ -68,49 +62,42 @@ class Auth(object):
return True
if event.type == RoomMemberEvent.TYPE:
allowed = self.is_membership_change_allowed(event)
allowed = self.is_membership_change_allowed(
event, auth_events
)
if allowed:
logger.debug("Allowing! %s", event)
else:
logger.debug("Denying! %s", event)
return allowed
self.check_event_sender_in_room(event)
self._can_send_event(event)
self.check_event_sender_in_room(event, auth_events)
self._can_send_event(event, auth_events)
if event.type == RoomPowerLevelsEvent.TYPE:
self._check_power_levels(event)
self._check_power_levels(event, auth_events)
if event.type == RoomRedactionEvent.TYPE:
self._check_redaction(event)
self._check_redaction(event, auth_events)
logger.debug("Allowing! %s", event)
return True
else:
raise AuthError(500, "Unknown event: %s" % event)
except AuthError as e:
logger.info(
"Event auth check failed on event %s with msg: %s",
event, e.msg
)
logger.info("Denying! %s", event)
if raises:
raise
return False
@defer.inlineCallbacks
def check_joined_room(self, room_id, user_id):
try:
member = yield self.store.get_room_member(
member = yield self.state.get_current_state(
room_id=room_id,
user_id=user_id
event_type=RoomMemberEvent.TYPE,
state_key=user_id
)
self._check_joined_room(member, user_id, room_id)
defer.returnValue(member)
except AttributeError:
pass
defer.returnValue(None)
@defer.inlineCallbacks
def check_host_in_room(self, room_id, host):
@ -130,9 +117,9 @@ class Auth(object):
defer.returnValue(False)
def check_event_sender_in_room(self, event):
def check_event_sender_in_room(self, event, auth_events):
key = (RoomMemberEvent.TYPE, event.user_id, )
member_event = event.state_events.get(key)
member_event = auth_events.get(key)
return self._check_joined_room(
member_event,
@ -147,15 +134,15 @@ class Auth(object):
))
@log_function
def is_membership_change_allowed(self, event):
def is_membership_change_allowed(self, event, auth_events):
membership = event.content["membership"]
# Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership:
# Get room creation event:
key = (RoomCreateEvent.TYPE, "", )
create = event.old_state_events.get(key)
if event.prev_events[0][0] == create.event_id:
create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key:
return True
@ -163,19 +150,19 @@ class Auth(object):
# get info about the caller
key = (RoomMemberEvent.TYPE, event.user_id, )
caller = event.old_state_events.get(key)
caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target
key = (RoomMemberEvent.TYPE, target_user_id, )
target = event.old_state_events.get(key)
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
key = (RoomJoinRulesEvent.TYPE, "", )
join_rule_event = event.old_state_events.get(key)
join_rule_event = auth_events.get(key)
if join_rule_event:
join_rule = join_rule_event.content.get(
"join_rule", JoinRules.INVITE
@ -186,11 +173,13 @@ class Auth(object):
user_level = self._get_power_level_from_event_state(
event,
event.user_id,
auth_events,
)
ban_level, kick_level, redact_level = (
self._get_ops_level_from_event_state(
event
event,
auth_events,
)
)
@ -260,9 +249,9 @@ class Auth(object):
return True
def _get_power_level_from_event_state(self, event, user_id):
def _get_power_level_from_event_state(self, event, user_id, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", )
power_level_event = event.old_state_events.get(key)
power_level_event = auth_events.get(key)
level = None
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
@ -270,16 +259,16 @@ class Auth(object):
level = power_level_event.content.get("users_default", 0)
else:
key = (RoomCreateEvent.TYPE, "", )
create_event = event.old_state_events.get(key)
create_event = auth_events.get(key)
if (create_event is not None and
create_event.content["creator"] == user_id):
return 100
return level
def _get_ops_level_from_event_state(self, event):
def _get_ops_level_from_event_state(self, event, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", )
power_level_event = event.old_state_events.get(key)
power_level_event = auth_events.get(key)
if power_level_event:
return (
@ -375,6 +364,11 @@ class Auth(object):
key = (RoomMemberEvent.TYPE, event.user_id, )
member_event = event.old_state_events.get(key)
key = (RoomCreateEvent.TYPE, "", )
create_event = event.old_state_events.get(key)
if create_event:
auth_events.append(create_event.event_id)
if join_rule_event:
join_rule = join_rule_event.content.get("join_rule")
is_public = join_rule == JoinRules.PUBLIC if join_rule else False
@ -406,9 +400,9 @@ class Auth(object):
event.auth_events = zip(auth_events, hashes)
@log_function
def _can_send_event(self, event):
def _can_send_event(self, event, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", )
send_level_event = event.old_state_events.get(key)
send_level_event = auth_events.get(key)
send_level = None
if send_level_event:
send_level = send_level_event.content.get("events", {}).get(
@ -432,6 +426,7 @@ class Auth(object):
user_level = self._get_power_level_from_event_state(
event,
event.user_id,
auth_events,
)
if user_level:
@ -468,14 +463,16 @@ class Auth(object):
return True
def _check_redaction(self, event):
def _check_redaction(self, event, auth_events):
user_level = self._get_power_level_from_event_state(
event,
event.user_id,
auth_events,
)
_, _, redact_level = self._get_ops_level_from_event_state(
event
event,
auth_events,
)
if user_level < redact_level:
@ -484,7 +481,7 @@ class Auth(object):
"You don't have permission to redact events"
)
def _check_power_levels(self, event):
def _check_power_levels(self, event, auth_events):
user_list = event.content.get("users", {})
# Validate users
for k, v in user_list.items():
@ -499,7 +496,7 @@ class Auth(object):
raise SynapseError(400, "Not a valid power level: %s" % (v,))
key = (event.type, event.state_key, )
current_state = event.old_state_events.get(key)
current_state = auth_events.get(key)
if not current_state:
return
@ -507,6 +504,7 @@ class Auth(object):
user_level = self._get_power_level_from_event_state(
event,
event.user_id,
auth_events,
)
# Check other levels:

View File

@ -17,6 +17,8 @@
import logging
logger = logging.getLogger(__name__)
class Codes(object):
UNAUTHORIZED = "M_UNAUTHORIZED"
@ -38,7 +40,7 @@ class CodeMessageException(Exception):
"""An exception with integer code and message string attributes."""
def __init__(self, code, msg):
logging.error("%s: %s, %s", type(self).__name__, code, msg)
logger.info("%s: %s, %s", type(self).__name__, code, msg)
super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
self.code = code
self.msg = msg
@ -140,7 +142,8 @@ def cs_exception(exception):
if isinstance(exception, CodeMessageException):
return exception.error_dict()
else:
logging.error("Unknown exception type: %s", type(exception))
logger.error("Unknown exception type: %s", type(exception))
return {}
def cs_error(msg, code=Codes.UNKNOWN, **kwargs):

View File

@ -83,6 +83,8 @@ class SynapseEvent(JsonEncodedObject):
"content",
]
outlier = False
def __init__(self, raises=True, **kwargs):
super(SynapseEvent, self).__init__(**kwargs)
# if "content" in kwargs:
@ -123,6 +125,7 @@ class SynapseEvent(JsonEncodedObject):
pdu_json.pop("outlier", None)
pdu_json.pop("replaces_state", None)
pdu_json.pop("redacted", None)
pdu_json.pop("prev_content", None)
state_hash = pdu_json.pop("state_hash", None)
if state_hash is not None:
pdu_json.setdefault("unsigned", {})["state_hash"] = state_hash

View File

@ -26,7 +26,7 @@ from twisted.web.server import Site
from synapse.http.server import JsonResource, RootRedirect
from synapse.http.content_repository import ContentRepoResource
from synapse.http.server_key_resource import LocalKey
from synapse.http.client import MatrixHttpClient
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import (
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX,
@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
class SynapseHomeServer(HomeServer):
def build_http_client(self):
return MatrixHttpClient(self)
return MatrixFederationHttpClient(self)
def build_resource_for_client(self):
return JsonResource()
@ -116,7 +116,7 @@ class SynapseHomeServer(HomeServer):
# extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {}
for (full_path, resource) in desired_tree:
logging.info("Attaching %s to path %s", resource, full_path)
logger.info("Attaching %s to path %s", resource, full_path)
last_resource = self.root_resource
for path_seg in full_path.split('/')[1:-1]:
if not path_seg in last_resource.listNames():
@ -221,12 +221,12 @@ def setup():
db_name = hs.get_db_name()
logging.info("Preparing database: %s...", db_name)
logger.info("Preparing database: %s...", db_name)
with sqlite3.connect(db_name) as db_conn:
prepare_database(db_conn)
logging.info("Database prepared in %s.", db_name)
logger.info("Database prepared in %s.", db_name)
hs.get_db_pool()
@ -257,13 +257,16 @@ def setup():
else:
reactor.run()
def run():
with LoggingContext("run"):
reactor.run()
def main():
with LoggingContext("main"):
setup()
if __name__ == '__main__':
main()

View File

@ -27,6 +27,7 @@ PIDFILE="homeserver.pid"
GREEN = "\x1b[1;32m"
NORMAL = "\x1b[m"
def start():
if not os.path.exists(CONFIGFILE):
sys.stderr.write(
@ -43,12 +44,14 @@ def start():
subprocess.check_call(args)
print GREEN + "started" + NORMAL
def stop():
if os.path.exists(PIDFILE):
pid = int(open(PIDFILE).read())
os.kill(pid, signal.SIGTERM)
print GREEN + "stopped" + NORMAL
def main():
action = sys.argv[1] if sys.argv[1:] else "usage"
if action == "start":
@ -62,5 +65,6 @@ def main():
sys.stderr.write("Usage: %s [start|stop|restart]\n" % (sys.argv[0],))
sys.exit(1)
if __name__=='__main__':
if __name__ == "__main__":
main()

View File

@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"""Check whether the hash for this PDU matches the contents"""
computed_hash = _compute_content_hash(event, hash_algorithm)
logging.debug("Expecting hash: %s", encode_base64(computed_hash.digest()))
logger.debug("Expecting hash: %s", encode_base64(computed_hash.digest()))
if computed_hash.name not in event.hashes:
raise SynapseError(
400,

View File

@ -17,7 +17,7 @@
from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor
from synapse.http.endpoint import matrix_endpoint
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.logcontext import PreserveLoggingContext
import json
import logging
@ -31,7 +31,7 @@ def fetch_server_key(server_name, ssl_context_factory):
"""Fetch the keys for a remote server."""
factory = SynapseKeyClientFactory()
endpoint = matrix_endpoint(
endpoint = matrix_federation_endpoint(
reactor, server_name, ssl_context_factory, timeout=30
)
@ -48,7 +48,7 @@ def fetch_server_key(server_name, ssl_context_factory):
class SynapseKeyClientError(Exception):
"""The key wasn't retireved from the remote server."""
"""The key wasn't retrieved from the remote server."""
pass

View File

@ -135,7 +135,7 @@ class Keyring(object):
time_now_ms = self.clock.time_msec()
self.store.store_server_certificate(
yield self.store.store_server_certificate(
server_name,
server_name,
time_now_ms,
@ -143,7 +143,7 @@ class Keyring(object):
)
for key_id, key in verify_keys.items():
self.store.store_server_verify_key(
yield self.store.store_server_verify_key(
server_name, server_name, time_now_ms, key
)

View File

@ -24,6 +24,7 @@ from .units import Transaction, Edu
from .persistence import TransactionActions
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
import logging
@ -319,6 +320,7 @@ class ReplicationLayer(object):
logger.debug("[%s] Transacition is new", transaction.transaction_id)
with PreserveLoggingContext():
dl = []
for pdu in pdu_list:
dl.append(self._handle_new_pdu(transaction.origin, pdu))
@ -425,7 +427,9 @@ class ReplicationLayer(object):
time_now = self._clock.time_msec()
defer.returnValue((200, {
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
"auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
"auth_chain": [
p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
],
}))
@defer.inlineCallbacks
@ -436,7 +440,9 @@ class ReplicationLayer(object):
(
200,
{
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
"auth_chain": [
a.get_pdu_json(time_now) for a in auth_pdus
],
}
)
)
@ -475,11 +481,17 @@ class ReplicationLayer(object):
# FIXME: We probably want to do something with the auth_chain given
# to us
# auth_chain = [
# Pdu(outlier=True, **p) for p in content.get("auth_chain", [])
# ]
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", [])
]
defer.returnValue(state)
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue({
"state": state,
"auth_chain": auth_chain,
})
@defer.inlineCallbacks
def send_invite(self, destination, context, event_id, pdu):
@ -498,13 +510,15 @@ class ReplicationLayer(object):
defer.returnValue(self.event_from_pdu_json(pdu_dict))
@log_function
def _get_persisted_pdu(self, origin, event_id):
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id.
Returns:
Deferred: Results in a `Pdu`.
"""
return self.handler.get_persisted_pdu(origin, event_id)
return self.handler.get_persisted_pdu(
origin, event_id, do_auth=do_auth
)
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
@ -523,7 +537,9 @@ class ReplicationLayer(object):
@log_function
def _handle_new_pdu(self, origin, pdu, backfilled=False):
# We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu(origin, pdu.event_id)
existing = yield self._get_persisted_pdu(
origin, pdu.event_id, do_auth=False
)
if existing and (not existing.outlier or pdu.outlier):
logger.debug("Already seen pdu %s", pdu.event_id)
@ -532,6 +548,36 @@ class ReplicationLayer(object):
state = None
# We need to make sure we have all the auth events.
for e_id, _ in pdu.auth_events:
exists = yield self._get_persisted_pdu(
origin,
e_id,
do_auth=False
)
if not exists:
try:
logger.debug(
"_handle_new_pdu fetch missing auth event %s from %s",
e_id,
origin,
)
yield self.get_pdu(
origin,
event_id=e_id,
outlier=True,
)
logger.debug("Processed pdu %s", e_id)
except:
logger.warn(
"Failed to get auth event %s from %s",
e_id,
origin
)
# Get missing pdus if necessary.
if not pdu.outlier:
# We only backfill backwards to the min depth.
@ -539,16 +585,28 @@ class ReplicationLayer(object):
pdu.room_id
)
logger.debug(
"_handle_new_pdu min_depth for %s: %d",
pdu.room_id, min_depth
)
if min_depth and pdu.depth > min_depth:
for event_id, hashes in pdu.prev_events:
exists = yield self._get_persisted_pdu(origin, event_id)
exists = yield self._get_persisted_pdu(
origin,
event_id,
do_auth=False
)
if not exists:
logger.debug("Requesting pdu %s", event_id)
logger.debug(
"_handle_new_pdu requesting pdu %s",
event_id
)
try:
yield self.get_pdu(
pdu.origin,
origin,
event_id=event_id,
)
logger.debug("Processed pdu %s", event_id)
@ -558,6 +616,10 @@ class ReplicationLayer(object):
else:
# We need to get the state at this event, since we have reached
# a backward extremity edge.
logger.debug(
"_handle_new_pdu getting state for %s",
pdu.room_id
)
state = yield self.get_state_for_context(
origin, pdu.room_id, pdu.event_id,
)
@ -649,6 +711,7 @@ class _TransactionQueue(object):
(pdu, deferred, order)
)
with PreserveLoggingContext():
self._attempt_new_transaction(destination)
deferreds.append(deferred)
@ -669,6 +732,8 @@ class _TransactionQueue(object):
deferred.errback(failure)
else:
logger.exception("Failed to send edu", failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(eb)
return deferred

View File

@ -25,7 +25,6 @@ import logging
logger = logging.getLogger(__name__)
class Edu(JsonEncodedObject):
""" An Edu represents a piece of data sent from one homeserver to another.

View File

@ -78,7 +78,7 @@ class BaseHandler(object):
if not suppress_auth:
logger.debug("Authing...")
self.auth.check(event, raises=True)
self.auth.check(event, auth_events=event.old_state_events)
logger.debug("Authed")
else:
logger.debug("Suppressed auth.")
@ -112,7 +112,7 @@ class BaseHandler(object):
event.destinations = list(destinations)
self.notifier.on_new_room_event(event, extra_users=extra_users)
yield self.notifier.on_new_room_event(event, extra_users=extra_users)
federation_handler = self.hs.get_handlers().federation_handler
yield federation_handler.handle_new_event(event, snapshot)

View File

@ -17,7 +17,7 @@
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.errors import SynapseError
from synapse.api.errors import SynapseError, Codes, CodeMessageException
from synapse.api.events.room import RoomAliasesEvent
import logging
@ -84,6 +84,7 @@ class DirectoryHandler(BaseHandler):
room_id = result.room_id
servers = result.servers
else:
try:
result = yield self.federation.make_query(
destination=room_alias.domain,
query_type="directory",
@ -92,14 +93,23 @@ class DirectoryHandler(BaseHandler):
},
retry_on_dns_fail=False,
)
except CodeMessageException as e:
logging.warn("Error retrieving alias")
if e.code == 404:
result = None
else:
raise
if result and "room_id" in result and "servers" in result:
room_id = result["room_id"]
servers = result["servers"]
if not room_id:
defer.returnValue({})
return
raise SynapseError(
404,
"Room alias %r not found" % (room_alias.to_string(),),
Codes.NOT_FOUND
)
extra_servers = yield self.store.get_joined_hosts_for_room(room_id)
servers = list(set(extra_servers) | set(servers))
@ -128,8 +138,11 @@ class DirectoryHandler(BaseHandler):
"servers": result.servers,
})
else:
raise SynapseError(404, "Room alias \"%s\" not found", room_alias)
raise SynapseError(
404,
"Room alias %r not found" % (room_alias.to_string(),),
Codes.NOT_FOUND
)
@defer.inlineCallbacks
def send_room_alias_update_event(self, user_id, room_id):

View File

@ -56,7 +56,7 @@ class EventStreamHandler(BaseHandler):
self.clock.cancel_call_later(
self._stop_timer_per_user.pop(auth_user))
else:
self.distributor.fire(
yield self.distributor.fire(
"started_user_eventstream", auth_user
)
self._streams_per_user[auth_user] += 1
@ -65,8 +65,10 @@ class EventStreamHandler(BaseHandler):
pagin_config.from_token = None
rm_handler = self.hs.get_handlers().room_member_handler
logger.debug("BETA")
room_ids = yield rm_handler.get_rooms_for_user(auth_user)
logger.debug("ALPHA")
with PreserveLoggingContext():
events, tokens = yield self.notifier.get_events_for(
auth_user, room_ids, pagin_config, timeout
@ -93,7 +95,7 @@ class EventStreamHandler(BaseHandler):
logger.debug(
"_later stopped_user_eventstream %s", auth_user
)
self.distributor.fire(
yield self.distributor.fire(
"stopped_user_eventstream", auth_user
)
del self._stop_timer_per_user[auth_user]

View File

@ -24,7 +24,8 @@ from synapse.api.constants import Membership
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import (
compute_event_signature, check_event_content_hash
compute_event_signature, check_event_content_hash,
add_hashes_and_signatures,
)
from syutil.jsonutil import encode_canonical_json
@ -122,7 +123,8 @@ class FederationHandler(BaseHandler):
event.origin, redacted_pdu_json
)
except SynapseError as e:
logger.warn("Signature check failed for %s redacted to %s",
logger.warn(
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
@ -140,15 +142,27 @@ class FederationHandler(BaseHandler):
)
event = redacted_event
is_new_state = yield self.state_handler.annotate_event_with_state(
event,
old_state=state
)
logger.debug("Event: %s", event)
# FIXME (erikj): Awful hack to make the case where we are not currently
# in the room work
current_state = None
if state:
is_in_room = yield self.auth.check_host_in_room(
event.room_id,
self.server_name
)
if not is_in_room:
logger.debug("Got event for room we're not in.")
current_state = state
try:
self.auth.check(event, raises=True)
yield self._handle_new_event(
event,
state=state,
backfilled=backfilled,
current_state=current_state,
)
except AuthError as e:
raise FederationError(
"ERROR",
@ -157,43 +171,14 @@ class FederationHandler(BaseHandler):
affected=event.event_id,
)
is_new_state = is_new_state and not backfilled
# TODO: Implement something in federation that allows us to
# respond to PDU.
yield self.store.persist_event(
event,
backfilled,
is_new_state=is_new_state
)
room = yield self.store.get_room(event.room_id)
if not room:
# Huh, let's try and get the current state
try:
yield self.replication_layer.get_state_for_context(
event.origin, event.room_id, event.event_id,
)
hosts = yield self.store.get_joined_hosts_for_room(
event.room_id
)
if self.hs.hostname in hosts:
try:
yield self.store.store_room(
room_id=event.room_id,
room_creator_user_id="",
is_public=False,
)
except:
pass
except:
logger.exception(
"Failed to get current state for room %s",
event.room_id
)
if not backfilled:
extra_users = []
@ -209,7 +194,7 @@ class FederationHandler(BaseHandler):
if event.type == RoomMemberEvent.TYPE:
if event.membership == Membership.JOIN:
user = self.hs.parse_userid(event.state_key)
self.distributor.fire(
yield self.distributor.fire(
"user_joined_room", user=user, room_id=event.room_id
)
@ -254,6 +239,8 @@ class FederationHandler(BaseHandler):
pdu=event
)
defer.returnValue(pdu)
@defer.inlineCallbacks
@ -275,6 +262,8 @@ class FederationHandler(BaseHandler):
We suspend processing of any received events from this room until we
have finished processing the join.
"""
logger.debug("Joining %s to %s", joinee, room_id)
pdu = yield self.replication_layer.make_join(
target_host,
room_id,
@ -297,19 +286,28 @@ class FederationHandler(BaseHandler):
try:
event.event_id = self.event_factory.create_event_id()
event.origin = self.hs.hostname
event.content = content
state = yield self.replication_layer.send_join(
if not hasattr(event, "signatures"):
event.signatures = {}
add_hashes_and_signatures(
event,
self.hs.hostname,
self.hs.config.signing_key[0],
)
ret = yield self.replication_layer.send_join(
target_host,
event
)
logger.debug("do_invite_join state: %s", state)
state = ret["state"]
auth_chain = ret["auth_chain"]
yield self.state_handler.annotate_event_with_state(
event,
old_state=state
)
logger.debug("do_invite_join auth_chain: %s", auth_chain)
logger.debug("do_invite_join state: %s", state)
logger.debug("do_invite_join event: %s", event)
@ -323,34 +321,41 @@ class FederationHandler(BaseHandler):
# FIXME
pass
for e in auth_chain:
e.outlier = True
yield self._handle_new_event(e)
yield self.notifier.on_new_room_event(
e, extra_users=[joinee]
)
for e in state:
# FIXME: Auth these.
e.outlier = True
yield self.state_handler.annotate_event_with_state(
e,
yield self._handle_new_event(e)
yield self.notifier.on_new_room_event(
e, extra_users=[joinee]
)
yield self.store.persist_event(
e,
backfilled=False,
is_new_state=True
)
yield self.store.persist_event(
yield self._handle_new_event(
event,
backfilled=False,
is_new_state=True
state=state,
current_state=state
)
yield self.notifier.on_new_room_event(
event, extra_users=[joinee]
)
logger.debug("Finished joining %s to %s", joinee, room_id)
finally:
room_queue = self.room_queues[room_id]
del self.room_queues[room_id]
for p in room_queue:
try:
yield self.on_receive_pdu(p, backfilled=False)
self.on_receive_pdu(p, backfilled=False)
except:
pass
logger.exception("Couldn't handle pdu")
defer.returnValue(True)
@ -374,7 +379,7 @@ class FederationHandler(BaseHandler):
yield self.state_handler.annotate_event_with_state(event)
yield self.auth.add_auth_events(event)
self.auth.check(event, raises=True)
self.auth.check(event, auth_events=event.old_state_events)
pdu = event
@ -390,16 +395,7 @@ class FederationHandler(BaseHandler):
event.outlier = False
is_new_state = yield self.state_handler.annotate_event_with_state(event)
self.auth.check(event, raises=True)
# FIXME (erikj): All this is duplicated above :(
yield self.store.persist_event(
event,
backfilled=False,
is_new_state=is_new_state
)
yield self._handle_new_event(event)
extra_users = []
if event.type == RoomMemberEvent.TYPE:
@ -412,9 +408,9 @@ class FederationHandler(BaseHandler):
)
if event.type == RoomMemberEvent.TYPE:
if event.membership == Membership.JOIN:
if event.content["membership"] == Membership.JOIN:
user = self.hs.parse_userid(event.state_key)
self.distributor.fire(
yield self.distributor.fire(
"user_joined_room", user=user, room_id=event.room_id
)
@ -527,7 +523,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def get_persisted_pdu(self, origin, event_id):
def get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id.
Returns:
@ -539,6 +535,7 @@ class FederationHandler(BaseHandler):
)
if event:
if do_auth:
in_room = yield self.auth.check_host_in_room(
event.room_id,
origin
@ -562,3 +559,65 @@ class FederationHandler(BaseHandler):
)
while waiters:
waiters.pop().callback(None)
@defer.inlineCallbacks
def _handle_new_event(self, event, state=None, backfilled=False,
current_state=None):
if state:
for s in state:
yield self._handle_new_event(s)
is_new_state = yield self.state_handler.annotate_event_with_state(
event,
old_state=state
)
if event.old_state_events:
known_ids = set(
[s.event_id for s in event.old_state_events.values()]
)
for e_id, _ in event.auth_events:
if e_id not in known_ids:
e = yield self.store.get_event(
e_id,
allow_none=True,
)
if not e:
# TODO: Do some conflict res to make sure that we're
# not the ones who are wrong.
logger.info(
"Rejecting %s as %s not in %s",
event.event_id, e_id, known_ids,
)
raise AuthError(403, "Auth events are stale")
auth_events = event.old_state_events
else:
# We need to get the auth events from somewhere.
# TODO: Don't just hit the DBs?
auth_events = {}
for e_id, _ in event.auth_events:
e = yield self.store.get_event(
e_id,
allow_none=True,
)
if not e:
raise AuthError(
403,
"Can't find auth event %s." % (e_id, )
)
auth_events[(e.type, e.state_key)] = e
self.auth.check(event, auth_events=auth_events)
yield self.store.persist_event(
event,
backfilled=backfilled,
is_new_state=(is_new_state and not backfilled),
current_state=current_state,
)

View File

@ -17,13 +17,12 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.errors import LoginError, Codes
from synapse.http.client import IdentityServerHttpClient
from synapse.http.client import SimpleHttpClient
from synapse.util.emailutils import EmailException
import synapse.util.emailutils as emailutils
import bcrypt
import logging
import urllib
logger = logging.getLogger(__name__)
@ -97,10 +96,16 @@ class LoginHandler(BaseHandler):
@defer.inlineCallbacks
def _query_email(self, email):
httpCli = IdentityServerHttpClient(self.hs)
httpCli = SimpleHttpClient(self.hs)
data = yield httpCli.get_json(
'matrix.org:8090', # TODO FIXME This should be configurable.
"/_matrix/identity/api/v1/lookup?medium=email&address=" +
"%s" % urllib.quote(email)
# TODO FIXME This should be configurable.
# XXX: ID servers need to use HTTPS
"http://%s%s" % (
"matrix.org:8090", "/_matrix/identity/api/v1/lookup"
),
{
'medium': 'email',
'address': email
}
)
defer.returnValue(data)

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import RoomError
from synapse.streams.config import PaginationConfig
from synapse.util.logcontext import PreserveLoggingContext
from ._base import BaseHandler
import logging
@ -86,6 +87,7 @@ class MessageHandler(BaseHandler):
event, snapshot, suppress_auth=suppress_auth
)
with PreserveLoggingContext():
self.hs.get_handlers().presence_handler.bump_presence_active_time(
user
)
@ -241,7 +243,7 @@ class MessageHandler(BaseHandler):
public_room_ids = [r["room_id"] for r in public_rooms]
limit = pagin_config.limit
if not limit:
if limit is None:
limit = 10
for event in room_list:
@ -304,7 +306,7 @@ class MessageHandler(BaseHandler):
auth_user = self.hs.parse_userid(user_id)
# TODO: These concurrently
state_tuples = yield self.store.get_current_state(room_id)
state_tuples = yield self.state_handler.get_current_state(room_id)
state = [self.hs.serialize_event(x) for x in state_tuples]
member_event = (yield self.store.get_room_member(
@ -340,8 +342,8 @@ class MessageHandler(BaseHandler):
)
presence.append(member_presence)
except Exception:
logger.exception("Failed to get member presence of %r",
m.user_id
logger.exception(
"Failed to get member presence of %r", m.user_id
)
defer.returnValue({

View File

@ -19,6 +19,7 @@ from synapse.api.errors import SynapseError, AuthError
from synapse.api.constants import PresenceState
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from ._base import BaseHandler
@ -142,7 +143,7 @@ class PresenceHandler(BaseHandler):
return UserPresenceCache()
def registered_user(self, user):
self.store.create_presence(user.localpart)
return self.store.create_presence(user.localpart)
@defer.inlineCallbacks
def is_presence_visible(self, observer_user, observed_user):
@ -241,14 +242,12 @@ class PresenceHandler(BaseHandler):
was_level = self.STATE_LEVELS[statuscache.get_state()["presence"]]
now_level = self.STATE_LEVELS[state["presence"]]
yield defer.DeferredList([
self.store.set_presence_state(
yield self.store.set_presence_state(
target_user.localpart, state_to_store
),
self.distributor.fire(
)
yield self.distributor.fire(
"collect_presencelike_data", target_user, state
),
])
)
if now_level > was_level:
state["last_active"] = self.clock.time_msec()
@ -256,6 +255,7 @@ class PresenceHandler(BaseHandler):
now_online = state["presence"] != PresenceState.OFFLINE
was_polling = target_user in self._user_cachemap
with PreserveLoggingContext():
if now_online and not was_polling:
self.start_polling_presence(target_user, state=state)
elif not now_online and was_polling:
@ -277,7 +277,7 @@ class PresenceHandler(BaseHandler):
self._user_cachemap_latest_serial += 1
statuscache.update(state, serial=self._user_cachemap_latest_serial)
self.push_presence(user, statuscache=statuscache)
return self.push_presence(user, statuscache=statuscache)
@log_function
def started_user_eventstream(self, user):
@ -381,8 +381,10 @@ class PresenceHandler(BaseHandler):
yield self.store.set_presence_list_accepted(
observer_user.localpart, observed_user.to_string()
)
self.start_polling_presence(observer_user, target_user=observed_user)
with PreserveLoggingContext():
self.start_polling_presence(
observer_user, target_user=observed_user
)
@defer.inlineCallbacks
def deny_presence(self, observed_user, observer_user):
@ -401,7 +403,10 @@ class PresenceHandler(BaseHandler):
observer_user.localpart, observed_user.to_string()
)
self.stop_polling_presence(observer_user, target_user=observed_user)
with PreserveLoggingContext():
self.stop_polling_presence(
observer_user, target_user=observed_user
)
@defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None):
@ -710,6 +715,7 @@ class PresenceHandler(BaseHandler):
if not self._remote_sendmap[user]:
del self._remote_sendmap[user]
with PreserveLoggingContext():
yield defer.DeferredList(deferreds)
@defer.inlineCallbacks

View File

@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import Membership
from synapse.util.logcontext import PreserveLoggingContext
from ._base import BaseHandler
@ -46,7 +47,7 @@ class ProfileHandler(BaseHandler):
)
def registered_user(self, user):
self.store.create_profile(user.localpart)
return self.store.create_profile(user.localpart)
@defer.inlineCallbacks
def get_displayname(self, target_user):
@ -152,6 +153,7 @@ class ProfileHandler(BaseHandler):
if not user.is_mine:
defer.returnValue(None)
with PreserveLoggingContext():
(displayname, avatar_url) = yield defer.gatherResults(
[
self.store.get_profile_displayname(user.localpart),

View File

@ -22,7 +22,7 @@ from synapse.api.errors import (
)
from ._base import BaseHandler
import synapse.util.stringutils as stringutils
from synapse.http.client import IdentityServerHttpClient
from synapse.http.client import SimpleHttpClient
from synapse.http.client import CaptchaServerHttpClient
import base64
@ -69,7 +69,7 @@ class RegistrationHandler(BaseHandler):
password_hash=password_hash
)
self.distributor.fire("registered_user", user)
yield self.distributor.fire("registered_user", user)
else:
# autogen a random user ID
attempts = 0
@ -133,7 +133,7 @@ class RegistrationHandler(BaseHandler):
if not threepid:
raise RegistrationError(400, "Couldn't validate 3pid")
logger.info("got threepid medium %s address %s",
logger.info("got threepid with medium '%s' and address '%s'",
threepid['medium'], threepid['address'])
@defer.inlineCallbacks
@ -159,7 +159,7 @@ class RegistrationHandler(BaseHandler):
def _threepid_from_creds(self, creds):
# TODO: get this from the homeserver rather than creating a new one for
# each request
httpCli = IdentityServerHttpClient(self.hs)
httpCli = SimpleHttpClient(self.hs)
# XXX: make this configurable!
trustedIdServers = ['matrix.org:8090']
if not creds['idServer'] in trustedIdServers:
@ -167,8 +167,11 @@ class RegistrationHandler(BaseHandler):
'credentials', creds['idServer'])
defer.returnValue(None)
data = yield httpCli.get_json(
# XXX: This should be HTTPS
"http://%s%s" % (
creds['idServer'],
"/_matrix/identity/api/v1/3pid/getValidated3pid",
"/_matrix/identity/api/v1/3pid/getValidated3pid"
),
{'sid': creds['sid'], 'clientSecret': creds['clientSecret']}
)
@ -178,16 +181,21 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks
def _bind_threepid(self, creds, mxid):
httpCli = IdentityServerHttpClient(self.hs)
yield
logger.debug("binding threepid")
httpCli = SimpleHttpClient(self.hs)
data = yield httpCli.post_urlencoded_get_json(
creds['idServer'],
"/_matrix/identity/api/v1/3pid/bind",
# XXX: Change when ID servers are all HTTPS
"http://%s%s" % (
creds['idServer'], "/_matrix/identity/api/v1/3pid/bind"
),
{
'sid': creds['sid'],
'clientSecret': creds['clientSecret'],
'mxid': mxid,
}
)
logger.debug("bound threepid")
defer.returnValue(data)
@defer.inlineCallbacks
@ -215,10 +223,7 @@ class RegistrationHandler(BaseHandler):
# each request
client = CaptchaServerHttpClient(self.hs)
data = yield client.post_urlencoded_get_raw(
"www.google.com:80",
"/recaptcha/api/verify",
# twisted dislikes google's response, no content length.
accept_partial=True,
"http://www.google.com:80/recaptcha/api/verify",
args={
'privatekey': private_key,
'remoteip': ip_addr,

View File

@ -178,7 +178,9 @@ class RoomCreationHandler(BaseHandler):
if room_alias:
result["room_alias"] = room_alias.to_string()
directory_handler.send_room_alias_update_event(user_id, room_id)
yield directory_handler.send_room_alias_update_event(
user_id, room_id
)
defer.returnValue(result)
@ -211,7 +213,6 @@ class RoomCreationHandler(BaseHandler):
**event_keys
)
power_levels_event = self.event_factory.create_event(
etype=RoomPowerLevelsEvent.TYPE,
content={
@ -480,7 +481,7 @@ class RoomMemberHandler(BaseHandler):
)
user = self.hs.parse_userid(event.user_id)
self.distributor.fire(
yield self.distributor.fire(
"user_joined_room", user=user, room_id=room_id
)

View File

@ -15,308 +15,45 @@
from twisted.internet import defer, reactor
from twisted.internet.error import DNSLookupError
from twisted.web.client import (
_AgentBase, _URI, readBody, FileBodyProducer, PartialDownloadError
Agent, readBody, FileBodyProducer, PartialDownloadError
)
from twisted.web.http_headers import Headers
from synapse.http.endpoint import matrix_endpoint
from synapse.util.async import sleep
from synapse.util.logcontext import PreserveLoggingContext
from syutil.jsonutil import encode_canonical_json
from synapse.api.errors import CodeMessageException, SynapseError
from syutil.crypto.jsonsign import sign_json
from StringIO import StringIO
import json
import logging
import urllib
import urlparse
logger = logging.getLogger(__name__)
class MatrixHttpAgent(_AgentBase):
def __init__(self, reactor, pool=None):
_AgentBase.__init__(self, reactor, pool)
def request(self, destination, endpoint, method, path, params, query,
headers, body_producer):
host = b""
port = 0
fragment = b""
parsed_URI = _URI(b"http", destination, host, port, path, params,
query, fragment)
# Set the connection pool key to be the destination.
key = destination
return self._requestWithEndpoint(key, endpoint, method, parsed_URI,
headers, body_producer,
parsed_URI.originForm)
class BaseHttpClient(object):
"""Base class for HTTP clients using twisted.
class SimpleHttpClient(object):
"""
A simple, no-frills HTTP client with methods that wrap up common ways of
using HTTP in Matrix
"""
def __init__(self, hs):
self.agent = MatrixHttpAgent(reactor)
self.hs = hs
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
self.agent = Agent(reactor)
@defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes,
body_callback, headers_dict={}, param_bytes=b"",
query_bytes=b"", retry_on_dns_fail=True):
""" Creates and sends a request to the given url
"""
headers_dict[b"User-Agent"] = [b"Synapse"]
headers_dict[b"Host"] = [destination]
url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "",)
)
logger.debug("Sending request to %s: %s %s",
destination, method, url_bytes)
logger.debug(
"Types: %s",
[
type(destination), type(method), type(path_bytes),
type(param_bytes),
type(query_bytes)
]
)
retries_left = 5
endpoint = self._getEndpoint(reactor, destination)
while True:
producer = None
if body_callback:
producer = body_callback(method, url_bytes, headers_dict)
try:
with PreserveLoggingContext():
response = yield self.agent.request(
destination,
endpoint,
method,
path_bytes,
param_bytes,
query_bytes,
Headers(headers_dict),
producer
)
logger.debug("Got response to %s", method)
break
except Exception as e:
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
logger.warn("DNS Lookup failed to %s with %s", destination,
e)
raise SynapseError(400, "Domain specified not found.")
logger.exception("Got error in _create_request")
_print_ex(e)
if retries_left:
yield sleep(2 ** (5 - retries_left))
retries_left -= 1
else:
raise
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
pass
else:
# :'(
# Update transactions table?
logger.error(
"Got response %d %s", response.code, response.phrase
)
raise CodeMessageException(
response.code, response.phrase
)
defer.returnValue(response)
class MatrixHttpClient(BaseHttpClient):
""" Wrapper around the twisted HTTP client api. Implements
Attributes:
agent (twisted.web.client.Agent): The twisted Agent used to send the
requests.
"""
RETRY_DNS_LOOKUP_FAILURES = "__retry_dns"
def __init__(self, hs):
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
BaseHttpClient.__init__(self, hs)
def sign_request(self, destination, method, url_bytes, headers_dict,
content=None):
request = {
"method": method,
"uri": url_bytes,
"origin": self.server_name,
"destination": destination,
}
if content is not None:
request["content"] = content
request = sign_json(request, self.server_name, self.signing_key)
auth_headers = []
for key, sig in request["signatures"][self.server_name].items():
auth_headers.append(bytes(
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
self.server_name, key, sig,
)
))
headers_dict[b"Authorization"] = auth_headers
@defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None):
""" Sends the specifed json data using PUT
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to
use as the request body.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. On a 4xx or 5xx error response a
CodeMessageException is raised.
"""
if not json_data_callback:
def json_data_callback():
return data
def body_callback(method, url_bytes, headers_dict):
json_data = json_data_callback()
self.sign_request(
destination, method, url_bytes, headers_dict, json_data
)
producer = _JsonProducer(json_data)
return producer
response = yield self._create_request(
destination.encode("ascii"),
"PUT",
path.encode("ascii"),
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
)
logger.debug("Getting resp body")
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue((response.code, body))
@defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
""" Get's some json from the given host homeserver and path
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
Returns:
Deferred: Succeeds when we get *any* HTTP response.
The result of the deferred is a tuple of `(code, response)`,
where `response` is a dict representing the decoded JSON body.
"""
logger.debug("get_json args: %s", args)
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, basestring):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True)
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
response = yield self._create_request(
destination.encode("ascii"),
"GET",
path.encode("ascii"),
query_bytes=query_bytes,
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail
)
body = yield readBody(response)
defer.returnValue(json.loads(body))
def _getEndpoint(self, reactor, destination):
return matrix_endpoint(
reactor, destination, timeout=10,
ssl_context_factory=self.hs.tls_context_factory
)
class IdentityServerHttpClient(BaseHttpClient):
"""Separate HTTP client for talking to the Identity servers since they
don't use SRV records and talk x-www-form-urlencoded rather than JSON.
"""
def _getEndpoint(self, reactor, destination):
#TODO: This should be talking TLS
return matrix_endpoint(reactor, destination, timeout=10)
@defer.inlineCallbacks
def post_urlencoded_get_json(self, destination, path, args={}):
def post_urlencoded_get_json(self, uri, args={}):
logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(args, True)
def body_callback(method, url_bytes, headers_dict):
return FileBodyProducer(StringIO(query_bytes))
response = yield self._create_request(
destination.encode("ascii"),
response = yield self.agent.request(
"POST",
path.encode("ascii"),
body_callback=body_callback,
headers_dict={
uri.encode("ascii"),
headers=Headers({
"Content-Type": ["application/x-www-form-urlencoded"]
}
}),
bodyProducer=FileBodyProducer(StringIO(query_bytes))
)
body = yield readBody(response)
@ -324,13 +61,11 @@ class IdentityServerHttpClient(BaseHttpClient):
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
""" Get's some json from the given host homeserver and path
def get_json(self, uri, args={}):
""" Get's some json from the given host and path
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
uri (str): The URI to request, not including query parameters
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
@ -342,18 +77,15 @@ class IdentityServerHttpClient(BaseHttpClient):
The result of the deferred is a tuple of `(code, response)`,
where `response` is a dict representing the decoded JSON body.
"""
logger.debug("get_json args: %s", args)
yield
if len(args):
query_bytes = urllib.urlencode(args, True)
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
uri = "%s?%s" % (uri, query_bytes)
response = yield self._create_request(
destination.encode("ascii"),
response = yield self.agent.request(
"GET",
path.encode("ascii"),
query_bytes=query_bytes,
retry_on_dns_fail=retry_on_dns_fail,
body_callback=None
uri.encode("ascii"),
)
body = yield readBody(response)
@ -361,38 +93,31 @@ class IdentityServerHttpClient(BaseHttpClient):
defer.returnValue(json.loads(body))
class CaptchaServerHttpClient(MatrixHttpClient):
"""Separate HTTP client for talking to google's captcha servers"""
def _getEndpoint(self, reactor, destination):
return matrix_endpoint(reactor, destination, timeout=10)
class CaptchaServerHttpClient(SimpleHttpClient):
"""
Separate HTTP client for talking to google's captcha servers
Only slightly special because accepts partial download responses
"""
@defer.inlineCallbacks
def post_urlencoded_get_raw(self, destination, path, accept_partial=False,
args={}):
def post_urlencoded_get_raw(self, url, args={}):
query_bytes = urllib.urlencode(args, True)
def body_callback(method, url_bytes, headers_dict):
return FileBodyProducer(StringIO(query_bytes))
response = yield self._create_request(
destination.encode("ascii"),
response = yield self.agent.request(
"POST",
path.encode("ascii"),
body_callback=body_callback,
headers_dict={
url.encode("ascii"),
bodyProducer=FileBodyProducer(StringIO(query_bytes)),
headers=Headers({
"Content-Type": ["application/x-www-form-urlencoded"]
}
})
)
try:
body = yield readBody(response)
defer.returnValue(body)
except PartialDownloadError as e:
if accept_partial:
# twisted dislikes google's response, no content length.
defer.returnValue(e.response)
else:
raise e
def _print_ex(e):
@ -401,24 +126,3 @@ def _print_ex(e):
_print_ex(ex)
else:
logger.exception(e)
class _JsonProducer(object):
""" Used by the twisted http client to create the HTTP body from json
"""
def __init__(self, jsn):
self.reset(jsn)
def reset(self, jsn):
self.body = encode_canonical_json(jsn)
self.length = len(self.body)
def startProducing(self, consumer):
consumer.write(self.body)
return defer.succeed(None)
def pauseProducing(self):
pass
def stopProducing(self):
pass

View File

@ -131,11 +131,13 @@ class ContentRepoResource(resource.Resource):
request.setHeader('Content-Type', content_type)
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to recommend
# caching as it's sensitive or private - or at least select private.
# don't bother setting Expires as all our matrix clients are smart enough to
# be happy with Cache-Control (right?)
request.setHeader('Cache-Control', 'public,max-age=86400,s-maxage=86400')
# XXX: we might want to turn this off for data we don't want to
# recommend caching as it's sensitive or private - or at least
# select private. don't bother setting Expires as all our matrix
# clients are smart enough to be happy with Cache-Control (right?)
request.setHeader(
"Cache-Control", "public,max-age=86400,s-maxage=86400"
)
d = FileSender().beginFileTransfer(f, request)
@ -179,7 +181,7 @@ class ContentRepoResource(resource.Resource):
fname = yield self.map_request_to_name(request)
# TODO I have a suspcious feeling this is just going to block
# TODO I have a suspicious feeling this is just going to block
with open(fname, "wb") as f:
f.write(request.content.read())
@ -188,7 +190,7 @@ class ContentRepoResource(resource.Resource):
# FIXME: we can't assume what the repo's public mounted path is
# ...plus self-signed SSL won't work to remote clients anyway
# ...and we can't assume that it's SSL anyway, as we might want to
# server it via the non-SSL listener...
# serve it via the non-SSL listener...
url = "%s/_matrix/content/%s" % (
self.external_addr, file_name
)

View File

@ -27,7 +27,7 @@ import random
logger = logging.getLogger(__name__)
def matrix_endpoint(reactor, destination, ssl_context_factory=None,
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout=None):
"""Construct an endpoint for the given matrix destination.

View File

@ -0,0 +1,308 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, reactor
from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, _AgentBase, _URI
from twisted.web.http_headers import Headers
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep
from synapse.util.logcontext import PreserveLoggingContext
from syutil.jsonutil import encode_canonical_json
from synapse.api.errors import CodeMessageException, SynapseError
from syutil.crypto.jsonsign import sign_json
import json
import logging
import urllib
import urlparse
logger = logging.getLogger(__name__)
class MatrixFederationHttpAgent(_AgentBase):
def __init__(self, reactor, pool=None):
_AgentBase.__init__(self, reactor, pool)
def request(self, destination, endpoint, method, path, params, query,
headers, body_producer):
host = b""
port = 0
fragment = b""
parsed_URI = _URI(b"http", destination, host, port, path, params,
query, fragment)
# Set the connection pool key to be the destination.
key = destination
return self._requestWithEndpoint(key, endpoint, method, parsed_URI,
headers, body_producer,
parsed_URI.originForm)
class MatrixFederationHttpClient(object):
"""HTTP client used to talk to other homeservers over the federation
protocol. Send client certificates and signs requests.
Attributes:
agent (twisted.web.client.Agent): The twisted Agent used to send the
requests.
"""
def __init__(self, hs):
self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
self.agent = MatrixFederationHttpAgent(reactor)
@defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes,
body_callback, headers_dict={}, param_bytes=b"",
query_bytes=b"", retry_on_dns_fail=True):
""" Creates and sends a request to the given url
"""
headers_dict[b"User-Agent"] = [b"Synapse"]
headers_dict[b"Host"] = [destination]
url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "",)
)
logger.debug("Sending request to %s: %s %s",
destination, method, url_bytes)
logger.debug(
"Types: %s",
[
type(destination), type(method), type(path_bytes),
type(param_bytes),
type(query_bytes)
]
)
retries_left = 5
endpoint = self._getEndpoint(reactor, destination)
while True:
producer = None
if body_callback:
producer = body_callback(method, url_bytes, headers_dict)
try:
with PreserveLoggingContext():
response = yield self.agent.request(
destination,
endpoint,
method,
path_bytes,
param_bytes,
query_bytes,
Headers(headers_dict),
producer
)
logger.debug("Got response to %s", method)
break
except Exception as e:
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
logger.warn("DNS Lookup failed to %s with %s", destination,
e)
raise SynapseError(400, "Domain specified not found.")
logger.exception("Got error in _create_request")
_print_ex(e)
if retries_left:
yield sleep(2 ** (5 - retries_left))
retries_left -= 1
else:
raise
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
pass
else:
# :'(
# Update transactions table?
logger.error(
"Got response %d %s", response.code, response.phrase
)
raise CodeMessageException(
response.code, response.phrase
)
defer.returnValue(response)
def sign_request(self, destination, method, url_bytes, headers_dict,
content=None):
request = {
"method": method,
"uri": url_bytes,
"origin": self.server_name,
"destination": destination,
}
if content is not None:
request["content"] = content
request = sign_json(request, self.server_name, self.signing_key)
auth_headers = []
for key, sig in request["signatures"][self.server_name].items():
auth_headers.append(bytes(
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
self.server_name, key, sig,
)
))
headers_dict[b"Authorization"] = auth_headers
@defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None):
""" Sends the specifed json data using PUT
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to
use as the request body.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. On a 4xx or 5xx error response a
CodeMessageException is raised.
"""
if not json_data_callback:
def json_data_callback():
return data
def body_callback(method, url_bytes, headers_dict):
json_data = json_data_callback()
self.sign_request(
destination, method, url_bytes, headers_dict, json_data
)
producer = _JsonProducer(json_data)
return producer
response = yield self._create_request(
destination.encode("ascii"),
"PUT",
path.encode("ascii"),
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
)
logger.debug("Getting resp body")
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue((response.code, body))
@defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
""" Get's some json from the given host homeserver and path
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
Returns:
Deferred: Succeeds when we get *any* HTTP response.
The result of the deferred is a tuple of `(code, response)`,
where `response` is a dict representing the decoded JSON body.
"""
logger.debug("get_json args: %s", args)
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, basestring):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True)
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
response = yield self._create_request(
destination.encode("ascii"),
"GET",
path.encode("ascii"),
query_bytes=query_bytes,
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail
)
body = yield readBody(response)
defer.returnValue(json.loads(body))
def _getEndpoint(self, reactor, destination):
return matrix_federation_endpoint(
reactor, destination, timeout=10,
ssl_context_factory=self.hs.tls_context_factory
)
def _print_ex(e):
if hasattr(e, "reasons") and e.reasons:
for ex in e.reasons:
_print_ex(ex)
else:
logger.exception(e)
class _JsonProducer(object):
""" Used by the twisted http client to create the HTTP body from json
"""
def __init__(self, jsn):
self.reset(jsn)
def reset(self, jsn):
self.body = encode_canonical_json(jsn)
self.length = len(self.body)
def startProducing(self, consumer):
consumer.write(self.body)
return defer.succeed(None)
def pauseProducing(self):
pass
def stopProducing(self):
pass

View File

@ -138,8 +138,7 @@ class JsonResource(HttpServer, resource.Resource):
)
except CodeMessageException as e:
if isinstance(e, SynapseError):
logger.error("%s SynapseError: %s - %s", request, e.code,
e.msg)
logger.info("%s SynapseError: %s - %s", request, e.code, e.msg)
else:
logger.exception(e)
self._send_response(

View File

@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.async import run_on_reactor
import logging
@ -96,6 +97,7 @@ class Notifier(object):
listening to the room, and any listeners for the users in the
`extra_users` param.
"""
yield run_on_reactor()
room_id = event.room_id
room_source = self.event_sources.sources["room"]
@ -143,6 +145,7 @@ class Notifier(object):
Will wake up all listeners for the given users and rooms.
"""
yield run_on_reactor()
presence_source = self.event_sources.sources["presence"]
listeners = set()
@ -211,6 +214,7 @@ class Notifier(object):
timeout,
deferred,
)
def _timeout_listener():
# TODO (erikj): We should probably set to_token to the current
# max rather than reusing from_token.

View File

@ -26,7 +26,6 @@ import logging
logger = logging.getLogger(__name__)
class EventStreamRestServlet(RestServlet):
PATTERN = client_path_pattern("/events$")

View File

@ -117,8 +117,6 @@ class PresenceListRestServlet(RestServlet):
logger.exception("JSON parse error")
raise SynapseError(400, "Unable to parse content")
deferreds = []
if "invite" in content:
for u in content["invite"]:
if not isinstance(u, basestring):
@ -126,8 +124,9 @@ class PresenceListRestServlet(RestServlet):
if len(u) == 0:
continue
invited_user = self.hs.parse_userid(u)
deferreds.append(self.handlers.presence_handler.send_invite(
observer_user=user, observed_user=invited_user))
yield self.handlers.presence_handler.send_invite(
observer_user=user, observed_user=invited_user
)
if "drop" in content:
for u in content["drop"]:
@ -136,10 +135,9 @@ class PresenceListRestServlet(RestServlet):
if len(u) == 0:
continue
dropped_user = self.hs.parse_userid(u)
deferreds.append(self.handlers.presence_handler.drop(
observer_user=user, observed_user=dropped_user))
yield defer.DeferredList(deferreds)
yield self.handlers.presence_handler.drop(
observer_user=user, observed_user=dropped_user
)
defer.returnValue((200, {}))

View File

@ -222,6 +222,7 @@ class RegisterRestServlet(RestServlet):
threepidCreds = register_json['threepidCreds']
handler = self.handlers.registration_handler
logger.debug("Registering email. threepidcreds: %s" % (threepidCreds))
yield handler.register_email(threepidCreds)
session["threepidCreds"] = threepidCreds # store creds for next stage
session[LoginType.EMAIL_IDENTITY] = True # mark email as done
@ -232,6 +233,7 @@ class RegisterRestServlet(RestServlet):
@defer.inlineCallbacks
def _do_password(self, request, register_json, session):
yield
if (self.hs.config.enable_registration_captcha and
not session[LoginType.RECAPTCHA]):
# captcha should've been done by this stage!
@ -259,6 +261,9 @@ class RegisterRestServlet(RestServlet):
)
if session[LoginType.EMAIL_IDENTITY]:
logger.debug("Binding emails %s to %s" % (
session["threepidCreds"], user_id)
)
yield handler.bind_emails(user_id, session["threepidCreds"])
result = {

View File

@ -148,7 +148,7 @@ class RoomStateEventRestServlet(RestServlet):
content = _parse_json(request)
event = self.event_factory.create_event(
etype=urllib.unquote(event_type),
etype=event_type, # already urldecoded
content=content,
room_id=urllib.unquote(room_id),
user_id=user.to_string(),

View File

@ -82,7 +82,7 @@ class StateHandler(object):
if hasattr(event, "outlier") and event.outlier:
event.state_group = None
event.old_state_events = None
event.state_events = {}
event.state_events = None
defer.returnValue(False)
return

View File

@ -67,7 +67,7 @@ SCHEMAS = [
# Remember to update this number every time an incompatible change is made to
# database schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 7
SCHEMA_VERSION = 8
class _RollbackButIsFineException(Exception):
@ -93,7 +93,8 @@ class DataStore(RoomMemberStore, RoomStore,
@defer.inlineCallbacks
@log_function
def persist_event(self, event, backfilled=False, is_new_state=True):
def persist_event(self, event, backfilled=False, is_new_state=True,
current_state=None):
stream_ordering = None
if backfilled:
if not self.min_token_deferred.called:
@ -109,6 +110,7 @@ class DataStore(RoomMemberStore, RoomStore,
backfilled=backfilled,
stream_ordering=stream_ordering,
is_new_state=is_new_state,
current_state=current_state,
)
except _RollbackButIsFineException:
pass
@ -137,7 +139,7 @@ class DataStore(RoomMemberStore, RoomStore,
@log_function
def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None,
is_new_state=True):
is_new_state=True, current_state=None):
if event.type == RoomMemberEvent.TYPE:
self._store_room_member_txn(txn, event)
elif event.type == FeedbackEvent.TYPE:
@ -206,8 +208,24 @@ class DataStore(RoomMemberStore, RoomStore,
self._store_state_groups_txn(txn, event)
if current_state:
txn.execute("DELETE FROM current_state_events")
for s in current_state:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": s.event_id,
"room_id": s.room_id,
"type": s.type,
"state_key": s.state_key,
},
or_replace=True,
)
is_state = hasattr(event, "state_key") and event.state_key is not None
if is_new_state and is_state:
if is_state:
vals = {
"event_id": event.event_id,
"room_id": event.room_id,
@ -225,6 +243,7 @@ class DataStore(RoomMemberStore, RoomStore,
or_replace=True,
)
if is_new_state:
self._simple_insert_txn(
txn,
"current_state_events",
@ -312,7 +331,12 @@ class DataStore(RoomMemberStore, RoomStore,
txn, event.event_id, ref_alg, ref_hash_bytes
)
self._update_min_depth_for_room_txn(txn, event.room_id, event.depth)
if not outlier:
self._update_min_depth_for_room_txn(
txn,
event.room_id,
event.depth
)
def _store_redaction(self, txn, event):
txn.execute(
@ -508,7 +532,7 @@ def prepare_database(db_conn):
"new for the server to understand"
)
elif user_version < SCHEMA_VERSION:
logging.info(
logger.info(
"Upgrading database from version %d",
user_version
)

View File

@ -57,7 +57,7 @@ class LoggingTransaction(object):
if args and args[0]:
values = args[0]
sql_logger.debug(
"[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)),
"[SQL values] {%s} " + ", ".join(("<%r>",) * len(values)),
self.name,
*values
)
@ -91,6 +91,7 @@ class SQLBaseStore(object):
def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
current_context = LoggingContext.current_context()
def inner_func(txn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
current_context.copy_to(context)
@ -115,7 +116,6 @@ class SQLBaseStore(object):
"[TXN END] {%s} %f",
name, end - start
)
with PreserveLoggingContext():
result = yield self._db_pool.runInteraction(
inner_func, *args, **kwargs
@ -246,7 +246,10 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found")
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
sql = (
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s "
"ORDER BY rowid asc"
) % {
"retcol": retcol,
"table": table,
"where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
@ -299,7 +302,7 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
"""
sql = "SELECT %s FROM %s WHERE %s" % (
sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
@ -334,7 +337,7 @@ class SQLBaseStore(object):
retcols=None, allow_none=False):
""" Combined SELECT then UPDATE."""
if retcols:
select_sql = "SELECT %s FROM %s WHERE %s" % (
select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k) for k in keyvalues)
@ -461,7 +464,7 @@ class SQLBaseStore(object):
def _get_events_txn(self, txn, event_ids):
# FIXME (erikj): This should be batched?
sql = "SELECT * FROM events WHERE event_id = ?"
sql = "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
event_rows = []
for e_id in event_ids:
@ -478,7 +481,9 @@ class SQLBaseStore(object):
def _parse_events_txn(self, txn, rows):
events = [self._parse_event_from_row(r) for r in rows]
select_event_sql = "SELECT * FROM events WHERE event_id = ?"
select_event_sql = (
"SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
)
for i, ev in enumerate(events):
signatures = self._get_event_signatures_txn(

View File

@ -75,7 +75,9 @@ class RegistrationStore(SQLBaseStore):
"VALUES (?,?,?)",
[user_id, password_hash, now])
except IntegrityError:
raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
# it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID

View File

@ -27,7 +27,9 @@ import logging
logger = logging.getLogger(__name__)
OpsLevel = collections.namedtuple("OpsLevel", ("ban_level", "kick_level", "redact_level"))
OpsLevel = collections.namedtuple("OpsLevel", (
"ban_level", "kick_level", "redact_level")
)
class RoomStore(SQLBaseStore):

View File

@ -177,8 +177,8 @@ class RoomMemberStore(SQLBaseStore):
return self._get_members_query(clause, vals)
def _get_members_query(self, where_clause, where_values):
return self._db_pool.runInteraction(
self._get_members_query_txn,
return self.runInteraction(
"get_members_query", self._get_members_query_txn,
where_clause, where_values
)

View File

@ -0,0 +1,34 @@
/* Copyright 2014 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS event_signatures_2 (
event_id TEXT,
signature_name TEXT,
key_id TEXT,
signature BLOB,
CONSTRAINT uniqueness UNIQUE (event_id, signature_name, key_id)
);
INSERT INTO event_signatures_2 (event_id, signature_name, key_id, signature)
SELECT event_id, signature_name, key_id, signature FROM event_signatures;
DROP TABLE event_signatures;
ALTER TABLE event_signatures_2 RENAME TO event_signatures;
CREATE INDEX IF NOT EXISTS event_signatures_id ON event_signatures (
event_id
);
PRAGMA user_version = 8;

View File

@ -42,7 +42,7 @@ CREATE TABLE IF NOT EXISTS event_signatures (
signature_name TEXT,
key_id TEXT,
signature BLOB,
CONSTRAINT uniqueness UNIQUE (event_id, key_id)
CONSTRAINT uniqueness UNIQUE (event_id, signature_name, key_id)
);
CREATE INDEX IF NOT EXISTS event_signatures_id ON event_signatures (

View File

@ -87,7 +87,7 @@ class StateStore(SQLBaseStore):
)
def _store_state_groups_txn(self, txn, event):
if not event.state_events:
if event.state_events is None:
return
state_group = event.state_group

View File

@ -237,7 +237,8 @@ class StreamStore(SQLBaseStore):
sql = (
"SELECT *, (%(redacted)s) AS redacted FROM events"
" WHERE outlier = 0 AND room_id = ? AND %(bounds)s"
"ORDER BY topological_ordering %(order)s, stream_ordering %(order)s %(limit)s "
" ORDER BY topological_ordering %(order)s,"
" stream_ordering %(order)s %(limit)s"
) % {
"redacted": del_sql,
"bounds": bounds,

View File

@ -28,11 +28,11 @@ class SourcePaginationConfig(object):
specific event source."""
def __init__(self, from_key=None, to_key=None, direction='f',
limit=0):
limit=None):
self.from_key = from_key
self.to_key = to_key
self.direction = 'f' if direction == 'f' else 'b'
self.limit = int(limit)
self.limit = int(limit) if limit is not None else None
class PaginationConfig(object):
@ -40,11 +40,11 @@ class PaginationConfig(object):
"""A configuration object which stores pagination parameters."""
def __init__(self, from_token=None, to_token=None, direction='f',
limit=0):
limit=None):
self.from_token = from_token
self.to_token = to_token
self.direction = 'f' if direction == 'f' else 'b'
self.limit = int(limit)
self.limit = int(limit) if limit is not None else None
@classmethod
def from_request(cls, request, raise_invalid_params=True):
@ -80,8 +80,8 @@ class PaginationConfig(object):
except:
raise SynapseError(400, "'to' paramater is invalid")
limit = get_param("limit", "0")
if not limit.isdigit():
limit = get_param("limit", None)
if limit is not None and not limit.isdigit():
raise SynapseError(400, "'limit' parameter must be an integer.")
try:

View File

@ -37,6 +37,7 @@ class Clock(object):
def call_later(self, delay, callback):
current_context = LoggingContext.current_context()
def wrapped_callback():
LoggingContext.thread_local.current_context = current_context
callback()

View File

@ -18,6 +18,7 @@ from twisted.internet import defer, reactor
from .logcontext import PreserveLoggingContext
@defer.inlineCallbacks
def sleep(seconds):
d = defer.Deferred()
@ -25,6 +26,7 @@ def sleep(seconds):
with PreserveLoggingContext():
yield d
def run_on_reactor():
""" This will cause the rest of the function to be invoked upon the next
iteration of the main loop

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer
import logging
@ -91,6 +93,7 @@ class Signal(object):
Each observer callable may return a Deferred."""
self.observers.append(observer)
@defer.inlineCallbacks
def fire(self, *args, **kwargs):
"""Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is
@ -98,6 +101,7 @@ class Signal(object):
Returns a Deferred that will complete when all the observers have
completed."""
with PreserveLoggingContext():
deferreds = []
for observer in self.observers:
d = defer.maybeDeferred(observer, *args, **kwargs)
@ -114,6 +118,7 @@ class Signal(object):
raise failure
deferreds.append(d.addErrback(eb))
return defer.DeferredList(
result = yield defer.DeferredList(
deferreds, fireOnOneErrback=not self.suppress_failures
)
defer.returnValue(result)

View File

@ -1,6 +1,8 @@
import threading
import logging
logger = logging.getLogger(__name__)
class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a
@ -53,8 +55,11 @@ class LoggingContext(object):
None to avoid suppressing any exeptions that were thrown.
"""
if self.thread_local.current_context is not self:
logging.error(
"Current logging context %s is not the expected context %s",
if self.thread_local.current_context is self.sentinel:
logger.debug("Expected logging context %s has been lost", self)
else:
logger.warn(
"Current logging context %s is not expected context %s",
self.thread_local.current_context,
self
)

View File

@ -83,20 +83,22 @@ class FederationTestCase(unittest.TestCase):
event_id="$a:b",
user_id="@a:b",
origin="b",
auth_events=[],
hashes={"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
)
self.datastore.persist_event.return_value = defer.succeed(None)
self.datastore.get_room.return_value = defer.succeed(True)
self.state_handler.annotate_event_with_state.return_value = (
defer.succeed(False)
)
def annotate(ev, old_state=None):
ev.old_state_events = []
return defer.succeed(False)
self.state_handler.annotate_event_with_state.side_effect = annotate
yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
self.datastore.persist_event.assert_called_once_with(
ANY, False, is_new_state=False
ANY, is_new_state=False, backfilled=False, current_state=None
)
self.state_handler.annotate_event_with_state.assert_called_once_with(
@ -104,7 +106,7 @@ class FederationTestCase(unittest.TestCase):
old_state=None,
)
self.auth.check.assert_called_once_with(ANY, raises=True)
self.auth.check.assert_called_once_with(ANY, auth_events={})
self.notifier.on_new_room_event.assert_called_once_with(
ANY,

View File

@ -120,7 +120,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.datastore.get_room_member.return_value = defer.succeed(None)
event.state_events = {
event.old_state_events = {
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
user_id="@alice:green",
room_id=room_id,
@ -129,9 +129,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
user_id="@bob:red",
room_id=room_id,
),
(RoomMemberEvent.TYPE, target_user_id): event,
}
event.state_events = event.old_state_events
event.state_events[(RoomMemberEvent.TYPE, target_user_id)] = event
# Actual invocation
yield self.room_member_handler.change_membership(event)
@ -187,6 +189,16 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
(RoomMemberEvent.TYPE, user_id): event,
}
event.old_state_events = {
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
user_id="@alice:green",
room_id=room_id,
),
}
event.state_events = event.old_state_events
event.state_events[(RoomMemberEvent.TYPE, user_id)] = event
# Actual invocation
yield self.room_member_handler.change_membership(event)

View File

@ -84,7 +84,8 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.assertEquals("Value", value)
self.mock_txn.execute.assert_called_with(
"SELECT retcol FROM tablename WHERE keycol = ?",
"SELECT retcol FROM tablename WHERE keycol = ? "
"ORDER BY rowid asc",
["TheKey"]
)
@ -101,7 +102,8 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
self.mock_txn.execute.assert_called_with(
"SELECT colA, colB, colC FROM tablename WHERE keycol = ?",
"SELECT colA, colB, colC FROM tablename WHERE keycol = ? "
"ORDER BY rowid asc",
["TheKey"]
)
@ -135,7 +137,8 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
self.mock_txn.execute.assert_called_with(
"SELECT colA FROM tablename WHERE keycol = ?",
"SELECT colA FROM tablename WHERE keycol = ? "
"ORDER BY rowid asc",
["A set"]
)
@ -184,7 +187,8 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.assertEquals({"columname": "Old Value"}, ret)
self.mock_txn.execute.assert_has_calls([
call('SELECT columname FROM tablename WHERE keycol = ?',
call('SELECT columname FROM tablename WHERE keycol = ? '
'ORDER BY rowid asc',
['TheKey']),
call("UPDATE tablename SET columname = ? WHERE keycol = ?",
["New Value", "TheKey"])