diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index a8f8989e3..c20ff3a57 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -299,6 +299,10 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, max_entries=hs.config.event_cache_size) + self._event_fetch_lock = threading.Lock() + self._event_fetch_list = [] + self._event_fetch_ongoing = False + self.database_engine = hs.database_engine self._stream_id_gen = StreamIdGenerator() diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 0aa4e0d44..be88328ce 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -15,7 +15,7 @@ from _base import SQLBaseStore, _RollbackButIsFineException -from twisted.internet import defer +from twisted.internet import defer, reactor from synapse.events import FrozenEvent from synapse.events.utils import prune_event @@ -89,18 +89,17 @@ class EventsStore(SQLBaseStore): Returns: Deferred : A FrozenEvent. """ - event = yield self.runInteraction( - "get_event", self._get_event_txn, - event_id, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, + events = yield self._get_events( + [event_id], + check_redacted=True, + get_prev_content=False, + allow_rejected=False, ) - if not event and not allow_none: + if not events and not allow_none: raise RuntimeError("Could not find event %s" % (event_id,)) - defer.returnValue(event) + defer.returnValue(events[0] if events else None) @log_function def _persist_event_txn(self, txn, event, context, backfilled, @@ -420,13 +419,21 @@ class EventsStore(SQLBaseStore): if e_id in event_map and event_map[e_id] ]) - missing_events = yield self._fetch_events( - txn, - missing_events_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) + if not txn: + missing_events = yield self._enqueue_events( + missing_events_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + else: + missing_events = self._fetch_events_txn( + txn, + missing_events_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) event_map.update(missing_events) @@ -492,11 +499,82 @@ class EventsStore(SQLBaseStore): )) @defer.inlineCallbacks - def _fetch_events(self, txn, events, check_redacted=True, - get_prev_content=False, allow_rejected=False): + def _enqueue_events(self, events, check_redacted=True, + get_prev_content=False, allow_rejected=False): if not events: defer.returnValue({}) + def do_fetch(txn): + event_list = [] + try: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + return + + event_id_lists = zip(*event_list)[0] + event_ids = [ + item for sublist in event_id_lists for item in sublist + ] + rows = self._fetch_event_rows(txn, event_ids) + + row_dict = { + r["event_id"]: r + for r in rows + } + + for ids, d in event_list: + d.callback( + [ + row_dict[i] for i in ids + if i in row_dict + ] + ) + except Exception as e: + for _, d in event_list: + try: + reactor.callFromThread(d.errback, e) + except: + pass + finally: + with self._event_fetch_lock: + self._event_fetch_ongoing = False + + def cb(rows): + return defer.gatherResults([ + self._get_event_from_row( + None, + row["internal_metadata"], row["json"], row["redacts"], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + rejected_reason=row["rejects"], + ) + for row in rows + ]) + + d = defer.Deferred() + d.addCallback(cb) + with self._event_fetch_lock: + self._event_fetch_list.append( + (events, d) + ) + + if not self._event_fetch_ongoing: + self.runInteraction( + "do_fetch", + do_fetch + ) + + res = yield d + + defer.returnValue({ + e.event_id: e + for e in res if e + }) + + def _fetch_event_rows(self, txn, events): rows = [] N = 200 for i in range(1 + len(events) / N): @@ -505,43 +583,56 @@ class EventsStore(SQLBaseStore): break sql = ( - "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id " + "SELECT " + " e.event_id as event_id, " + " e.internal_metadata," + " e.json," + " r.redacts as redacts," + " rej.event_id as rejects " " FROM event_json as e" " LEFT JOIN rejections as rej USING (event_id)" " LEFT JOIN redactions as r ON e.event_id = r.redacts" " WHERE e.event_id IN (%s)" ) % (",".join(["?"]*len(evs)),) - if txn: - txn.execute(sql, evs) - rows.extend(txn.fetchall()) - else: - res = yield self._execute("_fetch_events", None, sql, *evs) - rows.extend(res) + txn.execute(sql, evs) + rows.extend(self.cursor_to_dict(txn)) + + return rows + + @defer.inlineCallbacks + def _fetch_events(self, txn, events, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not events: + defer.returnValue({}) + + if txn: + rows = self._fetch_event_rows( + txn, events, + ) + else: + rows = yield self.runInteraction( + self._fetch_event_rows, + events, + ) res = yield defer.gatherResults( [ defer.maybeDeferred( self._get_event_from_row, txn, - row[0], row[1], row[2], + row["internal_metadata"], row["json"], row["redacts"], check_redacted=check_redacted, get_prev_content=get_prev_content, - rejected_reason=row[3], + rejected_reason=row["rejects"], ) for row in rows - ], - consumeErrors=True, + ] ) - for e in res: - self._get_event_cache.prefill( - e.event_id, check_redacted, get_prev_content, e - ) - defer.returnValue({ - e.event_id: e - for e in res if e + r.event_id: r + for r in res }) @defer.inlineCallbacks @@ -611,6 +702,10 @@ class EventsStore(SQLBaseStore): if prev: ev.unsigned["prev_content"] = prev.get_dict()["content"] + self._get_event_cache.prefill( + ev.event_id, check_redacted, get_prev_content, ev + ) + defer.returnValue(ev) def _parse_events(self, rows): diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index b9afb3364..260714ccc 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -80,16 +80,16 @@ class Clock(object): def stop_looping_call(self, loop): loop.stop() - def call_later(self, delay, callback): + def call_later(self, delay, callback, *args, **kwargs): current_context = LoggingContext.current_context() - def wrapped_callback(): + def wrapped_callback(*args, **kwargs): with PreserveLoggingContext(): LoggingContext.thread_local.current_context = current_context - callback() + callback(*args, **kwargs) with PreserveLoggingContext(): - return reactor.callLater(delay, wrapped_callback) + return reactor.callLater(delay, wrapped_callback, *args, **kwargs) def cancel_call_later(self, timer): timer.cancel()