Convert search code to async/await. (#7460)

This commit is contained in:
Patrick Cloke 2020-05-11 15:12:39 -04:00 committed by GitHub
parent 7cb8b4bc67
commit be309d99cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 24 deletions

1
changelog.d/7460.misc Normal file
View File

@ -0,0 +1 @@
Convert the search handler to async/await.

View File

@ -18,8 +18,6 @@ import logging
from unpaddedbase64 import decode_base64, encode_base64 from unpaddedbase64 import decode_base64, encode_base64
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
@ -39,8 +37,7 @@ class SearchHandler(BaseHandler):
self.state_store = self.storage.state self.state_store = self.storage.state
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def get_old_rooms_from_upgraded_room(self, room_id):
def get_old_rooms_from_upgraded_room(self, room_id):
"""Retrieves room IDs of old rooms in the history of an upgraded room. """Retrieves room IDs of old rooms in the history of an upgraded room.
We do so by checking the m.room.create event of the room for a We do so by checking the m.room.create event of the room for a
@ -60,7 +57,7 @@ class SearchHandler(BaseHandler):
historical_room_ids = [] historical_room_ids = []
# The initial room must have been known for us to get this far # The initial room must have been known for us to get this far
predecessor = yield self.store.get_room_predecessor(room_id) predecessor = await self.store.get_room_predecessor(room_id)
while True: while True:
if not predecessor: if not predecessor:
@ -75,7 +72,7 @@ class SearchHandler(BaseHandler):
# Don't add it to the list until we have checked that we are in the room # Don't add it to the list until we have checked that we are in the room
try: try:
next_predecessor_room = yield self.store.get_room_predecessor( next_predecessor_room = await self.store.get_room_predecessor(
predecessor_room_id predecessor_room_id
) )
except NotFoundError: except NotFoundError:
@ -89,8 +86,7 @@ class SearchHandler(BaseHandler):
return historical_room_ids return historical_room_ids
@defer.inlineCallbacks async def search(self, user, content, batch=None):
def search(self, user, content, batch=None):
"""Performs a full text search for a user. """Performs a full text search for a user.
Args: Args:
@ -179,7 +175,7 @@ class SearchHandler(BaseHandler):
search_filter = Filter(filter_dict) search_filter = Filter(filter_dict)
# TODO: Search through left rooms too # TODO: Search through left rooms too
rooms = yield self.store.get_rooms_for_local_user_where_membership_is( rooms = await self.store.get_rooms_for_local_user_where_membership_is(
user.to_string(), user.to_string(),
membership_list=[Membership.JOIN], membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban], # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
@ -192,7 +188,7 @@ class SearchHandler(BaseHandler):
historical_room_ids = [] historical_room_ids = []
for room_id in search_filter.rooms: for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist # Add any previous rooms to the search if they exist
ids = yield self.get_old_rooms_from_upgraded_room(room_id) ids = await self.get_old_rooms_from_upgraded_room(room_id)
historical_room_ids += ids historical_room_ids += ids
# Prevent any historical events from being filtered # Prevent any historical events from being filtered
@ -223,7 +219,7 @@ class SearchHandler(BaseHandler):
count = None count = None
if order_by == "rank": if order_by == "rank":
search_result = yield self.store.search_msgs(room_ids, search_term, keys) search_result = await self.store.search_msgs(room_ids, search_term, keys)
count = search_result["count"] count = search_result["count"]
@ -238,7 +234,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results]) filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client( events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events self.storage, user.to_string(), filtered_events
) )
@ -267,7 +263,7 @@ class SearchHandler(BaseHandler):
# But only go around 5 times since otherwise synapse will be sad. # But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit() and i < 5: while len(room_events) < search_filter.limit() and i < 5:
i += 1 i += 1
search_result = yield self.store.search_rooms( search_result = await self.store.search_rooms(
room_ids, room_ids,
search_term, search_term,
keys, keys,
@ -288,7 +284,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results]) filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client( events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events self.storage, user.to_string(), filtered_events
) )
@ -343,11 +339,11 @@ class SearchHandler(BaseHandler):
# If client has asked for "context" for each event (i.e. some surrounding # If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that # events and state), fetch that
if event_context is not None: if event_context is not None:
now_token = yield self.hs.get_event_sources().get_current_token() now_token = await self.hs.get_event_sources().get_current_token()
contexts = {} contexts = {}
for event in allowed_events: for event in allowed_events:
res = yield self.store.get_events_around( res = await self.store.get_events_around(
event.room_id, event.event_id, before_limit, after_limit event.room_id, event.event_id, before_limit, after_limit
) )
@ -357,11 +353,11 @@ class SearchHandler(BaseHandler):
len(res["events_after"]), len(res["events_after"]),
) )
res["events_before"] = yield filter_events_for_client( res["events_before"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_before"] self.storage, user.to_string(), res["events_before"]
) )
res["events_after"] = yield filter_events_for_client( res["events_after"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_after"] self.storage, user.to_string(), res["events_after"]
) )
@ -390,7 +386,7 @@ class SearchHandler(BaseHandler):
[(EventTypes.Member, sender) for sender in senders] [(EventTypes.Member, sender) for sender in senders]
) )
state = yield self.state_store.get_state_for_event( state = await self.state_store.get_state_for_event(
last_event_id, state_filter last_event_id, state_filter
) )
@ -412,10 +408,10 @@ class SearchHandler(BaseHandler):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
for context in contexts.values(): for context in contexts.values():
context["events_before"] = yield self._event_serializer.serialize_events( context["events_before"] = await self._event_serializer.serialize_events(
context["events_before"], time_now context["events_before"], time_now
) )
context["events_after"] = yield self._event_serializer.serialize_events( context["events_after"] = await self._event_serializer.serialize_events(
context["events_after"], time_now context["events_after"], time_now
) )
@ -423,7 +419,7 @@ class SearchHandler(BaseHandler):
if include_state: if include_state:
rooms = {e.room_id for e in allowed_events} rooms = {e.room_id for e in allowed_events}
for room_id in rooms: for room_id in rooms:
state = yield self.state_handler.get_current_state(room_id) state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values()) state_results[room_id] = list(state.values())
state_results.values() state_results.values()
@ -437,7 +433,7 @@ class SearchHandler(BaseHandler):
{ {
"rank": rank_map[e.event_id], "rank": rank_map[e.event_id],
"result": ( "result": (
yield self._event_serializer.serialize_event(e, time_now) await self._event_serializer.serialize_event(e, time_now)
), ),
"context": contexts.get(e.event_id, {}), "context": contexts.get(e.event_id, {}),
} }
@ -452,7 +448,7 @@ class SearchHandler(BaseHandler):
if state_results: if state_results:
s = {} s = {}
for room_id, state in state_results.items(): for room_id, state in state_results.items():
s[room_id] = yield self._event_serializer.serialize_events( s[room_id] = await self._event_serializer.serialize_events(
state, time_now state, time_now
) )