Add some type hints to datastore (#12248)

* inherit `MonthlyActiveUsersStore` from `RegistrationWorkerStore`

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
Dirk Klimpel 2022-03-18 16:24:18 +01:00 committed by GitHub
parent 872dbb0181
commit c46065fa3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 117 additions and 84 deletions

1
changelog.d/12248.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints for storage.

View File

@ -42,9 +42,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py |synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py |synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py |synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
|synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/receipts.py |synapse/storage/databases/main/receipts.py
|synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/roommember.py
@ -87,9 +84,6 @@ exclude = (?x)
|tests/state/test_v2.py |tests/state/test_v2.py
|tests/storage/test_background_update.py |tests/storage/test_background_update.py
|tests/storage/test_base.py |tests/storage/test_base.py
|tests/storage/test_client_ips.py
|tests/storage/test_database.py
|tests/storage/test_event_federation.py
|tests/storage/test_id_generators.py |tests/storage/test_id_generators.py
|tests/storage/test_roommember.py |tests/storage/test_roommember.py
|tests/test_metrics.py |tests/test_metrics.py

View File

@ -13,13 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
from typing_extensions import TypedDict from typing_extensions import TypedDict
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
@ -75,7 +79,7 @@ class GroupServerWorkerStore(SQLBaseStore):
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
# TODO: Pagination # TODO: Pagination
keyvalues = {"group_id": group_id} keyvalues: JsonDict = {"group_id": group_id}
if not include_private: if not include_private:
keyvalues["is_public"] = True keyvalues["is_public"] = True
@ -117,7 +121,7 @@ class GroupServerWorkerStore(SQLBaseStore):
# TODO: Pagination # TODO: Pagination
def _get_rooms_in_group_txn(txn): def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]:
sql = """ sql = """
SELECT room_id, is_public FROM group_rooms SELECT room_id, is_public FROM group_rooms
WHERE group_id = ? WHERE group_id = ?
@ -176,8 +180,10 @@ class GroupServerWorkerStore(SQLBaseStore):
* "order": int, the sort order of rooms in this category * "order": int, the sort order of rooms in this category
""" """
def _get_rooms_for_summary_txn(txn): def _get_rooms_for_summary_txn(
keyvalues = {"group_id": group_id} txn: LoggingTransaction,
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
keyvalues: JsonDict = {"group_id": group_id}
if not include_private: if not include_private:
keyvalues["is_public"] = True keyvalues["is_public"] = True
@ -241,7 +247,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_rooms_for_summary", _get_rooms_for_summary_txn "get_rooms_for_summary", _get_rooms_for_summary_txn
) )
async def get_group_categories(self, group_id): async def get_group_categories(self, group_id: str) -> JsonDict:
rows = await self.db_pool.simple_select_list( rows = await self.db_pool.simple_select_list(
table="group_room_categories", table="group_room_categories",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
@ -257,7 +263,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows for row in rows
} }
async def get_group_category(self, group_id, category_id): async def get_group_category(self, group_id: str, category_id: str) -> JsonDict:
category = await self.db_pool.simple_select_one( category = await self.db_pool.simple_select_one(
table="group_room_categories", table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id}, keyvalues={"group_id": group_id, "category_id": category_id},
@ -269,7 +275,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return category return category
async def get_group_roles(self, group_id): async def get_group_roles(self, group_id: str) -> JsonDict:
rows = await self.db_pool.simple_select_list( rows = await self.db_pool.simple_select_list(
table="group_roles", table="group_roles",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
@ -285,7 +291,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows for row in rows
} }
async def get_group_role(self, group_id, role_id): async def get_group_role(self, group_id: str, role_id: str) -> JsonDict:
role = await self.db_pool.simple_select_one( role = await self.db_pool.simple_select_one(
table="group_roles", table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id}, keyvalues={"group_id": group_id, "role_id": role_id},
@ -311,15 +317,19 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_local_groups_for_room", desc="get_local_groups_for_room",
) )
async def get_users_for_summary_by_role(self, group_id, include_private=False): async def get_users_for_summary_by_role(
self, group_id: str, include_private: bool = False
) -> Tuple[List[JsonDict], JsonDict]:
"""Get the users and roles that should be included in a summary request """Get the users and roles that should be included in a summary request
Returns: Returns:
([users], [roles]) ([users], [roles])
""" """
def _get_users_for_summary_txn(txn): def _get_users_for_summary_txn(
keyvalues = {"group_id": group_id} txn: LoggingTransaction,
) -> Tuple[List[JsonDict], JsonDict]:
keyvalues: JsonDict = {"group_id": group_id}
if not include_private: if not include_private:
keyvalues["is_public"] = True keyvalues["is_public"] = True
@ -406,7 +416,9 @@ class GroupServerWorkerStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
async def get_users_membership_info_in_group(self, group_id, user_id): async def get_users_membership_info_in_group(
self, group_id: str, user_id: str
) -> JsonDict:
"""Get a dict describing the membership of a user in a group. """Get a dict describing the membership of a user in a group.
Example if joined: Example if joined:
@ -421,7 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore):
An empty dict if the user is not join/invite/etc An empty dict if the user is not join/invite/etc
""" """
def _get_users_membership_in_group_txn(txn): def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict:
row = self.db_pool.simple_select_one_txn( row = self.db_pool.simple_select_one_txn(
txn, txn,
table="group_users", table="group_users",
@ -463,10 +475,14 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_publicised_groups_for_user", desc="get_publicised_groups_for_user",
) )
async def get_attestations_need_renewals(self, valid_until_ms): async def get_attestations_need_renewals(
self, valid_until_ms: int
) -> List[Dict[str, Any]]:
"""Get all attestations that need to be renewed until givent time""" """Get all attestations that need to be renewed until givent time"""
def _get_attestations_need_renewals_txn(txn): def _get_attestations_need_renewals_txn(
txn: LoggingTransaction,
) -> List[Dict[str, Any]]:
sql = """ sql = """
SELECT group_id, user_id FROM group_attestations_renewals SELECT group_id, user_id FROM group_attestations_renewals
WHERE valid_until_ms <= ? WHERE valid_until_ms <= ?
@ -478,7 +494,9 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_attestations_need_renewals", _get_attestations_need_renewals_txn "get_attestations_need_renewals", _get_attestations_need_renewals_txn
) )
async def get_remote_attestation(self, group_id, user_id): async def get_remote_attestation(
self, group_id: str, user_id: str
) -> Optional[JsonDict]:
"""Get the attestation that proves the remote agrees that the user is """Get the attestation that proves the remote agrees that the user is
in the group. in the group.
""" """
@ -504,8 +522,8 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_joined_groups", desc="get_joined_groups",
) )
async def get_all_groups_for_user(self, user_id, now_token): async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]:
def _get_all_groups_for_user_txn(txn): def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """ sql = """
SELECT group_id, type, membership, u.content SELECT group_id, type, membership, u.content
FROM local_group_updates AS u FROM local_group_updates AS u
@ -528,15 +546,16 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_all_groups_for_user", _get_all_groups_for_user_txn "get_all_groups_for_user", _get_all_groups_for_user_txn
) )
async def get_groups_changes_for_user(self, user_id, from_token, to_token): async def get_groups_changes_for_user(
from_token = int(from_token) self, user_id: str, from_token: int, to_token: int
has_changed = self._group_updates_stream_cache.has_entity_changed( ) -> List[JsonDict]:
has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined]
user_id, from_token user_id, from_token
) )
if not has_changed: if not has_changed:
return [] return []
def _get_groups_changes_for_user_txn(txn): def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """ sql = """
SELECT group_id, membership, type, u.content SELECT group_id, membership, type, u.content
FROM local_group_updates AS u FROM local_group_updates AS u
@ -583,12 +602,14 @@ class GroupServerWorkerStore(SQLBaseStore):
""" """
last_id = int(last_id) last_id = int(last_id)
has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined]
if not has_changed: if not has_changed:
return [], current_id, False return [], current_id, False
def _get_all_groups_changes_txn(txn): def _get_all_groups_changes_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
sql = """ sql = """
SELECT stream_id, group_id, user_id, type, content SELECT stream_id, group_id, user_id, type, content
FROM local_group_updates FROM local_group_updates
@ -596,10 +617,13 @@ class GroupServerWorkerStore(SQLBaseStore):
LIMIT ? LIMIT ?
""" """
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
updates = [ updates = cast(
(stream_id, (group_id, user_id, gtype, db_to_json(content_json))) List[Tuple[int, tuple]],
for stream_id, group_id, user_id, gtype, content_json in txn [
] (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
for stream_id, group_id, user_id, gtype, content_json in txn
],
)
limited = False limited = False
upto_token = current_id upto_token = current_id
@ -633,8 +657,8 @@ class GroupServerStore(GroupServerWorkerStore):
self, self,
group_id: str, group_id: str,
room_id: str, room_id: str,
category_id: str, category_id: Optional[str],
order: int, order: Optional[int],
is_public: Optional[bool], is_public: Optional[bool],
) -> None: ) -> None:
"""Add (or update) room's entry in summary. """Add (or update) room's entry in summary.
@ -661,11 +685,11 @@ class GroupServerStore(GroupServerWorkerStore):
def _add_room_to_summary_txn( def _add_room_to_summary_txn(
self, self,
txn, txn: LoggingTransaction,
group_id: str, group_id: str,
room_id: str, room_id: str,
category_id: str, category_id: Optional[str],
order: int, order: Optional[int],
is_public: Optional[bool], is_public: Optional[bool],
) -> None: ) -> None:
"""Add (or update) room's entry in summary. """Add (or update) room's entry in summary.
@ -750,7 +774,7 @@ class GroupServerStore(GroupServerWorkerStore):
WHERE group_id = ? AND category_id = ? WHERE group_id = ? AND category_id = ?
""" """
txn.execute(sql, (group_id, category_id)) txn.execute(sql, (group_id, category_id))
(order,) = txn.fetchone() (order,) = cast(Tuple[int], txn.fetchone())
if existing: if existing:
to_update = {} to_update = {}
@ -766,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore):
"category_id": category_id, "category_id": category_id,
"room_id": room_id, "room_id": room_id,
}, },
values=to_update, updatevalues=to_update,
) )
else: else:
if is_public is None: if is_public is None:
@ -785,7 +809,7 @@ class GroupServerStore(GroupServerWorkerStore):
) )
async def remove_room_from_summary( async def remove_room_from_summary(
self, group_id: str, room_id: str, category_id: str self, group_id: str, room_id: str, category_id: Optional[str]
) -> int: ) -> int:
if category_id is None: if category_id is None:
category_id = _DEFAULT_CATEGORY_ID category_id = _DEFAULT_CATEGORY_ID
@ -808,8 +832,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_public: Optional[bool], is_public: Optional[bool],
) -> None: ) -> None:
"""Add/update room category for group""" """Add/update room category for group"""
insertion_values = {} insertion_values: JsonDict = {}
update_values = {"category_id": category_id} # This cannot be empty update_values: JsonDict = {"category_id": category_id} # This cannot be empty
if profile is None: if profile is None:
insertion_values["profile"] = "{}" insertion_values["profile"] = "{}"
@ -844,8 +868,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_public: Optional[bool], is_public: Optional[bool],
) -> None: ) -> None:
"""Add/remove user role""" """Add/remove user role"""
insertion_values = {} insertion_values: JsonDict = {}
update_values = {"role_id": role_id} # This cannot be empty update_values: JsonDict = {"role_id": role_id} # This cannot be empty
if profile is None: if profile is None:
insertion_values["profile"] = "{}" insertion_values["profile"] = "{}"
@ -876,8 +900,8 @@ class GroupServerStore(GroupServerWorkerStore):
self, self,
group_id: str, group_id: str,
user_id: str, user_id: str,
role_id: str, role_id: Optional[str],
order: int, order: Optional[int],
is_public: Optional[bool], is_public: Optional[bool],
) -> None: ) -> None:
"""Add (or update) user's entry in summary. """Add (or update) user's entry in summary.
@ -904,13 +928,13 @@ class GroupServerStore(GroupServerWorkerStore):
def _add_user_to_summary_txn( def _add_user_to_summary_txn(
self, self,
txn, txn: LoggingTransaction,
group_id: str, group_id: str,
user_id: str, user_id: str,
role_id: str, role_id: Optional[str],
order: int, order: Optional[int],
is_public: Optional[bool], is_public: Optional[bool],
): ) -> None:
"""Add (or update) user's entry in summary. """Add (or update) user's entry in summary.
Args: Args:
@ -989,7 +1013,7 @@ class GroupServerStore(GroupServerWorkerStore):
WHERE group_id = ? AND role_id = ? WHERE group_id = ? AND role_id = ?
""" """
txn.execute(sql, (group_id, role_id)) txn.execute(sql, (group_id, role_id))
(order,) = txn.fetchone() (order,) = cast(Tuple[int], txn.fetchone())
if existing: if existing:
to_update = {} to_update = {}
@ -1005,7 +1029,7 @@ class GroupServerStore(GroupServerWorkerStore):
"role_id": role_id, "role_id": role_id,
"user_id": user_id, "user_id": user_id,
}, },
values=to_update, updatevalues=to_update,
) )
else: else:
if is_public is None: if is_public is None:
@ -1024,7 +1048,7 @@ class GroupServerStore(GroupServerWorkerStore):
) )
async def remove_user_from_summary( async def remove_user_from_summary(
self, group_id: str, user_id: str, role_id: str self, group_id: str, user_id: str, role_id: Optional[str]
) -> int: ) -> int:
if role_id is None: if role_id is None:
role_id = _DEFAULT_ROLE_ID role_id = _DEFAULT_ROLE_ID
@ -1065,7 +1089,7 @@ class GroupServerStore(GroupServerWorkerStore):
Optional if the user and group are on the same server Optional if the user and group are on the same server
""" """
def _add_user_to_group_txn(txn): def _add_user_to_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="group_users", table="group_users",
@ -1108,7 +1132,7 @@ class GroupServerStore(GroupServerWorkerStore):
await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn) await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
async def remove_user_from_group(self, group_id: str, user_id: str) -> None: async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
def _remove_user_from_group_txn(txn): def _remove_user_from_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="group_users", table="group_users",
@ -1159,7 +1183,7 @@ class GroupServerStore(GroupServerWorkerStore):
) )
async def remove_room_from_group(self, group_id: str, room_id: str) -> None: async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
def _remove_room_from_group_txn(txn): def _remove_room_from_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="group_rooms", table="group_rooms",
@ -1216,7 +1240,9 @@ class GroupServerStore(GroupServerWorkerStore):
content = content or {} content = content or {}
def _register_user_group_membership_txn(txn, next_id): def _register_user_group_membership_txn(
txn: LoggingTransaction, next_id: int
) -> int:
# TODO: Upsert? # TODO: Upsert?
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
@ -1249,7 +1275,7 @@ class GroupServerStore(GroupServerWorkerStore):
), ),
}, },
) )
self._group_updates_stream_cache.entity_has_changed(user_id, next_id) self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined]
# TODO: Insert profile to ensure it comes down stream if its a join. # TODO: Insert profile to ensure it comes down stream if its a join.
@ -1289,7 +1315,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id return next_id
async with self._group_updates_id_gen.get_next() as next_id: async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined]
res = await self.db_pool.runInteraction( res = await self.db_pool.runInteraction(
"register_user_group_membership", "register_user_group_membership",
_register_user_group_membership_txn, _register_user_group_membership_txn,
@ -1298,7 +1324,13 @@ class GroupServerStore(GroupServerWorkerStore):
return res return res
async def create_group( async def create_group(
self, group_id, user_id, name, avatar_url, short_description, long_description self,
group_id: str,
user_id: str,
name: str,
avatar_url: str,
short_description: str,
long_description: str,
) -> None: ) -> None:
await self.db_pool.simple_insert( await self.db_pool.simple_insert(
table="groups", table="groups",
@ -1313,7 +1345,7 @@ class GroupServerStore(GroupServerWorkerStore):
desc="create_group", desc="create_group",
) )
async def update_group_profile(self, group_id, profile): async def update_group_profile(self, group_id: str, profile: JsonDict) -> None:
await self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
table="groups", table="groups",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
@ -1361,8 +1393,8 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_attestation_renewal", desc="remove_attestation_renewal",
) )
def get_group_stream_token(self): def get_group_stream_token(self) -> int:
return self._group_updates_id_gen.get_current_token() return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined]
async def delete_group(self, group_id: str) -> None: async def delete_group(self, group_id: str) -> None:
"""Deletes a group fully from the database. """Deletes a group fully from the database.
@ -1371,7 +1403,7 @@ class GroupServerStore(GroupServerWorkerStore):
group_id: The group ID to delete. group_id: The group ID to delete.
""" """
def _delete_group_txn(txn): def _delete_group_txn(txn: LoggingTransaction) -> None:
tables = [ tables = [
"groups", "groups",
"group_users", "group_users",

View File

@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Dict, List, Optional from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause, make_in_list_sql_clause,
) )
from synapse.storage.databases.main.registration import RegistrationWorkerStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.threepids import canonicalise_email from synapse.util.threepids import canonicalise_email
@ -56,7 +58,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
Number of current monthly active users Number of current monthly active users
""" """
def _count_users(txn): def _count_users(txn: LoggingTransaction) -> int:
# Exclude app service users # Exclude app service users
sql = """ sql = """
SELECT COUNT(*) SELECT COUNT(*)
@ -66,7 +68,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
WHERE (users.appservice_id IS NULL OR users.appservice_id = ''); WHERE (users.appservice_id IS NULL OR users.appservice_id = '');
""" """
txn.execute(sql) txn.execute(sql)
(count,) = txn.fetchone() (count,) = cast(Tuple[int], txn.fetchone())
return count return count
return await self.db_pool.runInteraction("count_users", _count_users) return await self.db_pool.runInteraction("count_users", _count_users)
@ -84,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
""" """
def _count_users_by_service(txn): def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]:
sql = """ sql = """
SELECT COALESCE(appservice_id, 'native'), COUNT(*) SELECT COALESCE(appservice_id, 'native'), COUNT(*)
FROM monthly_active_users FROM monthly_active_users
@ -93,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
""" """
txn.execute(sql) txn.execute(sql)
result = txn.fetchall() result = cast(List[Tuple[str, int]], txn.fetchall())
return dict(result) return dict(result)
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
@ -141,12 +143,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
) )
@wrap_as_background_process("reap_monthly_active_users") @wrap_as_background_process("reap_monthly_active_users")
async def reap_monthly_active_users(self): async def reap_monthly_active_users(self) -> None:
"""Cleans out monthly active user table to ensure that no stale """Cleans out monthly active user table to ensure that no stale
entries exist. entries exist.
""" """
def _reap_users(txn, reserved_users): def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None:
""" """
Args: Args:
reserved_users (tuple): reserved users to preserve reserved_users (tuple): reserved users to preserve
@ -210,10 +212,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
# is racy. # is racy.
# Have resolved to invalidate the whole cache for now and do # Have resolved to invalidate the whole cache for now and do
# something about it if and when the perf becomes significant # something about it if and when the perf becomes significant
self._invalidate_all_cache_and_stream( self._invalidate_all_cache_and_stream( # type: ignore[attr-defined]
txn, self.user_last_seen_monthly_active txn, self.user_last_seen_monthly_active
) )
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined]
reserved_users = await self.get_registered_reserved_users() reserved_users = await self.get_registered_reserved_users()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
@ -221,7 +223,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
) )
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
def __init__( def __init__(
self, self,
database: DatabasePool, database: DatabasePool,
@ -242,13 +244,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
) )
def _initialise_reserved_users(self, txn, threepids): def _initialise_reserved_users(
self, txn: LoggingTransaction, threepids: List[dict]
) -> None:
"""Ensures that reserved threepids are accounted for in the MAU table, should """Ensures that reserved threepids are accounted for in the MAU table, should
be called on start up. be called on start up.
Args: Args:
txn (cursor): txn:
threepids (list[dict]): List of threepid dicts to reserve threepids: List of threepid dicts to reserve
""" """
# XXX what is this function trying to achieve? It upserts into # XXX what is this function trying to achieve? It upserts into
@ -299,7 +303,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
) )
def upsert_monthly_active_user_txn(self, txn, user_id): def upsert_monthly_active_user_txn(
self, txn: LoggingTransaction, user_id: str
) -> None:
"""Updates or inserts monthly active user member """Updates or inserts monthly active user member
We consciously do not call is_support_txn from this method because it We consciously do not call is_support_txn from this method because it
@ -336,7 +342,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
txn, self.user_last_seen_monthly_active, (user_id,) txn, self.user_last_seen_monthly_active, (user_id,)
) )
async def populate_monthly_active_users(self, user_id): async def populate_monthly_active_users(self, user_id: str) -> None:
"""Checks on the state of monthly active user limits and optionally """Checks on the state of monthly active user limits and optionally
add the user to the monthly active tables add the user to the monthly active tables
@ -345,7 +351,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
""" """
if self._limit_usage_by_mau or self._mau_stats_only: if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group # Trial users and guests should not be included as part of MAU group
is_guest = await self.is_guest(user_id) is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]
if is_guest: if is_guest:
return return
is_trial = await self.is_trial_user(user_id) is_trial = await self.is_trial_user(user_id)