mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-10-01 08:25:44 -04:00
Merge branch 'develop' into markjh/direct_to_device_synchrotron
This commit is contained in:
commit
965168a842
@ -134,6 +134,12 @@ Installing prerequisites on Raspbian::
|
|||||||
sudo pip install --upgrade ndg-httpsclient
|
sudo pip install --upgrade ndg-httpsclient
|
||||||
sudo pip install --upgrade virtualenv
|
sudo pip install --upgrade virtualenv
|
||||||
|
|
||||||
|
Installing prerequisites on openSUSE::
|
||||||
|
|
||||||
|
sudo zypper in -t pattern devel_basis
|
||||||
|
sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
|
||||||
|
python-devel libffi-devel libopenssl-devel libjpeg62-devel
|
||||||
|
|
||||||
To install the synapse homeserver run::
|
To install the synapse homeserver run::
|
||||||
|
|
||||||
virtualenv -p python2.7 ~/.synapse
|
virtualenv -p python2.7 ~/.synapse
|
||||||
@ -230,9 +236,6 @@ The advantages of Postgres include:
|
|||||||
pointing at the same DB master, as well as enabling DB replication in
|
pointing at the same DB master, as well as enabling DB replication in
|
||||||
synapse itself.
|
synapse itself.
|
||||||
|
|
||||||
The only disadvantage is that the code is relatively new as of April 2015 and
|
|
||||||
may have a few regressions relative to SQLite.
|
|
||||||
|
|
||||||
For information on how to install and use PostgreSQL, please see
|
For information on how to install and use PostgreSQL, please see
|
||||||
`docs/postgres.rst <docs/postgres.rst>`_.
|
`docs/postgres.rst <docs/postgres.rst>`_.
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ class Auth(object):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_from_context(self, event, context, do_sig_check=True):
|
def check_from_context(self, event, context, do_sig_check=True):
|
||||||
auth_events_ids = yield self.compute_auth_events(
|
auth_events_ids = yield self.compute_auth_events(
|
||||||
event, context.current_state_ids, for_verification=True,
|
event, context.prev_state_ids, for_verification=True,
|
||||||
)
|
)
|
||||||
auth_events = yield self.store.get_events(auth_events_ids)
|
auth_events = yield self.store.get_events(auth_events_ids)
|
||||||
auth_events = {
|
auth_events = {
|
||||||
@ -281,11 +281,13 @@ class Auth(object):
|
|||||||
with Measure(self.clock, "check_host_in_room"):
|
with Measure(self.clock, "check_host_in_room"):
|
||||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
|
|
||||||
group, curr_state_ids = yield self.state.resolve_state_groups(
|
entry = yield self.state.resolve_state_groups(
|
||||||
room_id, latest_event_ids
|
room_id, latest_event_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
ret = yield self.store.is_host_joined(room_id, host, group, curr_state_ids)
|
ret = yield self.store.is_host_joined(
|
||||||
|
room_id, host, entry.state_group, entry.state
|
||||||
|
)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
def check_event_sender_in_room(self, event, auth_events):
|
def check_event_sender_in_room(self, event, auth_events):
|
||||||
@ -852,7 +854,7 @@ class Auth(object):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_auth_events(self, builder, context):
|
def add_auth_events(self, builder, context):
|
||||||
auth_ids = yield self.compute_auth_events(builder, context.current_state_ids)
|
auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids)
|
||||||
|
|
||||||
auth_events_entries = yield self.store.add_event_hashes(
|
auth_events_entries = yield self.store.add_event_hashes(
|
||||||
auth_ids
|
auth_ids
|
||||||
|
@ -67,6 +67,8 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def query_user(self, service, user_id):
|
def query_user(self, service, user_id):
|
||||||
|
if service.url is None:
|
||||||
|
defer.returnValue(False)
|
||||||
uri = service.url + ("/users/%s" % urllib.quote(user_id))
|
uri = service.url + ("/users/%s" % urllib.quote(user_id))
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
@ -86,6 +88,8 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def query_alias(self, service, alias):
|
def query_alias(self, service, alias):
|
||||||
|
if service.url is None:
|
||||||
|
defer.returnValue(False)
|
||||||
uri = service.url + ("/rooms/%s" % urllib.quote(alias))
|
uri = service.url + ("/rooms/%s" % urllib.quote(alias))
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
@ -113,6 +117,8 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognised 'kind' argument %r to query_3pe()", kind
|
"Unrecognised 'kind' argument %r to query_3pe()", kind
|
||||||
)
|
)
|
||||||
|
if service.url is None:
|
||||||
|
defer.returnValue([])
|
||||||
|
|
||||||
uri = "%s%s/thirdparty/%s/%s" % (
|
uri = "%s%s/thirdparty/%s/%s" % (
|
||||||
service.url,
|
service.url,
|
||||||
@ -145,6 +151,9 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||||||
defer.returnValue([])
|
defer.returnValue([])
|
||||||
|
|
||||||
def get_3pe_protocol(self, service, protocol):
|
def get_3pe_protocol(self, service, protocol):
|
||||||
|
if service.url is None:
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get():
|
def _get():
|
||||||
uri = "%s%s/thirdparty/protocol/%s" % (
|
uri = "%s%s/thirdparty/protocol/%s" % (
|
||||||
@ -166,6 +175,9 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def push_bulk(self, service, events, txn_id=None):
|
def push_bulk(self, service, events, txn_id=None):
|
||||||
|
if service.url is None:
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
events = self._serialize(events)
|
events = self._serialize(events)
|
||||||
|
|
||||||
if txn_id is None:
|
if txn_id is None:
|
||||||
|
@ -86,7 +86,7 @@ def load_appservices(hostname, config_files):
|
|||||||
|
|
||||||
def _load_appservice(hostname, as_info, config_filename):
|
def _load_appservice(hostname, as_info, config_filename):
|
||||||
required_string_fields = [
|
required_string_fields = [
|
||||||
"id", "url", "as_token", "hs_token", "sender_localpart"
|
"id", "as_token", "hs_token", "sender_localpart"
|
||||||
]
|
]
|
||||||
for field in required_string_fields:
|
for field in required_string_fields:
|
||||||
if not isinstance(as_info.get(field), basestring):
|
if not isinstance(as_info.get(field), basestring):
|
||||||
@ -94,6 +94,14 @@ def _load_appservice(hostname, as_info, config_filename):
|
|||||||
field, config_filename,
|
field, config_filename,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
# 'url' must either be a string or explicitly null, not missing
|
||||||
|
# to avoid accidentally turning off push for ASes.
|
||||||
|
if (not isinstance(as_info.get("url"), basestring) and
|
||||||
|
as_info.get("url", "") is not None):
|
||||||
|
raise KeyError(
|
||||||
|
"Required string field or explicit null: 'url' (%s)" % (config_filename,)
|
||||||
|
)
|
||||||
|
|
||||||
localpart = as_info["sender_localpart"]
|
localpart = as_info["sender_localpart"]
|
||||||
if urllib.quote(localpart) != localpart:
|
if urllib.quote(localpart) != localpart:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -132,6 +140,13 @@ def _load_appservice(hostname, as_info, config_filename):
|
|||||||
for p in protocols:
|
for p in protocols:
|
||||||
if not isinstance(p, str):
|
if not isinstance(p, str):
|
||||||
raise KeyError("Bad value for 'protocols' item")
|
raise KeyError("Bad value for 'protocols' item")
|
||||||
|
|
||||||
|
if as_info["url"] is None:
|
||||||
|
logger.info(
|
||||||
|
"(%s) Explicitly empty 'url' provided. This application service"
|
||||||
|
" will not receive events or queries.",
|
||||||
|
config_filename,
|
||||||
|
)
|
||||||
return ApplicationService(
|
return ApplicationService(
|
||||||
token=as_info["as_token"],
|
token=as_info["as_token"],
|
||||||
url=as_info["url"],
|
url=as_info["url"],
|
||||||
|
@ -15,8 +15,9 @@
|
|||||||
|
|
||||||
|
|
||||||
class EventContext(object):
|
class EventContext(object):
|
||||||
def __init__(self, current_state_ids=None):
|
def __init__(self):
|
||||||
self.current_state_ids = current_state_ids
|
self.current_state_ids = None
|
||||||
|
self.prev_state_ids = None
|
||||||
self.state_group = None
|
self.state_group = None
|
||||||
self.rejected = False
|
self.rejected = False
|
||||||
self.push_actions = []
|
self.push_actions = []
|
||||||
|
@ -269,7 +269,7 @@ class FederationClient(FederationBase):
|
|||||||
|
|
||||||
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
|
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
|
||||||
|
|
||||||
pdu = None
|
signed_pdu = None
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
now = self._clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
last_attempt = pdu_attempts.get(destination, 0)
|
last_attempt = pdu_attempts.get(destination, 0)
|
||||||
@ -299,7 +299,7 @@ class FederationClient(FederationBase):
|
|||||||
pdu = pdu_list[0]
|
pdu = pdu_list[0]
|
||||||
|
|
||||||
# Check signatures are correct.
|
# Check signatures are correct.
|
||||||
pdu = yield self._check_sigs_and_hashes([pdu])[0]
|
signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -322,10 +322,10 @@ class FederationClient(FederationBase):
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self._get_pdu_cache is not None and pdu:
|
if self._get_pdu_cache is not None and signed_pdu:
|
||||||
self._get_pdu_cache[event_id] = pdu
|
self._get_pdu_cache[event_id] = signed_pdu
|
||||||
|
|
||||||
defer.returnValue(pdu)
|
defer.returnValue(signed_pdu)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -222,7 +222,7 @@ class FederationHandler(BaseHandler):
|
|||||||
# joined the room. Don't bother if the user is just
|
# joined the room. Don't bother if the user is just
|
||||||
# changing their profile info.
|
# changing their profile info.
|
||||||
newly_joined = True
|
newly_joined = True
|
||||||
prev_state_id = context.current_state_ids.get(
|
prev_state_id = context.prev_state_ids.get(
|
||||||
(event.type, event.state_key)
|
(event.type, event.state_key)
|
||||||
)
|
)
|
||||||
if prev_state_id:
|
if prev_state_id:
|
||||||
@ -835,12 +835,12 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
self.replication_layer.send_pdu(new_pdu, destinations)
|
self.replication_layer.send_pdu(new_pdu, destinations)
|
||||||
|
|
||||||
state_ids = context.current_state_ids.values()
|
state_ids = context.prev_state_ids.values()
|
||||||
auth_chain = yield self.store.get_auth_chain(set(
|
auth_chain = yield self.store.get_auth_chain(set(
|
||||||
[event.event_id] + state_ids
|
[event.event_id] + state_ids
|
||||||
))
|
))
|
||||||
|
|
||||||
state = yield self.store.get_events(context.current_state_ids.values())
|
state = yield self.store.get_events(context.prev_state_ids.values())
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
"state": state.values(),
|
"state": state.values(),
|
||||||
@ -1333,7 +1333,7 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
if not auth_events:
|
if not auth_events:
|
||||||
auth_events_ids = yield self.auth.compute_auth_events(
|
auth_events_ids = yield self.auth.compute_auth_events(
|
||||||
event, context.current_state_ids, for_verification=True,
|
event, context.prev_state_ids, for_verification=True,
|
||||||
)
|
)
|
||||||
auth_events = yield self.store.get_events(auth_events_ids)
|
auth_events = yield self.store.get_events(auth_events_ids)
|
||||||
auth_events = {
|
auth_events = {
|
||||||
@ -1432,6 +1432,11 @@ class FederationHandler(BaseHandler):
|
|||||||
current_state = set(e.event_id for e in auth_events.values())
|
current_state = set(e.event_id for e in auth_events.values())
|
||||||
event_auth_events = set(e_id for e_id, _ in event.auth_events)
|
event_auth_events = set(e_id for e_id, _ in event.auth_events)
|
||||||
|
|
||||||
|
if event.is_state():
|
||||||
|
event_key = (event.type, event.state_key)
|
||||||
|
else:
|
||||||
|
event_key = None
|
||||||
|
|
||||||
if event_auth_events - current_state:
|
if event_auth_events - current_state:
|
||||||
have_events = yield self.store.have_events(
|
have_events = yield self.store.have_events(
|
||||||
event_auth_events - current_state
|
event_auth_events - current_state
|
||||||
@ -1537,8 +1542,12 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
context.current_state_ids.update({
|
context.current_state_ids.update({
|
||||||
k: a.event_id for k, a in auth_events.items()
|
k: a.event_id for k, a in auth_events.items()
|
||||||
|
if k != event_key
|
||||||
})
|
})
|
||||||
context.state_group = None
|
context.prev_state_ids.update({
|
||||||
|
k: a.event_id for k, a in auth_events.items()
|
||||||
|
})
|
||||||
|
context.state_group = self.store.get_next_state_group()
|
||||||
|
|
||||||
if different_auth and not event.internal_metadata.is_outlier():
|
if different_auth and not event.internal_metadata.is_outlier():
|
||||||
logger.info("Different auth after resolution: %s", different_auth)
|
logger.info("Different auth after resolution: %s", different_auth)
|
||||||
@ -1560,7 +1569,7 @@ class FederationHandler(BaseHandler):
|
|||||||
if do_resolution:
|
if do_resolution:
|
||||||
# 1. Get what we think is the auth chain.
|
# 1. Get what we think is the auth chain.
|
||||||
auth_ids = yield self.auth.compute_auth_events(
|
auth_ids = yield self.auth.compute_auth_events(
|
||||||
event, context.current_state_ids
|
event, context.prev_state_ids
|
||||||
)
|
)
|
||||||
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
|
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
|
||||||
|
|
||||||
@ -1618,8 +1627,12 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
context.current_state_ids.update({
|
context.current_state_ids.update({
|
||||||
k: a.event_id for k, a in auth_events.items()
|
k: a.event_id for k, a in auth_events.items()
|
||||||
|
if k != event_key
|
||||||
})
|
})
|
||||||
context.state_group = None
|
context.prev_state_ids.update({
|
||||||
|
k: a.event_id for k, a in auth_events.items()
|
||||||
|
})
|
||||||
|
context.state_group = self.store.get_next_state_group()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.auth.check(event, auth_events=auth_events)
|
self.auth.check(event, auth_events=auth_events)
|
||||||
@ -1855,7 +1868,7 @@ class FederationHandler(BaseHandler):
|
|||||||
event.content["third_party_invite"]["signed"]["token"]
|
event.content["third_party_invite"]["signed"]["token"]
|
||||||
)
|
)
|
||||||
original_invite = None
|
original_invite = None
|
||||||
original_invite_id = context.current_state_ids.get(key)
|
original_invite_id = context.prev_state_ids.get(key)
|
||||||
if original_invite_id:
|
if original_invite_id:
|
||||||
original_invite = yield self.store.get_event(
|
original_invite = yield self.store.get_event(
|
||||||
original_invite_id, allow_none=True
|
original_invite_id, allow_none=True
|
||||||
@ -1893,7 +1906,7 @@ class FederationHandler(BaseHandler):
|
|||||||
signed = event.content["third_party_invite"]["signed"]
|
signed = event.content["third_party_invite"]["signed"]
|
||||||
token = signed["token"]
|
token = signed["token"]
|
||||||
|
|
||||||
invite_event_id = context.current_state_ids.get(
|
invite_event_id = context.prev_state_ids.get(
|
||||||
(EventTypes.ThirdPartyInvite, token,)
|
(EventTypes.ThirdPartyInvite, token,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -272,7 +272,7 @@ class MessageHandler(BaseHandler):
|
|||||||
If so, returns the version of the event in context.
|
If so, returns the version of the event in context.
|
||||||
Otherwise, returns None.
|
Otherwise, returns None.
|
||||||
"""
|
"""
|
||||||
prev_event_id = context.current_state_ids.get((event.type, event.state_key))
|
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
|
||||||
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
||||||
if not prev_event:
|
if not prev_event:
|
||||||
return
|
return
|
||||||
@ -808,8 +808,8 @@ class MessageHandler(BaseHandler):
|
|||||||
event = builder.build()
|
event = builder.build()
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Created event %s with current state: %s",
|
"Created event %s with state: %s",
|
||||||
event.event_id, context.current_state_ids,
|
event.event_id, context.prev_state_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(
|
defer.returnValue(
|
||||||
@ -904,7 +904,7 @@ class MessageHandler(BaseHandler):
|
|||||||
|
|
||||||
if event.type == EventTypes.Redaction:
|
if event.type == EventTypes.Redaction:
|
||||||
auth_events_ids = yield self.auth.compute_auth_events(
|
auth_events_ids = yield self.auth.compute_auth_events(
|
||||||
event, context.current_state_ids, for_verification=True,
|
event, context.prev_state_ids, for_verification=True,
|
||||||
)
|
)
|
||||||
auth_events = yield self.store.get_events(auth_events_ids)
|
auth_events = yield self.store.get_events(auth_events_ids)
|
||||||
auth_events = {
|
auth_events = {
|
||||||
@ -924,7 +924,7 @@ class MessageHandler(BaseHandler):
|
|||||||
"You don't have permission to redact events"
|
"You don't have permission to redact events"
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type == EventTypes.Create and context.current_state_ids:
|
if event.type == EventTypes.Create and context.prev_state_ids:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403,
|
403,
|
||||||
"Changing the room create event is forbidden",
|
"Changing the room create event is forbidden",
|
||||||
|
@ -191,6 +191,13 @@ class PresenceHandler(object):
|
|||||||
5000,
|
5000,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.clock.call_later(
|
||||||
|
60,
|
||||||
|
self.clock.looping_call,
|
||||||
|
self._persist_unpersisted_changes,
|
||||||
|
60 * 1000,
|
||||||
|
)
|
||||||
|
|
||||||
metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer))
|
metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -216,6 +223,27 @@ class PresenceHandler(object):
|
|||||||
])
|
])
|
||||||
logger.info("Finished _on_shutdown")
|
logger.info("Finished _on_shutdown")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _persist_unpersisted_changes(self):
|
||||||
|
"""We periodically persist the unpersisted changes, as otherwise they
|
||||||
|
may stack up and slow down shutdown times.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"Performing _persist_unpersisted_changes. Persiting %d unpersisted changes",
|
||||||
|
len(self.unpersisted_users_changes)
|
||||||
|
)
|
||||||
|
|
||||||
|
unpersisted = self.unpersisted_users_changes
|
||||||
|
self.unpersisted_users_changes = set()
|
||||||
|
|
||||||
|
if unpersisted:
|
||||||
|
yield self.store.update_presence([
|
||||||
|
self.user_to_current_state[user_id]
|
||||||
|
for user_id in unpersisted
|
||||||
|
])
|
||||||
|
|
||||||
|
logger.info("Finished _persist_unpersisted_changes")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _update_states(self, new_states):
|
def _update_states(self, new_states):
|
||||||
"""Updates presence of users. Sets the appropriate timeouts. Pokes
|
"""Updates presence of users. Sets the appropriate timeouts. Pokes
|
||||||
@ -922,7 +950,12 @@ def should_notify(old_state, new_state):
|
|||||||
if new_state.currently_active != old_state.currently_active:
|
if new_state.currently_active != old_state.currently_active:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
|
if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
|
||||||
|
# Only notify about last active bumps if we're not currently acive
|
||||||
|
if not (old_state.currently_active and new_state.currently_active):
|
||||||
|
return True
|
||||||
|
|
||||||
|
elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
|
||||||
# Always notify for a transition where last active gets bumped.
|
# Always notify for a transition where last active gets bumped.
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ class RoomMemberHandler(BaseHandler):
|
|||||||
ratelimit=ratelimit,
|
ratelimit=ratelimit,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_member_event_id = context.current_state_ids.get(
|
prev_member_event_id = context.prev_state_ids.get(
|
||||||
(EventTypes.Member, target.to_string()),
|
(EventTypes.Member, target.to_string()),
|
||||||
None
|
None
|
||||||
)
|
)
|
||||||
@ -341,7 +341,7 @@ class RoomMemberHandler(BaseHandler):
|
|||||||
|
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
if requester.is_guest:
|
if requester.is_guest:
|
||||||
guest_can_join = yield self._can_guest_join(context.current_state_ids)
|
guest_can_join = yield self._can_guest_join(context.prev_state_ids)
|
||||||
if not guest_can_join:
|
if not guest_can_join:
|
||||||
# This should be an auth check, but guests are a local concept,
|
# This should be an auth check, but guests are a local concept,
|
||||||
# so don't really fit into the general auth process.
|
# so don't really fit into the general auth process.
|
||||||
@ -355,7 +355,7 @@ class RoomMemberHandler(BaseHandler):
|
|||||||
ratelimit=ratelimit,
|
ratelimit=ratelimit,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_member_event_id = context.current_state_ids.get(
|
prev_member_event_id = context.prev_state_ids.get(
|
||||||
(EventTypes.Member, event.state_key),
|
(EventTypes.Member, event.state_key),
|
||||||
None
|
None
|
||||||
)
|
)
|
||||||
|
@ -565,21 +565,26 @@ class SyncHandler(object):
|
|||||||
if sync_result_builder.since_token is not None:
|
if sync_result_builder.since_token is not None:
|
||||||
since_stream_id = int(sync_result_builder.since_token.to_device_key)
|
since_stream_id = int(sync_result_builder.since_token.to_device_key)
|
||||||
|
|
||||||
if since_stream_id:
|
if since_stream_id != int(now_token.to_device_key):
|
||||||
|
# We only delete messages when a new message comes in, but that's
|
||||||
|
# fine so long as we delete them at some point.
|
||||||
|
|
||||||
logger.debug("Deleting messages up to %d", since_stream_id)
|
logger.debug("Deleting messages up to %d", since_stream_id)
|
||||||
yield self.store.delete_messages_for_device(
|
yield self.store.delete_messages_for_device(
|
||||||
user_id, device_id, since_stream_id
|
user_id, device_id, since_stream_id
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Getting messages up to %d", now_token.to_device_key)
|
logger.debug("Getting messages up to %d", now_token.to_device_key)
|
||||||
messages, stream_id = yield self.store.get_new_messages_for_device(
|
messages, stream_id = yield self.store.get_new_messages_for_device(
|
||||||
user_id, device_id, now_token.to_device_key
|
user_id, device_id, now_token.to_device_key
|
||||||
)
|
)
|
||||||
logger.debug("Got messages up to %d: %r", stream_id, messages)
|
logger.debug("Got messages up to %d: %r", stream_id, messages)
|
||||||
sync_result_builder.now_token = now_token.copy_and_replace(
|
sync_result_builder.now_token = now_token.copy_and_replace(
|
||||||
"to_device_key", stream_id
|
"to_device_key", stream_id
|
||||||
)
|
)
|
||||||
sync_result_builder.to_device = messages
|
sync_result_builder.to_device = messages
|
||||||
|
else:
|
||||||
|
sync_result_builder.to_device = []
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _generate_sync_entry_for_account_data(self, sync_result_builder):
|
def _generate_sync_entry_for_account_data(self, sync_result_builder):
|
||||||
|
@ -87,7 +87,7 @@ class BulkPushRuleEvaluator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
room_members = yield self.store.get_joined_users_from_context(
|
room_members = yield self.store.get_joined_users_from_context(
|
||||||
event.room_id, context.state_group, context.current_state_ids
|
event, context
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
|
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
|
||||||
|
@ -338,7 +338,7 @@ class Mailer(object):
|
|||||||
# want the generated-from-names one here otherwise we'll
|
# want the generated-from-names one here otherwise we'll
|
||||||
# end up with, "new message from Bob in the Bob room"
|
# end up with, "new message from Bob in the Bob room"
|
||||||
room_name = yield calculate_room_name(
|
room_name = yield calculate_room_name(
|
||||||
state_by_room[room_id], user_id, fallback_to_members=False
|
self.store, state_by_room[room_id], user_id, fallback_to_members=False
|
||||||
)
|
)
|
||||||
|
|
||||||
my_member_event = state_by_room[room_id][("m.room.member", user_id)]
|
my_member_event = state_by_room[room_id][("m.room.member", user_id)]
|
||||||
|
@ -74,7 +74,7 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True
|
|||||||
alias_event = yield store.get_event(
|
alias_event = yield store.get_event(
|
||||||
alias_id, allow_none=True
|
alias_id, allow_none=True
|
||||||
)
|
)
|
||||||
if alias_event and alias_event.content and alias_event.get("aliases"):
|
if alias_event and alias_event.content.get("aliases"):
|
||||||
the_aliases = alias_event.content["aliases"]
|
the_aliases = alias_event.content["aliases"]
|
||||||
if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]):
|
if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]):
|
||||||
defer.returnValue(the_aliases[0])
|
defer.returnValue(the_aliases[0])
|
||||||
|
@ -40,7 +40,6 @@ STREAM_NAMES = (
|
|||||||
("backfill",),
|
("backfill",),
|
||||||
("push_rules",),
|
("push_rules",),
|
||||||
("pushers",),
|
("pushers",),
|
||||||
("state",),
|
|
||||||
("caches",),
|
("caches",),
|
||||||
("to_device",),
|
("to_device",),
|
||||||
)
|
)
|
||||||
@ -131,7 +130,6 @@ class ReplicationResource(Resource):
|
|||||||
backfill_token = yield self.store.get_current_backfill_token()
|
backfill_token = yield self.store.get_current_backfill_token()
|
||||||
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
|
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
|
||||||
pushers_token = self.store.get_pushers_stream_token()
|
pushers_token = self.store.get_pushers_stream_token()
|
||||||
state_token = self.store.get_state_stream_token()
|
|
||||||
caches_token = self.store.get_cache_stream_token()
|
caches_token = self.store.get_cache_stream_token()
|
||||||
|
|
||||||
defer.returnValue(_ReplicationToken(
|
defer.returnValue(_ReplicationToken(
|
||||||
@ -143,7 +141,7 @@ class ReplicationResource(Resource):
|
|||||||
backfill_token,
|
backfill_token,
|
||||||
push_rules_token,
|
push_rules_token,
|
||||||
pushers_token,
|
pushers_token,
|
||||||
state_token,
|
0, # State stream is no longer a thing
|
||||||
caches_token,
|
caches_token,
|
||||||
int(stream_token.to_device_key),
|
int(stream_token.to_device_key),
|
||||||
))
|
))
|
||||||
@ -193,7 +191,6 @@ class ReplicationResource(Resource):
|
|||||||
yield self.receipts(writer, current_token, limit, request_streams)
|
yield self.receipts(writer, current_token, limit, request_streams)
|
||||||
yield self.push_rules(writer, current_token, limit, request_streams)
|
yield self.push_rules(writer, current_token, limit, request_streams)
|
||||||
yield self.pushers(writer, current_token, limit, request_streams)
|
yield self.pushers(writer, current_token, limit, request_streams)
|
||||||
yield self.state(writer, current_token, limit, request_streams)
|
|
||||||
yield self.caches(writer, current_token, limit, request_streams)
|
yield self.caches(writer, current_token, limit, request_streams)
|
||||||
yield self.to_device(writer, current_token, limit, request_streams)
|
yield self.to_device(writer, current_token, limit, request_streams)
|
||||||
self.streams(writer, current_token, request_streams)
|
self.streams(writer, current_token, request_streams)
|
||||||
@ -368,25 +365,6 @@ class ReplicationResource(Resource):
|
|||||||
"position", "user_id", "app_id", "pushkey"
|
"position", "user_id", "app_id", "pushkey"
|
||||||
))
|
))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def state(self, writer, current_token, limit, request_streams):
|
|
||||||
current_position = current_token.state
|
|
||||||
|
|
||||||
state = request_streams.get("state")
|
|
||||||
|
|
||||||
if state is not None:
|
|
||||||
state_groups, state_group_state = (
|
|
||||||
yield self.store.get_all_new_state_groups(
|
|
||||||
state, current_position, limit
|
|
||||||
)
|
|
||||||
)
|
|
||||||
writer.write_header_and_rows("state_groups", state_groups, (
|
|
||||||
"position", "room_id", "event_id"
|
|
||||||
))
|
|
||||||
writer.write_header_and_rows("state_group_state", state_group_state, (
|
|
||||||
"position", "type", "state_key", "event_id"
|
|
||||||
))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def caches(self, writer, current_token, limit, request_streams):
|
def caches(self, writer, current_token, limit, request_streams):
|
||||||
current_position = current_token.caches
|
current_position = current_token.caches
|
||||||
|
@ -123,6 +123,7 @@ class SlavedEventStore(BaseSlavedStore):
|
|||||||
get_state_groups_ids = DataStore.get_state_groups_ids.__func__
|
get_state_groups_ids = DataStore.get_state_groups_ids.__func__
|
||||||
get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
|
get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
|
||||||
get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
|
get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
|
||||||
|
get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
|
||||||
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
|
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
|
||||||
_get_joined_users_from_context = (
|
_get_joined_users_from_context = (
|
||||||
RoomMemberStore.__dict__["_get_joined_users_from_context"]
|
RoomMemberStore.__dict__["_get_joined_users_from_context"]
|
||||||
|
198
synapse/state.py
198
synapse/state.py
@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes
|
|||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.api.auth import AuthEventTypes
|
from synapse.api.auth import AuthEventTypes
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
|
from synapse.util.async import Linearizer
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
@ -43,11 +44,35 @@ SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR)
|
|||||||
EVICTION_TIMEOUT_SECONDS = 60 * 60
|
EVICTION_TIMEOUT_SECONDS = 60 * 60
|
||||||
|
|
||||||
|
|
||||||
|
_NEXT_STATE_ID = 1
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_state_id():
|
||||||
|
global _NEXT_STATE_ID
|
||||||
|
s = "X%d" % (_NEXT_STATE_ID,)
|
||||||
|
_NEXT_STATE_ID += 1
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
class _StateCacheEntry(object):
|
class _StateCacheEntry(object):
|
||||||
def __init__(self, state, state_group, ts):
|
__slots__ = ["state", "state_group", "state_id"]
|
||||||
|
|
||||||
|
def __init__(self, state, state_group):
|
||||||
self.state = state
|
self.state = state
|
||||||
self.state_group = state_group
|
self.state_group = state_group
|
||||||
|
|
||||||
|
# The `state_id` is a unique ID we generate that can be used as ID for
|
||||||
|
# this collection of state. Usually this would be the same as the
|
||||||
|
# state group, but on worker instances we can't generate a new state
|
||||||
|
# group each time we resolve state, so we generate a separate one that
|
||||||
|
# isn't persisted and is used solely for caches.
|
||||||
|
# `state_id` is either a state_group (and so an int) or a string. This
|
||||||
|
# ensures we don't accidentally persist a state_id as a stateg_group
|
||||||
|
if state_group:
|
||||||
|
self.state_id = state_group
|
||||||
|
else:
|
||||||
|
self.state_id = _gen_state_id()
|
||||||
|
|
||||||
|
|
||||||
class StateHandler(object):
|
class StateHandler(object):
|
||||||
""" Responsible for doing state conflict resolution.
|
""" Responsible for doing state conflict resolution.
|
||||||
@ -60,6 +85,7 @@ class StateHandler(object):
|
|||||||
|
|
||||||
# dict of set of event_ids -> _StateCacheEntry.
|
# dict of set of event_ids -> _StateCacheEntry.
|
||||||
self._state_cache = None
|
self._state_cache = None
|
||||||
|
self.resolve_linearizer = Linearizer()
|
||||||
|
|
||||||
def start_caching(self):
|
def start_caching(self):
|
||||||
logger.debug("start_caching")
|
logger.debug("start_caching")
|
||||||
@ -93,7 +119,8 @@ class StateHandler(object):
|
|||||||
if not latest_event_ids:
|
if not latest_event_ids:
|
||||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
|
|
||||||
_, state = yield self.resolve_state_groups(room_id, latest_event_ids)
|
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||||
|
state = ret.state
|
||||||
|
|
||||||
if event_type:
|
if event_type:
|
||||||
event_id = state.get((event_type, state_key))
|
event_id = state.get((event_type, state_key))
|
||||||
@ -116,7 +143,8 @@ class StateHandler(object):
|
|||||||
if not latest_event_ids:
|
if not latest_event_ids:
|
||||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
|
|
||||||
_, state = yield self.resolve_state_groups(room_id, latest_event_ids)
|
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||||
|
state = ret.state
|
||||||
|
|
||||||
if event_type:
|
if event_type:
|
||||||
defer.returnValue(state.get((event_type, state_key)))
|
defer.returnValue(state.get((event_type, state_key)))
|
||||||
@ -127,9 +155,9 @@ class StateHandler(object):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_current_user_in_room(self, room_id):
|
def get_current_user_in_room(self, room_id):
|
||||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||||
group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids)
|
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||||
joined_users = yield self.store.get_joined_users_from_context(
|
joined_users = yield self.store.get_joined_users_from_state(
|
||||||
room_id, group, state_ids
|
room_id, entry.state_id, entry.state
|
||||||
)
|
)
|
||||||
defer.returnValue(joined_users)
|
defer.returnValue(joined_users)
|
||||||
|
|
||||||
@ -154,52 +182,73 @@ class StateHandler(object):
|
|||||||
# state. Certainly store.get_current_state won't return any, and
|
# state. Certainly store.get_current_state won't return any, and
|
||||||
# persisting the event won't store the state group.
|
# persisting the event won't store the state group.
|
||||||
if old_state:
|
if old_state:
|
||||||
context.current_state_ids = {
|
context.prev_state_ids = {
|
||||||
(s.type, s.state_key): s.event_id for s in old_state
|
(s.type, s.state_key): s.event_id for s in old_state
|
||||||
}
|
}
|
||||||
|
if event.is_state():
|
||||||
|
context.current_state_events = dict(context.prev_state_ids)
|
||||||
|
key = (event.type, event.state_key)
|
||||||
|
context.current_state_events[key] = event.event_id
|
||||||
|
else:
|
||||||
|
context.current_state_events = context.prev_state_ids
|
||||||
else:
|
else:
|
||||||
context.current_state_ids = {}
|
context.current_state_ids = {}
|
||||||
|
context.prev_state_ids = {}
|
||||||
context.prev_state_events = []
|
context.prev_state_events = []
|
||||||
context.state_group = None
|
context.state_group = self.store.get_next_state_group()
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
if old_state:
|
if old_state:
|
||||||
context.current_state_ids = {
|
context.prev_state_ids = {
|
||||||
(s.type, s.state_key): s.event_id for s in old_state
|
(s.type, s.state_key): s.event_id for s in old_state
|
||||||
}
|
}
|
||||||
context.state_group = None
|
context.state_group = self.store.get_next_state_group()
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
if key in context.current_state_ids:
|
if key in context.prev_state_ids:
|
||||||
replaces = context.current_state_ids[key]
|
replaces = context.prev_state_ids[key]
|
||||||
if replaces != event.event_id: # Paranoia check
|
if replaces != event.event_id: # Paranoia check
|
||||||
event.unsigned["replaces_state"] = replaces
|
event.unsigned["replaces_state"] = replaces
|
||||||
|
context.current_state_ids = dict(context.prev_state_ids)
|
||||||
|
context.current_state_ids[key] = event.event_id
|
||||||
|
else:
|
||||||
|
context.current_state_ids = context.prev_state_ids
|
||||||
|
|
||||||
context.prev_state_events = []
|
context.prev_state_events = []
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
ret = yield self.resolve_state_groups(
|
entry = yield self.resolve_state_groups(
|
||||||
event.room_id, [e for e, _ in event.prev_events],
|
event.room_id, [e for e, _ in event.prev_events],
|
||||||
event_type=event.type,
|
event_type=event.type,
|
||||||
state_key=event.state_key,
|
state_key=event.state_key,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ret = yield self.resolve_state_groups(
|
entry = yield self.resolve_state_groups(
|
||||||
event.room_id, [e for e, _ in event.prev_events],
|
event.room_id, [e for e, _ in event.prev_events],
|
||||||
)
|
)
|
||||||
|
|
||||||
group, curr_state = ret
|
curr_state = entry.state
|
||||||
|
|
||||||
context.current_state_ids = curr_state
|
context.prev_state_ids = curr_state
|
||||||
context.state_group = group if not event.is_state() else None
|
if event.is_state():
|
||||||
|
context.state_group = self.store.get_next_state_group()
|
||||||
|
else:
|
||||||
|
if entry.state_group is None:
|
||||||
|
entry.state_group = self.store.get_next_state_group()
|
||||||
|
entry.state_id = entry.state_group
|
||||||
|
context.state_group = entry.state_group
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
if key in context.current_state_ids:
|
if key in context.prev_state_ids:
|
||||||
replaces = context.current_state_ids[key]
|
replaces = context.prev_state_ids[key]
|
||||||
event.unsigned["replaces_state"] = replaces
|
event.unsigned["replaces_state"] = replaces
|
||||||
|
context.current_state_ids = dict(context.prev_state_ids)
|
||||||
|
context.current_state_ids[key] = event.event_id
|
||||||
|
else:
|
||||||
|
context.current_state_ids = context.prev_state_ids
|
||||||
|
|
||||||
context.prev_state_events = []
|
context.prev_state_events = []
|
||||||
defer.returnValue(context)
|
defer.returnValue(context)
|
||||||
@ -231,70 +280,75 @@ class StateHandler(object):
|
|||||||
if len(group_names) == 1:
|
if len(group_names) == 1:
|
||||||
name, state_list = state_groups_ids.items().pop()
|
name, state_list = state_groups_ids.items().pop()
|
||||||
|
|
||||||
defer.returnValue((name, state_list,))
|
defer.returnValue(_StateCacheEntry(
|
||||||
|
state=state_list,
|
||||||
|
state_group=name,
|
||||||
|
))
|
||||||
|
|
||||||
if self._state_cache is not None:
|
with (yield self.resolve_linearizer.queue(group_names)):
|
||||||
cache = self._state_cache.get(group_names, None)
|
if self._state_cache is not None:
|
||||||
if cache:
|
cache = self._state_cache.get(group_names, None)
|
||||||
cache.ts = self.clock.time_msec()
|
if cache:
|
||||||
|
defer.returnValue(cache)
|
||||||
|
|
||||||
defer.returnValue(
|
logger.info(
|
||||||
(cache.state_group, cache.state,)
|
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
state = {}
|
||||||
|
for st in state_groups_ids.values():
|
||||||
|
for key, e_id in st.items():
|
||||||
|
state.setdefault(key, set()).add(e_id)
|
||||||
|
|
||||||
|
conflicted_state = {
|
||||||
|
k: list(v)
|
||||||
|
for k, v in state.items()
|
||||||
|
if len(v) > 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if conflicted_state:
|
||||||
|
logger.info("Resolving conflicted state for %r", room_id)
|
||||||
|
state_map = yield self.store.get_events(
|
||||||
|
[e_id for st in state_groups_ids.values() for e_id in st.values()],
|
||||||
|
get_prev_content=False
|
||||||
)
|
)
|
||||||
|
state_sets = [
|
||||||
|
[state_map[e_id] for key, e_id in st.items() if e_id in state_map]
|
||||||
|
for st in state_groups_ids.values()
|
||||||
|
]
|
||||||
|
new_state, _ = self._resolve_events(
|
||||||
|
state_sets, event_type, state_key
|
||||||
|
)
|
||||||
|
new_state = {
|
||||||
|
key: e.event_id for key, e in new_state.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
new_state = {
|
||||||
|
key: e_ids.pop() for key, e_ids in state.items()
|
||||||
|
}
|
||||||
|
|
||||||
logger.info(
|
state_group = None
|
||||||
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
|
new_state_event_ids = frozenset(new_state.values())
|
||||||
)
|
for sg, events in state_groups_ids.items():
|
||||||
|
if new_state_event_ids == frozenset(e_id for e_id in events):
|
||||||
|
state_group = sg
|
||||||
|
break
|
||||||
|
if state_group is None:
|
||||||
|
# Worker instances don't have access to this method, but we want
|
||||||
|
# to set the state_group on the main instance to increase cache
|
||||||
|
# hits.
|
||||||
|
if hasattr(self.store, "get_next_state_group"):
|
||||||
|
state_group = self.store.get_next_state_group()
|
||||||
|
|
||||||
state = {}
|
|
||||||
for st in state_groups_ids.values():
|
|
||||||
for key, e_id in st.items():
|
|
||||||
state.setdefault(key, set()).add(e_id)
|
|
||||||
|
|
||||||
conflicted_state = {
|
|
||||||
k: list(v)
|
|
||||||
for k, v in state.items()
|
|
||||||
if len(v) > 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if conflicted_state:
|
|
||||||
logger.info("Resolving conflicted state for %r", room_id)
|
|
||||||
state_map = yield self.store.get_events(
|
|
||||||
[e_id for st in state_groups_ids.values() for e_id in st.values()],
|
|
||||||
get_prev_content=False
|
|
||||||
)
|
|
||||||
state_sets = [
|
|
||||||
[state_map[e_id] for key, e_id in st.items() if e_id in state_map]
|
|
||||||
for st in state_groups_ids.values()
|
|
||||||
]
|
|
||||||
new_state, _ = self._resolve_events(
|
|
||||||
state_sets, event_type, state_key
|
|
||||||
)
|
|
||||||
new_state = {
|
|
||||||
key: e.event_id for key, e in new_state.items()
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
new_state = {
|
|
||||||
key: e_ids.pop() for key, e_ids in state.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
state_group = None
|
|
||||||
new_state_event_ids = frozenset(new_state.values())
|
|
||||||
for sg, events in state_groups_ids.items():
|
|
||||||
if new_state_event_ids == frozenset(e_id for e_id in events):
|
|
||||||
state_group = sg
|
|
||||||
break
|
|
||||||
|
|
||||||
if self._state_cache is not None:
|
|
||||||
cache = _StateCacheEntry(
|
cache = _StateCacheEntry(
|
||||||
state=new_state,
|
state=new_state,
|
||||||
state_group=state_group,
|
state_group=state_group,
|
||||||
ts=self.clock.time_msec()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._state_cache[group_names] = cache
|
if self._state_cache is not None:
|
||||||
|
self._state_cache[group_names] = cache
|
||||||
|
|
||||||
defer.returnValue((state_group, new_state,))
|
defer.returnValue(cache)
|
||||||
|
|
||||||
def resolve_events(self, state_sets, event):
|
def resolve_events(self, state_sets, event):
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -115,7 +115,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
|
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
|
||||||
self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id")
|
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
|
||||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||||
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
|
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
|
||||||
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
||||||
|
@ -271,39 +271,28 @@ class EventsStore(SQLBaseStore):
|
|||||||
len(events_and_contexts)
|
len(events_and_contexts)
|
||||||
)
|
)
|
||||||
|
|
||||||
state_group_id_manager = self._state_groups_id_gen.get_next_mult(
|
|
||||||
len(events_and_contexts)
|
|
||||||
)
|
|
||||||
with stream_ordering_manager as stream_orderings:
|
with stream_ordering_manager as stream_orderings:
|
||||||
with state_group_id_manager as state_group_ids:
|
for (event, context), stream, in zip(
|
||||||
for (event, context), stream, state_group_id in zip(
|
events_and_contexts, stream_orderings
|
||||||
events_and_contexts, stream_orderings, state_group_ids
|
):
|
||||||
):
|
event.internal_metadata.stream_ordering = stream
|
||||||
event.internal_metadata.stream_ordering = stream
|
|
||||||
# Assign a state group_id in case a new id is needed for
|
|
||||||
# this context. In theory we only need to assign this
|
|
||||||
# for contexts that have current_state and aren't outliers
|
|
||||||
# but that make the code more complicated. Assigning an ID
|
|
||||||
# per event only causes the state_group_ids to grow as fast
|
|
||||||
# as the stream_ordering so in practise shouldn't be a problem.
|
|
||||||
context.new_state_group_id = state_group_id
|
|
||||||
|
|
||||||
chunks = [
|
chunks = [
|
||||||
events_and_contexts[x:x + 100]
|
events_and_contexts[x:x + 100]
|
||||||
for x in xrange(0, len(events_and_contexts), 100)
|
for x in xrange(0, len(events_and_contexts), 100)
|
||||||
]
|
]
|
||||||
|
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
# We can't easily parallelize these since different chunks
|
# We can't easily parallelize these since different chunks
|
||||||
# might contain the same event. :(
|
# might contain the same event. :(
|
||||||
yield self.runInteraction(
|
yield self.runInteraction(
|
||||||
"persist_events",
|
"persist_events",
|
||||||
self._persist_events_txn,
|
self._persist_events_txn,
|
||||||
events_and_contexts=chunk,
|
events_and_contexts=chunk,
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
delete_existing=delete_existing,
|
delete_existing=delete_existing,
|
||||||
)
|
)
|
||||||
persist_event_counter.inc_by(len(chunk))
|
persist_event_counter.inc_by(len(chunk))
|
||||||
|
|
||||||
@_retry_on_integrity_error
|
@_retry_on_integrity_error
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -312,19 +301,17 @@ class EventsStore(SQLBaseStore):
|
|||||||
delete_existing=False):
|
delete_existing=False):
|
||||||
try:
|
try:
|
||||||
with self._stream_id_gen.get_next() as stream_ordering:
|
with self._stream_id_gen.get_next() as stream_ordering:
|
||||||
with self._state_groups_id_gen.get_next() as state_group_id:
|
event.internal_metadata.stream_ordering = stream_ordering
|
||||||
event.internal_metadata.stream_ordering = stream_ordering
|
yield self.runInteraction(
|
||||||
context.new_state_group_id = state_group_id
|
"persist_event",
|
||||||
yield self.runInteraction(
|
self._persist_event_txn,
|
||||||
"persist_event",
|
event=event,
|
||||||
self._persist_event_txn,
|
context=context,
|
||||||
event=event,
|
current_state=current_state,
|
||||||
context=context,
|
backfilled=backfilled,
|
||||||
current_state=current_state,
|
delete_existing=delete_existing,
|
||||||
backfilled=backfilled,
|
)
|
||||||
delete_existing=delete_existing,
|
persist_event_counter.inc()
|
||||||
)
|
|
||||||
persist_event_counter.inc()
|
|
||||||
except _RollbackButIsFineException:
|
except _RollbackButIsFineException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -528,7 +515,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
# Add an entry to the ex_outlier_stream table to replicate the
|
# Add an entry to the ex_outlier_stream table to replicate the
|
||||||
# change in outlier status to our workers.
|
# change in outlier status to our workers.
|
||||||
stream_order = event.internal_metadata.stream_ordering
|
stream_order = event.internal_metadata.stream_ordering
|
||||||
state_group_id = context.state_group or context.new_state_group_id
|
state_group_id = context.state_group
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
txn,
|
txn,
|
||||||
table="ex_outlier_stream",
|
table="ex_outlier_stream",
|
||||||
|
@ -145,7 +145,7 @@ class ReceiptsStore(SQLBaseStore):
|
|||||||
|
|
||||||
defer.returnValue([ev for res in results.values() for ev in res])
|
defer.returnValue([ev for res in results.values() for ev in res])
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=3, max_entries=5000, tree=True)
|
@cachedInlineCallbacks(num_args=3, tree=True)
|
||||||
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
||||||
"""Get receipts for a single room for sending to clients.
|
"""Get receipts for a single room for sending to clients.
|
||||||
|
|
||||||
|
@ -354,7 +354,8 @@ class RoomMemberStore(SQLBaseStore):
|
|||||||
desc="who_forgot"
|
desc="who_forgot"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_joined_users_from_context(self, room_id, state_group, state_ids):
|
def get_joined_users_from_context(self, event, context):
|
||||||
|
state_group = context.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
# If state_group is None it means it has yet to be assigned a
|
# If state_group is None it means it has yet to be assigned a
|
||||||
# state group, i.e. we need to make sure that calls with a state_group
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
@ -363,12 +364,24 @@ class RoomMemberStore(SQLBaseStore):
|
|||||||
state_group = object()
|
state_group = object()
|
||||||
|
|
||||||
return self._get_joined_users_from_context(
|
return self._get_joined_users_from_context(
|
||||||
room_id, state_group, state_ids
|
event.room_id, state_group, context.current_state_ids, event=event,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_joined_users_from_state(self, room_id, state_group, state_ids):
|
||||||
|
if not state_group:
|
||||||
|
# If state_group is None it means it has yet to be assigned a
|
||||||
|
# state group, i.e. we need to make sure that calls with a state_group
|
||||||
|
# of None don't hit previous cached calls with a None state_group.
|
||||||
|
# To do this we set the state_group to a new object as object() != object()
|
||||||
|
state_group = object()
|
||||||
|
|
||||||
|
return self._get_joined_users_from_context(
|
||||||
|
room_id, state_group, state_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
||||||
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
|
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
|
||||||
cache_context):
|
cache_context, event=None):
|
||||||
# We don't use `state_group`, its there so that we can cache based
|
# We don't use `state_group`, its there so that we can cache based
|
||||||
# on it. However, its important that its never None, since two current_state's
|
# on it. However, its important that its never None, since two current_state's
|
||||||
# with a state_group of None are likely to be different.
|
# with a state_group of None are likely to be different.
|
||||||
@ -393,7 +406,13 @@ class RoomMemberStore(SQLBaseStore):
|
|||||||
desc="_get_joined_users_from_context",
|
desc="_get_joined_users_from_context",
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(set(row["user_id"] for row in rows))
|
users_in_room = set(row["user_id"] for row in rows)
|
||||||
|
if event is not None and event.type == EventTypes.Member:
|
||||||
|
if event.membership == Membership.JOIN:
|
||||||
|
if event.event_id in member_event_ids:
|
||||||
|
users_in_room.add(event.state_key)
|
||||||
|
|
||||||
|
defer.returnValue(users_in_room)
|
||||||
|
|
||||||
def is_host_joined(self, room_id, host, state_group, state_ids):
|
def is_host_joined(self, room_id, host, state_group, state_ids):
|
||||||
if not state_group:
|
if not state_group:
|
||||||
|
32
synapse/storage/schema/delta/34/sent_txn_purge.py
Normal file
32
synapse/storage/schema/delta/34/sent_txn_purge.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.storage.engines import PostgresEngine
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def run_create(cur, database_engine, *args, **kwargs):
|
||||||
|
if isinstance(database_engine, PostgresEngine):
|
||||||
|
cur.execute("TRUNCATE sent_transactions")
|
||||||
|
else:
|
||||||
|
cur.execute("DELETE FROM sent_transactions")
|
||||||
|
|
||||||
|
cur.execute("CREATE INDEX sent_transactions_ts ON sent_transactions(ts)")
|
||||||
|
|
||||||
|
|
||||||
|
def run_upgrade(cur, database_engine, *args, **kwargs):
|
||||||
|
pass
|
@ -83,6 +83,14 @@ class StateStore(SQLBaseStore):
|
|||||||
for group, event_id_map in group_to_ids.items()
|
for group, event_id_map in group_to_ids.items()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def _have_persisted_state_group_txn(self, txn, state_group):
|
||||||
|
txn.execute(
|
||||||
|
"SELECT count(*) FROM state_groups WHERE id = ?",
|
||||||
|
(state_group,)
|
||||||
|
)
|
||||||
|
row = txn.fetchone()
|
||||||
|
return row and row[0]
|
||||||
|
|
||||||
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
|
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
|
||||||
state_groups = {}
|
state_groups = {}
|
||||||
for event, context in events_and_contexts:
|
for event, context in events_and_contexts:
|
||||||
@ -92,22 +100,19 @@ class StateStore(SQLBaseStore):
|
|||||||
if context.current_state_ids is None:
|
if context.current_state_ids is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if context.state_group is not None:
|
state_groups[event.event_id] = context.state_group
|
||||||
state_groups[event.event_id] = context.state_group
|
|
||||||
|
if self._have_persisted_state_group_txn(txn, context.state_group):
|
||||||
|
logger.info("Already persisted state_group: %r", context.state_group)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
state_event_ids = dict(context.current_state_ids)
|
state_event_ids = dict(context.current_state_ids)
|
||||||
|
|
||||||
if event.is_state():
|
|
||||||
state_event_ids[(event.type, event.state_key)] = event.event_id
|
|
||||||
|
|
||||||
state_group = context.new_state_group_id
|
|
||||||
|
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
txn,
|
txn,
|
||||||
table="state_groups",
|
table="state_groups",
|
||||||
values={
|
values={
|
||||||
"id": state_group,
|
"id": context.state_group,
|
||||||
"room_id": event.room_id,
|
"room_id": event.room_id,
|
||||||
"event_id": event.event_id,
|
"event_id": event.event_id,
|
||||||
},
|
},
|
||||||
@ -118,7 +123,7 @@ class StateStore(SQLBaseStore):
|
|||||||
table="state_groups_state",
|
table="state_groups_state",
|
||||||
values=[
|
values=[
|
||||||
{
|
{
|
||||||
"state_group": state_group,
|
"state_group": context.state_group,
|
||||||
"room_id": event.room_id,
|
"room_id": event.room_id,
|
||||||
"type": key[0],
|
"type": key[0],
|
||||||
"state_key": key[1],
|
"state_key": key[1],
|
||||||
@ -127,7 +132,6 @@ class StateStore(SQLBaseStore):
|
|||||||
for key, state_id in state_event_ids.items()
|
for key, state_id in state_event_ids.items()
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
state_groups[event.event_id] = state_group
|
|
||||||
|
|
||||||
self._simple_insert_many_txn(
|
self._simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
@ -527,5 +531,5 @@ class StateStore(SQLBaseStore):
|
|||||||
"get_all_new_state_groups", get_all_new_state_groups_txn
|
"get_all_new_state_groups", get_all_new_state_groups_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_state_stream_token(self):
|
def get_next_state_group(self):
|
||||||
return self._state_groups_id_gen.get_current_token()
|
return self._state_groups_id_gen.get_next()
|
||||||
|
@ -387,8 +387,10 @@ class TransactionStore(SQLBaseStore):
|
|||||||
def _cleanup_transactions(self):
|
def _cleanup_transactions(self):
|
||||||
now = self._clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
month_ago = now - 30 * 24 * 60 * 60 * 1000
|
month_ago = now - 30 * 24 * 60 * 60 * 1000
|
||||||
|
six_hours_ago = now - 6 * 60 * 60 * 1000
|
||||||
|
|
||||||
def _cleanup_transactions_txn(txn):
|
def _cleanup_transactions_txn(txn):
|
||||||
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
|
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
|
||||||
|
txn.execute("DELETE FROM sent_transactions WHERE ts < ?", (six_hours_ago,))
|
||||||
|
|
||||||
return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn)
|
return self.runInteraction("_persist_in_mem_txns", _cleanup_transactions_txn)
|
||||||
|
@ -115,6 +115,53 @@ class PresenceUpdateTestCase(unittest.TestCase):
|
|||||||
),
|
),
|
||||||
], any_order=True)
|
], any_order=True)
|
||||||
|
|
||||||
|
def test_online_to_online_last_active_noop(self):
|
||||||
|
wheel_timer = Mock()
|
||||||
|
user_id = "@foo:bar"
|
||||||
|
now = 5000000
|
||||||
|
|
||||||
|
prev_state = UserPresenceState.default(user_id)
|
||||||
|
prev_state = prev_state.copy_and_replace(
|
||||||
|
state=PresenceState.ONLINE,
|
||||||
|
last_active_ts=now - LAST_ACTIVE_GRANULARITY - 10,
|
||||||
|
currently_active=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_state = prev_state.copy_and_replace(
|
||||||
|
state=PresenceState.ONLINE,
|
||||||
|
last_active_ts=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
state, persist_and_notify, federation_ping = handle_update(
|
||||||
|
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertFalse(persist_and_notify)
|
||||||
|
self.assertTrue(federation_ping)
|
||||||
|
self.assertTrue(state.currently_active)
|
||||||
|
self.assertEquals(new_state.state, state.state)
|
||||||
|
self.assertEquals(new_state.status_msg, state.status_msg)
|
||||||
|
self.assertEquals(state.last_federation_update_ts, now)
|
||||||
|
|
||||||
|
self.assertEquals(wheel_timer.insert.call_count, 3)
|
||||||
|
wheel_timer.insert.assert_has_calls([
|
||||||
|
call(
|
||||||
|
now=now,
|
||||||
|
obj=user_id,
|
||||||
|
then=new_state.last_active_ts + IDLE_TIMER
|
||||||
|
),
|
||||||
|
call(
|
||||||
|
now=now,
|
||||||
|
obj=user_id,
|
||||||
|
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
|
||||||
|
),
|
||||||
|
call(
|
||||||
|
now=now,
|
||||||
|
obj=user_id,
|
||||||
|
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
|
||||||
|
),
|
||||||
|
], any_order=True)
|
||||||
|
|
||||||
def test_online_to_online_last_active(self):
|
def test_online_to_online_last_active(self):
|
||||||
wheel_timer = Mock()
|
wheel_timer = Mock()
|
||||||
user_id = "@foo:bar"
|
user_id = "@foo:bar"
|
||||||
|
@ -312,7 +312,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
|||||||
else:
|
else:
|
||||||
state_ids = None
|
state_ids = None
|
||||||
|
|
||||||
context = EventContext(current_state_ids=state_ids)
|
context = EventContext()
|
||||||
|
context.current_state_ids = state_ids
|
||||||
|
context.prev_state_ids = state_ids
|
||||||
context.push_actions = push_actions
|
context.push_actions = push_actions
|
||||||
|
|
||||||
ordering = None
|
ordering = None
|
||||||
|
@ -60,8 +60,8 @@ class ReplicationResourceCase(unittest.TestCase):
|
|||||||
self.assertEquals(body, {})
|
self.assertEquals(body, {})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_events_and_state(self):
|
def test_events(self):
|
||||||
get = self.get(events="-1", state="-1", timeout="0")
|
get = self.get(events="-1", timeout="0")
|
||||||
yield self.hs.get_handlers().room_creation_handler.create_room(
|
yield self.hs.get_handlers().room_creation_handler.create_room(
|
||||||
synapse.types.create_requester(self.user), {}
|
synapse.types.create_requester(self.user), {}
|
||||||
)
|
)
|
||||||
@ -70,12 +70,6 @@ class ReplicationResourceCase(unittest.TestCase):
|
|||||||
self.assertEquals(body["events"]["field_names"], [
|
self.assertEquals(body["events"]["field_names"], [
|
||||||
"position", "internal", "json", "state_group"
|
"position", "internal", "json", "state_group"
|
||||||
])
|
])
|
||||||
self.assertEquals(body["state_groups"]["field_names"], [
|
|
||||||
"position", "room_id", "event_id"
|
|
||||||
])
|
|
||||||
self.assertEquals(body["state_group_state"]["field_names"], [
|
|
||||||
"position", "type", "state_key", "event_id"
|
|
||||||
])
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_presence(self):
|
def test_presence(self):
|
||||||
|
@ -86,17 +86,8 @@ class StateGroupStore(object):
|
|||||||
|
|
||||||
state_events = dict(context.current_state_ids)
|
state_events = dict(context.current_state_ids)
|
||||||
|
|
||||||
if event.is_state():
|
self._group_to_state[context.state_group] = state_events
|
||||||
state_events[(event.type, event.state_key)] = event.event_id
|
self._event_to_state_group[event.event_id] = context.state_group
|
||||||
|
|
||||||
state_group = context.state_group
|
|
||||||
if not state_group:
|
|
||||||
state_group = self._next_group
|
|
||||||
self._next_group += 1
|
|
||||||
|
|
||||||
self._group_to_state[state_group] = state_events
|
|
||||||
|
|
||||||
self._event_to_state_group[event.event_id] = state_group
|
|
||||||
|
|
||||||
def get_events(self, event_ids, **kwargs):
|
def get_events(self, event_ids, **kwargs):
|
||||||
return {
|
return {
|
||||||
@ -151,6 +142,7 @@ class StateTestCase(unittest.TestCase):
|
|||||||
"get_state_groups_ids",
|
"get_state_groups_ids",
|
||||||
"add_event_hashes",
|
"add_event_hashes",
|
||||||
"get_events",
|
"get_events",
|
||||||
|
"get_next_state_group",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
hs = Mock(spec_set=[
|
hs = Mock(spec_set=[
|
||||||
@ -161,6 +153,8 @@ class StateTestCase(unittest.TestCase):
|
|||||||
hs.get_clock.return_value = MockClock()
|
hs.get_clock.return_value = MockClock()
|
||||||
hs.get_auth.return_value = Auth(hs)
|
hs.get_auth.return_value = Auth(hs)
|
||||||
|
|
||||||
|
self.store.get_next_state_group.side_effect = Mock
|
||||||
|
|
||||||
self.state = StateHandler(hs)
|
self.state = StateHandler(hs)
|
||||||
self.event_id = 0
|
self.event_id = 0
|
||||||
|
|
||||||
@ -209,7 +203,7 @@ class StateTestCase(unittest.TestCase):
|
|||||||
store.store_state_groups(event, context)
|
store.store_state_groups(event, context)
|
||||||
context_store[event.event_id] = context
|
context_store[event.event_id] = context
|
||||||
|
|
||||||
self.assertEqual(2, len(context_store["D"].current_state_ids))
|
self.assertEqual(2, len(context_store["D"].prev_state_ids))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_branch_basic_conflict(self):
|
def test_branch_basic_conflict(self):
|
||||||
@ -265,7 +259,7 @@ class StateTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"START", "A", "C"},
|
{"START", "A", "C"},
|
||||||
{e_id for e_id in context_store["D"].current_state_ids.values()}
|
{e_id for e_id in context_store["D"].prev_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -331,7 +325,7 @@ class StateTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"START", "A", "B", "C"},
|
{"START", "A", "B", "C"},
|
||||||
{e for e in context_store["E"].current_state_ids.values()}
|
{e for e in context_store["E"].prev_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -414,7 +408,7 @@ class StateTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
{"A1", "A2", "A3", "A5", "B"},
|
{"A1", "A2", "A3", "A5", "B"},
|
||||||
{e for e in context_store["D"].current_state_ids.values()}
|
{e for e in context_store["D"].prev_state_ids.values()}
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_depths(self, nodes, edges):
|
def _add_depths(self, nodes, edges):
|
||||||
@ -447,7 +441,7 @@ class StateTestCase(unittest.TestCase):
|
|||||||
set(e.event_id for e in old_state), set(context.current_state_ids.values())
|
set(e.event_id for e in old_state), set(context.current_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
self.assertIsNotNone(context.state_group)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_annotate_with_old_state(self):
|
def test_annotate_with_old_state(self):
|
||||||
@ -464,11 +458,9 @@ class StateTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set(e.event_id for e in old_state), set(context.current_state_ids.values())
|
set(e.event_id for e in old_state), set(context.prev_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_trivial_annotate_message(self):
|
def test_trivial_annotate_message(self):
|
||||||
event = create_event(type="test_message", name="event")
|
event = create_event(type="test_message", name="event")
|
||||||
@ -514,10 +506,10 @@ class StateTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set([e.event_id for e in old_state]),
|
set([e.event_id for e in old_state]),
|
||||||
set(context.current_state_ids.values())
|
set(context.prev_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
self.assertIsNotNone(context.state_group)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_resolve_message_conflict(self):
|
def test_resolve_message_conflict(self):
|
||||||
@ -550,7 +542,7 @@ class StateTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(len(context.current_state_ids), 6)
|
self.assertEqual(len(context.current_state_ids), 6)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
self.assertIsNotNone(context.state_group)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_resolve_state_conflict(self):
|
def test_resolve_state_conflict(self):
|
||||||
@ -583,7 +575,7 @@ class StateTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(len(context.current_state_ids), 6)
|
self.assertEqual(len(context.current_state_ids), 6)
|
||||||
|
|
||||||
self.assertIsNone(context.state_group)
|
self.assertIsNotNone(context.state_group)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_standard_depth_conflict(self):
|
def test_standard_depth_conflict(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user