mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-16 22:57:06 -05:00
Improve typing in user_directory files (#10891)
* Improve typing in user_directory files This makes the user_directory.py in storage pass most of mypy's checks (including `no-untyped-defs`). Unfortunately that file is in the tangled web of Store class inheritance so doesn't pass mypy at the moment. The handlers directory has already been mypyed. Co-authored-by: reivilibre <olivier@librepush.net>
This commit is contained in:
parent
e704cc2a48
commit
7f3352743e
1
changelog.d/10891.misc
Normal file
1
changelog.d/10891.misc
Normal file
@ -0,0 +1 @@
|
||||
Improve type hinting in the user directory code.
|
2
mypy.ini
2
mypy.ini
@ -85,9 +85,11 @@ files =
|
||||
tests/handlers/test_room_summary.py,
|
||||
tests/handlers/test_send_email.py,
|
||||
tests/handlers/test_sync.py,
|
||||
tests/handlers/test_user_directory.py,
|
||||
tests/rest/client/test_login.py,
|
||||
tests/rest/client/test_auth.py,
|
||||
tests/storage/test_state.py,
|
||||
tests/storage/test_user_directory.py,
|
||||
tests/util/test_itertools.py,
|
||||
tests/util/test_stream_change_cache.py
|
||||
|
||||
|
@ -14,14 +14,28 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.databases.main.state import StateFilter
|
||||
from synapse.storage.databases.main.state_deltas import StateDeltasStore
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.types import get_domain_from_id, get_localpart_from_id
|
||||
from synapse.storage.types import Connection
|
||||
from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -36,7 +50,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
# add_users_who_share_private_rooms?
|
||||
SHARE_PRIVATE_WORKING_SET = 500
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: Connection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.server_name = hs.hostname
|
||||
@ -57,10 +76,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
"populate_user_directory_cleanup", self._populate_user_directory_cleanup
|
||||
)
|
||||
|
||||
async def _populate_user_directory_createtables(self, progress, batch_size):
|
||||
async def _populate_user_directory_createtables(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
|
||||
# Get all the rooms that we want to process.
|
||||
def _make_staging_area(txn):
|
||||
def _make_staging_area(txn: LoggingTransaction) -> None:
|
||||
sql = (
|
||||
"CREATE TABLE IF NOT EXISTS "
|
||||
+ TEMP_TABLE
|
||||
@ -110,16 +131,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
)
|
||||
return 1
|
||||
|
||||
async def _populate_user_directory_cleanup(self, progress, batch_size):
|
||||
async def _populate_user_directory_cleanup(
|
||||
self,
|
||||
progress: JsonDict,
|
||||
batch_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Update the user directory stream position, then clean up the old tables.
|
||||
"""
|
||||
position = await self.db_pool.simple_select_one_onecol(
|
||||
TEMP_TABLE + "_position", None, "position"
|
||||
TEMP_TABLE + "_position", {}, "position"
|
||||
)
|
||||
await self.update_user_directory_stream_pos(position)
|
||||
|
||||
def _delete_staging_area(txn):
|
||||
def _delete_staging_area(txn: LoggingTransaction) -> None:
|
||||
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
|
||||
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
|
||||
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
|
||||
@ -133,18 +158,32 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
)
|
||||
return 1
|
||||
|
||||
async def _populate_user_directory_process_rooms(self, progress, batch_size):
|
||||
async def _populate_user_directory_process_rooms(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""
|
||||
Rescan the state of all rooms so we can track
|
||||
|
||||
- who's in a public room;
|
||||
- which local users share a private room with other users (local
|
||||
and remote); and
|
||||
- who should be in the user_directory.
|
||||
|
||||
Args:
|
||||
progress (dict)
|
||||
batch_size (int): Maximum number of state events to process
|
||||
per cycle.
|
||||
|
||||
Returns:
|
||||
number of events processed.
|
||||
"""
|
||||
# If we don't have progress filed, delete everything.
|
||||
if not progress:
|
||||
await self.delete_all_from_user_dir()
|
||||
|
||||
def _get_next_batch(txn):
|
||||
def _get_next_batch(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[Sequence[Tuple[str, int]]]:
|
||||
# Only fetch 250 rooms, so we don't fetch too many at once, even
|
||||
# if those 250 rooms have less than batch_size state events.
|
||||
sql = """
|
||||
@ -155,7 +194,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
TEMP_TABLE + "_rooms",
|
||||
)
|
||||
txn.execute(sql)
|
||||
rooms_to_work_on = txn.fetchall()
|
||||
rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall())
|
||||
|
||||
if not rooms_to_work_on:
|
||||
return None
|
||||
@ -163,7 +202,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
# Get how many are left to process, so we can give status on how
|
||||
# far we are in processing
|
||||
txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
|
||||
progress["remaining"] = txn.fetchone()[0]
|
||||
result = txn.fetchone()
|
||||
assert result is not None
|
||||
progress["remaining"] = result[0]
|
||||
|
||||
return rooms_to_work_on
|
||||
|
||||
@ -261,29 +302,33 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
|
||||
return processed_event_count
|
||||
|
||||
async def _populate_user_directory_process_users(self, progress, batch_size):
|
||||
async def _populate_user_directory_process_users(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""
|
||||
Add all local users to the user directory.
|
||||
"""
|
||||
|
||||
def _get_next_batch(txn):
|
||||
def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]:
|
||||
sql = "SELECT user_id FROM %s LIMIT %s" % (
|
||||
TEMP_TABLE + "_users",
|
||||
str(batch_size),
|
||||
)
|
||||
txn.execute(sql)
|
||||
users_to_work_on = txn.fetchall()
|
||||
user_result = cast(List[Tuple[str]], txn.fetchall())
|
||||
|
||||
if not users_to_work_on:
|
||||
if not user_result:
|
||||
return None
|
||||
|
||||
users_to_work_on = [x[0] for x in users_to_work_on]
|
||||
users_to_work_on = [x[0] for x in user_result]
|
||||
|
||||
# Get how many are left to process, so we can give status on how
|
||||
# far we are in processing
|
||||
sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users"
|
||||
txn.execute(sql)
|
||||
progress["remaining"] = txn.fetchone()[0]
|
||||
count_result = txn.fetchone()
|
||||
assert count_result is not None
|
||||
progress["remaining"] = count_result[0]
|
||||
|
||||
return users_to_work_on
|
||||
|
||||
@ -324,7 +369,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
|
||||
return len(users_to_work_on)
|
||||
|
||||
async def is_room_world_readable_or_publicly_joinable(self, room_id):
|
||||
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
|
||||
"""Check if the room is either world_readable or publically joinable"""
|
||||
|
||||
# Create a state filter that only queries join and history state event
|
||||
@ -368,7 +413,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
if not isinstance(avatar_url, str):
|
||||
avatar_url = None
|
||||
|
||||
def _update_profile_in_user_dir_txn(txn):
|
||||
def _update_profile_in_user_dir_txn(txn: LoggingTransaction) -> None:
|
||||
self.db_pool.simple_upsert_txn(
|
||||
txn,
|
||||
table="user_directory",
|
||||
@ -435,7 +480,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
for user_id, other_user_id in user_id_tuples
|
||||
],
|
||||
value_names=(),
|
||||
value_values=None,
|
||||
value_values=(),
|
||||
desc="add_users_who_share_room",
|
||||
)
|
||||
|
||||
@ -454,14 +499,14 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
key_names=["user_id", "room_id"],
|
||||
key_values=[(user_id, room_id) for user_id in user_ids],
|
||||
value_names=(),
|
||||
value_values=None,
|
||||
value_values=(),
|
||||
desc="add_users_in_public_rooms",
|
||||
)
|
||||
|
||||
async def delete_all_from_user_dir(self) -> None:
|
||||
"""Delete the entire user directory"""
|
||||
|
||||
def _delete_all_from_user_dir_txn(txn):
|
||||
def _delete_all_from_user_dir_txn(txn: LoggingTransaction) -> None:
|
||||
txn.execute("DELETE FROM user_directory")
|
||||
txn.execute("DELETE FROM user_directory_search")
|
||||
txn.execute("DELETE FROM users_in_public_rooms")
|
||||
@ -473,7 +518,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
)
|
||||
|
||||
@cached()
|
||||
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="user_directory",
|
||||
keyvalues={"user_id": user_id},
|
||||
@ -497,7 +542,12 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
# add_users_who_share_private_rooms?
|
||||
SHARE_PRIVATE_WORKING_SET = 500
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: Connection,
|
||||
hs: "HomeServer",
|
||||
) -> None:
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._prefer_local_users_in_search = (
|
||||
@ -506,7 +556,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
self._server_name = hs.config.server.server_name
|
||||
|
||||
async def remove_from_user_dir(self, user_id: str) -> None:
|
||||
def _remove_from_user_dir_txn(txn):
|
||||
def _remove_from_user_dir_txn(txn: LoggingTransaction) -> None:
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn, table="user_directory", keyvalues={"user_id": user_id}
|
||||
)
|
||||
@ -532,7 +582,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
"remove_from_user_dir", _remove_from_user_dir_txn
|
||||
)
|
||||
|
||||
async def get_users_in_dir_due_to_room(self, room_id):
|
||||
async def get_users_in_dir_due_to_room(self, room_id: str) -> Set[str]:
|
||||
"""Get all user_ids that are in the room directory because they're
|
||||
in the given room_id
|
||||
"""
|
||||
@ -565,7 +615,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
room_id
|
||||
"""
|
||||
|
||||
def _remove_user_who_share_room_txn(txn):
|
||||
def _remove_user_who_share_room_txn(txn: LoggingTransaction) -> None:
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="users_who_share_private_rooms",
|
||||
@ -586,7 +636,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
"remove_user_who_share_room", _remove_user_who_share_room_txn
|
||||
)
|
||||
|
||||
async def get_user_dir_rooms_user_is_in(self, user_id):
|
||||
async def get_user_dir_rooms_user_is_in(self, user_id: str) -> List[str]:
|
||||
"""
|
||||
Returns the rooms that a user is in.
|
||||
|
||||
@ -628,7 +678,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
A set of room ID's that the users share.
|
||||
"""
|
||||
|
||||
def _get_shared_rooms_for_users_txn(txn):
|
||||
def _get_shared_rooms_for_users_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[Dict[str, str]]:
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT p1.room_id
|
||||
@ -669,7 +721,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
desc="get_user_directory_stream_pos",
|
||||
)
|
||||
|
||||
async def search_user_dir(self, user_id, search_term, limit):
|
||||
async def search_user_dir(
|
||||
self, user_id: str, search_term: str, limit: int
|
||||
) -> JsonDict:
|
||||
"""Searches for users in directory
|
||||
|
||||
Returns:
|
||||
@ -705,7 +759,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
# We allow manipulating the ranking algorithm by injecting statements
|
||||
# based on config options.
|
||||
additional_ordering_statements = []
|
||||
ordering_arguments = ()
|
||||
ordering_arguments: Tuple[str, ...] = ()
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
|
||||
@ -811,7 +865,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
return {"limited": limited, "results": results}
|
||||
|
||||
|
||||
def _parse_query_sqlite(search_term):
|
||||
def _parse_query_sqlite(search_term: str) -> str:
|
||||
"""Takes a plain unicode string from the user and converts it into a form
|
||||
that can be passed to database.
|
||||
We use this so that we can add prefix matching, which isn't something
|
||||
@ -826,7 +880,7 @@ def _parse_query_sqlite(search_term):
|
||||
return " & ".join("(%s* OR %s)" % (result, result) for result in results)
|
||||
|
||||
|
||||
def _parse_query_postgres(search_term):
|
||||
def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
|
||||
"""Takes a plain unicode string from the user and converts it into a form
|
||||
that can be passed to database.
|
||||
We use this so that we can add prefix matching, which isn't something
|
||||
|
@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Tuple
|
||||
from unittest.mock import Mock
|
||||
from urllib.parse import quote
|
||||
|
||||
@ -325,7 +326,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||
r.add((i["user_id"], i["other_user_id"], i["room_id"]))
|
||||
return r
|
||||
|
||||
def get_users_in_public_rooms(self):
|
||||
def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
|
||||
r = self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
"users_in_public_rooms", None, ("user_id", "room_id")
|
||||
@ -336,7 +337,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||
retval.append((i["user_id"], i["room_id"]))
|
||||
return retval
|
||||
|
||||
def get_users_who_share_private_rooms(self):
|
||||
def get_users_who_share_private_rooms(self) -> List[Tuple[str, str, str]]:
|
||||
return self.get_success(
|
||||
self.store.db_pool.simple_select_list(
|
||||
"users_who_share_private_rooms",
|
||||
|
Loading…
Reference in New Issue
Block a user