Add type hints to synapse/storage/databases/main/room.py (#11575)

This commit is contained in:
Sean Quah 2021-12-15 18:00:48 +00:00 committed by GitHub
parent f901f8b70e
commit c7fe32edb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 108 additions and 77 deletions

1
changelog.d/11575.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to storage classes.

View File

@ -37,7 +37,6 @@ exclude = (?x)
|synapse/storage/databases/main/purge_events.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/room.py
|synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py
@ -205,6 +204,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.events_worker]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.room]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.room_batch]
disallow_untyped_defs = True

View File

@ -1020,7 +1020,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Add new room to the room directory if the old room was there
# Remove old room from the room directory
old_room = await self.store.get_room(old_room_id)
if old_room and old_room["is_public"]:
if old_room is not None and old_room["is_public"]:
await self.store.set_room_is_public(old_room_id, False)
await self.store.set_room_is_public(room_id, True)
@ -1031,7 +1031,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
local_group_ids = await self.store.get_local_groups_for_room(old_room_id)
for group_id in local_group_ids:
# Add new the new room to those groups
await self.store.add_room_to_group(group_id, room_id, old_room["is_public"])
await self.store.add_room_to_group(
group_id, room_id, old_room is not None and old_room["is_public"]
)
# Remove the old room from those groups
await self.store.remove_room_from_group(group_id, old_room_id)

View File

@ -149,7 +149,6 @@ class DataStore(
],
)
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._group_updates_id_gen = StreamIdGenerator(

View File

@ -17,7 +17,7 @@ import collections
import logging
from abc import abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, cast
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError
@ -29,8 +29,9 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.search import SearchStore
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@ -75,7 +76,7 @@ class RoomSortOrder(Enum):
STATE_EVENTS = "state_events"
class RoomWorkerStore(SQLBaseStore):
class RoomWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
database: DatabasePool,
@ -92,7 +93,7 @@ class RoomWorkerStore(SQLBaseStore):
room_creator_user_id: str,
is_public: bool,
room_version: RoomVersion,
):
) -> None:
"""Stores a room.
Args:
@ -120,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
async def get_room(self, room_id: str) -> dict:
async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve a room.
Args:
@ -145,7 +146,9 @@ class RoomWorkerStore(SQLBaseStore):
A dict containing the room information, or None if the room is unknown.
"""
def get_room_with_stats_txn(txn, room_id):
def get_room_with_stats_txn(
txn: LoggingTransaction, room_id: str
) -> Optional[Dict[str, Any]]:
sql = """
SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
@ -194,7 +197,7 @@ class RoomWorkerStore(SQLBaseStore):
ignore_non_federatable: If true filters out non-federatable rooms
"""
def _count_public_rooms_txn(txn):
def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
query_args = []
if network_tuple:
@ -235,7 +238,7 @@ class RoomWorkerStore(SQLBaseStore):
}
txn.execute(sql, query_args)
return txn.fetchone()[0]
return cast(Tuple[int], txn.fetchone())[0]
return await self.db_pool.runInteraction(
"count_public_rooms", _count_public_rooms_txn
@ -244,11 +247,11 @@ class RoomWorkerStore(SQLBaseStore):
async def get_room_count(self) -> int:
"""Retrieve the total number of rooms."""
def f(txn):
def f(txn: LoggingTransaction) -> int:
sql = "SELECT count(*) FROM rooms"
txn.execute(sql)
row = txn.fetchone()
return row[0] or 0
row = cast(Tuple[int], txn.fetchone())
return row[0]
return await self.db_pool.runInteraction("get_rooms", f)
@ -260,7 +263,7 @@ class RoomWorkerStore(SQLBaseStore):
bounds: Optional[Tuple[int, str]],
forwards: bool,
ignore_non_federatable: bool = False,
):
) -> List[Dict[str, Any]]:
"""Gets the largest public rooms (where largest is in terms of joined
members, as tracked in the statistics table).
@ -381,7 +384,9 @@ class RoomWorkerStore(SQLBaseStore):
LIMIT ?
"""
def _get_largest_public_rooms_txn(txn):
def _get_largest_public_rooms_txn(
txn: LoggingTransaction,
) -> List[Dict[str, Any]]:
txn.execute(sql, query_args)
results = self.db_pool.cursor_to_dict(txn)
@ -444,7 +449,7 @@ class RoomWorkerStore(SQLBaseStore):
"""
# Filter room names by a string
where_statement = ""
search_pattern = []
search_pattern: List[object] = []
if search_term:
where_statement = """
WHERE LOWER(state.name) LIKE ?
@ -552,7 +557,9 @@ class RoomWorkerStore(SQLBaseStore):
where_statement,
)
def _get_rooms_paginate_txn(txn):
def _get_rooms_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[Dict[str, Any]], int]:
# Add the search term into the WHERE clause
# and execute the data query
txn.execute(info_sql, search_pattern + [limit, start])
@ -584,7 +591,7 @@ class RoomWorkerStore(SQLBaseStore):
# Add the search term into the WHERE clause if present
txn.execute(count_sql, search_pattern)
room_count = txn.fetchone()
room_count = cast(Tuple[int], txn.fetchone())
return rooms, room_count[0]
return await self.db_pool.runInteraction(
@ -629,7 +636,7 @@ class RoomWorkerStore(SQLBaseStore):
burst_count: How many actions that can be performed before being limited.
"""
def set_ratelimit_txn(txn):
def set_ratelimit_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_upsert_txn(
txn,
table="ratelimit_override",
@ -652,7 +659,7 @@ class RoomWorkerStore(SQLBaseStore):
user_id: user ID of the user
"""
def delete_ratelimit_txn(txn):
def delete_ratelimit_txn(txn: LoggingTransaction) -> None:
row = self.db_pool.simple_select_one_txn(
txn,
table="ratelimit_override",
@ -676,7 +683,7 @@ class RoomWorkerStore(SQLBaseStore):
await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
@cached()
async def get_retention_policy_for_room(self, room_id):
async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]:
"""Get the retention policy for a given room.
If no retention policy has been found for this room, returns a policy defined
@ -685,13 +692,15 @@ class RoomWorkerStore(SQLBaseStore):
configuration).
Args:
room_id (str): The ID of the room to get the retention policy of.
room_id: The ID of the room to get the retention policy of.
Returns:
dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
A dict containing "min_lifetime" and "max_lifetime" for this room.
"""
def get_retention_policy_for_room_txn(txn):
def get_retention_policy_for_room_txn(
txn: LoggingTransaction,
) -> List[Dict[str, Optional[int]]]:
txn.execute(
"""
SELECT min_lifetime, max_lifetime FROM room_retention
@ -716,19 +725,23 @@ class RoomWorkerStore(SQLBaseStore):
"max_lifetime": self.config.retention.retention_default_max_lifetime,
}
row = ret[0]
min_lifetime = ret[0]["min_lifetime"]
max_lifetime = ret[0]["max_lifetime"]
# If one of the room's policy's attributes isn't defined, use the matching
# attribute from the default policy.
# The default values will be None if no default policy has been defined, or if one
# of the attributes is missing from the default policy.
if row["min_lifetime"] is None:
row["min_lifetime"] = self.config.retention.retention_default_min_lifetime
if min_lifetime is None:
min_lifetime = self.config.retention.retention_default_min_lifetime
if row["max_lifetime"] is None:
row["max_lifetime"] = self.config.retention.retention_default_max_lifetime
if max_lifetime is None:
max_lifetime = self.config.retention.retention_default_max_lifetime
return row
return {
"min_lifetime": min_lifetime,
"max_lifetime": max_lifetime,
}
async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
"""Retrieves all the local and remote media MXC URIs in a given room
@ -740,7 +753,9 @@ class RoomWorkerStore(SQLBaseStore):
The local and remote media as a lists of the media IDs.
"""
def _get_media_mxcs_in_room_txn(txn):
def _get_media_mxcs_in_room_txn(
txn: LoggingTransaction,
) -> Tuple[List[str], List[str]]:
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
local_media_mxcs = []
remote_media_mxcs = []
@ -766,7 +781,7 @@ class RoomWorkerStore(SQLBaseStore):
logger.info("Quarantining media in room: %s", room_id)
def _quarantine_media_in_room_txn(txn):
def _quarantine_media_in_room_txn(txn: LoggingTransaction) -> int:
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
return self._quarantine_media_txn(
txn, local_mxcs, remote_mxcs, quarantined_by
@ -776,13 +791,11 @@ class RoomWorkerStore(SQLBaseStore):
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
def _get_media_mxcs_in_room_txn(self, txn, room_id):
def _get_media_mxcs_in_room_txn(
self, txn: LoggingTransaction, room_id: str
) -> Tuple[List[str], List[Tuple[str, str]]]:
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
txn (cursor)
room_id (str)
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
@ -850,7 +863,7 @@ class RoomWorkerStore(SQLBaseStore):
logger.info("Quarantining media: %s/%s", server_name, media_id)
is_local = server_name == self.config.server.server_name
def _quarantine_media_by_id_txn(txn):
def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int:
local_mxcs = [media_id] if is_local else []
remote_mxcs = [(server_name, media_id)] if not is_local else []
@ -872,7 +885,7 @@ class RoomWorkerStore(SQLBaseStore):
quarantined_by: The ID of the user who made the quarantine request
"""
def _quarantine_media_by_user_txn(txn):
def _quarantine_media_by_user_txn(txn: LoggingTransaction) -> int:
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
@ -880,7 +893,9 @@ class RoomWorkerStore(SQLBaseStore):
"quarantine_media_by_user", _quarantine_media_by_user_txn
)
def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True):
def _get_media_ids_by_user_txn(
self, txn: LoggingTransaction, user_id: str, filter_quarantined: bool = True
) -> List[str]:
"""Retrieves local media IDs by a given user
Args:
@ -909,7 +924,7 @@ class RoomWorkerStore(SQLBaseStore):
def _quarantine_media_txn(
self,
txn,
txn: LoggingTransaction,
local_mxcs: List[str],
remote_mxcs: List[Tuple[str, str]],
quarantined_by: Optional[str],
@ -937,12 +952,15 @@ class RoomWorkerStore(SQLBaseStore):
# set quarantine
if quarantined_by is not None:
sql += "AND safe_from_quarantine = ?"
rows = [(quarantined_by, media_id, False) for media_id in local_mxcs]
txn.executemany(
sql, [(quarantined_by, media_id, False) for media_id in local_mxcs]
)
# remove from quarantine
else:
rows = [(quarantined_by, media_id) for media_id in local_mxcs]
txn.executemany(
sql, [(quarantined_by, media_id) for media_id in local_mxcs]
)
txn.executemany(sql, rows)
# Note that a rowcount of -1 can be used to indicate no rows were affected.
total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
@ -960,7 +978,7 @@ class RoomWorkerStore(SQLBaseStore):
async def get_rooms_for_retention_period_in_range(
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
) -> Dict[str, dict]:
) -> Dict[str, Dict[str, Optional[int]]]:
"""Retrieves all of the rooms within the given retention range.
Optionally includes the rooms which don't have a retention policy.
@ -980,7 +998,9 @@ class RoomWorkerStore(SQLBaseStore):
"min_lifetime" (int|None), and "max_lifetime" (int|None).
"""
def get_rooms_for_retention_period_in_range_txn(txn):
def get_rooms_for_retention_period_in_range_txn(
txn: LoggingTransaction,
) -> Dict[str, Dict[str, Optional[int]]]:
range_conditions = []
args = []
@ -1067,8 +1087,6 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
self.config = hs.config
self.db_pool.updates.register_background_update_handler(
"insert_room_retention",
self._background_insert_retention,
@ -1099,7 +1117,9 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
self._background_populate_rooms_creator_column,
)
async def _background_insert_retention(self, progress, batch_size):
async def _background_insert_retention(
self, progress: JsonDict, batch_size: int
) -> int:
"""Retrieves a list of all rooms within a range and inserts an entry for each of
them into the room_retention table.
NULLs the property's columns if missing from the retention event in the room's
@ -1109,7 +1129,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
last_room = progress.get("room_id", "")
def _background_insert_retention_txn(txn):
def _background_insert_retention_txn(txn: LoggingTransaction) -> bool:
txn.execute(
"""
SELECT state.room_id, state.event_id, events.json
@ -1168,15 +1188,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return batch_size
async def _background_add_rooms_room_version_column(
self, progress: dict, batch_size: int
):
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to go and add room version information to `rooms`
table from `current_state_events` table.
"""
last_room_id = progress.get("room_id", "")
def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction):
def _background_add_rooms_room_version_column_txn(
txn: LoggingTransaction,
) -> bool:
sql = """
SELECT room_id, json FROM current_state_events
INNER JOIN event_json USING (room_id, event_id)
@ -1237,7 +1259,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return batch_size
async def _remove_tombstoned_rooms_from_directory(
self, progress, batch_size
self, progress: JsonDict, batch_size: int
) -> int:
"""Removes any rooms with tombstone events from the room directory
@ -1247,7 +1269,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
last_room = progress.get("room_id", "")
def _get_rooms(txn):
def _get_rooms(txn: LoggingTransaction) -> List[str]:
txn.execute(
"""
SELECT room_id
@ -1285,7 +1307,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return len(rooms)
@abstractmethod
def set_room_is_public(self, room_id, is_public):
def set_room_is_public(self, room_id: str, is_public: bool) -> Awaitable[None]:
# this will need to be implemented if a background update is performed with
# existing (tombstoned, public) rooms in the database.
#
@ -1332,7 +1354,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
32-bit integer field.
"""
def process(txn: Cursor) -> int:
def process(txn: LoggingTransaction) -> int:
last_room = progress.get("last_room", "")
txn.execute(
"""
@ -1389,15 +1411,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return 0
async def _background_populate_rooms_creator_column(
self, progress: dict, batch_size: int
):
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to go and add creator information to `rooms`
table from `current_state_events` table.
"""
last_room_id = progress.get("room_id", "")
def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction):
def _background_populate_rooms_creator_column_txn(
txn: LoggingTransaction,
) -> bool:
sql = """
SELECT room_id, json FROM event_json
INNER JOIN rooms AS room USING (room_id)
@ -1448,7 +1472,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return batch_size
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
def __init__(
self,
database: DatabasePool,
@ -1457,11 +1481,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
):
super().__init__(database, db_conn, hs)
self.config = hs.config
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
async def upsert_room_on_join(
self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase]
):
) -> None:
"""Ensure that the room is stored in the table
Called when we join a room over federation, and overwrites any room version
@ -1507,7 +1531,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
async def maybe_store_room_on_outlier_membership(
self, room_id: str, room_version: RoomVersion
):
) -> None:
"""
When we receive an invite or any other event over federation that may relate to a room
we are not in, store the version of the room if we don't already know the room version.
@ -1547,8 +1571,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
self.hs.get_notifier().on_new_replication_data()
async def set_room_is_public_appservice(
self, room_id, appservice_id, network_id, is_public
):
self, room_id: str, appservice_id: str, network_id: str, is_public: bool
) -> None:
"""Edit the appservice/network specific public room list.
Each appservice can have a number of published room lists associated
@ -1557,11 +1581,10 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
network.
Args:
room_id (str)
appservice_id (str)
network_id (str)
is_public (bool): Whether to publish or unpublish the room from the
list.
room_id
appservice_id
network_id
is_public: Whether to publish or unpublish the room from the list.
"""
if is_public:
@ -1626,7 +1649,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
event_report: json list of information from event report
"""
def _get_event_report_txn(txn, report_id):
def _get_event_report_txn(
txn: LoggingTransaction, report_id: int
) -> Optional[Dict[str, Any]]:
sql = """
SELECT
@ -1698,9 +1723,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
count: total number of event reports matching the filter criteria
"""
def _get_event_reports_paginate_txn(txn):
def _get_event_reports_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[Dict[str, Any]], int]:
filters = []
args = []
args: List[object] = []
if user_id:
filters.append("er.user_id LIKE ?")
@ -1724,7 +1751,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
where_clause
)
txn.execute(sql, args)
count = txn.fetchone()[0]
count = cast(Tuple[int], txn.fetchone())[0]
sql = """
SELECT