Convert synapse.api to async/await (#8031)

This commit is contained in:
Patrick Cloke 2020-08-06 08:30:06 -04:00 committed by GitHub
parent c36228c403
commit d4a7829b12
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 171 additions and 159 deletions

View file

@ -13,12 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from typing import List, Optional, Tuple
import pymacaroons
from netaddr import IPAddress
from twisted.internet import defer
from twisted.web.server import Request
import synapse.types
@ -80,13 +79,14 @@ class Auth(object):
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key
@defer.inlineCallbacks
def check_from_context(self, room_version: str, event, context, do_sig_check=True):
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
auth_events_ids = yield self.compute_auth_events(
async def check_from_context(
self, room_version: str, event, context, do_sig_check=True
):
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
@ -94,14 +94,13 @@ class Auth(object):
room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
)
@defer.inlineCallbacks
def check_user_in_room(
async def check_user_in_room(
self,
room_id: str,
user_id: str,
current_state: Optional[StateMap[EventBase]] = None,
allow_departed_users: bool = False,
):
) -> EventBase:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
@ -119,37 +118,35 @@ class Auth(object):
Raises:
AuthError if the user is/was not in the room.
Returns:
Deferred[Optional[EventBase]]:
Membership event for the user if the user was in the
room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room.
Membership event for the user if the user was in the
room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room.
"""
if current_state:
member = current_state.get((EventTypes.Member, user_id), None)
else:
member = yield defer.ensureDeferred(
self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
member = await self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
membership = member.membership if member else None
if membership == Membership.JOIN:
return member
if member:
membership = member.membership
# XXX this looks totally bogus. Why do we not allow users who have been banned,
# or those who were members previously and have been re-invited?
if allow_departed_users and membership == Membership.LEAVE:
forgot = yield self.store.did_forget(user_id, room_id)
if not forgot:
if membership == Membership.JOIN:
return member
# XXX this looks totally bogus. Why do we not allow users who have been banned,
# or those who were members previously and have been re-invited?
if allow_departed_users and membership == Membership.LEAVE:
forgot = await self.store.did_forget(user_id, room_id)
if not forgot:
return member
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
@defer.inlineCallbacks
def check_host_in_room(self, room_id, host):
async def check_host_in_room(self, room_id, host):
with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.is_host_joined(room_id, host)
latest_event_ids = await self.store.is_host_joined(room_id, host)
return latest_event_ids
def can_federate(self, event, auth_events):
@ -160,14 +157,13 @@ class Auth(object):
def get_public_keys(self, invite_event):
return event_auth.get_public_keys(invite_event)
@defer.inlineCallbacks
def get_user_by_req(
async def get_user_by_req(
self,
request: Request,
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
):
) -> synapse.types.Requester:
""" Get a registered user's ID.
Args:
@ -180,7 +176,7 @@ class Auth(object):
/login will deliver access tokens regardless of expiration.
Returns:
defer.Deferred: resolves to a `synapse.types.Requester` object
Resolves to the requester
Raises:
InvalidClientCredentialsError if no user by that token exists or the token
is invalid.
@ -194,14 +190,14 @@ class Auth(object):
access_token = self.get_access_token_from_request(request)
user_id, app_service = yield self._get_appservice_user_id(request)
user_id, app_service = await self._get_appservice_user_id(request)
if user_id:
request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("appservice_id", app_service.id)
if ip_addr and self._track_appservice_user_ips:
yield self.store.insert_client_ip(
await self.store.insert_client_ip(
user_id=user_id,
access_token=access_token,
ip=ip_addr,
@ -211,7 +207,7 @@ class Auth(object):
return synapse.types.create_requester(user_id, app_service=app_service)
user_info = yield self.get_user_by_access_token(
user_info = await self.get_user_by_access_token(
access_token, rights, allow_expired=allow_expired
)
user = user_info["user"]
@ -221,7 +217,7 @@ class Auth(object):
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
user_id = user.to_string()
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
if (
expiration_ts is not None
and self.clock.time_msec() >= expiration_ts
@ -235,7 +231,7 @@ class Auth(object):
device_id = user_info.get("device_id")
if user and access_token and ip_addr:
yield self.store.insert_client_ip(
await self.store.insert_client_ip(
user_id=user.to_string(),
access_token=access_token,
ip=ip_addr,
@ -261,8 +257,7 @@ class Auth(object):
except KeyError:
raise MissingClientTokenError()
@defer.inlineCallbacks
def _get_appservice_user_id(self, request):
async def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
@ -283,14 +278,13 @@ class Auth(object):
if not app_service.is_interested_in_user(user_id):
raise AuthError(403, "Application service cannot masquerade as this user.")
if not (yield self.store.get_user_by_id(user_id)):
if not (await self.store.get_user_by_id(user_id)):
raise AuthError(403, "Application service has not registered this user")
return user_id, app_service
@defer.inlineCallbacks
def get_user_by_access_token(
async def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False,
):
) -> dict:
""" Validate access token and get user_id from it
Args:
@ -300,7 +294,7 @@ class Auth(object):
allow_expired: If False, raises an InvalidClientTokenError
if the token is expired
Returns:
Deferred[dict]: dict that includes:
dict that includes:
`user` (UserID)
`is_guest` (bool)
`token_id` (int|None): access token id. May be None if guest
@ -314,7 +308,7 @@ class Auth(object):
if rights == "access":
# first look in the database
r = yield self._look_up_user_by_access_token(token)
r = await self._look_up_user_by_access_token(token)
if r:
valid_until_ms = r["valid_until_ms"]
if (
@ -352,7 +346,7 @@ class Auth(object):
# It would of course be much easier to store guest access
# tokens in the database as well, but that would break existing
# guest tokens.
stored_user = yield self.store.get_user_by_id(user_id)
stored_user = await self.store.get_user_by_id(user_id)
if not stored_user:
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
if not stored_user["is_guest"]:
@ -482,9 +476,8 @@ class Auth(object):
now = self.hs.get_clock().time_msec()
return now < expiry
@defer.inlineCallbacks
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token)
async def _look_up_user_by_access_token(self, token):
ret = await self.store.get_user_by_access_token(token)
if not ret:
return None
@ -507,7 +500,7 @@ class Auth(object):
logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError()
request.authenticated_entity = service.sender
return defer.succeed(service)
return service
async def is_server_admin(self, user: UserID) -> bool:
""" Check if the given user is a local server admin.
@ -522,7 +515,7 @@ class Auth(object):
def compute_auth_events(
self, event, current_state_ids: StateMap[str], for_verification: bool = False,
):
) -> List[str]:
"""Given an event and current state return the list of event IDs used
to auth an event.
@ -530,11 +523,11 @@ class Auth(object):
should be added to the event's `auth_events`.
Returns:
defer.Deferred(list[str]): List of event IDs.
List of event IDs.
"""
if event.type == EventTypes.Create:
return defer.succeed([])
return []
# Currently we ignore the `for_verification` flag even though there are
# some situations where we can drop particular auth events when adding
@ -553,7 +546,7 @@ class Auth(object):
if auth_ev_id:
auth_ids.append(auth_ev_id)
return defer.succeed(auth_ids)
return auth_ids
async def check_can_change_room_list(self, room_id: str, user: UserID):
"""Determine whether the user is allowed to edit the room's entry in the
@ -636,10 +629,9 @@ class Auth(object):
return query_params[0].decode("ascii")
@defer.inlineCallbacks
def check_user_in_room_or_world_readable(
async def check_user_in_room_or_world_readable(
self, room_id: str, user_id: str, allow_departed_users: bool = False
):
) -> Tuple[str, Optional[str]]:
"""Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised.
@ -650,10 +642,9 @@ class Auth(object):
members but have now departed
Returns:
Deferred[tuple[str, str|None]]: Resolves to the current membership of
the user in the room and the membership event ID of the user. If
the user is not in the room and never has been, then
`(Membership.JOIN, None)` is returned.
Resolves to the current membership of the user in the room and the
membership event ID of the user. If the user is not in the room and
never has been, then `(Membership.JOIN, None)` is returned.
"""
try:
@ -662,15 +653,13 @@ class Auth(object):
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = yield self.check_user_in_room(
member_event = await self.check_user_in_room(
room_id, user_id, allow_departed_users=allow_departed_users
)
return member_event.membership, member_event.event_id
except AuthError:
visibility = yield defer.ensureDeferred(
self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
visibility = await self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
visibility