Derive current_state_events from state groups

This commit is contained in:
Erik Johnston 2017-01-20 11:52:51 +00:00
parent 97efe99ae9
commit 09eb08f910
4 changed files with 138 additions and 99 deletions

View File

@ -1319,7 +1319,6 @@ class FederationHandler(BaseHandler):
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, new_event_context, event, new_event_context,
current_state=state,
) )
defer.returnValue((event_stream_id, max_stream_id)) defer.returnValue((event_stream_id, max_stream_id))

View File

@ -429,6 +429,9 @@ def resolve_events(state_sets, state_map_factory):
dict[(str, str), synapse.events.FrozenEvent] is a map from dict[(str, str), synapse.events.FrozenEvent] is a map from
(type, state_key) to event. (type, state_key) to event.
""" """
if len(state_sets) == 1:
return state_sets[0]
unconflicted_state, conflicted_state = _seperate( unconflicted_state, conflicted_state = _seperate(
state_sets, state_sets,
) )

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, _RollbackButIsFineException from ._base import SQLBaseStore
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
@ -27,6 +27,7 @@ from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.state import resolve_events
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from collections import deque, namedtuple, OrderedDict from collections import deque, namedtuple, OrderedDict
@ -71,22 +72,19 @@ class _EventPeristenceQueue(object):
""" """
_EventPersistQueueItem = namedtuple("_EventPersistQueueItem", ( _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", (
"events_and_contexts", "current_state", "backfilled", "deferred", "events_and_contexts", "backfilled", "deferred",
)) ))
def __init__(self): def __init__(self):
self._event_persist_queues = {} self._event_persist_queues = {}
self._currently_persisting_rooms = set() self._currently_persisting_rooms = set()
def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state): def add_to_queue(self, room_id, events_and_contexts, backfilled):
"""Add events to the queue, with the given persist_event options. """Add events to the queue, with the given persist_event options.
""" """
queue = self._event_persist_queues.setdefault(room_id, deque()) queue = self._event_persist_queues.setdefault(room_id, deque())
if queue: if queue:
end_item = queue[-1] end_item = queue[-1]
if end_item.current_state or current_state:
# We perist events with current_state set to True one at a time
pass
if end_item.backfilled == backfilled: if end_item.backfilled == backfilled:
end_item.events_and_contexts.extend(events_and_contexts) end_item.events_and_contexts.extend(events_and_contexts)
return end_item.deferred.observe() return end_item.deferred.observe()
@ -96,7 +94,6 @@ class _EventPeristenceQueue(object):
queue.append(self._EventPersistQueueItem( queue.append(self._EventPersistQueueItem(
events_and_contexts=events_and_contexts, events_and_contexts=events_and_contexts,
backfilled=backfilled, backfilled=backfilled,
current_state=current_state,
deferred=deferred, deferred=deferred,
)) ))
@ -216,7 +213,6 @@ class EventsStore(SQLBaseStore):
d = preserve_fn(self._event_persist_queue.add_to_queue)( d = preserve_fn(self._event_persist_queue.add_to_queue)(
room_id, evs_ctxs, room_id, evs_ctxs,
backfilled=backfilled, backfilled=backfilled,
current_state=None,
) )
deferreds.append(d) deferreds.append(d)
@ -229,11 +225,10 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, context, current_state=None, backfilled=False): def persist_event(self, event, context, backfilled=False):
deferred = self._event_persist_queue.add_to_queue( deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], event.room_id, [(event, context)],
backfilled=backfilled, backfilled=backfilled,
current_state=current_state,
) )
self._maybe_start_persisting(event.room_id) self._maybe_start_persisting(event.room_id)
@ -246,17 +241,6 @@ class EventsStore(SQLBaseStore):
def _maybe_start_persisting(self, room_id): def _maybe_start_persisting(self, room_id):
@defer.inlineCallbacks @defer.inlineCallbacks
def persisting_queue(item): def persisting_queue(item):
if item.current_state:
for event, context in item.events_and_contexts:
# There should only ever be one item in
# events_and_contexts when current_state is
# not None
yield self._persist_event(
event, context,
current_state=item.current_state,
backfilled=item.backfilled,
)
else:
yield self._persist_events( yield self._persist_events(
item.events_and_contexts, item.events_and_contexts,
backfilled=item.backfilled, backfilled=item.backfilled,
@ -294,36 +278,89 @@ class EventsStore(SQLBaseStore):
for chunk in chunks: for chunk in chunks:
# We can't easily parallelize these since different chunks # We can't easily parallelize these since different chunks
# might contain the same event. :( # might contain the same event. :(
current_state_for_room = {}
if not backfilled:
# Work out the new "current state" for each room.
# We do this by working out what the new extremities are and then
# calculating the state from that.
events_by_room = {}
for event, context in chunk:
events_by_room.setdefault(event.room_id, []).append(
(event, context)
)
for room_id, ev_ctx_rm in events_by_room.items():
# Work out new extremities by recursively adding and removing
# the new events.
latest_event_ids = yield self.get_latest_event_ids_in_room(
room_id
)
new_latest_event_ids = set(latest_event_ids)
for event, ctx in ev_ctx_rm:
if event.internal_metadata.is_outlier():
continue
new_latest_event_ids.difference_update(
e_id for e_id, _ in event.prev_events
)
new_latest_event_ids.add(event.event_id)
if new_latest_event_ids == set(latest_event_ids):
# No change in extremities, so no change in state
continue
# Now we need to work out the different state sets for
# each state extremities
state_sets = []
missing_event_ids = []
was_updated = False
for event_id in new_latest_event_ids:
# First search in the list of new events we're adding,
# and then use the current state from that
for ev, ctx in ev_ctx_rm:
if event_id == ev.event_id:
if ctx.current_state_ids is None:
raise Exception("Unknown current state")
state_sets.append(ctx.current_state_ids)
if ctx.delta_ids or hasattr(ev, "state_key"):
was_updated = True
break
else:
# If we couldn't find it, then we'll need to pull
# the state from the database
was_updated = True
missing_event_ids.append(event_id)
if missing_event_ids:
# Now pull out the state for any missing events from DB
event_to_groups = yield self._get_state_group_for_events(
missing_event_ids,
)
groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups)
state_sets.extend(group_to_state.values())
if not new_latest_event_ids or was_updated:
current_state_for_room[room_id] = yield resolve_events(
state_sets,
state_map_factory=lambda ev_ids: self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
),
)
yield self.runInteraction( yield self.runInteraction(
"persist_events", "persist_events",
self._persist_events_txn, self._persist_events_txn,
events_and_contexts=chunk, events_and_contexts=chunk,
backfilled=backfilled, backfilled=backfilled,
delete_existing=delete_existing, delete_existing=delete_existing,
current_state_for_room=current_state_for_room,
) )
persist_event_counter.inc_by(len(chunk)) persist_event_counter.inc_by(len(chunk))
@_retry_on_integrity_error
@defer.inlineCallbacks
@log_function
def _persist_event(self, event, context, current_state=None, backfilled=False,
delete_existing=False):
try:
with self._stream_id_gen.get_next() as stream_ordering:
event.internal_metadata.stream_ordering = stream_ordering
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
event=event,
context=context,
current_state=current_state,
backfilled=backfilled,
delete_existing=delete_existing,
)
persist_event_counter.inc()
except _RollbackButIsFineException:
pass
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True, def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False, get_prev_content=False, allow_rejected=False,
@ -426,7 +463,7 @@ class EventsStore(SQLBaseStore):
@log_function @log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled, def _persist_events_txn(self, txn, events_and_contexts, backfilled,
delete_existing=False): delete_existing=False, current_state_for_room={}):
"""Insert some number of room events into the necessary database tables. """Insert some number of room events into the necessary database tables.
Rejected events are only inserted into the events table, the events_json table, Rejected events are only inserted into the events table, the events_json table,
@ -436,6 +473,40 @@ class EventsStore(SQLBaseStore):
If delete_existing is True then existing events will be purged from the If delete_existing is True then existing events will be purged from the
database before insertion. This is useful when retrying due to IntegrityError. database before insertion. This is useful when retrying due to IntegrityError.
""" """
for room_id, current_state in current_state_for_room.iteritems():
txn.call_after(self._get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (room_id,))
# Add an entry to the current_state_resets table to record the point
# where we clobbered the current state
stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
self._simple_insert_txn(
txn,
table="current_state_resets",
values={"event_stream_ordering": stream_order}
)
self._simple_delete_txn(
txn,
table="current_state_events",
keyvalues={"room_id": room_id},
)
self._simple_insert_many_txn(
txn,
table="current_state_events",
values=[
{
"event_id": ev_id,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
}
for key, ev_id in current_state.iteritems()
],
)
# Ensure that we don't have the same event twice. # Ensure that we don't have the same event twice.
# Pick the earliest non-outlier if there is one, else the earliest one. # Pick the earliest non-outlier if there is one, else the earliest one.
new_events_and_contexts = OrderedDict() new_events_and_contexts = OrderedDict()
@ -798,29 +869,6 @@ class EventsStore(SQLBaseStore):
# to update the current state table # to update the current state table
return return
for event, _ in state_events_and_contexts:
if event.internal_metadata.is_outlier():
# Outlier events shouldn't clobber the current state.
continue
txn.call_after(
self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,)
)
self._simple_upsert_txn(
txn,
"current_state_events",
keyvalues={
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
values={
"event_id": event.event_id,
}
)
return return
def _add_to_cache(self, txn, events_and_contexts): def _add_to_cache(self, txn, events_and_contexts):

View File

@ -60,7 +60,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_room_members(self): def test_room_members(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate() yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), []) yield self.check("get_rooms_for_user", (USER_ID,), [])
yield self.check("get_users_in_room", (ROOM_ID,), []) yield self.check("get_users_in_room", (ROOM_ID,), [])
@ -95,15 +95,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)]) )])
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2]) yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2])
# Join the room clobbering the state.
# This should remove any evidence of the other user being in the room.
yield self.persist( yield self.persist(
type="m.room.member", key=USER_ID, membership="join", type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
) )
yield self.replicate() yield self.replicate()
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID]) yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2, USER_ID])
yield self.check("get_rooms_for_user", (USER_ID_2,), [])
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_latest_event_ids_in_room(self): def test_get_latest_event_ids_in_room(self):
@ -125,7 +121,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_current_state(self): def test_get_current_state(self):
# Create the room. # Create the room.
create = yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate() yield self.replicate()
yield self.check( yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), [] "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), []
@ -151,22 +147,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
[join2] [join2]
) )
# Leave the room, then rejoin the room clobbering state.
yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
join3 = yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
[]
)
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
[join3]
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_redactions(self): def test_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.persist(type="m.room.create", key="", creator=USER_ID)
@ -283,6 +263,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if depth is None: if depth is None:
depth = self.event_id depth = self.event_id
if not prev_events:
latest_event_ids = yield self.master_store.get_latest_event_ids_in_room(
room_id
)
prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
event_dict = { event_dict = {
"sender": sender, "sender": sender,
"type": type, "type": type,
@ -309,12 +295,15 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
state_ids = { state_ids = {
key: e.event_id for key, e in state.items() key: e.event_id for key, e in state.items()
} }
else:
state_ids = None
context = EventContext() context = EventContext()
context.current_state_ids = state_ids context.current_state_ids = state_ids
context.prev_state_ids = state_ids context.prev_state_ids = state_ids
elif not backfill:
state_handler = self.hs.get_state_handler()
context = yield state_handler.compute_event_context(event)
else:
context = EventContext()
context.push_actions = push_actions context.push_actions = push_actions
ordering = None ordering = None
@ -324,7 +313,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
) )
else: else:
ordering, _ = yield self.master_store.persist_event( ordering, _ = yield self.master_store.persist_event(
event, context, current_state=reset_state event, context,
) )
if ordering: if ordering: