mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-18 16:47:07 -05:00
Add some type hints to datastore (#12485)
This commit is contained in:
parent
63ba9ba38b
commit
b76f1a4d5f
1
changelog.d/12485.misc
Normal file
1
changelog.d/12485.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add some type hints to datastore.
|
@ -15,12 +15,17 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
|
||||||
|
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
from synapse.storage.database import (
|
||||||
|
DatabasePool,
|
||||||
|
LoggingDatabaseConnection,
|
||||||
|
LoggingTransaction,
|
||||||
|
)
|
||||||
from synapse.storage.databases.main.stats import UserSortOrder
|
from synapse.storage.databases.main.stats import UserSortOrder
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||||
|
from synapse.storage.types import Cursor
|
||||||
from synapse.storage.util.id_generators import (
|
from synapse.storage.util.id_generators import (
|
||||||
IdGenerator,
|
IdGenerator,
|
||||||
MultiWriterIdGenerator,
|
MultiWriterIdGenerator,
|
||||||
@ -266,7 +271,9 @@ class DataStore(
|
|||||||
A tuple of a list of mappings from user to information and a count of total users.
|
A tuple of a list of mappings from user to information and a count of total users.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_users_paginate_txn(txn):
|
def get_users_paginate_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[List[JsonDict], int]:
|
||||||
filters = []
|
filters = []
|
||||||
args = [self.hs.config.server.server_name]
|
args = [self.hs.config.server.server_name]
|
||||||
|
|
||||||
@ -301,7 +308,7 @@ class DataStore(
|
|||||||
"""
|
"""
|
||||||
sql = "SELECT COUNT(*) as total_users " + sql_base
|
sql = "SELECT COUNT(*) as total_users " + sql_base
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
count = txn.fetchone()[0]
|
count = cast(Tuple[int], txn.fetchone())[0]
|
||||||
|
|
||||||
sql = f"""
|
sql = f"""
|
||||||
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
|
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
|
||||||
@ -338,7 +345,9 @@ class DataStore(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig):
|
def check_database_before_upgrade(
|
||||||
|
cur: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
|
||||||
|
) -> None:
|
||||||
"""Called before upgrading an existing database to check that it is broadly sane
|
"""Called before upgrading an existing database to check that it is broadly sane
|
||||||
compared with the configuration.
|
compared with the configuration.
|
||||||
"""
|
"""
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, cast
|
||||||
|
|
||||||
from synapse.appservice import (
|
from synapse.appservice import (
|
||||||
ApplicationService,
|
ApplicationService,
|
||||||
@ -83,7 +83,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
|
|||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT COALESCE(max(txn_id), 0) FROM application_services_txns"
|
"SELECT COALESCE(max(txn_id), 0) FROM application_services_txns"
|
||||||
)
|
)
|
||||||
return txn.fetchone()[0] # type: ignore
|
return cast(Tuple[int], txn.fetchone())[0]
|
||||||
|
|
||||||
self._as_txn_seq_gen = build_sequence_generator(
|
self._as_txn_seq_gen = build_sequence_generator(
|
||||||
db_conn,
|
db_conn,
|
||||||
|
@ -14,7 +14,17 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
from synapse.logging import issue9533_logger
|
from synapse.logging import issue9533_logger
|
||||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||||
@ -118,7 +128,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
prefilled_cache=device_outbox_prefill,
|
prefilled_cache=device_outbox_prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
def process_replication_rows(
|
||||||
|
self,
|
||||||
|
stream_name: str,
|
||||||
|
instance_name: str,
|
||||||
|
token: int,
|
||||||
|
rows: Iterable[ToDeviceStream.ToDeviceStreamRow],
|
||||||
|
) -> None:
|
||||||
if stream_name == ToDeviceStream.NAME:
|
if stream_name == ToDeviceStream.NAME:
|
||||||
# If replication is happening than postgres must be being used.
|
# If replication is happening than postgres must be being used.
|
||||||
assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
|
assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
|
||||||
@ -134,7 +150,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
|
|
||||||
def get_to_device_stream_token(self):
|
def get_to_device_stream_token(self) -> int:
|
||||||
return self._device_inbox_id_gen.get_current_token()
|
return self._device_inbox_id_gen.get_current_token()
|
||||||
|
|
||||||
async def get_messages_for_user_devices(
|
async def get_messages_for_user_devices(
|
||||||
@ -301,7 +317,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
if not user_ids_to_query:
|
if not user_ids_to_query:
|
||||||
return {}, to_stream_id
|
return {}, to_stream_id
|
||||||
|
|
||||||
def get_device_messages_txn(txn: LoggingTransaction):
|
def get_device_messages_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
|
||||||
# Build a query to select messages from any of the given devices that
|
# Build a query to select messages from any of the given devices that
|
||||||
# are between the given stream id bounds.
|
# are between the given stream id bounds.
|
||||||
|
|
||||||
@ -428,7 +446,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
log_kv({"message": "No changes in cache since last check"})
|
log_kv({"message": "No changes in cache since last check"})
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def delete_messages_for_device_txn(txn):
|
def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
|
||||||
sql = (
|
sql = (
|
||||||
"DELETE FROM device_inbox"
|
"DELETE FROM device_inbox"
|
||||||
" WHERE user_id = ? AND device_id = ?"
|
" WHERE user_id = ? AND device_id = ?"
|
||||||
@ -455,15 +473,14 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def get_new_device_msgs_for_remote(
|
async def get_new_device_msgs_for_remote(
|
||||||
self, destination, last_stream_id, current_stream_id, limit
|
self, destination: str, last_stream_id: int, current_stream_id: int, limit: int
|
||||||
) -> Tuple[List[dict], int]:
|
) -> Tuple[List[JsonDict], int]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
destination(str): The name of the remote server.
|
destination: The name of the remote server.
|
||||||
last_stream_id(int|long): The last position of the device message stream
|
last_stream_id: The last position of the device message stream
|
||||||
that the server sent up to.
|
that the server sent up to.
|
||||||
current_stream_id(int|long): The current position of the device
|
current_stream_id: The current position of the device message stream.
|
||||||
message stream.
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of messages for the device and where in the stream the messages got to.
|
A list of messages for the device and where in the stream the messages got to.
|
||||||
"""
|
"""
|
||||||
@ -485,7 +502,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
return [], last_stream_id
|
return [], last_stream_id
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
def get_new_messages_for_remote_destination_txn(txn):
|
def get_new_messages_for_remote_destination_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[List[JsonDict], int]:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT stream_id, messages_json FROM device_federation_outbox"
|
"SELECT stream_id, messages_json FROM device_federation_outbox"
|
||||||
" WHERE destination = ?"
|
" WHERE destination = ?"
|
||||||
@ -527,7 +546,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
up_to_stream_id: Where to delete messages up to.
|
up_to_stream_id: Where to delete messages up to.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def delete_messages_for_remote_destination_txn(txn):
|
def delete_messages_for_remote_destination_txn(txn: LoggingTransaction) -> None:
|
||||||
sql = (
|
sql = (
|
||||||
"DELETE FROM device_federation_outbox"
|
"DELETE FROM device_federation_outbox"
|
||||||
" WHERE destination = ?"
|
" WHERE destination = ?"
|
||||||
@ -566,7 +585,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
if last_id == current_id:
|
if last_id == current_id:
|
||||||
return [], current_id, False
|
return [], current_id, False
|
||||||
|
|
||||||
def get_all_new_device_messages_txn(txn):
|
def get_all_new_device_messages_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||||
# We limit like this as we might have multiple rows per stream_id, and
|
# We limit like this as we might have multiple rows per stream_id, and
|
||||||
# we want to make sure we always get all entries for any stream_id
|
# we want to make sure we always get all entries for any stream_id
|
||||||
# we return.
|
# we return.
|
||||||
@ -607,8 +628,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
@trace
|
@trace
|
||||||
async def add_messages_to_device_inbox(
|
async def add_messages_to_device_inbox(
|
||||||
self,
|
self,
|
||||||
local_messages_by_user_then_device: dict,
|
local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
|
||||||
remote_messages_by_destination: dict,
|
remote_messages_by_destination: Dict[str, JsonDict],
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Used to send messages from this server.
|
"""Used to send messages from this server.
|
||||||
|
|
||||||
@ -624,7 +645,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
assert self._can_write_to_device
|
assert self._can_write_to_device
|
||||||
|
|
||||||
def add_messages_txn(txn, now_ms, stream_id):
|
def add_messages_txn(
|
||||||
|
txn: LoggingTransaction, now_ms: int, stream_id: int
|
||||||
|
) -> None:
|
||||||
# Add the local messages directly to the local inbox.
|
# Add the local messages directly to the local inbox.
|
||||||
self._add_messages_to_local_device_inbox_txn(
|
self._add_messages_to_local_device_inbox_txn(
|
||||||
txn, stream_id, local_messages_by_user_then_device
|
txn, stream_id, local_messages_by_user_then_device
|
||||||
@ -677,11 +700,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
return self._device_inbox_id_gen.get_current_token()
|
return self._device_inbox_id_gen.get_current_token()
|
||||||
|
|
||||||
async def add_messages_from_remote_to_device_inbox(
|
async def add_messages_from_remote_to_device_inbox(
|
||||||
self, origin: str, message_id: str, local_messages_by_user_then_device: dict
|
self,
|
||||||
|
origin: str,
|
||||||
|
message_id: str,
|
||||||
|
local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
|
||||||
) -> int:
|
) -> int:
|
||||||
assert self._can_write_to_device
|
assert self._can_write_to_device
|
||||||
|
|
||||||
def add_messages_txn(txn, now_ms, stream_id):
|
def add_messages_txn(
|
||||||
|
txn: LoggingTransaction, now_ms: int, stream_id: int
|
||||||
|
) -> None:
|
||||||
# Check if we've already inserted a matching message_id for that
|
# Check if we've already inserted a matching message_id for that
|
||||||
# origin. This can happen if the origin doesn't receive our
|
# origin. This can happen if the origin doesn't receive our
|
||||||
# acknowledgement from the first time we received the message.
|
# acknowledgement from the first time we received the message.
|
||||||
@ -727,8 +755,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
|||||||
return stream_id
|
return stream_id
|
||||||
|
|
||||||
def _add_messages_to_local_device_inbox_txn(
|
def _add_messages_to_local_device_inbox_txn(
|
||||||
self, txn, stream_id, messages_by_user_then_device
|
self,
|
||||||
):
|
txn: LoggingTransaction,
|
||||||
|
stream_id: int,
|
||||||
|
messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
|
||||||
|
) -> None:
|
||||||
assert self._can_write_to_device
|
assert self._can_write_to_device
|
||||||
|
|
||||||
local_by_user_then_device = {}
|
local_by_user_then_device = {}
|
||||||
@ -840,8 +871,10 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
|||||||
self._remove_dead_devices_from_device_inbox,
|
self._remove_dead_devices_from_device_inbox,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _background_drop_index_device_inbox(self, progress, batch_size):
|
async def _background_drop_index_device_inbox(
|
||||||
def reindex_txn(conn):
|
self, progress: JsonDict, batch_size: int
|
||||||
|
) -> int:
|
||||||
|
def reindex_txn(conn: LoggingDatabaseConnection) -> None:
|
||||||
txn = conn.cursor()
|
txn = conn.cursor()
|
||||||
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
|
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
|
||||||
txn.close()
|
txn.close()
|
||||||
|
@ -25,6 +25,7 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from synapse.api.errors import Codes, StoreError
|
from synapse.api.errors import Codes, StoreError
|
||||||
@ -136,7 +137,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
Number of devices of this users.
|
Number of devices of this users.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def count_devices_by_users_txn(txn, user_ids):
|
def count_devices_by_users_txn(
|
||||||
|
txn: LoggingTransaction, user_ids: List[str]
|
||||||
|
) -> int:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT count(*)
|
SELECT count(*)
|
||||||
FROM devices
|
FROM devices
|
||||||
@ -149,7 +152,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
txn.execute(sql + clause, args)
|
txn.execute(sql + clause, args)
|
||||||
return txn.fetchone()[0]
|
return cast(Tuple[int], txn.fetchone())[0]
|
||||||
|
|
||||||
if not user_ids:
|
if not user_ids:
|
||||||
return 0
|
return 0
|
||||||
@ -468,7 +471,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
"""
|
"""
|
||||||
txn.execute(sql, (destination, from_stream_id, now_stream_id, limit))
|
txn.execute(sql, (destination, from_stream_id, now_stream_id, limit))
|
||||||
|
|
||||||
return list(txn)
|
return cast(List[Tuple[str, str, int, Optional[str]]], txn.fetchall())
|
||||||
|
|
||||||
async def _get_device_update_edus_by_remote(
|
async def _get_device_update_edus_by_remote(
|
||||||
self,
|
self,
|
||||||
@ -549,7 +552,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
async def _get_last_device_update_for_remote_user(
|
async def _get_last_device_update_for_remote_user(
|
||||||
self, destination: str, user_id: str, from_stream_id: int
|
self, destination: str, user_id: str, from_stream_id: int
|
||||||
) -> int:
|
) -> int:
|
||||||
def f(txn):
|
def f(txn: LoggingTransaction) -> int:
|
||||||
prev_sent_id_sql = """
|
prev_sent_id_sql = """
|
||||||
SELECT coalesce(max(stream_id), 0) as stream_id
|
SELECT coalesce(max(stream_id), 0) as stream_id
|
||||||
FROM device_lists_outbound_last_success
|
FROM device_lists_outbound_last_success
|
||||||
@ -767,7 +770,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
if not user_ids_to_check:
|
if not user_ids_to_check:
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
def _get_users_whose_devices_changed_txn(txn):
|
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
|
||||||
changes = set()
|
changes = set()
|
||||||
|
|
||||||
stream_id_where_clause = "stream_id > ?"
|
stream_id_where_clause = "stream_id > ?"
|
||||||
@ -966,7 +969,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
|
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
|
||||||
"""Mark that we no longer track device lists for remote user."""
|
"""Mark that we no longer track device lists for remote user."""
|
||||||
|
|
||||||
def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
|
def _mark_remote_user_device_list_as_unsubscribed_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> None:
|
||||||
self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
txn,
|
txn,
|
||||||
table="device_lists_remote_extremeties",
|
table="device_lists_remote_extremeties",
|
||||||
@ -1004,7 +1009,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _store_dehydrated_device_txn(
|
def _store_dehydrated_device_txn(
|
||||||
self, txn, user_id: str, device_id: str, device_data: str
|
self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
old_device_id = self.db_pool.simple_select_one_onecol_txn(
|
old_device_id = self.db_pool.simple_select_one_onecol_txn(
|
||||||
txn,
|
txn,
|
||||||
@ -1081,7 +1086,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
"""
|
"""
|
||||||
yesterday = self._clock.time_msec() - prune_age
|
yesterday = self._clock.time_msec() - prune_age
|
||||||
|
|
||||||
def _prune_txn(txn):
|
def _prune_txn(txn: LoggingTransaction) -> None:
|
||||||
# look for (user, destination) pairs which have an update older than
|
# look for (user, destination) pairs which have an update older than
|
||||||
# the cutoff.
|
# the cutoff.
|
||||||
#
|
#
|
||||||
@ -1204,8 +1209,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|||||||
"drop_device_lists_outbound_last_success_non_unique_idx",
|
"drop_device_lists_outbound_last_success_non_unique_idx",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
|
async def _drop_device_list_streams_non_unique_indexes(
|
||||||
def f(conn):
|
self, progress: JsonDict, batch_size: int
|
||||||
|
) -> int:
|
||||||
|
def f(conn: LoggingDatabaseConnection) -> None:
|
||||||
txn = conn.cursor()
|
txn = conn.cursor()
|
||||||
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
|
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
|
||||||
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
|
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
|
||||||
@ -1217,7 +1224,9 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
async def _remove_duplicate_outbound_pokes(self, progress, batch_size):
|
async def _remove_duplicate_outbound_pokes(
|
||||||
|
self, progress: JsonDict, batch_size: int
|
||||||
|
) -> int:
|
||||||
# for some reason, we have accumulated duplicate entries in
|
# for some reason, we have accumulated duplicate entries in
|
||||||
# device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
|
# device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
|
||||||
# efficient.
|
# efficient.
|
||||||
@ -1230,7 +1239,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|||||||
{"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
|
{"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _txn(txn):
|
def _txn(txn: LoggingTransaction) -> int:
|
||||||
clause, args = make_tuple_comparison_clause(
|
clause, args = make_tuple_comparison_clause(
|
||||||
[(x, last_row[x]) for x in KEY_COLS]
|
[(x, last_row[x]) for x in KEY_COLS]
|
||||||
)
|
)
|
||||||
@ -1602,7 +1611,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
|
|
||||||
context = get_active_span_text_map()
|
context = get_active_span_text_map()
|
||||||
|
|
||||||
def add_device_changes_txn(txn, stream_ids):
|
def add_device_changes_txn(
|
||||||
|
txn: LoggingTransaction, stream_ids: List[int]
|
||||||
|
) -> None:
|
||||||
self._add_device_change_to_stream_txn(
|
self._add_device_change_to_stream_txn(
|
||||||
txn,
|
txn,
|
||||||
user_id,
|
user_id,
|
||||||
@ -1635,8 +1646,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
device_ids: Collection[str],
|
device_ids: Collection[str],
|
||||||
stream_ids: List[str],
|
stream_ids: List[int],
|
||||||
):
|
) -> None:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self._device_list_stream_cache.entity_has_changed,
|
self._device_list_stream_cache.entity_has_changed,
|
||||||
user_id,
|
user_id,
|
||||||
@ -1720,7 +1731,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
device_ids: Iterable[str],
|
device_ids: Iterable[str],
|
||||||
room_ids: Collection[str],
|
room_ids: Collection[str],
|
||||||
stream_ids: List[str],
|
stream_ids: List[int],
|
||||||
context: Dict[str, str],
|
context: Dict[str, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Record the user in the room has updated their device."""
|
"""Record the user in the room has updated their device."""
|
||||||
@ -1775,7 +1786,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
LIMIT ?
|
LIMIT ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_uncoverted_outbound_room_pokes_txn(txn):
|
def get_uncoverted_outbound_room_pokes_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
|
||||||
txn.execute(sql, (limit,))
|
txn.execute(sql, (limit,))
|
||||||
|
|
||||||
return [
|
return [
|
||||||
@ -1808,7 +1821,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
Marks the associated row in `device_lists_changes_in_room` as handled.
|
Marks the associated row in `device_lists_changes_in_room` as handled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
|
def add_device_list_outbound_pokes_txn(
|
||||||
|
txn: LoggingTransaction, stream_ids: List[int]
|
||||||
|
) -> None:
|
||||||
if hosts:
|
if hosts:
|
||||||
self._add_device_outbound_poke_to_stream_txn(
|
self._add_device_outbound_poke_to_stream_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -522,7 +522,9 @@ class GroupServerWorkerStore(SQLBaseStore):
|
|||||||
desc="get_joined_groups",
|
desc="get_joined_groups",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]:
|
async def get_all_groups_for_user(
|
||||||
|
self, user_id: str, now_token: int
|
||||||
|
) -> List[JsonDict]:
|
||||||
def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
|
def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT group_id, type, membership, u.content
|
SELECT group_id, type, membership, u.content
|
||||||
|
@ -15,11 +15,12 @@
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Iterable, List, Optional, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
from synapse.storage.database import LoggingTransaction
|
||||||
from synapse.storage.keys import FetchKeyResult
|
from synapse.storage.keys import FetchKeyResult
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
@ -35,7 +36,9 @@ class KeyStore(SQLBaseStore):
|
|||||||
"""Persistence for signature verification keys"""
|
"""Persistence for signature verification keys"""
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def _get_server_verify_key(self, server_name_and_key_id):
|
def _get_server_verify_key(
|
||||||
|
self, server_name_and_key_id: Tuple[str, str]
|
||||||
|
) -> FetchKeyResult:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cachedList(
|
@cachedList(
|
||||||
@ -179,19 +182,21 @@ class KeyStore(SQLBaseStore):
|
|||||||
|
|
||||||
async def get_server_keys_json(
|
async def get_server_keys_json(
|
||||||
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
|
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
|
||||||
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
|
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
|
||||||
"""Retrieve the key json for a list of server_keys and key ids.
|
"""Retrieve the key json for a list of server_keys and key ids.
|
||||||
If no keys are found for a given server, key_id and source then
|
If no keys are found for a given server, key_id and source then
|
||||||
that server, key_id, and source triplet entry will be an empty list.
|
that server, key_id, and source triplet entry will be an empty list.
|
||||||
The JSON is returned as a byte array so that it can be efficiently
|
The JSON is returned as a byte array so that it can be efficiently
|
||||||
used in an HTTP response.
|
used in an HTTP response.
|
||||||
Args:
|
Args:
|
||||||
server_keys (list): List of (server_name, key_id, source) triplets.
|
server_keys: List of (server_name, key_id, source) triplets.
|
||||||
Returns:
|
Returns:
|
||||||
A mapping from (server_name, key_id, source) triplets to a list of dicts
|
A mapping from (server_name, key_id, source) triplets to a list of dicts
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_server_keys_json_txn(txn):
|
def _get_server_keys_json_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
|
||||||
results = {}
|
results = {}
|
||||||
for server_name, key_id, from_server in server_keys:
|
for server_name, key_id, from_server in server_keys:
|
||||||
keyvalues = {"server_name": server_name}
|
keyvalues = {"server_name": server_name}
|
||||||
|
@ -388,7 +388,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||||||
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
|
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
|
||||||
|
|
||||||
async def store_url_cache(
|
async def store_url_cache(
|
||||||
self, url, response_code, etag, expires_ts, og, media_id, download_ts
|
self,
|
||||||
|
url: str,
|
||||||
|
response_code: int,
|
||||||
|
etag: Optional[str],
|
||||||
|
expires_ts: int,
|
||||||
|
og: Optional[str],
|
||||||
|
media_id: str,
|
||||||
|
download_ts: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.db_pool.simple_insert(
|
await self.db_pool.simple_insert(
|
||||||
"local_media_repository_url_cache",
|
"local_media_repository_url_cache",
|
||||||
@ -441,7 +448,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def get_cached_remote_media(
|
async def get_cached_remote_media(
|
||||||
self, origin, media_id: str
|
self, origin: str, media_id: str
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
return await self.db_pool.simple_select_one(
|
return await self.db_pool.simple_select_one(
|
||||||
"remote_media_cache",
|
"remote_media_cache",
|
||||||
@ -608,7 +615,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
|
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
|
||||||
def delete_remote_media_txn(txn):
|
def delete_remote_media_txn(txn: LoggingTransaction) -> None:
|
||||||
self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
txn,
|
txn,
|
||||||
"remote_media_cache",
|
"remote_media_cache",
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, cast
|
||||||
|
|
||||||
from synapse.api.presence import PresenceState, UserPresenceState
|
from synapse.api.presence import PresenceState, UserPresenceState
|
||||||
from synapse.replication.tcp.streams import PresenceStream
|
from synapse.replication.tcp.streams import PresenceStream
|
||||||
@ -103,7 +103,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
|
|||||||
prefilled_cache=presence_cache_prefill,
|
prefilled_cache=presence_cache_prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def update_presence(self, presence_states) -> Tuple[int, int]:
|
async def update_presence(
|
||||||
|
self, presence_states: List[UserPresenceState]
|
||||||
|
) -> Tuple[int, int]:
|
||||||
assert self._can_persist_presence
|
assert self._can_persist_presence
|
||||||
|
|
||||||
stream_ordering_manager = self._presence_id_gen.get_next_mult(
|
stream_ordering_manager = self._presence_id_gen.get_next_mult(
|
||||||
@ -121,7 +123,10 @@ class PresenceStore(PresenceBackgroundUpdateStore):
|
|||||||
return stream_orderings[-1], self._presence_id_gen.get_current_token()
|
return stream_orderings[-1], self._presence_id_gen.get_current_token()
|
||||||
|
|
||||||
def _update_presence_txn(
|
def _update_presence_txn(
|
||||||
self, txn: LoggingTransaction, stream_orderings, presence_states
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
stream_orderings: List[int],
|
||||||
|
presence_states: List[UserPresenceState],
|
||||||
) -> None:
|
) -> None:
|
||||||
for stream_id, state in zip(stream_orderings, presence_states):
|
for stream_id, state in zip(stream_orderings, presence_states):
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
@ -405,7 +410,13 @@ class PresenceStore(PresenceBackgroundUpdateStore):
|
|||||||
self._presence_on_startup = []
|
self._presence_on_startup = []
|
||||||
return active_on_startup
|
return active_on_startup
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
|
def process_replication_rows(
|
||||||
|
self,
|
||||||
|
stream_name: str,
|
||||||
|
instance_name: str,
|
||||||
|
token: int,
|
||||||
|
rows: Iterable[Any],
|
||||||
|
) -> None:
|
||||||
if stream_name == PresenceStream.NAME:
|
if stream_name == PresenceStream.NAME:
|
||||||
self._presence_id_gen.advance(instance_name, token)
|
self._presence_id_gen.advance(instance_name, token)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
@ -14,11 +14,25 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
from synapse.push import PusherConfig, ThrottleParams
|
from synapse.push import PusherConfig, ThrottleParams
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
from synapse.storage.database import (
|
||||||
|
DatabasePool,
|
||||||
|
LoggingDatabaseConnection,
|
||||||
|
LoggingTransaction,
|
||||||
|
)
|
||||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
@ -117,7 +131,7 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
return self._decode_pushers_rows(ret)
|
return self._decode_pushers_rows(ret)
|
||||||
|
|
||||||
async def get_all_pushers(self) -> Iterator[PusherConfig]:
|
async def get_all_pushers(self) -> Iterator[PusherConfig]:
|
||||||
def get_pushers(txn):
|
def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]:
|
||||||
txn.execute("SELECT * FROM pushers")
|
txn.execute("SELECT * FROM pushers")
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
rows = self.db_pool.cursor_to_dict(txn)
|
||||||
|
|
||||||
@ -152,7 +166,9 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
if last_id == current_id:
|
if last_id == current_id:
|
||||||
return [], current_id, False
|
return [], current_id, False
|
||||||
|
|
||||||
def get_all_updated_pushers_rows_txn(txn):
|
def get_all_updated_pushers_rows_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT id, user_name, app_id, pushkey
|
SELECT id, user_name, app_id, pushkey
|
||||||
FROM pushers
|
FROM pushers
|
||||||
@ -160,10 +176,13 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
ORDER BY id ASC LIMIT ?
|
ORDER BY id ASC LIMIT ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (last_id, current_id, limit))
|
txn.execute(sql, (last_id, current_id, limit))
|
||||||
updates = [
|
updates = cast(
|
||||||
(stream_id, (user_name, app_id, pushkey, False))
|
List[Tuple[int, tuple]],
|
||||||
for stream_id, user_name, app_id, pushkey in txn
|
[
|
||||||
]
|
(stream_id, (user_name, app_id, pushkey, False))
|
||||||
|
for stream_id, user_name, app_id, pushkey in txn
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
SELECT stream_id, user_id, app_id, pushkey
|
SELECT stream_id, user_id, app_id, pushkey
|
||||||
@ -192,12 +211,12 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@cached(num_args=1, max_entries=15000)
|
@cached(num_args=1, max_entries=15000)
|
||||||
async def get_if_user_has_pusher(self, user_id: str):
|
async def get_if_user_has_pusher(self, user_id: str) -> None:
|
||||||
# This only exists for the cachedList decorator
|
# This only exists for the cachedList decorator
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
async def update_pusher_last_stream_ordering(
|
async def update_pusher_last_stream_ordering(
|
||||||
self, app_id, pushkey, user_id, last_stream_ordering
|
self, app_id: str, pushkey: str, user_id: str, last_stream_ordering: int
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.db_pool.simple_update_one(
|
await self.db_pool.simple_update_one(
|
||||||
"pushers",
|
"pushers",
|
||||||
@ -291,7 +310,7 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
last_user = progress.get("last_user", "")
|
last_user = progress.get("last_user", "")
|
||||||
|
|
||||||
def _delete_pushers(txn) -> int:
|
def _delete_pushers(txn: LoggingTransaction) -> int:
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
SELECT name FROM users
|
SELECT name FROM users
|
||||||
@ -339,7 +358,7 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
last_pusher = progress.get("last_pusher", 0)
|
last_pusher = progress.get("last_pusher", 0)
|
||||||
|
|
||||||
def _delete_pushers(txn) -> int:
|
def _delete_pushers(txn: LoggingTransaction) -> int:
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
SELECT p.id, access_token FROM pushers AS p
|
SELECT p.id, access_token FROM pushers AS p
|
||||||
@ -396,7 +415,7 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
last_pusher = progress.get("last_pusher", 0)
|
last_pusher = progress.get("last_pusher", 0)
|
||||||
|
|
||||||
def _delete_pushers(txn) -> int:
|
def _delete_pushers(txn: LoggingTransaction) -> int:
|
||||||
|
|
||||||
sql = """
|
sql = """
|
||||||
SELECT p.id, p.user_name, p.app_id, p.pushkey
|
SELECT p.id, p.user_name, p.app_id, p.pushkey
|
||||||
@ -502,7 +521,7 @@ class PusherStore(PusherWorkerStore):
|
|||||||
async def delete_pusher_by_app_id_pushkey_user_id(
|
async def delete_pusher_by_app_id_pushkey_user_id(
|
||||||
self, app_id: str, pushkey: str, user_id: str
|
self, app_id: str, pushkey: str, user_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
def delete_pusher_txn(txn, stream_id):
|
def delete_pusher_txn(txn: LoggingTransaction, stream_id: int) -> None:
|
||||||
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
|
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
|
||||||
txn, self.get_if_user_has_pusher, (user_id,)
|
txn, self.get_if_user_has_pusher, (user_id,)
|
||||||
)
|
)
|
||||||
@ -547,7 +566,7 @@ class PusherStore(PusherWorkerStore):
|
|||||||
# account.
|
# account.
|
||||||
pushers = list(await self.get_pushers_by_user_id(user_id))
|
pushers = list(await self.get_pushers_by_user_id(user_id))
|
||||||
|
|
||||||
def delete_pushers_txn(txn, stream_ids):
|
def delete_pushers_txn(txn: LoggingTransaction, stream_ids: List[int]) -> None:
|
||||||
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
|
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
|
||||||
txn, self.get_if_user_has_pusher, (user_id,)
|
txn, self.get_if_user_has_pusher, (user_id,)
|
||||||
)
|
)
|
||||||
|
@ -370,10 +370,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||||||
|
|
||||||
def _update_state_for_partial_state_event_txn(
|
def _update_state_for_partial_state_event_txn(
|
||||||
self,
|
self,
|
||||||
txn,
|
txn: LoggingTransaction,
|
||||||
event: EventBase,
|
event: EventBase,
|
||||||
context: EventContext,
|
context: EventContext,
|
||||||
):
|
) -> None:
|
||||||
# we shouldn't have any outliers here
|
# we shouldn't have any outliers here
|
||||||
assert not event.internal_metadata.is_outlier()
|
assert not event.internal_metadata.is_outlier()
|
||||||
|
|
||||||
|
@ -131,7 +131,7 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||||||
session_id: str,
|
session_id: str,
|
||||||
stage_type: str,
|
stage_type: str,
|
||||||
result: Union[str, bool, JsonDict],
|
result: Union[str, bool, JsonDict],
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Mark a session stage as completed.
|
Mark a session stage as completed.
|
||||||
|
|
||||||
@ -200,7 +200,9 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||||||
desc="set_ui_auth_client_dict",
|
desc="set_ui_auth_client_dict",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any):
|
async def set_ui_auth_session_data(
|
||||||
|
self, session_id: str, key: str, value: Any
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Store a key-value pair into the sessions data associated with this
|
Store a key-value pair into the sessions data associated with this
|
||||||
request. This data is stored server-side and cannot be modified by
|
request. This data is stored server-side and cannot be modified by
|
||||||
@ -223,7 +225,7 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
def _set_ui_auth_session_data_txn(
|
def _set_ui_auth_session_data_txn(
|
||||||
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
|
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
|
||||||
):
|
) -> None:
|
||||||
# Get the current value.
|
# Get the current value.
|
||||||
result = cast(
|
result = cast(
|
||||||
Dict[str, Any],
|
Dict[str, Any],
|
||||||
@ -275,7 +277,7 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||||||
session_id: str,
|
session_id: str,
|
||||||
user_agent: str,
|
user_agent: str,
|
||||||
ip: str,
|
ip: str,
|
||||||
):
|
) -> None:
|
||||||
"""Add the given user agent / IP to the tracking table"""
|
"""Add the given user agent / IP to the tracking table"""
|
||||||
await self.db_pool.simple_upsert(
|
await self.db_pool.simple_upsert(
|
||||||
table="ui_auth_sessions_ips",
|
table="ui_auth_sessions_ips",
|
||||||
@ -318,7 +320,7 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
def _delete_old_ui_auth_sessions_txn(
|
def _delete_old_ui_auth_sessions_txn(
|
||||||
self, txn: LoggingTransaction, expiration_time: int
|
self, txn: LoggingTransaction, expiration_time: int
|
||||||
):
|
) -> None:
|
||||||
# Get the expired sessions.
|
# Get the expired sessions.
|
||||||
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
|
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
|
||||||
txn.execute(sql, [expiration_time])
|
txn.execute(sql, [expiration_time])
|
||||||
|
Loading…
Reference in New Issue
Block a user