Add early returns to _check_for_soft_fail (#7769)

my editor was complaining about unset variables, so let's add some early
returns to fix that and reduce indentation/cognitive load.
This commit is contained in:
Richard van der Hoff 2020-07-01 16:41:19 +01:00 committed by GitHub
parent f01e2ca039
commit e866512367
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 61 deletions

1
changelog.d/7769.misc Normal file
View File

@ -0,0 +1 @@
Add early returns to `_check_for_soft_fail`.

View File

@ -2061,76 +2061,67 @@ class FederationHandler(BaseHandler):
# For new (non-backfilled and non-outlier) events we check if the event # For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we # passes auth based on the current state. If it doesn't then we
# "soft-fail" the event. # "soft-fail" the event.
do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier() if backfilled or event.internal_metadata.is_outlier():
if do_soft_fail_check: return
extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids) extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
prev_event_ids = set(event.prev_event_ids()) extrem_ids = set(extrem_ids)
prev_event_ids = set(event.prev_event_ids())
if extrem_ids == prev_event_ids: if extrem_ids == prev_event_ids:
# If they're the same then the current state is the same as the # If they're the same then the current state is the same as the
# state at the event, so no point rechecking auth for soft fail. # state at the event, so no point rechecking auth for soft fail.
do_soft_fail_check = False return
if do_soft_fail_check: room_version = await self.store.get_room_version_id(event.room_id)
room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
# Calculate the "current state". # Calculate the "current state".
if state is not None: if state is not None:
# If we're explicitly given the state then we won't have all the # If we're explicitly given the state then we won't have all the
# prev events, and so we have a gap in the graph. In this case # prev events, and so we have a gap in the graph. In this case
# we want to be a little careful as we might have been down for # we want to be a little careful as we might have been down for
# a while and have an incorrect view of the current state, # a while and have an incorrect view of the current state,
# however we still want to do checks as gaps are easy to # however we still want to do checks as gaps are easy to
# maliciously manufacture. # maliciously manufacture.
# #
# So we use a "current state" that is actually a state # So we use a "current state" that is actually a state
# resolution across the current forward extremities and the # resolution across the current forward extremities and the
# given state at the event. This should correctly handle cases # given state at the event. This should correctly handle cases
# like bans, especially with state res v2. # like bans, especially with state res v2.
state_sets = await self.state_store.get_state_groups( state_sets = await self.state_store.get_state_groups(
event.room_id, extrem_ids event.room_id, extrem_ids
) )
state_sets = list(state_sets.values()) state_sets = list(state_sets.values())
state_sets.append(state) state_sets.append(state)
current_state_ids = await self.state_handler.resolve_events( current_state_ids = await self.state_handler.resolve_events(
room_version, state_sets, event room_version, state_sets, event
) )
current_state_ids = { current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
k: e.event_id for k, e in current_state_ids.items() else:
} current_state_ids = await self.state_handler.get_current_state_ids(
else: event.room_id, latest_event_ids=extrem_ids
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
)
logger.debug(
"Doing soft-fail check for %s: state %s",
event.event_id,
current_state_ids,
) )
# Now check if event pass auth against said current state logger.debug(
auth_types = auth_types_for_event(event) "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids,
current_state_ids = [ )
e for k, e in current_state_ids.items() if k in auth_types
]
current_auth_events = await self.store.get_events(current_state_ids) # Now check if event pass auth against said current state
current_auth_events = { auth_types = auth_types_for_event(event)
(e.type, e.state_key): e for e in current_auth_events.values() current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
}
try: current_auth_events = await self.store.get_events(current_state_ids)
event_auth.check( current_auth_events = {
room_version_obj, event, auth_events=current_auth_events (e.type, e.state_key): e for e in current_auth_events.values()
) }
except AuthError as e:
logger.warning("Soft-failing %r because %s", event, e) try:
event.internal_metadata.soft_failed = True event_auth.check(room_version_obj, event, auth_events=current_auth_events)
except AuthError as e:
logger.warning("Soft-failing %r because %s", event, e)
event.internal_metadata.soft_failed = True
async def on_query_auth( async def on_query_auth(
self, origin, event_id, room_id, remote_auth_chain, rejects, missing self, origin, event_id, room_id, remote_auth_chain, rejects, missing