Add bulk insert events API

This commit is contained in:
Erik Johnston 2015-06-25 17:18:19 +01:00
parent 6924852592
commit 5130d80d79
8 changed files with 521 additions and 374 deletions

View File

@ -327,6 +327,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def make_join(self, destinations, room_id, user_id): def make_join(self, destinations, room_id, user_id):
for destination in destinations: for destination in destinations:
if destination == self.server_name:
continue
try: try:
ret = yield self.transport_layer.make_join( ret = yield self.transport_layer.make_join(
destination, room_id, user_id destination, room_id, user_id
@ -353,6 +356,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def send_join(self, destinations, pdu): def send_join(self, destinations, pdu):
for destination in destinations: for destination in destinations:
if destination == self.server_name:
continue
try: try:
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join( _, content = yield self.transport_layer.send_join(

View File

@ -138,26 +138,29 @@ class FederationHandler(BaseHandler):
if state and auth_chain is not None: if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication # If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.) # layer, then we should handle them (if we haven't before.)
event_infos = []
for e in itertools.chain(auth_chain, state): for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids: if e.event_id in seen_ids:
continue continue
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: auth_ids = [e_id for e_id, _ in e.auth_events]
auth_ids = [e_id for e_id, _ in e.auth_events] auth = {
auth = { (e.type, e.state_key): e for e in auth_chain
(e.type, e.state_key): e for e in auth_chain if e.event_id in auth_ids
if e.event_id in auth_ids }
} event_infos.append({
yield self._handle_new_event( "event": e,
origin, e, auth_events=auth "auth_events": auth,
) })
seen_ids.add(e.event_id) seen_ids.add(e.event_id)
except:
logger.exception( yield self._handle_new_events(
"Failed to handle state event %s", origin,
e.event_id, event_infos,
) outliers=True
)
try: try:
_, event_stream_id, max_stream_id = yield self._handle_new_event( _, event_stream_id, max_stream_id = yield self._handle_new_event(
@ -292,38 +295,29 @@ class FederationHandler(BaseHandler):
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results}) auth_events.update({a.event_id: a for a in results})
yield defer.gatherResults( ev_infos = []
[ for a in auth_events.values():
self._handle_new_event( if a.event_id in seen_events:
dest, a, continue
auth_events={ ev_infos.append({
(auth_events[a_id].type, auth_events[a_id].state_key): "event": a,
auth_events[a_id] "auth_events": {
for a_id, _ in a.auth_events (auth_events[a_id].type, auth_events[a_id].state_key):
}, auth_events[a_id]
) for a_id, _ in a.auth_events
for a in auth_events.values() }
if a.event_id not in seen_events })
],
consumeErrors=True,
).addErrback(unwrapFirstError)
yield defer.gatherResults( for e_id in events_to_state:
[ ev_infos.append({
self._handle_new_event( "event": event_map[e_id],
dest, event_map[e_id], "state": events_to_state[e_id],
state=events_to_state[e_id], "auth_events": {
backfilled=True, (auth_events[a_id].type, auth_events[a_id].state_key):
auth_events={ auth_events[a_id]
(auth_events[a_id].type, auth_events[a_id].state_key): for a_id, _ in event_map[e_id].auth_events
auth_events[a_id] }
for a_id, _ in event_map[e_id].auth_events })
},
)
for e_id in events_to_state
],
consumeErrors=True
).addErrback(unwrapFirstError)
events.sort(key=lambda e: e.depth) events.sort(key=lambda e: e.depth)
@ -331,10 +325,14 @@ class FederationHandler(BaseHandler):
if event in events_to_state: if event in events_to_state:
continue continue
yield self._handle_new_event( ev_infos.append({
dest, event, "event": event,
backfilled=True, })
)
yield self._handle_new_events(
dest, ev_infos,
backfilled=True,
)
defer.returnValue(events) defer.returnValue(events)
@ -600,32 +598,22 @@ class FederationHandler(BaseHandler):
# FIXME # FIXME
pass pass
yield self._handle_auth_events( ev_infos = []
origin, [e for e in auth_chain if e.event_id != event.event_id] for e in itertools.chain(state, auth_chain):
)
@defer.inlineCallbacks
def handle_state(e):
if e.event_id == event.event_id: if e.event_id == event.event_id:
return continue
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: auth_ids = [e_id for e_id, _ in e.auth_events]
auth_ids = [e_id for e_id, _ in e.auth_events] ev_infos.append({
auth = { "event": e,
"auth_events": {
(e.type, e.state_key): e for e in auth_chain (e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids if e.event_id in auth_ids
} }
yield self._handle_new_event( })
origin, e, auth_events=auth
)
except:
logger.exception(
"Failed to handle state event %s",
e.event_id,
)
yield defer.DeferredList([handle_state(e) for e in state]) yield self._handle_new_events(origin, ev_infos, outliers=True)
auth_ids = [e_id for e_id, _ in event.auth_events] auth_ids = [e_id for e_id, _ in event.auth_events]
auth_events = { auth_events = {
@ -940,11 +928,54 @@ class FederationHandler(BaseHandler):
def _handle_new_event(self, origin, event, state=None, backfilled=False, def _handle_new_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None): current_state=None, auth_events=None):
logger.debug( outlier = event.internal_metadata.is_outlier()
"_handle_new_event: %s, sigs: %s",
event.event_id, event.signatures, context = yield self._prep_event(
origin, event,
state=state,
backfilled=backfilled,
current_state=current_state,
auth_events=auth_events,
) )
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
is_new_state=(not outlier and not backfilled),
current_state=current_state,
)
defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False,
outliers=False):
contexts = yield defer.gatherResults(
[
self._prep_event(
origin,
ev_info["event"],
state=ev_info.get("state"),
backfilled=backfilled,
auth_events=ev_info.get("auth_events"),
)
for ev_info in event_infos
]
)
yield self.store.persist_events(
[
(ev_info["event"], context)
for ev_info, context in itertools.izip(event_infos, contexts)
],
backfilled=backfilled,
is_new_state=(not outliers and not backfilled),
)
@defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier() outlier = event.internal_metadata.is_outlier()
context = yield self.state_handler.compute_event_context( context = yield self.state_handler.compute_event_context(
@ -954,13 +985,6 @@ class FederationHandler(BaseHandler):
if not auth_events: if not auth_events:
auth_events = context.current_state auth_events = context.current_state
logger.debug(
"_handle_new_event: %s, auth_events: %s",
event.event_id, auth_events,
)
is_new_state = not outlier
# This is a hack to fix some old rooms where the initial join event # This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events. # didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_events: if event.type == EventTypes.Member and not event.auth_events:
@ -984,26 +1008,7 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
# FIXME: Don't store as rejected with AUTH_ERROR if we haven't defer.returnValue(context)
# seen all the auth events.
yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
is_new_state=False,
current_state=current_state,
)
raise
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
is_new_state=(is_new_state and not backfilled),
current_state=current_state,
)
defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects, def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
@ -1066,14 +1071,24 @@ class FederationHandler(BaseHandler):
@log_function @log_function
def do_auth(self, origin, event, context, auth_events): def do_auth(self, origin, event, context, auth_events):
# Check if we have all the auth events. # Check if we have all the auth events.
have_events = yield self.store.have_events( current_state = set(e.event_id for e in auth_events.values())
[e_id for e_id, _ in event.auth_events]
)
event_auth_events = set(e_id for e_id, _ in event.auth_events) event_auth_events = set(e_id for e_id, _ in event.auth_events)
if event_auth_events - current_state:
have_events = yield self.store.have_events(
event_auth_events - current_state
)
else:
have_events = {}
have_events.update({
e.event_id: ""
for e in auth_events.values()
})
seen_events = set(have_events.keys()) seen_events = set(have_events.keys())
missing_auth = event_auth_events - seen_events missing_auth = event_auth_events - seen_events - current_state
if missing_auth: if missing_auth:
logger.info("Missing auth: %s", missing_auth) logger.info("Missing auth: %s", missing_auth)

View File

@ -282,8 +282,7 @@ class EventFederationStore(SQLBaseStore):
}, },
) )
def _handle_prev_events(self, txn, outlier, event_id, prev_events, def _handle_mult_prev_events(self, txn, events):
room_id):
""" """
For the given event, update the event edges table and forward and For the given event, update the event edges table and forward and
backward extremities tables. backward extremities tables.
@ -293,68 +292,75 @@ class EventFederationStore(SQLBaseStore):
table="event_edges", table="event_edges",
values=[ values=[
{ {
"event_id": event_id, "event_id": ev.event_id,
"prev_event_id": e_id, "prev_event_id": e_id,
"room_id": room_id, "room_id": ev.room_id,
"is_state": False, "is_state": False,
} }
for e_id, _ in prev_events for ev in events
for e_id, _ in ev.prev_events
], ],
) )
# Update the extremities table if this is not an outlier. events_by_room = {}
if not outlier: for ev in events:
for e_id, _ in prev_events: events_by_room.setdefault(ev.room_id, []).append(ev)
# TODO (erikj): This could be done as a bulk insert
self._simple_delete_txn( for room_id, room_events in events_by_room.items():
txn, prevs = [
table="event_forward_extremities", e_id for ev in room_events for e_id, _ in ev.prev_events
keyvalues={ if not ev.internal_metadata.is_outlier()
"event_id": e_id, ]
"room_id": room_id, if prevs:
} txn.execute(
"DELETE FROM event_forward_extremities"
" WHERE room_id = ?"
" AND event_id in (%s)" % (
",".join(["?"] * len(prevs)),
),
[room_id] + prevs,
) )
# We only insert as a forward extremity the new event if there are query = (
# no other events that reference it as a prev event "INSERT INTO event_forward_extremities (event_id, room_id)"
query = ( " SELECT ?, ? WHERE NOT EXISTS ("
"SELECT 1 FROM event_edges WHERE prev_event_id = ?" " SELECT 1 FROM event_edges WHERE prev_event_id = ?"
) " )"
)
txn.execute(query, (event_id,)) txn.executemany(
query,
[(ev.event_id, ev.room_id, ev.event_id) for ev in events]
)
if not txn.fetchone(): query = (
query = ( "INSERT INTO event_backward_extremities (event_id, room_id)"
"INSERT INTO event_forward_extremities" " SELECT ?, ? WHERE NOT EXISTS ("
" (event_id, room_id)" " SELECT 1 FROM event_backward_extremities"
" VALUES (?, ?)" " WHERE event_id = ? AND room_id = ?"
) " )"
" AND NOT EXISTS ("
" SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
" AND outlier = ?"
" )"
)
txn.execute(query, (event_id, room_id)) txn.executemany(query, [
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
for ev in events for e_id, _ in ev.prev_events
if not ev.internal_metadata.is_outlier()
])
query = ( query = (
"INSERT INTO event_backward_extremities (event_id, room_id)" "DELETE FROM event_backward_extremities"
" SELECT ?, ? WHERE NOT EXISTS (" " WHERE event_id = ? AND room_id = ?"
" SELECT 1 FROM event_backward_extremities" )
" WHERE event_id = ? AND room_id = ?" txn.executemany(
" )" query,
" AND NOT EXISTS (" [(ev.event_id, ev.room_id) for ev in events]
" SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " )
" AND outlier = ?"
" )"
)
txn.executemany(query, [
(e_id, room_id, e_id, room_id, e_id, room_id, False)
for e_id, _ in prev_events
])
query = (
"DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
)
txn.execute(query, (event_id, room_id))
for room_id in events_by_room:
txn.call_after( txn.call_after(
self.get_latest_event_ids_in_room.invalidate, room_id self.get_latest_event_ids_in_room.invalidate, room_id
) )

View File

@ -23,9 +23,7 @@ from synapse.events.utils import prune_event
from synapse.util.logcontext import preserve_context_over_deferred from synapse.util.logcontext import preserve_context_over_deferred
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.crypto.event_signing import compute_event_reference_hash
from syutil.base64util import decode_base64
from syutil.jsonutil import encode_json from syutil.jsonutil import encode_json
from contextlib import contextmanager from contextlib import contextmanager
@ -46,6 +44,48 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
class EventsStore(SQLBaseStore): class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
def persist_events(self, events_and_contexts, backfilled=False,
is_new_state=True):
if not events_and_contexts:
return
if backfilled:
if not self.min_token_deferred.called:
yield self.min_token_deferred
start = self.min_token - 1
self.min_token -= len(events_and_contexts) + 1
stream_orderings = range(start, self.min_token, -1)
@contextmanager
def stream_ordering_manager():
yield stream_orderings
stream_ordering_manager = stream_ordering_manager()
else:
stream_ordering_manager = yield self._stream_id_gen.get_next_mult(
self, len(events_and_contexts)
)
with stream_ordering_manager as stream_orderings:
for (event, _), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
chunks = [
events_and_contexts[x:x+100]
for x in xrange(0, len(events_and_contexts), 100)
]
for chunk in chunks:
# We can't easily parallelize these since different chunks
# might contain the same event. :(
yield self.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=chunk,
backfilled=backfilled,
is_new_state=is_new_state,
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, context, backfilled=False, def persist_event(self, event, context, backfilled=False,
@ -67,13 +107,13 @@ class EventsStore(SQLBaseStore):
try: try:
with stream_ordering_manager as stream_ordering: with stream_ordering_manager as stream_ordering:
event.internal_metadata.stream_ordering = stream_ordering
yield self.runInteraction( yield self.runInteraction(
"persist_event", "persist_event",
self._persist_event_txn, self._persist_event_txn,
event=event, event=event,
context=context, context=context,
backfilled=backfilled, backfilled=backfilled,
stream_ordering=stream_ordering,
is_new_state=is_new_state, is_new_state=is_new_state,
current_state=current_state, current_state=current_state,
) )
@ -116,12 +156,7 @@ class EventsStore(SQLBaseStore):
@log_function @log_function
def _persist_event_txn(self, txn, event, context, backfilled, def _persist_event_txn(self, txn, event, context, backfilled,
stream_ordering=None, is_new_state=True, is_new_state=True, current_state=None):
current_state=None):
# Remove the any existing cache entries for the event_id
txn.call_after(self._invalidate_get_event_cache, event.event_id)
# We purposefully do this first since if we include a `current_state` # We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table # key, we *want* to update the `current_state_events` table
if current_state: if current_state:
@ -149,37 +184,78 @@ class EventsStore(SQLBaseStore):
} }
) )
outlier = event.internal_metadata.is_outlier() return self._persist_events_txn(
if not outlier:
self._update_min_depth_for_room_txn(
txn,
event.room_id,
event.depth
)
have_persisted = self._simple_select_one_txn(
txn, txn,
table="events", [(event, context)],
keyvalues={"event_id": event.event_id}, backfilled=backfilled,
retcols=["event_id", "outlier"], is_new_state=is_new_state,
allow_none=True,
) )
metadata_json = encode_json( @log_function
event.internal_metadata.get_dict(), def _persist_events_txn(self, txn, events_and_contexts, backfilled,
using_frozen_dicts=USE_FROZEN_DICTS is_new_state=True):
).decode("UTF-8")
# If we have already persisted this event, we don't need to do any # Remove the any existing cache entries for the event_ids
# more processing. for event, _ in events_and_contexts:
# The processing above must be done on every call to persist event, txn.call_after(self._invalidate_get_event_cache, event.event_id)
# since they might not have happened on previous calls. For example,
# if we are persisting an event that we had persisted as an outlier, depth_updates = {}
# but is no longer one. for event, _ in events_and_contexts:
if have_persisted: if event.internal_metadata.is_outlier():
if not outlier and have_persisted["outlier"]: continue
self._store_state_groups_txn(txn, event, context) depth_updates[event.room_id] = max(
event.depth, depth_updates.get(event.room_id, event.depth)
)
for room_id, depth in depth_updates.items():
self._update_min_depth_for_room_txn(txn, room_id, depth)
txn.execute(
"SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
",".join(["?"] * len(events_and_contexts)),
),
[event.event_id for event, _ in events_and_contexts]
)
have_persisted = {
event_id: outlier
for event_id, outlier in txn.fetchall()
}
event_map = {}
to_remove = set()
for event, context in events_and_contexts:
# Handle the case of the list including the same event multiple
# times. The tricky thing here is when they differ by whether
# they are an outlier.
if event.event_id in event_map:
other = event_map[event.event_id]
if not other.internal_metadata.is_outlier():
to_remove.add(event)
continue
elif not event.internal_metadata.is_outlier():
to_remove.add(event)
continue
else:
to_remove.add(other)
event_map[event.event_id] = event
if event.event_id not in have_persisted:
continue
to_remove.add(event)
outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted:
self._store_state_groups_txn(
txn, event, context,
)
metadata_json = encode_json(
event.internal_metadata.get_dict(),
using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8")
sql = ( sql = (
"UPDATE event_json SET internal_metadata = ?" "UPDATE event_json SET internal_metadata = ?"
@ -198,94 +274,91 @@ class EventsStore(SQLBaseStore):
sql, sql,
(False, event.event_id,) (False, event.event_id,)
) )
events_and_contexts = filter(
lambda ec: ec[0] not in to_remove,
events_and_contexts
)
if not events_and_contexts:
return return
if not outlier: self._store_mult_state_groups_txn(txn, [
self._store_state_groups_txn(txn, event, context) (event, context)
for event, context in events_and_contexts
if not event.internal_metadata.is_outlier()
])
self._handle_prev_events( self._handle_mult_prev_events(
txn, txn,
outlier=outlier, events=[event for event, _ in events_and_contexts],
event_id=event.event_id,
prev_events=event.prev_events,
room_id=event.room_id,
) )
if event.type == EventTypes.Member: for event, _ in events_and_contexts:
self._store_room_member_txn(txn, event) if event.type == EventTypes.Name:
elif event.type == EventTypes.Name: self._store_room_name_txn(txn, event)
self._store_room_name_txn(txn, event) elif event.type == EventTypes.Topic:
elif event.type == EventTypes.Topic: self._store_room_topic_txn(txn, event)
self._store_room_topic_txn(txn, event) elif event.type == EventTypes.Redaction:
elif event.type == EventTypes.Redaction: self._store_redaction(txn, event)
self._store_redaction(txn, event)
event_dict = { self._store_room_members_txn(
k: v txn,
for k, v in event.get_dict().items() [
if k not in [ event
"redacted", for event, _ in events_and_contexts
"redacted_because", if event.type == EventTypes.Member
] ]
} )
self._simple_insert_txn( def event_dict(event):
return {
k: v
for k, v in event.get_dict().items()
if k not in [
"redacted",
"redacted_because",
]
}
self._simple_insert_many_txn(
txn, txn,
table="event_json", table="event_json",
values={ values=[
"event_id": event.event_id, {
"room_id": event.room_id, "event_id": event.event_id,
"internal_metadata": metadata_json, "room_id": event.room_id,
"json": encode_json( "internal_metadata": encode_json(
event_dict, using_frozen_dicts=USE_FROZEN_DICTS event.internal_metadata.get_dict(),
).decode("UTF-8"), using_frozen_dicts=USE_FROZEN_DICTS
}, ).decode("UTF-8"),
"json": encode_json(
event_dict(event), using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8"),
}
for event, _ in events_and_contexts
],
) )
content = encode_json( self._simple_insert_many_txn(
event.content, using_frozen_dicts=USE_FROZEN_DICTS txn,
).decode("UTF-8") table="events",
values=[
vals = { {
"topological_ordering": event.depth, "stream_ordering": event.internal_metadata.stream_ordering,
"event_id": event.event_id, "topological_ordering": event.depth,
"type": event.type, "depth": event.depth,
"room_id": event.room_id, "event_id": event.event_id,
"content": content, "room_id": event.room_id,
"processed": True, "type": event.type,
"outlier": outlier, "processed": True,
"depth": event.depth, "outlier": event.internal_metadata.is_outlier(),
} "content": encode_json(
event.content, using_frozen_dicts=USE_FROZEN_DICTS
unrec = { ).decode("UTF-8"),
k: v }
for k, v in event.get_dict().items() for event, _ in events_and_contexts
if k not in vals.keys() and k not in [ ],
"redacted",
"redacted_because",
"signatures",
"hashes",
"prev_events",
]
}
vals["unrecognized_keys"] = encode_json(
unrec, using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8")
sql = (
"INSERT INTO events"
" (stream_ordering, topological_ordering, event_id, type,"
" room_id, content, processed, outlier, depth)"
" VALUES (?,?,?,?,?,?,?,?,?)"
)
txn.execute(
sql,
(
stream_ordering, event.depth, event.event_id, event.type,
event.room_id, content, True, outlier, event.depth
)
) )
if context.rejected: if context.rejected:
@ -293,19 +366,19 @@ class EventsStore(SQLBaseStore):
txn, event.event_id, context.rejected txn, event.event_id, context.rejected
) )
for hash_alg, hash_base64 in event.hashes.items(): # for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64) # hash_bytes = decode_base64(hash_base64)
self._store_event_content_hash_txn( # self._store_event_content_hash_txn(
txn, event.event_id, hash_alg, hash_bytes, # txn, event.event_id, hash_alg, hash_bytes,
) # )
for prev_event_id, prev_hashes in event.prev_events: # for prev_event_id, prev_hashes in event.prev_events:
for alg, hash_base64 in prev_hashes.items(): # for alg, hash_base64 in prev_hashes.items():
hash_bytes = decode_base64(hash_base64) # hash_bytes = decode_base64(hash_base64)
self._store_prev_event_hash_txn( # self._store_prev_event_hash_txn(
txn, event.event_id, prev_event_id, alg, # txn, event.event_id, prev_event_id, alg,
hash_bytes # hash_bytes
) # )
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
@ -316,16 +389,22 @@ class EventsStore(SQLBaseStore):
"room_id": event.room_id, "room_id": event.room_id,
"auth_id": auth_id, "auth_id": auth_id,
} }
for event, _ in events_and_contexts
for auth_id, _ in event.auth_events for auth_id, _ in event.auth_events
], ],
) )
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event) self._store_event_reference_hashes_txn(
self._store_event_reference_hash_txn( txn, [event for event, _ in events_and_contexts]
txn, event.event_id, ref_alg, ref_hash_bytes
) )
if event.is_state(): state_events_and_contexts = filter(
lambda i: i[0].is_state(),
events_and_contexts,
)
state_values = []
for event, context in state_events_and_contexts:
vals = { vals = {
"event_id": event.event_id, "event_id": event.event_id,
"room_id": event.room_id, "room_id": event.room_id,
@ -337,51 +416,55 @@ class EventsStore(SQLBaseStore):
if hasattr(event, "replaces_state"): if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state vals["prev_state"] = event.replaces_state
self._simple_insert_txn( state_values.append(vals)
txn,
"state_events",
vals,
)
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
table="event_edges", table="state_events",
values=[ values=state_values,
{ )
"event_id": event.event_id,
"prev_event_id": e_id,
"room_id": event.room_id,
"is_state": True,
}
for e_id, h in event.prev_state
],
)
if is_new_state and not context.rejected: self._simple_insert_many_txn(
txn.call_after( txn,
self.get_current_state_for_key.invalidate, table="event_edges",
event.room_id, event.type, event.state_key values=[
) {
"event_id": event.event_id,
"prev_event_id": prev_id,
"room_id": event.room_id,
"is_state": True,
}
for event, _ in state_events_and_contexts
for prev_id, _ in event.prev_state
],
)
if (event.type == EventTypes.Name if is_new_state:
or event.type == EventTypes.Aliases): for event, _ in state_events_and_contexts:
if not context.rejected:
txn.call_after( txn.call_after(
self.get_room_name_and_aliases.invalidate, self.get_current_state_for_key.invalidate,
event.room_id event.room_id, event.type, event.state_key
) )
self._simple_upsert_txn( if event.type in [EventTypes.Name, EventTypes.Aliases]:
txn, txn.call_after(
"current_state_events", self.get_room_name_and_aliases.invalidate,
keyvalues={ event.room_id
"room_id": event.room_id, )
"type": event.type,
"state_key": event.state_key, self._simple_upsert_txn(
}, txn,
values={ "current_state_events",
"event_id": event.event_id, keyvalues={
} "room_id": event.room_id,
) "type": event.type,
"state_key": event.state_key,
},
values={
"event_id": event.event_id,
}
)
return return

View File

@ -35,38 +35,28 @@ RoomsForUser = namedtuple(
class RoomMemberStore(SQLBaseStore): class RoomMemberStore(SQLBaseStore):
def _store_room_member_txn(self, txn, event): def _store_room_members_txn(self, txn, events):
"""Store a room member in the database. """Store a room member in the database.
""" """
try: self._simple_insert_many_txn(
target_user_id = event.state_key
except:
logger.exception(
"Failed to parse target_user_id=%s", target_user_id
)
raise
logger.debug(
"_store_room_member_txn: target_user_id=%s, membership=%s",
target_user_id,
event.membership,
)
self._simple_insert_txn(
txn, txn,
"room_memberships", table="room_memberships",
{ values=[
"event_id": event.event_id, {
"user_id": target_user_id, "event_id": event.event_id,
"sender": event.user_id, "user_id": event.state_key,
"room_id": event.room_id, "sender": event.user_id,
"membership": event.membership, "room_id": event.room_id,
} "membership": event.membership,
}
for event in events
]
) )
txn.call_after(self.get_rooms_for_user.invalidate, target_user_id) for event in events:
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) txn.call_after(self.get_rooms_for_user.invalidate, event.state_key)
txn.call_after(self.get_users_in_room.invalidate, event.room_id) txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member. """Retrieve the current state of a room member.

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from _base import SQLBaseStore from _base import SQLBaseStore
from syutil.base64util import encode_base64 from syutil.base64util import encode_base64
from synapse.crypto.event_signing import compute_event_reference_hash
class SignatureStore(SQLBaseStore): class SignatureStore(SQLBaseStore):
@ -101,23 +102,26 @@ class SignatureStore(SQLBaseStore):
txn.execute(query, (event_id, )) txn.execute(query, (event_id, ))
return {k: v for k, v in txn.fetchall()} return {k: v for k, v in txn.fetchall()}
def _store_event_reference_hash_txn(self, txn, event_id, algorithm, def _store_event_reference_hashes_txn(self, txn, events):
hash_bytes):
"""Store a hash for a PDU """Store a hash for a PDU
Args: Args:
txn (cursor): txn (cursor):
event_id (str): Id for the Event. events (list): list of Events.
algorithm (str): Hashing algorithm.
hash_bytes (bytes): Hash function output bytes.
""" """
self._simple_insert_txn(
vals = []
for event in events:
ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
vals.append({
"event_id": event.event_id,
"algorithm": ref_alg,
"hash": buffer(ref_hash_bytes),
})
self._simple_insert_many_txn(
txn, txn,
"event_reference_hashes", table="event_reference_hashes",
{ values=vals,
"event_id": event_id,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
},
) )
def _get_event_signatures_txn(self, txn, event_id): def _get_event_signatures_txn(self, txn, event_id):

View File

@ -100,16 +100,23 @@ class StateStore(SQLBaseStore):
) )
def _store_state_groups_txn(self, txn, event, context): def _store_state_groups_txn(self, txn, event, context):
if context.current_state is None: return self._store_mult_state_groups_txn(txn, [(event, context)])
return
state_events = dict(context.current_state) def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if context.current_state is None:
continue
if event.is_state(): if context.state_group is not None:
state_events[(event.type, event.state_key)] = event state_groups[event.event_id] = context.state_group
continue
state_events = dict(context.current_state)
if event.is_state():
state_events[(event.type, event.state_key)] = event
state_group = context.state_group
if not state_group:
state_group = self._state_groups_id_gen.get_next_txn(txn) state_group = self._state_groups_id_gen.get_next_txn(txn)
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
@ -135,14 +142,19 @@ class StateStore(SQLBaseStore):
for state in state_events.values() for state in state_events.values()
], ],
) )
state_groups[event.event_id] = state_group
self._simple_insert_txn( self._simple_insert_many_txn(
txn, txn,
table="event_to_state_groups", table="event_to_state_groups",
values={ values=[
"state_group": state_group, {
"event_id": event.event_id, "state_group": state_groups[event.event_id],
}, "event_id": event.event_id,
}
for event, context in events_and_contexts
if context.current_state is not None
],
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -107,6 +107,37 @@ class StreamIdGenerator(object):
defer.returnValue(manager()) defer.returnValue(manager())
@defer.inlineCallbacks
def get_next_mult(self, store, n):
"""
Usage:
with yield stream_id_gen.get_next(store, n) as stream_ids:
# ... persist events ...
"""
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock:
next_ids = range(self._current_max + 1, self._current_max + n + 1)
self._current_max += n
for next_id in next_ids:
self._unfinished_ids.append(next_id)
@contextlib.contextmanager
def manager():
try:
yield next_ids
finally:
with self._lock:
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
defer.returnValue(manager())
@defer.inlineCallbacks @defer.inlineCallbacks
def get_max_token(self, store): def get_max_token(self, store):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or