Convert additional database stores to async/await (#8045)

This commit is contained in:
Patrick Cloke 2020-08-07 12:17:17 -04:00 committed by GitHub
parent 1048ed2afa
commit f3fe6961b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 107 additions and 152 deletions

1
changelog.d/8045.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -14,8 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, Optional, Tuple
from twisted.internet import defer
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@ -82,21 +81,19 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
"devices_last_seen", self._devices_last_seen_update "devices_last_seen", self._devices_last_seen_update
) )
@defer.inlineCallbacks async def _remove_user_ip_nonunique(self, progress, batch_size):
def _remove_user_ip_nonunique(self, progress, batch_size):
def f(conn): def f(conn):
txn = conn.cursor() txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
txn.close() txn.close()
yield self.db_pool.runWithConnection(f) await self.db_pool.runWithConnection(f)
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
"user_ips_drop_nonunique_index" "user_ips_drop_nonunique_index"
) )
return 1 return 1
@defer.inlineCallbacks async def _analyze_user_ip(self, progress, batch_size):
def _analyze_user_ip(self, progress, batch_size):
# Background update to analyze user_ips table before we run the # Background update to analyze user_ips table before we run the
# deduplication background update. The table may not have been analyzed # deduplication background update. The table may not have been analyzed
# for ages due to the table locks. # for ages due to the table locks.
@ -106,14 +103,13 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
def user_ips_analyze(txn): def user_ips_analyze(txn):
txn.execute("ANALYZE user_ips") txn.execute("ANALYZE user_ips")
yield self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze) await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
yield self.db_pool.updates._end_background_update("user_ips_analyze") await self.db_pool.updates._end_background_update("user_ips_analyze")
return 1 return 1
@defer.inlineCallbacks async def _remove_user_ip_dupes(self, progress, batch_size):
def _remove_user_ip_dupes(self, progress, batch_size):
# This works function works by scanning the user_ips table in batches # This works function works by scanning the user_ips table in batches
# based on `last_seen`. For each row in a batch it searches the rest of # based on `last_seen`. For each row in a batch it searches the rest of
# the table to see if there are any duplicates, if there are then they # the table to see if there are any duplicates, if there are then they
@ -140,7 +136,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return None return None
# Get a last seen that has roughly `batch_size` since `begin_last_seen` # Get a last seen that has roughly `batch_size` since `begin_last_seen`
end_last_seen = yield self.db_pool.runInteraction( end_last_seen = await self.db_pool.runInteraction(
"user_ips_dups_get_last_seen", get_last_seen "user_ips_dups_get_last_seen", get_last_seen
) )
@ -275,15 +271,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
) )
yield self.db_pool.runInteraction("user_ips_dups_remove", remove) await self.db_pool.runInteraction("user_ips_dups_remove", remove)
if last: if last:
yield self.db_pool.updates._end_background_update("user_ips_remove_dupes") await self.db_pool.updates._end_background_update("user_ips_remove_dupes")
return batch_size return batch_size
@defer.inlineCallbacks async def _devices_last_seen_update(self, progress, batch_size):
def _devices_last_seen_update(self, progress, batch_size):
"""Background update to insert last seen info into devices table """Background update to insert last seen info into devices table
""" """
@ -346,12 +341,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return len(rows) return len(rows)
updated = yield self.db_pool.runInteraction( updated = await self.db_pool.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn "_devices_last_seen_update", _devices_last_seen_update_txn
) )
if not updated: if not updated:
yield self.db_pool.updates._end_background_update("devices_last_seen") await self.db_pool.updates._end_background_update("devices_last_seen")
return updated return updated
@ -460,25 +455,25 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# Failed to upsert, log and continue # Failed to upsert, log and continue
logger.error("Failed to insert client IP %r: %r", entry, e) logger.error("Failed to insert client IP %r: %r", entry, e)
@defer.inlineCallbacks async def get_last_client_ip_by_device(
def get_last_client_ip_by_device(self, user_id, device_id): self, user_id: str, device_id: Optional[str]
) -> Dict[Tuple[str, str], dict]:
"""For each device_id listed, give the user_ip it was last seen on """For each device_id listed, give the user_ip it was last seen on
Args: Args:
user_id (str) user_id: The user to fetch devices for.
device_id (str): If None fetches all devices for the user device_id: If None fetches all devices for the user
Returns: Returns:
defer.Deferred: resolves to a dict, where the keys A dictionary mapping a tuple of (user_id, device_id) to dicts, with
are (user_id, device_id) tuples. The values are also dicts, with keys giving the column names from the devices table.
keys giving the column names
""" """
keyvalues = {"user_id": user_id} keyvalues = {"user_id": user_id}
if device_id is not None: if device_id is not None:
keyvalues["device_id"] = device_id keyvalues["device_id"] = device_id
res = yield self.db_pool.simple_select_list( res = await self.db_pool.simple_select_list(
table="devices", table="devices",
keyvalues=keyvalues, keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
@ -500,8 +495,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
} }
return ret return ret
@defer.inlineCallbacks async def get_user_ip_and_agents(self, user):
def get_user_ip_and_agents(self, user):
user_id = user.to_string() user_id = user.to_string()
results = {} results = {}
@ -511,7 +505,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
user_agent, _, last_seen = self._batch_row_update[key] user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen) results[(access_token, ip)] = (user_agent, last_seen)
rows = yield self.db_pool.simple_select_list( rows = await self.db_pool.simple_select_list(
table="user_ips", table="user_ips",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "last_seen"], retcols=["access_token", "ip", "user_agent", "last_seen"],

View File

@ -16,8 +16,7 @@
import logging import logging
import re import re
from collections import namedtuple from collections import namedtuple
from typing import List, Optional
from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@ -114,8 +113,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
) )
@defer.inlineCallbacks async def _background_reindex_search(self, progress, batch_size):
def _background_reindex_search(self, progress, batch_size):
# we work through the events table from highest stream id to lowest # we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"] target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"] max_stream_id = progress["max_stream_id_exclusive"]
@ -206,19 +204,18 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
return len(event_search_rows) return len(event_search_rows)
result = yield self.db_pool.runInteraction( result = await self.db_pool.runInteraction(
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
) )
if not result: if not result:
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
self.EVENT_SEARCH_UPDATE_NAME self.EVENT_SEARCH_UPDATE_NAME
) )
return result return result
@defer.inlineCallbacks async def _background_reindex_gin_search(self, progress, batch_size):
def _background_reindex_gin_search(self, progress, batch_size):
"""This handles old synapses which used GIST indexes, if any; """This handles old synapses which used GIST indexes, if any;
converting them back to be GIN as per the actual schema. converting them back to be GIN as per the actual schema.
""" """
@ -255,15 +252,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
conn.set_session(autocommit=False) conn.set_session(autocommit=False)
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
yield self.db_pool.runWithConnection(create_index) await self.db_pool.runWithConnection(create_index)
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME
) )
return 1 return 1
@defer.inlineCallbacks async def _background_reindex_search_order(self, progress, batch_size):
def _background_reindex_search_order(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"] target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"] max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0) rows_inserted = progress.get("rows_inserted", 0)
@ -288,12 +284,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
) )
conn.set_session(autocommit=False) conn.set_session(autocommit=False)
yield self.db_pool.runWithConnection(create_index) await self.db_pool.runWithConnection(create_index)
pg = dict(progress) pg = dict(progress)
pg["have_added_indexes"] = True pg["have_added_indexes"] = True
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, self.EVENT_SEARCH_ORDER_UPDATE_NAME,
self.db_pool.updates._background_update_progress_txn, self.db_pool.updates._background_update_progress_txn,
self.EVENT_SEARCH_ORDER_UPDATE_NAME, self.EVENT_SEARCH_ORDER_UPDATE_NAME,
@ -331,12 +327,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
return len(rows), True return len(rows), True
num_rows, finished = yield self.db_pool.runInteraction( num_rows, finished = await self.db_pool.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
) )
if not finished: if not finished:
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
self.EVENT_SEARCH_ORDER_UPDATE_NAME self.EVENT_SEARCH_ORDER_UPDATE_NAME
) )
@ -347,8 +343,7 @@ class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(SearchStore, self).__init__(database, db_conn, hs) super(SearchStore, self).__init__(database, db_conn, hs)
@defer.inlineCallbacks async def search_msgs(self, room_ids, search_term, keys):
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys. """Performs a full text search over events with given keys.
Args: Args:
@ -425,7 +420,7 @@ class SearchStore(SearchBackgroundUpdateStore):
# entire table from the database. # entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500" sql += " ORDER BY rank DESC LIMIT 500"
results = yield self.db_pool.execute( results = await self.db_pool.execute(
"search_msgs", self.db_pool.cursor_to_dict, sql, *args "search_msgs", self.db_pool.cursor_to_dict, sql, *args
) )
@ -433,7 +428,7 @@ class SearchStore(SearchBackgroundUpdateStore):
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak) # search results (which is a data leak)
events = yield self.get_events_as_list( events = await self.get_events_as_list(
[r["event_id"] for r in results], [r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK, redact_behaviour=EventRedactBehaviour.BLOCK,
) )
@ -442,11 +437,11 @@ class SearchStore(SearchBackgroundUpdateStore):
highlights = None highlights = None
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
highlights = yield self._find_highlights_in_postgres(search_query, events) highlights = await self._find_highlights_in_postgres(search_query, events)
count_sql += " GROUP BY room_id" count_sql += " GROUP BY room_id"
count_results = yield self.db_pool.execute( count_results = await self.db_pool.execute(
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
) )
@ -462,19 +457,25 @@ class SearchStore(SearchBackgroundUpdateStore):
"count": count, "count": count,
} }
@defer.inlineCallbacks async def search_rooms(
def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None): self,
room_ids: List[str],
search_term: str,
keys: List[str],
limit,
pagination_token: Optional[str] = None,
) -> List[dict]:
"""Performs a full text search over events with given keys. """Performs a full text search over events with given keys.
Args: Args:
room_id (list): The room_ids to search in room_ids: The room_ids to search in
search_term (str): Search term to search for search_term: Search term to search for
keys (list): List of keys to search in, currently supports keys: List of keys to search in, currently supports "content.body",
"content.body", "content.name", "content.topic" "content.name", "content.topic"
pagination_token (str): A pagination token previously returned pagination_token: A pagination token previously returned
Returns: Returns:
list of dicts Each match as a dictionary.
""" """
clauses = [] clauses = []
@ -577,7 +578,7 @@ class SearchStore(SearchBackgroundUpdateStore):
args.append(limit) args.append(limit)
results = yield self.db_pool.execute( results = await self.db_pool.execute(
"search_rooms", self.db_pool.cursor_to_dict, sql, *args "search_rooms", self.db_pool.cursor_to_dict, sql, *args
) )
@ -585,7 +586,7 @@ class SearchStore(SearchBackgroundUpdateStore):
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak) # search results (which is a data leak)
events = yield self.get_events_as_list( events = await self.get_events_as_list(
[r["event_id"] for r in results], [r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK, redact_behaviour=EventRedactBehaviour.BLOCK,
) )
@ -594,11 +595,11 @@ class SearchStore(SearchBackgroundUpdateStore):
highlights = None highlights = None
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
highlights = yield self._find_highlights_in_postgres(search_query, events) highlights = await self._find_highlights_in_postgres(search_query, events)
count_sql += " GROUP BY room_id" count_sql += " GROUP BY room_id"
count_results = yield self.db_pool.execute( count_results = await self.db_pool.execute(
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
) )

View File

@ -15,8 +15,6 @@
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -40,9 +38,8 @@ class SignatureWorkerStore(SQLBaseStore):
return self.db_pool.runInteraction("get_event_reference_hashes", f) return self.db_pool.runInteraction("get_event_reference_hashes", f)
@defer.inlineCallbacks async def add_event_hashes(self, event_ids):
def add_event_hashes(self, event_ids): hashes = await self.get_event_reference_hashes(event_ids)
hashes = yield self.get_event_reference_hashes(event_ids)
hashes = { hashes = {
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
for e_id, h in hashes.items() for e_id, h in hashes.items()

View File

@ -16,8 +16,6 @@
import logging import logging
import re import re
from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.state import StateFilter from synapse.storage.databases.main.state import StateFilter
@ -59,8 +57,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"populate_user_directory_cleanup", self._populate_user_directory_cleanup "populate_user_directory_cleanup", self._populate_user_directory_cleanup
) )
@defer.inlineCallbacks async def _populate_user_directory_createtables(self, progress, batch_size):
def _populate_user_directory_createtables(self, progress, batch_size):
# Get all the rooms that we want to process. # Get all the rooms that we want to process.
def _make_staging_area(txn): def _make_staging_area(txn):
@ -102,45 +99,43 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
new_pos = yield self.get_max_stream_id_in_current_state_deltas() new_pos = await self.get_max_stream_id_in_current_state_deltas()
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"populate_user_directory_temp_build", _make_staging_area "populate_user_directory_temp_build", _make_staging_area
) )
yield self.db_pool.simple_insert( await self.db_pool.simple_insert(
TEMP_TABLE + "_position", {"position": new_pos} TEMP_TABLE + "_position", {"position": new_pos}
) )
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
"populate_user_directory_createtables" "populate_user_directory_createtables"
) )
return 1 return 1
@defer.inlineCallbacks async def _populate_user_directory_cleanup(self, progress, batch_size):
def _populate_user_directory_cleanup(self, progress, batch_size):
""" """
Update the user directory stream position, then clean up the old tables. Update the user directory stream position, then clean up the old tables.
""" """
position = yield self.db_pool.simple_select_one_onecol( position = await self.db_pool.simple_select_one_onecol(
TEMP_TABLE + "_position", None, "position" TEMP_TABLE + "_position", None, "position"
) )
yield self.update_user_directory_stream_pos(position) await self.update_user_directory_stream_pos(position)
def _delete_staging_area(txn): def _delete_staging_area(txn):
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"populate_user_directory_cleanup", _delete_staging_area "populate_user_directory_cleanup", _delete_staging_area
) )
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
"populate_user_directory_cleanup" "populate_user_directory_cleanup"
) )
return 1 return 1
@defer.inlineCallbacks async def _populate_user_directory_process_rooms(self, progress, batch_size):
def _populate_user_directory_process_rooms(self, progress, batch_size):
""" """
Args: Args:
progress (dict) progress (dict)
@ -151,7 +146,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# If we don't have progress filed, delete everything. # If we don't have progress filed, delete everything.
if not progress: if not progress:
yield self.delete_all_from_user_dir() await self.delete_all_from_user_dir()
def _get_next_batch(txn): def _get_next_batch(txn):
# Only fetch 250 rooms, so we don't fetch too many at once, even # Only fetch 250 rooms, so we don't fetch too many at once, even
@ -176,13 +171,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return rooms_to_work_on return rooms_to_work_on
rooms_to_work_on = yield self.db_pool.runInteraction( rooms_to_work_on = await self.db_pool.runInteraction(
"populate_user_directory_temp_read", _get_next_batch "populate_user_directory_temp_read", _get_next_batch
) )
# No more rooms -- complete the transaction. # No more rooms -- complete the transaction.
if not rooms_to_work_on: if not rooms_to_work_on:
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
"populate_user_directory_process_rooms" "populate_user_directory_process_rooms"
) )
return 1 return 1
@ -195,21 +190,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
processed_event_count = 0 processed_event_count = 0
for room_id, event_count in rooms_to_work_on: for room_id, event_count in rooms_to_work_on:
is_in_room = yield self.is_host_joined(room_id, self.server_name) is_in_room = await self.is_host_joined(room_id, self.server_name)
if is_in_room: if is_in_room:
is_public = yield self.is_room_world_readable_or_publicly_joinable( is_public = await self.is_room_world_readable_or_publicly_joinable(
room_id room_id
) )
users_with_profile = yield defer.ensureDeferred( users_with_profile = await state.get_current_users_in_room(room_id)
state.get_current_users_in_room(room_id)
)
user_ids = set(users_with_profile) user_ids = set(users_with_profile)
# Update each user in the user directory. # Update each user in the user directory.
for user_id, profile in users_with_profile.items(): for user_id, profile in users_with_profile.items():
yield self.update_profile_in_user_dir( await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url user_id, profile.display_name, profile.avatar_url
) )
@ -223,7 +216,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
to_insert.add(user_id) to_insert.add(user_id)
if to_insert: if to_insert:
yield self.add_users_in_public_rooms(room_id, to_insert) await self.add_users_in_public_rooms(room_id, to_insert)
to_insert.clear() to_insert.clear()
else: else:
for user_id in user_ids: for user_id in user_ids:
@ -243,22 +236,22 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# If it gets too big, stop and write to the database # If it gets too big, stop and write to the database
# to prevent storing too much in RAM. # to prevent storing too much in RAM.
if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET: if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET:
yield self.add_users_who_share_private_room( await self.add_users_who_share_private_room(
room_id, to_insert room_id, to_insert
) )
to_insert.clear() to_insert.clear()
if to_insert: if to_insert:
yield self.add_users_who_share_private_room(room_id, to_insert) await self.add_users_who_share_private_room(room_id, to_insert)
to_insert.clear() to_insert.clear()
# We've finished a room. Delete it from the table. # We've finished a room. Delete it from the table.
yield self.db_pool.simple_delete_one( await self.db_pool.simple_delete_one(
TEMP_TABLE + "_rooms", {"room_id": room_id} TEMP_TABLE + "_rooms", {"room_id": room_id}
) )
# Update the remaining counter. # Update the remaining counter.
progress["remaining"] -= 1 progress["remaining"] -= 1
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"populate_user_directory", "populate_user_directory",
self.db_pool.updates._background_update_progress_txn, self.db_pool.updates._background_update_progress_txn,
"populate_user_directory_process_rooms", "populate_user_directory_process_rooms",
@ -273,13 +266,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return processed_event_count return processed_event_count
@defer.inlineCallbacks async def _populate_user_directory_process_users(self, progress, batch_size):
def _populate_user_directory_process_users(self, progress, batch_size):
""" """
If search_all_users is enabled, add all of the users to the user directory. If search_all_users is enabled, add all of the users to the user directory.
""" """
if not self.hs.config.user_directory_search_all_users: if not self.hs.config.user_directory_search_all_users:
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
"populate_user_directory_process_users" "populate_user_directory_process_users"
) )
return 1 return 1
@ -305,13 +297,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return users_to_work_on return users_to_work_on
users_to_work_on = yield self.db_pool.runInteraction( users_to_work_on = await self.db_pool.runInteraction(
"populate_user_directory_temp_read", _get_next_batch "populate_user_directory_temp_read", _get_next_batch
) )
# No more users -- complete the transaction. # No more users -- complete the transaction.
if not users_to_work_on: if not users_to_work_on:
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
"populate_user_directory_process_users" "populate_user_directory_process_users"
) )
return 1 return 1
@ -322,18 +314,18 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
) )
for user_id in users_to_work_on: for user_id in users_to_work_on:
profile = yield self.get_profileinfo(get_localpart_from_id(user_id)) profile = await self.get_profileinfo(get_localpart_from_id(user_id))
yield self.update_profile_in_user_dir( await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url user_id, profile.display_name, profile.avatar_url
) )
# We've finished processing a user. Delete it from the table. # We've finished processing a user. Delete it from the table.
yield self.db_pool.simple_delete_one( await self.db_pool.simple_delete_one(
TEMP_TABLE + "_users", {"user_id": user_id} TEMP_TABLE + "_users", {"user_id": user_id}
) )
# Update the remaining counter. # Update the remaining counter.
progress["remaining"] -= 1 progress["remaining"] -= 1
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"populate_user_directory", "populate_user_directory",
self.db_pool.updates._background_update_progress_txn, self.db_pool.updates._background_update_progress_txn,
"populate_user_directory_process_users", "populate_user_directory_process_users",
@ -342,8 +334,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return len(users_to_work_on) return len(users_to_work_on)
@defer.inlineCallbacks async def is_room_world_readable_or_publicly_joinable(self, room_id):
def is_room_world_readable_or_publicly_joinable(self, room_id):
"""Check if the room is either world_readable or publically joinable """Check if the room is either world_readable or publically joinable
""" """
@ -353,20 +344,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
(EventTypes.RoomHistoryVisibility, ""), (EventTypes.RoomHistoryVisibility, ""),
) )
current_state_ids = yield self.get_filtered_current_state_ids( current_state_ids = await self.get_filtered_current_state_ids(
room_id, StateFilter.from_types(types_to_filter) room_id, StateFilter.from_types(types_to_filter)
) )
join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
if join_rules_id: if join_rules_id:
join_rule_ev = yield self.get_event(join_rules_id, allow_none=True) join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
if join_rule_ev: if join_rule_ev:
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
return True return True
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
if hist_vis_id: if hist_vis_id:
hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True) hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
if hist_vis_ev: if hist_vis_ev:
if hist_vis_ev.content.get("history_visibility") == "world_readable": if hist_vis_ev.content.get("history_visibility") == "world_readable":
return True return True
@ -590,19 +581,18 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"remove_from_user_dir", _remove_from_user_dir_txn "remove_from_user_dir", _remove_from_user_dir_txn
) )
@defer.inlineCallbacks async def get_users_in_dir_due_to_room(self, room_id):
def get_users_in_dir_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory because they're """Get all user_ids that are in the room directory because they're
in the given room_id in the given room_id
""" """
user_ids_share_pub = yield self.db_pool.simple_select_onecol( user_ids_share_pub = await self.db_pool.simple_select_onecol(
table="users_in_public_rooms", table="users_in_public_rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcol="user_id", retcol="user_id",
desc="get_users_in_dir_due_to_room", desc="get_users_in_dir_due_to_room",
) )
user_ids_share_priv = yield self.db_pool.simple_select_onecol( user_ids_share_priv = await self.db_pool.simple_select_onecol(
table="users_who_share_private_rooms", table="users_who_share_private_rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcol="other_user_id", retcol="other_user_id",
@ -645,8 +635,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"remove_user_who_share_room", _remove_user_who_share_room_txn "remove_user_who_share_room", _remove_user_who_share_room_txn
) )
@defer.inlineCallbacks async def get_user_dir_rooms_user_is_in(self, user_id):
def get_user_dir_rooms_user_is_in(self, user_id):
""" """
Returns the rooms that a user is in. Returns the rooms that a user is in.
@ -656,14 +645,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
Returns: Returns:
list: user_id list: user_id
""" """
rows = yield self.db_pool.simple_select_onecol( rows = await self.db_pool.simple_select_onecol(
table="users_who_share_private_rooms", table="users_who_share_private_rooms",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcol="room_id", retcol="room_id",
desc="get_rooms_user_is_in", desc="get_rooms_user_is_in",
) )
pub_rows = yield self.db_pool.simple_select_onecol( pub_rows = await self.db_pool.simple_select_onecol(
table="users_in_public_rooms", table="users_in_public_rooms",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcol="room_id", retcol="room_id",
@ -674,32 +663,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows) users.update(rows)
return list(users) return list(users)
@defer.inlineCallbacks
def get_rooms_in_common_for_users(self, user_id, other_user_id):
"""Given two user_ids find out the list of rooms they share.
"""
sql = """
SELECT room_id FROM (
SELECT c.room_id FROM current_state_events AS c
INNER JOIN room_memberships AS m USING (event_id)
WHERE type = 'm.room.member'
AND m.membership = 'join'
AND state_key = ?
) AS f1 INNER JOIN (
SELECT c.room_id FROM current_state_events AS c
INNER JOIN room_memberships AS m USING (event_id)
WHERE type = 'm.room.member'
AND m.membership = 'join'
AND state_key = ?
) f2 USING (room_id)
"""
rows = yield self.db_pool.execute(
"get_rooms_in_common_for_users", None, sql, user_id, other_user_id
)
return [room_id for room_id, in rows]
def get_user_directory_stream_pos(self): def get_user_directory_stream_pos(self):
return self.db_pool.simple_select_one_onecol( return self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos", table="user_directory_stream_pos",
@ -708,8 +671,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
desc="get_user_directory_stream_pos", desc="get_user_directory_stream_pos",
) )
@defer.inlineCallbacks async def search_user_dir(self, user_id, search_term, limit):
def search_user_dir(self, user_id, search_term, limit):
"""Searches for users in directory """Searches for users in directory
Returns: Returns:
@ -806,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# This should be unreachable. # This should be unreachable.
raise Exception("Unrecognized database engine") raise Exception("Unrecognized database engine")
results = yield self.db_pool.execute( results = await self.db_pool.execute(
"search_user_dir", self.db_pool.cursor_to_dict, sql, *args "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
) )

View File

@ -40,7 +40,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
def test_search_user_dir(self): def test_search_user_dir(self):
# normally when alice searches the directory she should just find # normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her. # bob because bobby doesn't share a room with her.
r = yield self.store.search_user_dir(ALICE, "bob", 10) r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"]) self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"])) self.assertEqual(1, len(r["results"]))
self.assertDictEqual( self.assertDictEqual(
@ -51,7 +51,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
def test_search_user_dir_all_users(self): def test_search_user_dir_all_users(self):
self.hs.config.user_directory_search_all_users = True self.hs.config.user_directory_search_all_users = True
try: try:
r = yield self.store.search_user_dir(ALICE, "bob", 10) r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"]) self.assertFalse(r["limited"])
self.assertEqual(2, len(r["results"])) self.assertEqual(2, len(r["results"]))
self.assertDictEqual( self.assertDictEqual(