mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Remove more usages of cursor_to_dict. (#16551)
Mostly to improve type safety.
This commit is contained in:
parent
85e5f2dc25
commit
679c691f6f
1
changelog.d/16551.misc
Normal file
1
changelog.d/16551.misc
Normal file
@ -0,0 +1 @@
|
||||
Improve type hints.
|
@ -19,6 +19,8 @@ import logging
|
||||
import urllib.parse
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException,
|
||||
Codes,
|
||||
@ -357,9 +359,9 @@ class IdentityHandler:
|
||||
|
||||
# Check to see if a session already exists and that it is not yet
|
||||
# marked as validated
|
||||
if session and session.get("validated_at") is None:
|
||||
session_id = session["session_id"]
|
||||
last_send_attempt = session["last_send_attempt"]
|
||||
if session and session.validated_at is None:
|
||||
session_id = session.session_id
|
||||
last_send_attempt = session.last_send_attempt
|
||||
|
||||
# Check that the send_attempt is higher than previous attempts
|
||||
if send_attempt <= last_send_attempt:
|
||||
@ -480,7 +482,6 @@ class IdentityHandler:
|
||||
|
||||
# We don't actually know which medium this 3PID is. Thus we first assume it's email,
|
||||
# and if validation fails we try msisdn
|
||||
validation_session = None
|
||||
|
||||
# Try to validate as email
|
||||
if self.hs.config.email.can_verify_email:
|
||||
@ -488,19 +489,18 @@ class IdentityHandler:
|
||||
validation_session = await self.store.get_threepid_validation_session(
|
||||
"email", client_secret, sid=sid, validated=True
|
||||
)
|
||||
|
||||
if validation_session:
|
||||
return validation_session
|
||||
if validation_session:
|
||||
return attr.asdict(validation_session)
|
||||
|
||||
# Try to validate as msisdn
|
||||
if self.hs.config.registration.account_threepid_delegate_msisdn:
|
||||
# Ask our delegated msisdn identity server
|
||||
validation_session = await self.threepid_from_creds(
|
||||
return await self.threepid_from_creds(
|
||||
self.hs.config.registration.account_threepid_delegate_msisdn,
|
||||
threepid_creds,
|
||||
)
|
||||
|
||||
return validation_session
|
||||
return None
|
||||
|
||||
async def proxy_msisdn_submit_token(
|
||||
self, id_server: str, client_secret: str, sid: str, token: str
|
||||
|
@ -187,9 +187,9 @@ class _BaseThreepidAuthChecker:
|
||||
|
||||
if row:
|
||||
threepid = {
|
||||
"medium": row["medium"],
|
||||
"address": row["address"],
|
||||
"validated_at": row["validated_at"],
|
||||
"medium": row.medium,
|
||||
"address": row.address,
|
||||
"validated_at": row.validated_at,
|
||||
}
|
||||
|
||||
# Valid threepid returned, delete from the db
|
||||
|
@ -949,10 +949,7 @@ class MediaRepository:
|
||||
|
||||
deleted = 0
|
||||
|
||||
for media in old_media:
|
||||
origin = media["media_origin"]
|
||||
media_id = media["media_id"]
|
||||
file_id = media["filesystem_id"]
|
||||
for origin, media_id, file_id in old_media:
|
||||
key = (origin, media_id)
|
||||
|
||||
logger.info("Deleting: %r", key)
|
||||
|
@ -85,7 +85,19 @@ class ListDestinationsRestServlet(RestServlet):
|
||||
destinations, total = await self._store.get_destinations_paginate(
|
||||
start, limit, destination, order_by, direction
|
||||
)
|
||||
response = {"destinations": destinations, "total": total}
|
||||
response = {
|
||||
"destinations": [
|
||||
{
|
||||
"destination": r[0],
|
||||
"retry_last_ts": r[1],
|
||||
"retry_interval": r[2],
|
||||
"failure_ts": r[3],
|
||||
"last_successful_stream_ordering": r[4],
|
||||
}
|
||||
for r in destinations
|
||||
],
|
||||
"total": total,
|
||||
}
|
||||
if (start + limit) < total:
|
||||
response["next_token"] = str(start + len(destinations))
|
||||
|
||||
|
@ -724,7 +724,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||
room_id, _ = await self.resolve_room_id(room_identifier)
|
||||
|
||||
extremities = await self.store.get_forward_extremities_for_room(room_id)
|
||||
return HTTPStatus.OK, {"count": len(extremities), "results": extremities}
|
||||
result = [
|
||||
{
|
||||
"event_id": ex[0],
|
||||
"state_group": ex[1],
|
||||
"depth": ex[2],
|
||||
"received_ts": ex[3],
|
||||
}
|
||||
for ex in extremities
|
||||
]
|
||||
|
||||
return HTTPStatus.OK, {"count": len(extremities), "results": result}
|
||||
|
||||
|
||||
class RoomEventContextServlet(RestServlet):
|
||||
|
@ -108,7 +108,18 @@ class UserMediaStatisticsRestServlet(RestServlet):
|
||||
users_media, total = await self.store.get_users_media_usage_paginate(
|
||||
start, limit, from_ts, until_ts, order_by, direction, search_term
|
||||
)
|
||||
ret = {"users": users_media, "total": total}
|
||||
ret = {
|
||||
"users": [
|
||||
{
|
||||
"user_id": r[0],
|
||||
"displayname": r[1],
|
||||
"media_count": r[2],
|
||||
"media_length": r[3],
|
||||
}
|
||||
for r in users_media
|
||||
],
|
||||
"total": total,
|
||||
}
|
||||
if (start + limit) < total:
|
||||
ret["next_token"] = start + len(users_media)
|
||||
|
||||
|
@ -35,7 +35,6 @@ from typing import (
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
@ -1047,43 +1046,20 @@ class DatabasePool:
|
||||
results = [dict(zip(col_headers, row)) for row in cursor]
|
||||
return results
|
||||
|
||||
@overload
|
||||
async def execute(
|
||||
self, desc: str, decoder: Literal[None], query: str, *args: Any
|
||||
) -> List[Tuple[Any, ...]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def execute(
|
||||
self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
|
||||
) -> R:
|
||||
...
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
desc: str,
|
||||
decoder: Optional[Callable[[Cursor], R]],
|
||||
query: str,
|
||||
*args: Any,
|
||||
) -> Union[List[Tuple[Any, ...]], R]:
|
||||
async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]:
|
||||
"""Runs a single query for a result set.
|
||||
|
||||
Args:
|
||||
desc: description of the transaction, for logging and metrics
|
||||
decoder - The function which can resolve the cursor results to
|
||||
something meaningful.
|
||||
query - The query string to execute
|
||||
*args - Query args.
|
||||
Returns:
|
||||
The result of decoder(results)
|
||||
"""
|
||||
|
||||
def interaction(txn: LoggingTransaction) -> Union[List[Tuple[Any, ...]], R]:
|
||||
def interaction(txn: LoggingTransaction) -> List[Tuple[Any, ...]]:
|
||||
txn.execute(query, args)
|
||||
if decoder:
|
||||
return decoder(txn)
|
||||
else:
|
||||
return txn.fetchall()
|
||||
return txn.fetchall()
|
||||
|
||||
return await self.runInteraction(desc, interaction)
|
||||
|
||||
|
@ -93,7 +93,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
||||
"""
|
||||
|
||||
rows = await self.db_pool.execute(
|
||||
"_censor_redactions_fetch", None, sql, before_ts, 100
|
||||
"_censor_redactions_fetch", sql, before_ts, 100
|
||||
)
|
||||
|
||||
updates = []
|
||||
|
@ -894,7 +894,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||
|
||||
rows = await self.db_pool.execute(
|
||||
"get_all_devices_changed",
|
||||
None,
|
||||
sql,
|
||||
from_key,
|
||||
to_key,
|
||||
@ -978,7 +977,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||
WHERE from_user_id = ? AND stream_id > ?
|
||||
"""
|
||||
rows = await self.db_pool.execute(
|
||||
"get_users_whose_signatures_changed", None, sql, user_id, from_key
|
||||
"get_users_whose_signatures_changed", sql, user_id, from_key
|
||||
)
|
||||
return {user for row in rows for user in db_to_json(row[0])}
|
||||
else:
|
||||
|
@ -155,7 +155,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
"""
|
||||
rows = await self.db_pool.execute(
|
||||
"get_e2e_device_keys_for_federation_query_check",
|
||||
None,
|
||||
sql,
|
||||
now_stream_id,
|
||||
user_id,
|
||||
|
@ -1310,12 +1310,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
|
||||
# ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the
|
||||
# indexes on it.
|
||||
# We need to pass execute a dummy function to handle the txn's result otherwise
|
||||
# it tries to call fetchall() on it and fails because there's no result to fetch.
|
||||
await self.db_pool.execute(
|
||||
await self.db_pool.runInteraction(
|
||||
"background_analyze_new_stream_ordering_column",
|
||||
lambda txn: None,
|
||||
"ANALYZE events(stream_ordering2)",
|
||||
lambda txn: txn.execute("ANALYZE events(stream_ordering2)"),
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
from typing import List, Optional, Tuple, cast
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
@ -91,12 +91,17 @@ class EventForwardExtremitiesStore(
|
||||
|
||||
async def get_forward_extremities_for_room(
|
||||
self, room_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get list of forward extremities for a room."""
|
||||
) -> List[Tuple[str, int, int, Optional[int]]]:
|
||||
"""
|
||||
Get list of forward extremities for a room.
|
||||
|
||||
Returns:
|
||||
A list of tuples of event_id, state_group, depth, and received_ts.
|
||||
"""
|
||||
|
||||
def get_forward_extremities_for_room_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> List[Tuple[str, int, int, Optional[int]]]:
|
||||
sql = """
|
||||
SELECT event_id, state_group, depth, received_ts
|
||||
FROM event_forward_extremities
|
||||
@ -106,7 +111,7 @@ class EventForwardExtremitiesStore(
|
||||
"""
|
||||
|
||||
txn.execute(sql, (room_id,))
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
return cast(List[Tuple[str, int, int, Optional[int]]], txn.fetchall())
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_forward_extremities_for_room",
|
||||
|
@ -650,7 +650,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
|
||||
async def get_remote_media_ids(
|
||||
self, before_ts: int, include_quarantined_media: bool
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
Retrieve a list of server name, media ID tuples from the remote media cache.
|
||||
|
||||
@ -664,12 +664,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
A list of tuples containing:
|
||||
* The server name of homeserver where the media originates from,
|
||||
* The ID of the media.
|
||||
* The filesystem ID.
|
||||
"""
|
||||
|
||||
sql = """
|
||||
SELECT media_origin, media_id, filesystem_id
|
||||
FROM remote_media_cache
|
||||
WHERE last_access_ts < ?
|
||||
"""
|
||||
sql = (
|
||||
"SELECT media_origin, media_id, filesystem_id"
|
||||
" FROM remote_media_cache"
|
||||
" WHERE last_access_ts < ?"
|
||||
)
|
||||
|
||||
if include_quarantined_media is False:
|
||||
# Only include media that has not been quarantined
|
||||
@ -677,8 +679,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
AND quarantined_by IS NULL
|
||||
"""
|
||||
|
||||
return await self.db_pool.execute(
|
||||
"get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
|
||||
return cast(
|
||||
List[Tuple[str, str, str]],
|
||||
await self.db_pool.execute("get_remote_media_ids", sql, before_ts),
|
||||
)
|
||||
|
||||
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
|
||||
|
@ -151,6 +151,22 @@ class ThreepidResult:
|
||||
added_at: int
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||
class ThreepidValidationSession:
|
||||
address: str
|
||||
"""address of the 3pid"""
|
||||
medium: str
|
||||
"""medium of the 3pid"""
|
||||
client_secret: str
|
||||
"""a secret provided by the client for this validation session"""
|
||||
session_id: str
|
||||
"""ID of the validation session"""
|
||||
last_send_attempt: int
|
||||
"""a number serving to dedupe send attempts for this session"""
|
||||
validated_at: Optional[int]
|
||||
"""timestamp of when this session was validated if so"""
|
||||
|
||||
|
||||
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
def __init__(
|
||||
self,
|
||||
@ -1172,7 +1188,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
address: Optional[str] = None,
|
||||
sid: Optional[str] = None,
|
||||
validated: Optional[bool] = True,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> Optional[ThreepidValidationSession]:
|
||||
"""Gets a session_id and last_send_attempt (if available) for a
|
||||
combination of validation metadata
|
||||
|
||||
@ -1187,15 +1203,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
perform no filtering
|
||||
|
||||
Returns:
|
||||
A dict containing the following:
|
||||
* address - address of the 3pid
|
||||
* medium - medium of the 3pid
|
||||
* client_secret - a secret provided by the client for this validation session
|
||||
* session_id - ID of the validation session
|
||||
* send_attempt - a number serving to dedupe send attempts for this session
|
||||
* validated_at - timestamp of when this session was validated if so
|
||||
|
||||
Otherwise None if a validation session is not found
|
||||
A ThreepidValidationSession or None if a validation session is not found
|
||||
"""
|
||||
if not client_secret:
|
||||
raise SynapseError(
|
||||
@ -1214,7 +1222,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
|
||||
def get_threepid_validation_session_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> Optional[ThreepidValidationSession]:
|
||||
sql = """
|
||||
SELECT address, session_id, medium, client_secret,
|
||||
last_send_attempt, validated_at
|
||||
@ -1229,11 +1237,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
sql += " LIMIT 1"
|
||||
|
||||
txn.execute(sql, list(keyvalues.values()))
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
if not rows:
|
||||
row = txn.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return rows[0]
|
||||
return ThreepidValidationSession(
|
||||
address=row[0],
|
||||
session_id=row[1],
|
||||
medium=row[2],
|
||||
client_secret=row[3],
|
||||
last_send_attempt=row[4],
|
||||
validated_at=row[5],
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_threepid_validation_session", get_threepid_validation_session_txn
|
||||
|
@ -940,7 +940,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||
like_clause = "%:" + host
|
||||
|
||||
rows = await self.db_pool.execute(
|
||||
"is_host_joined", None, sql, membership, room_id, like_clause
|
||||
"is_host_joined", sql, membership, room_id, like_clause
|
||||
)
|
||||
|
||||
if not rows:
|
||||
@ -1168,7 +1168,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||
AND forgotten = 0;
|
||||
"""
|
||||
|
||||
rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
|
||||
rows = await self.db_pool.execute("is_forgotten_room", sql, room_id)
|
||||
|
||||
# `count(*)` returns always an integer
|
||||
# If any rows still exist it means someone has not forgotten this room yet
|
||||
|
@ -26,6 +26,7 @@ from typing import (
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import attr
|
||||
@ -506,16 +507,18 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
# entire table from the database.
|
||||
sql += " ORDER BY rank DESC LIMIT 500"
|
||||
|
||||
results = await self.db_pool.execute(
|
||||
"search_msgs", self.db_pool.cursor_to_dict, sql, *args
|
||||
# List of tuples of (rank, room_id, event_id).
|
||||
results = cast(
|
||||
List[Tuple[Union[int, float], str, str]],
|
||||
await self.db_pool.execute("search_msgs", sql, *args),
|
||||
)
|
||||
|
||||
results = list(filter(lambda row: row["room_id"] in room_ids, results))
|
||||
results = list(filter(lambda row: row[1] in room_ids, results))
|
||||
|
||||
# We set redact_behaviour to block here to prevent redacted events being returned in
|
||||
# search results (which is a data leak)
|
||||
events = await self.get_events_as_list( # type: ignore[attr-defined]
|
||||
[r["event_id"] for r in results],
|
||||
[r[2] for r in results],
|
||||
redact_behaviour=EventRedactBehaviour.block,
|
||||
)
|
||||
|
||||
@ -527,16 +530,18 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
|
||||
count_sql += " GROUP BY room_id"
|
||||
|
||||
count_results = await self.db_pool.execute(
|
||||
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
|
||||
# List of tuples of (room_id, count).
|
||||
count_results = cast(
|
||||
List[Tuple[str, int]],
|
||||
await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
|
||||
)
|
||||
|
||||
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
|
||||
count = sum(row[1] for row in count_results if row[0] in room_ids)
|
||||
return {
|
||||
"results": [
|
||||
{"event": event_map[r["event_id"]], "rank": r["rank"]}
|
||||
{"event": event_map[r[2]], "rank": r[0]}
|
||||
for r in results
|
||||
if r["event_id"] in event_map
|
||||
if r[2] in event_map
|
||||
],
|
||||
"highlights": highlights,
|
||||
"count": count,
|
||||
@ -604,7 +609,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
search_query = search_term
|
||||
sql = """
|
||||
SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank,
|
||||
origin_server_ts, stream_ordering, room_id, event_id
|
||||
room_id, event_id, origin_server_ts, stream_ordering
|
||||
FROM event_search
|
||||
WHERE vector @@ websearch_to_tsquery('english', ?) AND
|
||||
"""
|
||||
@ -665,16 +670,18 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
# mypy expects to append only a `str`, not an `int`
|
||||
args.append(limit)
|
||||
|
||||
results = await self.db_pool.execute(
|
||||
"search_rooms", self.db_pool.cursor_to_dict, sql, *args
|
||||
# List of tuples of (rank, room_id, event_id, origin_server_ts, stream_ordering).
|
||||
results = cast(
|
||||
List[Tuple[Union[int, float], str, str, int, int]],
|
||||
await self.db_pool.execute("search_rooms", sql, *args),
|
||||
)
|
||||
|
||||
results = list(filter(lambda row: row["room_id"] in room_ids, results))
|
||||
results = list(filter(lambda row: row[1] in room_ids, results))
|
||||
|
||||
# We set redact_behaviour to block here to prevent redacted events being returned in
|
||||
# search results (which is a data leak)
|
||||
events = await self.get_events_as_list( # type: ignore[attr-defined]
|
||||
[r["event_id"] for r in results],
|
||||
[r[2] for r in results],
|
||||
redact_behaviour=EventRedactBehaviour.block,
|
||||
)
|
||||
|
||||
@ -686,22 +693,23 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
|
||||
count_sql += " GROUP BY room_id"
|
||||
|
||||
count_results = await self.db_pool.execute(
|
||||
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
|
||||
# List of tuples of (room_id, count).
|
||||
count_results = cast(
|
||||
List[Tuple[str, int]],
|
||||
await self.db_pool.execute("search_rooms_count", count_sql, *count_args),
|
||||
)
|
||||
|
||||
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
|
||||
count = sum(row[1] for row in count_results if row[0] in room_ids)
|
||||
|
||||
return {
|
||||
"results": [
|
||||
{
|
||||
"event": event_map[r["event_id"]],
|
||||
"rank": r["rank"],
|
||||
"pagination_token": "%s,%s"
|
||||
% (r["origin_server_ts"], r["stream_ordering"]),
|
||||
"event": event_map[r[2]],
|
||||
"rank": r[0],
|
||||
"pagination_token": "%s,%s" % (r[3], r[4]),
|
||||
}
|
||||
for r in results
|
||||
if r["event_id"] in event_map
|
||||
if r[2] in event_map
|
||||
],
|
||||
"highlights": highlights,
|
||||
"count": count,
|
||||
|
@ -679,7 +679,7 @@ class StatsStore(StateDeltasStore):
|
||||
order_by: Optional[str] = UserSortOrder.USER_ID.value,
|
||||
direction: Direction = Direction.FORWARDS,
|
||||
search_term: Optional[str] = None,
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
|
||||
"""Function to retrieve a paginated list of users and their uploaded local media
|
||||
(size and number). This will return a json list of users and the
|
||||
total number of users matching the filter criteria.
|
||||
@ -692,14 +692,19 @@ class StatsStore(StateDeltasStore):
|
||||
order_by: the sort order of the returned list
|
||||
direction: sort ascending or descending
|
||||
search_term: a string to filter user names by
|
||||
|
||||
Returns:
|
||||
A list of user dicts and an integer representing the total number of
|
||||
users that exist given this query
|
||||
A tuple of:
|
||||
A list of tuples of user information (the user ID, displayname,
|
||||
total number of media, total length of media) and
|
||||
|
||||
An integer representing the total number of users that exist
|
||||
given this query
|
||||
"""
|
||||
|
||||
def get_users_media_usage_paginate_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
) -> Tuple[List[Tuple[str, Optional[str], int, int]], int]:
|
||||
filters = []
|
||||
args: list = []
|
||||
|
||||
@ -773,7 +778,7 @@ class StatsStore(StateDeltasStore):
|
||||
|
||||
args += [limit, start]
|
||||
txn.execute(sql, args)
|
||||
users = self.db_pool.cursor_to_dict(txn)
|
||||
users = cast(List[Tuple[str, Optional[str], int, int]], txn.fetchall())
|
||||
|
||||
return users, count
|
||||
|
||||
|
@ -1078,7 +1078,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
"""
|
||||
|
||||
row = await self.db_pool.execute(
|
||||
"get_current_topological_token", None, sql, room_id, room_id, stream_key
|
||||
"get_current_topological_token", sql, room_id, room_id, stream_key
|
||||
)
|
||||
return row[0][0] if row else 0
|
||||
|
||||
@ -1636,7 +1636,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
rows = await self.db_pool.execute(
|
||||
"get_timeline_gaps",
|
||||
None,
|
||||
sql,
|
||||
room_id,
|
||||
from_token.stream if from_token else 0,
|
||||
|
@ -478,7 +478,10 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||
destination: Optional[str] = None,
|
||||
order_by: str = DestinationSortOrder.DESTINATION.value,
|
||||
direction: Direction = Direction.FORWARDS,
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
) -> Tuple[
|
||||
List[Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]],
|
||||
int,
|
||||
]:
|
||||
"""Function to retrieve a paginated list of destinations.
|
||||
This will return a json list of destinations and the
|
||||
total number of destinations matching the filter criteria.
|
||||
@ -490,13 +493,23 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||
order_by: the sort order of the returned list
|
||||
direction: sort ascending or descending
|
||||
Returns:
|
||||
A tuple of a list of mappings from destination to information
|
||||
A tuple of a list of tuples of destination information:
|
||||
* destination
|
||||
* retry_last_ts
|
||||
* retry_interval
|
||||
* failure_ts
|
||||
* last_successful_stream_ordering
|
||||
and a count of total destinations.
|
||||
"""
|
||||
|
||||
def get_destinations_paginate_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
) -> Tuple[
|
||||
List[
|
||||
Tuple[str, Optional[int], Optional[int], Optional[int], Optional[int]]
|
||||
],
|
||||
int,
|
||||
]:
|
||||
order_by_column = DestinationSortOrder(order_by).value
|
||||
|
||||
if direction == Direction.BACKWARDS:
|
||||
@ -523,7 +536,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
txn.execute(sql, args + [limit, start])
|
||||
destinations = self.db_pool.cursor_to_dict(txn)
|
||||
destinations = cast(
|
||||
List[
|
||||
Tuple[
|
||||
str, Optional[int], Optional[int], Optional[int], Optional[int]
|
||||
]
|
||||
],
|
||||
txn.fetchall(),
|
||||
)
|
||||
return destinations, count
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
|
@ -1145,15 +1145,19 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
raise Exception("Unrecognized database engine")
|
||||
|
||||
results = cast(
|
||||
List[UserProfile],
|
||||
await self.db_pool.execute(
|
||||
"search_user_dir", self.db_pool.cursor_to_dict, sql, *args
|
||||
),
|
||||
List[Tuple[str, Optional[str], Optional[str]]],
|
||||
await self.db_pool.execute("search_user_dir", sql, *args),
|
||||
)
|
||||
|
||||
limited = len(results) > limit
|
||||
|
||||
return {"limited": limited, "results": results[0:limit]}
|
||||
return {
|
||||
"limited": limited,
|
||||
"results": [
|
||||
{"user_id": r[0], "display_name": r[1], "avatar_url": r[2]}
|
||||
for r in results[0:limit]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _filter_text_for_index(text: str) -> str:
|
||||
|
@ -359,7 +359,6 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
|
||||
if max_group is None:
|
||||
rows = await self.db_pool.execute(
|
||||
"_background_deduplicate_state",
|
||||
None,
|
||||
"SELECT coalesce(max(id), 0) FROM state_groups",
|
||||
)
|
||||
max_group = rows[0][0]
|
||||
|
@ -100,7 +100,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
|
||||
event_id, stream_ordering = self.get_success(
|
||||
self.hs.get_datastores().main.db_pool.execute(
|
||||
"test:get_destination_rooms",
|
||||
None,
|
||||
"""
|
||||
SELECT event_id, stream_ordering
|
||||
FROM destination_rooms dr
|
||||
|
@ -457,8 +457,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
|
||||
);
|
||||
"""
|
||||
self.get_success(
|
||||
self.store.db_pool.execute(
|
||||
"test_not_null_constraint", lambda _: None, table_sql
|
||||
self.store.db_pool.runInteraction(
|
||||
"test_not_null_constraint", lambda txn: txn.execute(table_sql)
|
||||
)
|
||||
)
|
||||
|
||||
@ -466,8 +466,8 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
|
||||
# using SQLite.
|
||||
index_sql = "CREATE INDEX test_index ON test_constraint(a)"
|
||||
self.get_success(
|
||||
self.store.db_pool.execute(
|
||||
"test_not_null_constraint", lambda _: None, index_sql
|
||||
self.store.db_pool.runInteraction(
|
||||
"test_not_null_constraint", lambda txn: txn.execute(index_sql)
|
||||
)
|
||||
)
|
||||
|
||||
@ -574,13 +574,13 @@ class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
|
||||
);
|
||||
"""
|
||||
self.get_success(
|
||||
self.store.db_pool.execute(
|
||||
"test_foreign_key_constraint", lambda _: None, base_sql
|
||||
self.store.db_pool.runInteraction(
|
||||
"test_foreign_key_constraint", lambda txn: txn.execute(base_sql)
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
self.store.db_pool.execute(
|
||||
"test_foreign_key_constraint", lambda _: None, table_sql
|
||||
self.store.db_pool.runInteraction(
|
||||
"test_foreign_key_constraint", lambda txn: txn.execute(table_sql)
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -120,7 +120,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
res = self.get_success(
|
||||
self.store.db_pool.execute(
|
||||
"", None, "SELECT full_user_id from profiles ORDER BY full_user_id"
|
||||
"", "SELECT full_user_id from profiles ORDER BY full_user_id"
|
||||
)
|
||||
)
|
||||
self.assertEqual(len(res), len(expected_values))
|
||||
|
@ -87,7 +87,7 @@ class UserFiltersStoreTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
res = self.get_success(
|
||||
self.store.db_pool.execute(
|
||||
"", None, "SELECT full_user_id from user_filters ORDER BY full_user_id"
|
||||
"", "SELECT full_user_id from user_filters ORDER BY full_user_id"
|
||||
)
|
||||
)
|
||||
self.assertEqual(len(res), len(expected_values))
|
||||
|
Loading…
Reference in New Issue
Block a user