mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Add type hints to synapse/storage/databases/main/room.py
(#11575)
This commit is contained in:
parent
f901f8b70e
commit
c7fe32edb4
1
changelog.d/11575.misc
Normal file
1
changelog.d/11575.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add missing type hints to storage classes.
|
4
mypy.ini
4
mypy.ini
@ -37,7 +37,6 @@ exclude = (?x)
|
|||||||
|synapse/storage/databases/main/purge_events.py
|
|synapse/storage/databases/main/purge_events.py
|
||||||
|synapse/storage/databases/main/push_rule.py
|
|synapse/storage/databases/main/push_rule.py
|
||||||
|synapse/storage/databases/main/receipts.py
|
|synapse/storage/databases/main/receipts.py
|
||||||
|synapse/storage/databases/main/room.py
|
|
||||||
|synapse/storage/databases/main/roommember.py
|
|synapse/storage/databases/main/roommember.py
|
||||||
|synapse/storage/databases/main/search.py
|
|synapse/storage/databases/main/search.py
|
||||||
|synapse/storage/databases/main/state.py
|
|synapse/storage/databases/main/state.py
|
||||||
@ -205,6 +204,9 @@ disallow_untyped_defs = True
|
|||||||
[mypy-synapse.storage.databases.main.events_worker]
|
[mypy-synapse.storage.databases.main.events_worker]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.storage.databases.main.room]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.storage.databases.main.room_batch]
|
[mypy-synapse.storage.databases.main.room_batch]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
@ -1020,7 +1020,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||||||
# Add new room to the room directory if the old room was there
|
# Add new room to the room directory if the old room was there
|
||||||
# Remove old room from the room directory
|
# Remove old room from the room directory
|
||||||
old_room = await self.store.get_room(old_room_id)
|
old_room = await self.store.get_room(old_room_id)
|
||||||
if old_room and old_room["is_public"]:
|
if old_room is not None and old_room["is_public"]:
|
||||||
await self.store.set_room_is_public(old_room_id, False)
|
await self.store.set_room_is_public(old_room_id, False)
|
||||||
await self.store.set_room_is_public(room_id, True)
|
await self.store.set_room_is_public(room_id, True)
|
||||||
|
|
||||||
@ -1031,7 +1031,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||||||
local_group_ids = await self.store.get_local_groups_for_room(old_room_id)
|
local_group_ids = await self.store.get_local_groups_for_room(old_room_id)
|
||||||
for group_id in local_group_ids:
|
for group_id in local_group_ids:
|
||||||
# Add new the new room to those groups
|
# Add new the new room to those groups
|
||||||
await self.store.add_room_to_group(group_id, room_id, old_room["is_public"])
|
await self.store.add_room_to_group(
|
||||||
|
group_id, room_id, old_room is not None and old_room["is_public"]
|
||||||
|
)
|
||||||
|
|
||||||
# Remove the old room from those groups
|
# Remove the old room from those groups
|
||||||
await self.store.remove_room_from_group(group_id, old_room_id)
|
await self.store.remove_room_from_group(group_id, old_room_id)
|
||||||
|
@ -149,7 +149,6 @@ class DataStore(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
|
||||||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||||
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
||||||
self._group_updates_id_gen = StreamIdGenerator(
|
self._group_updates_id_gen = StreamIdGenerator(
|
||||||
|
@ -17,7 +17,7 @@ import collections
|
|||||||
import logging
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, cast
|
||||||
|
|
||||||
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
|
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
@ -29,8 +29,9 @@ from synapse.storage.database import (
|
|||||||
LoggingDatabaseConnection,
|
LoggingDatabaseConnection,
|
||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
)
|
)
|
||||||
from synapse.storage.databases.main.search import SearchStore
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
|
from synapse.storage.util.id_generators import IdGenerator
|
||||||
from synapse.types import JsonDict, ThirdPartyInstanceID
|
from synapse.types import JsonDict, ThirdPartyInstanceID
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
@ -75,7 +76,7 @@ class RoomSortOrder(Enum):
|
|||||||
STATE_EVENTS = "state_events"
|
STATE_EVENTS = "state_events"
|
||||||
|
|
||||||
|
|
||||||
class RoomWorkerStore(SQLBaseStore):
|
class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
database: DatabasePool,
|
database: DatabasePool,
|
||||||
@ -92,7 +93,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
room_creator_user_id: str,
|
room_creator_user_id: str,
|
||||||
is_public: bool,
|
is_public: bool,
|
||||||
room_version: RoomVersion,
|
room_version: RoomVersion,
|
||||||
):
|
) -> None:
|
||||||
"""Stores a room.
|
"""Stores a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -120,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
||||||
raise StoreError(500, "Problem creating room.")
|
raise StoreError(500, "Problem creating room.")
|
||||||
|
|
||||||
async def get_room(self, room_id: str) -> dict:
|
async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""Retrieve a room.
|
"""Retrieve a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -145,7 +146,9 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
A dict containing the room information, or None if the room is unknown.
|
A dict containing the room information, or None if the room is unknown.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_room_with_stats_txn(txn, room_id):
|
def get_room_with_stats_txn(
|
||||||
|
txn: LoggingTransaction, room_id: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
|
SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
|
||||||
curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
|
curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
|
||||||
@ -194,7 +197,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
ignore_non_federatable: If true filters out non-federatable rooms
|
ignore_non_federatable: If true filters out non-federatable rooms
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _count_public_rooms_txn(txn):
|
def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
|
||||||
query_args = []
|
query_args = []
|
||||||
|
|
||||||
if network_tuple:
|
if network_tuple:
|
||||||
@ -235,7 +238,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
}
|
}
|
||||||
|
|
||||||
txn.execute(sql, query_args)
|
txn.execute(sql, query_args)
|
||||||
return txn.fetchone()[0]
|
return cast(Tuple[int], txn.fetchone())[0]
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"count_public_rooms", _count_public_rooms_txn
|
"count_public_rooms", _count_public_rooms_txn
|
||||||
@ -244,11 +247,11 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
async def get_room_count(self) -> int:
|
async def get_room_count(self) -> int:
|
||||||
"""Retrieve the total number of rooms."""
|
"""Retrieve the total number of rooms."""
|
||||||
|
|
||||||
def f(txn):
|
def f(txn: LoggingTransaction) -> int:
|
||||||
sql = "SELECT count(*) FROM rooms"
|
sql = "SELECT count(*) FROM rooms"
|
||||||
txn.execute(sql)
|
txn.execute(sql)
|
||||||
row = txn.fetchone()
|
row = cast(Tuple[int], txn.fetchone())
|
||||||
return row[0] or 0
|
return row[0]
|
||||||
|
|
||||||
return await self.db_pool.runInteraction("get_rooms", f)
|
return await self.db_pool.runInteraction("get_rooms", f)
|
||||||
|
|
||||||
@ -260,7 +263,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
bounds: Optional[Tuple[int, str]],
|
bounds: Optional[Tuple[int, str]],
|
||||||
forwards: bool,
|
forwards: bool,
|
||||||
ignore_non_federatable: bool = False,
|
ignore_non_federatable: bool = False,
|
||||||
):
|
) -> List[Dict[str, Any]]:
|
||||||
"""Gets the largest public rooms (where largest is in terms of joined
|
"""Gets the largest public rooms (where largest is in terms of joined
|
||||||
members, as tracked in the statistics table).
|
members, as tracked in the statistics table).
|
||||||
|
|
||||||
@ -381,7 +384,9 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
LIMIT ?
|
LIMIT ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_largest_public_rooms_txn(txn):
|
def _get_largest_public_rooms_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
txn.execute(sql, query_args)
|
txn.execute(sql, query_args)
|
||||||
|
|
||||||
results = self.db_pool.cursor_to_dict(txn)
|
results = self.db_pool.cursor_to_dict(txn)
|
||||||
@ -444,7 +449,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
"""
|
"""
|
||||||
# Filter room names by a string
|
# Filter room names by a string
|
||||||
where_statement = ""
|
where_statement = ""
|
||||||
search_pattern = []
|
search_pattern: List[object] = []
|
||||||
if search_term:
|
if search_term:
|
||||||
where_statement = """
|
where_statement = """
|
||||||
WHERE LOWER(state.name) LIKE ?
|
WHERE LOWER(state.name) LIKE ?
|
||||||
@ -552,7 +557,9 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
where_statement,
|
where_statement,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_rooms_paginate_txn(txn):
|
def _get_rooms_paginate_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
# Add the search term into the WHERE clause
|
# Add the search term into the WHERE clause
|
||||||
# and execute the data query
|
# and execute the data query
|
||||||
txn.execute(info_sql, search_pattern + [limit, start])
|
txn.execute(info_sql, search_pattern + [limit, start])
|
||||||
@ -584,7 +591,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
# Add the search term into the WHERE clause if present
|
# Add the search term into the WHERE clause if present
|
||||||
txn.execute(count_sql, search_pattern)
|
txn.execute(count_sql, search_pattern)
|
||||||
|
|
||||||
room_count = txn.fetchone()
|
room_count = cast(Tuple[int], txn.fetchone())
|
||||||
return rooms, room_count[0]
|
return rooms, room_count[0]
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
@ -629,7 +636,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
burst_count: How many actions that can be performed before being limited.
|
burst_count: How many actions that can be performed before being limited.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def set_ratelimit_txn(txn):
|
def set_ratelimit_txn(txn: LoggingTransaction) -> None:
|
||||||
self.db_pool.simple_upsert_txn(
|
self.db_pool.simple_upsert_txn(
|
||||||
txn,
|
txn,
|
||||||
table="ratelimit_override",
|
table="ratelimit_override",
|
||||||
@ -652,7 +659,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
user_id: user ID of the user
|
user_id: user ID of the user
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def delete_ratelimit_txn(txn):
|
def delete_ratelimit_txn(txn: LoggingTransaction) -> None:
|
||||||
row = self.db_pool.simple_select_one_txn(
|
row = self.db_pool.simple_select_one_txn(
|
||||||
txn,
|
txn,
|
||||||
table="ratelimit_override",
|
table="ratelimit_override",
|
||||||
@ -676,7 +683,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
|
await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def get_retention_policy_for_room(self, room_id):
|
async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]:
|
||||||
"""Get the retention policy for a given room.
|
"""Get the retention policy for a given room.
|
||||||
|
|
||||||
If no retention policy has been found for this room, returns a policy defined
|
If no retention policy has been found for this room, returns a policy defined
|
||||||
@ -685,13 +692,15 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
configuration).
|
configuration).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id (str): The ID of the room to get the retention policy of.
|
room_id: The ID of the room to get the retention policy of.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
|
A dict containing "min_lifetime" and "max_lifetime" for this room.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_retention_policy_for_room_txn(txn):
|
def get_retention_policy_for_room_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> List[Dict[str, Optional[int]]]:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT min_lifetime, max_lifetime FROM room_retention
|
SELECT min_lifetime, max_lifetime FROM room_retention
|
||||||
@ -716,19 +725,23 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
"max_lifetime": self.config.retention.retention_default_max_lifetime,
|
"max_lifetime": self.config.retention.retention_default_max_lifetime,
|
||||||
}
|
}
|
||||||
|
|
||||||
row = ret[0]
|
min_lifetime = ret[0]["min_lifetime"]
|
||||||
|
max_lifetime = ret[0]["max_lifetime"]
|
||||||
|
|
||||||
# If one of the room's policy's attributes isn't defined, use the matching
|
# If one of the room's policy's attributes isn't defined, use the matching
|
||||||
# attribute from the default policy.
|
# attribute from the default policy.
|
||||||
# The default values will be None if no default policy has been defined, or if one
|
# The default values will be None if no default policy has been defined, or if one
|
||||||
# of the attributes is missing from the default policy.
|
# of the attributes is missing from the default policy.
|
||||||
if row["min_lifetime"] is None:
|
if min_lifetime is None:
|
||||||
row["min_lifetime"] = self.config.retention.retention_default_min_lifetime
|
min_lifetime = self.config.retention.retention_default_min_lifetime
|
||||||
|
|
||||||
if row["max_lifetime"] is None:
|
if max_lifetime is None:
|
||||||
row["max_lifetime"] = self.config.retention.retention_default_max_lifetime
|
max_lifetime = self.config.retention.retention_default_max_lifetime
|
||||||
|
|
||||||
return row
|
return {
|
||||||
|
"min_lifetime": min_lifetime,
|
||||||
|
"max_lifetime": max_lifetime,
|
||||||
|
}
|
||||||
|
|
||||||
async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
|
async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
|
||||||
"""Retrieves all the local and remote media MXC URIs in a given room
|
"""Retrieves all the local and remote media MXC URIs in a given room
|
||||||
@ -740,7 +753,9 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
The local and remote media as a lists of the media IDs.
|
The local and remote media as a lists of the media IDs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_media_mxcs_in_room_txn(txn):
|
def _get_media_mxcs_in_room_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[List[str], List[str]]:
|
||||||
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
|
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
|
||||||
local_media_mxcs = []
|
local_media_mxcs = []
|
||||||
remote_media_mxcs = []
|
remote_media_mxcs = []
|
||||||
@ -766,7 +781,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
logger.info("Quarantining media in room: %s", room_id)
|
logger.info("Quarantining media in room: %s", room_id)
|
||||||
|
|
||||||
def _quarantine_media_in_room_txn(txn):
|
def _quarantine_media_in_room_txn(txn: LoggingTransaction) -> int:
|
||||||
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
|
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
|
||||||
return self._quarantine_media_txn(
|
return self._quarantine_media_txn(
|
||||||
txn, local_mxcs, remote_mxcs, quarantined_by
|
txn, local_mxcs, remote_mxcs, quarantined_by
|
||||||
@ -776,13 +791,11 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
"quarantine_media_in_room", _quarantine_media_in_room_txn
|
"quarantine_media_in_room", _quarantine_media_in_room_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_media_mxcs_in_room_txn(self, txn, room_id):
|
def _get_media_mxcs_in_room_txn(
|
||||||
|
self, txn: LoggingTransaction, room_id: str
|
||||||
|
) -> Tuple[List[str], List[Tuple[str, str]]]:
|
||||||
"""Retrieves all the local and remote media MXC URIs in a given room
|
"""Retrieves all the local and remote media MXC URIs in a given room
|
||||||
|
|
||||||
Args:
|
|
||||||
txn (cursor)
|
|
||||||
room_id (str)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The local and remote media as a lists of tuples where the key is
|
The local and remote media as a lists of tuples where the key is
|
||||||
the hostname and the value is the media ID.
|
the hostname and the value is the media ID.
|
||||||
@ -850,7 +863,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
logger.info("Quarantining media: %s/%s", server_name, media_id)
|
logger.info("Quarantining media: %s/%s", server_name, media_id)
|
||||||
is_local = server_name == self.config.server.server_name
|
is_local = server_name == self.config.server.server_name
|
||||||
|
|
||||||
def _quarantine_media_by_id_txn(txn):
|
def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int:
|
||||||
local_mxcs = [media_id] if is_local else []
|
local_mxcs = [media_id] if is_local else []
|
||||||
remote_mxcs = [(server_name, media_id)] if not is_local else []
|
remote_mxcs = [(server_name, media_id)] if not is_local else []
|
||||||
|
|
||||||
@ -872,7 +885,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
quarantined_by: The ID of the user who made the quarantine request
|
quarantined_by: The ID of the user who made the quarantine request
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _quarantine_media_by_user_txn(txn):
|
def _quarantine_media_by_user_txn(txn: LoggingTransaction) -> int:
|
||||||
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
|
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
|
||||||
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
|
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
|
||||||
|
|
||||||
@ -880,7 +893,9 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
"quarantine_media_by_user", _quarantine_media_by_user_txn
|
"quarantine_media_by_user", _quarantine_media_by_user_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True):
|
def _get_media_ids_by_user_txn(
|
||||||
|
self, txn: LoggingTransaction, user_id: str, filter_quarantined: bool = True
|
||||||
|
) -> List[str]:
|
||||||
"""Retrieves local media IDs by a given user
|
"""Retrieves local media IDs by a given user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -909,7 +924,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
def _quarantine_media_txn(
|
def _quarantine_media_txn(
|
||||||
self,
|
self,
|
||||||
txn,
|
txn: LoggingTransaction,
|
||||||
local_mxcs: List[str],
|
local_mxcs: List[str],
|
||||||
remote_mxcs: List[Tuple[str, str]],
|
remote_mxcs: List[Tuple[str, str]],
|
||||||
quarantined_by: Optional[str],
|
quarantined_by: Optional[str],
|
||||||
@ -937,12 +952,15 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
# set quarantine
|
# set quarantine
|
||||||
if quarantined_by is not None:
|
if quarantined_by is not None:
|
||||||
sql += "AND safe_from_quarantine = ?"
|
sql += "AND safe_from_quarantine = ?"
|
||||||
rows = [(quarantined_by, media_id, False) for media_id in local_mxcs]
|
txn.executemany(
|
||||||
|
sql, [(quarantined_by, media_id, False) for media_id in local_mxcs]
|
||||||
|
)
|
||||||
# remove from quarantine
|
# remove from quarantine
|
||||||
else:
|
else:
|
||||||
rows = [(quarantined_by, media_id) for media_id in local_mxcs]
|
txn.executemany(
|
||||||
|
sql, [(quarantined_by, media_id) for media_id in local_mxcs]
|
||||||
|
)
|
||||||
|
|
||||||
txn.executemany(sql, rows)
|
|
||||||
# Note that a rowcount of -1 can be used to indicate no rows were affected.
|
# Note that a rowcount of -1 can be used to indicate no rows were affected.
|
||||||
total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
|
total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
|
||||||
|
|
||||||
@ -960,7 +978,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
async def get_rooms_for_retention_period_in_range(
|
async def get_rooms_for_retention_period_in_range(
|
||||||
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
|
self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
|
||||||
) -> Dict[str, dict]:
|
) -> Dict[str, Dict[str, Optional[int]]]:
|
||||||
"""Retrieves all of the rooms within the given retention range.
|
"""Retrieves all of the rooms within the given retention range.
|
||||||
|
|
||||||
Optionally includes the rooms which don't have a retention policy.
|
Optionally includes the rooms which don't have a retention policy.
|
||||||
@ -980,7 +998,9 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
"min_lifetime" (int|None), and "max_lifetime" (int|None).
|
"min_lifetime" (int|None), and "max_lifetime" (int|None).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_rooms_for_retention_period_in_range_txn(txn):
|
def get_rooms_for_retention_period_in_range_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Dict[str, Dict[str, Optional[int]]]:
|
||||||
range_conditions = []
|
range_conditions = []
|
||||||
args = []
|
args = []
|
||||||
|
|
||||||
@ -1067,8 +1087,6 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||||||
):
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self.config = hs.config
|
|
||||||
|
|
||||||
self.db_pool.updates.register_background_update_handler(
|
self.db_pool.updates.register_background_update_handler(
|
||||||
"insert_room_retention",
|
"insert_room_retention",
|
||||||
self._background_insert_retention,
|
self._background_insert_retention,
|
||||||
@ -1099,7 +1117,9 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||||||
self._background_populate_rooms_creator_column,
|
self._background_populate_rooms_creator_column,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _background_insert_retention(self, progress, batch_size):
|
async def _background_insert_retention(
|
||||||
|
self, progress: JsonDict, batch_size: int
|
||||||
|
) -> int:
|
||||||
"""Retrieves a list of all rooms within a range and inserts an entry for each of
|
"""Retrieves a list of all rooms within a range and inserts an entry for each of
|
||||||
them into the room_retention table.
|
them into the room_retention table.
|
||||||
NULLs the property's columns if missing from the retention event in the room's
|
NULLs the property's columns if missing from the retention event in the room's
|
||||||
@ -1109,7 +1129,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||||||
|
|
||||||
last_room = progress.get("room_id", "")
|
last_room = progress.get("room_id", "")
|
||||||
|
|
||||||
def _background_insert_retention_txn(txn):
|
def _background_insert_retention_txn(txn: LoggingTransaction) -> bool:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT state.room_id, state.event_id, events.json
|
SELECT state.room_id, state.event_id, events.json
|
||||||
@ -1168,15 +1188,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||||||
return batch_size
|
return batch_size
|
||||||
|
|
||||||
async def _background_add_rooms_room_version_column(
|
async def _background_add_rooms_room_version_column(
|
||||||
self, progress: dict, batch_size: int
|
self, progress: JsonDict, batch_size: int
|
||||||
):
|
) -> int:
|
||||||
"""Background update to go and add room version information to `rooms`
|
"""Background update to go and add room version information to `rooms`
|
||||||
table from `current_state_events` table.
|
table from `current_state_events` table.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
last_room_id = progress.get("room_id", "")
|
last_room_id = progress.get("room_id", "")
|
||||||
|
|
||||||
def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction):
|
def _background_add_rooms_room_version_column_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> bool:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT room_id, json FROM current_state_events
|
SELECT room_id, json FROM current_state_events
|
||||||
INNER JOIN event_json USING (room_id, event_id)
|
INNER JOIN event_json USING (room_id, event_id)
|
||||||
@ -1237,7 +1259,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||||||
return batch_size
|
return batch_size
|
||||||
|
|
||||||
async def _remove_tombstoned_rooms_from_directory(
|
async def _remove_tombstoned_rooms_from_directory(
|
||||||
self, progress, batch_size
|
self, progress: JsonDict, batch_size: int
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Removes any rooms with tombstone events from the room directory
|
"""Removes any rooms with tombstone events from the room directory
|
||||||
|
|
||||||
@ -1247,7 +1269,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||||||
|
|
||||||
last_room = progress.get("room_id", "")
|
last_room = progress.get("room_id", "")
|
||||||
|
|
||||||
def _get_rooms(txn):
|
def _get_rooms(txn: LoggingTransaction) -> List[str]:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT room_id
|
SELECT room_id
|
||||||
@ -1285,7 +1307,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||||||
return len(rooms)
|
return len(rooms)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_room_is_public(self, room_id, is_public):
|
def set_room_is_public(self, room_id: str, is_public: bool) -> Awaitable[None]:
|
||||||
# this will need to be implemented if a background update is performed with
|
# this will need to be implemented if a background update is performed with
|
||||||
# existing (tombstoned, public) rooms in the database.
|
# existing (tombstoned, public) rooms in the database.
|
||||||
#
|
#
|
||||||
@ -1332,7 +1354,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||||||
32-bit integer field.
|
32-bit integer field.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def process(txn: Cursor) -> int:
|
def process(txn: LoggingTransaction) -> int:
|
||||||
last_room = progress.get("last_room", "")
|
last_room = progress.get("last_room", "")
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
@ -1389,15 +1411,17 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
async def _background_populate_rooms_creator_column(
|
async def _background_populate_rooms_creator_column(
|
||||||
self, progress: dict, batch_size: int
|
self, progress: JsonDict, batch_size: int
|
||||||
):
|
) -> int:
|
||||||
"""Background update to go and add creator information to `rooms`
|
"""Background update to go and add creator information to `rooms`
|
||||||
table from `current_state_events` table.
|
table from `current_state_events` table.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
last_room_id = progress.get("room_id", "")
|
last_room_id = progress.get("room_id", "")
|
||||||
|
|
||||||
def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction):
|
def _background_populate_rooms_creator_column_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> bool:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT room_id, json FROM event_json
|
SELECT room_id, json FROM event_json
|
||||||
INNER JOIN rooms AS room USING (room_id)
|
INNER JOIN rooms AS room USING (room_id)
|
||||||
@ -1448,7 +1472,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||||||
return batch_size
|
return batch_size
|
||||||
|
|
||||||
|
|
||||||
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
database: DatabasePool,
|
database: DatabasePool,
|
||||||
@ -1457,11 +1481,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||||||
):
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self.config = hs.config
|
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
||||||
|
|
||||||
async def upsert_room_on_join(
|
async def upsert_room_on_join(
|
||||||
self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase]
|
self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase]
|
||||||
):
|
) -> None:
|
||||||
"""Ensure that the room is stored in the table
|
"""Ensure that the room is stored in the table
|
||||||
|
|
||||||
Called when we join a room over federation, and overwrites any room version
|
Called when we join a room over federation, and overwrites any room version
|
||||||
@ -1507,7 +1531,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||||||
|
|
||||||
async def maybe_store_room_on_outlier_membership(
|
async def maybe_store_room_on_outlier_membership(
|
||||||
self, room_id: str, room_version: RoomVersion
|
self, room_id: str, room_version: RoomVersion
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
When we receive an invite or any other event over federation that may relate to a room
|
When we receive an invite or any other event over federation that may relate to a room
|
||||||
we are not in, store the version of the room if we don't already know the room version.
|
we are not in, store the version of the room if we don't already know the room version.
|
||||||
@ -1547,8 +1571,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||||||
self.hs.get_notifier().on_new_replication_data()
|
self.hs.get_notifier().on_new_replication_data()
|
||||||
|
|
||||||
async def set_room_is_public_appservice(
|
async def set_room_is_public_appservice(
|
||||||
self, room_id, appservice_id, network_id, is_public
|
self, room_id: str, appservice_id: str, network_id: str, is_public: bool
|
||||||
):
|
) -> None:
|
||||||
"""Edit the appservice/network specific public room list.
|
"""Edit the appservice/network specific public room list.
|
||||||
|
|
||||||
Each appservice can have a number of published room lists associated
|
Each appservice can have a number of published room lists associated
|
||||||
@ -1557,11 +1581,10 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||||||
network.
|
network.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id (str)
|
room_id
|
||||||
appservice_id (str)
|
appservice_id
|
||||||
network_id (str)
|
network_id
|
||||||
is_public (bool): Whether to publish or unpublish the room from the
|
is_public: Whether to publish or unpublish the room from the list.
|
||||||
list.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if is_public:
|
if is_public:
|
||||||
@ -1626,7 +1649,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||||||
event_report: json list of information from event report
|
event_report: json list of information from event report
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_event_report_txn(txn, report_id):
|
def _get_event_report_txn(
|
||||||
|
txn: LoggingTransaction, report_id: int
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
SELECT
|
SELECT
|
||||||
@ -1698,9 +1723,11 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||||||
count: total number of event reports matching the filter criteria
|
count: total number of event reports matching the filter criteria
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_event_reports_paginate_txn(txn):
|
def _get_event_reports_paginate_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
filters = []
|
filters = []
|
||||||
args = []
|
args: List[object] = []
|
||||||
|
|
||||||
if user_id:
|
if user_id:
|
||||||
filters.append("er.user_id LIKE ?")
|
filters.append("er.user_id LIKE ?")
|
||||||
@ -1724,7 +1751,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||||||
where_clause
|
where_clause
|
||||||
)
|
)
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
count = txn.fetchone()[0]
|
count = cast(Tuple[int], txn.fetchone())[0]
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
SELECT
|
SELECT
|
||||||
|
Loading…
Reference in New Issue
Block a user