diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index f7b948300..939418ee2 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -44,6 +44,7 @@ from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.users import UserAdminServlet from synapse.types import UserID, create_requester +from synapse.util.async_helpers import maybe_awaitable from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) @@ -310,7 +311,7 @@ class PurgeHistoryRestServlet(RestServlet): errcode=Codes.BAD_JSON, ) - purge_id = await self.pagination_handler.start_purge_history( + purge_id = self.pagination_handler.start_purge_history( room_id, token, delete_local_events=delete_local_events ) @@ -480,7 +481,9 @@ class ShutdownRoomRestServlet(RestServlet): ratelimit=False, ) - aliases_for_room = await self.store.get_aliases_for_room(room_id) + aliases_for_room = await maybe_awaitable( + self.store.get_aliases_for_room(room_id) + ) await self.store.update_aliases_for_room( room_id, new_room_id, requester_user_id diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 0d3bdd88c..804dbca44 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -21,6 +21,8 @@ from typing import Dict, Sequence, Set, Union from six.moves import range +import attr + from twisted.internet import defer from twisted.internet.defer import CancelledError from twisted.python import failure @@ -483,3 +485,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): deferred.addCallbacks(success_cb, failure_cb) return new_d + + +@attr.s(slots=True, frozen=True) +class DoneAwaitable(object): + """Simple awaitable that returns the provided value. + """ + + value = attr.ib() + + def __await__(self): + return self + + def __iter__(self): + return self + + def __next__(self): + raise StopIteration(self.value) + + +def maybe_awaitable(value): + """Convert a value to an awaitable if not already an awaitable. + """ + + if hasattr(value, "__await__"): + return value + + return DoneAwaitable(value)