mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-07-26 12:55:17 -04:00
Add some type hints to datastore (#12717)
This commit is contained in:
parent
942c30b16b
commit
6edefef602
10 changed files with 254 additions and 161 deletions
|
@ -14,14 +14,18 @@
|
|||
# limitations under the License.
|
||||
import abc
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.config.homeserver import ExperimentalConfig
|
||||
from synapse.push.baserules import list_with_base_rules
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.pusher import PusherWorkerStore
|
||||
|
@ -30,9 +34,12 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
|||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
AbstractStreamIdTracker,
|
||||
IdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
@ -57,7 +64,11 @@ def _is_experimental_rule_enabled(
|
|||
return True
|
||||
|
||||
|
||||
def _load_rules(rawrules, enabled_map, experimental_config: ExperimentalConfig):
|
||||
def _load_rules(
|
||||
rawrules: List[JsonDict],
|
||||
enabled_map: Dict[str, bool],
|
||||
experimental_config: ExperimentalConfig,
|
||||
) -> List[JsonDict]:
|
||||
ruleslist = []
|
||||
for rawrule in rawrules:
|
||||
rule = dict(rawrule)
|
||||
|
@ -137,7 +148,7 @@ class PushRulesWorkerStore(
|
|||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_max_push_rules_stream_id(self):
|
||||
def get_max_push_rules_stream_id(self) -> int:
|
||||
"""Get the position of the push rules stream.
|
||||
|
||||
Returns:
|
||||
|
@ -146,7 +157,7 @@ class PushRulesWorkerStore(
|
|||
raise NotImplementedError()
|
||||
|
||||
@cached(max_entries=5000)
|
||||
async def get_push_rules_for_user(self, user_id):
|
||||
async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
table="push_rules",
|
||||
keyvalues={"user_name": user_id},
|
||||
|
@ -168,7 +179,7 @@ class PushRulesWorkerStore(
|
|||
return _load_rules(rows, enabled_map, self.hs.config.experimental)
|
||||
|
||||
@cached(max_entries=5000)
|
||||
async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]:
|
||||
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
|
||||
results = await self.db_pool.simple_select_list(
|
||||
table="push_rules_enable",
|
||||
keyvalues={"user_name": user_id},
|
||||
|
@ -184,13 +195,13 @@ class PushRulesWorkerStore(
|
|||
return False
|
||||
else:
|
||||
|
||||
def have_push_rules_changed_txn(txn):
|
||||
def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool:
|
||||
sql = (
|
||||
"SELECT COUNT(stream_id) FROM push_rules_stream"
|
||||
" WHERE user_id = ? AND ? < stream_id"
|
||||
)
|
||||
txn.execute(sql, (user_id, last_id))
|
||||
(count,) = txn.fetchone()
|
||||
(count,) = cast(Tuple[int], txn.fetchone())
|
||||
return bool(count)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
|
@ -202,11 +213,13 @@ class PushRulesWorkerStore(
|
|||
list_name="user_ids",
|
||||
num_args=1,
|
||||
)
|
||||
async def bulk_get_push_rules(self, user_ids):
|
||||
async def bulk_get_push_rules(
|
||||
self, user_ids: Collection[str]
|
||||
) -> Dict[str, List[JsonDict]]:
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
results = {user_id: [] for user_id in user_ids}
|
||||
results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
|
||||
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="push_rules",
|
||||
|
@ -250,7 +263,7 @@ class PushRulesWorkerStore(
|
|||
condition["pattern"] = new_room_id
|
||||
|
||||
# Add the rule for the new room
|
||||
await self.add_push_rule(
|
||||
await self.add_push_rule( # type: ignore[attr-defined]
|
||||
user_id=user_id,
|
||||
rule_id=new_rule_id,
|
||||
priority_class=rule["priority_class"],
|
||||
|
@ -286,11 +299,13 @@ class PushRulesWorkerStore(
|
|||
list_name="user_ids",
|
||||
num_args=1,
|
||||
)
|
||||
async def bulk_get_push_rules_enabled(self, user_ids):
|
||||
async def bulk_get_push_rules_enabled(
|
||||
self, user_ids: Collection[str]
|
||||
) -> Dict[str, Dict[str, bool]]:
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
results = {user_id: {} for user_id in user_ids}
|
||||
results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
|
||||
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="push_rules_enable",
|
||||
|
@ -306,7 +321,7 @@ class PushRulesWorkerStore(
|
|||
|
||||
async def get_all_push_rule_updates(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
|
||||
"""Get updates for push_rules replication stream.
|
||||
|
||||
Args:
|
||||
|
@ -331,7 +346,9 @@ class PushRulesWorkerStore(
|
|||
if last_id == current_id:
|
||||
return [], current_id, False
|
||||
|
||||
def get_all_push_rule_updates_txn(txn):
|
||||
def get_all_push_rule_updates_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
|
||||
sql = """
|
||||
SELECT stream_id, user_id
|
||||
FROM push_rules_stream
|
||||
|
@ -340,7 +357,10 @@ class PushRulesWorkerStore(
|
|||
LIMIT ?
|
||||
"""
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
updates = [(stream_id, (user_id,)) for stream_id, user_id in txn]
|
||||
updates = cast(
|
||||
List[Tuple[int, Tuple[str]]],
|
||||
[(stream_id, (user_id,)) for stream_id, user_id in txn],
|
||||
)
|
||||
|
||||
limited = False
|
||||
upper_bound = current_id
|
||||
|
@ -356,15 +376,30 @@ class PushRulesWorkerStore(
|
|||
|
||||
|
||||
class PushRuleStore(PushRulesWorkerStore):
|
||||
# Because we have write access, this will be a StreamIdGenerator
|
||||
# (see PushRulesWorkerStore.__init__)
|
||||
_push_rules_stream_id_gen: AbstractStreamIdGenerator
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
||||
|
||||
async def add_push_rule(
|
||||
self,
|
||||
user_id,
|
||||
rule_id,
|
||||
priority_class,
|
||||
conditions,
|
||||
actions,
|
||||
before=None,
|
||||
after=None,
|
||||
user_id: str,
|
||||
rule_id: str,
|
||||
priority_class: int,
|
||||
conditions: List[Dict[str, str]],
|
||||
actions: List[Union[JsonDict, str]],
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
) -> None:
|
||||
conditions_json = json_encoder.encode(conditions)
|
||||
actions_json = json_encoder.encode(actions)
|
||||
|
@ -400,17 +435,17 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
|
||||
def _add_push_rule_relative_txn(
|
||||
self,
|
||||
txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
priority_class,
|
||||
conditions_json,
|
||||
actions_json,
|
||||
before,
|
||||
after,
|
||||
):
|
||||
txn: LoggingTransaction,
|
||||
stream_id: int,
|
||||
event_stream_ordering: int,
|
||||
user_id: str,
|
||||
rule_id: str,
|
||||
priority_class: int,
|
||||
conditions_json: str,
|
||||
actions_json: str,
|
||||
before: str,
|
||||
after: str,
|
||||
) -> None:
|
||||
# Lock the table since otherwise we'll have annoying races between the
|
||||
# SELECT here and the UPSERT below.
|
||||
self.database_engine.lock_table(txn, "push_rules")
|
||||
|
@ -470,15 +505,15 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
|
||||
def _add_push_rule_highest_priority_txn(
|
||||
self,
|
||||
txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
priority_class,
|
||||
conditions_json,
|
||||
actions_json,
|
||||
):
|
||||
txn: LoggingTransaction,
|
||||
stream_id: int,
|
||||
event_stream_ordering: int,
|
||||
user_id: str,
|
||||
rule_id: str,
|
||||
priority_class: int,
|
||||
conditions_json: str,
|
||||
actions_json: str,
|
||||
) -> None:
|
||||
# Lock the table since otherwise we'll have annoying races between the
|
||||
# SELECT here and the UPSERT below.
|
||||
self.database_engine.lock_table(txn, "push_rules")
|
||||
|
@ -510,17 +545,17 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
|
||||
def _upsert_push_rule_txn(
|
||||
self,
|
||||
txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
priority_class,
|
||||
priority,
|
||||
conditions_json,
|
||||
actions_json,
|
||||
update_stream=True,
|
||||
):
|
||||
txn: LoggingTransaction,
|
||||
stream_id: int,
|
||||
event_stream_ordering: int,
|
||||
user_id: str,
|
||||
rule_id: str,
|
||||
priority_class: int,
|
||||
priority: int,
|
||||
conditions_json: str,
|
||||
actions_json: str,
|
||||
update_stream: bool = True,
|
||||
) -> None:
|
||||
"""Specialised version of simple_upsert_txn that picks a push_rule_id
|
||||
using the _push_rule_id_gen if it needs to insert the rule. It assumes
|
||||
that the "push_rules" table is locked"""
|
||||
|
@ -600,7 +635,11 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
rule_id: The rule_id of the rule to be deleted
|
||||
"""
|
||||
|
||||
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
|
||||
def delete_push_rule_txn(
|
||||
txn: LoggingTransaction,
|
||||
stream_id: int,
|
||||
event_stream_ordering: int,
|
||||
) -> None:
|
||||
# we don't use simple_delete_one_txn because that would fail if the
|
||||
# user did not have a push_rule_enable row.
|
||||
self.db_pool.simple_delete_txn(
|
||||
|
@ -661,14 +700,14 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
|
||||
def _set_push_rule_enabled_txn(
|
||||
self,
|
||||
txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
enabled,
|
||||
is_default_rule,
|
||||
):
|
||||
txn: LoggingTransaction,
|
||||
stream_id: int,
|
||||
event_stream_ordering: int,
|
||||
user_id: str,
|
||||
rule_id: str,
|
||||
enabled: bool,
|
||||
is_default_rule: bool,
|
||||
) -> None:
|
||||
new_id = self._push_rules_enable_id_gen.get_next()
|
||||
|
||||
if not is_default_rule:
|
||||
|
@ -740,7 +779,11 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
"""
|
||||
actions_json = json_encoder.encode(actions)
|
||||
|
||||
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
|
||||
def set_push_rule_actions_txn(
|
||||
txn: LoggingTransaction,
|
||||
stream_id: int,
|
||||
event_stream_ordering: int,
|
||||
) -> None:
|
||||
if is_default_rule:
|
||||
# Add a dummy rule to the rules table with the user specified
|
||||
# actions.
|
||||
|
@ -794,8 +837,15 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
)
|
||||
|
||||
def _insert_push_rules_update_txn(
|
||||
self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
stream_id: int,
|
||||
event_stream_ordering: int,
|
||||
user_id: str,
|
||||
rule_id: str,
|
||||
op: str,
|
||||
data: Optional[JsonDict] = None,
|
||||
) -> None:
|
||||
values = {
|
||||
"stream_id": stream_id,
|
||||
"event_stream_ordering": event_stream_ordering,
|
||||
|
@ -814,5 +864,5 @@ class PushRuleStore(PushRulesWorkerStore):
|
|||
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
|
||||
)
|
||||
|
||||
def get_max_push_rules_stream_id(self):
|
||||
def get_max_push_rules_stream_id(self) -> int:
|
||||
return self._push_rules_stream_id_gen.get_current_token()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue