Merge remote-tracking branch 'origin/develop' into markjh/direct_to_device

This commit is contained in:
Mark Haines 2016-08-26 14:35:31 +01:00
commit 4bbef62124
27 changed files with 866 additions and 450 deletions

View File

@ -52,7 +52,7 @@ class Auth(object):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
# Docs for these currently lives at # Docs for these currently lives at
# https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst # github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
# In addition, we have type == delete_pusher which grants access only to # In addition, we have type == delete_pusher which grants access only to
# delete pushers. # delete pushers.
self._KNOWN_CAVEAT_PREFIXES = set([ self._KNOWN_CAVEAT_PREFIXES = set([
@ -63,6 +63,17 @@ class Auth(object):
"user_id = ", "user_id = ",
]) ])
@defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True):
auth_events_ids = yield self.compute_auth_events(
event, context.current_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
self.check(event, auth_events=auth_events, do_sig_check=False)
def check(self, event, auth_events, do_sig_check=True): def check(self, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
@ -267,21 +278,15 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_host_in_room(self, room_id, host): def check_host_in_room(self, room_id, host):
curr_state = yield self.state.get_current_state(room_id) with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
for event in curr_state.values(): group, curr_state_ids = yield self.state.resolve_state_groups(
if event.type == EventTypes.Member: room_id, latest_event_ids
try: )
if get_domain_from_id(event.state_key) != host:
continue
except:
logger.warn("state_key not user_id: %s", event.state_key)
continue
if event.content["membership"] == Membership.JOIN: ret = yield self.store.is_host_joined(room_id, host, group, curr_state_ids)
defer.returnValue(True) defer.returnValue(ret)
defer.returnValue(False)
def check_event_sender_in_room(self, event, auth_events): def check_event_sender_in_room(self, event, auth_events):
key = (EventTypes.Member, event.user_id, ) key = (EventTypes.Member, event.user_id, )
@ -847,7 +852,7 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def add_auth_events(self, builder, context): def add_auth_events(self, builder, context):
auth_ids = self.compute_auth_events(builder, context.current_state) auth_ids = yield self.compute_auth_events(builder, context.current_state_ids)
auth_events_entries = yield self.store.add_event_hashes( auth_events_entries = yield self.store.add_event_hashes(
auth_ids auth_ids
@ -855,30 +860,32 @@ class Auth(object):
builder.auth_events = auth_events_entries builder.auth_events = auth_events_entries
def compute_auth_events(self, event, current_state): @defer.inlineCallbacks
def compute_auth_events(self, event, current_state_ids, for_verification=False):
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
return [] defer.returnValue([])
auth_ids = [] auth_ids = []
key = (EventTypes.PowerLevels, "", ) key = (EventTypes.PowerLevels, "", )
power_level_event = current_state.get(key) power_level_event_id = current_state_ids.get(key)
if power_level_event: if power_level_event_id:
auth_ids.append(power_level_event.event_id) auth_ids.append(power_level_event_id)
key = (EventTypes.JoinRules, "", ) key = (EventTypes.JoinRules, "", )
join_rule_event = current_state.get(key) join_rule_event_id = current_state_ids.get(key)
key = (EventTypes.Member, event.user_id, ) key = (EventTypes.Member, event.user_id, )
member_event = current_state.get(key) member_event_id = current_state_ids.get(key)
key = (EventTypes.Create, "", ) key = (EventTypes.Create, "", )
create_event = current_state.get(key) create_event_id = current_state_ids.get(key)
if create_event: if create_event_id:
auth_ids.append(create_event.event_id) auth_ids.append(create_event_id)
if join_rule_event: if join_rule_event_id:
join_rule_event = yield self.store.get_event(join_rule_event_id)
join_rule = join_rule_event.content.get("join_rule") join_rule = join_rule_event.content.get("join_rule")
is_public = join_rule == JoinRules.PUBLIC if join_rule else False is_public = join_rule == JoinRules.PUBLIC if join_rule else False
else: else:
@ -887,15 +894,21 @@ class Auth(object):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
e_type = event.content["membership"] e_type = event.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]: if e_type in [Membership.JOIN, Membership.INVITE]:
if join_rule_event: if join_rule_event_id:
auth_ids.append(join_rule_event.event_id) auth_ids.append(join_rule_event_id)
if e_type == Membership.JOIN: if e_type == Membership.JOIN:
if member_event and not is_public: if member_event_id and not is_public:
auth_ids.append(member_event.event_id) auth_ids.append(member_event_id)
else: else:
if member_event: if member_event_id:
auth_ids.append(member_event.event_id) auth_ids.append(member_event_id)
if for_verification:
key = (EventTypes.Member, event.state_key, )
existing_event_id = current_state_ids.get(key)
if existing_event_id:
auth_ids.append(existing_event_id)
if e_type == Membership.INVITE: if e_type == Membership.INVITE:
if "third_party_invite" in event.content: if "third_party_invite" in event.content:
@ -903,14 +916,15 @@ class Auth(object):
EventTypes.ThirdPartyInvite, EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"] event.content["third_party_invite"]["signed"]["token"]
) )
third_party_invite = current_state.get(key) third_party_invite_id = current_state_ids.get(key)
if third_party_invite: if third_party_invite_id:
auth_ids.append(third_party_invite.event_id) auth_ids.append(third_party_invite_id)
elif member_event: elif member_event_id:
member_event = yield self.store.get_event(member_event_id)
if member_event.content["membership"] == Membership.JOIN: if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id) auth_ids.append(member_event.event_id)
return auth_ids defer.returnValue(auth_ids)
def _get_send_level(self, etype, state_key, auth_events): def _get_send_level(self, etype, state_key, auth_events):
key = (EventTypes.PowerLevels, "", ) key = (EventTypes.PowerLevels, "", )

View File

@ -85,3 +85,8 @@ class RoomCreationPreset(object):
PRIVATE_CHAT = "private_chat" PRIVATE_CHAT = "private_chat"
PUBLIC_CHAT = "public_chat" PUBLIC_CHAT = "public_chat"
TRUSTED_PRIVATE_CHAT = "trusted_private_chat" TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
class ThirdPartyEntityKind(object):
USER = "user"
LOCATION = "location"

View File

@ -25,4 +25,3 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/r0" MEDIA_PREFIX = "/_matrix/media/r0"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"

View File

@ -14,11 +14,11 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.types import ThirdPartyEntityKind
import logging import logging
import urllib import urllib
@ -29,6 +29,9 @@ logger = logging.getLogger(__name__)
HOUR_IN_MS = 60 * 60 * 1000 HOUR_IN_MS = 60 * 60 * 1000
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
def _is_valid_3pe_result(r, field): def _is_valid_3pe_result(r, field):
if not isinstance(r, dict): if not isinstance(r, dict):
return False return False
@ -103,16 +106,20 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def query_3pe(self, service, kind, protocol, fields): def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER: if kind == ThirdPartyEntityKind.USER:
uri = "%s/thirdparty/user/%s" % (service.url, urllib.quote(protocol))
required_field = "userid" required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION: elif kind == ThirdPartyEntityKind.LOCATION:
uri = "%s/thirdparty/location/%s" % (service.url, urllib.quote(protocol))
required_field = "alias" required_field = "alias"
else: else:
raise ValueError( raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind "Unrecognised 'kind' argument %r to query_3pe()", kind
) )
uri = "%s%s/thirdparty/%s/%s" % (
service.url,
APP_SERVICE_PREFIX,
kind,
urllib.quote(protocol)
)
try: try:
response = yield self.get_json(uri, fields) response = yield self.get_json(uri, fields)
if not isinstance(response, list): if not isinstance(response, list):
@ -140,7 +147,11 @@ class ApplicationServiceApi(SimpleHttpClient):
def get_3pe_protocol(self, service, protocol): def get_3pe_protocol(self, service, protocol):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get(): def _get():
uri = "%s/thirdparty/protocol/%s" % (service.url, urllib.quote(protocol)) uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
urllib.quote(protocol)
)
try: try:
defer.returnValue((yield self.get_json(uri, {}))) defer.returnValue((yield self.get_json(uri, {})))
except Exception as ex: except Exception as ex:

View File

@ -99,7 +99,7 @@ class EventBase(object):
return d return d
def get(self, key, default): def get(self, key, default=None):
return self._event_dict.get(key, default) return self._event_dict.get(key, default)
def get_internal_metadata_dict(self): def get_internal_metadata_dict(self):

View File

@ -15,9 +15,8 @@
class EventContext(object): class EventContext(object):
def __init__(self, current_state_ids=None):
def __init__(self, current_state=None): self.current_state_ids = current_state_ids
self.current_state = current_state
self.state_group = None self.state_group = None
self.rejected = False self.rejected = False
self.push_actions = [] self.push_actions = []

View File

@ -65,33 +65,21 @@ class BaseHandler(object):
retry_after_ms=int(1000 * (time_allowed - time_now)), retry_after_ms=int(1000 * (time_allowed - time_now)),
) )
def is_host_in_room(self, current_state):
room_members = [
(state_key, event.membership)
for ((event_type, state_key), event) in current_state.items()
if event_type == EventTypes.Member
]
if len(room_members) == 0:
# Have we just created the room, and is this about to be the very
# first member event?
create_event = current_state.get(("m.room.create", ""))
if create_event:
return True
for (state_key, membership) in room_members:
if (
self.hs.is_mine_id(state_key)
and membership == Membership.JOIN
):
return True
return False
@defer.inlineCallbacks @defer.inlineCallbacks
def maybe_kick_guest_users(self, event, current_state): def maybe_kick_guest_users(self, event, context=None):
# Technically this function invalidates current_state by changing it. # Technically this function invalidates current_state by changing it.
# Hopefully this isn't that important to the caller. # Hopefully this isn't that important to the caller.
if event.type == EventTypes.GuestAccess: if event.type == EventTypes.GuestAccess:
guest_access = event.content.get("guest_access", "forbidden") guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join": if guest_access != "can_join":
if context:
current_state = yield self.store.get_events(
context.current_state_ids.values()
)
current_state = current_state.values()
else:
current_state = yield self.store.get_current_state(event.room_id)
logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state) yield self.kick_guest_users(current_state)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -29,6 +29,7 @@ from synapse.util import unwrapFirstError
from synapse.util.logcontext import ( from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
) )
from synapse.util.metrics import measure_func
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
@ -217,17 +218,28 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
prev_state = context.current_state.get((event.type, event.state_key))
if not prev_state or prev_state.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally # Only fire user_joined_room if the user has acutally
# joined the room. Don't bother if the user is just # joined the room. Don't bother if the user is just
# changing their profile info. # changing their profile info.
newly_joined = True
prev_state_id = context.current_state_ids.get(
(event.type, event.state_key)
)
if prev_state_id:
prev_state = yield self.store.get_event(
prev_state_id, allow_none=True,
)
if prev_state and prev_state.membership == Membership.JOIN:
newly_joined = False
if newly_joined:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id) yield user_joined_room(self.distributor, user, event.room_id)
@measure_func("_filter_events_for_server")
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events): def _filter_events_for_server(self, server_name, room_id, events):
event_to_state = yield self.store.get_state_for_events( event_to_state_ids = yield self.store.get_state_ids_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
types=( types=(
(EventTypes.RoomHistoryVisibility, ""), (EventTypes.RoomHistoryVisibility, ""),
@ -235,6 +247,30 @@ class FederationHandler(BaseHandler):
) )
) )
# We only want to pull out member events that correspond to the
# server's domain.
def check_match(id):
try:
return server_name == get_domain_from_id(id)
except:
return False
event_map = yield self.store.get_events([
e_id for key_to_eid in event_to_state_ids.values()
for key, e_id in key_to_eid
if key[0] != EventTypes.Member or check_match(key[1])
])
event_to_state = {
e_id: {
key: event_map[inner_e_id]
for key, inner_e_id in key_to_eid.items()
if inner_e_id in event_map
}
for e_id, key_to_eid in event_to_state_ids.items()
}
def redact_disallowed(event, state): def redact_disallowed(event, state):
if not state: if not state:
return event return event
@ -377,7 +413,9 @@ class FederationHandler(BaseHandler):
)).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results if a}) auth_events.update({a.event_id: a for a in results if a})
required_auth.update( required_auth.update(
a_id for event in results for a_id, _ in event.auth_events if event a_id
for event in results if event
for a_id, _ in event.auth_events
) )
missing_auth = required_auth - set(auth_events) missing_auth = required_auth - set(auth_events)
@ -560,6 +598,18 @@ class FederationHandler(BaseHandler):
])) ]))
states = dict(zip(event_ids, [s[1] for s in states])) states = dict(zip(event_ids, [s[1] for s in states]))
state_map = yield self.store.get_events(
[e_id for ids in states.values() for e_id in ids],
get_prev_content=False
)
states = {
key: {
k: state_map[e_id]
for k, e_id in state_dict.items()
if e_id in state_map
} for key, state_dict in states.items()
}
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id]) likely_domains = get_domains_from_state(states[e_id])
@ -722,7 +772,7 @@ class FederationHandler(BaseHandler):
# The remote hasn't signed it yet, obviously. We'll do the full checks # The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request` # when we get the event back in `on_send_join_request`
self.auth.check(event, auth_events=context.current_state, do_sig_check=False) yield self.auth.check_from_context(event, context, do_sig_check=False)
defer.returnValue(event) defer.returnValue(event)
@ -770,18 +820,11 @@ class FederationHandler(BaseHandler):
new_pdu = event new_pdu = event
destinations = set() message_handler = self.hs.get_handlers().message_handler
destinations = yield message_handler.get_joined_hosts_for_room_from_state(
for k, s in context.current_state.items(): context
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(get_domain_from_id(s.state_key))
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
) )
destinations = set(destinations)
destinations.discard(origin) destinations.discard(origin)
logger.debug( logger.debug(
@ -792,13 +835,15 @@ class FederationHandler(BaseHandler):
self.replication_layer.send_pdu(new_pdu, destinations) self.replication_layer.send_pdu(new_pdu, destinations)
state_ids = [e.event_id for e in context.current_state.values()] state_ids = context.current_state_ids.values()
auth_chain = yield self.store.get_auth_chain(set( auth_chain = yield self.store.get_auth_chain(set(
[event.event_id] + state_ids [event.event_id] + state_ids
)) ))
state = yield self.store.get_events(context.current_state_ids.values())
defer.returnValue({ defer.returnValue({
"state": context.current_state.values(), "state": state.values(),
"auth_chain": auth_chain, "auth_chain": auth_chain,
}) })
@ -954,7 +999,7 @@ class FederationHandler(BaseHandler):
try: try:
# The remote hasn't signed it yet, obviously. We'll do the full checks # The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request` # when we get the event back in `on_send_leave_request`
self.auth.check(event, auth_events=context.current_state, do_sig_check=False) yield self.auth.check_from_context(event, context, do_sig_check=False)
except AuthError as e: except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e) logger.warn("Failed to create new leave %r because %s", event, e)
raise e raise e
@ -998,18 +1043,11 @@ class FederationHandler(BaseHandler):
new_pdu = event new_pdu = event
destinations = set() message_handler = self.hs.get_handlers().message_handler
destinations = yield message_handler.get_joined_hosts_for_room_from_state(
for k, s in context.current_state.items(): context
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.LEAVE:
destinations.add(get_domain_from_id(s.state_key))
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
) )
destinations = set(destinations)
destinations.discard(origin) destinations.discard(origin)
logger.debug( logger.debug(
@ -1294,7 +1332,13 @@ class FederationHandler(BaseHandler):
) )
if not auth_events: if not auth_events:
auth_events = context.current_state auth_events_ids = yield self.auth.compute_auth_events(
event, context.current_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
# This is a hack to fix some old rooms where the initial join event # This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events. # didn't reference the create event in its auth events.
@ -1320,8 +1364,7 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
if event.type == EventTypes.GuestAccess: if event.type == EventTypes.GuestAccess:
full_context = yield self.store.get_current_state(room_id=event.room_id) yield self.maybe_kick_guest_users(event)
yield self.maybe_kick_guest_users(event, full_context)
defer.returnValue(context) defer.returnValue(context)
@ -1492,7 +1535,9 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state different_auth = event_auth_events - current_state
context.current_state.update(auth_events) context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = None context.state_group = None
if different_auth and not event.internal_metadata.is_outlier(): if different_auth and not event.internal_metadata.is_outlier():
@ -1514,8 +1559,8 @@ class FederationHandler(BaseHandler):
if do_resolution: if do_resolution:
# 1. Get what we think is the auth chain. # 1. Get what we think is the auth chain.
auth_ids = self.auth.compute_auth_events( auth_ids = yield self.auth.compute_auth_events(
event, context.current_state event, context.current_state_ids
) )
local_auth_chain = yield self.store.get_auth_chain(auth_ids) local_auth_chain = yield self.store.get_auth_chain(auth_ids)
@ -1571,7 +1616,9 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs. # 4. Look at rejects and their proofs.
# TODO. # TODO.
context.current_state.update(auth_events) context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = None context.state_group = None
try: try:
@ -1758,12 +1805,12 @@ class FederationHandler(BaseHandler):
) )
try: try:
self.auth.check(event, context.current_state) yield self.auth.check_from_context(event, context)
except AuthError as e: except AuthError as e:
logger.warn("Denying new third party invite %r because %s", event, e) logger.warn("Denying new third party invite %r because %s", event, e)
raise e raise e
yield self._check_signature(event, auth_events=context.current_state) yield self._check_signature(event, context)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(None, event, context) yield member_handler.send_membership_event(None, event, context)
else: else:
@ -1789,11 +1836,11 @@ class FederationHandler(BaseHandler):
) )
try: try:
self.auth.check(event, auth_events=context.current_state) self.auth.check_from_context(event, context)
except AuthError as e: except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e) logger.warn("Denying third party invite %r because %s", event, e)
raise e raise e
yield self._check_signature(event, auth_events=context.current_state) yield self._check_signature(event, context)
returned_invite = yield self.send_invite(origin, event) returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct. # TODO: Make sure the signatures actually are correct.
@ -1807,7 +1854,12 @@ class FederationHandler(BaseHandler):
EventTypes.ThirdPartyInvite, EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"] event.content["third_party_invite"]["signed"]["token"]
) )
original_invite = context.current_state.get(key) original_invite = None
original_invite_id = context.current_state_ids.get(key)
if original_invite_id:
original_invite = yield self.store.get_event(
original_invite_id, allow_none=True
)
if not original_invite: if not original_invite:
logger.info( logger.info(
"Could not find invite event for third_party_invite - " "Could not find invite event for third_party_invite - "
@ -1824,13 +1876,13 @@ class FederationHandler(BaseHandler):
defer.returnValue((event, context)) defer.returnValue((event, context))
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_signature(self, event, auth_events): def _check_signature(self, event, context):
""" """
Checks that the signature in the event is consistent with its invite. Checks that the signature in the event is consistent with its invite.
Args: Args:
event (Event): The m.room.member event to check event (Event): The m.room.member event to check
auth_events (dict<(event type, state_key), event>): context (EventContext):
Raises: Raises:
AuthError: if signature didn't match any keys, or key has been AuthError: if signature didn't match any keys, or key has been
@ -1841,10 +1893,14 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"] signed = event.content["third_party_invite"]["signed"]
token = signed["token"] token = signed["token"]
invite_event = auth_events.get( invite_event_id = context.current_state_ids.get(
(EventTypes.ThirdPartyInvite, token,) (EventTypes.ThirdPartyInvite, token,)
) )
invite_event = None
if invite_event_id:
invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
if not invite_event: if not invite_event:
raise AuthError(403, "Could not find invite") raise AuthError(403, "Could not find invite")

View File

@ -30,6 +30,7 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -248,7 +249,7 @@ class MessageHandler(BaseHandler):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state(): if event.is_state():
prev_state = self.deduplicate_state_event(event, context) prev_state = yield self.deduplicate_state_event(event, context)
if prev_state is not None: if prev_state is not None:
defer.returnValue(prev_state) defer.returnValue(prev_state)
@ -263,6 +264,7 @@ class MessageHandler(BaseHandler):
presence = self.hs.get_presence_handler() presence = self.hs.get_presence_handler()
yield presence.bump_presence_active_time(user) yield presence.bump_presence_active_time(user)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context): def deduplicate_state_event(self, event, context):
""" """
Checks whether event is in the latest resolved state in context. Checks whether event is in the latest resolved state in context.
@ -270,13 +272,17 @@ class MessageHandler(BaseHandler):
If so, returns the version of the event in context. If so, returns the version of the event in context.
Otherwise, returns None. Otherwise, returns None.
""" """
prev_event = context.current_state.get((event.type, event.state_key)) prev_event_id = context.current_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
if prev_event and event.user_id == prev_event.user_id: if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content) prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content) next_content = encode_canonical_json(event.content)
if prev_content == next_content: if prev_content == next_content:
return prev_event defer.returnValue(prev_event)
return None return
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_nonmember_event( def create_and_send_nonmember_event(
@ -803,7 +809,7 @@ class MessageHandler(BaseHandler):
logger.debug( logger.debug(
"Created event %s with current state: %s", "Created event %s with current state: %s",
event.event_id, context.current_state, event.event_id, context.current_state_ids,
) )
defer.returnValue( defer.returnValue(
@ -826,12 +832,12 @@ class MessageHandler(BaseHandler):
self.ratelimit(requester) self.ratelimit(requester)
try: try:
self.auth.check(event, auth_events=context.current_state) yield self.auth.check_from_context(event, context)
except AuthError as err: except AuthError as err:
logger.warn("Denying new event %r because %s", event, err) logger.warn("Denying new event %r because %s", event, err)
raise err raise err
yield self.maybe_kick_guest_users(event, context.current_state.values()) yield self.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias: if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least) # Check the alias is acually valid (at this time at least)
@ -859,6 +865,15 @@ class MessageHandler(BaseHandler):
e.sender == event.sender e.sender == event.sender
) )
state_to_include_ids = [
e_id
for k, e_id in context.current_state_ids.items()
if k[0] in self.hs.config.room_invite_state_types
or k[0] == EventTypes.Member and k[1] == event.sender
]
state_to_include = yield self.store.get_events(state_to_include_ids)
event.unsigned["invite_room_state"] = [ event.unsigned["invite_room_state"] = [
{ {
"type": e.type, "type": e.type,
@ -866,9 +881,7 @@ class MessageHandler(BaseHandler):
"content": e.content, "content": e.content,
"sender": e.sender, "sender": e.sender,
} }
for k, e in context.current_state.items() for e in state_to_include.values()
if e.type in self.hs.config.room_invite_state_types
or is_inviter_member_event(e)
] ]
invitee = UserID.from_string(event.state_key) invitee = UserID.from_string(event.state_key)
@ -890,7 +903,14 @@ class MessageHandler(BaseHandler):
) )
if event.type == EventTypes.Redaction: if event.type == EventTypes.Redaction:
if self.auth.check_redaction(event, auth_events=context.current_state): auth_events_ids = yield self.auth.compute_auth_events(
event, context.current_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
if self.auth.check_redaction(event, auth_events=auth_events):
original_event = yield self.store.get_event( original_event = yield self.store.get_event(
event.redacts, event.redacts,
check_redacted=False, check_redacted=False,
@ -904,7 +924,7 @@ class MessageHandler(BaseHandler):
"You don't have permission to redact events" "You don't have permission to redact events"
) )
if event.type == EventTypes.Create and context.current_state: if event.type == EventTypes.Create and context.current_state_ids:
raise AuthError( raise AuthError(
403, 403,
"Changing the room create event is forbidden", "Changing the room create event is forbidden",
@ -925,16 +945,7 @@ class MessageHandler(BaseHandler):
event_stream_id, max_stream_id event_stream_id, max_stream_id
) )
destinations = set() destinations = yield self.get_joined_hosts_for_room_from_state(context)
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(get_domain_from_id(s.state_key))
except SynapseError:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _notify(): def _notify():
@ -952,3 +963,39 @@ class MessageHandler(BaseHandler):
preserve_fn(federation_handler.handle_new_event)( preserve_fn(federation_handler.handle_new_event)(
event, destinations=destinations, event, destinations=destinations,
) )
def get_joined_hosts_for_room_from_state(self, context):
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._get_joined_hosts_for_room_from_state(
state_group, context.current_state_ids
)
@cachedInlineCallbacks(num_args=1, cache_context=True)
def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids,
cache_context):
# Don't bother getting state for people on the same HS
current_state = yield self.store.get_events([
e_id for key, e_id in current_state_ids.items()
if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1])
])
destinations = set()
for e in current_state.itervalues():
try:
if e.type == EventTypes.Member:
if e.content["membership"] == Membership.JOIN:
destinations.add(get_domain_from_id(e.state_key))
except SynapseError:
logger.warn(
"Failed to get destination from event %s", e.event_id
)
defer.returnValue(destinations)

View File

@ -93,19 +93,25 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit, ratelimit=ratelimit,
) )
prev_member_event = context.current_state.get( prev_member_event_id = context.current_state_ids.get(
(EventTypes.Member, target.to_string()), (EventTypes.Member, target.to_string()),
None None
) )
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the # Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile # room. Don't bother if the user is just changing their profile
# info. # info.
newly_joined = True
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield user_joined_room(self.distributor, target, room_id) yield user_joined_room(self.distributor, target, room_id)
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN: if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target, room_id) user_left_room(self.distributor, target, room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -195,16 +201,19 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts = [] remote_room_hosts = []
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
current_state = yield self.state_handler.get_current_state( current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids, room_id, latest_event_ids=latest_event_ids,
) )
old_state = current_state.get((EventTypes.Member, target.to_string())) old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
if old_state_id:
old_state = yield self.store.get_event(old_state_id, allow_none=True)
old_membership = old_state.content.get("membership") if old_state else None old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban": if action == "unban" and old_membership != "ban":
raise SynapseError( raise SynapseError(
403, 403,
"Cannot unban user who was not banned (membership=%s)" % old_membership, "Cannot unban user who was not banned"
" (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE errcode=Codes.BAD_STATE
) )
if old_membership == "ban" and action != "unban": if old_membership == "ban" and action != "unban":
@ -214,10 +223,10 @@ class RoomMemberHandler(BaseHandler):
errcode=Codes.BAD_STATE errcode=Codes.BAD_STATE
) )
is_host_in_room = self.is_host_in_room(current_state) is_host_in_room = yield self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN: if effective_membership_state == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(current_state): if requester.is_guest and not self._can_guest_join(current_state_ids):
# This should be an auth check, but guests are a local concept, # This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process. # so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
@ -326,12 +335,14 @@ class RoomMemberHandler(BaseHandler):
requester = synapse.types.create_requester(target_user) requester = synapse.types.create_requester(target_user)
message_handler = self.hs.get_handlers().message_handler message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context) prev_event = yield message_handler.deduplicate_state_event(event, context)
if prev_event is not None: if prev_event is not None:
return return
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(context.current_state): if requester.is_guest:
guest_can_join = yield self._can_guest_join(context.current_state_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept, # This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process. # so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
@ -344,27 +355,39 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit, ratelimit=ratelimit,
) )
prev_member_event = context.current_state.get( prev_member_event_id = context.current_state_ids.get(
(EventTypes.Member, target_user.to_string()), (EventTypes.Member, event.state_key),
None None
) )
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the # Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile # room. Don't bother if the user is just changing their profile
# info. # info.
newly_joined = True
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield user_joined_room(self.distributor, target_user, room_id) yield user_joined_room(self.distributor, target_user, room_id)
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN: if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target_user, room_id) user_left_room(self.distributor, target_user, room_id)
def _can_guest_join(self, current_state): @defer.inlineCallbacks
def _can_guest_join(self, current_state_ids):
""" """
Returns whether a guest can join a room based on its current state. Returns whether a guest can join a room based on its current state.
""" """
guest_access = current_state.get((EventTypes.GuestAccess, ""), None) guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
return ( if not guest_access_id:
defer.returnValue(False)
guest_access = yield self.store.get_event(guest_access_id)
defer.returnValue(
guest_access guest_access
and guest_access.content and guest_access.content
and "guest_access" in guest_access.content and "guest_access" in guest_access.content
@ -683,3 +706,24 @@ class RoomMemberHandler(BaseHandler):
if membership: if membership:
yield self.store.forget(user_id, room_id) yield self.store.forget(user_id, room_id)
@defer.inlineCallbacks
def _is_host_in_room(self, current_state_ids):
# Have we just created the room, and is this about to be the very
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
if len(current_state_ids) == 1 and create_event_id:
defer.returnValue(self.hs.is_mine_id(create_event_id))
for (etype, state_key), event_id in current_state_ids.items():
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
continue
event = yield self.store.get_event(event_id, allow_none=True)
if not event:
continue
if event.membership == Membership.JOIN:
defer.returnValue(True)
defer.returnValue(False)

View File

@ -358,11 +358,11 @@ class SyncHandler(object):
Returns: Returns:
A Deferred map from ((type, state_key)->Event) A Deferred map from ((type, state_key)->Event)
""" """
state = yield self.store.get_state_for_event(event.event_id) state_ids = yield self.store.get_state_ids_for_event(event.event_id)
if event.is_state(): if event.is_state():
state = state.copy() state_ids = state_ids.copy()
state[(event.type, event.state_key)] = event state_ids[(event.type, event.state_key)] = event.event_id
defer.returnValue(state) defer.returnValue(state_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_at(self, room_id, stream_position): def get_state_at(self, room_id, stream_position):
@ -415,57 +415,61 @@ class SyncHandler(object):
with Measure(self.clock, "compute_state_delta"): with Measure(self.clock, "compute_state_delta"):
if full_state: if full_state:
if batch: if batch:
current_state = yield self.store.get_state_for_event( current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id batch.events[-1].event_id
) )
state = yield self.store.get_state_for_event( state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id batch.events[0].event_id
) )
else: else:
current_state = yield self.get_state_at( current_state_ids = yield self.get_state_at(
room_id, stream_position=now_token room_id, stream_position=now_token
) )
state = current_state state_ids = current_state_ids
timeline_state = { timeline_state = {
(event.type, event.state_key): event (event.type, event.state_key): event.event_id
for event in batch.events if event.is_state() for event in batch.events if event.is_state()
} }
state = _calculate_state( state_ids = _calculate_state(
timeline_contains=timeline_state, timeline_contains=timeline_state,
timeline_start=state, timeline_start=state_ids,
previous={}, previous={},
current=current_state, current=current_state_ids,
) )
elif batch.limited: elif batch.limited:
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token room_id, stream_position=since_token
) )
current_state = yield self.store.get_state_for_event( current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id batch.events[-1].event_id
) )
state_at_timeline_start = yield self.store.get_state_for_event( state_at_timeline_start = yield self.store.get_state_ids_for_event(
batch.events[0].event_id batch.events[0].event_id
) )
timeline_state = { timeline_state = {
(event.type, event.state_key): event (event.type, event.state_key): event.event_id
for event in batch.events if event.is_state() for event in batch.events if event.is_state()
} }
state = _calculate_state( state_ids = _calculate_state(
timeline_contains=timeline_state, timeline_contains=timeline_state,
timeline_start=state_at_timeline_start, timeline_start=state_at_timeline_start,
previous=state_at_previous_sync, previous=state_at_previous_sync,
current=current_state, current=current_state_ids,
) )
else: else:
state_ids = {}
state = {} state = {}
if state_ids:
state = yield self.store.get_events(state_ids.values())
defer.returnValue({ defer.returnValue({
(e.type, e.state_key): e (e.type, e.state_key): e
@ -806,8 +810,13 @@ class SyncHandler(object):
# the last sync (even if we have since left). This is to make sure # the last sync (even if we have since left). This is to make sure
# we do send down the room, and with full state, where necessary # we do send down the room, and with full state, where necessary
if room_id in joined_room_ids or has_join: if room_id in joined_room_ids or has_join:
old_state = yield self.get_state_at(room_id, since_token) old_state_ids = yield self.get_state_at(room_id, since_token)
old_mem_ev = old_state.get((EventTypes.Member, user_id), None) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
old_mem_ev = None
if old_mem_ev_id:
old_mem_ev = yield self.store.get_event(
old_mem_ev_id, allow_none=True
)
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
newly_joined_rooms.append(room_id) newly_joined_rooms.append(room_id)
@ -1099,27 +1108,25 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
Returns: Returns:
dict dict
""" """
event_id_to_state = { event_id_to_key = {
e.event_id: e e: key
for e in itertools.chain( for key, e in itertools.chain(
timeline_contains.values(), timeline_contains.items(),
previous.values(), previous.items(),
timeline_start.values(), timeline_start.items(),
current.values(), current.items(),
) )
} }
c_ids = set(e.event_id for e in current.values()) c_ids = set(e for e in current.values())
tc_ids = set(e.event_id for e in timeline_contains.values()) tc_ids = set(e for e in timeline_contains.values())
p_ids = set(e.event_id for e in previous.values()) p_ids = set(e for e in previous.values())
ts_ids = set(e.event_id for e in timeline_start.values()) ts_ids = set(e for e in timeline_start.values())
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
evs = (event_id_to_state[e] for e in state_ids)
return { return {
(e.type, e.state_key): e event_id_to_key[e]: e for e in state_ids
for e in evs
} }

View File

@ -40,12 +40,12 @@ class ActionGenerator:
def handle_push_actions_for_event(self, event, context): def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "evaluator_for_event"): with Measure(self.clock, "evaluator_for_event"):
bulk_evaluator = yield evaluator_for_event( bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store, context.state_group, context.current_state event, self.hs, self.store, context
) )
with Measure(self.clock, "action_for_event_by_user"): with Measure(self.clock, "action_for_event_by_user"):
actions_by_user = yield bulk_evaluator.action_for_event_by_user( actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, context.current_state event, context
) )
context.push_actions = [ context.push_actions = [

View File

@ -19,8 +19,8 @@ from twisted.internet import defer
from .push_rule_evaluator import PushRuleEvaluatorForEvent from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes
from synapse.visibility import filter_events_for_clients from synapse.visibility import filter_events_for_clients_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,9 +36,9 @@ def _get_rules(room_id, user_ids, store):
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_event(event, hs, store, state_group, current_state): def evaluator_for_event(event, hs, store, context):
rules_by_user = yield store.bulk_get_push_rules_for_room( rules_by_user = yield store.bulk_get_push_rules_for_room(
event.room_id, state_group, current_state event.room_id, context
) )
# if this event is an invite event, we may need to run rules for the user # if this event is an invite event, we may need to run rules for the user
@ -72,7 +72,7 @@ class BulkPushRuleEvaluator:
self.store = store self.store = store
@defer.inlineCallbacks @defer.inlineCallbacks
def action_for_event_by_user(self, event, current_state): def action_for_event_by_user(self, event, context):
actions_by_user = {} actions_by_user = {}
# None of these users can be peeking since this list of users comes # None of these users can be peeking since this list of users comes
@ -82,27 +82,25 @@ class BulkPushRuleEvaluator:
(u, False) for u in self.rules_by_user.keys() (u, False) for u in self.rules_by_user.keys()
] ]
filtered_by_user = yield filter_events_for_clients( filtered_by_user = yield filter_events_for_clients_context(
self.store, user_tuples, [event], {event.event_id: current_state} self.store, user_tuples, [event], {event.event_id: context}
) )
room_members = set( room_members = yield self.store.get_joined_users_from_context(
e.state_key for e in current_state.values() event.room_id, context,
if e.type == EventTypes.Member and e.membership == Membership.JOIN
) )
evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
condition_cache = {} condition_cache = {}
display_names = {}
for ev in current_state.values():
nm = ev.content.get("displayname", None)
if nm and ev.type == EventTypes.Member:
display_names[ev.state_key] = nm
for uid, rules in self.rules_by_user.items(): for uid, rules in self.rules_by_user.items():
display_name = display_names.get(uid, None) display_name = None
member_ev_id = context.current_state_ids.get((EventTypes.Member, uid))
if member_ev_id:
member_ev = yield self.store.get_event(member_ev_id, allow_none=True)
if member_ev:
display_name = member_ev.content.get("displayname", None)
filtered = filtered_by_user[uid] filtered = filtered_by_user[uid]
if len(filtered) == 0: if len(filtered) == 0:

View File

@ -245,7 +245,7 @@ class HttpPusher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge): def _build_notification_dict(self, event, tweaks, badge):
ctx = yield push_tools.get_context_for_event( ctx = yield push_tools.get_context_for_event(
self.state_handler, event, self.user_id self.store, self.state_handler, event, self.user_id
) )
d = { d = {

View File

@ -22,7 +22,7 @@ from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute
from synapse.util.presentable_names import ( from synapse.push.presentable_names import (
calculate_room_name, name_from_member_event, descriptor_from_member_events calculate_room_name, name_from_member_event, descriptor_from_member_events
) )
from synapse.types import UserID from synapse.types import UserID
@ -139,7 +139,7 @@ class Mailer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _fetch_room_state(room_id): def _fetch_room_state(room_id):
room_state = yield self.state_handler.get_current_state(room_id) room_state = yield self.state_handler.get_current_state_ids(room_id)
state_by_room[room_id] = room_state state_by_room[room_id] = room_state
# Run at most 3 of these at once: sync does 10 at a time but email # Run at most 3 of these at once: sync does 10 at a time but email
@ -159,11 +159,12 @@ class Mailer(object):
) )
rooms.append(roomvars) rooms.append(roomvars)
reason['room_name'] = calculate_room_name( reason['room_name'] = yield calculate_room_name(
state_by_room[reason['room_id']], user_id, fallback_to_members=True self.store, state_by_room[reason['room_id']], user_id,
fallback_to_members=True
) )
summary_text = self.make_summary_text( summary_text = yield self.make_summary_text(
notifs_by_room, state_by_room, notif_events, user_id, reason notifs_by_room, state_by_room, notif_events, user_id, reason
) )
@ -203,12 +204,15 @@ class Mailer(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state): def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids):
my_member_event = room_state[("m.room.member", user_id)] my_member_event_id = room_state_ids[("m.room.member", user_id)]
my_member_event = yield self.store.get_event(my_member_event_id)
is_invite = my_member_event.content["membership"] == "invite" is_invite = my_member_event.content["membership"] == "invite"
room_name = yield calculate_room_name(self.store, room_state_ids, user_id)
room_vars = { room_vars = {
"title": calculate_room_name(room_state, user_id), "title": room_name,
"hash": string_ordinal_total(room_id), # See sender avatar hash "hash": string_ordinal_total(room_id), # See sender avatar hash
"notifs": [], "notifs": [],
"invite": is_invite, "invite": is_invite,
@ -218,7 +222,7 @@ class Mailer(object):
if not is_invite: if not is_invite:
for n in notifs: for n in notifs:
notifvars = yield self.get_notif_vars( notifvars = yield self.get_notif_vars(
n, user_id, notif_events[n['event_id']], room_state n, user_id, notif_events[n['event_id']], room_state_ids
) )
# merge overlapping notifs together. # merge overlapping notifs together.
@ -243,7 +247,7 @@ class Mailer(object):
defer.returnValue(room_vars) defer.returnValue(room_vars)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_notif_vars(self, notif, user_id, notif_event, room_state): def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
results = yield self.store.get_events_around( results = yield self.store.get_events_around(
notif['room_id'], notif['event_id'], notif['room_id'], notif['event_id'],
before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER
@ -261,17 +265,19 @@ class Mailer(object):
the_events.append(notif_event) the_events.append(notif_event)
for event in the_events: for event in the_events:
messagevars = self.get_message_vars(notif, event, room_state) messagevars = yield self.get_message_vars(notif, event, room_state_ids)
if messagevars is not None: if messagevars is not None:
ret['messages'].append(messagevars) ret['messages'].append(messagevars)
defer.returnValue(ret) defer.returnValue(ret)
def get_message_vars(self, notif, event, room_state): @defer.inlineCallbacks
def get_message_vars(self, notif, event, room_state_ids):
if event.type != EventTypes.Message: if event.type != EventTypes.Message:
return None return
sender_state_event = room_state[("m.room.member", event.sender)] sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
sender_state_event = yield self.store.get_event(sender_state_event_id)
sender_name = name_from_member_event(sender_state_event) sender_name = name_from_member_event(sender_state_event)
sender_avatar_url = sender_state_event.content.get("avatar_url") sender_avatar_url = sender_state_event.content.get("avatar_url")
@ -299,7 +305,7 @@ class Mailer(object):
if "body" in event.content: if "body" in event.content:
ret["body_text_plain"] = event.content["body"] ret["body_text_plain"] = event.content["body"]
return ret defer.returnValue(ret)
def add_text_message_vars(self, messagevars, event): def add_text_message_vars(self, messagevars, event):
msgformat = event.content.get("format") msgformat = event.content.get("format")
@ -321,6 +327,7 @@ class Mailer(object):
return messagevars return messagevars
@defer.inlineCallbacks
def make_summary_text(self, notifs_by_room, state_by_room, def make_summary_text(self, notifs_by_room, state_by_room,
notif_events, user_id, reason): notif_events, user_id, reason):
if len(notifs_by_room) == 1: if len(notifs_by_room) == 1:
@ -330,7 +337,7 @@ class Mailer(object):
# If the room has some kind of name, use it, but we don't # If the room has some kind of name, use it, but we don't
# want the generated-from-names one here otherwise we'll # want the generated-from-names one here otherwise we'll
# end up with, "new message from Bob in the Bob room" # end up with, "new message from Bob in the Bob room"
room_name = calculate_room_name( room_name = yield calculate_room_name(
state_by_room[room_id], user_id, fallback_to_members=False state_by_room[room_id], user_id, fallback_to_members=False
) )
@ -342,16 +349,16 @@ class Mailer(object):
inviter_name = name_from_member_event(inviter_member_event) inviter_name = name_from_member_event(inviter_member_event)
if room_name is None: if room_name is None:
return INVITE_FROM_PERSON % { defer.returnValue(INVITE_FROM_PERSON % {
"person": inviter_name, "person": inviter_name,
"app": self.app_name "app": self.app_name
} })
else: else:
return INVITE_FROM_PERSON_TO_ROOM % { defer.returnValue(INVITE_FROM_PERSON_TO_ROOM % {
"person": inviter_name, "person": inviter_name,
"room": room_name, "room": room_name,
"app": self.app_name, "app": self.app_name,
} })
sender_name = None sender_name = None
if len(notifs_by_room[room_id]) == 1: if len(notifs_by_room[room_id]) == 1:
@ -362,24 +369,24 @@ class Mailer(object):
sender_name = name_from_member_event(state_event) sender_name = name_from_member_event(state_event)
if sender_name is not None and room_name is not None: if sender_name is not None and room_name is not None:
return MESSAGE_FROM_PERSON_IN_ROOM % { defer.returnValue(MESSAGE_FROM_PERSON_IN_ROOM % {
"person": sender_name, "person": sender_name,
"room": room_name, "room": room_name,
"app": self.app_name, "app": self.app_name,
} })
elif sender_name is not None: elif sender_name is not None:
return MESSAGE_FROM_PERSON % { defer.returnValue(MESSAGE_FROM_PERSON % {
"person": sender_name, "person": sender_name,
"app": self.app_name, "app": self.app_name,
} })
else: else:
# There's more than one notification for this room, so just # There's more than one notification for this room, so just
# say there are several # say there are several
if room_name is not None: if room_name is not None:
return MESSAGES_IN_ROOM % { defer.returnValue(MESSAGES_IN_ROOM % {
"room": room_name, "room": room_name,
"app": self.app_name, "app": self.app_name,
} })
else: else:
# If the room doesn't have a name, say who the messages # If the room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room" # are from explicitly to avoid, "messages in the Bob room"
@ -388,22 +395,22 @@ class Mailer(object):
for n in notifs_by_room[room_id] for n in notifs_by_room[room_id]
])) ]))
return MESSAGES_FROM_PERSON % { defer.returnValue(MESSAGES_FROM_PERSON % {
"person": descriptor_from_member_events([ "person": descriptor_from_member_events([
state_by_room[room_id][("m.room.member", s)] state_by_room[room_id][("m.room.member", s)]
for s in sender_ids for s in sender_ids
]), ]),
"app": self.app_name, "app": self.app_name,
} })
else: else:
# Stuff's happened in multiple different rooms # Stuff's happened in multiple different rooms
# ...but we still refer to the 'reason' room which triggered the mail # ...but we still refer to the 'reason' room which triggered the mail
if reason['room_name'] is not None: if reason['room_name'] is not None:
return MESSAGES_IN_ROOM_AND_OTHERS % { defer.returnValue(MESSAGES_IN_ROOM_AND_OTHERS % {
"room": reason['room_name'], "room": reason['room_name'],
"app": self.app_name, "app": self.app_name,
} })
else: else:
# If the reason room doesn't have a name, say who the messages # If the reason room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room" # are from explicitly to avoid, "messages in the Bob room"
@ -412,13 +419,13 @@ class Mailer(object):
for n in notifs_by_room[reason['room_id']] for n in notifs_by_room[reason['room_id']]
])) ]))
return MESSAGES_FROM_PERSON_AND_OTHERS % { defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % {
"person": descriptor_from_member_events([ "person": descriptor_from_member_events([
state_by_room[reason['room_id']][("m.room.member", s)] state_by_room[reason['room_id']][("m.room.member", s)]
for s in sender_ids for s in sender_ids
]), ]),
"app": self.app_name, "app": self.app_name,
} })
def make_room_link(self, room_id): def make_room_link(self, room_id):
# need /beta for Universal Links to work on iOS # need /beta for Universal Links to work on iOS

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
import re import re
import logging import logging
@ -25,7 +27,8 @@ ALIAS_RE = re.compile(r"^#.*:.+$")
ALL_ALONE = "Empty Room" ALL_ALONE = "Empty Room"
def calculate_room_name(room_state, user_id, fallback_to_members=True, @defer.inlineCallbacks
def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True,
fallback_to_single_member=True): fallback_to_single_member=True):
""" """
Works out a user-facing name for the given room as per Matrix Works out a user-facing name for the given room as per Matrix
@ -42,59 +45,78 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
(string or None) A human readable name for the room. (string or None) A human readable name for the room.
""" """
# does it have a name? # does it have a name?
if ("m.room.name", "") in room_state: if ("m.room.name", "") in room_state_ids:
m_room_name = room_state[("m.room.name", "")] m_room_name = yield store.get_event(
if m_room_name.content and m_room_name.content["name"]: room_state_ids[("m.room.name", "")], allow_none=True
return m_room_name.content["name"] )
if m_room_name and m_room_name.content and m_room_name.content["name"]:
defer.returnValue(m_room_name.content["name"])
# does it have a canonical alias? # does it have a canonical alias?
if ("m.room.canonical_alias", "") in room_state: if ("m.room.canonical_alias", "") in room_state_ids:
canon_alias = room_state[("m.room.canonical_alias", "")] canon_alias = yield store.get_event(
room_state_ids[("m.room.canonical_alias", "")], allow_none=True
)
if ( if (
canon_alias.content and canon_alias.content["alias"] and canon_alias and canon_alias.content and canon_alias.content["alias"] and
_looks_like_an_alias(canon_alias.content["alias"]) _looks_like_an_alias(canon_alias.content["alias"])
): ):
return canon_alias.content["alias"] defer.returnValue(canon_alias.content["alias"])
# at this point we're going to need to search the state by all state keys # at this point we're going to need to search the state by all state keys
# for an event type, so rearrange the data structure # for an event type, so rearrange the data structure
room_state_bytype = _state_as_two_level_dict(room_state) room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
# right then, any aliases at all? # right then, any aliases at all?
if "m.room.aliases" in room_state_bytype: if "m.room.aliases" in room_state_bytype_ids:
m_room_aliases = room_state_bytype["m.room.aliases"] m_room_aliases = room_state_bytype_ids["m.room.aliases"]
if len(m_room_aliases.values()) > 0: for alias_id in m_room_aliases.values():
first_alias_event = m_room_aliases.values()[0] alias_event = yield store.get_event(
if first_alias_event.content and first_alias_event.content["aliases"]: alias_id, allow_none=True
the_aliases = first_alias_event.content["aliases"] )
if alias_event and alias_event.content and alias_event.get("aliases"):
the_aliases = alias_event.content["aliases"]
if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]): if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]):
return the_aliases[0] defer.returnValue(the_aliases[0])
if not fallback_to_members: if not fallback_to_members:
return None defer.returnValue(None)
my_member_event = None my_member_event = None
if ("m.room.member", user_id) in room_state: if ("m.room.member", user_id) in room_state_ids:
my_member_event = room_state[("m.room.member", user_id)] my_member_event = yield store.get_event(
room_state_ids[("m.room.member", user_id)], allow_none=True
)
if ( if (
my_member_event is not None and my_member_event is not None and
my_member_event.content['membership'] == "invite" my_member_event.content['membership'] == "invite"
): ):
if ("m.room.member", my_member_event.sender) in room_state: if ("m.room.member", my_member_event.sender) in room_state_ids:
inviter_member_event = room_state[("m.room.member", my_member_event.sender)] inviter_member_event = yield store.get_event(
room_state_ids[("m.room.member", my_member_event.sender)],
allow_none=True,
)
if inviter_member_event:
if fallback_to_single_member: if fallback_to_single_member:
return "Invite from %s" % (name_from_member_event(inviter_member_event),) defer.returnValue(
"Invite from %s" % (
name_from_member_event(inviter_member_event),
)
)
else: else:
return None return
else: else:
return "Room Invite" defer.returnValue("Room Invite")
# we're going to have to generate a name based on who's in the room, # we're going to have to generate a name based on who's in the room,
# so find out who is in the room that isn't the user. # so find out who is in the room that isn't the user.
if "m.room.member" in room_state_bytype: if "m.room.member" in room_state_bytype_ids:
member_events = yield store.get_events(
room_state_bytype_ids["m.room.member"].values()
)
all_members = [ all_members = [
ev for ev in room_state_bytype["m.room.member"].values() ev for ev in member_events.values()
if ev.content['membership'] == "join" or ev.content['membership'] == "invite" if ev.content['membership'] == "join" or ev.content['membership'] == "invite"
] ]
# Sort the member events oldest-first so the we name people in the # Sort the member events oldest-first so the we name people in the
@ -111,9 +133,9 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
# self-chat, peeked room with 1 participant, # self-chat, peeked room with 1 participant,
# or inbound invite, or outbound 3PID invite. # or inbound invite, or outbound 3PID invite.
if all_members[0].sender == user_id: if all_members[0].sender == user_id:
if "m.room.third_party_invite" in room_state_bytype: if "m.room.third_party_invite" in room_state_bytype_ids:
third_party_invites = ( third_party_invites = (
room_state_bytype["m.room.third_party_invite"].values() room_state_bytype_ids["m.room.third_party_invite"].values()
) )
if len(third_party_invites) > 0: if len(third_party_invites) > 0:
@ -126,17 +148,17 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
# return "Inviting %s" % ( # return "Inviting %s" % (
# descriptor_from_member_events(third_party_invites) # descriptor_from_member_events(third_party_invites)
# ) # )
return "Inviting email address" defer.returnValue("Inviting email address")
else: else:
return ALL_ALONE defer.returnValue(ALL_ALONE)
else: else:
return name_from_member_event(all_members[0]) defer.returnValue(name_from_member_event(all_members[0]))
else: else:
return ALL_ALONE defer.returnValue(ALL_ALONE)
elif len(other_members) == 1 and not fallback_to_single_member: elif len(other_members) == 1 and not fallback_to_single_member:
return None return
else: else:
return descriptor_from_member_events(other_members) defer.returnValue(descriptor_from_member_events(other_members))
def descriptor_from_member_events(member_events): def descriptor_from_member_events(member_events):

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.util.presentable_names import ( from synapse.push.presentable_names import (
calculate_room_name, name_from_member_event calculate_room_name, name_from_member_event
) )
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@ -49,21 +49,22 @@ def get_badge_count(store, user_id):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_context_for_event(state_handler, ev, user_id): def get_context_for_event(store, state_handler, ev, user_id):
ctx = {} ctx = {}
room_state = yield state_handler.get_current_state(ev.room_id) room_state_ids = yield state_handler.get_current_state_ids(ev.room_id)
# we no longer bother setting room_alias, and make room_name the # we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or # human-readable name instead, be that m.room.name, an alias or
# a list of people in the room # a list of people in the room
name = calculate_room_name( name = yield calculate_room_name(
room_state, user_id, fallback_to_single_member=False store, room_state_ids, user_id, fallback_to_single_member=False
) )
if name: if name:
ctx['name'] = name ctx['name'] = name
sender_state_event = room_state[("m.room.member", ev.sender)] sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
sender_state_event = yield store.get_event(sender_state_event_id)
ctx['sender_display_name'] = name_from_member_event(sender_state_event) ctx['sender_display_name'] = name_from_member_event(sender_state_event)
defer.returnValue(ctx) defer.returnValue(ctx)

View File

@ -120,10 +120,15 @@ class SlavedEventStore(BaseSlavedStore):
get_state_for_event = DataStore.get_state_for_event.__func__ get_state_for_event = DataStore.get_state_for_event.__func__
get_state_for_events = DataStore.get_state_for_events.__func__ get_state_for_events = DataStore.get_state_for_events.__func__
get_state_groups = DataStore.get_state_groups.__func__ get_state_groups = DataStore.get_state_groups.__func__
get_state_groups_ids = DataStore.get_state_groups_ids.__func__
get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__ get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
get_room_events_stream_for_rooms = ( get_room_events_stream_for_rooms = (
DataStore.get_room_events_stream_for_rooms.__func__ DataStore.get_room_events_stream_for_rooms.__func__
) )
is_host_joined = DataStore.is_host_joined.__func__
_is_host_joined = RoomMemberStore.__dict__["_is_host_joined"]
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__ get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
_set_before_and_after = staticmethod(DataStore._set_before_and_after) _set_before_and_after = staticmethod(DataStore._set_before_and_after)

View File

@ -18,8 +18,8 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.types import ThirdPartyEntityKind
from ._base import client_v2_patterns from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -93,8 +93,30 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
res = yield self.resolve_state_groups(room_id, latest_event_ids) _, state = yield self.resolve_state_groups(room_id, latest_event_ids)
state = res[1]
if event_type:
event_id = state.get((event_type, state_key))
event = None
if event_id:
event = yield self.store.get_event(event_id, allow_none=True)
defer.returnValue(event)
return
state_map = yield self.store.get_events(state.values(), get_prev_content=False)
state = {
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
}
defer.returnValue(state)
@defer.inlineCallbacks
def get_current_state_ids(self, room_id, event_type=None, state_key="",
latest_event_ids=None):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
_, state = yield self.resolve_state_groups(room_id, latest_event_ids)
if event_type: if event_type:
defer.returnValue(state.get((event_type, state_key))) defer.returnValue(state.get((event_type, state_key)))
@ -123,27 +145,27 @@ class StateHandler(object):
# state. Certainly store.get_current_state won't return any, and # state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group. # persisting the event won't store the state group.
if old_state: if old_state:
context.current_state = { context.current_state_ids = {
(s.type, s.state_key): s for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
else: else:
context.current_state = {} context.current_state_ids = {}
context.prev_state_events = [] context.prev_state_events = []
context.state_group = None context.state_group = None
defer.returnValue(context) defer.returnValue(context)
if old_state: if old_state:
context.current_state = { context.current_state_ids = {
(s.type, s.state_key): s for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
context.state_group = None context.state_group = None
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.current_state: if key in context.current_state_ids:
replaces = context.current_state[key] replaces = context.current_state_ids[key]
if replaces.event_id != event.event_id: # Paranoia check if replaces != event.event_id: # Paranoia check
event.unsigned["replaces_state"] = replaces.event_id event.unsigned["replaces_state"] = replaces
context.prev_state_events = [] context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
@ -159,18 +181,18 @@ class StateHandler(object):
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
) )
group, curr_state, prev_state = ret group, curr_state = ret
context.current_state = curr_state context.current_state_ids = curr_state
context.state_group = group if not event.is_state() else None context.state_group = group if not event.is_state() else None
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.current_state: if key in context.current_state_ids:
replaces = context.current_state[key] replaces = context.current_state_ids[key]
event.unsigned["replaces_state"] = replaces.event_id event.unsigned["replaces_state"] = replaces
context.prev_state_events = prev_state context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -187,72 +209,83 @@ class StateHandler(object):
""" """
logger.debug("resolve_state_groups event_ids %s", event_ids) logger.debug("resolve_state_groups event_ids %s", event_ids)
state_groups = yield self.store.get_state_groups( state_groups_ids = yield self.store.get_state_groups_ids(
room_id, event_ids room_id, event_ids
) )
logger.debug( logger.debug(
"resolve_state_groups state_groups %s", "resolve_state_groups state_groups %s",
state_groups.keys() state_groups_ids.keys()
) )
group_names = frozenset(state_groups.keys()) group_names = frozenset(state_groups_ids.keys())
if len(group_names) == 1: if len(group_names) == 1:
name, state_list = state_groups.items().pop() name, state_list = state_groups_ids.items().pop()
state = {
(e.type, e.state_key): e
for e in state_list
}
prev_state = state.get((event_type, state_key), None)
if prev_state:
prev_state = prev_state.event_id
prev_states = [prev_state]
else:
prev_states = []
defer.returnValue((name, state, prev_states)) defer.returnValue((name, state_list,))
if self._state_cache is not None: if self._state_cache is not None:
cache = self._state_cache.get(group_names, None) cache = self._state_cache.get(group_names, None)
if cache: if cache:
cache.ts = self.clock.time_msec() cache.ts = self.clock.time_msec()
event_dict = yield self.store.get_events(cache.state.values())
state = {(e.type, e.state_key): e for e in event_dict.values()}
prev_state = state.get((event_type, state_key), None)
if prev_state:
prev_state = prev_state.event_id
prev_states = [prev_state]
else:
prev_states = []
defer.returnValue( defer.returnValue(
(cache.state_group, state, prev_states) (cache.state_group, cache.state,)
) )
logger.info("Resolving state for %s with %d groups", room_id, len(state_groups)) logger.info(
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
new_state, prev_states = self._resolve_events(
state_groups.values(), event_type, state_key
) )
state = {}
for st in state_groups_ids.values():
for key, e_id in st.items():
state.setdefault(key, set()).add(e_id)
conflicted_state = {
k: list(v)
for k, v in state.items()
if len(v) > 1
}
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
state_map = yield self.store.get_events(
[e_id for st in state_groups_ids.values() for e_id in st.values()],
get_prev_content=False
)
state_sets = [
[state_map[e_id] for key, e_id in st.items() if e_id in state_map]
for st in state_groups_ids.values()
]
new_state, _ = self._resolve_events(
state_sets, event_type, state_key
)
new_state = {
key: e.event_id for key, e in new_state.items()
}
else:
new_state = {
key: e_ids.pop() for key, e_ids in state.items()
}
state_group = None state_group = None
new_state_event_ids = frozenset(e.event_id for e in new_state.values()) new_state_event_ids = frozenset(new_state.values())
for sg, events in state_groups.items(): for sg, events in state_groups_ids.items():
if new_state_event_ids == frozenset(e.event_id for e in events): if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg state_group = sg
break break
if self._state_cache is not None: if self._state_cache is not None:
cache = _StateCacheEntry( cache = _StateCacheEntry(
state={key: event.event_id for key, event in new_state.items()}, state=new_state,
state_group=state_group, state_group=state_group,
ts=self.clock.time_msec() ts=self.clock.time_msec()
) )
self._state_cache[group_names] = cache self._state_cache[group_names] = cache
defer.returnValue((state_group, new_state, prev_states)) defer.returnValue((state_group, new_state,))
def resolve_events(self, state_sets, event): def resolve_events(self, state_sets, event):
logger.info( logger.info(

View File

@ -124,7 +124,8 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def bulk_get_push_rules_for_room(self, room_id, state_group, current_state): def bulk_get_push_rules_for_room(self, room_id, context):
state_group = context.state_group
if not state_group: if not state_group:
# If state_group is None it means it has yet to be assigned a # If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group # state group, i.e. we need to make sure that calls with a state_group
@ -132,10 +133,12 @@ class PushRuleStore(SQLBaseStore):
# To do this we set the state_group to a new object as object() != object() # To do this we set the state_group to a new object as object() != object()
state_group = object() state_group = object()
return self._bulk_get_push_rules_for_room(room_id, state_group, current_state) return self._bulk_get_push_rules_for_room(
room_id, state_group, context.current_state_ids
)
@cachedInlineCallbacks(num_args=2, cache_context=True) @cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state, def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
cache_context): cache_context):
# We don't use `state_group`, its there so that we can cache based # We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's # on it. However, its important that its never None, since two current_state's
@ -147,10 +150,16 @@ class PushRuleStore(SQLBaseStore):
# their unread countss are correct in the event stream, but to avoid # their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've # generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room. # sent a read receipt into the room.
local_user_member_ids = [
e_id for (etype, state_key), e_id in current_state_ids.iteritems()
if etype == EventTypes.Member and self.hs.is_mine_id(state_key)
]
local_member_events = yield self._get_events(local_user_member_ids)
local_users_in_room = set( local_users_in_room = set(
e.state_key for e in current_state.values() member_event.state_key for member_event in local_member_events
if e.type == EventTypes.Member and e.membership == Membership.JOIN if member_event.membership == Membership.JOIN
and self.hs.is_mine_id(e.state_key)
) )
# users in the room who have pushers need to get push rules run because # users in the room who have pushers need to get push rules run because

View File

@ -20,7 +20,7 @@ from collections import namedtuple
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.api.constants import Membership from synapse.api.constants import Membership, EventTypes
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
import logging import logging
@ -325,7 +325,8 @@ class RoomMemberStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=3) @cachedInlineCallbacks(num_args=3)
def was_forgotten_at(self, user_id, room_id, event_id): def was_forgotten_at(self, user_id, room_id, event_id):
"""Returns whether user_id has elected to discard history for room_id at event_id. """Returns whether user_id has elected to discard history for room_id at
event_id.
event_id must be a membership event.""" event_id must be a membership event."""
def f(txn): def f(txn):
@ -358,3 +359,80 @@ class RoomMemberStore(SQLBaseStore):
}, },
desc="who_forgot" desc="who_forgot"
) )
def get_joined_users_from_context(self, room_id, context):
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._get_joined_users_from_context(
room_id, state_group, context.current_state_ids
)
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
cache_context):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
member_event_ids = [
e_id
for key, e_id in current_state_ids.iteritems()
if key[0] == EventTypes.Member
]
rows = yield self._simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
retcols=['user_id'],
keyvalues={
"membership": Membership.JOIN,
},
batch_size=1000,
desc="_get_joined_users_from_context",
)
defer.returnValue(set(row["user_id"] for row in rows))
def is_host_joined(self, room_id, host, state_group, state_ids):
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._is_host_joined(
room_id, host, state_group, state_ids
)
@cachedInlineCallbacks(num_args=3)
def _is_host_joined(self, room_id, host, state_group, current_state_ids):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
for (etype, state_key), event_id in current_state_ids.items():
if etype == EventTypes.Member:
try:
if get_domain_from_id(state_key) != host:
continue
except:
logger.warn("state_key not user_id: %s", state_key)
continue
event = yield self.get_event(event_id, allow_none=True)
if event and event.content["membership"] == Membership.JOIN:
defer.returnValue(True)
defer.returnValue(False)

View File

@ -44,11 +44,7 @@ class StateStore(SQLBaseStore):
""" """
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids): def get_state_groups_ids(self, room_id, event_ids):
""" Get the state groups for the given list of event_ids
The return value is a dict mapping group names to lists of events.
"""
if not event_ids: if not event_ids:
defer.returnValue({}) defer.returnValue({})
@ -59,9 +55,32 @@ class StateStore(SQLBaseStore):
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups) group_to_state = yield self._get_state_for_groups(groups)
defer.returnValue(group_to_state)
@defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids):
""" Get the state groups for the given list of event_ids
The return value is a dict mapping group names to lists of events.
"""
if not event_ids:
defer.returnValue({})
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
state_event_map = yield self.get_events(
[
ev_id for group_ids in group_to_ids.values()
for ev_id in group_ids.values()
],
get_prev_content=False
)
defer.returnValue({ defer.returnValue({
group: state_map.values() group: [
for group, state_map in group_to_state.items() state_event_map[v] for v in event_id_map.values() if v in state_event_map
]
for group, event_id_map in group_to_ids.items()
}) })
def _store_mult_state_groups_txn(self, txn, events_and_contexts): def _store_mult_state_groups_txn(self, txn, events_and_contexts):
@ -70,17 +89,17 @@ class StateStore(SQLBaseStore):
if event.internal_metadata.is_outlier(): if event.internal_metadata.is_outlier():
continue continue
if context.current_state is None: if context.current_state_ids is None:
continue continue
if context.state_group is not None: if context.state_group is not None:
state_groups[event.event_id] = context.state_group state_groups[event.event_id] = context.state_group
continue continue
state_events = dict(context.current_state) state_event_ids = dict(context.current_state_ids)
if event.is_state(): if event.is_state():
state_events[(event.type, event.state_key)] = event state_event_ids[(event.type, event.state_key)] = event.event_id
state_group = context.new_state_group_id state_group = context.new_state_group_id
@ -100,12 +119,12 @@ class StateStore(SQLBaseStore):
values=[ values=[
{ {
"state_group": state_group, "state_group": state_group,
"room_id": state.room_id, "room_id": event.room_id,
"type": state.type, "type": key[0],
"state_key": state.state_key, "state_key": key[1],
"event_id": state.event_id, "event_id": state_id,
} }
for state in state_events.values() for key, state_id in state_event_ids.items()
], ],
) )
state_groups[event.event_id] = state_group state_groups[event.event_id] = state_group
@ -248,6 +267,31 @@ class StateStore(SQLBaseStore):
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups, types) group_to_state = yield self._get_state_for_groups(groups, types)
state_event_map = yield self.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False
)
event_to_state = {
event_id: {
k: state_event_map[v]
for k, v in group_to_state[group].items()
if v in state_event_map
}
for event_id, group in event_to_groups.items()
}
defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, types):
event_to_groups = yield self._get_state_group_for_events(
event_ids,
)
groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups, types)
event_to_state = { event_to_state = {
event_id: group_to_state[group] event_id: group_to_state[group]
for event_id, group in event_to_groups.items() for event_id, group in event_to_groups.items()
@ -272,6 +316,23 @@ class StateStore(SQLBaseStore):
state_map = yield self.get_state_for_events([event_id], types) state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id]) defer.returnValue(state_map[event_id])
@defer.inlineCallbacks
def get_state_ids_for_event(self, event_id, types=None):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
types(list[(str, str)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which
matches any key
Returns:
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_ids_for_events([event_id], types)
defer.returnValue(state_map[event_id])
@cached(num_args=2, max_entries=10000) @cached(num_args=2, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id): def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
@ -428,20 +489,13 @@ class StateStore(SQLBaseStore):
full=(types is None), full=(types is None),
) )
state_events = yield self._get_events(
[ev_id for sd in results.values() for ev_id in sd.values()],
get_prev_content=False
)
state_events = {e.event_id: e for e in state_events}
# Remove all the entries with None values. The None values were just # Remove all the entries with None values. The None values were just
# used for bookkeeping in the cache. # used for bookkeeping in the cache.
for group, state_dict in results.items(): for group, state_dict in results.items():
results[group] = { results[group] = {
key: state_events[event_id] key: event_id
for key, event_id in state_dict.items() for key, event_id in state_dict.items()
if event_id and event_id in state_events if event_id
} }
defer.returnValue(results) defer.returnValue(results)

View File

@ -271,10 +271,3 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "t%d-%d" % (self.topological, self.stream) return "t%d-%d" % (self.topological, self.stream)
else: else:
return "s%d" % (self.stream,) return "s%d" % (self.stream,)
# Some arbitrary constants used for internal API enumerations. Don't rely on
# exact values; always pass or compare symbolically
class ThirdPartyEntityKind(object):
USER = 'user'
LOCATION = 'location'

View File

@ -180,6 +180,25 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
}) })
@defer.inlineCallbacks
def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context):
user_ids = set(u[0] for u in user_tuples)
event_id_to_state = {}
for event_id, context in event_id_to_context.items():
state = yield store.get_events([
e_id
for key, e_id in context.current_state_ids.iteritems()
if key == (EventTypes.RoomHistoryVisibility, "")
or (key[0] == EventTypes.Member and key[1] in user_ids)
])
event_id_to_state[event_id] = state
res = yield filter_events_for_clients(
store, user_tuples, events, event_id_to_state
)
defer.returnValue(res)
@defer.inlineCallbacks @defer.inlineCallbacks
def filter_events_for_client(store, user_id, events, is_peeking=False): def filter_events_for_client(store, user_id, events, is_peeking=False):
""" """

View File

@ -305,7 +305,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.event_id += 1 self.event_id += 1
context = EventContext(current_state=state) if state is not None:
state_ids = {
key: e.event_id for key, e in state.items()
}
else:
state_ids = None
context = EventContext(current_state_ids=state_ids)
context.push_actions = push_actions context.push_actions = push_actions
ordering = None ordering = None

View File

@ -67,9 +67,11 @@ class StateGroupStore(object):
self._event_to_state_group = {} self._event_to_state_group = {}
self._group_to_state = {} self._group_to_state = {}
self._event_id_to_event = {}
self._next_group = 1 self._next_group = 1
def get_state_groups(self, room_id, event_ids): def get_state_groups_ids(self, room_id, event_ids):
groups = {} groups = {}
for event_id in event_ids: for event_id in event_ids:
group = self._event_to_state_group.get(event_id) group = self._event_to_state_group.get(event_id)
@ -79,23 +81,33 @@ class StateGroupStore(object):
return defer.succeed(groups) return defer.succeed(groups)
def store_state_groups(self, event, context): def store_state_groups(self, event, context):
if context.current_state is None: if context.current_state_ids is None:
return return
state_events = context.current_state state_events = dict(context.current_state_ids)
if event.is_state(): if event.is_state():
state_events[(event.type, event.state_key)] = event state_events[(event.type, event.state_key)] = event.event_id
state_group = context.state_group state_group = context.state_group
if not state_group: if not state_group:
state_group = self._next_group state_group = self._next_group
self._next_group += 1 self._next_group += 1
self._group_to_state[state_group] = state_events.values() self._group_to_state[state_group] = state_events
self._event_to_state_group[event.event_id] = state_group self._event_to_state_group[event.event_id] = state_group
def get_events(self, event_ids, **kwargs):
return {
e_id: self._event_id_to_event[e_id] for e_id in event_ids
if e_id in self._event_id_to_event
}
def register_events(self, events):
for e in events:
self._event_id_to_event[e.event_id] = e
class DictObj(dict): class DictObj(dict):
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -136,8 +148,9 @@ class StateTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.store = Mock( self.store = Mock(
spec_set=[ spec_set=[
"get_state_groups", "get_state_groups_ids",
"add_event_hashes", "add_event_hashes",
"get_events",
] ]
) )
hs = Mock(spec_set=[ hs = Mock(spec_set=[
@ -187,7 +200,7 @@ class StateTestCase(unittest.TestCase):
) )
store = StateGroupStore() store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
context_store = {} context_store = {}
@ -196,7 +209,7 @@ class StateTestCase(unittest.TestCase):
store.store_state_groups(event, context) store.store_state_groups(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
self.assertEqual(2, len(context_store["D"].current_state)) self.assertEqual(2, len(context_store["D"].current_state_ids))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_branch_basic_conflict(self): def test_branch_basic_conflict(self):
@ -239,7 +252,9 @@ class StateTestCase(unittest.TestCase):
) )
store = StateGroupStore() store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
self.store.get_events = store.get_events
store.register_events(graph.walk())
context_store = {} context_store = {}
@ -250,7 +265,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual( self.assertSetEqual(
{"START", "A", "C"}, {"START", "A", "C"},
{e.event_id for e in context_store["D"].current_state.values()} {e_id for e_id in context_store["D"].current_state_ids.values()}
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -303,7 +318,9 @@ class StateTestCase(unittest.TestCase):
) )
store = StateGroupStore() store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
self.store.get_events = store.get_events
store.register_events(graph.walk())
context_store = {} context_store = {}
@ -314,7 +331,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual( self.assertSetEqual(
{"START", "A", "B", "C"}, {"START", "A", "B", "C"},
{e.event_id for e in context_store["E"].current_state.values()} {e for e in context_store["E"].current_state_ids.values()}
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -384,7 +401,9 @@ class StateTestCase(unittest.TestCase):
graph = Graph(nodes, edges) graph = Graph(nodes, edges)
store = StateGroupStore() store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
self.store.get_events = store.get_events
store.register_events(graph.walk())
context_store = {} context_store = {}
@ -395,7 +414,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual( self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"}, {"A1", "A2", "A3", "A5", "B"},
{e.event_id for e in context_store["D"].current_state.values()} {e for e in context_store["D"].current_state_ids.values()}
) )
def _add_depths(self, nodes, edges): def _add_depths(self, nodes, edges):
@ -424,13 +443,8 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state event, old_state=old_state
) )
for k, v in context.current_state.items():
type, state_key = k
self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key)
self.assertEqual( self.assertEqual(
set(old_state), set(context.current_state.values()) set(e.event_id for e in old_state), set(context.current_state_ids.values())
) )
self.assertIsNone(context.state_group) self.assertIsNone(context.state_group)
@ -449,14 +463,8 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state event, old_state=old_state
) )
for k, v in context.current_state.items():
type, state_key = k
self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key)
self.assertEqual( self.assertEqual(
set(old_state), set(e.event_id for e in old_state), set(context.current_state_ids.values())
set(context.current_state.values())
) )
self.assertIsNone(context.state_group) self.assertIsNone(context.state_group)
@ -473,20 +481,15 @@ class StateTestCase(unittest.TestCase):
group_name = "group_name_1" group_name = "group_name_1"
self.store.get_state_groups.return_value = { self.store.get_state_groups_ids.return_value = {
group_name: old_state, group_name: {(e.type, e.state_key): e.event_id for e in old_state},
} }
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
for k, v in context.current_state.items():
type, state_key = k
self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key)
self.assertEqual( self.assertEqual(
set([e.event_id for e in old_state]), set([e.event_id for e in old_state]),
set([e.event_id for e in context.current_state.values()]) set(context.current_state_ids.values())
) )
self.assertEqual(group_name, context.state_group) self.assertEqual(group_name, context.state_group)
@ -503,20 +506,15 @@ class StateTestCase(unittest.TestCase):
group_name = "group_name_1" group_name = "group_name_1"
self.store.get_state_groups.return_value = { self.store.get_state_groups_ids.return_value = {
group_name: old_state, group_name: {(e.type, e.state_key): e.event_id for e in old_state},
} }
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
for k, v in context.current_state.items():
type, state_key = k
self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key)
self.assertEqual( self.assertEqual(
set([e.event_id for e in old_state]), set([e.event_id for e in old_state]),
set([e.event_id for e in context.current_state.values()]) set(context.current_state_ids.values())
) )
self.assertIsNone(context.state_group) self.assertIsNone(context.state_group)
@ -543,9 +541,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
] ]
store = StateGroupStore()
store.register_events(old_state_1)
store.register_events(old_state_2)
self.store.get_events = store.get_events
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(len(context.current_state), 6) self.assertEqual(len(context.current_state_ids), 6)
self.assertIsNone(context.state_group) self.assertIsNone(context.state_group)
@ -571,9 +574,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
] ]
store = StateGroupStore()
store.register_events(old_state_1)
store.register_events(old_state_2)
self.store.get_events = store.get_events
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(len(context.current_state), 6) self.assertEqual(len(context.current_state_ids), 6)
self.assertIsNone(context.state_group) self.assertIsNone(context.state_group)
@ -606,9 +614,16 @@ class StateTestCase(unittest.TestCase):
create_event(type="test1", state_key="1", depth=2), create_event(type="test1", state_key="1", depth=2),
] ]
store = StateGroupStore()
store.register_events(old_state_1)
store.register_events(old_state_2)
self.store.get_events = store.get_events
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(old_state_2[2], context.current_state[("test1", "1")]) self.assertEqual(
old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
)
# Reverse the depth to make sure we are actually using the depths # Reverse the depth to make sure we are actually using the depths
# during state resolution. # during state resolution.
@ -625,17 +640,22 @@ class StateTestCase(unittest.TestCase):
create_event(type="test1", state_key="1", depth=1), create_event(type="test1", state_key="1", depth=1),
] ]
store.register_events(old_state_1)
store.register_events(old_state_2)
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(old_state_1[2], context.current_state[("test1", "1")]) self.assertEqual(
old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
)
def _get_context(self, event, old_state_1, old_state_2): def _get_context(self, event, old_state_1, old_state_2):
group_name_1 = "group_name_1" group_name_1 = "group_name_1"
group_name_2 = "group_name_2" group_name_2 = "group_name_2"
self.store.get_state_groups.return_value = { self.store.get_state_groups_ids.return_value = {
group_name_1: old_state_1, group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
group_name_2: old_state_2, group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
} }
return self.state.compute_event_context(event) return self.state.compute_event_context(event)