mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-11 03:04:19 -05:00
Merge remote-tracking branch 'origin/develop' into rav/remove_who_forgot_in_room
This commit is contained in:
commit
4f5cc8e4e7
1
changelog.d/3520.bugfix
Normal file
1
changelog.d/3520.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Correctly announce deleted devices over federation
|
0
changelog.d/3562.misc
Normal file
0
changelog.d/3562.misc
Normal file
1
changelog.d/3572.misc
Normal file
1
changelog.d/3572.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Merge Linearizer and Limiter
|
0
changelog.d/3577.misc
Normal file
0
changelog.d/3577.misc
Normal file
1
changelog.d/3579.misc
Normal file
1
changelog.d/3579.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Lazily load state on master process when using workers to reduce DB consumption
|
1
changelog.d/3581.misc
Normal file
1
changelog.d/3581.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Lazily load state on master process when using workers to reduce DB consumption
|
1
changelog.d/3582.misc
Normal file
1
changelog.d/3582.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Lazily load state on master process when using workers to reduce DB consumption
|
@ -65,8 +65,9 @@ class Auth(object):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_from_context(self, event, context, do_sig_check=True):
|
def check_from_context(self, event, context, do_sig_check=True):
|
||||||
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
auth_events_ids = yield self.compute_auth_events(
|
auth_events_ids = yield self.compute_auth_events(
|
||||||
event, context.prev_state_ids, for_verification=True,
|
event, prev_state_ids, for_verification=True,
|
||||||
)
|
)
|
||||||
auth_events = yield self.store.get_events(auth_events_ids)
|
auth_events = yield self.store.get_events(auth_events_ids)
|
||||||
auth_events = {
|
auth_events = {
|
||||||
@ -544,7 +545,8 @@ class Auth(object):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_auth_events(self, builder, context):
|
def add_auth_events(self, builder, context):
|
||||||
auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids)
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
|
auth_ids = yield self.compute_auth_events(builder, prev_state_ids)
|
||||||
|
|
||||||
auth_events_entries = yield self.store.add_event_hashes(
|
auth_events_entries = yield self.store.add_event_hashes(
|
||||||
auth_ids
|
auth_ids
|
||||||
|
@ -18,6 +18,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from six import iteritems
|
||||||
|
|
||||||
from twisted.application import service
|
from twisted.application import service
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
from twisted.web.resource import EncodingResourceWrapper, NoResource
|
from twisted.web.resource import EncodingResourceWrapper, NoResource
|
||||||
@ -442,7 +444,7 @@ def run(hs):
|
|||||||
stats["total_nonbridged_users"] = total_nonbridged_users
|
stats["total_nonbridged_users"] = total_nonbridged_users
|
||||||
|
|
||||||
daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
|
daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
|
||||||
for name, count in daily_user_type_results.iteritems():
|
for name, count in iteritems(daily_user_type_results):
|
||||||
stats["daily_user_type_" + name] = count
|
stats["daily_user_type_" + name] = count
|
||||||
|
|
||||||
room_count = yield hs.get_datastore().get_room_count()
|
room_count = yield hs.get_datastore().get_room_count()
|
||||||
@ -453,7 +455,7 @@ def run(hs):
|
|||||||
stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
|
stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
|
||||||
|
|
||||||
r30_results = yield hs.get_datastore().count_r30_users()
|
r30_results = yield hs.get_datastore().count_r30_users()
|
||||||
for name, count in r30_results.iteritems():
|
for name, count in iteritems(r30_results):
|
||||||
stats["r30_users_" + name] = count
|
stats["r30_users_" + name] = count
|
||||||
|
|
||||||
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
|
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
|
||||||
|
@ -25,6 +25,8 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from six import iteritems
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
|
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
|
||||||
@ -173,7 +175,7 @@ def main():
|
|||||||
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
|
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
|
||||||
|
|
||||||
cache_factors = config.get("synctl_cache_factors", {})
|
cache_factors = config.get("synctl_cache_factors", {})
|
||||||
for cache_name, factor in cache_factors.iteritems():
|
for cache_name, factor in iteritems(cache_factors):
|
||||||
os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor)
|
os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor)
|
||||||
|
|
||||||
worker_configfiles = []
|
worker_configfiles = []
|
||||||
|
@ -13,22 +13,18 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from six import iteritems
|
||||||
|
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
|
||||||
|
|
||||||
|
|
||||||
class EventContext(object):
|
class EventContext(object):
|
||||||
"""
|
"""
|
||||||
Attributes:
|
Attributes:
|
||||||
current_state_ids (dict[(str, str), str]):
|
|
||||||
The current state map including the current event.
|
|
||||||
(type, state_key) -> event_id
|
|
||||||
|
|
||||||
prev_state_ids (dict[(str, str), str]):
|
|
||||||
The current state map excluding the current event.
|
|
||||||
(type, state_key) -> event_id
|
|
||||||
|
|
||||||
state_group (int|None): state group id, if the state has been stored
|
state_group (int|None): state group id, if the state has been stored
|
||||||
as a state group. This is usually only None if e.g. the event is
|
as a state group. This is usually only None if e.g. the event is
|
||||||
an outlier.
|
an outlier.
|
||||||
@ -45,38 +41,77 @@ class EventContext(object):
|
|||||||
|
|
||||||
prev_state_events (?): XXX: is this ever set to anything other than
|
prev_state_events (?): XXX: is this ever set to anything other than
|
||||||
the empty list?
|
the empty list?
|
||||||
|
|
||||||
|
_current_state_ids (dict[(str, str), str]|None):
|
||||||
|
The current state map including the current event. None if outlier
|
||||||
|
or we haven't fetched the state from DB yet.
|
||||||
|
(type, state_key) -> event_id
|
||||||
|
|
||||||
|
_prev_state_ids (dict[(str, str), str]|None):
|
||||||
|
The current state map excluding the current event. None if outlier
|
||||||
|
or we haven't fetched the state from DB yet.
|
||||||
|
(type, state_key) -> event_id
|
||||||
|
|
||||||
|
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
|
||||||
|
been calculated. None if we haven't started calculating yet
|
||||||
|
|
||||||
|
_event_type (str): The type of the event the context is associated with.
|
||||||
|
Only set when state has not been fetched yet.
|
||||||
|
|
||||||
|
_event_state_key (str|None): The state_key of the event the context is
|
||||||
|
associated with. Only set when state has not been fetched yet.
|
||||||
|
|
||||||
|
_prev_state_id (str|None): If the event associated with the context is
|
||||||
|
a state event, then `_prev_state_id` is the event_id of the state
|
||||||
|
that was replaced.
|
||||||
|
Only set when state has not been fetched yet.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = [
|
__slots__ = [
|
||||||
"current_state_ids",
|
|
||||||
"prev_state_ids",
|
|
||||||
"state_group",
|
"state_group",
|
||||||
"rejected",
|
"rejected",
|
||||||
"prev_group",
|
"prev_group",
|
||||||
"delta_ids",
|
"delta_ids",
|
||||||
"prev_state_events",
|
"prev_state_events",
|
||||||
"app_service",
|
"app_service",
|
||||||
|
"_current_state_ids",
|
||||||
|
"_prev_state_ids",
|
||||||
|
"_prev_state_id",
|
||||||
|
"_event_type",
|
||||||
|
"_event_state_key",
|
||||||
|
"_fetching_state_deferred",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# The current state including the current event
|
self.prev_state_events = []
|
||||||
self.current_state_ids = None
|
|
||||||
# The current state excluding the current event
|
|
||||||
self.prev_state_ids = None
|
|
||||||
self.state_group = None
|
|
||||||
|
|
||||||
self.rejected = False
|
self.rejected = False
|
||||||
|
self.app_service = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def with_state(state_group, current_state_ids, prev_state_ids,
|
||||||
|
prev_group=None, delta_ids=None):
|
||||||
|
context = EventContext()
|
||||||
|
|
||||||
|
# The current state including the current event
|
||||||
|
context._current_state_ids = current_state_ids
|
||||||
|
# The current state excluding the current event
|
||||||
|
context._prev_state_ids = prev_state_ids
|
||||||
|
context.state_group = state_group
|
||||||
|
|
||||||
|
context._prev_state_id = None
|
||||||
|
context._event_type = None
|
||||||
|
context._event_state_key = None
|
||||||
|
context._fetching_state_deferred = defer.succeed(None)
|
||||||
|
|
||||||
# A previously persisted state group and a delta between that
|
# A previously persisted state group and a delta between that
|
||||||
# and this state.
|
# and this state.
|
||||||
self.prev_group = None
|
context.prev_group = prev_group
|
||||||
self.delta_ids = None
|
context.delta_ids = delta_ids
|
||||||
|
|
||||||
self.prev_state_events = None
|
return context
|
||||||
|
|
||||||
self.app_service = None
|
@defer.inlineCallbacks
|
||||||
|
def serialize(self, event, store):
|
||||||
def serialize(self, event):
|
|
||||||
"""Converts self to a type that can be serialized as JSON, and then
|
"""Converts self to a type that can be serialized as JSON, and then
|
||||||
deserialized by `deserialize`
|
deserialized by `deserialize`
|
||||||
|
|
||||||
@ -92,11 +127,12 @@ class EventContext(object):
|
|||||||
# the prev_state_ids, so if we're a state event we include the event
|
# the prev_state_ids, so if we're a state event we include the event
|
||||||
# id that we replaced in the state.
|
# id that we replaced in the state.
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
prev_state_id = self.prev_state_ids.get((event.type, event.state_key))
|
prev_state_ids = yield self.get_prev_state_ids(store)
|
||||||
|
prev_state_id = prev_state_ids.get((event.type, event.state_key))
|
||||||
else:
|
else:
|
||||||
prev_state_id = None
|
prev_state_id = None
|
||||||
|
|
||||||
return {
|
defer.returnValue({
|
||||||
"prev_state_id": prev_state_id,
|
"prev_state_id": prev_state_id,
|
||||||
"event_type": event.type,
|
"event_type": event.type,
|
||||||
"event_state_key": event.state_key if event.is_state() else None,
|
"event_state_key": event.state_key if event.is_state() else None,
|
||||||
@ -106,10 +142,9 @@ class EventContext(object):
|
|||||||
"delta_ids": _encode_state_dict(self.delta_ids),
|
"delta_ids": _encode_state_dict(self.delta_ids),
|
||||||
"prev_state_events": self.prev_state_events,
|
"prev_state_events": self.prev_state_events,
|
||||||
"app_service_id": self.app_service.id if self.app_service else None
|
"app_service_id": self.app_service.id if self.app_service else None
|
||||||
}
|
})
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@defer.inlineCallbacks
|
|
||||||
def deserialize(store, input):
|
def deserialize(store, input):
|
||||||
"""Converts a dict that was produced by `serialize` back into a
|
"""Converts a dict that was produced by `serialize` back into a
|
||||||
EventContext.
|
EventContext.
|
||||||
@ -122,32 +157,100 @@ class EventContext(object):
|
|||||||
EventContext
|
EventContext
|
||||||
"""
|
"""
|
||||||
context = EventContext()
|
context = EventContext()
|
||||||
context.state_group = input["state_group"]
|
|
||||||
context.rejected = input["rejected"]
|
|
||||||
context.prev_group = input["prev_group"]
|
|
||||||
context.delta_ids = _decode_state_dict(input["delta_ids"])
|
|
||||||
context.prev_state_events = input["prev_state_events"]
|
|
||||||
|
|
||||||
# We use the state_group and prev_state_id stuff to pull the
|
# We use the state_group and prev_state_id stuff to pull the
|
||||||
# current_state_ids out of the DB and construct prev_state_ids.
|
# current_state_ids out of the DB and construct prev_state_ids.
|
||||||
prev_state_id = input["prev_state_id"]
|
context._prev_state_id = input["prev_state_id"]
|
||||||
event_type = input["event_type"]
|
context._event_type = input["event_type"]
|
||||||
event_state_key = input["event_state_key"]
|
context._event_state_key = input["event_state_key"]
|
||||||
|
context._fetching_state_deferred = None
|
||||||
|
|
||||||
context.current_state_ids = yield store.get_state_ids_for_group(
|
context.state_group = input["state_group"]
|
||||||
context.state_group,
|
context.prev_group = input["prev_group"]
|
||||||
)
|
context.delta_ids = _decode_state_dict(input["delta_ids"])
|
||||||
if prev_state_id and event_state_key:
|
|
||||||
context.prev_state_ids = dict(context.current_state_ids)
|
context.rejected = input["rejected"]
|
||||||
context.prev_state_ids[(event_type, event_state_key)] = prev_state_id
|
context.prev_state_events = input["prev_state_events"]
|
||||||
else:
|
|
||||||
context.prev_state_ids = context.current_state_ids
|
|
||||||
|
|
||||||
app_service_id = input["app_service_id"]
|
app_service_id = input["app_service_id"]
|
||||||
if app_service_id:
|
if app_service_id:
|
||||||
context.app_service = store.get_app_service_by_id(app_service_id)
|
context.app_service = store.get_app_service_by_id(app_service_id)
|
||||||
|
|
||||||
defer.returnValue(context)
|
return context
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_current_state_ids(self, store):
|
||||||
|
"""Gets the current state IDs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[dict[(str, str), str]|None]: Returns None if state_group
|
||||||
|
is None, which happens when the associated event is an outlier.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self._fetching_state_deferred:
|
||||||
|
self._fetching_state_deferred = run_in_background(
|
||||||
|
self._fill_out_state, store,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield make_deferred_yieldable(self._fetching_state_deferred)
|
||||||
|
|
||||||
|
defer.returnValue(self._current_state_ids)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_prev_state_ids(self, store):
|
||||||
|
"""Gets the prev state IDs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[dict[(str, str), str]|None]: Returns None if state_group
|
||||||
|
is None, which happens when the associated event is an outlier.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self._fetching_state_deferred:
|
||||||
|
self._fetching_state_deferred = run_in_background(
|
||||||
|
self._fill_out_state, store,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield make_deferred_yieldable(self._fetching_state_deferred)
|
||||||
|
|
||||||
|
defer.returnValue(self._prev_state_ids)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _fill_out_state(self, store):
|
||||||
|
"""Called to populate the _current_state_ids and _prev_state_ids
|
||||||
|
attributes by loading from the database.
|
||||||
|
"""
|
||||||
|
if self.state_group is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._current_state_ids = yield store.get_state_ids_for_group(
|
||||||
|
self.state_group,
|
||||||
|
)
|
||||||
|
if self._prev_state_id and self._event_state_key is not None:
|
||||||
|
self._prev_state_ids = dict(self._current_state_ids)
|
||||||
|
|
||||||
|
key = (self._event_type, self._event_state_key)
|
||||||
|
self._prev_state_ids[key] = self._prev_state_id
|
||||||
|
else:
|
||||||
|
self._prev_state_ids = self._current_state_ids
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def update_state(self, state_group, prev_state_ids, current_state_ids,
|
||||||
|
delta_ids):
|
||||||
|
"""Replace the state in the context
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We need to make sure we wait for any ongoing fetching of state
|
||||||
|
# to complete so that the updated state doesn't get clobbered
|
||||||
|
if self._fetching_state_deferred:
|
||||||
|
yield make_deferred_yieldable(self._fetching_state_deferred)
|
||||||
|
|
||||||
|
self.state_group = state_group
|
||||||
|
self._prev_state_ids = prev_state_ids
|
||||||
|
self._current_state_ids = current_state_ids
|
||||||
|
self.delta_ids = delta_ids
|
||||||
|
|
||||||
|
# We need to ensure that that we've marked as having fetched the state
|
||||||
|
self._fetching_state_deferred = defer.succeed(None)
|
||||||
|
|
||||||
|
|
||||||
def _encode_state_dict(state_dict):
|
def _encode_state_dict(state_dict):
|
||||||
@ -159,7 +262,7 @@ def _encode_state_dict(state_dict):
|
|||||||
|
|
||||||
return [
|
return [
|
||||||
(etype, state_key, v)
|
(etype, state_key, v)
|
||||||
for (etype, state_key), v in state_dict.iteritems()
|
for (etype, state_key), v in iteritems(state_dict)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -112,8 +112,9 @@ class BaseHandler(object):
|
|||||||
guest_access = event.content.get("guest_access", "forbidden")
|
guest_access = event.content.get("guest_access", "forbidden")
|
||||||
if guest_access != "can_join":
|
if guest_access != "can_join":
|
||||||
if context:
|
if context:
|
||||||
|
current_state_ids = yield context.get_current_state_ids(self.store)
|
||||||
current_state = yield self.store.get_events(
|
current_state = yield self.store.get_events(
|
||||||
list(context.current_state_ids.values())
|
list(current_state_ids.values())
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
current_state = yield self.state_handler.get_current_state(
|
current_state = yield self.state_handler.get_current_state(
|
||||||
|
@ -21,8 +21,8 @@ import logging
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from six import iteritems
|
from six import iteritems, itervalues
|
||||||
from six.moves import http_client
|
from six.moves import http_client, zip
|
||||||
|
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
from signedjson.sign import verify_signed_json
|
from signedjson.sign import verify_signed_json
|
||||||
@ -486,7 +486,10 @@ class FederationHandler(BaseHandler):
|
|||||||
# joined the room. Don't bother if the user is just
|
# joined the room. Don't bother if the user is just
|
||||||
# changing their profile info.
|
# changing their profile info.
|
||||||
newly_joined = True
|
newly_joined = True
|
||||||
prev_state_id = context.prev_state_ids.get(
|
|
||||||
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
|
|
||||||
|
prev_state_id = prev_state_ids.get(
|
||||||
(event.type, event.state_key)
|
(event.type, event.state_key)
|
||||||
)
|
)
|
||||||
if prev_state_id:
|
if prev_state_id:
|
||||||
@ -731,7 +734,7 @@ class FederationHandler(BaseHandler):
|
|||||||
"""
|
"""
|
||||||
joined_users = [
|
joined_users = [
|
||||||
(state_key, int(event.depth))
|
(state_key, int(event.depth))
|
||||||
for (e_type, state_key), event in state.iteritems()
|
for (e_type, state_key), event in iteritems(state)
|
||||||
if e_type == EventTypes.Member
|
if e_type == EventTypes.Member
|
||||||
and event.membership == Membership.JOIN
|
and event.membership == Membership.JOIN
|
||||||
]
|
]
|
||||||
@ -748,7 +751,7 @@ class FederationHandler(BaseHandler):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return sorted(joined_domains.iteritems(), key=lambda d: d[1])
|
return sorted(joined_domains.items(), key=lambda d: d[1])
|
||||||
|
|
||||||
curr_domains = get_domains_from_state(curr_state)
|
curr_domains = get_domains_from_state(curr_state)
|
||||||
|
|
||||||
@ -811,7 +814,7 @@ class FederationHandler(BaseHandler):
|
|||||||
tried_domains = set(likely_domains)
|
tried_domains = set(likely_domains)
|
||||||
tried_domains.add(self.server_name)
|
tried_domains.add(self.server_name)
|
||||||
|
|
||||||
event_ids = list(extremities.iterkeys())
|
event_ids = list(extremities.keys())
|
||||||
|
|
||||||
logger.debug("calling resolve_state_groups in _maybe_backfill")
|
logger.debug("calling resolve_state_groups in _maybe_backfill")
|
||||||
resolve = logcontext.preserve_fn(
|
resolve = logcontext.preserve_fn(
|
||||||
@ -827,15 +830,15 @@ class FederationHandler(BaseHandler):
|
|||||||
states = dict(zip(event_ids, [s.state for s in states]))
|
states = dict(zip(event_ids, [s.state for s in states]))
|
||||||
|
|
||||||
state_map = yield self.store.get_events(
|
state_map = yield self.store.get_events(
|
||||||
[e_id for ids in states.itervalues() for e_id in ids.itervalues()],
|
[e_id for ids in itervalues(states) for e_id in itervalues(ids)],
|
||||||
get_prev_content=False
|
get_prev_content=False
|
||||||
)
|
)
|
||||||
states = {
|
states = {
|
||||||
key: {
|
key: {
|
||||||
k: state_map[e_id]
|
k: state_map[e_id]
|
||||||
for k, e_id in state_dict.iteritems()
|
for k, e_id in iteritems(state_dict)
|
||||||
if e_id in state_map
|
if e_id in state_map
|
||||||
} for key, state_dict in states.iteritems()
|
} for key, state_dict in iteritems(states)
|
||||||
}
|
}
|
||||||
|
|
||||||
for e_id, _ in sorted_extremeties_tuple:
|
for e_id, _ in sorted_extremeties_tuple:
|
||||||
@ -1106,10 +1109,12 @@ class FederationHandler(BaseHandler):
|
|||||||
user = UserID.from_string(event.state_key)
|
user = UserID.from_string(event.state_key)
|
||||||
yield user_joined_room(self.distributor, user, event.room_id)
|
yield user_joined_room(self.distributor, user, event.room_id)
|
||||||
|
|
||||||
state_ids = list(context.prev_state_ids.values())
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
|
|
||||||
|
state_ids = list(prev_state_ids.values())
|
||||||
auth_chain = yield self.store.get_auth_chain(state_ids)
|
auth_chain = yield self.store.get_auth_chain(state_ids)
|
||||||
|
|
||||||
state = yield self.store.get_events(list(context.prev_state_ids.values()))
|
state = yield self.store.get_events(list(prev_state_ids.values()))
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
"state": list(state.values()),
|
"state": list(state.values()),
|
||||||
@ -1515,7 +1520,7 @@ class FederationHandler(BaseHandler):
|
|||||||
yield self.store.persist_events(
|
yield self.store.persist_events(
|
||||||
[
|
[
|
||||||
(ev_info["event"], context)
|
(ev_info["event"], context)
|
||||||
for ev_info, context in itertools.izip(event_infos, contexts)
|
for ev_info, context in zip(event_infos, contexts)
|
||||||
],
|
],
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
)
|
)
|
||||||
@ -1635,8 +1640,9 @@ class FederationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not auth_events:
|
if not auth_events:
|
||||||
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
auth_events_ids = yield self.auth.compute_auth_events(
|
auth_events_ids = yield self.auth.compute_auth_events(
|
||||||
event, context.prev_state_ids, for_verification=True,
|
event, prev_state_ids, for_verification=True,
|
||||||
)
|
)
|
||||||
auth_events = yield self.store.get_events(auth_events_ids)
|
auth_events = yield self.store.get_events(auth_events_ids)
|
||||||
auth_events = {
|
auth_events = {
|
||||||
@ -1876,9 +1882,10 @@ class FederationHandler(BaseHandler):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if do_resolution:
|
if do_resolution:
|
||||||
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
# 1. Get what we think is the auth chain.
|
# 1. Get what we think is the auth chain.
|
||||||
auth_ids = yield self.auth.compute_auth_events(
|
auth_ids = yield self.auth.compute_auth_events(
|
||||||
event, context.prev_state_ids
|
event, prev_state_ids
|
||||||
)
|
)
|
||||||
local_auth_chain = yield self.store.get_auth_chain(
|
local_auth_chain = yield self.store.get_auth_chain(
|
||||||
auth_ids, include_given=True
|
auth_ids, include_given=True
|
||||||
@ -1968,21 +1975,35 @@ class FederationHandler(BaseHandler):
|
|||||||
k: a.event_id for k, a in iteritems(auth_events)
|
k: a.event_id for k, a in iteritems(auth_events)
|
||||||
if k != event_key
|
if k != event_key
|
||||||
}
|
}
|
||||||
context.current_state_ids = dict(context.current_state_ids)
|
current_state_ids = yield context.get_current_state_ids(self.store)
|
||||||
context.current_state_ids.update(state_updates)
|
current_state_ids = dict(current_state_ids)
|
||||||
|
|
||||||
|
current_state_ids.update(state_updates)
|
||||||
|
|
||||||
if context.delta_ids is not None:
|
if context.delta_ids is not None:
|
||||||
context.delta_ids = dict(context.delta_ids)
|
delta_ids = dict(context.delta_ids)
|
||||||
context.delta_ids.update(state_updates)
|
delta_ids.update(state_updates)
|
||||||
context.prev_state_ids = dict(context.prev_state_ids)
|
|
||||||
context.prev_state_ids.update({
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
|
prev_state_ids = dict(prev_state_ids)
|
||||||
|
|
||||||
|
prev_state_ids.update({
|
||||||
k: a.event_id for k, a in iteritems(auth_events)
|
k: a.event_id for k, a in iteritems(auth_events)
|
||||||
})
|
})
|
||||||
context.state_group = yield self.store.store_state_group(
|
|
||||||
|
state_group = yield self.store.store_state_group(
|
||||||
event.event_id,
|
event.event_id,
|
||||||
event.room_id,
|
event.room_id,
|
||||||
prev_group=context.prev_group,
|
prev_group=context.prev_group,
|
||||||
delta_ids=context.delta_ids,
|
delta_ids=delta_ids,
|
||||||
current_state_ids=context.current_state_ids,
|
current_state_ids=current_state_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield context.update_state(
|
||||||
|
state_group=state_group,
|
||||||
|
current_state_ids=current_state_ids,
|
||||||
|
prev_state_ids=prev_state_ids,
|
||||||
|
delta_ids=delta_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -2222,7 +2243,8 @@ class FederationHandler(BaseHandler):
|
|||||||
event.content["third_party_invite"]["signed"]["token"]
|
event.content["third_party_invite"]["signed"]["token"]
|
||||||
)
|
)
|
||||||
original_invite = None
|
original_invite = None
|
||||||
original_invite_id = context.prev_state_ids.get(key)
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
|
original_invite_id = prev_state_ids.get(key)
|
||||||
if original_invite_id:
|
if original_invite_id:
|
||||||
original_invite = yield self.store.get_event(
|
original_invite = yield self.store.get_event(
|
||||||
original_invite_id, allow_none=True
|
original_invite_id, allow_none=True
|
||||||
@ -2264,7 +2286,8 @@ class FederationHandler(BaseHandler):
|
|||||||
signed = event.content["third_party_invite"]["signed"]
|
signed = event.content["third_party_invite"]["signed"]
|
||||||
token = signed["token"]
|
token = signed["token"]
|
||||||
|
|
||||||
invite_event_id = context.prev_state_ids.get(
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
|
invite_event_id = prev_state_ids.get(
|
||||||
(EventTypes.ThirdPartyInvite, token,)
|
(EventTypes.ThirdPartyInvite, token,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -630,7 +630,8 @@ class EventCreationHandler(object):
|
|||||||
If so, returns the version of the event in context.
|
If so, returns the version of the event in context.
|
||||||
Otherwise, returns None.
|
Otherwise, returns None.
|
||||||
"""
|
"""
|
||||||
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
|
prev_event_id = prev_state_ids.get((event.type, event.state_key))
|
||||||
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
||||||
if not prev_event:
|
if not prev_event:
|
||||||
return
|
return
|
||||||
@ -752,8 +753,8 @@ class EventCreationHandler(object):
|
|||||||
event = builder.build()
|
event = builder.build()
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Created event %s with state: %s",
|
"Created event %s",
|
||||||
event.event_id, context.prev_state_ids,
|
event.event_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(
|
defer.returnValue(
|
||||||
@ -806,8 +807,9 @@ class EventCreationHandler(object):
|
|||||||
# If we're a worker we need to hit out to the master.
|
# If we're a worker we need to hit out to the master.
|
||||||
if self.config.worker_app:
|
if self.config.worker_app:
|
||||||
yield send_event_to_master(
|
yield send_event_to_master(
|
||||||
self.hs.get_clock(),
|
clock=self.hs.get_clock(),
|
||||||
self.http_client,
|
store=self.store,
|
||||||
|
client=self.http_client,
|
||||||
host=self.config.worker_replication_host,
|
host=self.config.worker_replication_host,
|
||||||
port=self.config.worker_replication_http_port,
|
port=self.config.worker_replication_http_port,
|
||||||
requester=requester,
|
requester=requester,
|
||||||
@ -884,9 +886,11 @@ class EventCreationHandler(object):
|
|||||||
e.sender == event.sender
|
e.sender == event.sender
|
||||||
)
|
)
|
||||||
|
|
||||||
|
current_state_ids = yield context.get_current_state_ids(self.store)
|
||||||
|
|
||||||
state_to_include_ids = [
|
state_to_include_ids = [
|
||||||
e_id
|
e_id
|
||||||
for k, e_id in iteritems(context.current_state_ids)
|
for k, e_id in iteritems(current_state_ids)
|
||||||
if k[0] in self.hs.config.room_invite_state_types
|
if k[0] in self.hs.config.room_invite_state_types
|
||||||
or k == (EventTypes.Member, event.sender)
|
or k == (EventTypes.Member, event.sender)
|
||||||
]
|
]
|
||||||
@ -922,8 +926,9 @@ class EventCreationHandler(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if event.type == EventTypes.Redaction:
|
if event.type == EventTypes.Redaction:
|
||||||
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
auth_events_ids = yield self.auth.compute_auth_events(
|
auth_events_ids = yield self.auth.compute_auth_events(
|
||||||
event, context.prev_state_ids, for_verification=True,
|
event, prev_state_ids, for_verification=True,
|
||||||
)
|
)
|
||||||
auth_events = yield self.store.get_events(auth_events_ids)
|
auth_events = yield self.store.get_events(auth_events_ids)
|
||||||
auth_events = {
|
auth_events = {
|
||||||
@ -943,7 +948,9 @@ class EventCreationHandler(object):
|
|||||||
"You don't have permission to redact events"
|
"You don't have permission to redact events"
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type == EventTypes.Create and context.prev_state_ids:
|
if event.type == EventTypes.Create:
|
||||||
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
|
if prev_state_ids:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403,
|
403,
|
||||||
"Changing the room create event is forbidden",
|
"Changing the room create event is forbidden",
|
||||||
|
@ -201,7 +201,9 @@ class RoomMemberHandler(object):
|
|||||||
ratelimit=ratelimit,
|
ratelimit=ratelimit,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_member_event_id = context.prev_state_ids.get(
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
|
|
||||||
|
prev_member_event_id = prev_state_ids.get(
|
||||||
(EventTypes.Member, target.to_string()),
|
(EventTypes.Member, target.to_string()),
|
||||||
None
|
None
|
||||||
)
|
)
|
||||||
@ -496,9 +498,10 @@ class RoomMemberHandler(object):
|
|||||||
if prev_event is not None:
|
if prev_event is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
if requester.is_guest:
|
if requester.is_guest:
|
||||||
guest_can_join = yield self._can_guest_join(context.prev_state_ids)
|
guest_can_join = yield self._can_guest_join(prev_state_ids)
|
||||||
if not guest_can_join:
|
if not guest_can_join:
|
||||||
# This should be an auth check, but guests are a local concept,
|
# This should be an auth check, but guests are a local concept,
|
||||||
# so don't really fit into the general auth process.
|
# so don't really fit into the general auth process.
|
||||||
@ -517,7 +520,7 @@ class RoomMemberHandler(object):
|
|||||||
ratelimit=ratelimit,
|
ratelimit=ratelimit,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_member_event_id = context.prev_state_ids.get(
|
prev_member_event_id = prev_state_ids.get(
|
||||||
(EventTypes.Member, event.state_key),
|
(EventTypes.Member, event.state_key),
|
||||||
None
|
None
|
||||||
)
|
)
|
||||||
|
@ -274,7 +274,7 @@ class Notifier(object):
|
|||||||
logger.exception("Error notifying application services of event")
|
logger.exception("Error notifying application services of event")
|
||||||
|
|
||||||
def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
|
def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
|
||||||
""" Used to inform listeners that something has happend event wise.
|
""" Used to inform listeners that something has happened event wise.
|
||||||
|
|
||||||
Will wake up all listeners for the given users and rooms.
|
Will wake up all listeners for the given users and rooms.
|
||||||
"""
|
"""
|
||||||
|
@ -112,7 +112,8 @@ class BulkPushRuleEvaluator(object):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_power_levels_and_sender_level(self, event, context):
|
def _get_power_levels_and_sender_level(self, event, context):
|
||||||
pl_event_id = context.prev_state_ids.get(POWER_KEY)
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
|
pl_event_id = prev_state_ids.get(POWER_KEY)
|
||||||
if pl_event_id:
|
if pl_event_id:
|
||||||
# fastpath: if there's a power level event, that's all we need, and
|
# fastpath: if there's a power level event, that's all we need, and
|
||||||
# not having a power level event is an extreme edge case
|
# not having a power level event is an extreme edge case
|
||||||
@ -120,7 +121,7 @@ class BulkPushRuleEvaluator(object):
|
|||||||
auth_events = {POWER_KEY: pl_event}
|
auth_events = {POWER_KEY: pl_event}
|
||||||
else:
|
else:
|
||||||
auth_events_ids = yield self.auth.compute_auth_events(
|
auth_events_ids = yield self.auth.compute_auth_events(
|
||||||
event, context.prev_state_ids, for_verification=False,
|
event, prev_state_ids, for_verification=False,
|
||||||
)
|
)
|
||||||
auth_events = yield self.store.get_events(auth_events_ids)
|
auth_events = yield self.store.get_events(auth_events_ids)
|
||||||
auth_events = {
|
auth_events = {
|
||||||
@ -304,7 +305,7 @@ class RulesForRoom(object):
|
|||||||
|
|
||||||
push_rules_delta_state_cache_metric.inc_hits()
|
push_rules_delta_state_cache_metric.inc_hits()
|
||||||
else:
|
else:
|
||||||
current_state_ids = context.current_state_ids
|
current_state_ids = yield context.get_current_state_ids(self.store)
|
||||||
push_rules_delta_state_cache_metric.inc_misses()
|
push_rules_delta_state_cache_metric.inc_misses()
|
||||||
|
|
||||||
push_rules_state_size_counter.inc(len(current_state_ids))
|
push_rules_state_size_counter.inc(len(current_state_ids))
|
||||||
|
@ -34,12 +34,13 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_event_to_master(clock, client, host, port, requester, event, context,
|
def send_event_to_master(clock, store, client, host, port, requester, event, context,
|
||||||
ratelimit, extra_users):
|
ratelimit, extra_users):
|
||||||
"""Send event to be handled on the master
|
"""Send event to be handled on the master
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
clock (synapse.util.Clock)
|
clock (synapse.util.Clock)
|
||||||
|
store (DataStore)
|
||||||
client (SimpleHttpClient)
|
client (SimpleHttpClient)
|
||||||
host (str): host of master
|
host (str): host of master
|
||||||
port (int): port on master listening for HTTP replication
|
port (int): port on master listening for HTTP replication
|
||||||
@ -53,11 +54,13 @@ def send_event_to_master(clock, client, host, port, requester, event, context,
|
|||||||
host, port, event.event_id,
|
host, port, event.event_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
serialized_context = yield context.serialize(event, store)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"event": event.get_pdu_json(),
|
"event": event.get_pdu_json(),
|
||||||
"internal_metadata": event.internal_metadata.get_dict(),
|
"internal_metadata": event.internal_metadata.get_dict(),
|
||||||
"rejected_reason": event.rejected_reason,
|
"rejected_reason": event.rejected_reason,
|
||||||
"context": context.serialize(event),
|
"context": serialized_context,
|
||||||
"requester": requester.serialize(),
|
"requester": requester.serialize(),
|
||||||
"ratelimit": ratelimit,
|
"ratelimit": ratelimit,
|
||||||
"extra_users": [u.to_string() for u in extra_users],
|
"extra_users": [u.to_string() for u in extra_users],
|
||||||
|
103
synapse/state.py
103
synapse/state.py
@ -18,7 +18,7 @@ import hashlib
|
|||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
from six import iteritems, itervalues
|
from six import iteritems, iterkeys, itervalues
|
||||||
|
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
|
||||||
@ -203,25 +203,27 @@ class StateHandler(object):
|
|||||||
# If this is an outlier, then we know it shouldn't have any current
|
# If this is an outlier, then we know it shouldn't have any current
|
||||||
# state. Certainly store.get_current_state won't return any, and
|
# state. Certainly store.get_current_state won't return any, and
|
||||||
# persisting the event won't store the state group.
|
# persisting the event won't store the state group.
|
||||||
context = EventContext()
|
|
||||||
if old_state:
|
if old_state:
|
||||||
context.prev_state_ids = {
|
prev_state_ids = {
|
||||||
(s.type, s.state_key): s.event_id for s in old_state
|
(s.type, s.state_key): s.event_id for s in old_state
|
||||||
}
|
}
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
context.current_state_ids = dict(context.prev_state_ids)
|
current_state_ids = dict(prev_state_ids)
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
context.current_state_ids[key] = event.event_id
|
current_state_ids[key] = event.event_id
|
||||||
else:
|
else:
|
||||||
context.current_state_ids = context.prev_state_ids
|
current_state_ids = prev_state_ids
|
||||||
else:
|
else:
|
||||||
context.current_state_ids = {}
|
current_state_ids = {}
|
||||||
context.prev_state_ids = {}
|
prev_state_ids = {}
|
||||||
context.prev_state_events = []
|
|
||||||
|
|
||||||
# We don't store state for outliers, so we don't generate a state
|
# We don't store state for outliers, so we don't generate a state
|
||||||
# froup for it.
|
# group for it.
|
||||||
context.state_group = None
|
context = EventContext.with_state(
|
||||||
|
state_group=None,
|
||||||
|
current_state_ids=current_state_ids,
|
||||||
|
prev_state_ids=prev_state_ids,
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
@ -230,31 +232,35 @@ class StateHandler(object):
|
|||||||
# Let's just correctly fill out the context and create a
|
# Let's just correctly fill out the context and create a
|
||||||
# new state group for it.
|
# new state group for it.
|
||||||
|
|
||||||
context = EventContext()
|
prev_state_ids = {
|
||||||
context.prev_state_ids = {
|
|
||||||
(s.type, s.state_key): s.event_id for s in old_state
|
(s.type, s.state_key): s.event_id for s in old_state
|
||||||
}
|
}
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
if key in context.prev_state_ids:
|
if key in prev_state_ids:
|
||||||
replaces = context.prev_state_ids[key]
|
replaces = prev_state_ids[key]
|
||||||
if replaces != event.event_id: # Paranoia check
|
if replaces != event.event_id: # Paranoia check
|
||||||
event.unsigned["replaces_state"] = replaces
|
event.unsigned["replaces_state"] = replaces
|
||||||
context.current_state_ids = dict(context.prev_state_ids)
|
current_state_ids = dict(prev_state_ids)
|
||||||
context.current_state_ids[key] = event.event_id
|
current_state_ids[key] = event.event_id
|
||||||
else:
|
else:
|
||||||
context.current_state_ids = context.prev_state_ids
|
current_state_ids = prev_state_ids
|
||||||
|
|
||||||
context.state_group = yield self.store.store_state_group(
|
state_group = yield self.store.store_state_group(
|
||||||
event.event_id,
|
event.event_id,
|
||||||
event.room_id,
|
event.room_id,
|
||||||
prev_group=None,
|
prev_group=None,
|
||||||
delta_ids=None,
|
delta_ids=None,
|
||||||
current_state_ids=context.current_state_ids,
|
current_state_ids=current_state_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
context = EventContext.with_state(
|
||||||
|
state_group=state_group,
|
||||||
|
current_state_ids=current_state_ids,
|
||||||
|
prev_state_ids=prev_state_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
context.prev_state_events = []
|
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
logger.debug("calling resolve_state_groups from compute_event_context")
|
logger.debug("calling resolve_state_groups from compute_event_context")
|
||||||
@ -262,47 +268,47 @@ class StateHandler(object):
|
|||||||
event.room_id, [e for e, _ in event.prev_events],
|
event.room_id, [e for e, _ in event.prev_events],
|
||||||
)
|
)
|
||||||
|
|
||||||
curr_state = entry.state
|
prev_state_ids = entry.state
|
||||||
|
prev_group = None
|
||||||
|
delta_ids = None
|
||||||
|
|
||||||
context = EventContext()
|
|
||||||
context.prev_state_ids = curr_state
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
# If this is a state event then we need to create a new state
|
# If this is a state event then we need to create a new state
|
||||||
# group for the state after this event.
|
# group for the state after this event.
|
||||||
|
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
if key in context.prev_state_ids:
|
if key in prev_state_ids:
|
||||||
replaces = context.prev_state_ids[key]
|
replaces = prev_state_ids[key]
|
||||||
event.unsigned["replaces_state"] = replaces
|
event.unsigned["replaces_state"] = replaces
|
||||||
|
|
||||||
context.current_state_ids = dict(context.prev_state_ids)
|
current_state_ids = dict(prev_state_ids)
|
||||||
context.current_state_ids[key] = event.event_id
|
current_state_ids[key] = event.event_id
|
||||||
|
|
||||||
if entry.state_group:
|
if entry.state_group:
|
||||||
# If the state at the event has a state group assigned then
|
# If the state at the event has a state group assigned then
|
||||||
# we can use that as the prev group
|
# we can use that as the prev group
|
||||||
context.prev_group = entry.state_group
|
prev_group = entry.state_group
|
||||||
context.delta_ids = {
|
delta_ids = {
|
||||||
key: event.event_id
|
key: event.event_id
|
||||||
}
|
}
|
||||||
elif entry.prev_group:
|
elif entry.prev_group:
|
||||||
# If the state at the event only has a prev group, then we can
|
# If the state at the event only has a prev group, then we can
|
||||||
# use that as a prev group too.
|
# use that as a prev group too.
|
||||||
context.prev_group = entry.prev_group
|
prev_group = entry.prev_group
|
||||||
context.delta_ids = dict(entry.delta_ids)
|
delta_ids = dict(entry.delta_ids)
|
||||||
context.delta_ids[key] = event.event_id
|
delta_ids[key] = event.event_id
|
||||||
|
|
||||||
context.state_group = yield self.store.store_state_group(
|
state_group = yield self.store.store_state_group(
|
||||||
event.event_id,
|
event.event_id,
|
||||||
event.room_id,
|
event.room_id,
|
||||||
prev_group=context.prev_group,
|
prev_group=prev_group,
|
||||||
delta_ids=context.delta_ids,
|
delta_ids=delta_ids,
|
||||||
current_state_ids=context.current_state_ids,
|
current_state_ids=current_state_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
context.current_state_ids = context.prev_state_ids
|
current_state_ids = prev_state_ids
|
||||||
context.prev_group = entry.prev_group
|
prev_group = entry.prev_group
|
||||||
context.delta_ids = entry.delta_ids
|
delta_ids = entry.delta_ids
|
||||||
|
|
||||||
if entry.state_group is None:
|
if entry.state_group is None:
|
||||||
entry.state_group = yield self.store.store_state_group(
|
entry.state_group = yield self.store.store_state_group(
|
||||||
@ -310,13 +316,20 @@ class StateHandler(object):
|
|||||||
event.room_id,
|
event.room_id,
|
||||||
prev_group=entry.prev_group,
|
prev_group=entry.prev_group,
|
||||||
delta_ids=entry.delta_ids,
|
delta_ids=entry.delta_ids,
|
||||||
current_state_ids=context.current_state_ids,
|
current_state_ids=current_state_ids,
|
||||||
)
|
)
|
||||||
entry.state_id = entry.state_group
|
entry.state_id = entry.state_group
|
||||||
|
|
||||||
context.state_group = entry.state_group
|
state_group = entry.state_group
|
||||||
|
|
||||||
|
context = EventContext.with_state(
|
||||||
|
state_group=state_group,
|
||||||
|
current_state_ids=current_state_ids,
|
||||||
|
prev_state_ids=prev_state_ids,
|
||||||
|
prev_group=prev_group,
|
||||||
|
delta_ids=delta_ids,
|
||||||
|
)
|
||||||
|
|
||||||
context.prev_state_events = []
|
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -647,7 +660,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
|
|||||||
for event_id in event_ids
|
for event_id in event_ids
|
||||||
)
|
)
|
||||||
if event_map is not None:
|
if event_map is not None:
|
||||||
needed_events -= set(event_map.iterkeys())
|
needed_events -= set(iterkeys(event_map))
|
||||||
|
|
||||||
logger.info("Asking for %d conflicted events", len(needed_events))
|
logger.info("Asking for %d conflicted events", len(needed_events))
|
||||||
|
|
||||||
@ -668,7 +681,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
|
|||||||
new_needed_events = set(itervalues(auth_events))
|
new_needed_events = set(itervalues(auth_events))
|
||||||
new_needed_events -= needed_events
|
new_needed_events -= needed_events
|
||||||
if event_map is not None:
|
if event_map is not None:
|
||||||
new_needed_events -= set(event_map.iterkeys())
|
new_needed_events -= set(iterkeys(event_map))
|
||||||
|
|
||||||
logger.info("Asking for %d auth events", len(new_needed_events))
|
logger.info("Asking for %d auth events", len(new_needed_events))
|
||||||
|
|
||||||
|
@ -248,6 +248,20 @@ class DeviceStore(SQLBaseStore):
|
|||||||
|
|
||||||
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
|
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
|
||||||
content, stream_id):
|
content, stream_id):
|
||||||
|
if content.get("deleted"):
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
table="device_lists_remote_cache",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.call_after(
|
||||||
|
self.device_id_exists_cache.invalidate, (user_id, device_id,)
|
||||||
|
)
|
||||||
|
else:
|
||||||
self._simple_upsert_txn(
|
self._simple_upsert_txn(
|
||||||
txn,
|
txn,
|
||||||
table="device_lists_remote_cache",
|
table="device_lists_remote_cache",
|
||||||
@ -366,7 +380,7 @@ class DeviceStore(SQLBaseStore):
|
|||||||
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
|
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
|
||||||
|
|
||||||
devices = self._get_e2e_device_keys_txn(
|
devices = self._get_e2e_device_keys_txn(
|
||||||
txn, query_map.keys(), include_all_devices=True
|
txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_sent_id_sql = """
|
prev_sent_id_sql = """
|
||||||
@ -393,12 +407,15 @@ class DeviceStore(SQLBaseStore):
|
|||||||
|
|
||||||
prev_id = stream_id
|
prev_id = stream_id
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
key_json = device.get("key_json", None)
|
key_json = device.get("key_json", None)
|
||||||
if key_json:
|
if key_json:
|
||||||
result["keys"] = json.loads(key_json)
|
result["keys"] = json.loads(key_json)
|
||||||
device_display_name = device.get("device_display_name", None)
|
device_display_name = device.get("device_display_name", None)
|
||||||
if device_display_name:
|
if device_display_name:
|
||||||
result["device_display_name"] = device_display_name
|
result["device_display_name"] = device_display_name
|
||||||
|
else:
|
||||||
|
result["deleted"] = True
|
||||||
|
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
|
@ -64,12 +64,18 @@ class EndToEndKeyStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_e2e_device_keys(self, query_list, include_all_devices=False):
|
def get_e2e_device_keys(
|
||||||
|
self, query_list, include_all_devices=False,
|
||||||
|
include_deleted_devices=False,
|
||||||
|
):
|
||||||
"""Fetch a list of device keys.
|
"""Fetch a list of device keys.
|
||||||
Args:
|
Args:
|
||||||
query_list(list): List of pairs of user_ids and device_ids.
|
query_list(list): List of pairs of user_ids and device_ids.
|
||||||
include_all_devices (bool): whether to include entries for devices
|
include_all_devices (bool): whether to include entries for devices
|
||||||
that don't have device keys
|
that don't have device keys
|
||||||
|
include_deleted_devices (bool): whether to include null entries for
|
||||||
|
devices which no longer exist (but were in the query_list).
|
||||||
|
This option only takes effect if include_all_devices is true.
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping from user-id to dict mapping from device_id to
|
Dict mapping from user-id to dict mapping from device_id to
|
||||||
dict containing "key_json", "device_display_name".
|
dict containing "key_json", "device_display_name".
|
||||||
@ -79,7 +85,7 @@ class EndToEndKeyStore(SQLBaseStore):
|
|||||||
|
|
||||||
results = yield self.runInteraction(
|
results = yield self.runInteraction(
|
||||||
"get_e2e_device_keys", self._get_e2e_device_keys_txn,
|
"get_e2e_device_keys", self._get_e2e_device_keys_txn,
|
||||||
query_list, include_all_devices,
|
query_list, include_all_devices, include_deleted_devices,
|
||||||
)
|
)
|
||||||
|
|
||||||
for user_id, device_keys in iteritems(results):
|
for user_id, device_keys in iteritems(results):
|
||||||
@ -88,10 +94,19 @@ class EndToEndKeyStore(SQLBaseStore):
|
|||||||
|
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
|
def _get_e2e_device_keys_txn(
|
||||||
|
self, txn, query_list, include_all_devices=False,
|
||||||
|
include_deleted_devices=False,
|
||||||
|
):
|
||||||
query_clauses = []
|
query_clauses = []
|
||||||
query_params = []
|
query_params = []
|
||||||
|
|
||||||
|
if include_all_devices is False:
|
||||||
|
include_deleted_devices = False
|
||||||
|
|
||||||
|
if include_deleted_devices:
|
||||||
|
deleted_devices = set(query_list)
|
||||||
|
|
||||||
for (user_id, device_id) in query_list:
|
for (user_id, device_id) in query_list:
|
||||||
query_clause = "user_id = ?"
|
query_clause = "user_id = ?"
|
||||||
query_params.append(user_id)
|
query_params.append(user_id)
|
||||||
@ -119,8 +134,14 @@ class EndToEndKeyStore(SQLBaseStore):
|
|||||||
|
|
||||||
result = {}
|
result = {}
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
if include_deleted_devices:
|
||||||
|
deleted_devices.remove((row["user_id"], row["device_id"]))
|
||||||
result.setdefault(row["user_id"], {})[row["device_id"]] = row
|
result.setdefault(row["user_id"], {})[row["device_id"]] = row
|
||||||
|
|
||||||
|
if include_deleted_devices:
|
||||||
|
for user_id, device_id in deleted_devices:
|
||||||
|
result.setdefault(user_id, {})[device_id] = None
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -549,7 +549,7 @@ class EventsStore(EventsWorkerStore):
|
|||||||
if ctx.state_group in state_groups_map:
|
if ctx.state_group in state_groups_map:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
state_groups_map[ctx.state_group] = ctx.current_state_ids
|
state_groups_map[ctx.state_group] = yield ctx.get_current_state_ids(self)
|
||||||
|
|
||||||
# We need to map the event_ids to their state groups. First, let's
|
# We need to map the event_ids to their state groups. First, let's
|
||||||
# check if the event is one we're persisting, in which case we can
|
# check if the event is one we're persisting, in which case we can
|
||||||
|
@ -185,6 +185,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
|||||||
|
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def bulk_get_push_rules_for_room(self, event, context):
|
def bulk_get_push_rules_for_room(self, event, context):
|
||||||
state_group = context.state_group
|
state_group = context.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
@ -194,9 +195,11 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
|
|||||||
# To do this we set the state_group to a new object as object() != object()
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
state_group = object()
|
state_group = object()
|
||||||
|
|
||||||
return self._bulk_get_push_rules_for_room(
|
current_state_ids = yield context.get_current_state_ids(self)
|
||||||
event.room_id, state_group, context.current_state_ids, event=event
|
result = yield self._bulk_get_push_rules_for_room(
|
||||||
|
event.room_id, state_group, current_state_ids, event=event
|
||||||
)
|
)
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
||||||
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
|
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
|
||||||
|
@ -232,6 +232,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||||||
|
|
||||||
defer.returnValue(user_who_share_room)
|
defer.returnValue(user_who_share_room)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def get_joined_users_from_context(self, event, context):
|
def get_joined_users_from_context(self, event, context):
|
||||||
state_group = context.state_group
|
state_group = context.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
@ -241,11 +242,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||||||
# To do this we set the state_group to a new object as object() != object()
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
state_group = object()
|
state_group = object()
|
||||||
|
|
||||||
return self._get_joined_users_from_context(
|
current_state_ids = yield context.get_current_state_ids(self)
|
||||||
event.room_id, state_group, context.current_state_ids,
|
result = yield self._get_joined_users_from_context(
|
||||||
|
event.room_id, state_group, current_state_ids,
|
||||||
event=event,
|
event=event,
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
def get_joined_users_from_state(self, room_id, state_entry):
|
def get_joined_users_from_state(self, room_id, state_entry):
|
||||||
state_group = state_entry.state_group
|
state_group = state_entry.state_group
|
||||||
|
@ -184,13 +184,13 @@ class Linearizer(object):
|
|||||||
|
|
||||||
# key_to_defer is a map from the key to a 2 element list where
|
# key_to_defer is a map from the key to a 2 element list where
|
||||||
# the first element is the number of things executing, and
|
# the first element is the number of things executing, and
|
||||||
# the second element is a deque of deferreds for the things blocked from
|
# the second element is an OrderedDict, where the keys are deferreds for the
|
||||||
# executing.
|
# things blocked from executing.
|
||||||
self.key_to_defer = {}
|
self.key_to_defer = {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def queue(self, key):
|
def queue(self, key):
|
||||||
entry = self.key_to_defer.setdefault(key, [0, collections.deque()])
|
entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
|
||||||
|
|
||||||
# If the number of things executing is greater than the maximum
|
# If the number of things executing is greater than the maximum
|
||||||
# then add a deferred to the list of blocked items
|
# then add a deferred to the list of blocked items
|
||||||
@ -198,12 +198,28 @@ class Linearizer(object):
|
|||||||
# this item so that it can continue executing.
|
# this item so that it can continue executing.
|
||||||
if entry[0] >= self.max_count:
|
if entry[0] >= self.max_count:
|
||||||
new_defer = defer.Deferred()
|
new_defer = defer.Deferred()
|
||||||
entry[1].append(new_defer)
|
entry[1][new_defer] = 1
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Waiting to acquire linearizer lock %r for key %r", self.name, key,
|
"Waiting to acquire linearizer lock %r for key %r", self.name, key,
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
yield make_deferred_yieldable(new_defer)
|
yield make_deferred_yieldable(new_defer)
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, CancelledError):
|
||||||
|
logger.info(
|
||||||
|
"Cancelling wait for linearizer lock %r for key %r",
|
||||||
|
self.name, key,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warn(
|
||||||
|
"Unexpected exception waiting for linearizer lock %r for key %r",
|
||||||
|
self.name, key,
|
||||||
|
)
|
||||||
|
|
||||||
|
# we just have to take ourselves back out of the queue.
|
||||||
|
del entry[1][new_defer]
|
||||||
|
raise
|
||||||
|
|
||||||
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
|
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
|
||||||
entry[0] += 1
|
entry[0] += 1
|
||||||
@ -238,7 +254,7 @@ class Linearizer(object):
|
|||||||
entry[0] -= 1
|
entry[0] -= 1
|
||||||
|
|
||||||
if entry[1]:
|
if entry[1]:
|
||||||
next_def = entry[1].popleft()
|
(next_def, _) = entry[1].popitem(last=False)
|
||||||
|
|
||||||
# we need to run the next thing in the sentinel context.
|
# we need to run the next thing in the sentinel context.
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
|
@ -12,11 +12,12 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import itertools
|
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
|
|
||||||
import six
|
from six import iteritems, itervalues
|
||||||
|
from six.moves import map
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
@ -204,7 +205,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
|
|||||||
return event
|
return event
|
||||||
|
|
||||||
# check each event: gives an iterable[None|EventBase]
|
# check each event: gives an iterable[None|EventBase]
|
||||||
filtered_events = itertools.imap(allowed, events)
|
filtered_events = map(allowed, events)
|
||||||
|
|
||||||
# remove the None entries
|
# remove the None entries
|
||||||
filtered_events = filter(operator.truth, filtered_events)
|
filtered_events = filter(operator.truth, filtered_events)
|
||||||
@ -244,7 +245,7 @@ def filter_events_for_server(store, server_name, events):
|
|||||||
# membership states for the requesting server to determine
|
# membership states for the requesting server to determine
|
||||||
# if the server is either in the room or has been invited
|
# if the server is either in the room or has been invited
|
||||||
# into the room.
|
# into the room.
|
||||||
for ev in state.itervalues():
|
for ev in itervalues(state):
|
||||||
if ev.type != EventTypes.Member:
|
if ev.type != EventTypes.Member:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
@ -278,7 +279,7 @@ def filter_events_for_server(store, server_name, events):
|
|||||||
)
|
)
|
||||||
|
|
||||||
visibility_ids = set()
|
visibility_ids = set()
|
||||||
for sids in event_to_state_ids.itervalues():
|
for sids in itervalues(event_to_state_ids):
|
||||||
hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
|
hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
|
||||||
if hist:
|
if hist:
|
||||||
visibility_ids.add(hist)
|
visibility_ids.add(hist)
|
||||||
@ -291,7 +292,7 @@ def filter_events_for_server(store, server_name, events):
|
|||||||
event_map = yield store.get_events(visibility_ids)
|
event_map = yield store.get_events(visibility_ids)
|
||||||
all_open = all(
|
all_open = all(
|
||||||
e.content.get("history_visibility") in (None, "shared", "world_readable")
|
e.content.get("history_visibility") in (None, "shared", "world_readable")
|
||||||
for e in event_map.itervalues()
|
for e in itervalues(event_map)
|
||||||
)
|
)
|
||||||
|
|
||||||
if all_open:
|
if all_open:
|
||||||
@ -329,7 +330,7 @@ def filter_events_for_server(store, server_name, events):
|
|||||||
#
|
#
|
||||||
state_key_to_event_id_set = {
|
state_key_to_event_id_set = {
|
||||||
e
|
e
|
||||||
for key_to_eid in six.itervalues(event_to_state_ids)
|
for key_to_eid in itervalues(event_to_state_ids)
|
||||||
for e in key_to_eid.items()
|
for e in key_to_eid.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -352,10 +353,10 @@ def filter_events_for_server(store, server_name, events):
|
|||||||
event_to_state = {
|
event_to_state = {
|
||||||
e_id: {
|
e_id: {
|
||||||
key: event_map[inner_e_id]
|
key: event_map[inner_e_id]
|
||||||
for key, inner_e_id in key_to_eid.iteritems()
|
for key, inner_e_id in iteritems(key_to_eid)
|
||||||
if inner_e_id in event_map
|
if inner_e_id in event_map
|
||||||
}
|
}
|
||||||
for e_id, key_to_eid in event_to_state_ids.iteritems()
|
for e_id, key_to_eid in iteritems(event_to_state_ids)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer.returnValue([
|
defer.returnValue([
|
||||||
|
@ -222,9 +222,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
|||||||
state_ids = {
|
state_ids = {
|
||||||
key: e.event_id for key, e in state.items()
|
key: e.event_id for key, e in state.items()
|
||||||
}
|
}
|
||||||
context = EventContext()
|
context = EventContext.with_state(
|
||||||
context.current_state_ids = state_ids
|
state_group=None,
|
||||||
context.prev_state_ids = state_ids
|
current_state_ids=state_ids,
|
||||||
|
prev_state_ids=state_ids
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
state_handler = self.hs.get_state_handler()
|
state_handler = self.hs.get_state_handler()
|
||||||
context = yield state_handler.compute_event_context(event)
|
context = yield state_handler.compute_event_context(event)
|
||||||
|
@ -137,7 +137,6 @@ class MessageAcceptTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
|
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
|
||||||
|
|
||||||
@unittest.DEBUG
|
|
||||||
def test_cant_hide_past_history(self):
|
def test_cant_hide_past_history(self):
|
||||||
"""
|
"""
|
||||||
If you send a message, you must be able to provide the direct
|
If you send a message, you must be able to provide the direct
|
||||||
@ -178,7 +177,7 @@ class MessageAcceptTests(unittest.TestCase):
|
|||||||
for x, y in d.items()
|
for x, y in d.items()
|
||||||
if x == ("m.room.member", "@us:test")
|
if x == ("m.room.member", "@us:test")
|
||||||
],
|
],
|
||||||
"auth_chain_ids": d.values(),
|
"auth_chain_ids": list(d.values()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -204,7 +204,8 @@ class StateTestCase(unittest.TestCase):
|
|||||||
self.store.register_event_context(event, context)
|
self.store.register_event_context(event, context)
|
||||||
context_store[event.event_id] = context
|
context_store[event.event_id] = context
|
||||||
|
|
||||||
self.assertEqual(2, len(context_store["D"].prev_state_ids))
|
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
|
||||||
|
self.assertEqual(2, len(prev_state_ids))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_branch_basic_conflict(self):
|
def test_branch_basic_conflict(self):
|
||||||
@ -255,9 +256,11 @@ class StateTestCase(unittest.TestCase):
|
|||||||
self.store.register_event_context(event, context)
|
self.store.register_event_context(event, context)
|
||||||
context_store[event.event_id] = context
|
context_store[event.event_id] = context
|
||||||
|
|
||||||
|
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"START", "A", "C"},
|
{"START", "A", "C"},
|
||||||
{e_id for e_id in context_store["D"].prev_state_ids.values()}
|
{e_id for e_id in prev_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -318,9 +321,11 @@ class StateTestCase(unittest.TestCase):
|
|||||||
self.store.register_event_context(event, context)
|
self.store.register_event_context(event, context)
|
||||||
context_store[event.event_id] = context
|
context_store[event.event_id] = context
|
||||||
|
|
||||||
|
prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"START", "A", "B", "C"},
|
{"START", "A", "B", "C"},
|
||||||
{e for e in context_store["E"].prev_state_ids.values()}
|
{e for e in prev_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -398,9 +403,11 @@ class StateTestCase(unittest.TestCase):
|
|||||||
self.store.register_event_context(event, context)
|
self.store.register_event_context(event, context)
|
||||||
context_store[event.event_id] = context
|
context_store[event.event_id] = context
|
||||||
|
|
||||||
|
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"A1", "A2", "A3", "A5", "B"},
|
{"A1", "A2", "A3", "A5", "B"},
|
||||||
{e for e in context_store["D"].prev_state_ids.values()}
|
{e for e in prev_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_depths(self, nodes, edges):
|
def _add_depths(self, nodes, edges):
|
||||||
@ -429,8 +436,10 @@ class StateTestCase(unittest.TestCase):
|
|||||||
event, old_state=old_state
|
event, old_state=old_state
|
||||||
)
|
)
|
||||||
|
|
||||||
|
current_state_ids = yield context.get_current_state_ids(self.store)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set(e.event_id for e in old_state), set(context.current_state_ids.values())
|
set(e.event_id for e in old_state), set(current_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNotNone(context.state_group)
|
self.assertIsNotNone(context.state_group)
|
||||||
@ -449,8 +458,10 @@ class StateTestCase(unittest.TestCase):
|
|||||||
event, old_state=old_state
|
event, old_state=old_state
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set(e.event_id for e in old_state), set(context.prev_state_ids.values())
|
set(e.event_id for e in old_state), set(prev_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -475,9 +486,11 @@ class StateTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
context = yield self.state.compute_event_context(event)
|
context = yield self.state.compute_event_context(event)
|
||||||
|
|
||||||
|
current_state_ids = yield context.get_current_state_ids(self.store)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set([e.event_id for e in old_state]),
|
set([e.event_id for e in old_state]),
|
||||||
set(context.current_state_ids.values())
|
set(current_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(group_name, context.state_group)
|
self.assertEqual(group_name, context.state_group)
|
||||||
@ -504,9 +517,11 @@ class StateTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
context = yield self.state.compute_event_context(event)
|
context = yield self.state.compute_event_context(event)
|
||||||
|
|
||||||
|
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set([e.event_id for e in old_state]),
|
set([e.event_id for e in old_state]),
|
||||||
set(context.prev_state_ids.values())
|
set(prev_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNotNone(context.state_group)
|
self.assertIsNotNone(context.state_group)
|
||||||
@ -545,7 +560,9 @@ class StateTestCase(unittest.TestCase):
|
|||||||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
|
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(context.current_state_ids), 6)
|
current_state_ids = yield context.get_current_state_ids(self.store)
|
||||||
|
|
||||||
|
self.assertEqual(len(current_state_ids), 6)
|
||||||
|
|
||||||
self.assertIsNotNone(context.state_group)
|
self.assertIsNotNone(context.state_group)
|
||||||
|
|
||||||
@ -585,7 +602,9 @@ class StateTestCase(unittest.TestCase):
|
|||||||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
|
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(len(context.current_state_ids), 6)
|
current_state_ids = yield context.get_current_state_ids(self.store)
|
||||||
|
|
||||||
|
self.assertEqual(len(current_state_ids), 6)
|
||||||
|
|
||||||
self.assertIsNotNone(context.state_group)
|
self.assertIsNotNone(context.state_group)
|
||||||
|
|
||||||
@ -642,8 +661,10 @@ class StateTestCase(unittest.TestCase):
|
|||||||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
|
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
current_state_ids = yield context.get_current_state_ids(self.store)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
old_state_2[3].event_id, context.current_state_ids[("test1", "1")]
|
old_state_2[3].event_id, current_state_ids[("test1", "1")]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reverse the depth to make sure we are actually using the depths
|
# Reverse the depth to make sure we are actually using the depths
|
||||||
@ -670,8 +691,10 @@ class StateTestCase(unittest.TestCase):
|
|||||||
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
|
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
current_state_ids = yield context.get_current_state_ids(self.store)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
old_state_1[3].event_id, context.current_state_ids[("test1", "1")]
|
old_state_1[3].event_id, current_state_ids[("test1", "1")]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
|
def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
from six.moves import range
|
from six.moves import range
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
|
from twisted.internet.defer import CancelledError
|
||||||
|
|
||||||
from synapse.util import Clock, logcontext
|
from synapse.util import Clock, logcontext
|
||||||
from synapse.util.async import Linearizer
|
from synapse.util.async import Linearizer
|
||||||
@ -112,3 +113,33 @@ class LinearizerTestCase(unittest.TestCase):
|
|||||||
d6 = limiter.queue(key)
|
d6 = limiter.queue(key)
|
||||||
with (yield d6):
|
with (yield d6):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_cancellation(self):
|
||||||
|
linearizer = Linearizer()
|
||||||
|
|
||||||
|
key = object()
|
||||||
|
|
||||||
|
d1 = linearizer.queue(key)
|
||||||
|
cm1 = yield d1
|
||||||
|
|
||||||
|
d2 = linearizer.queue(key)
|
||||||
|
self.assertFalse(d2.called)
|
||||||
|
|
||||||
|
d3 = linearizer.queue(key)
|
||||||
|
self.assertFalse(d3.called)
|
||||||
|
|
||||||
|
d2.cancel()
|
||||||
|
|
||||||
|
with cm1:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.assertTrue(d2.called)
|
||||||
|
try:
|
||||||
|
yield d2
|
||||||
|
self.fail("Expected d2 to raise CancelledError")
|
||||||
|
except CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
with (yield d3):
|
||||||
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user