Merge remote-tracking branch 'origin/develop' into rav/remove_who_forgot_in_room

This commit is contained in:
Richard van der Hoff 2018-07-23 17:15:12 +01:00
commit 4f5cc8e4e7
30 changed files with 480 additions and 199 deletions

1
changelog.d/3520.bugfix Normal file
View File

@ -0,0 +1 @@
Correctly announce deleted devices over federation

0
changelog.d/3562.misc Normal file
View File

1
changelog.d/3572.misc Normal file
View File

@ -0,0 +1 @@
Merge Linearizer and Limiter

0
changelog.d/3577.misc Normal file
View File

1
changelog.d/3579.misc Normal file
View File

@ -0,0 +1 @@
Lazily load state on master process when using workers to reduce DB consumption

1
changelog.d/3581.misc Normal file
View File

@ -0,0 +1 @@
Lazily load state on master process when using workers to reduce DB consumption

1
changelog.d/3582.misc Normal file
View File

@ -0,0 +1 @@
Lazily load state on master process when using workers to reduce DB consumption

View File

@ -65,8 +65,9 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True): def check_from_context(self, event, context, do_sig_check=True):
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.compute_auth_events( auth_events_ids = yield self.compute_auth_events(
event, context.prev_state_ids, for_verification=True, event, prev_state_ids, for_verification=True,
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = yield self.store.get_events(auth_events_ids)
auth_events = { auth_events = {
@ -544,7 +545,8 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def add_auth_events(self, builder, context): def add_auth_events(self, builder, context):
auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids) prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_ids = yield self.compute_auth_events(builder, prev_state_ids)
auth_events_entries = yield self.store.add_event_hashes( auth_events_entries = yield self.store.add_event_hashes(
auth_ids auth_ids

View File

@ -18,6 +18,8 @@ import logging
import os import os
import sys import sys
from six import iteritems
from twisted.application import service from twisted.application import service
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.web.resource import EncodingResourceWrapper, NoResource from twisted.web.resource import EncodingResourceWrapper, NoResource
@ -442,7 +444,7 @@ def run(hs):
stats["total_nonbridged_users"] = total_nonbridged_users stats["total_nonbridged_users"] = total_nonbridged_users
daily_user_type_results = yield hs.get_datastore().count_daily_user_type() daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
for name, count in daily_user_type_results.iteritems(): for name, count in iteritems(daily_user_type_results):
stats["daily_user_type_" + name] = count stats["daily_user_type_" + name] = count
room_count = yield hs.get_datastore().get_room_count() room_count = yield hs.get_datastore().get_room_count()
@ -453,7 +455,7 @@ def run(hs):
stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
r30_results = yield hs.get_datastore().count_r30_users() r30_results = yield hs.get_datastore().count_r30_users()
for name, count in r30_results.iteritems(): for name, count in iteritems(r30_results):
stats["r30_users_" + name] = count stats["r30_users_" + name] = count
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()

View File

@ -25,6 +25,8 @@ import subprocess
import sys import sys
import time import time
from six import iteritems
import yaml import yaml
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
@ -173,7 +175,7 @@ def main():
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
cache_factors = config.get("synctl_cache_factors", {}) cache_factors = config.get("synctl_cache_factors", {})
for cache_name, factor in cache_factors.iteritems(): for cache_name, factor in iteritems(cache_factors):
os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor) os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor)
worker_configfiles = [] worker_configfiles = []

View File

@ -13,22 +13,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from six import iteritems
from frozendict import frozendict from frozendict import frozendict
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
class EventContext(object): class EventContext(object):
""" """
Attributes: Attributes:
current_state_ids (dict[(str, str), str]):
The current state map including the current event.
(type, state_key) -> event_id
prev_state_ids (dict[(str, str), str]):
The current state map excluding the current event.
(type, state_key) -> event_id
state_group (int|None): state group id, if the state has been stored state_group (int|None): state group id, if the state has been stored
as a state group. This is usually only None if e.g. the event is as a state group. This is usually only None if e.g. the event is
an outlier. an outlier.
@ -45,38 +41,77 @@ class EventContext(object):
prev_state_events (?): XXX: is this ever set to anything other than prev_state_events (?): XXX: is this ever set to anything other than
the empty list? the empty list?
_current_state_ids (dict[(str, str), str]|None):
The current state map including the current event. None if outlier
or we haven't fetched the state from DB yet.
(type, state_key) -> event_id
_prev_state_ids (dict[(str, str), str]|None):
The current state map excluding the current event. None if outlier
or we haven't fetched the state from DB yet.
(type, state_key) -> event_id
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
been calculated. None if we haven't started calculating yet
_event_type (str): The type of the event the context is associated with.
Only set when state has not been fetched yet.
_event_state_key (str|None): The state_key of the event the context is
associated with. Only set when state has not been fetched yet.
_prev_state_id (str|None): If the event associated with the context is
a state event, then `_prev_state_id` is the event_id of the state
that was replaced.
Only set when state has not been fetched yet.
""" """
__slots__ = [ __slots__ = [
"current_state_ids",
"prev_state_ids",
"state_group", "state_group",
"rejected", "rejected",
"prev_group", "prev_group",
"delta_ids", "delta_ids",
"prev_state_events", "prev_state_events",
"app_service", "app_service",
"_current_state_ids",
"_prev_state_ids",
"_prev_state_id",
"_event_type",
"_event_state_key",
"_fetching_state_deferred",
] ]
def __init__(self): def __init__(self):
# The current state including the current event self.prev_state_events = []
self.current_state_ids = None
# The current state excluding the current event
self.prev_state_ids = None
self.state_group = None
self.rejected = False self.rejected = False
self.app_service = None
@staticmethod
def with_state(state_group, current_state_ids, prev_state_ids,
prev_group=None, delta_ids=None):
context = EventContext()
# The current state including the current event
context._current_state_ids = current_state_ids
# The current state excluding the current event
context._prev_state_ids = prev_state_ids
context.state_group = state_group
context._prev_state_id = None
context._event_type = None
context._event_state_key = None
context._fetching_state_deferred = defer.succeed(None)
# A previously persisted state group and a delta between that # A previously persisted state group and a delta between that
# and this state. # and this state.
self.prev_group = None context.prev_group = prev_group
self.delta_ids = None context.delta_ids = delta_ids
self.prev_state_events = None return context
self.app_service = None @defer.inlineCallbacks
def serialize(self, event, store):
def serialize(self, event):
"""Converts self to a type that can be serialized as JSON, and then """Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize` deserialized by `deserialize`
@ -92,11 +127,12 @@ class EventContext(object):
# the prev_state_ids, so if we're a state event we include the event # the prev_state_ids, so if we're a state event we include the event
# id that we replaced in the state. # id that we replaced in the state.
if event.is_state(): if event.is_state():
prev_state_id = self.prev_state_ids.get((event.type, event.state_key)) prev_state_ids = yield self.get_prev_state_ids(store)
prev_state_id = prev_state_ids.get((event.type, event.state_key))
else: else:
prev_state_id = None prev_state_id = None
return { defer.returnValue({
"prev_state_id": prev_state_id, "prev_state_id": prev_state_id,
"event_type": event.type, "event_type": event.type,
"event_state_key": event.state_key if event.is_state() else None, "event_state_key": event.state_key if event.is_state() else None,
@ -106,10 +142,9 @@ class EventContext(object):
"delta_ids": _encode_state_dict(self.delta_ids), "delta_ids": _encode_state_dict(self.delta_ids),
"prev_state_events": self.prev_state_events, "prev_state_events": self.prev_state_events,
"app_service_id": self.app_service.id if self.app_service else None "app_service_id": self.app_service.id if self.app_service else None
} })
@staticmethod @staticmethod
@defer.inlineCallbacks
def deserialize(store, input): def deserialize(store, input):
"""Converts a dict that was produced by `serialize` back into a """Converts a dict that was produced by `serialize` back into a
EventContext. EventContext.
@ -122,32 +157,100 @@ class EventContext(object):
EventContext EventContext
""" """
context = EventContext() context = EventContext()
context.state_group = input["state_group"]
context.rejected = input["rejected"]
context.prev_group = input["prev_group"]
context.delta_ids = _decode_state_dict(input["delta_ids"])
context.prev_state_events = input["prev_state_events"]
# We use the state_group and prev_state_id stuff to pull the # We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids. # current_state_ids out of the DB and construct prev_state_ids.
prev_state_id = input["prev_state_id"] context._prev_state_id = input["prev_state_id"]
event_type = input["event_type"] context._event_type = input["event_type"]
event_state_key = input["event_state_key"] context._event_state_key = input["event_state_key"]
context._fetching_state_deferred = None
context.current_state_ids = yield store.get_state_ids_for_group( context.state_group = input["state_group"]
context.state_group, context.prev_group = input["prev_group"]
) context.delta_ids = _decode_state_dict(input["delta_ids"])
if prev_state_id and event_state_key:
context.prev_state_ids = dict(context.current_state_ids) context.rejected = input["rejected"]
context.prev_state_ids[(event_type, event_state_key)] = prev_state_id context.prev_state_events = input["prev_state_events"]
else:
context.prev_state_ids = context.current_state_ids
app_service_id = input["app_service_id"] app_service_id = input["app_service_id"]
if app_service_id: if app_service_id:
context.app_service = store.get_app_service_by_id(app_service_id) context.app_service = store.get_app_service_by_id(app_service_id)
defer.returnValue(context) return context
@defer.inlineCallbacks
def get_current_state_ids(self, store):
"""Gets the current state IDs
Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
is None, which happens when the associated event is an outlier.
"""
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store,
)
yield make_deferred_yieldable(self._fetching_state_deferred)
defer.returnValue(self._current_state_ids)
@defer.inlineCallbacks
def get_prev_state_ids(self, store):
"""Gets the prev state IDs
Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
is None, which happens when the associated event is an outlier.
"""
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store,
)
yield make_deferred_yieldable(self._fetching_state_deferred)
defer.returnValue(self._prev_state_ids)
@defer.inlineCallbacks
def _fill_out_state(self, store):
"""Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database.
"""
if self.state_group is None:
return
self._current_state_ids = yield store.get_state_ids_for_group(
self.state_group,
)
if self._prev_state_id and self._event_state_key is not None:
self._prev_state_ids = dict(self._current_state_ids)
key = (self._event_type, self._event_state_key)
self._prev_state_ids[key] = self._prev_state_id
else:
self._prev_state_ids = self._current_state_ids
@defer.inlineCallbacks
def update_state(self, state_group, prev_state_ids, current_state_ids,
delta_ids):
"""Replace the state in the context
"""
# We need to make sure we wait for any ongoing fetching of state
# to complete so that the updated state doesn't get clobbered
if self._fetching_state_deferred:
yield make_deferred_yieldable(self._fetching_state_deferred)
self.state_group = state_group
self._prev_state_ids = prev_state_ids
self._current_state_ids = current_state_ids
self.delta_ids = delta_ids
# We need to ensure that that we've marked as having fetched the state
self._fetching_state_deferred = defer.succeed(None)
def _encode_state_dict(state_dict): def _encode_state_dict(state_dict):
@ -159,7 +262,7 @@ def _encode_state_dict(state_dict):
return [ return [
(etype, state_key, v) (etype, state_key, v)
for (etype, state_key), v in state_dict.iteritems() for (etype, state_key), v in iteritems(state_dict)
] ]

View File

@ -112,8 +112,9 @@ class BaseHandler(object):
guest_access = event.content.get("guest_access", "forbidden") guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join": if guest_access != "can_join":
if context: if context:
current_state_ids = yield context.get_current_state_ids(self.store)
current_state = yield self.store.get_events( current_state = yield self.store.get_events(
list(context.current_state_ids.values()) list(current_state_ids.values())
) )
else: else:
current_state = yield self.state_handler.get_current_state( current_state = yield self.state_handler.get_current_state(

View File

@ -21,8 +21,8 @@ import logging
import sys import sys
import six import six
from six import iteritems from six import iteritems, itervalues
from six.moves import http_client from six.moves import http_client, zip
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json from signedjson.sign import verify_signed_json
@ -486,7 +486,10 @@ class FederationHandler(BaseHandler):
# joined the room. Don't bother if the user is just # joined the room. Don't bother if the user is just
# changing their profile info. # changing their profile info.
newly_joined = True newly_joined = True
prev_state_id = context.prev_state_ids.get(
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_id = prev_state_ids.get(
(event.type, event.state_key) (event.type, event.state_key)
) )
if prev_state_id: if prev_state_id:
@ -731,7 +734,7 @@ class FederationHandler(BaseHandler):
""" """
joined_users = [ joined_users = [
(state_key, int(event.depth)) (state_key, int(event.depth))
for (e_type, state_key), event in state.iteritems() for (e_type, state_key), event in iteritems(state)
if e_type == EventTypes.Member if e_type == EventTypes.Member
and event.membership == Membership.JOIN and event.membership == Membership.JOIN
] ]
@ -748,7 +751,7 @@ class FederationHandler(BaseHandler):
except Exception: except Exception:
pass pass
return sorted(joined_domains.iteritems(), key=lambda d: d[1]) return sorted(joined_domains.items(), key=lambda d: d[1])
curr_domains = get_domains_from_state(curr_state) curr_domains = get_domains_from_state(curr_state)
@ -811,7 +814,7 @@ class FederationHandler(BaseHandler):
tried_domains = set(likely_domains) tried_domains = set(likely_domains)
tried_domains.add(self.server_name) tried_domains.add(self.server_name)
event_ids = list(extremities.iterkeys()) event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill") logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = logcontext.preserve_fn( resolve = logcontext.preserve_fn(
@ -827,15 +830,15 @@ class FederationHandler(BaseHandler):
states = dict(zip(event_ids, [s.state for s in states])) states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events( state_map = yield self.store.get_events(
[e_id for ids in states.itervalues() for e_id in ids.itervalues()], [e_id for ids in itervalues(states) for e_id in itervalues(ids)],
get_prev_content=False get_prev_content=False
) )
states = { states = {
key: { key: {
k: state_map[e_id] k: state_map[e_id]
for k, e_id in state_dict.iteritems() for k, e_id in iteritems(state_dict)
if e_id in state_map if e_id in state_map
} for key, state_dict in states.iteritems() } for key, state_dict in iteritems(states)
} }
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
@ -1106,10 +1109,12 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id) yield user_joined_room(self.distributor, user, event.room_id)
state_ids = list(context.prev_state_ids.values()) prev_state_ids = yield context.get_prev_state_ids(self.store)
state_ids = list(prev_state_ids.values())
auth_chain = yield self.store.get_auth_chain(state_ids) auth_chain = yield self.store.get_auth_chain(state_ids)
state = yield self.store.get_events(list(context.prev_state_ids.values())) state = yield self.store.get_events(list(prev_state_ids.values()))
defer.returnValue({ defer.returnValue({
"state": list(state.values()), "state": list(state.values()),
@ -1515,7 +1520,7 @@ class FederationHandler(BaseHandler):
yield self.store.persist_events( yield self.store.persist_events(
[ [
(ev_info["event"], context) (ev_info["event"], context)
for ev_info, context in itertools.izip(event_infos, contexts) for ev_info, context in zip(event_infos, contexts)
], ],
backfilled=backfilled, backfilled=backfilled,
) )
@ -1635,8 +1640,9 @@ class FederationHandler(BaseHandler):
) )
if not auth_events: if not auth_events:
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events( auth_events_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids, for_verification=True, event, prev_state_ids, for_verification=True,
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = yield self.store.get_events(auth_events_ids)
auth_events = { auth_events = {
@ -1876,9 +1882,10 @@ class FederationHandler(BaseHandler):
break break
if do_resolution: if do_resolution:
prev_state_ids = yield context.get_prev_state_ids(self.store)
# 1. Get what we think is the auth chain. # 1. Get what we think is the auth chain.
auth_ids = yield self.auth.compute_auth_events( auth_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids event, prev_state_ids
) )
local_auth_chain = yield self.store.get_auth_chain( local_auth_chain = yield self.store.get_auth_chain(
auth_ids, include_given=True auth_ids, include_given=True
@ -1968,21 +1975,35 @@ class FederationHandler(BaseHandler):
k: a.event_id for k, a in iteritems(auth_events) k: a.event_id for k, a in iteritems(auth_events)
if k != event_key if k != event_key
} }
context.current_state_ids = dict(context.current_state_ids) current_state_ids = yield context.get_current_state_ids(self.store)
context.current_state_ids.update(state_updates) current_state_ids = dict(current_state_ids)
current_state_ids.update(state_updates)
if context.delta_ids is not None: if context.delta_ids is not None:
context.delta_ids = dict(context.delta_ids) delta_ids = dict(context.delta_ids)
context.delta_ids.update(state_updates) delta_ids.update(state_updates)
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({ prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = dict(prev_state_ids)
prev_state_ids.update({
k: a.event_id for k, a in iteritems(auth_events) k: a.event_id for k, a in iteritems(auth_events)
}) })
context.state_group = yield self.store.store_state_group(
state_group = yield self.store.store_state_group(
event.event_id, event.event_id,
event.room_id, event.room_id,
prev_group=context.prev_group, prev_group=context.prev_group,
delta_ids=context.delta_ids, delta_ids=delta_ids,
current_state_ids=context.current_state_ids, current_state_ids=current_state_ids,
)
yield context.update_state(
state_group=state_group,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
delta_ids=delta_ids,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -2222,7 +2243,8 @@ class FederationHandler(BaseHandler):
event.content["third_party_invite"]["signed"]["token"] event.content["third_party_invite"]["signed"]["token"]
) )
original_invite = None original_invite = None
original_invite_id = context.prev_state_ids.get(key) prev_state_ids = yield context.get_prev_state_ids(self.store)
original_invite_id = prev_state_ids.get(key)
if original_invite_id: if original_invite_id:
original_invite = yield self.store.get_event( original_invite = yield self.store.get_event(
original_invite_id, allow_none=True original_invite_id, allow_none=True
@ -2264,7 +2286,8 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"] signed = event.content["third_party_invite"]["signed"]
token = signed["token"] token = signed["token"]
invite_event_id = context.prev_state_ids.get( prev_state_ids = yield context.get_prev_state_ids(self.store)
invite_event_id = prev_state_ids.get(
(EventTypes.ThirdPartyInvite, token,) (EventTypes.ThirdPartyInvite, token,)
) )

View File

@ -630,7 +630,8 @@ class EventCreationHandler(object):
If so, returns the version of the event in context. If so, returns the version of the event in context.
Otherwise, returns None. Otherwise, returns None.
""" """
prev_event_id = context.prev_state_ids.get((event.type, event.state_key)) prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_event_id = prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True) prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event: if not prev_event:
return return
@ -752,8 +753,8 @@ class EventCreationHandler(object):
event = builder.build() event = builder.build()
logger.debug( logger.debug(
"Created event %s with state: %s", "Created event %s",
event.event_id, context.prev_state_ids, event.event_id,
) )
defer.returnValue( defer.returnValue(
@ -806,8 +807,9 @@ class EventCreationHandler(object):
# If we're a worker we need to hit out to the master. # If we're a worker we need to hit out to the master.
if self.config.worker_app: if self.config.worker_app:
yield send_event_to_master( yield send_event_to_master(
self.hs.get_clock(), clock=self.hs.get_clock(),
self.http_client, store=self.store,
client=self.http_client,
host=self.config.worker_replication_host, host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port, port=self.config.worker_replication_http_port,
requester=requester, requester=requester,
@ -884,9 +886,11 @@ class EventCreationHandler(object):
e.sender == event.sender e.sender == event.sender
) )
current_state_ids = yield context.get_current_state_ids(self.store)
state_to_include_ids = [ state_to_include_ids = [
e_id e_id
for k, e_id in iteritems(context.current_state_ids) for k, e_id in iteritems(current_state_ids)
if k[0] in self.hs.config.room_invite_state_types if k[0] in self.hs.config.room_invite_state_types
or k == (EventTypes.Member, event.sender) or k == (EventTypes.Member, event.sender)
] ]
@ -922,8 +926,9 @@ class EventCreationHandler(object):
) )
if event.type == EventTypes.Redaction: if event.type == EventTypes.Redaction:
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events( auth_events_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids, for_verification=True, event, prev_state_ids, for_verification=True,
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = yield self.store.get_events(auth_events_ids)
auth_events = { auth_events = {
@ -943,7 +948,9 @@ class EventCreationHandler(object):
"You don't have permission to redact events" "You don't have permission to redact events"
) )
if event.type == EventTypes.Create and context.prev_state_ids: if event.type == EventTypes.Create:
prev_state_ids = yield context.get_prev_state_ids(self.store)
if prev_state_ids:
raise AuthError( raise AuthError(
403, 403,
"Changing the room create event is forbidden", "Changing the room create event is forbidden",

View File

@ -201,7 +201,9 @@ class RoomMemberHandler(object):
ratelimit=ratelimit, ratelimit=ratelimit,
) )
prev_member_event_id = context.prev_state_ids.get( prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_member_event_id = prev_state_ids.get(
(EventTypes.Member, target.to_string()), (EventTypes.Member, target.to_string()),
None None
) )
@ -496,9 +498,10 @@ class RoomMemberHandler(object):
if prev_event is not None: if prev_event is not None:
return return
prev_state_ids = yield context.get_prev_state_ids(self.store)
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if requester.is_guest: if requester.is_guest:
guest_can_join = yield self._can_guest_join(context.prev_state_ids) guest_can_join = yield self._can_guest_join(prev_state_ids)
if not guest_can_join: if not guest_can_join:
# This should be an auth check, but guests are a local concept, # This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process. # so don't really fit into the general auth process.
@ -517,7 +520,7 @@ class RoomMemberHandler(object):
ratelimit=ratelimit, ratelimit=ratelimit,
) )
prev_member_event_id = context.prev_state_ids.get( prev_member_event_id = prev_state_ids.get(
(EventTypes.Member, event.state_key), (EventTypes.Member, event.state_key),
None None
) )

View File

@ -274,7 +274,7 @@ class Notifier(object):
logger.exception("Error notifying application services of event") logger.exception("Error notifying application services of event")
def on_new_event(self, stream_key, new_token, users=[], rooms=[]): def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
""" Used to inform listeners that something has happend event wise. """ Used to inform listeners that something has happened event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """

View File

@ -112,7 +112,8 @@ class BulkPushRuleEvaluator(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_power_levels_and_sender_level(self, event, context): def _get_power_levels_and_sender_level(self, event, context):
pl_event_id = context.prev_state_ids.get(POWER_KEY) prev_state_ids = yield context.get_prev_state_ids(self.store)
pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id: if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and # fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case # not having a power level event is an extreme edge case
@ -120,7 +121,7 @@ class BulkPushRuleEvaluator(object):
auth_events = {POWER_KEY: pl_event} auth_events = {POWER_KEY: pl_event}
else: else:
auth_events_ids = yield self.auth.compute_auth_events( auth_events_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids, for_verification=False, event, prev_state_ids, for_verification=False,
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = yield self.store.get_events(auth_events_ids)
auth_events = { auth_events = {
@ -304,7 +305,7 @@ class RulesForRoom(object):
push_rules_delta_state_cache_metric.inc_hits() push_rules_delta_state_cache_metric.inc_hits()
else: else:
current_state_ids = context.current_state_ids current_state_ids = yield context.get_current_state_ids(self.store)
push_rules_delta_state_cache_metric.inc_misses() push_rules_delta_state_cache_metric.inc_misses()
push_rules_state_size_counter.inc(len(current_state_ids)) push_rules_state_size_counter.inc(len(current_state_ids))

View File

@ -34,12 +34,13 @@ logger = logging.getLogger(__name__)
@defer.inlineCallbacks @defer.inlineCallbacks
def send_event_to_master(clock, client, host, port, requester, event, context, def send_event_to_master(clock, store, client, host, port, requester, event, context,
ratelimit, extra_users): ratelimit, extra_users):
"""Send event to be handled on the master """Send event to be handled on the master
Args: Args:
clock (synapse.util.Clock) clock (synapse.util.Clock)
store (DataStore)
client (SimpleHttpClient) client (SimpleHttpClient)
host (str): host of master host (str): host of master
port (int): port on master listening for HTTP replication port (int): port on master listening for HTTP replication
@ -53,11 +54,13 @@ def send_event_to_master(clock, client, host, port, requester, event, context,
host, port, event.event_id, host, port, event.event_id,
) )
serialized_context = yield context.serialize(event, store)
payload = { payload = {
"event": event.get_pdu_json(), "event": event.get_pdu_json(),
"internal_metadata": event.internal_metadata.get_dict(), "internal_metadata": event.internal_metadata.get_dict(),
"rejected_reason": event.rejected_reason, "rejected_reason": event.rejected_reason,
"context": context.serialize(event), "context": serialized_context,
"requester": requester.serialize(), "requester": requester.serialize(),
"ratelimit": ratelimit, "ratelimit": ratelimit,
"extra_users": [u.to_string() for u in extra_users], "extra_users": [u.to_string() for u in extra_users],

View File

@ -18,7 +18,7 @@ import hashlib
import logging import logging
from collections import namedtuple from collections import namedtuple
from six import iteritems, itervalues from six import iteritems, iterkeys, itervalues
from frozendict import frozendict from frozendict import frozendict
@ -203,25 +203,27 @@ class StateHandler(object):
# If this is an outlier, then we know it shouldn't have any current # If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and # state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group. # persisting the event won't store the state group.
context = EventContext()
if old_state: if old_state:
context.prev_state_ids = { prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
if event.is_state(): if event.is_state():
context.current_state_ids = dict(context.prev_state_ids) current_state_ids = dict(prev_state_ids)
key = (event.type, event.state_key) key = (event.type, event.state_key)
context.current_state_ids[key] = event.event_id current_state_ids[key] = event.event_id
else: else:
context.current_state_ids = context.prev_state_ids current_state_ids = prev_state_ids
else: else:
context.current_state_ids = {} current_state_ids = {}
context.prev_state_ids = {} prev_state_ids = {}
context.prev_state_events = []
# We don't store state for outliers, so we don't generate a state # We don't store state for outliers, so we don't generate a state
# froup for it. # group for it.
context.state_group = None context = EventContext.with_state(
state_group=None,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
)
defer.returnValue(context) defer.returnValue(context)
@ -230,31 +232,35 @@ class StateHandler(object):
# Let's just correctly fill out the context and create a # Let's just correctly fill out the context and create a
# new state group for it. # new state group for it.
context = EventContext() prev_state_ids = {
context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.prev_state_ids: if key in prev_state_ids:
replaces = context.prev_state_ids[key] replaces = prev_state_ids[key]
if replaces != event.event_id: # Paranoia check if replaces != event.event_id: # Paranoia check
event.unsigned["replaces_state"] = replaces event.unsigned["replaces_state"] = replaces
context.current_state_ids = dict(context.prev_state_ids) current_state_ids = dict(prev_state_ids)
context.current_state_ids[key] = event.event_id current_state_ids[key] = event.event_id
else: else:
context.current_state_ids = context.prev_state_ids current_state_ids = prev_state_ids
context.state_group = yield self.store.store_state_group( state_group = yield self.store.store_state_group(
event.event_id, event.event_id,
event.room_id, event.room_id,
prev_group=None, prev_group=None,
delta_ids=None, delta_ids=None,
current_state_ids=context.current_state_ids, current_state_ids=current_state_ids,
)
context = EventContext.with_state(
state_group=state_group,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
) )
context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context") logger.debug("calling resolve_state_groups from compute_event_context")
@ -262,47 +268,47 @@ class StateHandler(object):
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
) )
curr_state = entry.state prev_state_ids = entry.state
prev_group = None
delta_ids = None
context = EventContext()
context.prev_state_ids = curr_state
if event.is_state(): if event.is_state():
# If this is a state event then we need to create a new state # If this is a state event then we need to create a new state
# group for the state after this event. # group for the state after this event.
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.prev_state_ids: if key in prev_state_ids:
replaces = context.prev_state_ids[key] replaces = prev_state_ids[key]
event.unsigned["replaces_state"] = replaces event.unsigned["replaces_state"] = replaces
context.current_state_ids = dict(context.prev_state_ids) current_state_ids = dict(prev_state_ids)
context.current_state_ids[key] = event.event_id current_state_ids[key] = event.event_id
if entry.state_group: if entry.state_group:
# If the state at the event has a state group assigned then # If the state at the event has a state group assigned then
# we can use that as the prev group # we can use that as the prev group
context.prev_group = entry.state_group prev_group = entry.state_group
context.delta_ids = { delta_ids = {
key: event.event_id key: event.event_id
} }
elif entry.prev_group: elif entry.prev_group:
# If the state at the event only has a prev group, then we can # If the state at the event only has a prev group, then we can
# use that as a prev group too. # use that as a prev group too.
context.prev_group = entry.prev_group prev_group = entry.prev_group
context.delta_ids = dict(entry.delta_ids) delta_ids = dict(entry.delta_ids)
context.delta_ids[key] = event.event_id delta_ids[key] = event.event_id
context.state_group = yield self.store.store_state_group( state_group = yield self.store.store_state_group(
event.event_id, event.event_id,
event.room_id, event.room_id,
prev_group=context.prev_group, prev_group=prev_group,
delta_ids=context.delta_ids, delta_ids=delta_ids,
current_state_ids=context.current_state_ids, current_state_ids=current_state_ids,
) )
else: else:
context.current_state_ids = context.prev_state_ids current_state_ids = prev_state_ids
context.prev_group = entry.prev_group prev_group = entry.prev_group
context.delta_ids = entry.delta_ids delta_ids = entry.delta_ids
if entry.state_group is None: if entry.state_group is None:
entry.state_group = yield self.store.store_state_group( entry.state_group = yield self.store.store_state_group(
@ -310,13 +316,20 @@ class StateHandler(object):
event.room_id, event.room_id,
prev_group=entry.prev_group, prev_group=entry.prev_group,
delta_ids=entry.delta_ids, delta_ids=entry.delta_ids,
current_state_ids=context.current_state_ids, current_state_ids=current_state_ids,
) )
entry.state_id = entry.state_group entry.state_id = entry.state_group
context.state_group = entry.state_group state_group = entry.state_group
context = EventContext.with_state(
state_group=state_group,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
prev_group=prev_group,
delta_ids=delta_ids,
)
context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -647,7 +660,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
for event_id in event_ids for event_id in event_ids
) )
if event_map is not None: if event_map is not None:
needed_events -= set(event_map.iterkeys()) needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d conflicted events", len(needed_events)) logger.info("Asking for %d conflicted events", len(needed_events))
@ -668,7 +681,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
new_needed_events = set(itervalues(auth_events)) new_needed_events = set(itervalues(auth_events))
new_needed_events -= needed_events new_needed_events -= needed_events
if event_map is not None: if event_map is not None:
new_needed_events -= set(event_map.iterkeys()) new_needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d auth events", len(new_needed_events)) logger.info("Asking for %d auth events", len(new_needed_events))

View File

@ -248,6 +248,20 @@ class DeviceStore(SQLBaseStore):
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id, def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
content, stream_id): content, stream_id):
if content.get("deleted"):
self._simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
)
txn.call_after(
self.device_id_exists_cache.invalidate, (user_id, device_id,)
)
else:
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
table="device_lists_remote_cache", table="device_lists_remote_cache",
@ -366,7 +380,7 @@ class DeviceStore(SQLBaseStore):
now_stream_id = max(stream_id for stream_id in itervalues(query_map)) now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn( devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
) )
prev_sent_id_sql = """ prev_sent_id_sql = """
@ -393,12 +407,15 @@ class DeviceStore(SQLBaseStore):
prev_id = stream_id prev_id = stream_id
if device is not None:
key_json = device.get("key_json", None) key_json = device.get("key_json", None)
if key_json: if key_json:
result["keys"] = json.loads(key_json) result["keys"] = json.loads(key_json)
device_display_name = device.get("device_display_name", None) device_display_name = device.get("device_display_name", None)
if device_display_name: if device_display_name:
result["device_display_name"] = device_display_name result["device_display_name"] = device_display_name
else:
result["deleted"] = True
results.append(result) results.append(result)

View File

@ -64,12 +64,18 @@ class EndToEndKeyStore(SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_e2e_device_keys(self, query_list, include_all_devices=False): def get_e2e_device_keys(
self, query_list, include_all_devices=False,
include_deleted_devices=False,
):
"""Fetch a list of device keys. """Fetch a list of device keys.
Args: Args:
query_list(list): List of pairs of user_ids and device_ids. query_list(list): List of pairs of user_ids and device_ids.
include_all_devices (bool): whether to include entries for devices include_all_devices (bool): whether to include entries for devices
that don't have device keys that don't have device keys
include_deleted_devices (bool): whether to include null entries for
devices which no longer exist (but were in the query_list).
This option only takes effect if include_all_devices is true.
Returns: Returns:
Dict mapping from user-id to dict mapping from device_id to Dict mapping from user-id to dict mapping from device_id to
dict containing "key_json", "device_display_name". dict containing "key_json", "device_display_name".
@ -79,7 +85,7 @@ class EndToEndKeyStore(SQLBaseStore):
results = yield self.runInteraction( results = yield self.runInteraction(
"get_e2e_device_keys", self._get_e2e_device_keys_txn, "get_e2e_device_keys", self._get_e2e_device_keys_txn,
query_list, include_all_devices, query_list, include_all_devices, include_deleted_devices,
) )
for user_id, device_keys in iteritems(results): for user_id, device_keys in iteritems(results):
@ -88,10 +94,19 @@ class EndToEndKeyStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices): def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False,
include_deleted_devices=False,
):
query_clauses = [] query_clauses = []
query_params = [] query_params = []
if include_all_devices is False:
include_deleted_devices = False
if include_deleted_devices:
deleted_devices = set(query_list)
for (user_id, device_id) in query_list: for (user_id, device_id) in query_list:
query_clause = "user_id = ?" query_clause = "user_id = ?"
query_params.append(user_id) query_params.append(user_id)
@ -119,8 +134,14 @@ class EndToEndKeyStore(SQLBaseStore):
result = {} result = {}
for row in rows: for row in rows:
if include_deleted_devices:
deleted_devices.remove((row["user_id"], row["device_id"]))
result.setdefault(row["user_id"], {})[row["device_id"]] = row result.setdefault(row["user_id"], {})[row["device_id"]] = row
if include_deleted_devices:
for user_id, device_id in deleted_devices:
result.setdefault(user_id, {})[device_id] = None
return result return result
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -549,7 +549,7 @@ class EventsStore(EventsWorkerStore):
if ctx.state_group in state_groups_map: if ctx.state_group in state_groups_map:
continue continue
state_groups_map[ctx.state_group] = ctx.current_state_ids state_groups_map[ctx.state_group] = yield ctx.get_current_state_ids(self)
# We need to map the event_ids to their state groups. First, let's # We need to map the event_ids to their state groups. First, let's
# check if the event is one we're persisting, in which case we can # check if the event is one we're persisting, in which case we can

View File

@ -185,6 +185,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
defer.returnValue(results) defer.returnValue(results)
@defer.inlineCallbacks
def bulk_get_push_rules_for_room(self, event, context): def bulk_get_push_rules_for_room(self, event, context):
state_group = context.state_group state_group = context.state_group
if not state_group: if not state_group:
@ -194,9 +195,11 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
# To do this we set the state_group to a new object as object() != object() # To do this we set the state_group to a new object as object() != object()
state_group = object() state_group = object()
return self._bulk_get_push_rules_for_room( current_state_ids = yield context.get_current_state_ids(self)
event.room_id, state_group, context.current_state_ids, event=event result = yield self._bulk_get_push_rules_for_room(
event.room_id, state_group, current_state_ids, event=event
) )
defer.returnValue(result)
@cachedInlineCallbacks(num_args=2, cache_context=True) @cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids, def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,

View File

@ -232,6 +232,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
defer.returnValue(user_who_share_room) defer.returnValue(user_who_share_room)
@defer.inlineCallbacks
def get_joined_users_from_context(self, event, context): def get_joined_users_from_context(self, event, context):
state_group = context.state_group state_group = context.state_group
if not state_group: if not state_group:
@ -241,11 +242,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object() # To do this we set the state_group to a new object as object() != object()
state_group = object() state_group = object()
return self._get_joined_users_from_context( current_state_ids = yield context.get_current_state_ids(self)
event.room_id, state_group, context.current_state_ids, result = yield self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids,
event=event, event=event,
context=context, context=context,
) )
defer.returnValue(result)
def get_joined_users_from_state(self, room_id, state_entry): def get_joined_users_from_state(self, room_id, state_entry):
state_group = state_entry.state_group state_group = state_entry.state_group

View File

@ -184,13 +184,13 @@ class Linearizer(object):
# key_to_defer is a map from the key to a 2 element list where # key_to_defer is a map from the key to a 2 element list where
# the first element is the number of things executing, and # the first element is the number of things executing, and
# the second element is a deque of deferreds for the things blocked from # the second element is an OrderedDict, where the keys are deferreds for the
# executing. # things blocked from executing.
self.key_to_defer = {} self.key_to_defer = {}
@defer.inlineCallbacks @defer.inlineCallbacks
def queue(self, key): def queue(self, key):
entry = self.key_to_defer.setdefault(key, [0, collections.deque()]) entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
# If the number of things executing is greater than the maximum # If the number of things executing is greater than the maximum
# then add a deferred to the list of blocked items # then add a deferred to the list of blocked items
@ -198,12 +198,28 @@ class Linearizer(object):
# this item so that it can continue executing. # this item so that it can continue executing.
if entry[0] >= self.max_count: if entry[0] >= self.max_count:
new_defer = defer.Deferred() new_defer = defer.Deferred()
entry[1].append(new_defer) entry[1][new_defer] = 1
logger.info( logger.info(
"Waiting to acquire linearizer lock %r for key %r", self.name, key, "Waiting to acquire linearizer lock %r for key %r", self.name, key,
) )
try:
yield make_deferred_yieldable(new_defer) yield make_deferred_yieldable(new_defer)
except Exception as e:
if isinstance(e, CancelledError):
logger.info(
"Cancelling wait for linearizer lock %r for key %r",
self.name, key,
)
else:
logger.warn(
"Unexpected exception waiting for linearizer lock %r for key %r",
self.name, key,
)
# we just have to take ourselves back out of the queue.
del entry[1][new_defer]
raise
logger.info("Acquired linearizer lock %r for key %r", self.name, key) logger.info("Acquired linearizer lock %r for key %r", self.name, key)
entry[0] += 1 entry[0] += 1
@ -238,7 +254,7 @@ class Linearizer(object):
entry[0] -= 1 entry[0] -= 1
if entry[1]: if entry[1]:
next_def = entry[1].popleft() (next_def, _) = entry[1].popitem(last=False)
# we need to run the next thing in the sentinel context. # we need to run the next thing in the sentinel context.
with PreserveLoggingContext(): with PreserveLoggingContext():

View File

@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import itertools
import logging import logging
import operator import operator
import six from six import iteritems, itervalues
from six.moves import map
from twisted.internet import defer from twisted.internet import defer
@ -204,7 +205,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
return event return event
# check each event: gives an iterable[None|EventBase] # check each event: gives an iterable[None|EventBase]
filtered_events = itertools.imap(allowed, events) filtered_events = map(allowed, events)
# remove the None entries # remove the None entries
filtered_events = filter(operator.truth, filtered_events) filtered_events = filter(operator.truth, filtered_events)
@ -244,7 +245,7 @@ def filter_events_for_server(store, server_name, events):
# membership states for the requesting server to determine # membership states for the requesting server to determine
# if the server is either in the room or has been invited # if the server is either in the room or has been invited
# into the room. # into the room.
for ev in state.itervalues(): for ev in itervalues(state):
if ev.type != EventTypes.Member: if ev.type != EventTypes.Member:
continue continue
try: try:
@ -278,7 +279,7 @@ def filter_events_for_server(store, server_name, events):
) )
visibility_ids = set() visibility_ids = set()
for sids in event_to_state_ids.itervalues(): for sids in itervalues(event_to_state_ids):
hist = sids.get((EventTypes.RoomHistoryVisibility, "")) hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
if hist: if hist:
visibility_ids.add(hist) visibility_ids.add(hist)
@ -291,7 +292,7 @@ def filter_events_for_server(store, server_name, events):
event_map = yield store.get_events(visibility_ids) event_map = yield store.get_events(visibility_ids)
all_open = all( all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable") e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in event_map.itervalues() for e in itervalues(event_map)
) )
if all_open: if all_open:
@ -329,7 +330,7 @@ def filter_events_for_server(store, server_name, events):
# #
state_key_to_event_id_set = { state_key_to_event_id_set = {
e e
for key_to_eid in six.itervalues(event_to_state_ids) for key_to_eid in itervalues(event_to_state_ids)
for e in key_to_eid.items() for e in key_to_eid.items()
} }
@ -352,10 +353,10 @@ def filter_events_for_server(store, server_name, events):
event_to_state = { event_to_state = {
e_id: { e_id: {
key: event_map[inner_e_id] key: event_map[inner_e_id]
for key, inner_e_id in key_to_eid.iteritems() for key, inner_e_id in iteritems(key_to_eid)
if inner_e_id in event_map if inner_e_id in event_map
} }
for e_id, key_to_eid in event_to_state_ids.iteritems() for e_id, key_to_eid in iteritems(event_to_state_ids)
} }
defer.returnValue([ defer.returnValue([

View File

@ -222,9 +222,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
state_ids = { state_ids = {
key: e.event_id for key, e in state.items() key: e.event_id for key, e in state.items()
} }
context = EventContext() context = EventContext.with_state(
context.current_state_ids = state_ids state_group=None,
context.prev_state_ids = state_ids current_state_ids=state_ids,
prev_state_ids=state_ids
)
else: else:
state_handler = self.hs.get_state_handler() state_handler = self.hs.get_state_handler()
context = yield state_handler.compute_event_context(event) context = yield state_handler.compute_event_context(event)

View File

@ -137,7 +137,6 @@ class MessageAcceptTests(unittest.TestCase):
) )
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
@unittest.DEBUG
def test_cant_hide_past_history(self): def test_cant_hide_past_history(self):
""" """
If you send a message, you must be able to provide the direct If you send a message, you must be able to provide the direct
@ -178,7 +177,7 @@ class MessageAcceptTests(unittest.TestCase):
for x, y in d.items() for x, y in d.items()
if x == ("m.room.member", "@us:test") if x == ("m.room.member", "@us:test")
], ],
"auth_chain_ids": d.values(), "auth_chain_ids": list(d.values()),
} }
) )

View File

@ -204,7 +204,8 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
self.assertEqual(2, len(context_store["D"].prev_state_ids)) prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
self.assertEqual(2, len(prev_state_ids))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_branch_basic_conflict(self): def test_branch_basic_conflict(self):
@ -255,9 +256,11 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
self.assertSetEqual( self.assertSetEqual(
{"START", "A", "C"}, {"START", "A", "C"},
{e_id for e_id in context_store["D"].prev_state_ids.values()} {e_id for e_id in prev_state_ids.values()}
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -318,9 +321,11 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
self.assertSetEqual( self.assertSetEqual(
{"START", "A", "B", "C"}, {"START", "A", "B", "C"},
{e for e in context_store["E"].prev_state_ids.values()} {e for e in prev_state_ids.values()}
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -398,9 +403,11 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
self.assertSetEqual( self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"}, {"A1", "A2", "A3", "A5", "B"},
{e for e in context_store["D"].prev_state_ids.values()} {e for e in prev_state_ids.values()}
) )
def _add_depths(self, nodes, edges): def _add_depths(self, nodes, edges):
@ -429,8 +436,10 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state event, old_state=old_state
) )
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual( self.assertEqual(
set(e.event_id for e in old_state), set(context.current_state_ids.values()) set(e.event_id for e in old_state), set(current_state_ids.values())
) )
self.assertIsNotNone(context.state_group) self.assertIsNotNone(context.state_group)
@ -449,8 +458,10 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state event, old_state=old_state
) )
prev_state_ids = yield context.get_prev_state_ids(self.store)
self.assertEqual( self.assertEqual(
set(e.event_id for e in old_state), set(context.prev_state_ids.values()) set(e.event_id for e in old_state), set(prev_state_ids.values())
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -475,9 +486,11 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual( self.assertEqual(
set([e.event_id for e in old_state]), set([e.event_id for e in old_state]),
set(context.current_state_ids.values()) set(current_state_ids.values())
) )
self.assertEqual(group_name, context.state_group) self.assertEqual(group_name, context.state_group)
@ -504,9 +517,11 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event) context = yield self.state.compute_event_context(event)
prev_state_ids = yield context.get_prev_state_ids(self.store)
self.assertEqual( self.assertEqual(
set([e.event_id for e in old_state]), set([e.event_id for e in old_state]),
set(context.prev_state_ids.values()) set(prev_state_ids.values())
) )
self.assertIsNotNone(context.state_group) self.assertIsNotNone(context.state_group)
@ -545,7 +560,9 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
) )
self.assertEqual(len(context.current_state_ids), 6) current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual(len(current_state_ids), 6)
self.assertIsNotNone(context.state_group) self.assertIsNotNone(context.state_group)
@ -585,7 +602,9 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
) )
self.assertEqual(len(context.current_state_ids), 6) current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual(len(current_state_ids), 6)
self.assertIsNotNone(context.state_group) self.assertIsNotNone(context.state_group)
@ -642,8 +661,10 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
) )
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual( self.assertEqual(
old_state_2[3].event_id, context.current_state_ids[("test1", "1")] old_state_2[3].event_id, current_state_ids[("test1", "1")]
) )
# Reverse the depth to make sure we are actually using the depths # Reverse the depth to make sure we are actually using the depths
@ -670,8 +691,10 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
) )
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual( self.assertEqual(
old_state_1[3].event_id, context.current_state_ids[("test1", "1")] old_state_1[3].event_id, current_state_ids[("test1", "1")]
) )
def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2, def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,

View File

@ -17,6 +17,7 @@
from six.moves import range from six.moves import range
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError
from synapse.util import Clock, logcontext from synapse.util import Clock, logcontext
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
@ -112,3 +113,33 @@ class LinearizerTestCase(unittest.TestCase):
d6 = limiter.queue(key) d6 = limiter.queue(key)
with (yield d6): with (yield d6):
pass pass
@defer.inlineCallbacks
def test_cancellation(self):
linearizer = Linearizer()
key = object()
d1 = linearizer.queue(key)
cm1 = yield d1
d2 = linearizer.queue(key)
self.assertFalse(d2.called)
d3 = linearizer.queue(key)
self.assertFalse(d3.called)
d2.cancel()
with cm1:
pass
self.assertTrue(d2.called)
try:
yield d2
self.fail("Expected d2 to raise CancelledError")
except CancelledError:
pass
with (yield d3):
pass