mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-12-11 07:07:40 -05:00
Merge remote-tracking branch 'upstream/release-v1.38'
This commit is contained in:
commit
fa8ec8051b
88 changed files with 4940 additions and 2441 deletions
|
|
@ -111,7 +111,7 @@ def make_conn(
|
|||
db_config: DatabaseConnectionConfig,
|
||||
engine: BaseDatabaseEngine,
|
||||
default_txn_name: str,
|
||||
) -> Connection:
|
||||
) -> "LoggingDatabaseConnection":
|
||||
"""Make a new connection to the database and return it.
|
||||
|
||||
Returns:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ import logging
|
|||
from queue import Empty, PriorityQueue
|
||||
from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from synapse.api.constants import MAX_DEPTH
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
|
|
@ -32,6 +34,16 @@ from synapse.util.caches.descriptors import cached
|
|||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
oldest_pdu_in_federation_staging = Gauge(
|
||||
"synapse_federation_server_oldest_inbound_pdu_in_staging",
|
||||
"The age in seconds since we received the oldest pdu in the federation staging area",
|
||||
)
|
||||
|
||||
number_pdus_in_federation_queue = Gauge(
|
||||
"synapse_federation_server_number_inbound_pdu_in_staging",
|
||||
"The total number of events in the inbound federation staging",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -54,6 +66,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
500000, "_event_auth_cache", size_callback=len
|
||||
) # type: LruCache[str, List[Tuple[str, int]]]
|
||||
|
||||
self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
|
||||
|
||||
async def get_auth_chain(
|
||||
self, room_id: str, event_ids: Collection[str], include_given: bool = False
|
||||
) -> List[EventBase]:
|
||||
|
|
@ -1075,16 +1089,62 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
self,
|
||||
origin: str,
|
||||
event_id: str,
|
||||
) -> None:
|
||||
"""Remove the given event from the staging area"""
|
||||
await self.db_pool.simple_delete(
|
||||
table="federation_inbound_events_staging",
|
||||
keyvalues={
|
||||
"origin": origin,
|
||||
"event_id": event_id,
|
||||
},
|
||||
desc="remove_received_event_from_staging",
|
||||
)
|
||||
) -> Optional[int]:
|
||||
"""Remove the given event from the staging area.
|
||||
|
||||
Returns:
|
||||
The received_ts of the row that was deleted, if any.
|
||||
"""
|
||||
if self.db_pool.engine.supports_returning:
|
||||
|
||||
def _remove_received_event_from_staging_txn(txn):
|
||||
sql = """
|
||||
DELETE FROM federation_inbound_events_staging
|
||||
WHERE origin = ? AND event_id = ?
|
||||
RETURNING received_ts
|
||||
"""
|
||||
|
||||
txn.execute(sql, (origin, event_id))
|
||||
return txn.fetchone()
|
||||
|
||||
row = await self.db_pool.runInteraction(
|
||||
"remove_received_event_from_staging",
|
||||
_remove_received_event_from_staging_txn,
|
||||
db_autocommit=True,
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return row[0]
|
||||
|
||||
else:
|
||||
|
||||
def _remove_received_event_from_staging_txn(txn):
|
||||
received_ts = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="federation_inbound_events_staging",
|
||||
keyvalues={
|
||||
"origin": origin,
|
||||
"event_id": event_id,
|
||||
},
|
||||
retcol="received_ts",
|
||||
allow_none=True,
|
||||
)
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="federation_inbound_events_staging",
|
||||
keyvalues={
|
||||
"origin": origin,
|
||||
"event_id": event_id,
|
||||
},
|
||||
)
|
||||
|
||||
return received_ts
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"remove_received_event_from_staging",
|
||||
_remove_received_event_from_staging_txn,
|
||||
)
|
||||
|
||||
async def get_next_staged_event_id_for_room(
|
||||
self,
|
||||
|
|
@ -1147,6 +1207,40 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
|
||||
return origin, event
|
||||
|
||||
async def get_all_rooms_with_staged_incoming_events(self) -> List[str]:
|
||||
"""Get the room IDs of all events currently staged."""
|
||||
return await self.db_pool.simple_select_onecol(
|
||||
table="federation_inbound_events_staging",
|
||||
keyvalues={},
|
||||
retcol="DISTINCT room_id",
|
||||
desc="get_all_rooms_with_staged_incoming_events",
|
||||
)
|
||||
|
||||
@wrap_as_background_process("_get_stats_for_federation_staging")
|
||||
async def _get_stats_for_federation_staging(self):
|
||||
"""Update the prometheus metrics for the inbound federation staging area."""
|
||||
|
||||
def _get_stats_for_federation_staging_txn(txn):
|
||||
txn.execute(
|
||||
"SELECT coalesce(count(*), 0) FROM federation_inbound_events_staging"
|
||||
)
|
||||
(count,) = txn.fetchone()
|
||||
|
||||
txn.execute(
|
||||
"SELECT coalesce(min(received_ts), 0) FROM federation_inbound_events_staging"
|
||||
)
|
||||
|
||||
(age,) = txn.fetchone()
|
||||
|
||||
return count, age
|
||||
|
||||
count, age = await self.db_pool.runInteraction(
|
||||
"_get_stats_for_federation_staging", _get_stats_for_federation_staging_txn
|
||||
)
|
||||
|
||||
number_pdus_in_federation_queue.set(count)
|
||||
oldest_pdu_in_federation_staging.set(age)
|
||||
|
||||
|
||||
class EventFederationStore(EventFederationWorkerStore):
|
||||
"""Responsible for storing and serving up the various graphs associated
|
||||
|
|
|
|||
|
|
@ -29,6 +29,34 @@ from synapse.types import JsonDict
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_REPLACE_STREAM_ORDERING_SQL_COMMANDS = (
|
||||
# there should be no leftover rows without a stream_ordering2, but just in case...
|
||||
"UPDATE events SET stream_ordering2 = stream_ordering WHERE stream_ordering2 IS NULL",
|
||||
# now we can drop the rule and switch the columns
|
||||
"DROP RULE populate_stream_ordering2 ON events",
|
||||
"ALTER TABLE events DROP COLUMN stream_ordering",
|
||||
"ALTER TABLE events RENAME COLUMN stream_ordering2 TO stream_ordering",
|
||||
# ... and finally, rename the indexes into place for consistency with sqlite
|
||||
"ALTER INDEX event_contains_url_index2 RENAME TO event_contains_url_index",
|
||||
"ALTER INDEX events_order_room2 RENAME TO events_order_room",
|
||||
"ALTER INDEX events_room_stream2 RENAME TO events_room_stream",
|
||||
"ALTER INDEX events_ts2 RENAME TO events_ts",
|
||||
)
|
||||
|
||||
|
||||
class _BackgroundUpdates:
|
||||
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
|
||||
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
|
||||
DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
|
||||
POPULATE_STREAM_ORDERING2 = "populate_stream_ordering2"
|
||||
INDEX_STREAM_ORDERING2 = "index_stream_ordering2"
|
||||
INDEX_STREAM_ORDERING2_CONTAINS_URL = "index_stream_ordering2_contains_url"
|
||||
INDEX_STREAM_ORDERING2_ROOM_ORDER = "index_stream_ordering2_room_order"
|
||||
INDEX_STREAM_ORDERING2_ROOM_STREAM = "index_stream_ordering2_room_stream"
|
||||
INDEX_STREAM_ORDERING2_TS = "index_stream_ordering2_ts"
|
||||
REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column"
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class _CalculateChainCover:
|
||||
"""Return value for _calculate_chain_cover_txn."""
|
||||
|
|
@ -48,19 +76,15 @@ class _CalculateChainCover:
|
|||
|
||||
|
||||
class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
|
||||
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
|
||||
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
|
||||
DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
|
||||
_BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME,
|
||||
self._background_reindex_origin_server_ts,
|
||||
)
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
|
||||
_BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
|
||||
self._background_reindex_fields_sender,
|
||||
)
|
||||
|
||||
|
|
@ -85,7 +109,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update
|
||||
_BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES,
|
||||
self._cleanup_extremities_bg_update,
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
|
|
@ -139,6 +164,59 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
self._purged_chain_cover_index,
|
||||
)
|
||||
|
||||
################################################################################
|
||||
|
||||
# bg updates for replacing stream_ordering with a BIGINT
|
||||
# (these only run on postgres.)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
_BackgroundUpdates.POPULATE_STREAM_ORDERING2,
|
||||
self._background_populate_stream_ordering2,
|
||||
)
|
||||
# CREATE UNIQUE INDEX events_stream_ordering ON events(stream_ordering2);
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
_BackgroundUpdates.INDEX_STREAM_ORDERING2,
|
||||
index_name="events_stream_ordering",
|
||||
table="events",
|
||||
columns=["stream_ordering2"],
|
||||
unique=True,
|
||||
)
|
||||
# CREATE INDEX event_contains_url_index ON events(room_id, topological_ordering, stream_ordering) WHERE contains_url = true AND outlier = false;
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
_BackgroundUpdates.INDEX_STREAM_ORDERING2_CONTAINS_URL,
|
||||
index_name="event_contains_url_index2",
|
||||
table="events",
|
||||
columns=["room_id", "topological_ordering", "stream_ordering2"],
|
||||
where_clause="contains_url = true AND outlier = false",
|
||||
)
|
||||
# CREATE INDEX events_order_room ON events(room_id, topological_ordering, stream_ordering);
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
_BackgroundUpdates.INDEX_STREAM_ORDERING2_ROOM_ORDER,
|
||||
index_name="events_order_room2",
|
||||
table="events",
|
||||
columns=["room_id", "topological_ordering", "stream_ordering2"],
|
||||
)
|
||||
# CREATE INDEX events_room_stream ON events(room_id, stream_ordering);
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
_BackgroundUpdates.INDEX_STREAM_ORDERING2_ROOM_STREAM,
|
||||
index_name="events_room_stream2",
|
||||
table="events",
|
||||
columns=["room_id", "stream_ordering2"],
|
||||
)
|
||||
# CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering);
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
_BackgroundUpdates.INDEX_STREAM_ORDERING2_TS,
|
||||
index_name="events_ts2",
|
||||
table="events",
|
||||
columns=["origin_server_ts", "stream_ordering2"],
|
||||
)
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
_BackgroundUpdates.REPLACE_STREAM_ORDERING_COLUMN,
|
||||
self._background_replace_stream_ordering_column,
|
||||
)
|
||||
|
||||
################################################################################
|
||||
|
||||
async def _background_reindex_fields_sender(self, progress, batch_size):
|
||||
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||
max_stream_id = progress["max_stream_id_exclusive"]
|
||||
|
|
@ -190,18 +268,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
}
|
||||
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
|
||||
txn, _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
|
||||
)
|
||||
|
||||
return len(rows)
|
||||
|
||||
result = await self.db_pool.runInteraction(
|
||||
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
|
||||
_BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
|
||||
)
|
||||
|
||||
if not result:
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
|
||||
_BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
|
||||
)
|
||||
|
||||
return result
|
||||
|
|
@ -264,18 +342,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
}
|
||||
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
|
||||
txn, _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, progress
|
||||
)
|
||||
|
||||
return len(rows_to_update)
|
||||
|
||||
result = await self.db_pool.runInteraction(
|
||||
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
|
||||
_BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
|
||||
)
|
||||
|
||||
if not result:
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.EVENT_ORIGIN_SERVER_TS_NAME
|
||||
_BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME
|
||||
)
|
||||
|
||||
return result
|
||||
|
|
@ -454,7 +532,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
|
||||
if not num_handled:
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.DELETE_SOFT_FAILED_EXTREMITIES
|
||||
_BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES
|
||||
)
|
||||
|
||||
def _drop_table_txn(txn):
|
||||
|
|
@ -1009,3 +1087,81 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
await self.db_pool.updates._end_background_update("purged_chain_cover")
|
||||
|
||||
return result
|
||||
|
||||
async def _background_populate_stream_ordering2(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""Populate events.stream_ordering2, then replace stream_ordering
|
||||
|
||||
This is to deal with the fact that stream_ordering was initially created as a
|
||||
32-bit integer field.
|
||||
"""
|
||||
batch_size = max(batch_size, 1)
|
||||
|
||||
def process(txn: Cursor) -> int:
|
||||
last_stream = progress.get("last_stream", -(1 << 31))
|
||||
txn.execute(
|
||||
"""
|
||||
UPDATE events SET stream_ordering2=stream_ordering
|
||||
WHERE stream_ordering IN (
|
||||
SELECT stream_ordering FROM events WHERE stream_ordering > ?
|
||||
ORDER BY stream_ordering LIMIT ?
|
||||
)
|
||||
RETURNING stream_ordering;
|
||||
""",
|
||||
(last_stream, batch_size),
|
||||
)
|
||||
row_count = txn.rowcount
|
||||
if row_count == 0:
|
||||
return 0
|
||||
last_stream = max(row[0] for row in txn)
|
||||
logger.info("populated stream_ordering2 up to %i", last_stream)
|
||||
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn,
|
||||
_BackgroundUpdates.POPULATE_STREAM_ORDERING2,
|
||||
{"last_stream": last_stream},
|
||||
)
|
||||
return row_count
|
||||
|
||||
result = await self.db_pool.runInteraction(
|
||||
"_background_populate_stream_ordering2", process
|
||||
)
|
||||
|
||||
if result != 0:
|
||||
return result
|
||||
|
||||
await self.db_pool.updates._end_background_update(
|
||||
_BackgroundUpdates.POPULATE_STREAM_ORDERING2
|
||||
)
|
||||
return 0
|
||||
|
||||
async def _background_replace_stream_ordering_column(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""Drop the old 'stream_ordering' column and rename 'stream_ordering2' into its place."""
|
||||
|
||||
def process(txn: Cursor) -> None:
|
||||
for sql in _REPLACE_STREAM_ORDERING_SQL_COMMANDS:
|
||||
logger.info("completing stream_ordering migration: %s", sql)
|
||||
txn.execute(sql)
|
||||
|
||||
# ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the
|
||||
# indexes on it.
|
||||
# We need to pass execute a dummy function to handle the txn's result otherwise
|
||||
# it tries to call fetchall() on it and fails because there's no result to fetch.
|
||||
await self.db_pool.execute(
|
||||
"background_analyze_new_stream_ordering_column",
|
||||
lambda txn: None,
|
||||
"ANALYZE events(stream_ordering2)",
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"_background_replace_stream_ordering_column", process
|
||||
)
|
||||
|
||||
await self.db_pool.updates._end_background_update(
|
||||
_BackgroundUpdates.REPLACE_STREAM_ORDERING_COLUMN
|
||||
)
|
||||
|
||||
return 0
|
||||
|
|
|
|||
|
|
@ -73,20 +73,20 @@ class ProfileWorkerStore(SQLBaseStore):
|
|||
async def set_profile_displayname(
|
||||
self, user_localpart: str, new_displayname: Optional[str]
|
||||
) -> None:
|
||||
await self.db_pool.simple_update_one(
|
||||
await self.db_pool.simple_upsert(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
updatevalues={"displayname": new_displayname},
|
||||
values={"displayname": new_displayname},
|
||||
desc="set_profile_displayname",
|
||||
)
|
||||
|
||||
async def set_profile_avatar_url(
|
||||
self, user_localpart: str, new_avatar_url: Optional[str]
|
||||
) -> None:
|
||||
await self.db_pool.simple_update_one(
|
||||
await self.db_pool.simple_upsert(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
updatevalues={"avatar_url": new_avatar_url},
|
||||
values={"avatar_url": new_avatar_url},
|
||||
desc="set_profile_avatar_url",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -53,6 +53,9 @@ class TokenLookupResult:
|
|||
valid_until_ms: The timestamp the token expires, if any.
|
||||
token_owner: The "owner" of the token. This is either the same as the
|
||||
user, or a server admin who is logged in as the user.
|
||||
token_used: True if this token was used at least once in a request.
|
||||
This field can be out of date since `get_user_by_access_token` is
|
||||
cached.
|
||||
"""
|
||||
|
||||
user_id = attr.ib(type=str)
|
||||
|
|
@ -62,6 +65,7 @@ class TokenLookupResult:
|
|||
device_id = attr.ib(type=Optional[str], default=None)
|
||||
valid_until_ms = attr.ib(type=Optional[int], default=None)
|
||||
token_owner = attr.ib(type=str)
|
||||
token_used = attr.ib(type=bool, default=False)
|
||||
|
||||
# Make the token owner default to the user ID, which is the common case.
|
||||
@token_owner.default
|
||||
|
|
@ -69,6 +73,29 @@ class TokenLookupResult:
|
|||
return self.user_id
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True)
|
||||
class RefreshTokenLookupResult:
|
||||
"""Result of looking up a refresh token."""
|
||||
|
||||
user_id = attr.ib(type=str)
|
||||
"""The user this token belongs to."""
|
||||
|
||||
device_id = attr.ib(type=str)
|
||||
"""The device associated with this refresh token."""
|
||||
|
||||
token_id = attr.ib(type=int)
|
||||
"""The ID of this refresh token."""
|
||||
|
||||
next_token_id = attr.ib(type=Optional[int])
|
||||
"""The ID of the refresh token which replaced this one."""
|
||||
|
||||
has_next_refresh_token_been_refreshed = attr.ib(type=bool)
|
||||
"""True if the next refresh token was used for another refresh."""
|
||||
|
||||
has_next_access_token_been_used = attr.ib(type=bool)
|
||||
"""True if the next access token was already used at least once."""
|
||||
|
||||
|
||||
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -441,7 +468,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
access_tokens.id as token_id,
|
||||
access_tokens.device_id,
|
||||
access_tokens.valid_until_ms,
|
||||
access_tokens.user_id as token_owner
|
||||
access_tokens.user_id as token_owner,
|
||||
access_tokens.used as token_used
|
||||
FROM users
|
||||
INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
|
||||
WHERE token = ?
|
||||
|
|
@ -449,8 +477,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
txn.execute(sql, (token,))
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
if rows:
|
||||
return TokenLookupResult(**rows[0])
|
||||
row = rows[0]
|
||||
|
||||
# This field is nullable, ensure it comes out as a boolean
|
||||
if row["token_used"] is None:
|
||||
row["token_used"] = False
|
||||
|
||||
return TokenLookupResult(**row)
|
||||
|
||||
return None
|
||||
|
||||
|
|
@ -1072,6 +1107,111 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
desc="update_access_token_last_validated",
|
||||
)
|
||||
|
||||
@cached()
|
||||
async def mark_access_token_as_used(self, token_id: int) -> None:
|
||||
"""
|
||||
Mark the access token as used, which invalidates the refresh token used
|
||||
to obtain it.
|
||||
|
||||
Because get_user_by_access_token is cached, this function might be
|
||||
called multiple times for the same token, effectively doing unnecessary
|
||||
SQL updates. Because updating the `used` field only goes one way (from
|
||||
False to True) it is safe to cache this function as well to avoid this
|
||||
issue.
|
||||
|
||||
Args:
|
||||
token_id: The ID of the access token to update.
|
||||
Raises:
|
||||
StoreError if there was a problem updating this.
|
||||
"""
|
||||
await self.db_pool.simple_update_one(
|
||||
"access_tokens",
|
||||
{"id": token_id},
|
||||
{"used": True},
|
||||
desc="mark_access_token_as_used",
|
||||
)
|
||||
|
||||
async def lookup_refresh_token(
|
||||
self, token: str
|
||||
) -> Optional[RefreshTokenLookupResult]:
|
||||
"""Lookup a refresh token with hints about its validity."""
|
||||
|
||||
def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT
|
||||
rt.id token_id,
|
||||
rt.user_id,
|
||||
rt.device_id,
|
||||
rt.next_token_id,
|
||||
(nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
|
||||
at.used has_next_access_token_been_used
|
||||
FROM refresh_tokens rt
|
||||
LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
|
||||
LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
|
||||
WHERE rt.token = ?
|
||||
""",
|
||||
(token,),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return RefreshTokenLookupResult(
|
||||
token_id=row[0],
|
||||
user_id=row[1],
|
||||
device_id=row[2],
|
||||
next_token_id=row[3],
|
||||
has_next_refresh_token_been_refreshed=row[4],
|
||||
# This column is nullable, ensure it's a boolean
|
||||
has_next_access_token_been_used=(row[5] or False),
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"lookup_refresh_token", _lookup_refresh_token_txn
|
||||
)
|
||||
|
||||
async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None:
|
||||
"""
|
||||
Set the successor of a refresh token, removing the existing successor
|
||||
if any.
|
||||
|
||||
Args:
|
||||
token_id: ID of the refresh token to update.
|
||||
next_token_id: ID of its successor.
|
||||
"""
|
||||
|
||||
def _replace_refresh_token_txn(txn) -> None:
|
||||
# First check if there was an existing refresh token
|
||||
old_next_token_id = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
"refresh_tokens",
|
||||
{"id": token_id},
|
||||
"next_token_id",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
"refresh_tokens",
|
||||
{"id": token_id},
|
||||
{"next_token_id": next_token_id},
|
||||
)
|
||||
|
||||
# Delete the old "next" token if it exists. This should cascade and
|
||||
# delete the associated access_token
|
||||
if old_next_token_id is not None:
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
txn,
|
||||
"refresh_tokens",
|
||||
{"id": old_next_token_id},
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"replace_refresh_token", _replace_refresh_token_txn
|
||||
)
|
||||
|
||||
|
||||
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||
def __init__(
|
||||
|
|
@ -1263,6 +1403,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
|
||||
|
||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
|
||||
|
||||
async def add_access_token_to_user(
|
||||
self,
|
||||
|
|
@ -1271,14 +1412,18 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
device_id: Optional[str],
|
||||
valid_until_ms: Optional[int],
|
||||
puppets_user_id: Optional[str] = None,
|
||||
refresh_token_id: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Adds an access token for the given user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID.
|
||||
token: The new access token to add.
|
||||
device_id: ID of the device to associate with the access token
|
||||
device_id: ID of the device to associate with the access token.
|
||||
valid_until_ms: when the token is valid until. None for no expiry.
|
||||
puppets_user_id
|
||||
refresh_token_id: ID of the refresh token generated alongside this
|
||||
access token.
|
||||
Raises:
|
||||
StoreError if there was a problem adding this.
|
||||
Returns:
|
||||
|
|
@ -1297,12 +1442,47 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
"valid_until_ms": valid_until_ms,
|
||||
"puppets_user_id": puppets_user_id,
|
||||
"last_validated": now,
|
||||
"refresh_token_id": refresh_token_id,
|
||||
"used": False,
|
||||
},
|
||||
desc="add_access_token_to_user",
|
||||
)
|
||||
|
||||
return next_id
|
||||
|
||||
async def add_refresh_token_to_user(
|
||||
self,
|
||||
user_id: str,
|
||||
token: str,
|
||||
device_id: Optional[str],
|
||||
) -> int:
|
||||
"""Adds a refresh token for the given user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID.
|
||||
token: The new access token to add.
|
||||
device_id: ID of the device to associate with the refresh token.
|
||||
Raises:
|
||||
StoreError if there was a problem adding this.
|
||||
Returns:
|
||||
The token ID
|
||||
"""
|
||||
next_id = self._refresh_tokens_id_gen.get_next()
|
||||
|
||||
await self.db_pool.simple_insert(
|
||||
"refresh_tokens",
|
||||
{
|
||||
"id": next_id,
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"token": token,
|
||||
"next_token_id": None,
|
||||
},
|
||||
desc="add_refresh_token_to_user",
|
||||
)
|
||||
|
||||
return next_id
|
||||
|
||||
def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
|
||||
old_device_id = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn, "access_tokens", {"token": token}, "device_id"
|
||||
|
|
@ -1545,7 +1725,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
device_id: Optional[str] = None,
|
||||
) -> List[Tuple[str, int, Optional[str]]]:
|
||||
"""
|
||||
Invalidate access tokens belonging to a user
|
||||
Invalidate access and refresh tokens belonging to a user
|
||||
|
||||
Args:
|
||||
user_id: ID of user the tokens belong to
|
||||
|
|
@ -1565,7 +1745,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
items = keyvalues.items()
|
||||
where_clause = " AND ".join(k + " = ?" for k, _ in items)
|
||||
values = [v for _, v in items] # type: List[Union[str, int]]
|
||||
# Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
|
||||
# is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
|
||||
# clause and values before we handle that. This seems to be only used in the "set password" handler.
|
||||
refresh_where_clause = where_clause
|
||||
refresh_values = values.copy()
|
||||
if except_token_id:
|
||||
# TODO: support that for refresh tokens
|
||||
where_clause += " AND id != ?"
|
||||
values.append(except_token_id)
|
||||
|
||||
|
|
@ -1583,6 +1769,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
|
||||
txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
|
||||
|
||||
txn.execute(
|
||||
"DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause,
|
||||
refresh_values,
|
||||
)
|
||||
|
||||
return tokens_and_devices
|
||||
|
||||
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
|
||||
|
|
@ -1599,6 +1790,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
|
||||
await self.db_pool.runInteraction("delete_access_token", f)
|
||||
|
||||
async def delete_refresh_token(self, refresh_token: str) -> None:
|
||||
def f(txn):
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
txn, table="refresh_tokens", keyvalues={"token": refresh_token}
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction("delete_refresh_token", f)
|
||||
|
||||
async def add_user_pending_deactivation(self, user_id: str) -> None:
|
||||
"""
|
||||
Adds a user to the table of users who need to be parted from all the rooms they're
|
||||
|
|
|
|||
|
|
@ -49,6 +49,12 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
|
|||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def supports_returning(self) -> bool:
|
||||
"""Do we support the `RETURNING` clause in insert/update/delete?"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def check_database(
|
||||
self, db_conn: ConnectionType, allow_outdated_version: bool = False
|
||||
|
|
|
|||
|
|
@ -133,6 +133,11 @@ class PostgresEngine(BaseDatabaseEngine):
|
|||
"""Do we support using `a = ANY(?)` and passing a list"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_returning(self) -> bool:
|
||||
"""Do we support the `RETURNING` clause in insert/update/delete?"""
|
||||
return True
|
||||
|
||||
def is_deadlock(self, error):
|
||||
if isinstance(error, self.module.DatabaseError):
|
||||
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
|
||||
|
|
|
|||
|
|
@ -60,6 +60,11 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
|
|||
"""Do we support using `a = ANY(?)` and passing a list"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_returning(self) -> bool:
|
||||
"""Do we support the `RETURNING` clause in insert/update/delete?"""
|
||||
return self.module.sqlite_version_info >= (3, 35, 0)
|
||||
|
||||
def check_database(self, db_conn, allow_outdated_version: bool = False):
|
||||
if not allow_outdated_version:
|
||||
version = self.module.sqlite_version_info
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
SCHEMA_VERSION = 59
|
||||
SCHEMA_VERSION = 60
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
|
|
|
|||
34
synapse/storage/schema/main/delta/59/14refresh_tokens.sql
Normal file
34
synapse/storage/schema/main/delta/59/14refresh_tokens.sql
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Holds MSC2918 refresh tokens
|
||||
CREATE TABLE refresh_tokens (
|
||||
id BIGINT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL,
|
||||
-- When consumed, a new refresh token is generated, which is tracked by
|
||||
-- this foreign key
|
||||
next_token_id BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE,
|
||||
UNIQUE(token)
|
||||
);
|
||||
|
||||
-- Add a reference to the refresh token generated alongside each access token
|
||||
ALTER TABLE "access_tokens"
|
||||
ADD COLUMN refresh_token_id BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE;
|
||||
|
||||
-- Add a flag whether the token was already used or not
|
||||
ALTER TABLE "access_tokens"
|
||||
ADD COLUMN used BOOLEAN;
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- This migration handles the process of changing the type of `stream_ordering` to
|
||||
-- a BIGINT.
|
||||
--
|
||||
-- Note that this is only a problem on postgres as sqlite only has one "integer" type
|
||||
-- which can cope with values up to 2^63.
|
||||
|
||||
-- First add a new column to contain the bigger stream_ordering
|
||||
ALTER TABLE events ADD COLUMN stream_ordering2 BIGINT;
|
||||
|
||||
-- Create a rule which will populate it for new rows.
|
||||
CREATE OR REPLACE RULE "populate_stream_ordering2" AS
|
||||
ON INSERT TO events
|
||||
DO UPDATE events SET stream_ordering2=NEW.stream_ordering WHERE stream_ordering=NEW.stream_ordering;
|
||||
|
||||
-- Start a bg process to populate it for old events
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(6001, 'populate_stream_ordering2', '{}');
|
||||
|
||||
-- ... and some more to build indexes on it. These aren't really interdependent
|
||||
-- but the backround_updates manager can only handle a single dependency per update.
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
|
||||
(6001, 'index_stream_ordering2', '{}', 'populate_stream_ordering2'),
|
||||
(6001, 'index_stream_ordering2_room_order', '{}', 'index_stream_ordering2'),
|
||||
(6001, 'index_stream_ordering2_contains_url', '{}', 'index_stream_ordering2_room_order'),
|
||||
(6001, 'index_stream_ordering2_room_stream', '{}', 'index_stream_ordering2_contains_url'),
|
||||
(6001, 'index_stream_ordering2_ts', '{}', 'index_stream_ordering2_room_stream');
|
||||
|
||||
-- ... and another to do the switcheroo
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
|
||||
(6001, 'replace_stream_ordering_column', '{}', 'index_stream_ordering2_ts');
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- This migration is closely related to '01recreate_stream_ordering.sql.postgres'.
|
||||
--
|
||||
-- It updates the other tables which use an INTEGER to refer to a stream ordering.
|
||||
-- These tables are all small enough that a re-create is tractable.
|
||||
ALTER TABLE pushers ALTER COLUMN last_stream_ordering SET DATA TYPE BIGINT;
|
||||
ALTER TABLE federation_stream_position ALTER COLUMN stream_id SET DATA TYPE BIGINT;
|
||||
|
||||
-- these aren't actually event stream orderings, but they are numbers where 2 billion
|
||||
-- is a bit limiting, application_services_state is tiny, and I don't want to ever have
|
||||
-- to do this again.
|
||||
ALTER TABLE application_services_state ALTER COLUMN last_txn SET DATA TYPE BIGINT;
|
||||
ALTER TABLE application_services_state ALTER COLUMN read_receipt_stream_id SET DATA TYPE BIGINT;
|
||||
ALTER TABLE application_services_state ALTER COLUMN presence_stream_id SET DATA TYPE BIGINT;
|
||||
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue