Remove more usages of cursor_to_dict. (#16551)

Mostly to improve type safety.
This commit is contained in:
Patrick Cloke 2023-10-26 15:12:28 -04:00 committed by GitHub
parent 85e5f2dc25
commit 679c691f6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 193 additions and 134 deletions

1
changelog.d/16551.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -19,6 +19,8 @@ import logging
import urllib.parse import urllib.parse
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
import attr
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, CodeMessageException,
Codes, Codes,
@ -357,9 +359,9 @@ class IdentityHandler:
# Check to see if a session already exists and that it is not yet # Check to see if a session already exists and that it is not yet
# marked as validated # marked as validated
if session and session.get("validated_at") is None: if session and session.validated_at is None:
session_id = session["session_id"] session_id = session.session_id
last_send_attempt = session["last_send_attempt"] last_send_attempt = session.last_send_attempt
# Check that the send_attempt is higher than previous attempts # Check that the send_attempt is higher than previous attempts
if send_attempt <= last_send_attempt: if send_attempt <= last_send_attempt:
@ -480,7 +482,6 @@ class IdentityHandler:
# We don't actually know which medium this 3PID is. Thus we first assume it's email, # We don't actually know which medium this 3PID is. Thus we first assume it's email,
# and if validation fails we try msisdn # and if validation fails we try msisdn
validation_session = None
# Try to validate as email # Try to validate as email
if self.hs.config.email.can_verify_email: if self.hs.config.email.can_verify_email:
@ -488,19 +489,18 @@ class IdentityHandler:
validation_session = await self.store.get_threepid_validation_session( validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True "email", client_secret, sid=sid, validated=True
) )
if validation_session: if validation_session:
return validation_session return attr.asdict(validation_session)
# Try to validate as msisdn # Try to validate as msisdn
if self.hs.config.registration.account_threepid_delegate_msisdn: if self.hs.config.registration.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server # Ask our delegated msisdn identity server
validation_session = await self.threepid_from_creds( return await self.threepid_from_creds(
self.hs.config.registration.account_threepid_delegate_msisdn, self.hs.config.registration.account_threepid_delegate_msisdn,
threepid_creds, threepid_creds,
) )
return validation_session return None
async def proxy_msisdn_submit_token( async def proxy_msisdn_submit_token(
self, id_server: str, client_secret: str, sid: str, token: str self, id_server: str, client_secret: str, sid: str, token: str

View File

@ -187,9 +187,9 @@ class _BaseThreepidAuthChecker:
if row: if row:
threepid = { threepid = {
"medium": row["medium"], "medium": row.medium,
"address": row["address"], "address": row.address,
"validated_at": row["validated_at"], "validated_at": row.validated_at,
} }
# Valid threepid returned, delete from the db # Valid threepid returned, delete from the db

View File

@ -949,10 +949,7 @@ class MediaRepository:
deleted = 0 deleted = 0
for media in old_media: for origin, media_id, file_id in old_media:
origin = media["media_origin"]
media_id = media["media_id"]
file_id = media["filesystem_id"]
key = (origin, media_id) key = (origin, media_id)
logger.info("Deleting: %r", key) logger.info("Deleting: %r", key)

View File

@ -85,7 +85,19 @@ class ListDestinationsRestServlet(RestServlet):
destinations, total = await self._store.get_destinations_paginate( destinations, total = await self._store.get_destinations_paginate(
start, limit, destination, order_by, direction start, limit, destination, order_by, direction
) )
response = {"destinations": destinations, "total": total} response = {
"destinations": [
{
"destination": r[0],
"retry_last_ts": r[1],
"retry_interval": r[2],
"failure_ts": r[3],
"last_successful_stream_ordering": r[4],
}
for r in destinations
],
"total": total,
}
if (start + limit) < total: if (start + limit) < total:
response["next_token"] = str(start + len(destinations)) response["next_token"] = str(start + len(destinations))

View File

@ -724,7 +724,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
room_id, _ = await self.resolve_room_id(room_identifier) room_id, _ = await self.resolve_room_id(room_identifier)
extremities = await self.store.get_forward_extremities_for_room(room_id) extremities = await self.store.get_forward_extremities_for_room(room_id)
return HTTPStatus.OK, {"count": len(extremities), "results": extremities} result = [
{
"event_id": ex[0],
"state_group": ex[1],
"depth": ex[2],
"received_ts": ex[3],
}
for ex in extremities
]
return HTTPStatus.OK, {"count": len(extremities), "results": result}
class RoomEventContextServlet(RestServlet): class RoomEventContextServlet(RestServlet):

View File

@ -108,7 +108,18 @@ class UserMediaStatisticsRestServlet(RestServlet):
users_media, total = await self.store.get_users_media_usage_paginate( users_media, total = await self.store.get_users_media_usage_paginate(
start, limit, from_ts, until_ts, order_by, direction, search_term start, limit, from_ts, until_ts, order_by, direction, search_term
) )
ret = {"users": users_media, "total": total} ret = {
"users": [
{
"user_id": r[0],
"displayname": r[1],
"media_count": r[2],
"media_length": r[3],
}
for r in users_media
],
"total": total,
}
if (start + limit) < total: if (start + limit) < total:
ret["next_token"] = start + len(users_media) ret["next_token"] = start + len(users_media)

View File

@ -35,7 +35,6 @@ from typing import (
Tuple, Tuple,
Type, Type,
TypeVar, TypeVar,
Union,
cast, cast,
overload, overload,
) )
@ -1047,42 +1046,19 @@ class DatabasePool:
results = [dict(zip(col_headers, row)) for row in cursor] results = [dict(zip(col_headers, row)) for row in cursor]
return results return results
@overload async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]:
async def execute(
self, desc: str, decoder: Literal[None], query: str, *args: Any
) -> List[Tuple[Any, ...]]:
...
@overload
async def execute(
self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
) -> R:
...
async def execute(
self,
desc: str,
decoder: Optional[Callable[[Cursor], R]],
query: str,
*args: Any,
) -> Union[List[Tuple[Any, ...]], R]:
"""Runs a single query for a result set. """Runs a single query for a result set.
Args: Args:
desc: description of the transaction, for logging and metrics desc: description of the transaction, for logging and metrics
decoder - The function which can resolve the cursor results to
something meaningful.
query - The query string to execute query - The query string to execute
*args - Query args. *args - Query args.
Returns: Returns:
The result of decoder(results) The result of decoder(results)
""" """
def interaction(txn: LoggingTransaction) -> Union[List[Tuple[Any, ...]], R]: def interaction(txn: LoggingTransaction) -> List[Tuple[Any, ...]]:
txn.execute(query, args) txn.execute(query, args)
if decoder:
return decoder(txn)
else:
return txn.fetchall() return txn.fetchall()
return await self.runInteraction(desc, interaction) return await self.runInteraction(desc, interaction)

View File

@ -93,7 +93,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
""" """
rows = await self.db_pool.execute( rows = await self.db_pool.execute(
"_censor_redactions_fetch", None, sql, before_ts, 100 "_censor_redactions_fetch", sql, before_ts, 100
) )
updates = [] updates = []

View File

@ -894,7 +894,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
rows = await self.db_pool.execute( rows = await self.db_pool.execute(
"get_all_devices_changed", "get_all_devices_changed",
None,
sql, sql,
from_key, from_key,
to_key, to_key,
@ -978,7 +977,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
WHERE from_user_id = ? AND stream_id > ? WHERE from_user_id = ? AND stream_id > ?
""" """
rows = await self.db_pool.execute( rows = await self.db_pool.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key "get_users_whose_signatures_changed", sql, user_id, from_key
) )
return {user for row in rows for user in db_to_json(row[0])} return {user for row in rows for user in db_to_json(row[0])}
else: else:

View File

@ -155,7 +155,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
""" """
rows = await self.db_pool.execute( rows = await self.db_pool.execute(
"get_e2e_device_keys_for_federation_query_check", "get_e2e_device_keys_for_federation_query_check",
None,
sql, sql,
now_stream_id, now_stream_id,
user_id, user_id,

View File

@ -1310,12 +1310,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the # ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the
# indexes on it. # indexes on it.
# We need to pass execute a dummy function to handle the txn's result otherwise await self.db_pool.runInteraction(
# it tries to call fetchall() on it and fails because there's no result to fetch.
await self.db_pool.execute(
"background_analyze_new_stream_ordering_column", "background_analyze_new_stream_ordering_column",
lambda txn: None, lambda txn: txn.execute("ANALYZE events(stream_ordering2)"),
"ANALYZE events(stream_ordering2)",
) )
await self.db_pool.runInteraction( await self.db_pool.runInteraction(

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, List from typing import List, Optional, Tuple, cast
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction from synapse.storage.database import LoggingTransaction
@ -91,12 +91,17 @@ class EventForwardExtremitiesStore(
async def get_forward_extremities_for_room( async def get_forward_extremities_for_room(
self, room_id: str self, room_id: str
) -> List[Dict[str, Any]]: ) -> List[Tuple[str, int, int, Optional[int]]]:
"""Get list of forward extremities for a room.""" """
Get list of forward extremities for a room.
Returns:
A list of tuples of event_id, state_group, depth, and received_ts.
"""
def get_forward_extremities_for_room_txn( def get_forward_extremities_for_room_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> List[Dict[str, Any]]: ) -> List[Tuple[str, int, int, Optional[int]]]:
sql = """ sql = """
SELECT event_id, state_group, depth, received_ts SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities FROM event_forward_extremities
@ -106,7 +111,7 @@ class EventForwardExtremitiesStore(
""" """
txn.execute(sql, (room_id,)) txn.execute(sql, (room_id,))
return self.db_pool.cursor_to_dict(txn) return cast(List[Tuple[str, int, int, Optional[int]]], txn.fetchall())
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_forward_extremities_for_room", "get_forward_extremities_for_room",

View File

@ -650,7 +650,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_remote_media_ids( async def get_remote_media_ids(
self, before_ts: int, include_quarantined_media: bool self, before_ts: int, include_quarantined_media: bool
) -> List[Dict[str, str]]: ) -> List[Tuple[str, str, str]]:
""" """
Retrieve a list of server name, media ID tuples from the remote media cache. Retrieve a list of server name, media ID tuples from the remote media cache.
@ -664,12 +664,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
A list of tuples containing: A list of tuples containing:
* The server name of homeserver where the media originates from, * The server name of homeserver where the media originates from,
* The ID of the media. * The ID of the media.
* The filesystem ID.
"""
sql = """
SELECT media_origin, media_id, filesystem_id
FROM remote_media_cache
WHERE last_access_ts < ?
""" """
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
" WHERE last_access_ts < ?"
)
if include_quarantined_media is False: if include_quarantined_media is False:
# Only include media that has not been quarantined # Only include media that has not been quarantined
@ -677,8 +679,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
AND quarantined_by IS NULL AND quarantined_by IS NULL
""" """
return await self.db_pool.execute( return cast(
"get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts List[Tuple[str, str, str]],
await self.db_pool.execute("get_remote_media_ids", sql, before_ts),
) )
async def delete_remote_media(self, media_origin: str, media_id: str) -> None: async def delete_remote_media(self, media_origin: str, media_id: str) -> None:

View File

@ -151,6 +151,22 @@ class ThreepidResult:
added_at: int added_at: int
@attr.s(frozen=True, slots=True, auto_attribs=True)
class ThreepidValidationSession:
address: str
"""address of the 3pid"""
medium: str
"""medium of the 3pid"""
client_secret: str
"""a secret provided by the client for this validation session"""
session_id: str
"""ID of the validation session"""
last_send_attempt: int
"""a number serving to dedupe send attempts for this session"""
validated_at: Optional[int]
"""timestamp of when this session was validated if so"""
class RegistrationWorkerStore(CacheInvalidationWorkerStore): class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__( def __init__(
self, self,
@ -1172,7 +1188,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
address: Optional[str] = None, address: Optional[str] = None,
sid: Optional[str] = None, sid: Optional[str] = None,
validated: Optional[bool] = True, validated: Optional[bool] = True,
) -> Optional[Dict[str, Any]]: ) -> Optional[ThreepidValidationSession]:
"""Gets a session_id and last_send_attempt (if available) for a """Gets a session_id and last_send_attempt (if available) for a
combination of validation metadata combination of validation metadata
@ -1187,15 +1203,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
perform no filtering perform no filtering
Returns: Returns:
A dict containing the following: A ThreepidValidationSession or None if a validation session is not found
* address - address of the 3pid
* medium - medium of the 3pid
* client_secret - a secret provided by the client for this validation session
* session_id - ID of the validation session
* send_attempt - a number serving to dedupe send attempts for this session
* validated_at - timestamp of when this session was validated if so
Otherwise None if a validation session is not found
""" """
if not client_secret: if not client_secret:
raise SynapseError( raise SynapseError(
@ -1214,7 +1222,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def get_threepid_validation_session_txn( def get_threepid_validation_session_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Optional[Dict[str, Any]]: ) -> Optional[ThreepidValidationSession]:
sql = """ sql = """
SELECT address, session_id, medium, client_secret, SELECT address, session_id, medium, client_secret,
last_send_attempt, validated_at last_send_attempt, validated_at
@ -1229,11 +1237,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
sql += " LIMIT 1" sql += " LIMIT 1"
txn.execute(sql, list(keyvalues.values())) txn.execute(sql, list(keyvalues.values()))
rows = self.db_pool.cursor_to_dict(txn) row = txn.fetchone()
if not rows: if not row:
return None return None
return rows[0] return ThreepidValidationSession(
address=row[0],
session_id=row[1],
medium=row[2],
client_secret=row[3],
last_send_attempt=row[4],
validated_at=row[5],
)
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn "get_threepid_validation_session", get_threepid_validation_session_txn

View File

@ -940,7 +940,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
like_clause = "%:" + host like_clause = "%:" + host
rows = await self.db_pool.execute( rows = await self.db_pool.execute(
"is_host_joined", None, sql, membership, room_id, like_clause "is_host_joined", sql, membership, room_id, like_clause
) )
if not rows: if not rows:
@ -1168,7 +1168,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
AND forgotten = 0; AND forgotten = 0;
""" """
rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id) rows = await self.db_pool.execute("is_forgotten_room", sql, room_id)
# `count(*)` returns always an integer # `count(*)` returns always an integer
# If any rows still exist it means someone has not forgotten this room yet # If any rows still exist it means someone has not forgotten this room yet

View File

@ -26,6 +26,7 @@ from typing import (
Set, Set,
Tuple, Tuple,
Union, Union,
cast,
) )
import attr import attr
@ -506,16 +507,18 @@ class SearchStore(SearchBackgroundUpdateStore):
# entire table from the database. # entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500" sql += " ORDER BY rank DESC LIMIT 500"
results = await self.db_pool.execute( # List of tuples of (rank, room_id, event_id).
"search_msgs", self.db_pool.cursor_to_dict, sql, *args results = cast(
List[Tuple[Union[int, float], str, str]],
await self.db_pool.execute("search_msgs", sql, *args),
) )
results = list(filter(lambda row: row["room_id"] in room_ids, results)) results = list(filter(lambda row: row[1] in room_ids, results))
# We set redact_behaviour to block here to prevent redacted events being returned in # We set redact_behaviour to block here to prevent redacted events being returned in
# search results (which is a data leak) # search results (which is a data leak)
events = await self.get_events_as_list( # type: ignore[attr-defined] events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results], [r[2] for r in results],
redact_behaviour=EventRedactBehaviour.block, redact_behaviour=EventRedactBehaviour.block,
) )
@ -527,16 +530,18 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id" count_sql += " GROUP BY room_id"
count_results = await self.db_pool.execute( # List of tuples of (room_id, count).
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args count_results = cast(
List[Tuple[str, int]],
await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
) )
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) count = sum(row[1] for row in count_results if row[0] in room_ids)
return { return {
"results": [ "results": [
{"event": event_map[r["event_id"]], "rank": r["rank"]} {"event": event_map[r[2]], "rank": r[0]}
for r in results for r in results
if r["event_id"] in event_map if r[2] in event_map
], ],
"highlights": highlights, "highlights": highlights,
"count": count, "count": count,
@ -604,7 +609,7 @@ class SearchStore(SearchBackgroundUpdateStore):
search_query = search_term search_query = search_term
sql = """ sql = """
SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank, SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank,
origin_server_ts, stream_ordering, room_id, event_id room_id, event_id, origin_server_ts, stream_ordering
FROM event_search FROM event_search
WHERE vector @@ websearch_to_tsquery('english', ?) AND WHERE vector @@ websearch_to_tsquery('english', ?) AND
""" """
@ -665,16 +670,18 @@ class SearchStore(SearchBackgroundUpdateStore):
# mypy expects to append only a `str`, not an `int` # mypy expects to append only a `str`, not an `int`
args.append(limit) args.append(limit)
results = await self.db_pool.execute( # List of tuples of (rank, room_id, event_id, origin_server_ts, stream_ordering).
"search_rooms", self.db_pool.cursor_to_dict, sql, *args results = cast(
List[Tuple[Union[int, float], str, str, int, int]],
await self.db_pool.execute("search_rooms", sql, *args),
) )
results = list(filter(lambda row: row["room_id"] in room_ids, results)) results = list(filter(lambda row: row[1] in room_ids, results))
# We set redact_behaviour to block here to prevent redacted events being returned in # We set redact_behaviour to block here to prevent redacted events being returned in
# search results (which is a data leak) # search results (which is a data leak)
events = await self.get_events_as_list( # type: ignore[attr-defined] events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results], [r[2] for r in results],
redact_behaviour=EventRedactBehaviour.block, redact_behaviour=EventRedactBehaviour.block,
) )
@ -686,22 +693,23 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id" count_sql += " GROUP BY room_id"
count_results = await self.db_pool.execute( # List of tuples of (room_id, count).
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args count_results = cast(
List[Tuple[str, int]],
await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
) )
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) count = sum(row[1] for row in count_results if row[0] in room_ids)
return { return {
"results": [ "results": [
{ {
"event": event_map[r["event_id"]], "event": event_map[r[2]],
"rank": r["rank"], "rank": r[0],
"pagination_token": "%s,%s" "pagination_token": "%s,%s" % (r[3], r[4]),
% (r["origin_server_ts"], r["stream_ordering"]),
} }
for r in results for r in results
if r["event_id"] in event_map if r[2] in event_map
], ],
"highlights": highlights, "highlights": highlights,
"count": count, "count": count,

View File

@ -679,7 +679,7 @@ class StatsStore(StateDeltasStore):
order_by: Optional[str] = UserSortOrder.USER_ID.value, order_by: Optional[str] = UserSortOrder.USER_ID.value,
direction: Direction = Direction.FORWARDS, direction: Direction = Direction.FORWARDS,
search_term: Optional[str] = None, search_term: Optional[str] = None,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
"""Function to retrieve a paginated list of users and their uploaded local media """Function to retrieve a paginated list of users and their uploaded local media
(size and number). This will return a json list of users and the (size and number). This will return a json list of users and the
total number of users matching the filter criteria. total number of users matching the filter criteria.
@ -692,14 +692,19 @@ class StatsStore(StateDeltasStore):
order_by: the sort order of the returned list order_by: the sort order of the returned list
direction: sort ascending or descending direction: sort ascending or descending
search_term: a string to filter user names by search_term: a string to filter user names by
Returns: Returns:
A list of user dicts and an integer representing the total number of A tuple of:
users that exist given this query A list of tuples of user information (the user ID, displayname,
total number of media, total length of media) and
An integer representing the total number of users that exist
given this query
""" """
def get_users_media_usage_paginate_txn( def get_users_media_usage_paginate_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
filters = [] filters = []
args: list = [] args: list = []
@ -773,7 +778,7 @@ class StatsStore(StateDeltasStore):
args += [limit, start] args += [limit, start]
txn.execute(sql, args) txn.execute(sql, args)
users = self.db_pool.cursor_to_dict(txn) users = cast(List[Tuple[str, Optional[str], int, int]], txn.fetchall())
return users, count return users, count

View File

@ -1078,7 +1078,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
""" """
row = await self.db_pool.execute( row = await self.db_pool.execute(
"get_current_topological_token", None, sql, room_id, room_id, stream_key "get_current_topological_token", sql, room_id, room_id, stream_key
) )
return row[0][0] if row else 0 return row[0][0] if row else 0
@ -1636,7 +1636,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = await self.db_pool.execute( rows = await self.db_pool.execute(
"get_timeline_gaps", "get_timeline_gaps",
None,
sql, sql,
room_id, room_id,
from_token.stream if from_token else 0, from_token.stream if from_token else 0,

View File

@ -478,7 +478,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
destination: Optional[str] = None, destination: Optional[str] = None,
order_by: str = DestinationSortOrder.DESTINATION.value, order_by: str = DestinationSortOrder.DESTINATION.value,
direction: Direction = Direction.FORWARDS, direction: Direction = Direction.FORWARDS,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[
List[Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]],
int,
]:
"""Function to retrieve a paginated list of destinations. """Function to retrieve a paginated list of destinations.
This will return a json list of destinations and the This will return a json list of destinations and the
total number of destinations matching the filter criteria. total number of destinations matching the filter criteria.
@ -490,13 +493,23 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
order_by: the sort order of the returned list order_by: the sort order of the returned list
direction: sort ascending or descending direction: sort ascending or descending
Returns: Returns:
A tuple of a list of mappings from destination to information A tuple of a list of tuples of destination information:
* destination
* retry_last_ts
* retry_interval
* failure_ts
* last_successful_stream_ordering
and a count of total destinations. and a count of total destinations.
""" """
def get_destinations_paginate_txn( def get_destinations_paginate_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[
List[
Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]
],
int,
]:
order_by_column = DestinationSortOrder(order_by).value order_by_column = DestinationSortOrder(order_by).value
if direction == Direction.BACKWARDS: if direction == Direction.BACKWARDS:
@ -523,7 +536,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
LIMIT ? OFFSET ? LIMIT ? OFFSET ?
""" """
txn.execute(sql, args + [limit, start]) txn.execute(sql, args + [limit, start])
destinations = self.db_pool.cursor_to_dict(txn) destinations = cast(
List[
Tuple[
str, Optional[int], Optional[int], Optional[int], Optional[int]
]
],
txn.fetchall(),
)
return destinations, count return destinations, count
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(

View File

@ -1145,15 +1145,19 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
raise Exception("Unrecognized database engine") raise Exception("Unrecognized database engine")
results = cast( results = cast(
List[UserProfile], List[Tuple[str, Optional[str], Optional[str]]],
await self.db_pool.execute( await self.db_pool.execute("search_user_dir", sql, *args),
"search_user_dir", self.db_pool.cursor_to_dict, sql, *args
),
) )
limited = len(results) > limit limited = len(results) > limit
return {"limited": limited, "results": results[0:limit]} return {
"limited": limited,
"results": [
{"user_id": r[0], "display_name": r[1], "avatar_url": r[2]}
for r in results[0:limit]
],
}
def _filter_text_for_index(text: str) -> str: def _filter_text_for_index(text: str) -> str:

View File

@ -359,7 +359,6 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
if max_group is None: if max_group is None:
rows = await self.db_pool.execute( rows = await self.db_pool.execute(
"_background_deduplicate_state", "_background_deduplicate_state",
None,
"SELECT coalesce(max(id), 0) FROM state_groups", "SELECT coalesce(max(id), 0) FROM state_groups",
) )
max_group = rows[0][0] max_group = rows[0][0]

View File

@ -100,7 +100,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
event_id, stream_ordering = self.get_success( event_id, stream_ordering = self.get_success(
self.hs.get_datastores().main.db_pool.execute( self.hs.get_datastores().main.db_pool.execute(
"test:get_destination_rooms", "test:get_destination_rooms",
None,
""" """
SELECT event_id, stream_ordering SELECT event_id, stream_ordering
FROM destination_rooms dr FROM destination_rooms dr

View File

@ -457,8 +457,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
); );
""" """
self.get_success( self.get_success(
self.store.db_pool.execute( self.store.db_pool.runInteraction(
"test_not_null_constraint", lambda _: None, table_sql "test_not_null_constraint", lambda txn: txn.execute(table_sql)
) )
) )
@ -466,8 +466,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
# using SQLite. # using SQLite.
index_sql = "CREATE INDEX test_index ON test_constraint(a)" index_sql = "CREATE INDEX test_index ON test_constraint(a)"
self.get_success( self.get_success(
self.store.db_pool.execute( self.store.db_pool.runInteraction(
"test_not_null_constraint", lambda _: None, index_sql "test_not_null_constraint", lambda txn: txn.execute(index_sql)
) )
) )
@ -574,13 +574,13 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
); );
""" """
self.get_success( self.get_success(
self.store.db_pool.execute( self.store.db_pool.runInteraction(
"test_foreign_key_constraint", lambda _: None, base_sql "test_foreign_key_constraint", lambda txn: txn.execute(base_sql)
) )
) )
self.get_success( self.get_success(
self.store.db_pool.execute( self.store.db_pool.runInteraction(
"test_foreign_key_constraint", lambda _: None, table_sql "test_foreign_key_constraint", lambda txn: txn.execute(table_sql)
) )
) )

View File

@ -120,7 +120,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
res = self.get_success( res = self.get_success(
self.store.db_pool.execute( self.store.db_pool.execute(
"", None, "SELECT full_user_id from profiles ORDER BY full_user_id" "", "SELECT full_user_id from profiles ORDER BY full_user_id"
) )
) )
self.assertEqual(len(res), len(expected_values)) self.assertEqual(len(res), len(expected_values))

View File

@ -87,7 +87,7 @@ class UserFiltersStoreTestCase(unittest.HomeserverTestCase):
res = self.get_success( res = self.get_success(
self.store.db_pool.execute( self.store.db_pool.execute(
"", None, "SELECT full_user_id from user_filters ORDER BY full_user_id" "", "SELECT full_user_id from user_filters ORDER BY full_user_id"
) )
) )
self.assertEqual(len(res), len(expected_values)) self.assertEqual(len(res), len(expected_values))