mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Remove manys calls to cursor_to_dict (#16431)
This avoids calling cursor_to_dict and then immediately unpacking the values in the dict for other users. By not creating the intermediate dictionary we can avoid allocating the dictionary and strings for the keys, which should generally be more performant. Additionally this improves type hints by avoid Dict[str, Any] dictionaries coming out of the database layer.
This commit is contained in:
parent
4e302b30b6
commit
fa907025f4
@ -1 +1 @@
|
|||||||
Reduce the size of each replication command instance.
|
Reduce memory allocations.
|
||||||
|
1
changelog.d/16431.misc
Normal file
1
changelog.d/16431.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Reduce memory allocations.
|
@ -101,7 +101,7 @@ if TYPE_CHECKING:
|
|||||||
class PusherConfig:
|
class PusherConfig:
|
||||||
"""Parameters necessary to configure a pusher."""
|
"""Parameters necessary to configure a pusher."""
|
||||||
|
|
||||||
id: Optional[str]
|
id: Optional[int]
|
||||||
user_name: str
|
user_name: str
|
||||||
|
|
||||||
profile_tag: str
|
profile_tag: str
|
||||||
|
@ -151,10 +151,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
|||||||
sql += " AND content != '{}'"
|
sql += " AND content != '{}'"
|
||||||
|
|
||||||
txn.execute(sql, (user_id,))
|
txn.execute(sql, (user_id,))
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
row["account_data_type"]: db_to_json(row["content"]) for row in rows
|
account_data_type: db_to_json(content)
|
||||||
|
for account_data_type, content in txn
|
||||||
}
|
}
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
@ -196,13 +196,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
|
|||||||
sql += " AND content != '{}'"
|
sql += " AND content != '{}'"
|
||||||
|
|
||||||
txn.execute(sql, (user_id,))
|
txn.execute(sql, (user_id,))
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
|
||||||
|
|
||||||
by_room: Dict[str, Dict[str, JsonDict]] = {}
|
by_room: Dict[str, Dict[str, JsonDict]] = {}
|
||||||
for row in rows:
|
for room_id, account_data_type, content in txn:
|
||||||
room_data = by_room.setdefault(row["room_id"], {})
|
room_data = by_room.setdefault(room_id, {})
|
||||||
|
|
||||||
room_data[row["account_data_type"]] = db_to_json(row["content"])
|
room_data[account_data_type] = db_to_json(content)
|
||||||
|
|
||||||
return by_room
|
return by_room
|
||||||
|
|
||||||
|
@ -14,17 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import (
|
from typing import TYPE_CHECKING, List, Optional, Pattern, Sequence, Tuple, cast
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Pattern,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
cast,
|
|
||||||
)
|
|
||||||
|
|
||||||
from synapse.appservice import (
|
from synapse.appservice import (
|
||||||
ApplicationService,
|
ApplicationService,
|
||||||
@ -353,21 +343,15 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
|
|
||||||
def _get_oldest_unsent_txn(
|
def _get_oldest_unsent_txn(
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Tuple[int, str]]:
|
||||||
# Monotonically increasing txn ids, so just select the smallest
|
# Monotonically increasing txn ids, so just select the smallest
|
||||||
# one in the txns table (we delete them when they are sent)
|
# one in the txns table (we delete them when they are sent)
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT * FROM application_services_txns WHERE as_id=?"
|
"SELECT txn_id, event_ids FROM application_services_txns WHERE as_id=?"
|
||||||
" ORDER BY txn_id ASC LIMIT 1",
|
" ORDER BY txn_id ASC LIMIT 1",
|
||||||
(service.id,),
|
(service.id,),
|
||||||
)
|
)
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
return cast(Optional[Tuple[int, str]], txn.fetchone())
|
||||||
if not rows:
|
|
||||||
return None
|
|
||||||
|
|
||||||
entry = rows[0]
|
|
||||||
|
|
||||||
return entry
|
|
||||||
|
|
||||||
entry = await self.db_pool.runInteraction(
|
entry = await self.db_pool.runInteraction(
|
||||||
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
|
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
|
||||||
@ -376,8 +360,9 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
if not entry:
|
if not entry:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
event_ids = db_to_json(entry["event_ids"])
|
txn_id, event_ids_str = entry
|
||||||
|
|
||||||
|
event_ids = db_to_json(event_ids_str)
|
||||||
events = await self.get_events_as_list(event_ids)
|
events = await self.get_events_as_list(event_ids)
|
||||||
|
|
||||||
# TODO: to-device messages, one-time key counts, device list summaries and unused
|
# TODO: to-device messages, one-time key counts, device list summaries and unused
|
||||||
@ -385,7 +370,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
# We likely want to populate those for reliability.
|
# We likely want to populate those for reliability.
|
||||||
return AppServiceTransaction(
|
return AppServiceTransaction(
|
||||||
service=service,
|
service=service,
|
||||||
id=entry["txn_id"],
|
id=txn_id,
|
||||||
events=events,
|
events=events,
|
||||||
ephemeral=[],
|
ephemeral=[],
|
||||||
to_device_messages=[],
|
to_device_messages=[],
|
||||||
|
@ -1413,13 +1413,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||||||
|
|
||||||
def get_devices_not_accessed_since_txn(
|
def get_devices_not_accessed_since_txn(
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Tuple[str, str]]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT user_id, device_id
|
SELECT user_id, device_id
|
||||||
FROM devices WHERE last_seen < ? AND hidden = FALSE
|
FROM devices WHERE last_seen < ? AND hidden = FALSE
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (since_ms,))
|
txn.execute(sql, (since_ms,))
|
||||||
return self.db_pool.cursor_to_dict(txn)
|
return cast(List[Tuple[str, str]], txn.fetchall())
|
||||||
|
|
||||||
rows = await self.db_pool.runInteraction(
|
rows = await self.db_pool.runInteraction(
|
||||||
"get_devices_not_accessed_since",
|
"get_devices_not_accessed_since",
|
||||||
@ -1427,11 +1427,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
devices: Dict[str, List[str]] = {}
|
devices: Dict[str, List[str]] = {}
|
||||||
for row in rows:
|
for user_id, device_id in rows:
|
||||||
# Remote devices are never stale from our point of view.
|
# Remote devices are never stale from our point of view.
|
||||||
if self.hs.is_mine_id(row["user_id"]):
|
if self.hs.is_mine_id(user_id):
|
||||||
user_devices = devices.setdefault(row["user_id"], [])
|
user_devices = devices.setdefault(user_id, [])
|
||||||
user_devices.append(row["device_id"])
|
user_devices.append(device_id)
|
||||||
|
|
||||||
return devices
|
return devices
|
||||||
|
|
||||||
|
@ -921,14 +921,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||||||
}
|
}
|
||||||
|
|
||||||
txn.execute(sql, params)
|
txn.execute(sql, params)
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
|
||||||
|
|
||||||
for row in rows:
|
for user_id, key_type, key_data, _ in txn:
|
||||||
user_id = row["user_id"]
|
|
||||||
key_type = row["keytype"]
|
|
||||||
key = db_to_json(row["keydata"])
|
|
||||||
user_keys = result.setdefault(user_id, {})
|
user_keys = result.setdefault(user_id, {})
|
||||||
user_keys[key_type] = key
|
user_keys[key_type] = db_to_json(key_data)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -988,13 +984,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||||||
query_params.extend(item)
|
query_params.extend(item)
|
||||||
|
|
||||||
txn.execute(sql, query_params)
|
txn.execute(sql, query_params)
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
|
||||||
|
|
||||||
# and add the signatures to the appropriate keys
|
# and add the signatures to the appropriate keys
|
||||||
for row in rows:
|
for target_user_id, target_device_id, key_id, signature in txn:
|
||||||
key_id: str = row["key_id"]
|
|
||||||
target_user_id: str = row["target_user_id"]
|
|
||||||
target_device_id: str = row["target_device_id"]
|
|
||||||
key_type = devices[(target_user_id, target_device_id)]
|
key_type = devices[(target_user_id, target_device_id)]
|
||||||
# We need to copy everything, because the result may have come
|
# We need to copy everything, because the result may have come
|
||||||
# from the cache. dict.copy only does a shallow copy, so we
|
# from the cache. dict.copy only does a shallow copy, so we
|
||||||
@ -1012,13 +1004,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||||||
].copy()
|
].copy()
|
||||||
if from_user_id in signatures:
|
if from_user_id in signatures:
|
||||||
user_sigs = signatures[from_user_id] = signatures[from_user_id]
|
user_sigs = signatures[from_user_id] = signatures[from_user_id]
|
||||||
user_sigs[key_id] = row["signature"]
|
user_sigs[key_id] = signature
|
||||||
else:
|
else:
|
||||||
signatures[from_user_id] = {key_id: row["signature"]}
|
signatures[from_user_id] = {key_id: signature}
|
||||||
else:
|
else:
|
||||||
target_user_key["signatures"] = {
|
target_user_key["signatures"] = {from_user_id: {key_id: signature}}
|
||||||
from_user_id: {key_id: row["signature"]}
|
|
||||||
}
|
|
||||||
|
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
|
@ -1654,8 +1654,6 @@ class PersistEventsStore:
|
|||||||
) -> None:
|
) -> None:
|
||||||
to_prefill = []
|
to_prefill = []
|
||||||
|
|
||||||
rows = []
|
|
||||||
|
|
||||||
ev_map = {e.event_id: e for e, _ in events_and_contexts}
|
ev_map = {e.event_id: e for e, _ in events_and_contexts}
|
||||||
if not ev_map:
|
if not ev_map:
|
||||||
return
|
return
|
||||||
@ -1676,10 +1674,9 @@ class PersistEventsStore:
|
|||||||
)
|
)
|
||||||
|
|
||||||
txn.execute(sql + clause, args)
|
txn.execute(sql + clause, args)
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
for event_id, redacts, rejects in txn:
|
||||||
for row in rows:
|
event = ev_map[event_id]
|
||||||
event = ev_map[row["event_id"]]
|
if not rejects and not redacts:
|
||||||
if not row["rejects"] and not row["redacts"]:
|
|
||||||
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
|
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
|
||||||
|
|
||||||
async def external_prefill() -> None:
|
async def external_prefill() -> None:
|
||||||
|
@ -434,13 +434,21 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
|
|||||||
|
|
||||||
txn = db_conn.cursor()
|
txn = db_conn.cursor()
|
||||||
txn.execute(sql, (PresenceState.OFFLINE,))
|
txn.execute(sql, (PresenceState.OFFLINE,))
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
rows = txn.fetchall()
|
||||||
txn.close()
|
txn.close()
|
||||||
|
|
||||||
for row in rows:
|
return [
|
||||||
row["currently_active"] = bool(row["currently_active"])
|
UserPresenceState(
|
||||||
|
user_id=user_id,
|
||||||
return [UserPresenceState(**row) for row in rows]
|
state=state,
|
||||||
|
last_active_ts=last_active_ts,
|
||||||
|
last_federation_update_ts=last_federation_update_ts,
|
||||||
|
last_user_sync_ts=last_user_sync_ts,
|
||||||
|
status_msg=status_msg,
|
||||||
|
currently_active=bool(currently_active),
|
||||||
|
)
|
||||||
|
for user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, status_msg, currently_active in rows
|
||||||
|
]
|
||||||
|
|
||||||
def take_presence_startup_info(self) -> List[UserPresenceState]:
|
def take_presence_startup_info(self) -> List[UserPresenceState]:
|
||||||
active_on_startup = self._presence_on_startup
|
active_on_startup = self._presence_on_startup
|
||||||
|
@ -47,6 +47,27 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# The type of a row in the pushers table.
|
||||||
|
PusherRow = Tuple[
|
||||||
|
int, # id
|
||||||
|
str, # user_name
|
||||||
|
Optional[int], # access_token
|
||||||
|
str, # profile_tag
|
||||||
|
str, # kind
|
||||||
|
str, # app_id
|
||||||
|
str, # app_display_name
|
||||||
|
str, # device_display_name
|
||||||
|
str, # pushkey
|
||||||
|
int, # ts
|
||||||
|
str, # lang
|
||||||
|
str, # data
|
||||||
|
int, # last_stream_ordering
|
||||||
|
int, # last_success
|
||||||
|
int, # failing_since
|
||||||
|
bool, # enabled
|
||||||
|
str, # device_id
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class PusherWorkerStore(SQLBaseStore):
|
class PusherWorkerStore(SQLBaseStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -83,30 +104,66 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
self._remove_deleted_email_pushers,
|
self._remove_deleted_email_pushers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
|
def _decode_pushers_rows(
|
||||||
|
self,
|
||||||
|
rows: Iterable[PusherRow],
|
||||||
|
) -> Iterator[PusherConfig]:
|
||||||
"""JSON-decode the data in the rows returned from the `pushers` table
|
"""JSON-decode the data in the rows returned from the `pushers` table
|
||||||
|
|
||||||
Drops any rows whose data cannot be decoded
|
Drops any rows whose data cannot be decoded
|
||||||
"""
|
"""
|
||||||
for r in rows:
|
for (
|
||||||
data_json = r["data"]
|
id,
|
||||||
|
user_name,
|
||||||
|
access_token,
|
||||||
|
profile_tag,
|
||||||
|
kind,
|
||||||
|
app_id,
|
||||||
|
app_display_name,
|
||||||
|
device_display_name,
|
||||||
|
pushkey,
|
||||||
|
ts,
|
||||||
|
lang,
|
||||||
|
data,
|
||||||
|
last_stream_ordering,
|
||||||
|
last_success,
|
||||||
|
failing_since,
|
||||||
|
enabled,
|
||||||
|
device_id,
|
||||||
|
) in rows:
|
||||||
try:
|
try:
|
||||||
r["data"] = db_to_json(data_json)
|
data_json = db_to_json(data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Invalid JSON in data for pusher %d: %s, %s",
|
"Invalid JSON in data for pusher %d: %s, %s",
|
||||||
r["id"],
|
id,
|
||||||
data_json,
|
data,
|
||||||
e.args[0],
|
e.args[0],
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If we're using SQLite, then boolean values are integers. This is
|
yield PusherConfig(
|
||||||
# troublesome since some code using the return value of this method might
|
id=id,
|
||||||
# expect it to be a boolean, or will expose it to clients (in responses).
|
user_name=user_name,
|
||||||
r["enabled"] = bool(r["enabled"])
|
profile_tag=profile_tag,
|
||||||
|
kind=kind,
|
||||||
yield PusherConfig(**r)
|
app_id=app_id,
|
||||||
|
app_display_name=app_display_name,
|
||||||
|
device_display_name=device_display_name,
|
||||||
|
pushkey=pushkey,
|
||||||
|
ts=ts,
|
||||||
|
lang=lang,
|
||||||
|
data=data_json,
|
||||||
|
last_stream_ordering=last_stream_ordering,
|
||||||
|
last_success=last_success,
|
||||||
|
failing_since=failing_since,
|
||||||
|
# If we're using SQLite, then boolean values are integers. This is
|
||||||
|
# troublesome since some code using the return value of this method might
|
||||||
|
# expect it to be a boolean, or will expose it to clients (in responses).
|
||||||
|
enabled=bool(enabled),
|
||||||
|
device_id=device_id,
|
||||||
|
access_token=access_token,
|
||||||
|
)
|
||||||
|
|
||||||
def get_pushers_stream_token(self) -> int:
|
def get_pushers_stream_token(self) -> int:
|
||||||
return self._pushers_id_gen.get_current_token()
|
return self._pushers_id_gen.get_current_token()
|
||||||
@ -136,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
The pushers for which the given columns have the given values.
|
The pushers for which the given columns have the given values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
|
def get_pushers_by_txn(txn: LoggingTransaction) -> List[PusherRow]:
|
||||||
# We could technically use simple_select_list here, but we need to call
|
# We could technically use simple_select_list here, but we need to call
|
||||||
# COALESCE on the 'enabled' column. While it is technically possible to give
|
# COALESCE on the 'enabled' column. While it is technically possible to give
|
||||||
# simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
|
# simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
|
||||||
@ -154,7 +211,7 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
txn.execute(sql, list(keyvalues.values()))
|
txn.execute(sql, list(keyvalues.values()))
|
||||||
|
|
||||||
return self.db_pool.cursor_to_dict(txn)
|
return cast(List[PusherRow], txn.fetchall())
|
||||||
|
|
||||||
ret = await self.db_pool.runInteraction(
|
ret = await self.db_pool.runInteraction(
|
||||||
desc="get_pushers_by",
|
desc="get_pushers_by",
|
||||||
@ -164,14 +221,22 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
return self._decode_pushers_rows(ret)
|
return self._decode_pushers_rows(ret)
|
||||||
|
|
||||||
async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
|
async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
|
||||||
def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
|
def get_enabled_pushers_txn(txn: LoggingTransaction) -> List[PusherRow]:
|
||||||
txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
|
txn.execute(
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
"""
|
||||||
|
SELECT id, user_name, access_token, profile_tag, kind, app_id,
|
||||||
|
app_display_name, device_display_name, pushkey, ts, lang, data,
|
||||||
|
last_stream_ordering, last_success, failing_since,
|
||||||
|
enabled, device_id
|
||||||
|
FROM pushers WHERE COALESCE(enabled, TRUE)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return cast(List[PusherRow], txn.fetchall())
|
||||||
|
|
||||||
return self._decode_pushers_rows(rows)
|
return self._decode_pushers_rows(
|
||||||
|
await self.db_pool.runInteraction(
|
||||||
return await self.db_pool.runInteraction(
|
"get_enabled_pushers", get_enabled_pushers_txn
|
||||||
"get_enabled_pushers", get_enabled_pushers_txn
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_all_updated_pushers_rows(
|
async def get_all_updated_pushers_rows(
|
||||||
@ -304,7 +369,7 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def get_throttle_params_by_room(
|
async def get_throttle_params_by_room(
|
||||||
self, pusher_id: str
|
self, pusher_id: int
|
||||||
) -> Dict[str, ThrottleParams]:
|
) -> Dict[str, ThrottleParams]:
|
||||||
res = await self.db_pool.simple_select_list(
|
res = await self.db_pool.simple_select_list(
|
||||||
"pusher_throttle",
|
"pusher_throttle",
|
||||||
@ -323,7 +388,7 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
return params_by_room
|
return params_by_room
|
||||||
|
|
||||||
async def set_throttle_params(
|
async def set_throttle_params(
|
||||||
self, pusher_id: str, room_id: str, params: ThrottleParams
|
self, pusher_id: int, room_id: str, params: ThrottleParams
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.db_pool.simple_upsert(
|
await self.db_pool.simple_upsert(
|
||||||
"pusher_throttle",
|
"pusher_throttle",
|
||||||
@ -534,7 +599,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
|
|||||||
(last_pusher_id, batch_size),
|
(last_pusher_id, batch_size),
|
||||||
)
|
)
|
||||||
|
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
rows = txn.fetchall()
|
||||||
if len(rows) == 0:
|
if len(rows) == 0:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@ -550,19 +615,19 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
|
|||||||
txn=txn,
|
txn=txn,
|
||||||
table="pushers",
|
table="pushers",
|
||||||
key_names=("id",),
|
key_names=("id",),
|
||||||
key_values=[(row["pusher_id"],) for row in rows],
|
key_values=[row[0] for row in rows],
|
||||||
value_names=("device_id", "access_token"),
|
value_names=("device_id", "access_token"),
|
||||||
# If there was already a device_id on the pusher, we only want to clear
|
# If there was already a device_id on the pusher, we only want to clear
|
||||||
# the access_token column, so we keep the existing device_id. Otherwise,
|
# the access_token column, so we keep the existing device_id. Otherwise,
|
||||||
# we set the device_id we got from joining the access_tokens table.
|
# we set the device_id we got from joining the access_tokens table.
|
||||||
value_values=[
|
value_values=[
|
||||||
(row["pusher_device_id"] or row["token_device_id"], None)
|
(pusher_device_id or token_device_id, None)
|
||||||
for row in rows
|
for _, pusher_device_id, token_device_id in rows
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.db_pool.updates._background_update_progress_txn(
|
self.db_pool.updates._background_update_progress_txn(
|
||||||
txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["pusher_id"]}
|
txn, "set_device_id_for_pushers", {"pusher_id": rows[-1][0]}
|
||||||
)
|
)
|
||||||
|
|
||||||
return len(rows)
|
return len(rows)
|
||||||
|
@ -313,25 +313,25 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||||||
) -> Sequence[JsonMapping]:
|
) -> Sequence[JsonMapping]:
|
||||||
"""See get_linearized_receipts_for_room"""
|
"""See get_linearized_receipts_for_room"""
|
||||||
|
|
||||||
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
|
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]:
|
||||||
if from_key:
|
if from_key:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT * FROM receipts_linearized WHERE"
|
"SELECT receipt_type, user_id, event_id, data"
|
||||||
|
" FROM receipts_linearized WHERE"
|
||||||
" room_id = ? AND stream_id > ? AND stream_id <= ?"
|
" room_id = ? AND stream_id > ? AND stream_id <= ?"
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.execute(sql, (room_id, from_key, to_key))
|
txn.execute(sql, (room_id, from_key, to_key))
|
||||||
else:
|
else:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT * FROM receipts_linearized WHERE"
|
"SELECT receipt_type, user_id, event_id, data"
|
||||||
|
" FROM receipts_linearized WHERE"
|
||||||
" room_id = ? AND stream_id <= ?"
|
" room_id = ? AND stream_id <= ?"
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.execute(sql, (room_id, to_key))
|
txn.execute(sql, (room_id, to_key))
|
||||||
|
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
return cast(List[Tuple[str, str, str, str]], txn.fetchall())
|
||||||
|
|
||||||
return rows
|
|
||||||
|
|
||||||
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
|
rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
|
||||||
|
|
||||||
@ -339,10 +339,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
content: JsonDict = {}
|
content: JsonDict = {}
|
||||||
for row in rows:
|
for receipt_type, user_id, event_id, data in rows:
|
||||||
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
|
content.setdefault(event_id, {}).setdefault(receipt_type, {})[
|
||||||
row["user_id"]
|
user_id
|
||||||
] = db_to_json(row["data"])
|
] = db_to_json(data)
|
||||||
|
|
||||||
return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}]
|
return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}]
|
||||||
|
|
||||||
@ -357,10 +357,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||||||
if not room_ids:
|
if not room_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
|
def f(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> List[Tuple[str, str, str, str, Optional[str], str]]:
|
||||||
if from_key:
|
if from_key:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT * FROM receipts_linearized WHERE
|
SELECT room_id, receipt_type, user_id, event_id, thread_id, data
|
||||||
|
FROM receipts_linearized WHERE
|
||||||
stream_id > ? AND stream_id <= ? AND
|
stream_id > ? AND stream_id <= ? AND
|
||||||
"""
|
"""
|
||||||
clause, args = make_in_list_sql_clause(
|
clause, args = make_in_list_sql_clause(
|
||||||
@ -370,7 +373,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||||||
txn.execute(sql + clause, [from_key, to_key] + list(args))
|
txn.execute(sql + clause, [from_key, to_key] + list(args))
|
||||||
else:
|
else:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT * FROM receipts_linearized WHERE
|
SELECT room_id, receipt_type, user_id, event_id, thread_id, data
|
||||||
|
FROM receipts_linearized WHERE
|
||||||
stream_id <= ? AND
|
stream_id <= ? AND
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -380,29 +384,31 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
txn.execute(sql + clause, [to_key] + list(args))
|
txn.execute(sql + clause, [to_key] + list(args))
|
||||||
|
|
||||||
return self.db_pool.cursor_to_dict(txn)
|
return cast(
|
||||||
|
List[Tuple[str, str, str, str, Optional[str], str]], txn.fetchall()
|
||||||
|
)
|
||||||
|
|
||||||
txn_results = await self.db_pool.runInteraction(
|
txn_results = await self.db_pool.runInteraction(
|
||||||
"_get_linearized_receipts_for_rooms", f
|
"_get_linearized_receipts_for_rooms", f
|
||||||
)
|
)
|
||||||
|
|
||||||
results: JsonDict = {}
|
results: JsonDict = {}
|
||||||
for row in txn_results:
|
for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results:
|
||||||
# We want a single event per room, since we want to batch the
|
# We want a single event per room, since we want to batch the
|
||||||
# receipts by room, event and type.
|
# receipts by room, event and type.
|
||||||
room_event = results.setdefault(
|
room_event = results.setdefault(
|
||||||
row["room_id"],
|
room_id,
|
||||||
{"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
|
{"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
|
||||||
)
|
)
|
||||||
|
|
||||||
# The content is of the form:
|
# The content is of the form:
|
||||||
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
|
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
|
||||||
event_entry = room_event["content"].setdefault(row["event_id"], {})
|
event_entry = room_event["content"].setdefault(event_id, {})
|
||||||
receipt_type = event_entry.setdefault(row["receipt_type"], {})
|
receipt_type_dict = event_entry.setdefault(receipt_type, {})
|
||||||
|
|
||||||
receipt_type[row["user_id"]] = db_to_json(row["data"])
|
receipt_type_dict[user_id] = db_to_json(data)
|
||||||
if row["thread_id"]:
|
if thread_id:
|
||||||
receipt_type[row["user_id"]]["thread_id"] = row["thread_id"]
|
receipt_type_dict[user_id]["thread_id"] = thread_id
|
||||||
|
|
||||||
results = {
|
results = {
|
||||||
room_id: [results[room_id]] if room_id in results else []
|
room_id: [results[room_id]] if room_id in results else []
|
||||||
@ -428,10 +434,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||||||
A dictionary of roomids to a list of receipts.
|
A dictionary of roomids to a list of receipts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
|
def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
|
||||||
if from_key:
|
if from_key:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT * FROM receipts_linearized WHERE
|
SELECT room_id, receipt_type, user_id, event_id, data
|
||||||
|
FROM receipts_linearized WHERE
|
||||||
stream_id > ? AND stream_id <= ?
|
stream_id > ? AND stream_id <= ?
|
||||||
ORDER BY stream_id DESC
|
ORDER BY stream_id DESC
|
||||||
LIMIT 100
|
LIMIT 100
|
||||||
@ -439,7 +446,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||||||
txn.execute(sql, [from_key, to_key])
|
txn.execute(sql, [from_key, to_key])
|
||||||
else:
|
else:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT * FROM receipts_linearized WHERE
|
SELECT room_id, receipt_type, user_id, event_id, data
|
||||||
|
FROM receipts_linearized WHERE
|
||||||
stream_id <= ?
|
stream_id <= ?
|
||||||
ORDER BY stream_id DESC
|
ORDER BY stream_id DESC
|
||||||
LIMIT 100
|
LIMIT 100
|
||||||
@ -447,27 +455,27 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
txn.execute(sql, [to_key])
|
txn.execute(sql, [to_key])
|
||||||
|
|
||||||
return self.db_pool.cursor_to_dict(txn)
|
return cast(List[Tuple[str, str, str, str, str]], txn.fetchall())
|
||||||
|
|
||||||
txn_results = await self.db_pool.runInteraction(
|
txn_results = await self.db_pool.runInteraction(
|
||||||
"get_linearized_receipts_for_all_rooms", f
|
"get_linearized_receipts_for_all_rooms", f
|
||||||
)
|
)
|
||||||
|
|
||||||
results: JsonDict = {}
|
results: JsonDict = {}
|
||||||
for row in txn_results:
|
for room_id, receipt_type, user_id, event_id, data in txn_results:
|
||||||
# We want a single event per room, since we want to batch the
|
# We want a single event per room, since we want to batch the
|
||||||
# receipts by room, event and type.
|
# receipts by room, event and type.
|
||||||
room_event = results.setdefault(
|
room_event = results.setdefault(
|
||||||
row["room_id"],
|
room_id,
|
||||||
{"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
|
{"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
|
||||||
)
|
)
|
||||||
|
|
||||||
# The content is of the form:
|
# The content is of the form:
|
||||||
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
|
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
|
||||||
event_entry = room_event["content"].setdefault(row["event_id"], {})
|
event_entry = room_event["content"].setdefault(event_id, {})
|
||||||
receipt_type = event_entry.setdefault(row["receipt_type"], {})
|
receipt_type_dict = event_entry.setdefault(receipt_type, {})
|
||||||
|
|
||||||
receipt_type[row["user_id"]] = db_to_json(row["data"])
|
receipt_type_dict[user_id] = db_to_json(data)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -195,7 +195,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
|
async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
|
||||||
"""Returns info about the user account, if it exists."""
|
"""Returns info about the user account, if it exists."""
|
||||||
|
|
||||||
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
|
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[UserInfo]:
|
||||||
# We could technically use simple_select_one here, but it would not perform
|
# We could technically use simple_select_one here, but it would not perform
|
||||||
# the COALESCEs (unless hacked into the column names), which could yield
|
# the COALESCEs (unless hacked into the column names), which could yield
|
||||||
# confusing results.
|
# confusing results.
|
||||||
@ -213,35 +213,46 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
(user_id,),
|
(user_id,),
|
||||||
)
|
)
|
||||||
|
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
row = txn.fetchone()
|
||||||
|
if not row:
|
||||||
if len(rows) == 0:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return rows[0]
|
(
|
||||||
|
name,
|
||||||
|
is_guest,
|
||||||
|
admin,
|
||||||
|
consent_version,
|
||||||
|
consent_ts,
|
||||||
|
consent_server_notice_sent,
|
||||||
|
appservice_id,
|
||||||
|
creation_ts,
|
||||||
|
user_type,
|
||||||
|
deactivated,
|
||||||
|
shadow_banned,
|
||||||
|
approved,
|
||||||
|
locked,
|
||||||
|
) = row
|
||||||
|
|
||||||
row = await self.db_pool.runInteraction(
|
return UserInfo(
|
||||||
|
appservice_id=appservice_id,
|
||||||
|
consent_server_notice_sent=consent_server_notice_sent,
|
||||||
|
consent_version=consent_version,
|
||||||
|
consent_ts=consent_ts,
|
||||||
|
creation_ts=creation_ts,
|
||||||
|
is_admin=bool(admin),
|
||||||
|
is_deactivated=bool(deactivated),
|
||||||
|
is_guest=bool(is_guest),
|
||||||
|
is_shadow_banned=bool(shadow_banned),
|
||||||
|
user_id=UserID.from_string(name),
|
||||||
|
user_type=user_type,
|
||||||
|
approved=bool(approved),
|
||||||
|
locked=bool(locked),
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
desc="get_user_by_id",
|
desc="get_user_by_id",
|
||||||
func=get_user_by_id_txn,
|
func=get_user_by_id_txn,
|
||||||
)
|
)
|
||||||
if row is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return UserInfo(
|
|
||||||
appservice_id=row["appservice_id"],
|
|
||||||
consent_server_notice_sent=row["consent_server_notice_sent"],
|
|
||||||
consent_version=row["consent_version"],
|
|
||||||
consent_ts=row["consent_ts"],
|
|
||||||
creation_ts=row["creation_ts"],
|
|
||||||
is_admin=bool(row["admin"]),
|
|
||||||
is_deactivated=bool(row["deactivated"]),
|
|
||||||
is_guest=bool(row["is_guest"]),
|
|
||||||
is_shadow_banned=bool(row["shadow_banned"]),
|
|
||||||
user_id=UserID.from_string(row["name"]),
|
|
||||||
user_type=row["user_type"],
|
|
||||||
approved=bool(row["approved"]),
|
|
||||||
locked=bool(row["locked"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def is_trial_user(self, user_id: str) -> bool:
|
async def is_trial_user(self, user_id: str) -> bool:
|
||||||
"""Checks if user is in the "trial" period, i.e. within the first
|
"""Checks if user is in the "trial" period, i.e. within the first
|
||||||
@ -579,16 +590,31 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
txn.execute(sql, (token,))
|
txn.execute(sql, (token,))
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
row = txn.fetchone()
|
||||||
|
|
||||||
if rows:
|
if row:
|
||||||
row = rows[0]
|
(
|
||||||
|
user_id,
|
||||||
|
is_guest,
|
||||||
|
shadow_banned,
|
||||||
|
token_id,
|
||||||
|
device_id,
|
||||||
|
valid_until_ms,
|
||||||
|
token_owner,
|
||||||
|
token_used,
|
||||||
|
) = row
|
||||||
|
|
||||||
# This field is nullable, ensure it comes out as a boolean
|
return TokenLookupResult(
|
||||||
if row["token_used"] is None:
|
user_id=user_id,
|
||||||
row["token_used"] = False
|
is_guest=is_guest,
|
||||||
|
shadow_banned=shadow_banned,
|
||||||
return TokenLookupResult(**row)
|
token_id=token_id,
|
||||||
|
device_id=device_id,
|
||||||
|
valid_until_ms=valid_until_ms,
|
||||||
|
token_owner=token_owner,
|
||||||
|
# This field is nullable, ensure it comes out as a boolean
|
||||||
|
token_used=bool(token_used),
|
||||||
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -833,11 +859,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
"""Counts all users registered on the homeserver."""
|
"""Counts all users registered on the homeserver."""
|
||||||
|
|
||||||
def _count_users(txn: LoggingTransaction) -> int:
|
def _count_users(txn: LoggingTransaction) -> int:
|
||||||
txn.execute("SELECT COUNT(*) AS users FROM users")
|
txn.execute("SELECT COUNT(*) FROM users")
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
row = txn.fetchone()
|
||||||
if rows:
|
assert row is not None
|
||||||
return rows[0]["users"]
|
return row[0]
|
||||||
return 0
|
|
||||||
|
|
||||||
return await self.db_pool.runInteraction("count_users", _count_users)
|
return await self.db_pool.runInteraction("count_users", _count_users)
|
||||||
|
|
||||||
@ -891,11 +916,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
"""Counts all users without a special user_type registered on the homeserver."""
|
"""Counts all users without a special user_type registered on the homeserver."""
|
||||||
|
|
||||||
def _count_users(txn: LoggingTransaction) -> int:
|
def _count_users(txn: LoggingTransaction) -> int:
|
||||||
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
|
txn.execute("SELECT COUNT(*) FROM users where user_type is null")
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
row = txn.fetchone()
|
||||||
if rows:
|
assert row is not None
|
||||||
return rows[0]["users"]
|
return row[0]
|
||||||
return 0
|
|
||||||
|
|
||||||
return await self.db_pool.runInteraction("count_real_users", _count_users)
|
return await self.db_pool.runInteraction("count_real_users", _count_users)
|
||||||
|
|
||||||
@ -1252,12 +1276,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
)
|
)
|
||||||
txn.execute(sql, [])
|
txn.execute(sql, [])
|
||||||
|
|
||||||
res = self.db_pool.cursor_to_dict(txn)
|
for (name,) in txn.fetchall():
|
||||||
if res:
|
self.set_expiration_date_for_user_txn(txn, name, use_delta=True)
|
||||||
for user in res:
|
|
||||||
self.set_expiration_date_for_user_txn(
|
|
||||||
txn, user["name"], use_delta=True
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"get_users_with_no_expiration_date",
|
"get_users_with_no_expiration_date",
|
||||||
@ -1963,11 +1983,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
(user_id,),
|
(user_id,),
|
||||||
)
|
)
|
||||||
|
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
row = txn.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
|
||||||
# We cast to bool because the value returned by the database engine might
|
# We cast to bool because the value returned by the database engine might
|
||||||
# be an integer if we're using SQLite.
|
# be an integer if we're using SQLite.
|
||||||
return bool(rows[0]["approved"])
|
return bool(row[0])
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
desc="is_user_pending_approval",
|
desc="is_user_pending_approval",
|
||||||
@ -2045,22 +2066,22 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
|||||||
(last_user, batch_size),
|
(last_user, batch_size),
|
||||||
)
|
)
|
||||||
|
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
rows = txn.fetchall()
|
||||||
|
|
||||||
if not rows:
|
if not rows:
|
||||||
return True, 0
|
return True, 0
|
||||||
|
|
||||||
rows_processed_nb = 0
|
rows_processed_nb = 0
|
||||||
|
|
||||||
for user in rows:
|
for name, count_tokens, count_threepids in rows:
|
||||||
if not user["count_tokens"] and not user["count_threepids"]:
|
if not count_tokens and not count_threepids:
|
||||||
self.set_user_deactivated_status_txn(txn, user["name"], True)
|
self.set_user_deactivated_status_txn(txn, name, True)
|
||||||
rows_processed_nb += 1
|
rows_processed_nb += 1
|
||||||
|
|
||||||
logger.info("Marked %d rows as deactivated", rows_processed_nb)
|
logger.info("Marked %d rows as deactivated", rows_processed_nb)
|
||||||
|
|
||||||
self.db_pool.updates._background_update_progress_txn(
|
self.db_pool.updates._background_update_progress_txn(
|
||||||
txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
|
txn, "users_set_deactivated_flag", {"user_id": rows[-1][0]}
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_size > len(rows):
|
if batch_size > len(rows):
|
||||||
|
@ -831,7 +831,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
|
|
||||||
def get_retention_policy_for_room_txn(
|
def get_retention_policy_for_room_txn(
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
) -> List[Dict[str, Optional[int]]]:
|
) -> Optional[Tuple[Optional[int], Optional[int]]]:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT min_lifetime, max_lifetime FROM room_retention
|
SELECT min_lifetime, max_lifetime FROM room_retention
|
||||||
@ -841,7 +841,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
(room_id,),
|
(room_id,),
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.db_pool.cursor_to_dict(txn)
|
return cast(Optional[Tuple[Optional[int], Optional[int]]], txn.fetchone())
|
||||||
|
|
||||||
ret = await self.db_pool.runInteraction(
|
ret = await self.db_pool.runInteraction(
|
||||||
"get_retention_policy_for_room",
|
"get_retention_policy_for_room",
|
||||||
@ -856,8 +856,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
max_lifetime=self.config.retention.retention_default_max_lifetime,
|
max_lifetime=self.config.retention.retention_default_max_lifetime,
|
||||||
)
|
)
|
||||||
|
|
||||||
min_lifetime = ret[0]["min_lifetime"]
|
min_lifetime, max_lifetime = ret
|
||||||
max_lifetime = ret[0]["max_lifetime"]
|
|
||||||
|
|
||||||
# If one of the room's policy's attributes isn't defined, use the matching
|
# If one of the room's policy's attributes isn't defined, use the matching
|
||||||
# attribute from the default policy.
|
# attribute from the default policy.
|
||||||
@ -1162,14 +1161,13 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
|
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
|
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
rooms_dict = {
|
||||||
rooms_dict = {}
|
room_id: RetentionPolicy(
|
||||||
|
min_lifetime=min_lifetime,
|
||||||
for row in rows:
|
max_lifetime=max_lifetime,
|
||||||
rooms_dict[row["room_id"]] = RetentionPolicy(
|
|
||||||
min_lifetime=row["min_lifetime"],
|
|
||||||
max_lifetime=row["max_lifetime"],
|
|
||||||
)
|
)
|
||||||
|
for room_id, min_lifetime, max_lifetime in txn
|
||||||
|
}
|
||||||
|
|
||||||
if include_null:
|
if include_null:
|
||||||
# If required, do a second query that retrieves all of the rooms we know
|
# If required, do a second query that retrieves all of the rooms we know
|
||||||
@ -1178,13 +1176,11 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
|
|
||||||
txn.execute(sql)
|
txn.execute(sql)
|
||||||
|
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
|
||||||
|
|
||||||
# If a room isn't already in the dict (i.e. it doesn't have a retention
|
# If a room isn't already in the dict (i.e. it doesn't have a retention
|
||||||
# policy in its state), add it with a null policy.
|
# policy in its state), add it with a null policy.
|
||||||
for row in rows:
|
for (room_id,) in txn:
|
||||||
if row["room_id"] not in rooms_dict:
|
if room_id not in rooms_dict:
|
||||||
rooms_dict[row["room_id"]] = RetentionPolicy()
|
rooms_dict[room_id] = RetentionPolicy()
|
||||||
|
|
||||||
return rooms_dict
|
return rooms_dict
|
||||||
|
|
||||||
@ -1703,24 +1699,24 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||||||
(last_room, batch_size),
|
(last_room, batch_size),
|
||||||
)
|
)
|
||||||
|
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
rows = txn.fetchall()
|
||||||
|
|
||||||
if not rows:
|
if not rows:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
for row in rows:
|
for room_id, event_id, json in rows:
|
||||||
if not row["json"]:
|
if not json:
|
||||||
retention_policy = {}
|
retention_policy = {}
|
||||||
else:
|
else:
|
||||||
ev = db_to_json(row["json"])
|
ev = db_to_json(json)
|
||||||
retention_policy = ev["content"]
|
retention_policy = ev["content"]
|
||||||
|
|
||||||
self.db_pool.simple_insert_txn(
|
self.db_pool.simple_insert_txn(
|
||||||
txn=txn,
|
txn=txn,
|
||||||
table="room_retention",
|
table="room_retention",
|
||||||
values={
|
values={
|
||||||
"room_id": row["room_id"],
|
"room_id": room_id,
|
||||||
"event_id": row["event_id"],
|
"event_id": event_id,
|
||||||
"min_lifetime": retention_policy.get("min_lifetime"),
|
"min_lifetime": retention_policy.get("min_lifetime"),
|
||||||
"max_lifetime": retention_policy.get("max_lifetime"),
|
"max_lifetime": retention_policy.get("max_lifetime"),
|
||||||
},
|
},
|
||||||
@ -1729,7 +1725,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||||||
logger.info("Inserted %d rows into room_retention", len(rows))
|
logger.info("Inserted %d rows into room_retention", len(rows))
|
||||||
|
|
||||||
self.db_pool.updates._background_update_progress_txn(
|
self.db_pool.updates._background_update_progress_txn(
|
||||||
txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
|
txn, "insert_room_retention", {"room_id": rows[-1][0]}
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_size > len(rows):
|
if batch_size > len(rows):
|
||||||
|
@ -1349,18 +1349,16 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
|||||||
|
|
||||||
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
|
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
|
||||||
|
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
rows = txn.fetchall()
|
||||||
if not rows:
|
if not rows:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
min_stream_id = rows[-1]["stream_ordering"]
|
min_stream_id = rows[-1][0]
|
||||||
|
|
||||||
to_update = []
|
to_update = []
|
||||||
for row in rows:
|
for _, event_id, room_id, json in rows:
|
||||||
event_id = row["event_id"]
|
|
||||||
room_id = row["room_id"]
|
|
||||||
try:
|
try:
|
||||||
event_json = db_to_json(row["json"])
|
event_json = db_to_json(json)
|
||||||
content = event_json["content"]
|
content = event_json["content"]
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
@ -179,22 +179,24 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
|||||||
# store_search_entries_txn with a generator function, but that
|
# store_search_entries_txn with a generator function, but that
|
||||||
# would mean having two cursors open on the database at once.
|
# would mean having two cursors open on the database at once.
|
||||||
# Instead we just build a list of results.
|
# Instead we just build a list of results.
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
rows = txn.fetchall()
|
||||||
if not rows:
|
if not rows:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
min_stream_id = rows[-1]["stream_ordering"]
|
min_stream_id = rows[-1][0]
|
||||||
|
|
||||||
event_search_rows = []
|
event_search_rows = []
|
||||||
for row in rows:
|
for (
|
||||||
|
stream_ordering,
|
||||||
|
event_id,
|
||||||
|
room_id,
|
||||||
|
etype,
|
||||||
|
json,
|
||||||
|
origin_server_ts,
|
||||||
|
) in rows:
|
||||||
try:
|
try:
|
||||||
event_id = row["event_id"]
|
|
||||||
room_id = row["room_id"]
|
|
||||||
etype = row["type"]
|
|
||||||
stream_ordering = row["stream_ordering"]
|
|
||||||
origin_server_ts = row["origin_server_ts"]
|
|
||||||
try:
|
try:
|
||||||
event_json = db_to_json(row["json"])
|
event_json = db_to_json(json)
|
||||||
content = event_json["content"]
|
content = event_json["content"]
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, cast
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import (
|
from synapse.storage.database import (
|
||||||
@ -27,6 +27,8 @@ from synapse.util import json_encoder
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
ScheduledTaskRow = Tuple[str, str, str, int, str, str, str, str]
|
||||||
|
|
||||||
|
|
||||||
class TaskSchedulerWorkerStore(SQLBaseStore):
|
class TaskSchedulerWorkerStore(SQLBaseStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -38,13 +40,18 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
|
|||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_row_to_task(row: Dict[str, Any]) -> ScheduledTask:
|
def _convert_row_to_task(row: ScheduledTaskRow) -> ScheduledTask:
|
||||||
row["status"] = TaskStatus(row["status"])
|
task_id, action, status, timestamp, resource_id, params, result, error = row
|
||||||
if row["params"] is not None:
|
return ScheduledTask(
|
||||||
row["params"] = db_to_json(row["params"])
|
id=task_id,
|
||||||
if row["result"] is not None:
|
action=action,
|
||||||
row["result"] = db_to_json(row["result"])
|
status=TaskStatus(status),
|
||||||
return ScheduledTask(**row)
|
timestamp=timestamp,
|
||||||
|
resource_id=resource_id,
|
||||||
|
params=db_to_json(params) if params is not None else None,
|
||||||
|
result=db_to_json(result) if result is not None else None,
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
|
||||||
async def get_scheduled_tasks(
|
async def get_scheduled_tasks(
|
||||||
self,
|
self,
|
||||||
@ -68,7 +75,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
|
|||||||
Returns: a list of `ScheduledTask`, ordered by increasing timestamps
|
Returns: a list of `ScheduledTask`, ordered by increasing timestamps
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
|
def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[ScheduledTaskRow]:
|
||||||
clauses: List[str] = []
|
clauses: List[str] = []
|
||||||
args: List[Any] = []
|
args: List[Any] = []
|
||||||
if resource_id:
|
if resource_id:
|
||||||
@ -101,7 +108,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
|
|||||||
args.append(limit)
|
args.append(limit)
|
||||||
|
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
return self.db_pool.cursor_to_dict(txn)
|
return cast(List[ScheduledTaskRow], txn.fetchall())
|
||||||
|
|
||||||
rows = await self.db_pool.runInteraction(
|
rows = await self.db_pool.runInteraction(
|
||||||
"get_scheduled_tasks", get_scheduled_tasks_txn
|
"get_scheduled_tasks", get_scheduled_tasks_txn
|
||||||
@ -193,7 +200,22 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
|
|||||||
desc="get_scheduled_task",
|
desc="get_scheduled_task",
|
||||||
)
|
)
|
||||||
|
|
||||||
return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None
|
return (
|
||||||
|
TaskSchedulerWorkerStore._convert_row_to_task(
|
||||||
|
(
|
||||||
|
row["id"],
|
||||||
|
row["action"],
|
||||||
|
row["status"],
|
||||||
|
row["timestamp"],
|
||||||
|
row["resource_id"],
|
||||||
|
row["params"],
|
||||||
|
row["result"],
|
||||||
|
row["error"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if row
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
async def delete_scheduled_task(self, id: str) -> None:
|
async def delete_scheduled_task(self, id: str) -> None:
|
||||||
"""Delete a specific task from its id.
|
"""Delete a specific task from its id.
|
||||||
|
Loading…
Reference in New Issue
Block a user