mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
6cfee09be9
Add db_conn parameters to the `__init__` methods of the *Store classes, so that they are all consistent, which makes the multiple inheritance work correctly (and so that we can later extract mixins which can be used in the slavedstores)
609 lines
22 KiB
Python
609 lines
22 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright 2015, 2016 OpenMarket Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from twisted.internet import defer
|
|
|
|
from .background_updates import BackgroundUpdateStore
|
|
from synapse.api.errors import SynapseError
|
|
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
|
|
|
import logging
|
|
import re
|
|
import ujson as json
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SearchStore(BackgroundUpdateStore):
|
|
|
|
EVENT_SEARCH_UPDATE_NAME = "event_search"
|
|
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
|
|
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
|
|
|
|
def __init__(self, db_conn, hs):
|
|
super(SearchStore, self).__init__(db_conn, hs)
|
|
self.register_background_update_handler(
|
|
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
|
|
)
|
|
self.register_background_update_handler(
|
|
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
|
|
self._background_reindex_search_order
|
|
)
|
|
self.register_background_update_handler(
|
|
self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME,
|
|
self._background_reindex_gist_search
|
|
)
|
|
|
|
@defer.inlineCallbacks
|
|
def _background_reindex_search(self, progress, batch_size):
|
|
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
|
max_stream_id = progress["max_stream_id_exclusive"]
|
|
rows_inserted = progress.get("rows_inserted", 0)
|
|
|
|
INSERT_CLUMP_SIZE = 1000
|
|
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
|
|
|
|
def reindex_search_txn(txn):
|
|
sql = (
|
|
"SELECT stream_ordering, event_id, room_id, type, content FROM events"
|
|
" WHERE ? <= stream_ordering AND stream_ordering < ?"
|
|
" AND (%s)"
|
|
" ORDER BY stream_ordering DESC"
|
|
" LIMIT ?"
|
|
) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),)
|
|
|
|
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
|
|
|
|
rows = self.cursor_to_dict(txn)
|
|
if not rows:
|
|
return 0
|
|
|
|
min_stream_id = rows[-1]["stream_ordering"]
|
|
|
|
event_search_rows = []
|
|
for row in rows:
|
|
try:
|
|
event_id = row["event_id"]
|
|
room_id = row["room_id"]
|
|
etype = row["type"]
|
|
try:
|
|
content = json.loads(row["content"])
|
|
except Exception:
|
|
continue
|
|
|
|
if etype == "m.room.message":
|
|
key = "content.body"
|
|
value = content["body"]
|
|
elif etype == "m.room.topic":
|
|
key = "content.topic"
|
|
value = content["topic"]
|
|
elif etype == "m.room.name":
|
|
key = "content.name"
|
|
value = content["name"]
|
|
except (KeyError, AttributeError):
|
|
# If the event is missing a necessary field then
|
|
# skip over it.
|
|
continue
|
|
|
|
if not isinstance(value, basestring):
|
|
# If the event body, name or topic isn't a string
|
|
# then skip over it
|
|
continue
|
|
|
|
event_search_rows.append((event_id, room_id, key, value))
|
|
|
|
if isinstance(self.database_engine, PostgresEngine):
|
|
sql = (
|
|
"INSERT INTO event_search (event_id, room_id, key, vector)"
|
|
" VALUES (?,?,?,to_tsvector('english', ?))"
|
|
)
|
|
elif isinstance(self.database_engine, Sqlite3Engine):
|
|
sql = (
|
|
"INSERT INTO event_search (event_id, room_id, key, value)"
|
|
" VALUES (?,?,?,?)"
|
|
)
|
|
else:
|
|
# This should be unreachable.
|
|
raise Exception("Unrecognized database engine")
|
|
|
|
for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE):
|
|
clump = event_search_rows[index:index + INSERT_CLUMP_SIZE]
|
|
txn.executemany(sql, clump)
|
|
|
|
progress = {
|
|
"target_min_stream_id_inclusive": target_min_stream_id,
|
|
"max_stream_id_exclusive": min_stream_id,
|
|
"rows_inserted": rows_inserted + len(event_search_rows)
|
|
}
|
|
|
|
self._background_update_progress_txn(
|
|
txn, self.EVENT_SEARCH_UPDATE_NAME, progress
|
|
)
|
|
|
|
return len(event_search_rows)
|
|
|
|
result = yield self.runInteraction(
|
|
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
|
|
)
|
|
|
|
if not result:
|
|
yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
|
|
|
|
defer.returnValue(result)
|
|
|
|
@defer.inlineCallbacks
|
|
def _background_reindex_gist_search(self, progress, batch_size):
|
|
def create_index(conn):
|
|
conn.rollback()
|
|
conn.set_session(autocommit=True)
|
|
c = conn.cursor()
|
|
|
|
c.execute(
|
|
"CREATE INDEX CONCURRENTLY event_search_fts_idx_gist"
|
|
" ON event_search USING GIST (vector)"
|
|
)
|
|
|
|
c.execute("DROP INDEX event_search_fts_idx")
|
|
|
|
conn.set_session(autocommit=False)
|
|
|
|
if isinstance(self.database_engine, PostgresEngine):
|
|
yield self.runWithConnection(create_index)
|
|
|
|
yield self._end_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
|
|
defer.returnValue(1)
|
|
|
|
@defer.inlineCallbacks
|
|
def _background_reindex_search_order(self, progress, batch_size):
|
|
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
|
max_stream_id = progress["max_stream_id_exclusive"]
|
|
rows_inserted = progress.get("rows_inserted", 0)
|
|
have_added_index = progress['have_added_indexes']
|
|
|
|
if not have_added_index:
|
|
def create_index(conn):
|
|
conn.rollback()
|
|
conn.set_session(autocommit=True)
|
|
c = conn.cursor()
|
|
|
|
# We create with NULLS FIRST so that when we search *backwards*
|
|
# we get the ones with non null origin_server_ts *first*
|
|
c.execute(
|
|
"CREATE INDEX CONCURRENTLY event_search_room_order ON event_search("
|
|
"room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
|
|
)
|
|
c.execute(
|
|
"CREATE INDEX CONCURRENTLY event_search_order ON event_search("
|
|
"origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
|
|
)
|
|
conn.set_session(autocommit=False)
|
|
|
|
yield self.runWithConnection(create_index)
|
|
|
|
pg = dict(progress)
|
|
pg["have_added_indexes"] = True
|
|
|
|
yield self.runInteraction(
|
|
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
|
|
self._background_update_progress_txn,
|
|
self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg,
|
|
)
|
|
|
|
def reindex_search_txn(txn):
|
|
sql = (
|
|
"UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
|
|
" origin_server_ts = e.origin_server_ts"
|
|
" FROM events AS e"
|
|
" WHERE e.event_id = es.event_id"
|
|
" AND ? <= e.stream_ordering AND e.stream_ordering < ?"
|
|
" RETURNING es.stream_ordering"
|
|
)
|
|
|
|
min_stream_id = max_stream_id - batch_size
|
|
txn.execute(sql, (min_stream_id, max_stream_id))
|
|
rows = txn.fetchall()
|
|
|
|
if min_stream_id < target_min_stream_id:
|
|
# We've recached the end.
|
|
return len(rows), False
|
|
|
|
progress = {
|
|
"target_min_stream_id_inclusive": target_min_stream_id,
|
|
"max_stream_id_exclusive": min_stream_id,
|
|
"rows_inserted": rows_inserted + len(rows),
|
|
"have_added_indexes": True,
|
|
}
|
|
|
|
self._background_update_progress_txn(
|
|
txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress
|
|
)
|
|
|
|
return len(rows), True
|
|
|
|
num_rows, finished = yield self.runInteraction(
|
|
self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
|
|
)
|
|
|
|
if not finished:
|
|
yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME)
|
|
|
|
defer.returnValue(num_rows)
|
|
|
|
@defer.inlineCallbacks
|
|
def search_msgs(self, room_ids, search_term, keys):
|
|
"""Performs a full text search over events with given keys.
|
|
|
|
Args:
|
|
room_ids (list): List of room ids to search in
|
|
search_term (str): Search term to search for
|
|
keys (list): List of keys to search in, currently supports
|
|
"content.body", "content.name", "content.topic"
|
|
|
|
Returns:
|
|
list of dicts
|
|
"""
|
|
clauses = []
|
|
|
|
search_query = search_query = _parse_query(self.database_engine, search_term)
|
|
|
|
args = []
|
|
|
|
# Make sure we don't explode because the person is in too many rooms.
|
|
# We filter the results below regardless.
|
|
if len(room_ids) < 500:
|
|
clauses.append(
|
|
"room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
|
|
)
|
|
args.extend(room_ids)
|
|
|
|
local_clauses = []
|
|
for key in keys:
|
|
local_clauses.append("key = ?")
|
|
args.append(key)
|
|
|
|
clauses.append(
|
|
"(%s)" % (" OR ".join(local_clauses),)
|
|
)
|
|
|
|
count_args = args
|
|
count_clauses = clauses
|
|
|
|
if isinstance(self.database_engine, PostgresEngine):
|
|
sql = (
|
|
"SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank,"
|
|
" room_id, event_id"
|
|
" FROM event_search"
|
|
" WHERE vector @@ to_tsquery('english', ?)"
|
|
)
|
|
args = [search_query, search_query] + args
|
|
|
|
count_sql = (
|
|
"SELECT room_id, count(*) as count FROM event_search"
|
|
" WHERE vector @@ to_tsquery('english', ?)"
|
|
)
|
|
count_args = [search_query] + count_args
|
|
elif isinstance(self.database_engine, Sqlite3Engine):
|
|
sql = (
|
|
"SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
|
|
" FROM event_search"
|
|
" WHERE value MATCH ?"
|
|
)
|
|
args = [search_query] + args
|
|
|
|
count_sql = (
|
|
"SELECT room_id, count(*) as count FROM event_search"
|
|
" WHERE value MATCH ?"
|
|
)
|
|
count_args = [search_term] + count_args
|
|
else:
|
|
# This should be unreachable.
|
|
raise Exception("Unrecognized database engine")
|
|
|
|
for clause in clauses:
|
|
sql += " AND " + clause
|
|
|
|
for clause in count_clauses:
|
|
count_sql += " AND " + clause
|
|
|
|
# We add an arbitrary limit here to ensure we don't try to pull the
|
|
# entire table from the database.
|
|
sql += " ORDER BY rank DESC LIMIT 500"
|
|
|
|
results = yield self._execute(
|
|
"search_msgs", self.cursor_to_dict, sql, *args
|
|
)
|
|
|
|
results = filter(lambda row: row["room_id"] in room_ids, results)
|
|
|
|
events = yield self._get_events([r["event_id"] for r in results])
|
|
|
|
event_map = {
|
|
ev.event_id: ev
|
|
for ev in events
|
|
}
|
|
|
|
highlights = None
|
|
if isinstance(self.database_engine, PostgresEngine):
|
|
highlights = yield self._find_highlights_in_postgres(search_query, events)
|
|
|
|
count_sql += " GROUP BY room_id"
|
|
|
|
count_results = yield self._execute(
|
|
"search_rooms_count", self.cursor_to_dict, count_sql, *count_args
|
|
)
|
|
|
|
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
|
|
|
|
defer.returnValue({
|
|
"results": [
|
|
{
|
|
"event": event_map[r["event_id"]],
|
|
"rank": r["rank"],
|
|
}
|
|
for r in results
|
|
if r["event_id"] in event_map
|
|
],
|
|
"highlights": highlights,
|
|
"count": count,
|
|
})
|
|
|
|
@defer.inlineCallbacks
|
|
def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
|
|
"""Performs a full text search over events with given keys.
|
|
|
|
Args:
|
|
room_id (list): The room_ids to search in
|
|
search_term (str): Search term to search for
|
|
keys (list): List of keys to search in, currently supports
|
|
"content.body", "content.name", "content.topic"
|
|
pagination_token (str): A pagination token previously returned
|
|
|
|
Returns:
|
|
list of dicts
|
|
"""
|
|
clauses = []
|
|
|
|
search_query = search_query = _parse_query(self.database_engine, search_term)
|
|
|
|
args = []
|
|
|
|
# Make sure we don't explode because the person is in too many rooms.
|
|
# We filter the results below regardless.
|
|
if len(room_ids) < 500:
|
|
clauses.append(
|
|
"room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
|
|
)
|
|
args.extend(room_ids)
|
|
|
|
local_clauses = []
|
|
for key in keys:
|
|
local_clauses.append("key = ?")
|
|
args.append(key)
|
|
|
|
clauses.append(
|
|
"(%s)" % (" OR ".join(local_clauses),)
|
|
)
|
|
|
|
# take copies of the current args and clauses lists, before adding
|
|
# pagination clauses to main query.
|
|
count_args = list(args)
|
|
count_clauses = list(clauses)
|
|
|
|
if pagination_token:
|
|
try:
|
|
origin_server_ts, stream = pagination_token.split(",")
|
|
origin_server_ts = int(origin_server_ts)
|
|
stream = int(stream)
|
|
except Exception:
|
|
raise SynapseError(400, "Invalid pagination token")
|
|
|
|
clauses.append(
|
|
"(origin_server_ts < ?"
|
|
" OR (origin_server_ts = ? AND stream_ordering < ?))"
|
|
)
|
|
args.extend([origin_server_ts, origin_server_ts, stream])
|
|
|
|
if isinstance(self.database_engine, PostgresEngine):
|
|
sql = (
|
|
"SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank,"
|
|
" origin_server_ts, stream_ordering, room_id, event_id"
|
|
" FROM event_search"
|
|
" WHERE vector @@ to_tsquery('english', ?) AND "
|
|
)
|
|
args = [search_query, search_query] + args
|
|
|
|
count_sql = (
|
|
"SELECT room_id, count(*) as count FROM event_search"
|
|
" WHERE vector @@ to_tsquery('english', ?) AND "
|
|
)
|
|
count_args = [search_query] + count_args
|
|
elif isinstance(self.database_engine, Sqlite3Engine):
|
|
# We use CROSS JOIN here to ensure we use the right indexes.
|
|
# https://sqlite.org/optoverview.html#crossjoin
|
|
#
|
|
# We want to use the full text search index on event_search to
|
|
# extract all possible matches first, then lookup those matches
|
|
# in the events table to get the topological ordering. We need
|
|
# to use the indexes in this order because sqlite refuses to
|
|
# MATCH unless it uses the full text search index
|
|
sql = (
|
|
"SELECT rank(matchinfo) as rank, room_id, event_id,"
|
|
" origin_server_ts, stream_ordering"
|
|
" FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo"
|
|
" FROM event_search"
|
|
" WHERE value MATCH ?"
|
|
" )"
|
|
" CROSS JOIN events USING (event_id)"
|
|
" WHERE "
|
|
)
|
|
args = [search_query] + args
|
|
|
|
count_sql = (
|
|
"SELECT room_id, count(*) as count FROM event_search"
|
|
" WHERE value MATCH ? AND "
|
|
)
|
|
count_args = [search_term] + count_args
|
|
else:
|
|
# This should be unreachable.
|
|
raise Exception("Unrecognized database engine")
|
|
|
|
sql += " AND ".join(clauses)
|
|
count_sql += " AND ".join(count_clauses)
|
|
|
|
# We add an arbitrary limit here to ensure we don't try to pull the
|
|
# entire table from the database.
|
|
if isinstance(self.database_engine, PostgresEngine):
|
|
sql += (
|
|
" ORDER BY origin_server_ts DESC NULLS LAST,"
|
|
" stream_ordering DESC NULLS LAST LIMIT ?"
|
|
)
|
|
elif isinstance(self.database_engine, Sqlite3Engine):
|
|
sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
|
|
else:
|
|
raise Exception("Unrecognized database engine")
|
|
|
|
args.append(limit)
|
|
|
|
results = yield self._execute(
|
|
"search_rooms", self.cursor_to_dict, sql, *args
|
|
)
|
|
|
|
results = filter(lambda row: row["room_id"] in room_ids, results)
|
|
|
|
events = yield self._get_events([r["event_id"] for r in results])
|
|
|
|
event_map = {
|
|
ev.event_id: ev
|
|
for ev in events
|
|
}
|
|
|
|
highlights = None
|
|
if isinstance(self.database_engine, PostgresEngine):
|
|
highlights = yield self._find_highlights_in_postgres(search_query, events)
|
|
|
|
count_sql += " GROUP BY room_id"
|
|
|
|
count_results = yield self._execute(
|
|
"search_rooms_count", self.cursor_to_dict, count_sql, *count_args
|
|
)
|
|
|
|
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
|
|
|
|
defer.returnValue({
|
|
"results": [
|
|
{
|
|
"event": event_map[r["event_id"]],
|
|
"rank": r["rank"],
|
|
"pagination_token": "%s,%s" % (
|
|
r["origin_server_ts"], r["stream_ordering"]
|
|
),
|
|
}
|
|
for r in results
|
|
if r["event_id"] in event_map
|
|
],
|
|
"highlights": highlights,
|
|
"count": count,
|
|
})
|
|
|
|
def _find_highlights_in_postgres(self, search_query, events):
|
|
"""Given a list of events and a search term, return a list of words
|
|
that match from the content of the event.
|
|
|
|
This is used to give a list of words that clients can match against to
|
|
highlight the matching parts.
|
|
|
|
Args:
|
|
search_query (str)
|
|
events (list): A list of events
|
|
|
|
Returns:
|
|
deferred : A set of strings.
|
|
"""
|
|
def f(txn):
|
|
highlight_words = set()
|
|
for event in events:
|
|
# As a hack we simply join values of all possible keys. This is
|
|
# fine since we're only using them to find possible highlights.
|
|
values = []
|
|
for key in ("body", "name", "topic"):
|
|
v = event.content.get(key, None)
|
|
if v:
|
|
values.append(v)
|
|
|
|
if not values:
|
|
continue
|
|
|
|
value = " ".join(values)
|
|
|
|
# We need to find some values for StartSel and StopSel that
|
|
# aren't in the value so that we can pick results out.
|
|
start_sel = "<"
|
|
stop_sel = ">"
|
|
|
|
while start_sel in value:
|
|
start_sel += "<"
|
|
while stop_sel in value:
|
|
stop_sel += ">"
|
|
|
|
query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
|
|
_to_postgres_options({
|
|
"StartSel": start_sel,
|
|
"StopSel": stop_sel,
|
|
"MaxFragments": "50",
|
|
})
|
|
)
|
|
txn.execute(query, (value, search_query,))
|
|
headline, = txn.fetchall()[0]
|
|
|
|
# Now we need to pick the possible highlights out of the haedline
|
|
# result.
|
|
matcher_regex = "%s(.*?)%s" % (
|
|
re.escape(start_sel),
|
|
re.escape(stop_sel),
|
|
)
|
|
|
|
res = re.findall(matcher_regex, headline)
|
|
highlight_words.update([r.lower() for r in res])
|
|
|
|
return highlight_words
|
|
|
|
return self.runInteraction("_find_highlights", f)
|
|
|
|
|
|
def _to_postgres_options(options_dict):
|
|
return "'%s'" % (
|
|
",".join("%s=%s" % (k, v) for k, v in options_dict.items()),
|
|
)
|
|
|
|
|
|
def _parse_query(database_engine, search_term):
|
|
"""Takes a plain unicode string from the user and converts it into a form
|
|
that can be passed to database.
|
|
We use this so that we can add prefix matching, which isn't something
|
|
that is supported by default.
|
|
"""
|
|
|
|
# Pull out the individual words, discarding any non-word characters.
|
|
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
|
|
|
|
if isinstance(database_engine, PostgresEngine):
|
|
return " & ".join(result + ":*" for result in results)
|
|
elif isinstance(database_engine, Sqlite3Engine):
|
|
return " & ".join(result + "*" for result in results)
|
|
else:
|
|
# This should be unreachable.
|
|
raise Exception("Unrecognized database engine")
|