Merge remote-tracking branch 'upstream/release-v1.71'

This commit is contained in:
Tulir Asokan 2022-11-01 16:18:58 +02:00
commit 2337ca829d
135 changed files with 5192 additions and 2356 deletions

View file

@ -201,7 +201,7 @@ class DataStore(
name: Optional[str] = None,
guests: bool = True,
deactivated: bool = False,
order_by: str = UserSortOrder.USER_ID.value,
order_by: str = UserSortOrder.NAME.value,
direction: str = "f",
approved: bool = True,
) -> Tuple[List[JsonDict], int]:
@ -261,6 +261,7 @@ class DataStore(
sql_base = f"""
FROM users as u
LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
LEFT JOIN erased_users AS eu ON u.name = eu.user_id
{where_clause}
"""
sql = "SELECT COUNT(*) as total_users " + sql_base
@ -269,7 +270,8 @@ class DataStore(
sql = f"""
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
displayname, avatar_url, creation_ts * 1000 as creation_ts, approved
displayname, avatar_url, creation_ts * 1000 as creation_ts, approved,
eu.user_id is not null as erased
{sql_base}
ORDER BY {order_by_column} {order}, u.name ASC
LIMIT ? OFFSET ?
@ -277,6 +279,13 @@ class DataStore(
args += [limit, start]
txn.execute(sql, args)
users = self.db_pool.cursor_to_dict(txn)
# some of those boolean values are returned as integers when we're on SQLite
columns_to_boolify = ["erased"]
for user in users:
for column in columns_to_boolify:
user[column] = bool(user[column])
return users, count
return await self.db_pool.runInteraction(

View file

@ -157,10 +157,23 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
app_service: "ApplicationService",
cache_context: _CacheContext,
) -> List[str]:
users_in_room = await self.get_users_in_room(
"""
Get all users in a room that the appservice controls.
Args:
room_id: The room to check in.
app_service: The application service to check interest/control against
Returns:
List of user IDs that the appservice controls.
"""
# We can use `get_local_users_in_room(...)` here because an application service
# can only be interested in local users of the server it's on (ignore any remote
# users that might match the user namespace regex).
local_users_in_room = await self.get_local_users_in_room(
room_id, on_invalidate=cache_context.invalidate
)
return list(filter(app_service.is_interested_in_user, users_in_room))
return list(filter(app_service.is_interested_in_user, local_users_in_room))
class ApplicationServiceStore(ApplicationServiceWorkerStore):

View file

@ -274,6 +274,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
destination, int(from_stream_id)
)
if not has_changed:
# debugging for https://github.com/matrix-org/synapse/issues/14251
issue_8631_logger.debug(
"%s: no change between %i and %i",
destination,
from_stream_id,
now_stream_id,
)
return now_stream_id, []
updates = await self.db_pool.runInteraction(
@ -1848,7 +1855,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self,
txn: LoggingTransaction,
user_id: str,
device_ids: Iterable[str],
device_id: str,
hosts: Collection[str],
stream_ids: List[int],
context: Optional[Dict[str, str]],
@ -1864,6 +1871,21 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
stream_id_iterator = iter(stream_ids)
encoded_context = json_encoder.encode(context)
mark_sent = not self.hs.is_mine_id(user_id)
values = [
(
destination,
next(stream_id_iterator),
user_id,
device_id,
mark_sent,
now,
encoded_context if whitelisted_homeserver(destination) else "{}",
)
for destination in hosts
]
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
@ -1876,23 +1898,21 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"ts",
"opentracing_context",
),
values=[
(
destination,
next(stream_id_iterator),
user_id,
device_id,
not self.hs.is_mine_id(
user_id
), # We only need to send out update for *our* users
now,
encoded_context if whitelisted_homeserver(destination) else "{}",
)
for destination in hosts
for device_id in device_ids
],
values=values,
)
# debugging for https://github.com/matrix-org/synapse/issues/14251
if issue_8631_logger.isEnabledFor(logging.DEBUG):
issue_8631_logger.debug(
"Recorded outbound pokes for %s:%s with device stream ids %s",
user_id,
device_id,
{
stream_id: destination
for (destination, stream_id, _, _, _, _, _) in values
},
)
def _add_device_outbound_room_poke_txn(
self,
txn: LoggingTransaction,
@ -1997,7 +2017,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self._add_device_outbound_poke_to_stream_txn(
txn,
user_id=user_id,
device_ids=[device_id],
device_id=device_id,
hosts=hosts,
stream_ids=stream_ids,
context=context,

View file

@ -139,11 +139,15 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@trace
@cancellable
async def get_e2e_device_keys_for_cs_api(
self, query_list: List[Tuple[str, Optional[str]]]
self,
query_list: List[Tuple[str, Optional[str]]],
include_displaynames: bool = True,
) -> Dict[str, Dict[str, JsonDict]]:
"""Fetch a list of device keys, formatted suitably for the C/S API.
Args:
query_list(list): List of pairs of user_ids and device_ids.
query_list: List of pairs of user_ids and device_ids.
include_displaynames: Whether to include the displayname of returned devices
(if one exists).
Returns:
Dict mapping from user-id to dict mapping from device_id to
key data. The key data will be a dict in the same format as the
@ -166,9 +170,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
continue
r["unsigned"] = {}
display_name = device_info.display_name
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
if include_displaynames:
# Include the device's display name in the "unsigned" dictionary
display_name = device_info.display_name
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
rv[user_id][device_id] = r
return rv

View file

@ -29,6 +29,7 @@ from typing import (
)
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@ -62,7 +63,9 @@ logger = logging.getLogger(__name__)
def _load_rules(
rawrules: List[JsonDict], enabled_map: Dict[str, bool]
rawrules: List[JsonDict],
enabled_map: Dict[str, bool],
experimental_config: ExperimentalConfig,
) -> FilteredPushRules:
"""Take the DB rows returned from the DB and convert them into a full
`FilteredPushRules` object.
@ -80,7 +83,9 @@ def _load_rules(
push_rules = PushRules(ruleslist)
filtered_rules = FilteredPushRules(push_rules, enabled_map)
filtered_rules = FilteredPushRules(
push_rules, enabled_map, msc3664_enabled=experimental_config.msc3664_enabled
)
return filtered_rules
@ -160,7 +165,7 @@ class PushRulesWorkerStore(
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
return _load_rules(rows, enabled_map)
return _load_rules(rows, enabled_map, self.hs.config.experimental)
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
@ -219,7 +224,9 @@ class PushRulesWorkerStore(
results: Dict[str, FilteredPushRules] = {}
for user_id, rules in raw_rules.items():
results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
results[user_id] = _load_rules(
rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental
)
return results

View file

@ -21,7 +21,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import attr
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.api.errors import (
Codes,
NotFoundError,
StoreError,
SynapseError,
ThreepidValidationError,
)
from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import (
@ -50,6 +56,14 @@ class ExternalIDReuseException(Exception):
because this external id is given to an other user."""
class LoginTokenExpired(Exception):
"""Exception if the login token sent expired"""
class LoginTokenReused(Exception):
"""Exception if the login token sent was already used"""
@attr.s(frozen=True, slots=True, auto_attribs=True)
class TokenLookupResult:
"""Result of looking up an access token.
@ -115,6 +129,20 @@ class RefreshTokenLookupResult:
If None, the session can be refreshed indefinitely."""
@attr.s(auto_attribs=True, frozen=True, slots=True)
class LoginTokenLookupResult:
"""Result of looking up a login token."""
user_id: str
"""The user this token belongs to."""
auth_provider_id: Optional[str]
"""The SSO Identity Provider that the user authenticated with, to get this token."""
auth_provider_session_id: Optional[str]
"""The session ID advertised by the SSO Identity Provider."""
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@ -1789,6 +1817,130 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"replace_refresh_token", _replace_refresh_token_txn
)
async def add_login_token_to_user(
self,
user_id: str,
token: str,
expiry_ts: int,
auth_provider_id: Optional[str],
auth_provider_session_id: Optional[str],
) -> None:
"""Adds a short-term login token for the given user.
Args:
user_id: The user ID.
token: The new login token to add.
expiry_ts (milliseconds since the epoch): Time after which the login token
cannot be used.
auth_provider_id: The SSO Identity Provider that the user authenticated with
to get this token, if any
auth_provider_session_id: The session ID advertised by the SSO Identity
Provider, if any.
"""
await self.db_pool.simple_insert(
"login_tokens",
{
"token": token,
"user_id": user_id,
"expiry_ts": expiry_ts,
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
},
desc="add_login_token_to_user",
)
def _consume_login_token(
self,
txn: LoggingTransaction,
token: str,
ts: int,
) -> LoginTokenLookupResult:
values = self.db_pool.simple_select_one_txn(
txn,
"login_tokens",
keyvalues={"token": token},
retcols=(
"user_id",
"expiry_ts",
"used_ts",
"auth_provider_id",
"auth_provider_session_id",
),
allow_none=True,
)
if values is None:
raise NotFoundError()
self.db_pool.simple_update_one_txn(
txn,
"login_tokens",
keyvalues={"token": token},
updatevalues={"used_ts": ts},
)
user_id = values["user_id"]
expiry_ts = values["expiry_ts"]
used_ts = values["used_ts"]
auth_provider_id = values["auth_provider_id"]
auth_provider_session_id = values["auth_provider_session_id"]
# Token was already used
if used_ts is not None:
raise LoginTokenReused()
# Token expired
if ts > int(expiry_ts):
raise LoginTokenExpired()
return LoginTokenLookupResult(
user_id=user_id,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
async def consume_login_token(self, token: str) -> LoginTokenLookupResult:
"""Lookup a login token and consume it.
Args:
token: The login token.
Returns:
The data stored with that token, including the `user_id`. Returns `None` if
the token does not exist or if it expired.
Raises:
NotFound if the login token was not found in database
LoginTokenExpired if the login token expired
LoginTokenReused if the login token was already used
"""
return await self.db_pool.runInteraction(
"consume_login_token",
self._consume_login_token,
token,
self._clock.time_msec(),
)
async def invalidate_login_tokens_by_session_id(
self, auth_provider_id: str, auth_provider_session_id: str
) -> None:
"""Invalidate login tokens with the given IdP session ID.
Args:
auth_provider_id: The SSO Identity Provider that the user authenticated with
to get this token
auth_provider_session_id: The session ID advertised by the SSO Identity
Provider
"""
await self.db_pool.simple_update(
table="login_tokens",
keyvalues={
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
},
updatevalues={"used_ts": self._clock.time_msec()},
desc="invalidate_login_tokens_by_session_id",
)
@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
@ -2019,6 +2171,12 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
and hs.config.experimental.msc3866.require_approval_for_new_accounts
)
# Create a background job for removing expired login tokens
if hs.config.worker.run_background_tasks:
self._clock.looping_call(
self._delete_expired_login_tokens, THIRTY_MINUTES_IN_MS
)
async def add_access_token_to_user(
self,
user_id: str,
@ -2617,6 +2775,23 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
approved,
)
@wrap_as_background_process("delete_expired_login_tokens")
async def _delete_expired_login_tokens(self) -> None:
"""Remove login tokens with expiry dates that have passed."""
def _delete_expired_login_tokens_txn(txn: LoggingTransaction, ts: int) -> None:
sql = "DELETE FROM login_tokens WHERE expiry_ts <= ?"
txn.execute(sql, (ts,))
# We keep the expired tokens for an extra 5 minutes so we can measure how many
# times a token is being used after its expiry
now = self._clock.time_msec()
await self.db_pool.runInteraction(
"delete_expired_login_tokens",
_delete_expired_login_tokens_txn,
now - (5 * 60 * 1000),
)
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""

View file

@ -152,6 +152,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
the forward extremities of those rooms will exclude most members. We may also
calculate room state incorrectly for such rooms and believe that a member is or
is not in the room when the opposite is true.
Note: If you only care about users in the room local to the homeserver, use
`get_local_users_in_room(...)` instead which will be more performant.
"""
return await self.db_pool.simple_select_onecol(
table="current_state_events",
@ -707,8 +710,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# 250 users is pretty arbitrary but the data can be quite large if users
# are in many rooms.
for user_ids in batch_iter(user_ids, 250):
all_user_rooms.update(await self._get_rooms_for_users(user_ids))
for batch_user_ids in batch_iter(user_ids, 250):
all_user_rooms.update(await self._get_rooms_for_users(batch_user_ids))
return all_user_rooms
@ -742,7 +745,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# user and the set of other users, and then checking if there is any
# overlap.
sql = f"""
SELECT b.state_key
SELECT DISTINCT b.state_key
FROM (
SELECT room_id FROM current_state_events
WHERE type = 'm.room.member' AND membership = 'join' AND state_key = ?
@ -751,7 +754,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
SELECT room_id, state_key FROM current_state_events
WHERE type = 'm.room.member' AND membership = 'join' AND {clause}
) AS b using (room_id)
LIMIT 1
"""
txn.execute(sql, (user_id, *args))

View file

@ -11,10 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import logging
import re
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
from collections import deque
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Collection,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
import attr
@ -27,7 +39,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
if TYPE_CHECKING:
@ -68,11 +80,11 @@ class SearchWorkerStore(SQLBaseStore):
if not self.hs.config.server.enable_search:
return
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search"
" (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)
sql = """
INSERT INTO event_search
(event_id, room_id, key, vector, stream_ordering, origin_server_ts)
VALUES (?,?,?,to_tsvector('english', ?),?,?)
"""
args1 = (
(
@ -89,20 +101,20 @@ class SearchWorkerStore(SQLBaseStore):
txn.execute_batch(sql, args1)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
self.db_pool.simple_insert_many_txn(
txn,
table="event_search",
keys=("event_id", "room_id", "key", "value"),
values=(
(
entry.event_id,
entry.room_id,
entry.key,
_clean_value_for_search(entry.value),
)
for entry in entries
),
)
args2 = (
(
entry.event_id,
entry.room_id,
entry.key,
_clean_value_for_search(entry.value),
)
for entry in entries
)
txn.execute_batch(sql, args2)
else:
# This should be unreachable.
@ -150,15 +162,17 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
def reindex_search_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id, room_id, type, json, "
" origin_server_ts FROM events"
" JOIN event_json USING (room_id, event_id)"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" AND (%s)"
" ORDER BY stream_ordering DESC"
" LIMIT ?"
) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),)
sql = """
SELECT stream_ordering, event_id, room_id, type, json, origin_server_ts
FROM events
JOIN event_json USING (room_id, event_id)
WHERE ? <= stream_ordering AND stream_ordering < ?
AND (%s)
ORDER BY stream_ordering DESC
LIMIT ?
""" % (
" OR ".join("type = '%s'" % (t,) for t in TYPES),
)
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
@ -272,8 +286,10 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
try:
c.execute(
"CREATE INDEX CONCURRENTLY event_search_fts_idx"
" ON event_search USING GIN (vector)"
"""
CREATE INDEX CONCURRENTLY event_search_fts_idx
ON event_search USING GIN (vector)
"""
)
except psycopg2.ProgrammingError as e:
logger.warning(
@ -311,12 +327,16 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
# We create with NULLS FIRST so that when we search *backwards*
# we get the ones with non null origin_server_ts *first*
c.execute(
"CREATE INDEX CONCURRENTLY event_search_room_order ON event_search("
"room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
"""
CREATE INDEX CONCURRENTLY event_search_room_order
ON event_search(room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)
"""
)
c.execute(
"CREATE INDEX CONCURRENTLY event_search_order ON event_search("
"origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
"""
CREATE INDEX CONCURRENTLY event_search_order
ON event_search(origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)
"""
)
conn.set_session(autocommit=False)
@ -333,14 +353,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
)
def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
sql = (
"UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
" origin_server_ts = e.origin_server_ts"
" FROM events AS e"
" WHERE e.event_id = es.event_id"
" AND ? <= e.stream_ordering AND e.stream_ordering < ?"
" RETURNING es.stream_ordering"
)
sql = """
UPDATE event_search AS es
SET stream_ordering = e.stream_ordering, origin_server_ts = e.origin_server_ts
FROM events AS e
WHERE e.event_id = es.event_id
AND ? <= e.stream_ordering AND e.stream_ordering < ?
RETURNING es.stream_ordering
"""
min_stream_id = max_stream_id - batch_size
txn.execute(sql, (min_stream_id, max_stream_id))
@ -421,8 +441,6 @@ class SearchStore(SearchBackgroundUpdateStore):
"""
clauses = []
search_query = _parse_query(self.database_engine, search_term)
args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms.
@ -444,32 +462,36 @@ class SearchStore(SearchBackgroundUpdateStore):
count_clauses = clauses
if isinstance(self.database_engine, PostgresEngine):
sql = (
"SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank,"
" room_id, event_id"
" FROM event_search"
" WHERE vector @@ to_tsquery('english', ?)"
)
search_query = search_term
tsquery_func = self.database_engine.tsquery_func
sql = f"""
SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank,
room_id, event_id
FROM event_search
WHERE vector @@ {tsquery_func}('english', ?)
"""
args = [search_query, search_query] + args
count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
" WHERE vector @@ to_tsquery('english', ?)"
)
count_sql = f"""
SELECT room_id, count(*) as count FROM event_search
WHERE vector @@ {tsquery_func}('english', ?)
"""
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
" FROM event_search"
" WHERE value MATCH ?"
)
search_query = _parse_query_for_sqlite(search_term)
sql = """
SELECT rank(matchinfo(event_search)) as rank, room_id, event_id
FROM event_search
WHERE value MATCH ?
"""
args = [search_query] + args
count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
" WHERE value MATCH ?"
)
count_args = [search_term] + count_args
count_sql = """
SELECT room_id, count(*) as count FROM event_search
WHERE value MATCH ?
"""
count_args = [search_query] + count_args
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@ -501,7 +523,9 @@ class SearchStore(SearchBackgroundUpdateStore):
highlights = None
if isinstance(self.database_engine, PostgresEngine):
highlights = await self._find_highlights_in_postgres(search_query, events)
highlights = await self._find_highlights_in_postgres(
search_query, events, tsquery_func
)
count_sql += " GROUP BY room_id"
@ -510,7 +534,6 @@ class SearchStore(SearchBackgroundUpdateStore):
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
return {
"results": [
{"event": event_map[r["event_id"]], "rank": r["rank"]}
@ -542,9 +565,6 @@ class SearchStore(SearchBackgroundUpdateStore):
Each match as a dictionary.
"""
clauses = []
search_query = _parse_query(self.database_engine, search_term)
args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms.
@ -576,26 +596,30 @@ class SearchStore(SearchBackgroundUpdateStore):
raise SynapseError(400, "Invalid pagination token")
clauses.append(
"(origin_server_ts < ?"
" OR (origin_server_ts = ? AND stream_ordering < ?))"
"""
(origin_server_ts < ? OR (origin_server_ts = ? AND stream_ordering < ?))
"""
)
args.extend([origin_server_ts, origin_server_ts, stream])
if isinstance(self.database_engine, PostgresEngine):
sql = (
"SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank,"
" origin_server_ts, stream_ordering, room_id, event_id"
" FROM event_search"
" WHERE vector @@ to_tsquery('english', ?) AND "
)
search_query = search_term
tsquery_func = self.database_engine.tsquery_func
sql = f"""
SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank,
origin_server_ts, stream_ordering, room_id, event_id
FROM event_search
WHERE vector @@ {tsquery_func}('english', ?) AND
"""
args = [search_query, search_query] + args
count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
" WHERE vector @@ to_tsquery('english', ?) AND "
)
count_sql = f"""
SELECT room_id, count(*) as count FROM event_search
WHERE vector @@ {tsquery_func}('english', ?) AND
"""
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
# We use CROSS JOIN here to ensure we use the right indexes.
# https://sqlite.org/optoverview.html#crossjoin
#
@ -604,23 +628,25 @@ class SearchStore(SearchBackgroundUpdateStore):
# in the events table to get the topological ordering. We need
# to use the indexes in this order because sqlite refuses to
# MATCH unless it uses the full text search index
sql = (
"SELECT rank(matchinfo) as rank, room_id, event_id,"
" origin_server_ts, stream_ordering"
" FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo"
" FROM event_search"
" WHERE value MATCH ?"
" )"
" CROSS JOIN events USING (event_id)"
" WHERE "
sql = """
SELECT
rank(matchinfo) as rank, room_id, event_id, origin_server_ts, stream_ordering
FROM (
SELECT key, event_id, matchinfo(event_search) as matchinfo
FROM event_search
WHERE value MATCH ?
)
CROSS JOIN events USING (event_id)
WHERE
"""
search_query = _parse_query_for_sqlite(search_term)
args = [search_query] + args
count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
" WHERE value MATCH ? AND "
)
count_args = [search_term] + count_args
count_sql = """
SELECT room_id, count(*) as count FROM event_search
WHERE value MATCH ? AND
"""
count_args = [search_query] + count_args
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@ -631,10 +657,10 @@ class SearchStore(SearchBackgroundUpdateStore):
# We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database.
if isinstance(self.database_engine, PostgresEngine):
sql += (
" ORDER BY origin_server_ts DESC NULLS LAST,"
" stream_ordering DESC NULLS LAST LIMIT ?"
)
sql += """
ORDER BY origin_server_ts DESC NULLS LAST, stream_ordering DESC NULLS LAST
LIMIT ?
"""
elif isinstance(self.database_engine, Sqlite3Engine):
sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
else:
@ -660,7 +686,9 @@ class SearchStore(SearchBackgroundUpdateStore):
highlights = None
if isinstance(self.database_engine, PostgresEngine):
highlights = await self._find_highlights_in_postgres(search_query, events)
highlights = await self._find_highlights_in_postgres(
search_query, events, tsquery_func
)
count_sql += " GROUP BY room_id"
@ -686,7 +714,7 @@ class SearchStore(SearchBackgroundUpdateStore):
}
async def _find_highlights_in_postgres(
self, search_query: str, events: List[EventBase]
self, search_query: str, events: List[EventBase], tsquery_func: str
) -> Set[str]:
"""Given a list of events and a search term, return a list of words
that match from the content of the event.
@ -697,6 +725,7 @@ class SearchStore(SearchBackgroundUpdateStore):
Args:
search_query
events: A list of events
tsquery_func: The tsquery_* function to use when making queries
Returns:
A set of strings.
@ -729,7 +758,7 @@ class SearchStore(SearchBackgroundUpdateStore):
while stop_sel in value:
stop_sel += ">"
query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
query = f"SELECT ts_headline(?, {tsquery_func}('english', ?), %s)" % (
_to_postgres_options(
{
"StartSel": start_sel,
@ -760,20 +789,127 @@ def _to_postgres_options(options_dict: JsonDict) -> str:
return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
that is supported by default.
@dataclass
class Phrase:
phrase: List[str]
class SearchToken(enum.Enum):
Not = enum.auto()
Or = enum.auto()
And = enum.auto()
Token = Union[str, Phrase, SearchToken]
TokenList = List[Token]
def _is_stop_word(word: str) -> bool:
# TODO Pull these out of the dictionary:
# https://github.com/postgres/postgres/blob/master/src/backend/snowball/stopwords/english.stop
return word in {"the", "a", "you", "me", "and", "but"}
def _tokenize_query(query: str) -> TokenList:
"""
Convert the user-supplied `query` into a TokenList, which can be translated into
some DB-specific syntax.
# Pull out the individual words, discarding any non-word characters.
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
The following constructs are supported:
if isinstance(database_engine, PostgresEngine):
return " & ".join(result + ":*" for result in results)
elif isinstance(database_engine, Sqlite3Engine):
return " & ".join(result + "*" for result in results)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
- phrase queries using "double quotes"
- case-insensitive `or` and `and` operators
- negation of a keyword via unary `-`
- unary hyphen to denote NOT e.g. 'include -exclude'
The following differs from websearch_to_tsquery:
- Stop words are not removed.
- Unclosed phrases are treated differently.
"""
tokens: TokenList = []
# Find phrases.
in_phrase = False
parts = deque(query.split('"'))
for i, part in enumerate(parts):
# The contents inside double quotes is treated as a phrase.
in_phrase = bool(i % 2)
# Pull out the individual words, discarding any non-word characters.
words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE))
# Phrases have simplified handling of words.
if in_phrase:
# Skip stop words.
phrase = [word for word in words if not _is_stop_word(word)]
# Consecutive words are implicitly ANDed together.
if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or):
tokens.append(SearchToken.And)
# Add the phrase.
tokens.append(Phrase(phrase))
continue
# Otherwise, not in a phrase.
while words:
word = words.popleft()
if word.startswith("-"):
tokens.append(SearchToken.Not)
# If there's more word, put it back to be processed again.
word = word[1:]
if word:
words.appendleft(word)
elif word.lower() == "or":
tokens.append(SearchToken.Or)
else:
# Skip stop words.
if _is_stop_word(word):
continue
# Consecutive words are implicitly ANDed together.
if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or):
tokens.append(SearchToken.And)
# Add the search term.
tokens.append(word)
return tokens
def _tokens_to_sqlite_match_query(tokens: TokenList) -> str:
"""
Convert the list of tokens to a string suitable for passing to sqlite's MATCH.
Assume sqlite was compiled with enhanced query syntax.
Ref: https://www.sqlite.org/fts3.html#full_text_index_queries
"""
match_query = []
for token in tokens:
if isinstance(token, str):
match_query.append(token)
elif isinstance(token, Phrase):
match_query.append('"' + " ".join(token.phrase) + '"')
elif token == SearchToken.Not:
# TODO: SQLite treats NOT as a *binary* operator. Hopefully a search
# term has already been added before this.
match_query.append(" NOT ")
elif token == SearchToken.Or:
match_query.append(" OR ")
elif token == SearchToken.And:
match_query.append(" AND ")
else:
raise ValueError(f"unknown token {token}")
return "".join(match_query)
def _parse_query_for_sqlite(search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to sqllite's matchinfo().
"""
return _tokens_to_sqlite_match_query(_tokenize_query(search_term))

View file

@ -170,6 +170,22 @@ class PostgresEngine(
"""Do we support the `RETURNING` clause in insert/update/delete?"""
return True
@property
def tsquery_func(self) -> str:
"""
Selects a tsquery_* func to use.
Ref: https://www.postgresql.org/docs/current/textsearch-controls.html
Returns:
The function name.
"""
# Postgres 11 added support for websearch_to_tsquery.
assert self._version is not None
if self._version >= 110000:
return "websearch_to_tsquery"
return "plainto_tsquery"
def is_deadlock(self, error: Exception) -> bool:
if isinstance(error, psycopg2.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html

View file

@ -88,6 +88,10 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
db_conn.create_function("rank", 1, _rank)
db_conn.execute("PRAGMA foreign_keys = ON;")
# Enable WAL.
# see https://www.sqlite.org/wal.html
db_conn.execute("PRAGMA journal_mode = WAL;")
db_conn.commit()
def is_deadlock(self, error: Exception) -> bool:

View file

@ -0,0 +1,62 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from synapse.storage.engines import BaseDatabaseEngine, Sqlite3Engine
from synapse.storage.types import Cursor
def run_create(cur: Cursor, database_engine: BaseDatabaseEngine) -> None:
"""
Upgrade the event_search table to use the porter tokenizer if it isn't already
Applies only for sqlite.
"""
if not isinstance(database_engine, Sqlite3Engine):
return
# Rebuild the table event_search table with tokenize=porter configured.
cur.execute("DROP TABLE event_search")
cur.execute(
"""
CREATE VIRTUAL TABLE event_search
USING fts4 (tokenize=porter, event_id, room_id, sender, key, value )
"""
)
# Re-run the background job to re-populate the event_search table.
cur.execute("SELECT MIN(stream_ordering) FROM events")
row = cur.fetchone()
min_stream_id = row[0]
# If there are not any events, nothing to do.
if min_stream_id is None:
return
cur.execute("SELECT MAX(stream_ordering) FROM events")
row = cur.fetchone()
max_stream_id = row[0]
progress = {
"target_min_stream_id_inclusive": min_stream_id,
"max_stream_id_exclusive": max_stream_id + 1,
}
progress_json = json.dumps(progress)
sql = """
INSERT into background_updates (ordering, update_name, progress_json)
VALUES (?, ?, ?)
"""
cur.execute(sql, (7310, "event_search", progress_json))

View file

@ -0,0 +1,35 @@
/*
* Copyright 2022 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Login tokens are short-lived tokens that are used for the m.login.token
-- login method, mainly during SSO logins
CREATE TABLE login_tokens (
token TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
expiry_ts BIGINT NOT NULL,
used_ts BIGINT,
auth_provider_id TEXT,
auth_provider_session_id TEXT
);
-- We're sometimes querying them by their session ID we got from their IDP
CREATE INDEX login_tokens_auth_provider_idx
ON login_tokens (auth_provider_id, auth_provider_session_id);
-- We're deleting them by their expiration time
CREATE INDEX login_tokens_expiry_time_idx
ON login_tokens (expiry_ts);