Remove manys calls to cursor_to_dict (#16431)

This avoids calling cursor_to_dict and then immediately
unpacking the values in the dict for other users. By not
creating the intermediate dictionary we can avoid allocating
the dictionary and strings for the keys, which should generally
be more performant.

Additionally this improves type hints by avoid Dict[str, Any]
dictionaries coming out of the database layer.
This commit is contained in:
Patrick Cloke 2023-10-05 11:07:38 -04:00 committed by GitHub
parent 4e302b30b6
commit fa907025f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 319 additions and 227 deletions

View File

@ -1 +1 @@
Reduce the size of each replication command instance. Reduce memory allocations.

1
changelog.d/16431.misc Normal file
View File

@ -0,0 +1 @@
Reduce memory allocations.

View File

@ -101,7 +101,7 @@ if TYPE_CHECKING:
class PusherConfig: class PusherConfig:
"""Parameters necessary to configure a pusher.""" """Parameters necessary to configure a pusher."""
id: Optional[str] id: Optional[int]
user_name: str user_name: str
profile_tag: str profile_tag: str

View File

@ -151,10 +151,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
sql += " AND content != '{}'" sql += " AND content != '{}'"
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
rows = self.db_pool.cursor_to_dict(txn)
return { return {
row["account_data_type"]: db_to_json(row["content"]) for row in rows account_data_type: db_to_json(content)
for account_data_type, content in txn
} }
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
@ -196,13 +196,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
sql += " AND content != '{}'" sql += " AND content != '{}'"
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
rows = self.db_pool.cursor_to_dict(txn)
by_room: Dict[str, Dict[str, JsonDict]] = {} by_room: Dict[str, Dict[str, JsonDict]] = {}
for row in rows: for room_id, account_data_type, content in txn:
room_data = by_room.setdefault(row["room_id"], {}) room_data = by_room.setdefault(room_id, {})
room_data[row["account_data_type"]] = db_to_json(row["content"]) room_data[account_data_type] = db_to_json(content)
return by_room return by_room

View File

@ -14,17 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re import re
from typing import ( from typing import TYPE_CHECKING, List, Optional, Pattern, Sequence, Tuple, cast
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Pattern,
Sequence,
Tuple,
cast,
)
from synapse.appservice import ( from synapse.appservice import (
ApplicationService, ApplicationService,
@ -353,21 +343,15 @@ class ApplicationServiceTransactionWorkerStore(
def _get_oldest_unsent_txn( def _get_oldest_unsent_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Optional[Dict[str, Any]]: ) -> Optional[Tuple[int, str]]:
# Monotonically increasing txn ids, so just select the smallest # Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent) # one in the txns table (we delete them when they are sent)
txn.execute( txn.execute(
"SELECT * FROM application_services_txns WHERE as_id=?" "SELECT txn_id, event_ids FROM application_services_txns WHERE as_id=?"
" ORDER BY txn_id ASC LIMIT 1", " ORDER BY txn_id ASC LIMIT 1",
(service.id,), (service.id,),
) )
rows = self.db_pool.cursor_to_dict(txn) return cast(Optional[Tuple[int, str]], txn.fetchone())
if not rows:
return None
entry = rows[0]
return entry
entry = await self.db_pool.runInteraction( entry = await self.db_pool.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
@ -376,8 +360,9 @@ class ApplicationServiceTransactionWorkerStore(
if not entry: if not entry:
return None return None
event_ids = db_to_json(entry["event_ids"]) txn_id, event_ids_str = entry
event_ids = db_to_json(event_ids_str)
events = await self.get_events_as_list(event_ids) events = await self.get_events_as_list(event_ids)
# TODO: to-device messages, one-time key counts, device list summaries and unused # TODO: to-device messages, one-time key counts, device list summaries and unused
@ -385,7 +370,7 @@ class ApplicationServiceTransactionWorkerStore(
# We likely want to populate those for reliability. # We likely want to populate those for reliability.
return AppServiceTransaction( return AppServiceTransaction(
service=service, service=service,
id=entry["txn_id"], id=txn_id,
events=events, events=events,
ephemeral=[], ephemeral=[],
to_device_messages=[], to_device_messages=[],

View File

@ -1413,13 +1413,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_devices_not_accessed_since_txn( def get_devices_not_accessed_since_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> List[Dict[str, str]]: ) -> List[Tuple[str, str]]:
sql = """ sql = """
SELECT user_id, device_id SELECT user_id, device_id
FROM devices WHERE last_seen < ? AND hidden = FALSE FROM devices WHERE last_seen < ? AND hidden = FALSE
""" """
txn.execute(sql, (since_ms,)) txn.execute(sql, (since_ms,))
return self.db_pool.cursor_to_dict(txn) return cast(List[Tuple[str, str]], txn.fetchall())
rows = await self.db_pool.runInteraction( rows = await self.db_pool.runInteraction(
"get_devices_not_accessed_since", "get_devices_not_accessed_since",
@ -1427,11 +1427,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
) )
devices: Dict[str, List[str]] = {} devices: Dict[str, List[str]] = {}
for row in rows: for user_id, device_id in rows:
# Remote devices are never stale from our point of view. # Remote devices are never stale from our point of view.
if self.hs.is_mine_id(row["user_id"]): if self.hs.is_mine_id(user_id):
user_devices = devices.setdefault(row["user_id"], []) user_devices = devices.setdefault(user_id, [])
user_devices.append(row["device_id"]) user_devices.append(device_id)
return devices return devices

View File

@ -921,14 +921,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
} }
txn.execute(sql, params) txn.execute(sql, params)
rows = self.db_pool.cursor_to_dict(txn)
for row in rows: for user_id, key_type, key_data, _ in txn:
user_id = row["user_id"]
key_type = row["keytype"]
key = db_to_json(row["keydata"])
user_keys = result.setdefault(user_id, {}) user_keys = result.setdefault(user_id, {})
user_keys[key_type] = key user_keys[key_type] = db_to_json(key_data)
return result return result
@ -988,13 +984,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
query_params.extend(item) query_params.extend(item)
txn.execute(sql, query_params) txn.execute(sql, query_params)
rows = self.db_pool.cursor_to_dict(txn)
# and add the signatures to the appropriate keys # and add the signatures to the appropriate keys
for row in rows: for target_user_id, target_device_id, key_id, signature in txn:
key_id: str = row["key_id"]
target_user_id: str = row["target_user_id"]
target_device_id: str = row["target_device_id"]
key_type = devices[(target_user_id, target_device_id)] key_type = devices[(target_user_id, target_device_id)]
# We need to copy everything, because the result may have come # We need to copy everything, because the result may have come
# from the cache. dict.copy only does a shallow copy, so we # from the cache. dict.copy only does a shallow copy, so we
@ -1012,13 +1004,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
].copy() ].copy()
if from_user_id in signatures: if from_user_id in signatures:
user_sigs = signatures[from_user_id] = signatures[from_user_id] user_sigs = signatures[from_user_id] = signatures[from_user_id]
user_sigs[key_id] = row["signature"] user_sigs[key_id] = signature
else: else:
signatures[from_user_id] = {key_id: row["signature"]} signatures[from_user_id] = {key_id: signature}
else: else:
target_user_key["signatures"] = { target_user_key["signatures"] = {from_user_id: {key_id: signature}}
from_user_id: {key_id: row["signature"]}
}
return keys return keys

View File

@ -1654,8 +1654,6 @@ class PersistEventsStore:
) -> None: ) -> None:
to_prefill = [] to_prefill = []
rows = []
ev_map = {e.event_id: e for e, _ in events_and_contexts} ev_map = {e.event_id: e for e, _ in events_and_contexts}
if not ev_map: if not ev_map:
return return
@ -1676,10 +1674,9 @@ class PersistEventsStore:
) )
txn.execute(sql + clause, args) txn.execute(sql + clause, args)
rows = self.db_pool.cursor_to_dict(txn) for event_id, redacts, rejects in txn:
for row in rows: event = ev_map[event_id]
event = ev_map[row["event_id"]] if not rejects and not redacts:
if not row["rejects"] and not row["redacts"]:
to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
async def external_prefill() -> None: async def external_prefill() -> None:

View File

@ -434,13 +434,21 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
txn = db_conn.cursor() txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,)) txn.execute(sql, (PresenceState.OFFLINE,))
rows = self.db_pool.cursor_to_dict(txn) rows = txn.fetchall()
txn.close() txn.close()
for row in rows: return [
row["currently_active"] = bool(row["currently_active"]) UserPresenceState(
user_id=user_id,
return [UserPresenceState(**row) for row in rows] state=state,
last_active_ts=last_active_ts,
last_federation_update_ts=last_federation_update_ts,
last_user_sync_ts=last_user_sync_ts,
status_msg=status_msg,
currently_active=bool(currently_active),
)
for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows
]
def take_presence_startup_info(self) -> List[UserPresenceState]: def take_presence_startup_info(self) -> List[UserPresenceState]:
active_on_startup = self._presence_on_startup active_on_startup = self._presence_on_startup

View File

@ -47,6 +47,27 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# The type of a row in the pushers table.
PusherRow = Tuple[
int, # id
str, # user_name
Optional[int], # access_token
str, # profile_tag
str, # kind
str, # app_id
str, # app_display_name
str, # device_display_name
str, # pushkey
int, # ts
str, # lang
str, # data
int, # last_stream_ordering
int, # last_success
int, # failing_since
bool, # enabled
str, # device_id
]
class PusherWorkerStore(SQLBaseStore): class PusherWorkerStore(SQLBaseStore):
def __init__( def __init__(
@ -83,30 +104,66 @@ class PusherWorkerStore(SQLBaseStore):
self._remove_deleted_email_pushers, self._remove_deleted_email_pushers,
) )
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]: def _decode_pushers_rows(
self,
rows: Iterable[PusherRow],
) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table """JSON-decode the data in the rows returned from the `pushers` table
Drops any rows whose data cannot be decoded Drops any rows whose data cannot be decoded
""" """
for r in rows: for (
data_json = r["data"] id,
user_name,
access_token,
profile_tag,
kind,
app_id,
app_display_name,
device_display_name,
pushkey,
ts,
lang,
data,
last_stream_ordering,
last_success,
failing_since,
enabled,
device_id,
) in rows:
try: try:
r["data"] = db_to_json(data_json) data_json = db_to_json(data)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"Invalid JSON in data for pusher %d: %s, %s", "Invalid JSON in data for pusher %d: %s, %s",
r["id"], id,
data_json, data,
e.args[0], e.args[0],
) )
continue continue
# If we're using SQLite, then boolean values are integers. This is yield PusherConfig(
# troublesome since some code using the return value of this method might id=id,
# expect it to be a boolean, or will expose it to clients (in responses). user_name=user_name,
r["enabled"] = bool(r["enabled"]) profile_tag=profile_tag,
kind=kind,
yield PusherConfig(**r) app_id=app_id,
app_display_name=app_display_name,
device_display_name=device_display_name,
pushkey=pushkey,
ts=ts,
lang=lang,
data=data_json,
last_stream_ordering=last_stream_ordering,
last_success=last_success,
failing_since=failing_since,
# If we're using SQLite, then boolean values are integers. This is
# troublesome since some code using the return value of this method might
# expect it to be a boolean, or will expose it to clients (in responses).
enabled=bool(enabled),
device_id=device_id,
access_token=access_token,
)
def get_pushers_stream_token(self) -> int: def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token() return self._pushers_id_gen.get_current_token()
@ -136,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
The pushers for which the given columns have the given values. The pushers for which the given columns have the given values.
""" """
def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]: def get_pushers_by_txn(txn: LoggingTransaction) -> List[PusherRow]:
# We could technically use simple_select_list here, but we need to call # We could technically use simple_select_list here, but we need to call
# COALESCE on the 'enabled' column. While it is technically possible to give # COALESCE on the 'enabled' column. While it is technically possible to give
# simple_select_list the whole `COALESCE(...) AS ...` as a column name, it # simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
@ -154,7 +211,7 @@ class PusherWorkerStore(SQLBaseStore):
txn.execute(sql, list(keyvalues.values())) txn.execute(sql, list(keyvalues.values()))
return self.db_pool.cursor_to_dict(txn) return cast(List[PusherRow], txn.fetchall())
ret = await self.db_pool.runInteraction( ret = await self.db_pool.runInteraction(
desc="get_pushers_by", desc="get_pushers_by",
@ -164,14 +221,22 @@ class PusherWorkerStore(SQLBaseStore):
return self._decode_pushers_rows(ret) return self._decode_pushers_rows(ret)
async def get_enabled_pushers(self) -> Iterator[PusherConfig]: async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]: def get_enabled_pushers_txn(txn: LoggingTransaction) -> List[PusherRow]:
txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)") txn.execute(
rows = self.db_pool.cursor_to_dict(txn) """
SELECT id, user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, pushkey, ts, lang, data,
last_stream_ordering, last_success, failing_since,
enabled, device_id
FROM pushers WHERE COALESCE(enabled, TRUE)
"""
)
return cast(List[PusherRow], txn.fetchall())
return self._decode_pushers_rows(rows) return self._decode_pushers_rows(
await self.db_pool.runInteraction(
return await self.db_pool.runInteraction( "get_enabled_pushers", get_enabled_pushers_txn
"get_enabled_pushers", get_enabled_pushers_txn )
) )
async def get_all_updated_pushers_rows( async def get_all_updated_pushers_rows(
@ -304,7 +369,7 @@ class PusherWorkerStore(SQLBaseStore):
) )
async def get_throttle_params_by_room( async def get_throttle_params_by_room(
self, pusher_id: str self, pusher_id: int
) -> Dict[str, ThrottleParams]: ) -> Dict[str, ThrottleParams]:
res = await self.db_pool.simple_select_list( res = await self.db_pool.simple_select_list(
"pusher_throttle", "pusher_throttle",
@ -323,7 +388,7 @@ class PusherWorkerStore(SQLBaseStore):
return params_by_room return params_by_room
async def set_throttle_params( async def set_throttle_params(
self, pusher_id: str, room_id: str, params: ThrottleParams self, pusher_id: int, room_id: str, params: ThrottleParams
) -> None: ) -> None:
await self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
"pusher_throttle", "pusher_throttle",
@ -534,7 +599,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
(last_pusher_id, batch_size), (last_pusher_id, batch_size),
) )
rows = self.db_pool.cursor_to_dict(txn) rows = txn.fetchall()
if len(rows) == 0: if len(rows) == 0:
return 0 return 0
@ -550,19 +615,19 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
txn=txn, txn=txn,
table="pushers", table="pushers",
key_names=("id",), key_names=("id",),
key_values=[(row["pusher_id"],) for row in rows], key_values=[row[0] for row in rows],
value_names=("device_id", "access_token"), value_names=("device_id", "access_token"),
# If there was already a device_id on the pusher, we only want to clear # If there was already a device_id on the pusher, we only want to clear
# the access_token column, so we keep the existing device_id. Otherwise, # the access_token column, so we keep the existing device_id. Otherwise,
# we set the device_id we got from joining the access_tokens table. # we set the device_id we got from joining the access_tokens table.
value_values=[ value_values=[
(row["pusher_device_id"] or row["token_device_id"], None) (pusher_device_id or token_device_id, None)
for row in rows for _, pusher_device_id, token_device_id in rows
], ],
) )
self.db_pool.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["pusher_id"]} txn, "set_device_id_for_pushers", {"pusher_id": rows[-1][0]}
) )
return len(rows) return len(rows)

View File

@ -313,25 +313,25 @@ class ReceiptsWorkerStore(SQLBaseStore):
) -> Sequence[JsonMapping]: ) -> Sequence[JsonMapping]:
"""See get_linearized_receipts_for_room""" """See get_linearized_receipts_for_room"""
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]:
if from_key: if from_key:
sql = ( sql = (
"SELECT * FROM receipts_linearized WHERE" "SELECT receipt_type, user_id, event_id, data"
" FROM receipts_linearized WHERE"
" room_id = ? AND stream_id > ? AND stream_id <= ?" " room_id = ? AND stream_id > ? AND stream_id <= ?"
) )
txn.execute(sql, (room_id, from_key, to_key)) txn.execute(sql, (room_id, from_key, to_key))
else: else:
sql = ( sql = (
"SELECT * FROM receipts_linearized WHERE" "SELECT receipt_type, user_id, event_id, data"
" FROM receipts_linearized WHERE"
" room_id = ? AND stream_id <= ?" " room_id = ? AND stream_id <= ?"
) )
txn.execute(sql, (room_id, to_key)) txn.execute(sql, (room_id, to_key))
rows = self.db_pool.cursor_to_dict(txn) return cast(List[Tuple[str, str, str, str]], txn.fetchall())
return rows
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f) rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
@ -339,10 +339,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return [] return []
content: JsonDict = {} content: JsonDict = {}
for row in rows: for receipt_type, user_id, event_id, data in rows:
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ content.setdefault(event_id, {}).setdefault(receipt_type, {})[
row["user_id"] user_id
] = db_to_json(row["data"]) ] = db_to_json(data)
return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}] return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}]
@ -357,10 +357,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
if not room_ids: if not room_ids:
return {} return {}
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: def f(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, str, Optional[str], str]]:
if from_key: if from_key:
sql = """ sql = """
SELECT * FROM receipts_linearized WHERE SELECT room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? AND stream_id > ? AND stream_id <= ? AND
""" """
clause, args = make_in_list_sql_clause( clause, args = make_in_list_sql_clause(
@ -370,7 +373,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql + clause, [from_key, to_key] + list(args)) txn.execute(sql + clause, [from_key, to_key] + list(args))
else: else:
sql = """ sql = """
SELECT * FROM receipts_linearized WHERE SELECT room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized WHERE
stream_id <= ? AND stream_id <= ? AND
""" """
@ -380,29 +384,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql + clause, [to_key] + list(args)) txn.execute(sql + clause, [to_key] + list(args))
return self.db_pool.cursor_to_dict(txn) return cast(
List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall()
)
txn_results = await self.db_pool.runInteraction( txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f "_get_linearized_receipts_for_rooms", f
) )
results: JsonDict = {} results: JsonDict = {}
for row in txn_results: for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results:
# We want a single event per room, since we want to batch the # We want a single event per room, since we want to batch the
# receipts by room, event and type. # receipts by room, event and type.
room_event = results.setdefault( room_event = results.setdefault(
row["room_id"], room_id,
{"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}}, {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
) )
# The content is of the form: # The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
event_entry = room_event["content"].setdefault(row["event_id"], {}) event_entry = room_event["content"].setdefault(event_id, {})
receipt_type = event_entry.setdefault(row["receipt_type"], {}) receipt_type_dict = event_entry.setdefault(receipt_type, {})
receipt_type[row["user_id"]] = db_to_json(row["data"]) receipt_type_dict[user_id] = db_to_json(data)
if row["thread_id"]: if thread_id:
receipt_type[row["user_id"]]["thread_id"] = row["thread_id"] receipt_type_dict[user_id]["thread_id"] = thread_id
results = { results = {
room_id: [results[room_id]] if room_id in results else [] room_id: [results[room_id]] if room_id in results else []
@ -428,10 +434,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
A dictionary of roomids to a list of receipts. A dictionary of roomids to a list of receipts.
""" """
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
if from_key: if from_key:
sql = """ sql = """
SELECT * FROM receipts_linearized WHERE SELECT room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? stream_id > ? AND stream_id <= ?
ORDER BY stream_id DESC ORDER BY stream_id DESC
LIMIT 100 LIMIT 100
@ -439,7 +446,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, [from_key, to_key]) txn.execute(sql, [from_key, to_key])
else: else:
sql = """ sql = """
SELECT * FROM receipts_linearized WHERE SELECT room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized WHERE
stream_id <= ? stream_id <= ?
ORDER BY stream_id DESC ORDER BY stream_id DESC
LIMIT 100 LIMIT 100
@ -447,27 +455,27 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, [to_key]) txn.execute(sql, [to_key])
return self.db_pool.cursor_to_dict(txn) return cast(List[Tuple[str, str, str, str, str]], txn.fetchall())
txn_results = await self.db_pool.runInteraction( txn_results = await self.db_pool.runInteraction(
"get_linearized_receipts_for_all_rooms", f "get_linearized_receipts_for_all_rooms", f
) )
results: JsonDict = {} results: JsonDict = {}
for row in txn_results: for room_id, receipt_type, user_id, event_id, data in txn_results:
# We want a single event per room, since we want to batch the # We want a single event per room, since we want to batch the
# receipts by room, event and type. # receipts by room, event and type.
room_event = results.setdefault( room_event = results.setdefault(
row["room_id"], room_id,
{"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}}, {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
) )
# The content is of the form: # The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
event_entry = room_event["content"].setdefault(row["event_id"], {}) event_entry = room_event["content"].setdefault(event_id, {})
receipt_type = event_entry.setdefault(row["receipt_type"], {}) receipt_type_dict = event_entry.setdefault(receipt_type, {})
receipt_type[row["user_id"]] = db_to_json(row["data"]) receipt_type_dict[user_id] = db_to_json(data)
return results return results

View File

@ -195,7 +195,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]: async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Returns info about the user account, if it exists.""" """Returns info about the user account, if it exists."""
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]:
# We could technically use simple_select_one here, but it would not perform # We could technically use simple_select_one here, but it would not perform
# the COALESCEs (unless hacked into the column names), which could yield # the COALESCEs (unless hacked into the column names), which could yield
# confusing results. # confusing results.
@ -213,35 +213,46 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
(user_id,), (user_id,),
) )
rows = self.db_pool.cursor_to_dict(txn) row = txn.fetchone()
if not row:
if len(rows) == 0:
return None return None
return rows[0] (
name,
is_guest,
admin,
consent_version,
consent_ts,
consent_server_notice_sent,
appservice_id,
creation_ts,
user_type,
deactivated,
shadow_banned,
approved,
locked,
) = row
row = await self.db_pool.runInteraction( return UserInfo(
appservice_id=appservice_id,
consent_server_notice_sent=consent_server_notice_sent,
consent_version=consent_version,
consent_ts=consent_ts,
creation_ts=creation_ts,
is_admin=bool(admin),
is_deactivated=bool(deactivated),
is_guest=bool(is_guest),
is_shadow_banned=bool(shadow_banned),
user_id=UserID.from_string(name),
user_type=user_type,
approved=bool(approved),
locked=bool(locked),
)
return await self.db_pool.runInteraction(
desc="get_user_by_id", desc="get_user_by_id",
func=get_user_by_id_txn, func=get_user_by_id_txn,
) )
if row is None:
return None
return UserInfo(
appservice_id=row["appservice_id"],
consent_server_notice_sent=row["consent_server_notice_sent"],
consent_version=row["consent_version"],
consent_ts=row["consent_ts"],
creation_ts=row["creation_ts"],
is_admin=bool(row["admin"]),
is_deactivated=bool(row["deactivated"]),
is_guest=bool(row["is_guest"]),
is_shadow_banned=bool(row["shadow_banned"]),
user_id=UserID.from_string(row["name"]),
user_type=row["user_type"],
approved=bool(row["approved"]),
locked=bool(row["locked"]),
)
async def is_trial_user(self, user_id: str) -> bool: async def is_trial_user(self, user_id: str) -> bool:
"""Checks if user is in the "trial" period, i.e. within the first """Checks if user is in the "trial" period, i.e. within the first
@ -579,16 +590,31 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
""" """
txn.execute(sql, (token,)) txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn) row = txn.fetchone()
if rows: if row:
row = rows[0] (
user_id,
is_guest,
shadow_banned,
token_id,
device_id,
valid_until_ms,
token_owner,
token_used,
) = row
# This field is nullable, ensure it comes out as a boolean return TokenLookupResult(
if row["token_used"] is None: user_id=user_id,
row["token_used"] = False is_guest=is_guest,
shadow_banned=shadow_banned,
return TokenLookupResult(**row) token_id=token_id,
device_id=device_id,
valid_until_ms=valid_until_ms,
token_owner=token_owner,
# This field is nullable, ensure it comes out as a boolean
token_used=bool(token_used),
)
return None return None
@ -833,11 +859,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Counts all users registered on the homeserver.""" """Counts all users registered on the homeserver."""
def _count_users(txn: LoggingTransaction) -> int: def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) AS users FROM users") txn.execute("SELECT COUNT(*) FROM users")
rows = self.db_pool.cursor_to_dict(txn) row = txn.fetchone()
if rows: assert row is not None
return rows[0]["users"] return row[0]
return 0
return await self.db_pool.runInteraction("count_users", _count_users) return await self.db_pool.runInteraction("count_users", _count_users)
@ -891,11 +916,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"""Counts all users without a special user_type registered on the homeserver.""" """Counts all users without a special user_type registered on the homeserver."""
def _count_users(txn: LoggingTransaction) -> int: def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") txn.execute("SELECT COUNT(*) FROM users where user_type is null")
rows = self.db_pool.cursor_to_dict(txn) row = txn.fetchone()
if rows: assert row is not None
return rows[0]["users"] return row[0]
return 0
return await self.db_pool.runInteraction("count_real_users", _count_users) return await self.db_pool.runInteraction("count_real_users", _count_users)
@ -1252,12 +1276,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
) )
txn.execute(sql, []) txn.execute(sql, [])
res = self.db_pool.cursor_to_dict(txn) for (name,) in txn.fetchall():
if res: self.set_expiration_date_for_user_txn(txn, name, use_delta=True)
for user in res:
self.set_expiration_date_for_user_txn(
txn, user["name"], use_delta=True
)
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"get_users_with_no_expiration_date", "get_users_with_no_expiration_date",
@ -1963,11 +1983,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
(user_id,), (user_id,),
) )
rows = self.db_pool.cursor_to_dict(txn) row = txn.fetchone()
assert row is not None
# We cast to bool because the value returned by the database engine might # We cast to bool because the value returned by the database engine might
# be an integer if we're using SQLite. # be an integer if we're using SQLite.
return bool(rows[0]["approved"]) return bool(row[0])
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
desc="is_user_pending_approval", desc="is_user_pending_approval",
@ -2045,22 +2066,22 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
(last_user, batch_size), (last_user, batch_size),
) )
rows = self.db_pool.cursor_to_dict(txn) rows = txn.fetchall()
if not rows: if not rows:
return True, 0 return True, 0
rows_processed_nb = 0 rows_processed_nb = 0
for user in rows: for name, count_tokens, count_threepids in rows:
if not user["count_tokens"] and not user["count_threepids"]: if not count_tokens and not count_threepids:
self.set_user_deactivated_status_txn(txn, user["name"], True) self.set_user_deactivated_status_txn(txn, name, True)
rows_processed_nb += 1 rows_processed_nb += 1
logger.info("Marked %d rows as deactivated", rows_processed_nb) logger.info("Marked %d rows as deactivated", rows_processed_nb)
self.db_pool.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} txn, "users_set_deactivated_flag", {"user_id": rows[-1][0]}
) )
if batch_size > len(rows): if batch_size > len(rows):

View File

@ -831,7 +831,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def get_retention_policy_for_room_txn( def get_retention_policy_for_room_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> List[Dict[str, Optional[int]]]: ) -> Optional[Tuple[Optional[int], Optional[int]]]:
txn.execute( txn.execute(
""" """
SELECT min_lifetime, max_lifetime FROM room_retention SELECT min_lifetime, max_lifetime FROM room_retention
@ -841,7 +841,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
(room_id,), (room_id,),
) )
return self.db_pool.cursor_to_dict(txn) return cast(Optional[Tuple[Optional[int], Optional[int]]], txn.fetchone())
ret = await self.db_pool.runInteraction( ret = await self.db_pool.runInteraction(
"get_retention_policy_for_room", "get_retention_policy_for_room",
@ -856,8 +856,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
max_lifetime=self.config.retention.retention_default_max_lifetime, max_lifetime=self.config.retention.retention_default_max_lifetime,
) )
min_lifetime = ret[0]["min_lifetime"] min_lifetime, max_lifetime = ret
max_lifetime = ret[0]["max_lifetime"]
# If one of the room's policy's attributes isn't defined, use the matching # If one of the room's policy's attributes isn't defined, use the matching
# attribute from the default policy. # attribute from the default policy.
@ -1162,14 +1161,13 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql, args) txn.execute(sql, args)
rows = self.db_pool.cursor_to_dict(txn) rooms_dict = {
rooms_dict = {} room_id: RetentionPolicy(
min_lifetime=min_lifetime,
for row in rows: max_lifetime=max_lifetime,
rooms_dict[row["room_id"]] = RetentionPolicy(
min_lifetime=row["min_lifetime"],
max_lifetime=row["max_lifetime"],
) )
for room_id, min_lifetime, max_lifetime in txn
}
if include_null: if include_null:
# If required, do a second query that retrieves all of the rooms we know # If required, do a second query that retrieves all of the rooms we know
@ -1178,13 +1176,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql) txn.execute(sql)
rows = self.db_pool.cursor_to_dict(txn)
# If a room isn't already in the dict (i.e. it doesn't have a retention # If a room isn't already in the dict (i.e. it doesn't have a retention
# policy in its state), add it with a null policy. # policy in its state), add it with a null policy.
for row in rows: for (room_id,) in txn:
if row["room_id"] not in rooms_dict: if room_id not in rooms_dict:
rooms_dict[row["room_id"]] = RetentionPolicy() rooms_dict[room_id] = RetentionPolicy()
return rooms_dict return rooms_dict
@ -1703,24 +1699,24 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
(last_room, batch_size), (last_room, batch_size),
) )
rows = self.db_pool.cursor_to_dict(txn) rows = txn.fetchall()
if not rows: if not rows:
return True return True
for row in rows: for room_id, event_id, json in rows:
if not row["json"]: if not json:
retention_policy = {} retention_policy = {}
else: else:
ev = db_to_json(row["json"]) ev = db_to_json(json)
retention_policy = ev["content"] retention_policy = ev["content"]
self.db_pool.simple_insert_txn( self.db_pool.simple_insert_txn(
txn=txn, txn=txn,
table="room_retention", table="room_retention",
values={ values={
"room_id": row["room_id"], "room_id": room_id,
"event_id": row["event_id"], "event_id": event_id,
"min_lifetime": retention_policy.get("min_lifetime"), "min_lifetime": retention_policy.get("min_lifetime"),
"max_lifetime": retention_policy.get("max_lifetime"), "max_lifetime": retention_policy.get("max_lifetime"),
}, },
@ -1729,7 +1725,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
logger.info("Inserted %d rows into room_retention", len(rows)) logger.info("Inserted %d rows into room_retention", len(rows))
self.db_pool.updates._background_update_progress_txn( self.db_pool.updates._background_update_progress_txn(
txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} txn, "insert_room_retention", {"room_id": rows[-1][0]}
) )
if batch_size > len(rows): if batch_size > len(rows):

View File

@ -1349,18 +1349,16 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
rows = self.db_pool.cursor_to_dict(txn) rows = txn.fetchall()
if not rows: if not rows:
return 0 return 0
min_stream_id = rows[-1]["stream_ordering"] min_stream_id = rows[-1][0]
to_update = [] to_update = []
for row in rows: for _, event_id, room_id, json in rows:
event_id = row["event_id"]
room_id = row["room_id"]
try: try:
event_json = db_to_json(row["json"]) event_json = db_to_json(json)
content = event_json["content"] content = event_json["content"]
except Exception: except Exception:
continue continue

View File

@ -179,22 +179,24 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
# store_search_entries_txn with a generator function, but that # store_search_entries_txn with a generator function, but that
# would mean having two cursors open on the database at once. # would mean having two cursors open on the database at once.
# Instead we just build a list of results. # Instead we just build a list of results.
rows = self.db_pool.cursor_to_dict(txn) rows = txn.fetchall()
if not rows: if not rows:
return 0 return 0
min_stream_id = rows[-1]["stream_ordering"] min_stream_id = rows[-1][0]
event_search_rows = [] event_search_rows = []
for row in rows: for (
stream_ordering,
event_id,
room_id,
etype,
json,
origin_server_ts,
) in rows:
try: try:
event_id = row["event_id"]
room_id = row["room_id"]
etype = row["type"]
stream_ordering = row["stream_ordering"]
origin_server_ts = row["origin_server_ts"]
try: try:
event_json = db_to_json(row["json"]) event_json = db_to_json(json)
content = event_json["content"] content = event_json["content"]
except Exception: except Exception:
continue continue

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, List, Optional, Tuple, cast
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import ( from synapse.storage.database import (
@ -27,6 +27,8 @@ from synapse.util import json_encoder
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
ScheduledTaskRow = Tuple[str, str, str, int, str, str, str, str]
class TaskSchedulerWorkerStore(SQLBaseStore): class TaskSchedulerWorkerStore(SQLBaseStore):
def __init__( def __init__(
@ -38,13 +40,18 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
@staticmethod @staticmethod
def _convert_row_to_task(row: Dict[str, Any]) -> ScheduledTask: def _convert_row_to_task(row: ScheduledTaskRow) -> ScheduledTask:
row["status"] = TaskStatus(row["status"]) task_id, action, status, timestamp, resource_id, params, result, error = row
if row["params"] is not None: return ScheduledTask(
row["params"] = db_to_json(row["params"]) id=task_id,
if row["result"] is not None: action=action,
row["result"] = db_to_json(row["result"]) status=TaskStatus(status),
return ScheduledTask(**row) timestamp=timestamp,
resource_id=resource_id,
params=db_to_json(params) if params is not None else None,
result=db_to_json(result) if result is not None else None,
error=error,
)
async def get_scheduled_tasks( async def get_scheduled_tasks(
self, self,
@ -68,7 +75,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
Returns: a list of `ScheduledTask`, ordered by increasing timestamps Returns: a list of `ScheduledTask`, ordered by increasing timestamps
""" """
def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]: def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[ScheduledTaskRow]:
clauses: List[str] = [] clauses: List[str] = []
args: List[Any] = [] args: List[Any] = []
if resource_id: if resource_id:
@ -101,7 +108,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
args.append(limit) args.append(limit)
txn.execute(sql, args) txn.execute(sql, args)
return self.db_pool.cursor_to_dict(txn) return cast(List[ScheduledTaskRow], txn.fetchall())
rows = await self.db_pool.runInteraction( rows = await self.db_pool.runInteraction(
"get_scheduled_tasks", get_scheduled_tasks_txn "get_scheduled_tasks", get_scheduled_tasks_txn
@ -193,7 +200,22 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
desc="get_scheduled_task", desc="get_scheduled_task",
) )
return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None return (
TaskSchedulerWorkerStore._convert_row_to_task(
(
row["id"],
row["action"],
row["status"],
row["timestamp"],
row["resource_id"],
row["params"],
row["result"],
row["error"],
)
)
if row
else None
)
async def delete_scheduled_task(self, id: str) -> None: async def delete_scheduled_task(self, id: str) -> None:
"""Delete a specific task from its id. """Delete a specific task from its id.