Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2017-01-20 14:00:31 +00:00
commit e41f183aa8
47 changed files with 1458 additions and 1002 deletions

View file

@ -138,6 +138,7 @@ Installing prerequisites on openSUSE::
python-devel libffi-devel libopenssl-devel libjpeg62-devel python-devel libffi-devel libopenssl-devel libjpeg62-devel
Installing prerequisites on OpenBSD:: Installing prerequisites on OpenBSD::
doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \ doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \
libxslt libxslt

View file

@ -16,18 +16,14 @@
import logging import logging
import pymacaroons import pymacaroons
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer from twisted.internet import defer
from unpaddedbase64 import decode_base64
import synapse.types import synapse.types
from synapse import event_auth
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.api.errors import AuthError, Codes
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -78,147 +74,7 @@ class Auth(object):
True if the auth checks pass. True if the auth checks pass.
""" """
with Measure(self.clock, "auth.check"): with Measure(self.clock, "auth.check"):
self.check_size_limits(event) event_auth.check(event, auth_events, do_sig_check=do_sig_check)
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
event_id_domain = get_domain_from_id(event.event_id)
is_invite_via_3pid = (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
and "third_party_invite" in event.content
)
# Check the sender's domain has signed the event
if not event.signatures.get(sender_domain):
# We allow invites via 3pid to have a sender from a different
# HS, as the sender must match the sender of the original
# 3pid invite. This is checked further down with the
# other dedicated membership checks.
if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server")
# Check the event_id's domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
logger.warn("Trusting event: %s", event.event_id)
return True
if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
)
# FIXME
return True
creation_event = auth_events.get((EventTypes.Create, ""), None)
if not creation_event:
raise SynapseError(
403,
"Room %r does not exist" % (event.room_id,)
)
creating_domain = get_domain_from_id(event.room_id)
originating_domain = get_domain_from_id(event.sender)
if creating_domain != originating_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
)
return True
logger.debug(
"Auth events: %s",
[a.event_id for a in auth_events.values()]
)
if event.type == EventTypes.Member:
allowed = self.is_membership_change_allowed(
event, auth_events
)
if allowed:
logger.debug("Allowing! %s", event)
else:
logger.debug("Denying! %s", event)
return allowed
self.check_event_sender_in_room(event, auth_events)
# Special case to allow m.room.third_party_invite events wherever
# a user is allowed to issue invites. Fixes
# https://github.com/vector-im/vector-web/issues/1208 hopefully
if event.type == EventTypes.ThirdPartyInvite:
user_level = self._get_user_power_level(event.user_id, auth_events)
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, (
"You cannot issue a third party invite for %s." %
(event.content.display_name,)
)
)
else:
return True
self._can_send_event(event, auth_events)
if event.type == EventTypes.PowerLevels:
self._check_power_levels(event, auth_events)
if event.type == EventTypes.Redaction:
self.check_redaction(event, auth_events)
logger.debug("Allowing! %s", event)
def check_size_limits(self, event):
def too_big(field):
raise EventSizeError("%s too large" % (field,))
if len(event.user_id) > 255:
too_big("user_id")
if len(event.room_id) > 255:
too_big("room_id")
if event.is_state() and len(event.state_key) > 255:
too_big("state_key")
if len(event.type) > 255:
too_big("type")
if len(event.event_id) > 255:
too_big("event_id")
if len(encode_canonical_json(event.get_pdu_json())) > 65536:
too_big("event")
@defer.inlineCallbacks @defer.inlineCallbacks
def check_joined_room(self, room_id, user_id, current_state=None): def check_joined_room(self, room_id, user_id, current_state=None):
@ -290,7 +146,7 @@ class Auth(object):
with Measure(self.clock, "check_host_in_room"): with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.info("calling resolve_state_groups from check_host_in_room") logger.debug("calling resolve_state_groups from check_host_in_room")
entry = yield self.state.resolve_state_groups( entry = yield self.state.resolve_state_groups(
room_id, latest_event_ids room_id, latest_event_ids
) )
@ -300,16 +156,6 @@ class Auth(object):
) )
defer.returnValue(ret) defer.returnValue(ret)
def check_event_sender_in_room(self, event, auth_events):
key = (EventTypes.Member, event.user_id, )
member_event = auth_events.get(key)
return self._check_joined_room(
member_event,
event.user_id,
event.room_id
)
def _check_joined_room(self, member, user_id, room_id): def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN: if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s (%s)" % ( raise AuthError(403, "User %s not in room %s (%s)" % (
@ -321,267 +167,8 @@ class Auth(object):
return creation_event.content.get("m.federate", True) is True return creation_event.content.get("m.federate", True) is True
@log_function
def is_membership_change_allowed(self, event, auth_events):
membership = event.content["membership"]
# Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership:
# Get room creation event:
key = (EventTypes.Create, "", )
create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key:
return True
target_user_id = event.state_key
creating_domain = get_domain_from_id(event.room_id)
target_domain = get_domain_from_id(target_user_id)
if creating_domain != target_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# get info about the caller
key = (EventTypes.Member, event.user_id, )
caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target
key = (EventTypes.Member, target_user_id, )
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
target_banned = target and target.membership == Membership.BAN
key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key)
if join_rule_event:
join_rule = join_rule_event.content.get(
"join_rule", JoinRules.INVITE
)
else:
join_rule = JoinRules.INVITE
user_level = self._get_user_power_level(event.user_id, auth_events)
target_level = self._get_user_power_level(
target_user_id, auth_events
)
# FIXME (erikj): What should we do here as the default?
ban_level = self._get_named_level(auth_events, "ban", 50)
logger.debug(
"is_membership_change_allowed: %s",
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
"target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
"join_rule": join_rule,
"target_user_id": target_user_id,
"event.user_id": event.user_id,
}
)
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True
if Membership.JOIN != membership:
if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
return True
if not caller_in_room: # caller isn't joined
raise AuthError(
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
# Invites are valid iff caller is in the room and target isn't.
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." %
target_user_id)
else:
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, "You cannot invite user %s." % target_user_id
)
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
elif target_banned:
raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.")
else:
# TODO (erikj): may_join list
# TODO (erikj): private rooms
raise AuthError(403, "You are not allowed to join this room")
elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
raise AuthError(
403, "You cannot unban user &s." % (target_user_id,)
)
elif target_user_id != event.user_id:
kick_level = self._get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
else:
raise AuthError(500, "Unknown membership %s" % membership)
return True
def _verify_third_party_invite(self, event, auth_events):
"""
Validates that the invite event is authorized by a previous third-party invite.
Checks that the public key, and keyserver, match those in the third party invite,
and that the invite event has a signature issued using that public key.
Args:
event: The m.room.member join event being validated.
auth_events: All relevant previous context events which may be used
for authorization decisions.
Return:
True if the event fulfills the expectations of a previous third party
invite event.
"""
if "third_party_invite" not in event.content:
return False
if "signed" not in event.content["third_party_invite"]:
return False
signed = event.content["third_party_invite"]["signed"]
for key in {"mxid", "token"}:
if key not in signed:
return False
token = signed["token"]
invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,)
)
if not invite_event:
return False
if invite_event.sender != event.sender:
return False
if event.user_id != invite_event.user_id:
return False
if signed["mxid"] != event.state_key:
return False
if signed["token"] != token:
return False
for public_key_object in self.get_public_keys(invite_event):
public_key = public_key_object["public_key"]
try:
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
continue
verify_key = decode_verify_key_bytes(
key_name,
decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
# We got the public key from the invite, so we know that the
# correct server signed the signed bundle.
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True
except (KeyError, SignatureVerifyException,):
continue
return False
def get_public_keys(self, invite_event): def get_public_keys(self, invite_event):
public_keys = [] return event_auth.get_public_keys(invite_event)
if "public_key" in invite_event.content:
o = {
"public_key": invite_event.content["public_key"],
}
if "key_validity_url" in invite_event.content:
o["key_validity_url"] = invite_event.content["key_validity_url"]
public_keys.append(o)
public_keys.extend(invite_event.content.get("public_keys", []))
return public_keys
def _get_power_level_event(self, auth_events):
key = (EventTypes.PowerLevels, "", )
return auth_events.get(key)
def _get_user_power_level(self, user_id, auth_events):
power_level_event = self._get_power_level_event(auth_events)
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
if not level:
level = power_level_event.content.get("users_default", 0)
if level is None:
return 0
else:
return int(level)
else:
key = (EventTypes.Create, "", )
create_event = auth_events.get(key)
if (create_event is not None and
create_event.content["creator"] == user_id):
return 100
else:
return 0
def _get_named_level(self, auth_events, name, default):
power_level_event = self._get_power_level_event(auth_events)
if not power_level_event:
return default
level = power_level_event.content.get(name, None)
if level is not None:
return int(level)
else:
return default
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_req(self, request, allow_guest=False, rights="access"): def get_user_by_req(self, request, allow_guest=False, rights="access"):
@ -974,56 +561,6 @@ class Auth(object):
defer.returnValue(auth_ids) defer.returnValue(auth_ids)
def _get_send_level(self, etype, state_key, auth_events):
key = (EventTypes.PowerLevels, "", )
send_level_event = auth_events.get(key)
send_level = None
if send_level_event:
send_level = send_level_event.content.get("events", {}).get(
etype
)
if send_level is None:
if state_key is not None:
send_level = send_level_event.content.get(
"state_default", 50
)
else:
send_level = send_level_event.content.get(
"events_default", 0
)
if send_level:
send_level = int(send_level)
else:
send_level = 0
return send_level
@log_function
def _can_send_event(self, event, auth_events):
send_level = self._get_send_level(
event.type, event.get("state_key", None), auth_events
)
user_level = self._get_user_power_level(event.user_id, auth_events)
if user_level < send_level:
raise AuthError(
403,
"You don't have permission to post that to the room. " +
"user_level (%d) < send_level (%d)" % (user_level, send_level)
)
# Check state_key
if hasattr(event, "state_key"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True
def check_redaction(self, event, auth_events): def check_redaction(self, event, auth_events):
"""Check whether the event sender is allowed to redact the target event. """Check whether the event sender is allowed to redact the target event.
@ -1037,107 +574,7 @@ class Auth(object):
AuthError if the event sender is definitely not allowed to redact AuthError if the event sender is definitely not allowed to redact
the target event. the target event.
""" """
user_level = self._get_user_power_level(event.user_id, auth_events) return event_auth.check_redaction(event, auth_events)
redact_level = self._get_named_level(auth_events, "redact", 50)
if user_level >= redact_level:
return False
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
return True
raise AuthError(
403,
"You don't have permission to redact events"
)
def _check_power_levels(self, event, auth_events):
user_list = event.content.get("users", {})
# Validate users
for k, v in user_list.items():
try:
UserID.from_string(k)
except:
raise SynapseError(400, "Not a valid user_id: %s" % (k,))
try:
int(v)
except:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
key = (event.type, event.state_key, )
current_state = auth_events.get(key)
if not current_state:
return
user_level = self._get_user_power_level(event.user_id, auth_events)
# Check other levels:
levels_to_check = [
("users_default", None),
("events_default", None),
("state_default", None),
("ban", None),
("redact", None),
("kick", None),
("invite", None),
]
old_list = current_state.content.get("users")
for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append(
(user, "users")
)
old_list = current_state.content.get("events")
new_list = event.content.get("events")
for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append(
(ev_id, "events")
)
old_state = current_state.content
new_state = event.content
for level_to_check, dir in levels_to_check:
old_loc = old_state
new_loc = new_state
if dir:
old_loc = old_loc.get(dir, {})
new_loc = new_loc.get(dir, {})
if level_to_check in old_loc:
old_level = int(old_loc[level_to_check])
else:
old_level = None
if level_to_check in new_loc:
new_level = int(new_loc[level_to_check])
else:
new_level = None
if new_level is not None and old_level is not None:
if new_level == old_level:
continue
if dir == "users" and level_to_check != event.user_id:
if old_level == user_level:
raise AuthError(
403,
"You don't have permission to remove ops level equal "
"to your own"
)
if old_level > user_level or new_level > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater "
"than your own"
)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_can_change_room_list(self, room_id, user): def check_can_change_room_list(self, room_id, user):
@ -1167,10 +604,10 @@ class Auth(object):
if power_level_event: if power_level_event:
auth_events[(EventTypes.PowerLevels, "")] = power_level_event auth_events[(EventTypes.PowerLevels, "")] = power_level_event
send_level = self._get_send_level( send_level = event_auth.get_send_level(
EventTypes.Aliases, "", auth_events EventTypes.Aliases, "", auth_events
) )
user_level = self._get_user_power_level(user_id, auth_events) user_level = event_auth.get_user_power_level(user_id, auth_events)
if user_level < send_level: if user_level < send_level:
raise AuthError( raise AuthError(

View file

@ -76,8 +76,7 @@ class AppserviceServer(HomeServer):
def _listen_http(self, listener_config): def _listen_http(self, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", None) bind_addresses = listener_config["bind_addresses"]
bind_addresses = listener_config.get("bind_addresses", [])
site_tag = listener_config.get("tag", port) site_tag = listener_config.get("tag", port)
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
@ -87,9 +86,6 @@ class AppserviceServer(HomeServer):
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses: for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
port, port,
@ -109,11 +105,7 @@ class AppserviceServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_address = listener.get("bind_address", None) bind_addresses = listener["bind_addresses"]
bind_addresses = listener.get("bind_addresses", [])
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses: for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(

View file

@ -90,8 +90,7 @@ class ClientReaderServer(HomeServer):
def _listen_http(self, listener_config): def _listen_http(self, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", None) bind_addresses = listener_config["bind_addresses"]
bind_addresses = listener_config.get("bind_addresses", [])
site_tag = listener_config.get("tag", port) site_tag = listener_config.get("tag", port)
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
@ -110,9 +109,6 @@ class ClientReaderServer(HomeServer):
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses: for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
port, port,
@ -132,11 +128,7 @@ class ClientReaderServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_address = listener.get("bind_address", None) bind_addresses = listener["bind_addresses"]
bind_addresses = listener.get("bind_addresses", [])
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses: for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(

View file

@ -86,8 +86,7 @@ class FederationReaderServer(HomeServer):
def _listen_http(self, listener_config): def _listen_http(self, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", None) bind_addresses = listener_config["bind_addresses"]
bind_addresses = listener_config.get("bind_addresses", [])
site_tag = listener_config.get("tag", port) site_tag = listener_config.get("tag", port)
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
@ -101,9 +100,6 @@ class FederationReaderServer(HomeServer):
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses: for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
port, port,
@ -123,11 +119,7 @@ class FederationReaderServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_address = listener.get("bind_address", None) bind_addresses = listener["bind_addresses"]
bind_addresses = listener.get("bind_addresses", [])
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses: for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(

View file

@ -82,8 +82,7 @@ class FederationSenderServer(HomeServer):
def _listen_http(self, listener_config): def _listen_http(self, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", None) bind_addresses = listener_config["bind_addresses"]
bind_addresses = listener_config.get("bind_addresses", [])
site_tag = listener_config.get("tag", port) site_tag = listener_config.get("tag", port)
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
@ -93,9 +92,6 @@ class FederationSenderServer(HomeServer):
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses: for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
port, port,
@ -115,11 +111,7 @@ class FederationSenderServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_address = listener.get("bind_address", None) bind_addresses = listener["bind_addresses"]
bind_addresses = listener.get("bind_addresses", [])
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses: for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(

View file

@ -107,8 +107,7 @@ def build_resource_for_web_client(hs):
class SynapseHomeServer(HomeServer): class SynapseHomeServer(HomeServer):
def _listener_http(self, config, listener_config): def _listener_http(self, config, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", None) bind_addresses = listener_config["bind_addresses"]
bind_addresses = listener_config.get("bind_addresses", [])
tls = listener_config.get("tls", False) tls = listener_config.get("tls", False)
site_tag = listener_config.get("tag", port) site_tag = listener_config.get("tag", port)
@ -175,9 +174,6 @@ class SynapseHomeServer(HomeServer):
root_resource = create_resource_tree(resources, root_resource) root_resource = create_resource_tree(resources, root_resource)
if bind_address is not None:
bind_addresses.append(bind_address)
if tls: if tls:
for address in bind_addresses: for address in bind_addresses:
reactor.listenSSL( reactor.listenSSL(
@ -212,11 +208,7 @@ class SynapseHomeServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listener_http(config, listener) self._listener_http(config, listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_address = listener.get("bind_address", None) bind_addresses = listener["bind_addresses"]
bind_addresses = listener.get("bind_addresses", [])
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses: for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(

View file

@ -87,8 +87,7 @@ class MediaRepositoryServer(HomeServer):
def _listen_http(self, listener_config): def _listen_http(self, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", None) bind_addresses = listener_config["bind_addresses"]
bind_addresses = listener_config.get("bind_addresses", [])
site_tag = listener_config.get("tag", port) site_tag = listener_config.get("tag", port)
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
@ -107,9 +106,6 @@ class MediaRepositoryServer(HomeServer):
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses: for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
port, port,
@ -129,11 +125,7 @@ class MediaRepositoryServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_address = listener.get("bind_address", None) bind_addresses = listener["bind_addresses"]
bind_addresses = listener.get("bind_addresses", [])
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses: for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(

View file

@ -121,8 +121,7 @@ class PusherServer(HomeServer):
def _listen_http(self, listener_config): def _listen_http(self, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", None) bind_addresses = listener_config["bind_addresses"]
bind_addresses = listener_config.get("bind_addresses", [])
site_tag = listener_config.get("tag", port) site_tag = listener_config.get("tag", port)
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
@ -132,9 +131,6 @@ class PusherServer(HomeServer):
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses: for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
port, port,
@ -154,11 +150,7 @@ class PusherServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_address = listener.get("bind_address", None) bind_addresses = listener["bind_addresses"]
bind_addresses = listener.get("bind_addresses", [])
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses: for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(

View file

@ -289,8 +289,7 @@ class SynchrotronServer(HomeServer):
def _listen_http(self, listener_config): def _listen_http(self, listener_config):
port = listener_config["port"] port = listener_config["port"]
bind_address = listener_config.get("bind_address", None) bind_addresses = listener_config["bind_addresses"]
bind_addresses = listener_config.get("bind_addresses", [])
site_tag = listener_config.get("tag", port) site_tag = listener_config.get("tag", port)
resources = {} resources = {}
for res in listener_config["resources"]: for res in listener_config["resources"]:
@ -312,9 +311,6 @@ class SynchrotronServer(HomeServer):
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses: for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(
port, port,
@ -334,11 +330,7 @@ class SynchrotronServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
bind_address = listener.get("bind_address", None) bind_addresses = listener["bind_addresses"]
bind_addresses = listener.get("bind_addresses", [])
if bind_address is not None:
bind_addresses.append(bind_address)
for address in bind_addresses: for address in bind_addresses:
reactor.listenTCP( reactor.listenTCP(

View file

@ -68,6 +68,9 @@ class EmailConfig(Config):
self.email_notif_for_new_users = email_config.get( self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True "notif_for_new_users", True
) )
self.email_riot_base_url = email_config.get(
"riot_base_url", None
)
if "app_name" in email_config: if "app_name" in email_config:
self.email_app_name = email_config["app_name"] self.email_app_name = email_config["app_name"]
else: else:
@ -85,6 +88,9 @@ class EmailConfig(Config):
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
return """ return """
# Enable sending emails for notification events # Enable sending emails for notification events
# Defining a custom URL for Riot is only needed if email notifications
# should contain links to a self-hosted installation of Riot; when set
# the "app_name" setting is ignored.
#email: #email:
# enable_notifs: false # enable_notifs: false
# smtp_host: "localhost" # smtp_host: "localhost"
@ -95,4 +101,5 @@ class EmailConfig(Config):
# notif_template_html: notif_mail.html # notif_template_html: notif_mail.html
# notif_template_text: notif_mail.txt # notif_template_text: notif_mail.txt
# notif_for_new_users: True # notif_for_new_users: True
# riot_base_url: "http://localhost/riot"
""" """

View file

@ -22,7 +22,6 @@ import yaml
from string import Template from string import Template
import os import os
import signal import signal
from synapse.util.debug import debug_deferreds
DEFAULT_LOG_CONFIG = Template(""" DEFAULT_LOG_CONFIG = Template("""
@ -71,8 +70,6 @@ class LoggingConfig(Config):
self.verbosity = config.get("verbose", 0) self.verbosity = config.get("verbose", 0)
self.log_config = self.abspath(config.get("log_config")) self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file")) self.log_file = self.abspath(config.get("log_file"))
if config.get("full_twisted_stacktraces"):
debug_deferreds()
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
log_file = self.abspath("homeserver.log") log_file = self.abspath("homeserver.log")
@ -88,11 +85,6 @@ class LoggingConfig(Config):
# A yaml python logging config file # A yaml python logging config file
log_config: "%(log_config)s" log_config: "%(log_config)s"
# Stop twisted from discarding the stack traces of exceptions in
# deferreds by waiting a reactor tick before running a deferred's
# callbacks.
# full_twisted_stacktraces: true
""" % locals() """ % locals()
def read_arguments(self, args): def read_arguments(self, args):

View file

@ -42,6 +42,15 @@ class ServerConfig(Config):
self.listeners = config.get("listeners", []) self.listeners = config.get("listeners", [])
for listener in self.listeners:
bind_address = listener.pop("bind_address", None)
bind_addresses = listener.setdefault("bind_addresses", [])
if bind_address:
bind_addresses.append(bind_address)
elif not bind_addresses:
bind_addresses.append('')
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None)) self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
bind_port = config.get("bind_port") bind_port = config.get("bind_port")
@ -54,7 +63,7 @@ class ServerConfig(Config):
self.listeners.append({ self.listeners.append({
"port": bind_port, "port": bind_port,
"bind_address": bind_host, "bind_addresses": [bind_host],
"tls": True, "tls": True,
"type": "http", "type": "http",
"resources": [ "resources": [
@ -73,7 +82,7 @@ class ServerConfig(Config):
if unsecure_port: if unsecure_port:
self.listeners.append({ self.listeners.append({
"port": unsecure_port, "port": unsecure_port,
"bind_address": bind_host, "bind_addresses": [bind_host],
"tls": False, "tls": False,
"type": "http", "type": "http",
"resources": [ "resources": [
@ -92,7 +101,7 @@ class ServerConfig(Config):
if manhole: if manhole:
self.listeners.append({ self.listeners.append({
"port": manhole, "port": manhole,
"bind_address": "127.0.0.1", "bind_addresses": ["127.0.0.1"],
"type": "manhole", "type": "manhole",
}) })
@ -100,7 +109,7 @@ class ServerConfig(Config):
if metrics_port: if metrics_port:
self.listeners.append({ self.listeners.append({
"port": metrics_port, "port": metrics_port,
"bind_address": config.get("metrics_bind_host", "127.0.0.1"), "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
"tls": False, "tls": False,
"type": "http", "type": "http",
"resources": [ "resources": [

View file

@ -19,7 +19,9 @@ class VoipConfig(Config):
def read_config(self, config): def read_config(self, config):
self.turn_uris = config.get("turn_uris", []) self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config["turn_shared_secret"] self.turn_shared_secret = config.get("turn_shared_secret")
self.turn_username = config.get("turn_username")
self.turn_password = config.get("turn_password")
self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"]) self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
def default_config(self, **kwargs): def default_config(self, **kwargs):
@ -32,6 +34,11 @@ class VoipConfig(Config):
# The shared secret used to compute passwords for the TURN server # The shared secret used to compute passwords for the TURN server
turn_shared_secret: "YOUR_SHARED_SECRET" turn_shared_secret: "YOUR_SHARED_SECRET"
# The Username and password if the TURN server needs them and
# does not use a token
#turn_username: "TURNSERVER_USERNAME"
#turn_password: "TURNSERVER_PASSWORD"
# How long generated TURN credentials last # How long generated TURN credentials last
turn_user_lifetime: "1h" turn_user_lifetime: "1h"
""" """

View file

@ -29,3 +29,13 @@ class WorkerConfig(Config):
self.worker_log_file = config.get("worker_log_file") self.worker_log_file = config.get("worker_log_file")
self.worker_log_config = config.get("worker_log_config") self.worker_log_config = config.get("worker_log_config")
self.worker_replication_url = config.get("worker_replication_url") self.worker_replication_url = config.get("worker_replication_url")
if self.worker_listeners:
for listener in self.worker_listeners:
bind_address = listener.pop("bind_address", None)
bind_addresses = listener.setdefault("bind_addresses", [])
if bind_address:
bind_addresses.append(bind_address)
elif not bind_addresses:
bind_addresses.append('')

678
synapse/event_auth.py Normal file
View file

@ -0,0 +1,678 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
from unpaddedbase64 import decode_base64
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, SynapseError, EventSizeError
from synapse.types import UserID, get_domain_from_id
logger = logging.getLogger(__name__)
def check(event, auth_events, do_sig_check=True, do_size_check=True):
""" Checks if this event is correctly authed.
Args:
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
Returns:
True if the auth checks pass.
"""
if do_size_check:
_check_size_limits(event)
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
event_id_domain = get_domain_from_id(event.event_id)
is_invite_via_3pid = (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
and "third_party_invite" in event.content
)
# Check the sender's domain has signed the event
if not event.signatures.get(sender_domain):
# We allow invites via 3pid to have a sender from a different
# HS, as the sender must match the sender of the original
# 3pid invite. This is checked further down with the
# other dedicated membership checks.
if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server")
# Check the event_id's domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
logger.warn("Trusting event: %s", event.event_id)
return True
if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
)
# FIXME
return True
creation_event = auth_events.get((EventTypes.Create, ""), None)
if not creation_event:
raise SynapseError(
403,
"Room %r does not exist" % (event.room_id,)
)
creating_domain = get_domain_from_id(event.room_id)
originating_domain = get_domain_from_id(event.sender)
if creating_domain != originating_domain:
if not _can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
)
return True
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Auth events: %s",
[a.event_id for a in auth_events.values()]
)
if event.type == EventTypes.Member:
allowed = _is_membership_change_allowed(
event, auth_events
)
if allowed:
logger.debug("Allowing! %s", event)
else:
logger.debug("Denying! %s", event)
return allowed
_check_event_sender_in_room(event, auth_events)
# Special case to allow m.room.third_party_invite events wherever
# a user is allowed to issue invites. Fixes
# https://github.com/vector-im/vector-web/issues/1208 hopefully
if event.type == EventTypes.ThirdPartyInvite:
user_level = get_user_power_level(event.user_id, auth_events)
invite_level = _get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, (
"You cannot issue a third party invite for %s." %
(event.content.display_name,)
)
)
else:
return True
_can_send_event(event, auth_events)
if event.type == EventTypes.PowerLevels:
_check_power_levels(event, auth_events)
if event.type == EventTypes.Redaction:
check_redaction(event, auth_events)
logger.debug("Allowing! %s", event)
def _check_size_limits(event):
def too_big(field):
raise EventSizeError("%s too large" % (field,))
if len(event.user_id) > 255:
too_big("user_id")
if len(event.room_id) > 255:
too_big("room_id")
if event.is_state() and len(event.state_key) > 255:
too_big("state_key")
if len(event.type) > 255:
too_big("type")
if len(event.event_id) > 255:
too_big("event_id")
if len(encode_canonical_json(event.get_pdu_json())) > 65536:
too_big("event")
def _can_federate(event, auth_events):
creation_event = auth_events.get((EventTypes.Create, ""))
return creation_event.content.get("m.federate", True) is True
def _is_membership_change_allowed(event, auth_events):
membership = event.content["membership"]
# Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership:
# Get room creation event:
key = (EventTypes.Create, "", )
create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key:
return True
target_user_id = event.state_key
creating_domain = get_domain_from_id(event.room_id)
target_domain = get_domain_from_id(target_user_id)
if creating_domain != target_domain:
if not _can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# get info about the caller
key = (EventTypes.Member, event.user_id, )
caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target
key = (EventTypes.Member, target_user_id, )
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
target_banned = target and target.membership == Membership.BAN
key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key)
if join_rule_event:
join_rule = join_rule_event.content.get(
"join_rule", JoinRules.INVITE
)
else:
join_rule = JoinRules.INVITE
user_level = get_user_power_level(event.user_id, auth_events)
target_level = get_user_power_level(
target_user_id, auth_events
)
# FIXME (erikj): What should we do here as the default?
ban_level = _get_named_level(auth_events, "ban", 50)
logger.debug(
"_is_membership_change_allowed: %s",
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
"target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
"join_rule": join_rule,
"target_user_id": target_user_id,
"event.user_id": event.user_id,
}
)
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not _verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True
if Membership.JOIN != membership:
if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
return True
if not caller_in_room: # caller isn't joined
raise AuthError(
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
# Invites are valid iff caller is in the room and target isn't.
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." %
target_user_id)
else:
invite_level = _get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, "You cannot invite user %s." % target_user_id
)
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
elif target_banned:
raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.")
else:
# TODO (erikj): may_join list
# TODO (erikj): private rooms
raise AuthError(403, "You are not allowed to join this room")
elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
raise AuthError(
403, "You cannot unban user &s." % (target_user_id,)
)
elif target_user_id != event.user_id:
kick_level = _get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
else:
raise AuthError(500, "Unknown membership %s" % membership)
return True
def _check_event_sender_in_room(event, auth_events):
key = (EventTypes.Member, event.user_id, )
member_event = auth_events.get(key)
return _check_joined_room(
member_event,
event.user_id,
event.room_id
)
def _check_joined_room(member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s (%s)" % (
user_id, room_id, repr(member)
))
def get_send_level(etype, state_key, auth_events):
key = (EventTypes.PowerLevels, "", )
send_level_event = auth_events.get(key)
send_level = None
if send_level_event:
send_level = send_level_event.content.get("events", {}).get(
etype
)
if send_level is None:
if state_key is not None:
send_level = send_level_event.content.get(
"state_default", 50
)
else:
send_level = send_level_event.content.get(
"events_default", 0
)
if send_level:
send_level = int(send_level)
else:
send_level = 0
return send_level
def _can_send_event(event, auth_events):
send_level = get_send_level(
event.type, event.get("state_key", None), auth_events
)
user_level = get_user_power_level(event.user_id, auth_events)
if user_level < send_level:
raise AuthError(
403,
"You don't have permission to post that to the room. " +
"user_level (%d) < send_level (%d)" % (user_level, send_level)
)
# Check state_key
if hasattr(event, "state_key"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True
def check_redaction(event, auth_events):
"""Check whether the event sender is allowed to redact the target event.
Returns:
True if the the sender is allowed to redact the target event if the
target event was created by them.
False if the sender is allowed to redact the target event with no
further checks.
Raises:
AuthError if the event sender is definitely not allowed to redact
the target event.
"""
user_level = get_user_power_level(event.user_id, auth_events)
redact_level = _get_named_level(auth_events, "redact", 50)
if user_level >= redact_level:
return False
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
return True
raise AuthError(
403,
"You don't have permission to redact events"
)
def _check_power_levels(event, auth_events):
user_list = event.content.get("users", {})
# Validate users
for k, v in user_list.items():
try:
UserID.from_string(k)
except:
raise SynapseError(400, "Not a valid user_id: %s" % (k,))
try:
int(v)
except:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
key = (event.type, event.state_key, )
current_state = auth_events.get(key)
if not current_state:
return
user_level = get_user_power_level(event.user_id, auth_events)
# Check other levels:
levels_to_check = [
("users_default", None),
("events_default", None),
("state_default", None),
("ban", None),
("redact", None),
("kick", None),
("invite", None),
]
old_list = current_state.content.get("users")
for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append(
(user, "users")
)
old_list = current_state.content.get("events")
new_list = event.content.get("events")
for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append(
(ev_id, "events")
)
old_state = current_state.content
new_state = event.content
for level_to_check, dir in levels_to_check:
old_loc = old_state
new_loc = new_state
if dir:
old_loc = old_loc.get(dir, {})
new_loc = new_loc.get(dir, {})
if level_to_check in old_loc:
old_level = int(old_loc[level_to_check])
else:
old_level = None
if level_to_check in new_loc:
new_level = int(new_loc[level_to_check])
else:
new_level = None
if new_level is not None and old_level is not None:
if new_level == old_level:
continue
if dir == "users" and level_to_check != event.user_id:
if old_level == user_level:
raise AuthError(
403,
"You don't have permission to remove ops level equal "
"to your own"
)
if old_level > user_level or new_level > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater "
"than your own"
)
def _get_power_level_event(auth_events):
key = (EventTypes.PowerLevels, "", )
return auth_events.get(key)
def get_user_power_level(user_id, auth_events):
power_level_event = _get_power_level_event(auth_events)
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
if not level:
level = power_level_event.content.get("users_default", 0)
if level is None:
return 0
else:
return int(level)
else:
key = (EventTypes.Create, "", )
create_event = auth_events.get(key)
if (create_event is not None and
create_event.content["creator"] == user_id):
return 100
else:
return 0
def _get_named_level(auth_events, name, default):
power_level_event = _get_power_level_event(auth_events)
if not power_level_event:
return default
level = power_level_event.content.get(name, None)
if level is not None:
return int(level)
else:
return default
def _verify_third_party_invite(event, auth_events):
"""
Validates that the invite event is authorized by a previous third-party invite.
Checks that the public key, and keyserver, match those in the third party invite,
and that the invite event has a signature issued using that public key.
Args:
event: The m.room.member join event being validated.
auth_events: All relevant previous context events which may be used
for authorization decisions.
Return:
True if the event fulfills the expectations of a previous third party
invite event.
"""
if "third_party_invite" not in event.content:
return False
if "signed" not in event.content["third_party_invite"]:
return False
signed = event.content["third_party_invite"]["signed"]
for key in {"mxid", "token"}:
if key not in signed:
return False
token = signed["token"]
invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,)
)
if not invite_event:
return False
if invite_event.sender != event.sender:
return False
if event.user_id != invite_event.user_id:
return False
if signed["mxid"] != event.state_key:
return False
if signed["token"] != token:
return False
for public_key_object in get_public_keys(invite_event):
public_key = public_key_object["public_key"]
try:
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
continue
verify_key = decode_verify_key_bytes(
key_name,
decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
# We got the public key from the invite, so we know that the
# correct server signed the signed bundle.
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True
except (KeyError, SignatureVerifyException,):
continue
return False
def get_public_keys(invite_event):
public_keys = []
if "public_key" in invite_event.content:
o = {
"public_key": invite_event.content["public_key"],
}
if "key_validity_url" in invite_event.content:
o["key_validity_url"] = invite_event.content["key_validity_url"]
public_keys.append(o)
public_keys.extend(invite_event.content.get("public_keys", []))
return public_keys
def auth_types_for_event(event):
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room.
Used to limit the number of events to fetch from the database to
actually auth the event.
"""
if event.type == EventTypes.Create:
return []
auth_types = []
auth_types.append((EventTypes.PowerLevels, "", ))
auth_types.append((EventTypes.Member, event.user_id, ))
auth_types.append((EventTypes.Create, "", ))
if event.type == EventTypes.Member:
membership = event.content["membership"]
if membership in [Membership.JOIN, Membership.INVITE]:
auth_types.append((EventTypes.JoinRules, "", ))
auth_types.append((EventTypes.Member, event.state_key, ))
if membership == Membership.INVITE:
if "third_party_invite" in event.content:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
)
auth_types.append(key)
return auth_types

View file

@ -79,7 +79,6 @@ class EventBase(object):
auth_events = _event_dict_property("auth_events") auth_events = _event_dict_property("auth_events")
depth = _event_dict_property("depth") depth = _event_dict_property("depth")
content = _event_dict_property("content") content = _event_dict_property("content")
event_id = _event_dict_property("event_id")
hashes = _event_dict_property("hashes") hashes = _event_dict_property("hashes")
origin = _event_dict_property("origin") origin = _event_dict_property("origin")
origin_server_ts = _event_dict_property("origin_server_ts") origin_server_ts = _event_dict_property("origin_server_ts")
@ -88,8 +87,6 @@ class EventBase(object):
redacts = _event_dict_property("redacts") redacts = _event_dict_property("redacts")
room_id = _event_dict_property("room_id") room_id = _event_dict_property("room_id")
sender = _event_dict_property("sender") sender = _event_dict_property("sender")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
user_id = _event_dict_property("sender") user_id = _event_dict_property("sender")
@property @property
@ -162,6 +159,11 @@ class FrozenEvent(EventBase):
else: else:
frozen_dict = event_dict frozen_dict = event_dict
self.event_id = event_dict["event_id"]
self.type = event_dict["type"]
if "state_key" in event_dict:
self.state_key = event_dict["state_key"]
super(FrozenEvent, self).__init__( super(FrozenEvent, self).__init__(
frozen_dict, frozen_dict,
signatures=signatures, signatures=signatures,

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import EventBase, FrozenEvent from . import EventBase, FrozenEvent, _event_dict_property
from synapse.types import EventID from synapse.types import EventID
@ -34,6 +34,10 @@ class EventBuilder(EventBase):
internal_metadata_dict=internal_metadata_dict, internal_metadata_dict=internal_metadata_dict,
) )
event_id = _event_dict_property("event_id")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
def build(self): def build(self):
return FrozenEvent.from_event(self) return FrozenEvent.from_event(self)

View file

@ -26,7 +26,7 @@ from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent from synapse.events import FrozenEvent, builder
import synapse.metrics import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
@ -499,8 +499,10 @@ class FederationClient(FederationBase):
if "prev_state" not in pdu_dict: if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = [] pdu_dict["prev_state"] = []
ev = builder.EventBuilder(pdu_dict)
defer.returnValue( defer.returnValue(
(destination, self.event_from_pdu_json(pdu_dict)) (destination, ev)
) )
break break
except CodeMessageException as e: except CodeMessageException as e:

View file

@ -362,7 +362,7 @@ class TransactionQueue(object):
if not success: if not success:
break break
except NotRetryingDestination: except NotRetryingDestination:
logger.info( logger.debug(
"TX [%s] not ready for retry yet - " "TX [%s] not ready for retry yet - "
"dropping transaction for now", "dropping transaction for now",
destination, destination,

View file

@ -591,12 +591,12 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys()) event_ids = list(extremities.keys())
logger.info("calling resolve_state_groups in _maybe_backfill") logger.debug("calling resolve_state_groups in _maybe_backfill")
states = yield preserve_context_over_deferred(defer.gatherResults([ states = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
for e in event_ids for e in event_ids
])) ]))
states = dict(zip(event_ids, [s[1] for s in states])) states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events( state_map = yield self.store.get_events(
[e_id for ids in states.values() for e_id in ids], [e_id for ids in states.values() for e_id in ids],
@ -1530,7 +1530,7 @@ class FederationHandler(BaseHandler):
(d.type, d.state_key): d for d in different_events if d (d.type, d.state_key): d for d in different_events if d
}) })
new_state, prev_state = self.state_handler.resolve_events( new_state = self.state_handler.resolve_events(
[local_view.values(), remote_view.values()], [local_view.values(), remote_view.values()],
event event
) )

View file

@ -439,15 +439,23 @@ class Mailer(object):
}) })
def make_room_link(self, room_id): def make_room_link(self, room_id):
# need /beta for Universal Links to work on iOS if self.hs.config.email_riot_base_url:
if self.app_name == "Vector": base_url = self.hs.config.email_riot_base_url
return "https://vector.im/beta/#/room/%s" % (room_id,) elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS
base_url = "https://vector.im/beta/#/room"
else: else:
return "https://matrix.to/#/%s" % (room_id,) base_url = "https://matrix.to/#"
return "%s/%s" % (base_url, room_id)
def make_notif_link(self, notif): def make_notif_link(self, notif):
# need /beta for Universal Links to work on iOS if self.hs.config.email_riot_base_url:
if self.app_name == "Vector": return "%s/#/room/%s/%s" % (
self.hs.config.email_riot_base_url,
notif['room_id'], notif['event_id']
)
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS
return "https://vector.im/beta/#/room/%s/%s" % ( return "https://vector.im/beta/#/room/%s/%s" % (
notif['room_id'], notif['event_id'] notif['room_id'], notif['event_id']
) )

View file

@ -52,7 +52,7 @@ def get_badge_count(store, user_id):
def get_context_for_event(store, state_handler, ev, user_id): def get_context_for_event(store, state_handler, ev, user_id):
ctx = {} ctx = {}
room_state_ids = yield state_handler.get_current_state_ids(ev.room_id) room_state_ids = yield store.get_state_ids_for_event(ev.event_id)
# we no longer bother setting room_alias, and make room_name the # we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or # human-readable name instead, be that m.room.name, an alias or

View file

@ -86,7 +86,11 @@ class HttpTransactionCache(object):
pass # execute the function instead. pass # execute the function instead.
deferred = fn(*args, **kwargs) deferred = fn(*args, **kwargs)
observable = ObservableDeferred(deferred)
# We don't add an errback to the raw deferred, so we ask ObservableDeferred
# to swallow the error. This is fine as the error will still be reported
# to the observers.
observable = ObservableDeferred(deferred, consumeErrors=True)
self.transactions[txn_key] = (observable, self.clock.time_msec()) self.transactions[txn_key] = (observable, self.clock.time_msec())
return observable.observe() return observable.observe()

View file

@ -118,8 +118,14 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_password_login(self, login_submission): def do_password_login(self, login_submission):
if 'medium' in login_submission and 'address' in login_submission: if 'medium' in login_submission and 'address' in login_submission:
address = login_submission['address']
if login_submission['medium'] == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid( user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
login_submission['medium'], login_submission['address'] login_submission['medium'], address
) )
if not user_id: if not user_id:
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)

View file

@ -32,19 +32,27 @@ class VoipRestServlet(ClientV1RestServlet):
turnUris = self.hs.config.turn_uris turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret turnSecret = self.hs.config.turn_shared_secret
turnUsername = self.hs.config.turn_username
turnPassword = self.hs.config.turn_password
userLifetime = self.hs.config.turn_user_lifetime userLifetime = self.hs.config.turn_user_lifetime
if not turnUris or not turnSecret or not userLifetime:
if turnUris and turnSecret and userLifetime:
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
username = "%d:%s" % (expiry, requester.user.to_string())
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
# We need to use standard padded base64 encoding here
# encode_base64 because we need to add the standard padding to get the
# same result as the TURN server.
password = base64.b64encode(mac.digest())
elif turnUris and turnUsername and turnPassword and userLifetime:
username = turnUsername
password = turnPassword
else:
defer.returnValue((200, {})) defer.returnValue((200, {}))
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
username = "%d:%s" % (expiry, requester.user.to_string())
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
# We need to use standard padded base64 encoding here
# encode_base64 because we need to add the standard padding to get the
# same result as the TURN server.
password = base64.b64encode(mac.digest())
defer.returnValue((200, { defer.returnValue((200, {
'username': username, 'username': username,
'password': password, 'password': password,

View file

@ -96,6 +96,11 @@ class PasswordRestServlet(RestServlet):
threepid = result[LoginType.EMAIL_IDENTITY] threepid = result[LoginType.EMAIL_IDENTITY]
if 'medium' not in threepid or 'address' not in threepid: if 'medium' not in threepid or 'address' not in threepid:
raise SynapseError(500, "Malformed threepid") raise SynapseError(500, "Malformed threepid")
if threepid['medium'] == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
threepid['address'] = threepid['address'].lower()
# if using email, we must know about the email they're authing with! # if using email, we must know about the email they're authing with!
threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
threepid['medium'], threepid['address'] threepid['medium'], threepid['address']

View file

@ -16,12 +16,12 @@
from twisted.internet import defer from twisted.internet import defer
from synapse import event_auth
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
@ -41,12 +41,14 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR) SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60 EVICTION_TIMEOUT_SECONDS = 60 * 60
_NEXT_STATE_ID = 1 _NEXT_STATE_ID = 1
POWER_KEY = (EventTypes.PowerLevels, "")
def _gen_state_id(): def _gen_state_id():
global _NEXT_STATE_ID global _NEXT_STATE_ID
@ -77,6 +79,9 @@ class _StateCacheEntry(object):
else: else:
self.state_id = _gen_state_id() self.state_id = _gen_state_id()
def __len__(self):
return len(self.state)
class StateHandler(object): class StateHandler(object):
""" Responsible for doing state conflict resolution. """ Responsible for doing state conflict resolution.
@ -99,6 +104,7 @@ class StateHandler(object):
clock=self.clock, clock=self.clock,
max_len=SIZE_OF_CACHE, max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True, reset_expiry_on_get=True,
) )
@ -123,7 +129,7 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.info("calling resolve_state_groups from get_current_state") logger.debug("calling resolve_state_groups from get_current_state")
ret = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state state = ret.state
@ -148,7 +154,7 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.info("calling resolve_state_groups from get_current_state_ids") logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state state = ret.state
@ -162,7 +168,7 @@ class StateHandler(object):
def get_current_user_in_room(self, room_id, latest_event_ids=None): def get_current_user_in_room(self, room_id, latest_event_ids=None):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.info("calling resolve_state_groups from get_current_user_in_room") logger.debug("calling resolve_state_groups from get_current_user_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids) entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state( joined_users = yield self.store.get_joined_users_from_state(
room_id, entry.state_id, entry.state room_id, entry.state_id, entry.state
@ -226,7 +232,7 @@ class StateHandler(object):
context.prev_state_events = [] context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
logger.info("calling resolve_state_groups from compute_event_context") logger.debug("calling resolve_state_groups from compute_event_context")
if event.is_state(): if event.is_state():
entry = yield self.resolve_state_groups( entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
@ -327,20 +333,13 @@ class StateHandler(object):
if conflicted_state: if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id) logger.info("Resolving conflicted state for %r", room_id)
state_map = yield self.store.get_events( with Measure(self.clock, "state._resolve_events"):
[e_id for st in state_groups_ids.values() for e_id in st.values()], new_state = yield resolve_events(
get_prev_content=False state_groups_ids.values(),
) state_map_factory=lambda ev_ids: self.store.get_events(
state_sets = [ ev_ids, get_prev_content=False, check_redacted=False,
[state_map[e_id] for key, e_id in st.items() if e_id in state_map] ),
for st in state_groups_ids.values() )
]
new_state, _ = self._resolve_events(
state_sets, event_type, state_key
)
new_state = {
key: e.event_id for key, e in new_state.items()
}
else: else:
new_state = { new_state = {
key: e_ids.pop() for key, e_ids in state.items() key: e_ids.pop() for key, e_ids in state.items()
@ -388,152 +387,264 @@ class StateHandler(object):
logger.info( logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets) "Resolving state for %s with %d groups", event.room_id, len(state_sets)
) )
if event.is_state(): state_set_ids = [{
return self._resolve_events( (ev.type, ev.state_key): ev.event_id
state_sets, event.type, event.state_key for ev in st
) } for st in state_sets]
else:
return self._resolve_events(state_sets) state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
def _resolve_events(self, state_sets, event_type=None, state_key=""):
"""
Returns
(dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple
(new_state, prev_states). new_state is a map from (type, state_key)
to event. prev_states is a list of event_ids.
"""
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
state = {} new_state = resolve_events(state_set_ids, state_map)
for st in state_sets:
for e in st:
state.setdefault(
(e.type, e.state_key),
{}
)[e.event_id] = e
unconflicted_state = { new_state = {
k: v.values()[0] for k, v in state.items() key: state_map[ev_id] for key, ev_id in new_state.items()
if len(v.values()) == 1 }
}
conflicted_state = { return new_state
k: v.values()
for k, v in state.items()
if len(v.values()) > 1
}
if event_type:
prev_states_events = conflicted_state.get(
(event_type, state_key), []
)
prev_states = [s.event_id for s in prev_states_events]
else:
prev_states = []
auth_events = { def _ordered_events(events):
k: e for k, e in unconflicted_state.items() def key_func(e):
if k[0] in AuthEventTypes return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
}
try: return sorted(events, key=key_func)
resolved_state = self._resolve_state_events(
conflicted_state, auth_events
)
except:
logger.exception("Failed to resolve state")
raise
new_state = unconflicted_state
new_state.update(resolved_state)
return new_state, prev_states def resolve_events(state_sets, state_map_factory):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map_factory(dict|callable): If callable, then will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event. Otherwise, should be
a dict from event_id to event of all events in state_sets.
@log_function Returns
def _resolve_state_events(self, conflicted_state, auth_events): dict[(str, str), synapse.events.FrozenEvent] is a map from
""" This is where we actually decide which of the conflicted state to (type, state_key) to event.
use. """
unconflicted_state, conflicted_state = _seperate(
state_sets,
)
We resolve conflicts in the following order: if callable(state_map_factory):
1. power levels return _resolve_with_state_fac(
2. join rules unconflicted_state, conflicted_state, state_map_factory
3. memberships )
4. other events.
"""
resolved_state = {}
power_key = (EventTypes.PowerLevels, "")
if power_key in conflicted_state:
events = conflicted_state[power_key]
logger.debug("Resolving conflicted power levels %r", events)
resolved_state[power_key] = self._resolve_auth_events(
events, auth_events)
auth_events.update(resolved_state) state_map = state_map_factory
for key, events in conflicted_state.items(): auth_events = _create_auth_events_from_maps(
if key[0] == EventTypes.JoinRules: unconflicted_state, conflicted_state, state_map
logger.debug("Resolving conflicted join rules %r", events) )
resolved_state[key] = self._resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state) return _resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
)
for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = self._resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state) def _seperate(state_sets):
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
"""
unconflicted_state = dict(state_sets[0])
conflicted_state = {}
for key, events in conflicted_state.items(): for state_set in state_sets[1:]:
if key not in resolved_state: for key, value in state_set.iteritems():
logger.debug("Resolving conflicted state %r:%r", key, events) # Check if there is an unconflicted entry for the state key.
resolved_state[key] = self._resolve_normal_events( unconflicted_value = unconflicted_state.get(key)
events, auth_events if unconflicted_value is None:
) # There isn't an unconflicted entry so check if there is a
# conflicted entry.
ls = conflicted_state.get(key)
if ls is None:
# There wasn't a conflicted entry so haven't seen this key before.
# Therefore it isn't conflicted yet.
unconflicted_state[key] = value
else:
# This key is already conflicted, add our value to the conflict set.
ls.add(value)
elif unconflicted_value != value:
# If the unconflicted value is not the same as our value then we
# have a new conflict. So move the key from the unconflicted_state
# to the conflicted state.
conflicted_state[key] = {value, unconflicted_value}
unconflicted_state.pop(key, None)
return resolved_state return unconflicted_state, conflicted_state
def _resolve_auth_events(self, events, auth_events):
reverse = [i for i in reversed(self._ordered_events(events))]
auth_events = dict(auth_events) @defer.inlineCallbacks
def _resolve_with_state_fac(unconflicted_state, conflicted_state,
state_map_factory):
needed_events = set(
event_id
for event_ids in conflicted_state.itervalues()
for event_id in event_ids
)
prev_event = reverse[0] logger.info("Asking for %d conflicted events", len(needed_events))
for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
# The signatures have already been checked at this point
self.hs.get_auth().check(event, auth_events, do_sig_check=False)
prev_event = event
except AuthError:
return prev_event
return event state_map = yield state_map_factory(needed_events)
def _resolve_normal_events(self, events, auth_events): auth_events = _create_auth_events_from_maps(
for event in self._ordered_events(events): unconflicted_state, conflicted_state, state_map
try: )
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
# The signatures have already been checked at this point
self.hs.get_auth().check(event, auth_events, do_sig_check=False)
return event
except AuthError:
pass
# Use the last event (the one with the least depth) if they all fail new_needed_events = set(auth_events.itervalues())
# the auth check. new_needed_events -= needed_events
return event
def _ordered_events(self, events): logger.info("Asking for %d auth events", len(new_needed_events))
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
return sorted(events, key=key_func) state_map_new = yield state_map_factory(new_needed_events)
state_map.update(state_map_new)
defer.returnValue(_resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
))
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
for event_ids in conflicted_state.itervalues():
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
event_id = unconflicted_state.get(key, None)
if event_id:
auth_events[key] = event_id
return auth_events
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
state_map):
conflicted_state = {}
for key, event_ids in conflicted_state_ds.iteritems():
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
elif len(events) == 1:
unconflicted_state_ids[key] = events[0].event_id
auth_events = {
key: state_map[ev_id]
for key, ev_id in auth_event_ids.items()
if ev_id in state_map
}
try:
resolved_state = _resolve_state_events(
conflicted_state, auth_events
)
except:
logger.exception("Failed to resolve state")
raise
new_state = unconflicted_state_ids
for key, event in resolved_state.iteritems():
new_state[key] = event.event_id
return new_state
def _resolve_state_events(conflicted_state, auth_events):
""" This is where we actually decide which of the conflicted state to
use.
We resolve conflicts in the following order:
1. power levels
2. join rules
3. memberships
4. other events.
"""
resolved_state = {}
if POWER_KEY in conflicted_state:
events = conflicted_state[POWER_KEY]
logger.debug("Resolving conflicted power levels %r", events)
resolved_state[POWER_KEY] = _resolve_auth_events(
events, auth_events)
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(
events, auth_events
)
return resolved_state
def _resolve_auth_events(events, auth_events):
reverse = [i for i in reversed(_ordered_events(events))]
auth_keys = set(
key
for event in events
for key in event_auth.auth_types_for_event(event)
)
new_auth_events = {}
for key in auth_keys:
auth_event = auth_events.get(key, None)
if auth_event:
new_auth_events[key] = auth_event
auth_events = new_auth_events
prev_event = reverse[0]
for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
prev_event = event
except AuthError:
return prev_event
return event
def _resolve_normal_events(events, auth_events):
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
return event
except AuthError:
pass
# Use the last event (the one with the least depth) if they all fail
# the auth check.
return event

View file

@ -169,7 +169,7 @@ class SQLBaseStore(object):
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache( self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
) )
self._event_fetch_lock = threading.Condition() self._event_fetch_lock = threading.Condition()

View file

@ -18,13 +18,29 @@ import ujson
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore from .background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeviceInboxStore(SQLBaseStore): class DeviceInboxStore(BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, hs):
super(DeviceInboxStore, self).__init__(hs)
self.register_background_index_update(
"device_inbox_stream_index",
index_name="device_inbox_stream_id_user_id",
table="device_inbox",
columns=["stream_id", "user_id"],
)
self.register_background_update_handler(
self.DEVICE_INBOX_STREAM_ID,
self._background_drop_index_device_inbox,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_messages_to_device_inbox(self, local_messages_by_user_then_device, def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
@ -368,3 +384,18 @@ class DeviceInboxStore(SQLBaseStore):
"delete_device_msgs_for_remote", "delete_device_msgs_for_remote",
delete_messages_for_remote_destination_txn delete_messages_for_remote_destination_txn
) )
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute(
"DROP INDEX IF EXISTS device_inbox_stream_id"
)
txn.close()
yield self.runWithConnection(reindex_txn)
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
defer.returnValue(1)

View file

@ -1090,10 +1090,10 @@ class EventsStore(SQLBaseStore):
self._do_fetch self._do_fetch
) )
logger.info("Loading %d events", len(events)) logger.debug("Loading %d events", len(events))
with PreserveLoggingContext(): with PreserveLoggingContext():
rows = yield events_d rows = yield events_d
logger.info("Loaded %d events (%d rows)", len(events), len(rows)) logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
if not allow_rejected: if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]] rows[:] = [r for r in rows if not r["rejects"]]

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 39 SCHEMA_VERSION = 40
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -390,7 +390,8 @@ class RoomMemberStore(SQLBaseStore):
room_id, state_group, state_ids, room_id, state_group, state_ids,
) )
@cachedInlineCallbacks(num_args=2, cache_context=True) @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
max_entries=100000)
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
cache_context, event=None): cache_context, event=None):
# We don't use `state_group`, it's there so that we can cache based # We don't use `state_group`, it's there so that we can cache based

View file

@ -0,0 +1,21 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- turn the pre-fill startup query into a index-only scan on postgresql.
INSERT into background_updates (update_name, progress_json)
VALUES ('device_inbox_stream_index', '{}');
INSERT into background_updates (update_name, progress_json, depends_on)
VALUES ('device_inbox_stream_drop', '{}', 'device_inbox_stream_index');

View file

@ -284,7 +284,7 @@ class StateStore(SQLBaseStore):
return [r[0] for r in results] return [r[0] for r in results]
return self.runInteraction("get_current_state_for_key", f) return self.runInteraction("get_current_state_for_key", f)
@cached(num_args=2, max_entries=1000) @cached(num_args=2, max_entries=100000, iterable=True)
def _get_state_group_from_group(self, group, types): def _get_state_group_from_group(self, group, types):
raise NotImplementedError() raise NotImplementedError()

View file

@ -40,8 +40,8 @@ def register_cache(name, cache):
) )
_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR)) _string_cache = LruCache(int(100000 * CACHE_SIZE_FACTOR))
caches_by_name["string_cache"] = _string_cache _stirng_cache_metrics = register_cache("string_cache", _string_cache)
KNOWN_KEYS = { KNOWN_KEYS = {
@ -69,7 +69,12 @@ KNOWN_KEYS = {
def intern_string(string): def intern_string(string):
"""Takes a (potentially) unicode string and interns using custom cache """Takes a (potentially) unicode string and interns using custom cache
""" """
return _string_cache.setdefault(string, string) new_str = _string_cache.setdefault(string, string)
if new_str is string:
_stirng_cache_metrics.inc_hits()
else:
_stirng_cache_metrics.inc_misses()
return new_str
def intern_dict(dictionary): def intern_dict(dictionary):

View file

@ -17,7 +17,7 @@ import logging
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.logcontext import ( from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
) )
@ -42,6 +42,25 @@ _CacheSentinel = object()
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
class CacheEntry(object):
__slots__ = [
"deferred", "sequence", "callbacks", "invalidated"
]
def __init__(self, deferred, sequence, callbacks):
self.deferred = deferred
self.sequence = sequence
self.callbacks = set(callbacks)
self.invalidated = False
def invalidate(self):
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()
class Cache(object): class Cache(object):
__slots__ = ( __slots__ = (
"cache", "cache",
@ -51,12 +70,16 @@ class Cache(object):
"sequence", "sequence",
"thread", "thread",
"metrics", "metrics",
"_pending_deferred_cache",
) )
def __init__(self, name, max_entries=1000, keylen=1, tree=False): def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
cache_type = TreeCache if tree else dict cache_type = TreeCache if tree else dict
self._pending_deferred_cache = cache_type()
self.cache = LruCache( self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type max_size=max_entries, keylen=keylen, cache_type=cache_type,
size_callback=(lambda d: len(d.result)) if iterable else None,
) )
self.name = name self.name = name
@ -76,7 +99,15 @@ class Cache(object):
) )
def get(self, key, default=_CacheSentinel, callback=None): def get(self, key, default=_CacheSentinel, callback=None):
val = self.cache.get(key, _CacheSentinel, callback=callback) callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
if val.sequence == self.sequence:
val.callbacks.update(callbacks)
self.metrics.inc_hits()
return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel: if val is not _CacheSentinel:
self.metrics.inc_hits() self.metrics.inc_hits()
return val return val
@ -88,15 +119,39 @@ class Cache(object):
else: else:
return default return default
def update(self, sequence, key, value, callback=None): def set(self, key, value, callback=None):
callbacks = [callback] if callback else []
self.check_thread() self.check_thread()
if self.sequence == sequence: entry = CacheEntry(
# Only update the cache if the caches sequence number matches the deferred=value,
# number that the cache had before the SELECT was started (SYN-369) sequence=self.sequence,
self.prefill(key, value, callback=callback) callbacks=callbacks,
)
entry.callbacks.update(callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
self._pending_deferred_cache[key] = entry
def shuffle(result):
if self.sequence == entry.sequence:
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
self.cache.set(key, entry.deferred, entry.callbacks)
else:
entry.invalidate()
else:
entry.invalidate()
return result
entry.deferred.addCallback(shuffle)
def prefill(self, key, value, callback=None): def prefill(self, key, value, callback=None):
self.cache.set(key, value, callback=callback) callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key): def invalidate(self, key):
self.check_thread() self.check_thread()
@ -108,6 +163,10 @@ class Cache(object):
# Increment the sequence number so that any SELECT statements that # Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369) # raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1 self.sequence += 1
entry = self._pending_deferred_cache.pop(key, None)
if entry:
entry.invalidate()
self.cache.pop(key, None) self.cache.pop(key, None)
def invalidate_many(self, key): def invalidate_many(self, key):
@ -119,6 +178,11 @@ class Cache(object):
self.sequence += 1 self.sequence += 1
self.cache.del_multi(key) self.cache.del_multi(key)
entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()
def invalidate_all(self): def invalidate_all(self):
self.check_thread() self.check_thread()
self.sequence += 1 self.sequence += 1
@ -155,7 +219,7 @@ class CacheDescriptor(object):
""" """
def __init__(self, orig, max_entries=1000, num_args=1, tree=False, def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
inlineCallbacks=False, cache_context=False): inlineCallbacks=False, cache_context=False, iterable=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR) max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig self.orig = orig
@ -169,6 +233,8 @@ class CacheDescriptor(object):
self.num_args = num_args self.num_args = num_args
self.tree = tree self.tree = tree
self.iterable = iterable
all_args = inspect.getargspec(orig) all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1] self.arg_names = all_args.args[1:num_args + 1]
@ -203,6 +269,7 @@ class CacheDescriptor(object):
max_entries=self.max_entries, max_entries=self.max_entries,
keylen=self.num_args, keylen=self.num_args,
tree=self.tree, tree=self.tree,
iterable=self.iterable,
) )
@functools.wraps(self.orig) @functools.wraps(self.orig)
@ -243,11 +310,6 @@ class CacheDescriptor(object):
return preserve_context_over_deferred(observer) return preserve_context_over_deferred(observer)
except KeyError: except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
ret = defer.maybeDeferred( ret = defer.maybeDeferred(
preserve_context_over_fn, preserve_context_over_fn,
self.function_to_call, self.function_to_call,
@ -261,7 +323,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr) ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True) ret = ObservableDeferred(ret, consumeErrors=True)
cache.update(sequence, cache_key, ret, callback=invalidate_callback) cache.set(cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe()) return preserve_context_over_deferred(ret.observe())
@ -359,7 +421,6 @@ class CacheListDescriptor(object):
missing.append(arg) missing.append(arg)
if missing: if missing:
sequence = cache.sequence
args_to_call = dict(arg_dict) args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing args_to_call[self.list_name] = missing
@ -382,8 +443,8 @@ class CacheListDescriptor(object):
key = list(keyargs) key = list(keyargs)
key[self.list_pos] = arg key[self.list_pos] = arg
cache.update( cache.set(
sequence, tuple(key), observer, tuple(key), observer,
callback=invalidate_callback callback=invalidate_callback
) )
@ -421,17 +482,20 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
self.cache.invalidate(self.key) self.cache.invalidate(self.key)
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False): def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
iterable=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
num_args=num_args, num_args=num_args,
tree=tree, tree=tree,
cache_context=cache_context, cache_context=cache_context,
iterable=iterable,
) )
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False): def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
iterable=False):
return lambda orig: CacheDescriptor( return lambda orig: CacheDescriptor(
orig, orig,
max_entries=max_entries, max_entries=max_entries,
@ -439,6 +503,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
tree=tree, tree=tree,
inlineCallbacks=True, inlineCallbacks=True,
cache_context=cache_context, cache_context=cache_context,
iterable=iterable,
) )

View file

@ -23,7 +23,9 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))):
def __len__(self):
return len(self.value)
class DictionaryCache(object): class DictionaryCache(object):
@ -32,7 +34,7 @@ class DictionaryCache(object):
""" """
def __init__(self, name, max_entries=1000): def __init__(self, name, max_entries=1000):
self.cache = LruCache(max_size=max_entries) self.cache = LruCache(max_size=max_entries, size_callback=len)
self.name = name self.name = name
self.sequence = 0 self.sequence = 0

View file

@ -15,6 +15,7 @@
from synapse.util.caches import register_cache from synapse.util.caches import register_cache
from collections import OrderedDict
import logging import logging
@ -23,7 +24,7 @@ logger = logging.getLogger(__name__)
class ExpiringCache(object): class ExpiringCache(object):
def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
reset_expiry_on_get=False): reset_expiry_on_get=False, iterable=False):
""" """
Args: Args:
cache_name (str): Name of this cache, used for logging. cache_name (str): Name of this cache, used for logging.
@ -36,6 +37,8 @@ class ExpiringCache(object):
evicted based on time. evicted based on time.
reset_expiry_on_get (bool): If true, will reset the expiry time for reset_expiry_on_get (bool): If true, will reset the expiry time for
an item on access. Defaults to False. an item on access. Defaults to False.
iterable (bool): If true, the size is calculated by summing the
sizes of all entries, rather than the number of entries.
""" """
self._cache_name = cache_name self._cache_name = cache_name
@ -47,9 +50,13 @@ class ExpiringCache(object):
self._reset_expiry_on_get = reset_expiry_on_get self._reset_expiry_on_get = reset_expiry_on_get
self._cache = {} self._cache = OrderedDict()
self.metrics = register_cache(cache_name, self._cache) self.metrics = register_cache(cache_name, self)
self.iterable = iterable
self._size_estimate = 0
def start(self): def start(self):
if not self._expiry_ms: if not self._expiry_ms:
@ -65,15 +72,14 @@ class ExpiringCache(object):
now = self._clock.time_msec() now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value) self._cache[key] = _CacheEntry(now, value)
# Evict if there are now too many items if self.iterable:
if self._max_len and len(self._cache.keys()) > self._max_len: self._size_estimate += len(value)
sorted_entries = sorted(
self._cache.items(),
key=lambda item: item[1].time,
)
for k, _ in sorted_entries[self._max_len:]: # Evict if there are now too many items
self._cache.pop(k) while self._max_len and len(self) > self._max_len:
_key, value = self._cache.popitem(last=False)
if self.iterable:
self._size_estimate -= len(value.value)
def __getitem__(self, key): def __getitem__(self, key):
try: try:
@ -99,7 +105,7 @@ class ExpiringCache(object):
# zero expiry time means don't expire. This should never get called # zero expiry time means don't expire. This should never get called
# since we have this check in start too. # since we have this check in start too.
return return
begin_length = len(self._cache) begin_length = len(self)
now = self._clock.time_msec() now = self._clock.time_msec()
@ -110,15 +116,20 @@ class ExpiringCache(object):
keys_to_delete.add(key) keys_to_delete.add(key)
for k in keys_to_delete: for k in keys_to_delete:
self._cache.pop(k) value = self._cache.pop(k)
if self.iterable:
self._size_estimate -= len(value.value)
logger.debug( logger.debug(
"[%s] _prune_cache before: %d, after len: %d", "[%s] _prune_cache before: %d, after len: %d",
self._cache_name, begin_length, len(self._cache) self._cache_name, begin_length, len(self)
) )
def __len__(self): def __len__(self):
return len(self._cache) if self.iterable:
return self._size_estimate
else:
return len(self._cache)
class _CacheEntry(object): class _CacheEntry(object):

View file

@ -49,7 +49,7 @@ class LruCache(object):
Can also set callbacks on objects when getting/setting which are fired Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted. when that key gets invalidated/evicted.
""" """
def __init__(self, max_size, keylen=1, cache_type=dict): def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
cache = cache_type() cache = cache_type()
self.cache = cache # Used for introspection. self.cache = cache # Used for introspection.
list_root = _Node(None, None, None, None) list_root = _Node(None, None, None, None)
@ -58,6 +58,12 @@ class LruCache(object):
lock = threading.Lock() lock = threading.Lock()
def evict():
while cache_len() > max_size:
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
def synchronized(f): def synchronized(f):
@wraps(f) @wraps(f)
def inner(*args, **kwargs): def inner(*args, **kwargs):
@ -66,6 +72,16 @@ class LruCache(object):
return inner return inner
cached_cache_len = [0]
if size_callback is not None:
def cache_len():
return cached_cache_len[0]
else:
def cache_len():
return len(cache)
self.len = synchronized(cache_len)
def add_node(key, value, callbacks=set()): def add_node(key, value, callbacks=set()):
prev_node = list_root prev_node = list_root
next_node = prev_node.next_node next_node = prev_node.next_node
@ -74,6 +90,9 @@ class LruCache(object):
next_node.prev_node = node next_node.prev_node = node
cache[key] = node cache[key] = node
if size_callback:
cached_cache_len[0] += size_callback(node.value)
def move_node_to_front(node): def move_node_to_front(node):
prev_node = node.prev_node prev_node = node.prev_node
next_node = node.next_node next_node = node.next_node
@ -92,23 +111,25 @@ class LruCache(object):
prev_node.next_node = next_node prev_node.next_node = next_node
next_node.prev_node = prev_node next_node.prev_node = prev_node
if size_callback:
cached_cache_len[0] -= size_callback(node.value)
for cb in node.callbacks: for cb in node.callbacks:
cb() cb()
node.callbacks.clear() node.callbacks.clear()
@synchronized @synchronized
def cache_get(key, default=None, callback=None): def cache_get(key, default=None, callbacks=[]):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
move_node_to_front(node) move_node_to_front(node)
if callback: node.callbacks.update(callbacks)
node.callbacks.add(callback)
return node.value return node.value
else: else:
return default return default
@synchronized @synchronized
def cache_set(key, value, callback=None): def cache_set(key, value, callbacks=[]):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
if value != node.value: if value != node.value:
@ -116,21 +137,18 @@ class LruCache(object):
cb() cb()
node.callbacks.clear() node.callbacks.clear()
if callback: if size_callback:
node.callbacks.add(callback) cached_cache_len[0] -= size_callback(node.value)
cached_cache_len[0] += size_callback(value)
node.callbacks.update(callbacks)
move_node_to_front(node) move_node_to_front(node)
node.value = value node.value = value
else: else:
if callback: add_node(key, value, set(callbacks))
callbacks = set([callback])
else: evict()
callbacks = set()
add_node(key, value, callbacks)
if len(cache) > max_size:
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
@synchronized @synchronized
def cache_set_default(key, value): def cache_set_default(key, value):
@ -139,10 +157,7 @@ class LruCache(object):
return node.value return node.value
else: else:
add_node(key, value) add_node(key, value)
if len(cache) > max_size: evict()
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
return value return value
@synchronized @synchronized
@ -174,10 +189,8 @@ class LruCache(object):
for cb in node.callbacks: for cb in node.callbacks:
cb() cb()
cache.clear() cache.clear()
if size_callback:
@synchronized cached_cache_len[0] = 0
def cache_len():
return len(cache)
@synchronized @synchronized
def cache_contains(key): def cache_contains(key):
@ -190,7 +203,7 @@ class LruCache(object):
self.pop = cache_pop self.pop = cache_pop
if cache_type is TreeCache: if cache_type is TreeCache:
self.del_multi = cache_del_multi self.del_multi = cache_del_multi
self.len = cache_len self.len = synchronized(cache_len)
self.contains = cache_contains self.contains = cache_contains
self.clear = cache_clear self.clear = cache_clear

View file

@ -65,12 +65,27 @@ class TreeCache(object):
return popped return popped
def values(self): def values(self):
return [e.value for e in self.root.values()] return list(iterate_tree_cache_entry(self.root))
def __len__(self): def __len__(self):
return self.size return self.size
def iterate_tree_cache_entry(d):
"""Helper function to iterate over the leaves of a tree, i.e. a dict of that
can contain dicts.
"""
if isinstance(d, dict):
for value_d in d.itervalues():
for value in iterate_tree_cache_entry(value_d):
yield value
else:
if isinstance(d, _Entry):
yield d.value
else:
yield d
class _Entry(object): class _Entry(object):
__slots__ = ["value"] __slots__ = ["value"]

View file

@ -1,71 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, reactor
from functools import wraps
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
def debug_deferreds():
"""Cause all deferreds to wait for a reactor tick before running their
callbacks. This increases the chance of getting a stack trace out of
a defer.inlineCallback since the code waiting on the deferred will get
a chance to add an errback before the deferred runs."""
# Helper method for retrieving and restoring the current logging context
# around a callback.
def with_logging_context(fn):
context = LoggingContext.current_context()
def restore_context_callback(x):
with PreserveLoggingContext(context):
return fn(x)
return restore_context_callback
# We are going to modify the __init__ method of defer.Deferred so we
# need to get a copy of the old method so we can still call it.
old__init__ = defer.Deferred.__init__
# We need to create a deferred to bounce the callbacks through the reactor
# but we don't want to add a callback when we create that deferred so we
# we create a new type of deferred that uses the old __init__ method.
# This is safe as long as the old __init__ method doesn't invoke an
# __init__ using super.
class Bouncer(defer.Deferred):
__init__ = old__init__
# We'll add this as a callback to all Deferreds. Twisted will wait until
# the bouncer deferred resolves before calling the callbacks of the
# original deferred.
def bounce_callback(x):
bouncer = Bouncer()
reactor.callLater(0, with_logging_context(bouncer.callback), x)
return bouncer
# We'll add this as an errback to all Deferreds. Twisted will wait until
# the bouncer deferred resolves before calling the errbacks of the
# original deferred.
def bounce_errback(x):
bouncer = Bouncer()
reactor.callLater(0, with_logging_context(bouncer.errback), x)
return bouncer
@wraps(old__init__)
def new__init__(self, *args, **kargs):
old__init__(self, *args, **kargs)
self.addCallbacks(bounce_callback, bounce_errback)
defer.Deferred.__init__ = new__init__

View file

@ -25,10 +25,13 @@ from synapse.api.filtering import Filter
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
user_localpart = "test_user" user_localpart = "test_user"
# MockEvent = namedtuple("MockEvent", "sender type room_id")
def MockEvent(**kwargs): def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
return FrozenEvent(kwargs) return FrozenEvent(kwargs)

View file

@ -21,6 +21,10 @@ from synapse.events.utils import prune_event, serialize_event
def MockEvent(**kwargs): def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
return FrozenEvent(kwargs) return FrozenEvent(kwargs)
@ -35,9 +39,13 @@ class PruneEventTestCase(unittest.TestCase):
def test_minimal(self): def test_minimal(self):
self.run_test( self.run_test(
{'type': 'A'},
{ {
'type': 'A', 'type': 'A',
'event_id': '$test:domain',
},
{
'type': 'A',
'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
@ -69,10 +77,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test( self.run_test(
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain',
'unsigned': {'age_ts': 20}, 'unsigned': {'age_ts': 20},
}, },
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {'age_ts': 20}, 'unsigned': {'age_ts': 20},
@ -82,10 +92,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test( self.run_test(
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain',
'unsigned': {'other_key': 'here'}, 'unsigned': {'other_key': 'here'},
}, },
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
@ -96,10 +108,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test( self.run_test(
{ {
'type': 'C', 'type': 'C',
'event_id': '$test:domain',
'content': {'things': 'here'}, 'content': {'things': 'here'},
}, },
{ {
'type': 'C', 'type': 'C',
'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
@ -109,10 +123,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test( self.run_test(
{ {
'type': 'm.room.create', 'type': 'm.room.create',
'event_id': '$test:domain',
'content': {'creator': '@2:domain', 'other_field': 'here'}, 'content': {'creator': '@2:domain', 'other_field': 'here'},
}, },
{ {
'type': 'm.room.create', 'type': 'm.room.create',
'event_id': '$test:domain',
'content': {'creator': '@2:domain'}, 'content': {'creator': '@2:domain'},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
@ -255,6 +271,8 @@ class SerializeEventTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
self.serialize( self.serialize(
MockEvent( MockEvent(
type="foo",
event_id="test",
room_id="!foo:bar", room_id="!foo:bar",
content={ content={
"foo": "bar", "foo": "bar",
@ -263,6 +281,8 @@ class SerializeEventTestCase(unittest.TestCase):
[] []
), ),
{ {
"type": "foo",
"event_id": "test",
"room_id": "!foo:bar", "room_id": "!foo:bar",
"content": { "content": {
"foo": "bar", "foo": "bar",

View file

@ -241,7 +241,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
callcount2 = [0] callcount2 = [0]
class A(object): class A(object):
@cached(max_entries=2) @cached(max_entries=20) # HACK: This makes it 2 due to cache factor
def func(self, key): def func(self, key):
callcount[0] += 1 callcount[0] += 1
return key return key
@ -258,6 +258,10 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(callcount[0], 2) self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2) self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3") yield a.func("foo3")
self.assertEquals(callcount[0], 3) self.assertEquals(callcount[0], 3)

View file

@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
# Copyright 2017 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import unittest
from synapse.util.caches.expiringcache import ExpiringCache
from tests.utils import MockClock
class ExpiringCacheTestCase(unittest.TestCase):
def test_get_set(self):
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=1)
cache["key"] = "value"
self.assertEquals(cache.get("key"), "value")
self.assertEquals(cache["key"], "value")
def test_eviction(self):
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=2)
cache["key"] = "value"
cache["key2"] = "value2"
self.assertEquals(cache.get("key"), "value")
self.assertEquals(cache.get("key2"), "value2")
cache["key3"] = "value3"
self.assertEquals(cache.get("key"), None)
self.assertEquals(cache.get("key2"), "value2")
self.assertEquals(cache.get("key3"), "value3")
def test_iterable_eviction(self):
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=5, iterable=True)
cache["key"] = [1]
cache["key2"] = [2, 3]
cache["key3"] = [4, 5]
self.assertEquals(cache.get("key"), [1])
self.assertEquals(cache.get("key2"), [2, 3])
self.assertEquals(cache.get("key3"), [4, 5])
cache["key4"] = [6, 7]
self.assertEquals(cache.get("key"), None)
self.assertEquals(cache.get("key2"), None)
self.assertEquals(cache.get("key3"), [4, 5])
self.assertEquals(cache.get("key4"), [6, 7])
def test_time_eviction(self):
clock = MockClock()
cache = ExpiringCache("test", clock, expiry_ms=1000)
cache.start()
cache["key"] = 1
clock.advance_time(0.5)
cache["key2"] = 2
self.assertEquals(cache.get("key"), 1)
self.assertEquals(cache.get("key2"), 2)
clock.advance_time(0.9)
self.assertEquals(cache.get("key"), None)
self.assertEquals(cache.get("key2"), 2)
clock.advance_time(1)
self.assertEquals(cache.get("key"), None)
self.assertEquals(cache.get("key2"), None)

View file

@ -93,7 +93,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
cache.set("key", "value") cache.set("key", "value")
self.assertFalse(m.called) self.assertFalse(m.called)
cache.get("key", callback=m) cache.get("key", callbacks=[m])
self.assertFalse(m.called) self.assertFalse(m.called)
cache.get("key", "value") cache.get("key", "value")
@ -112,10 +112,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
cache.set("key", "value") cache.set("key", "value")
self.assertFalse(m.called) self.assertFalse(m.called)
cache.get("key", callback=m) cache.get("key", callbacks=[m])
self.assertFalse(m.called) self.assertFalse(m.called)
cache.get("key", callback=m) cache.get("key", callbacks=[m])
self.assertFalse(m.called) self.assertFalse(m.called)
cache.set("key", "value2") cache.set("key", "value2")
@ -128,7 +128,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m = Mock() m = Mock()
cache = LruCache(1) cache = LruCache(1)
cache.set("key", "value", m) cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called) self.assertFalse(m.called)
cache.set("key", "value") cache.set("key", "value")
@ -144,7 +144,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m = Mock() m = Mock()
cache = LruCache(1) cache = LruCache(1)
cache.set("key", "value", m) cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called) self.assertFalse(m.called)
cache.pop("key") cache.pop("key")
@ -163,10 +163,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m4 = Mock() m4 = Mock()
cache = LruCache(4, 2, cache_type=TreeCache) cache = LruCache(4, 2, cache_type=TreeCache)
cache.set(("a", "1"), "value", m1) cache.set(("a", "1"), "value", callbacks=[m1])
cache.set(("a", "2"), "value", m2) cache.set(("a", "2"), "value", callbacks=[m2])
cache.set(("b", "1"), "value", m3) cache.set(("b", "1"), "value", callbacks=[m3])
cache.set(("b", "2"), "value", m4) cache.set(("b", "2"), "value", callbacks=[m4])
self.assertEquals(m1.call_count, 0) self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0) self.assertEquals(m2.call_count, 0)
@ -185,8 +185,8 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m2 = Mock() m2 = Mock()
cache = LruCache(5) cache = LruCache(5)
cache.set("key1", "value", m1) cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", m2) cache.set("key2", "value", callbacks=[m2])
self.assertEquals(m1.call_count, 0) self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0) self.assertEquals(m2.call_count, 0)
@ -202,14 +202,14 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m3 = Mock(name="m3") m3 = Mock(name="m3")
cache = LruCache(2) cache = LruCache(2)
cache.set("key1", "value", m1) cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", m2) cache.set("key2", "value", callbacks=[m2])
self.assertEquals(m1.call_count, 0) self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0) self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0) self.assertEquals(m3.call_count, 0)
cache.set("key3", "value", m3) cache.set("key3", "value", callbacks=[m3])
self.assertEquals(m1.call_count, 1) self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0) self.assertEquals(m2.call_count, 0)
@ -227,8 +227,33 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
self.assertEquals(m2.call_count, 0) self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0) self.assertEquals(m3.call_count, 0)
cache.set("key1", "value", m1) cache.set("key1", "value", callbacks=[m1])
self.assertEquals(m1.call_count, 1) self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0) self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 1) self.assertEquals(m3.call_count, 1)
class LruCacheSizedTestCase(unittest.TestCase):
def test_evict(self):
cache = LruCache(5, size_callback=len)
cache["key1"] = [0]
cache["key2"] = [1, 2]
cache["key3"] = [3]
cache["key4"] = [4]
self.assertEquals(cache["key1"], [0])
self.assertEquals(cache["key2"], [1, 2])
self.assertEquals(cache["key3"], [3])
self.assertEquals(cache["key4"], [4])
self.assertEquals(len(cache), 5)
cache["key5"] = [5, 6]
self.assertEquals(len(cache), 4)
self.assertEquals(cache.get("key1"), None)
self.assertEquals(cache.get("key2"), None)
self.assertEquals(cache["key3"], [3])
self.assertEquals(cache["key4"], [4])
self.assertEquals(cache["key5"], [5, 6])