Pass through room version to event auth

This commit is contained in:
Erik Johnston 2019-01-25 18:31:41 +00:00
parent b6dce9b9fd
commit ae2a957dba
10 changed files with 69 additions and 27 deletions

View File

@ -65,7 +65,7 @@ class Auth(object):
register_cache("cache", "token_cache", self.token_cache) register_cache("cache", "token_cache", self.token_cache)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True): def check_from_context(self, room_version, event, context, do_sig_check=True):
prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.compute_auth_events( auth_events_ids = yield self.compute_auth_events(
event, prev_state_ids, for_verification=True, event, prev_state_ids, for_verification=True,
@ -74,12 +74,16 @@ class Auth(object):
auth_events = { auth_events = {
(e.type, e.state_key): e for e in itervalues(auth_events) (e.type, e.state_key): e for e in itervalues(auth_events)
} }
self.check(event, auth_events=auth_events, do_sig_check=do_sig_check) self.check(
room_version, event,
auth_events=auth_events, do_sig_check=do_sig_check,
)
def check(self, event, auth_events, do_sig_check=True): def check(self, room_version, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Args: Args:
room_version (str): version of the room
event: the event being checked. event: the event being checked.
auth_events (dict: event-key -> event): the existing room state. auth_events (dict: event-key -> event): the existing room state.
@ -88,7 +92,9 @@ class Auth(object):
True if the auth checks pass. True if the auth checks pass.
""" """
with Measure(self.clock, "auth.check"): with Measure(self.clock, "auth.check"):
event_auth.check(event, auth_events, do_sig_check=do_sig_check) event_auth.check(
room_version, event, auth_events, do_sig_check=do_sig_check
)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_joined_room(self, room_id, user_id, current_state=None): def check_joined_room(self, room_id, user_id, current_state=None):

View File

@ -27,10 +27,11 @@ from synapse.types import UserID, get_domain_from_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check(event, auth_events, do_sig_check=True, do_size_check=True): def check(room_version, event, auth_events, do_sig_check=True, do_size_check=True):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Args: Args:
room_version (str): the version of the room
event: the event being checked. event: the event being checked.
auth_events (dict: event-key -> event): the existing room state. auth_events (dict: event-key -> event): the existing room state.

View File

@ -1189,7 +1189,9 @@ class FederationHandler(BaseHandler):
# The remote hasn't signed it yet, obviously. We'll do the full checks # The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request` # when we get the event back in `on_send_join_request`
yield self.auth.check_from_context(event, context, do_sig_check=False) yield self.auth.check_from_context(
room_version, event, context, do_sig_check=False,
)
defer.returnValue(event) defer.returnValue(event)
@ -1388,7 +1390,9 @@ class FederationHandler(BaseHandler):
try: try:
# The remote hasn't signed it yet, obviously. We'll do the full checks # The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request` # when we get the event back in `on_send_leave_request`
yield self.auth.check_from_context(event, context, do_sig_check=False) yield self.auth.check_from_context(
room_version, event, context, do_sig_check=False,
)
except AuthError as e: except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e) logger.warn("Failed to create new leave %r because %s", event, e)
raise e raise e
@ -1683,7 +1687,7 @@ class FederationHandler(BaseHandler):
auth_for_e[(EventTypes.Create, "")] = create_event auth_for_e[(EventTypes.Create, "")] = create_event
try: try:
self.auth.check(e, auth_events=auth_for_e) self.auth.check(room_version, e, auth_events=auth_for_e)
except SynapseError as err: except SynapseError as err:
# we may get SynapseErrors here as well as AuthErrors. For # we may get SynapseErrors here as well as AuthErrors. For
# instance, there are a couple of (ancient) events in some # instance, there are a couple of (ancient) events in some
@ -1927,6 +1931,8 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state different_auth = event_auth_events - current_state
room_version = yield self.store.get_room_version(event.room_id)
if different_auth and not event.internal_metadata.is_outlier(): if different_auth and not event.internal_metadata.is_outlier():
# Do auth conflict res. # Do auth conflict res.
logger.info("Different auth: %s", different_auth) logger.info("Different auth: %s", different_auth)
@ -1951,8 +1957,6 @@ class FederationHandler(BaseHandler):
(d.type, d.state_key): d for d in different_events if d (d.type, d.state_key): d for d in different_events if d
}) })
room_version = yield self.store.get_room_version(event.room_id)
new_state = yield self.state_handler.resolve_events( new_state = yield self.state_handler.resolve_events(
room_version, room_version,
[list(local_view.values()), list(remote_view.values())], [list(local_view.values()), list(remote_view.values())],
@ -2052,7 +2056,7 @@ class FederationHandler(BaseHandler):
) )
try: try:
self.auth.check(event, auth_events=auth_events) self.auth.check(room_version, event, auth_events=auth_events)
except AuthError as e: except AuthError as e:
logger.warn("Failed auth resolution for %r because %s", event, e) logger.warn("Failed auth resolution for %r because %s", event, e)
raise e raise e
@ -2288,7 +2292,7 @@ class FederationHandler(BaseHandler):
) )
try: try:
yield self.auth.check_from_context(event, context) yield self.auth.check_from_context(room_version, event, context)
except AuthError as e: except AuthError as e:
logger.warn("Denying new third party invite %r because %s", event, e) logger.warn("Denying new third party invite %r because %s", event, e)
raise e raise e
@ -2330,7 +2334,7 @@ class FederationHandler(BaseHandler):
) )
try: try:
self.auth.check_from_context(event, context) self.auth.check_from_context(room_version, event, context)
except AuthError as e: except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e) logger.warn("Denying third party invite %r because %s", event, e)
raise e raise e

View File

@ -611,8 +611,13 @@ class EventCreationHandler(object):
extra_users (list(UserID)): Any extra users to notify about event extra_users (list(UserID)): Any extra users to notify about event
""" """
if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""):
room_version = event.content["room_version"]
else:
room_version = yield self.store.get_room_version(event.room_id)
try: try:
yield self.auth.check_from_context(event, context) yield self.auth.check_from_context(room_version, event, context)
except AuthError as err: except AuthError as err:
logger.warn("Denying new event %r because %s", event, err) logger.warn("Denying new event %r because %s", event, err)
raise err raise err

View File

@ -123,7 +123,10 @@ class RoomCreationHandler(BaseHandler):
token_id=requester.access_token_id, token_id=requester.access_token_id,
) )
) )
yield self.auth.check_from_context(tombstone_event, tombstone_context) old_room_version = yield self.store.get_room_version(old_room_id)
yield self.auth.check_from_context(
old_room_version, tombstone_event, tombstone_context,
)
yield self.clone_exiting_room( yield self.clone_exiting_room(
requester, requester,

View File

@ -611,7 +611,7 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
RoomVersions.VDH_TEST, RoomVersions.STATE_V2_TEST, RoomVersions.V2, RoomVersions.VDH_TEST, RoomVersions.STATE_V2_TEST, RoomVersions.V2,
): ):
return v2.resolve_events_with_store( return v2.resolve_events_with_store(
state_sets, event_map, state_res_store, room_version, state_sets, event_map, state_res_store,
) )
else: else:
# This should only happen if we added a version but forgot to add it to # This should only happen if we added a version but forgot to add it to

View File

@ -21,7 +21,7 @@ from six import iteritems, iterkeys, itervalues
from twisted.internet import defer from twisted.internet import defer
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes, RoomVersions
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -274,7 +274,11 @@ def _resolve_auth_events(events, auth_events):
auth_events[(prev_event.type, prev_event.state_key)] = prev_event auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try: try:
# The signatures have already been checked at this point # The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) event_auth.check(
RoomVersions.V1, event, auth_events,
do_sig_check=False,
do_size_check=False,
)
prev_event = event prev_event = event
except AuthError: except AuthError:
return prev_event return prev_event
@ -286,7 +290,11 @@ def _resolve_normal_events(events, auth_events):
for event in _ordered_events(events): for event in _ordered_events(events):
try: try:
# The signatures have already been checked at this point # The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) event_auth.check(
RoomVersions.V1, event, auth_events,
do_sig_check=False,
do_size_check=False,
)
return event return event
except AuthError: except AuthError:
pass pass

View File

@ -29,10 +29,12 @@ logger = logging.getLogger(__name__)
@defer.inlineCallbacks @defer.inlineCallbacks
def resolve_events_with_store(state_sets, event_map, state_res_store): def resolve_events_with_store(room_version, state_sets, event_map, state_res_store):
"""Resolves the state using the v2 state resolution algorithm """Resolves the state using the v2 state resolution algorithm
Args: Args:
room_version (str): The room version
state_sets(list): List of dicts of (type, state_key) -> event_id, state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve. which are the different state groups to resolve.
@ -104,7 +106,7 @@ def resolve_events_with_store(state_sets, event_map, state_res_store):
# Now sequentially auth each one # Now sequentially auth each one
resolved_state = yield _iterative_auth_checks( resolved_state = yield _iterative_auth_checks(
sorted_power_events, unconflicted_state, event_map, room_version, sorted_power_events, unconflicted_state, event_map,
state_res_store, state_res_store,
) )
@ -129,7 +131,7 @@ def resolve_events_with_store(state_sets, event_map, state_res_store):
logger.debug("resolving remaining events") logger.debug("resolving remaining events")
resolved_state = yield _iterative_auth_checks( resolved_state = yield _iterative_auth_checks(
leftover_events, resolved_state, event_map, room_version, leftover_events, resolved_state, event_map,
state_res_store, state_res_store,
) )
@ -350,11 +352,13 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
@defer.inlineCallbacks @defer.inlineCallbacks
def _iterative_auth_checks(event_ids, base_state, event_map, state_res_store): def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
state_res_store):
"""Sequentially apply auth checks to each event in given list, updating the """Sequentially apply auth checks to each event in given list, updating the
state as it goes along. state as it goes along.
Args: Args:
room_version (str)
event_ids (list[str]): Ordered list of events to apply auth checks to event_ids (list[str]): Ordered list of events to apply auth checks to
base_state (dict[tuple[str, str], str]): The set of state to start with base_state (dict[tuple[str, str], str]): The set of state to start with
event_map (dict[str,FrozenEvent]) event_map (dict[str,FrozenEvent])
@ -385,7 +389,7 @@ def _iterative_auth_checks(event_ids, base_state, event_map, state_res_store):
try: try:
event_auth.check( event_auth.check(
event, auth_events, room_version, event, auth_events,
do_sig_check=False, do_sig_check=False,
do_size_check=False do_size_check=False
) )

View File

@ -19,7 +19,7 @@ from six.moves import zip
import attr import attr
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import EventTypes, JoinRules, Membership, RoomVersions
from synapse.event_auth import auth_types_for_event from synapse.event_auth import auth_types_for_event
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
@ -539,6 +539,7 @@ class StateTestCase(unittest.TestCase):
state_before = dict(state_at_event[prev_events[0]]) state_before = dict(state_at_event[prev_events[0]])
else: else:
state_d = resolve_events_with_store( state_d = resolve_events_with_store(
RoomVersions.V2,
[state_at_event[n] for n in prev_events], [state_at_event[n] for n in prev_events],
event_map=event_map, event_map=event_map,
state_res_store=TestStateResolutionStore(event_map), state_res_store=TestStateResolutionStore(event_map),
@ -685,6 +686,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
# Test that we correctly handle passing `None` as the event_map # Test that we correctly handle passing `None` as the event_map
state_d = resolve_events_with_store( state_d = resolve_events_with_store(
RoomVersions.V2,
[self.state_at_bob, self.state_at_charlie], [self.state_at_bob, self.state_at_charlie],
event_map=None, event_map=None,
state_res_store=TestStateResolutionStore(self.event_map), state_res_store=TestStateResolutionStore(self.event_map),

View File

@ -16,6 +16,7 @@
import unittest import unittest
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import RoomVersions
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
@ -35,12 +36,16 @@ class EventAuthTestCase(unittest.TestCase):
} }
# creator should be able to send state # creator should be able to send state
event_auth.check(_random_state_event(creator), auth_events, do_sig_check=False) event_auth.check(
RoomVersions.V1, _random_state_event(creator), auth_events,
do_sig_check=False,
)
# joiner should not be able to send state # joiner should not be able to send state
self.assertRaises( self.assertRaises(
AuthError, AuthError,
event_auth.check, event_auth.check,
RoomVersions.V1,
_random_state_event(joiner), _random_state_event(joiner),
auth_events, auth_events,
do_sig_check=False, do_sig_check=False,
@ -69,13 +74,17 @@ class EventAuthTestCase(unittest.TestCase):
self.assertRaises( self.assertRaises(
AuthError, AuthError,
event_auth.check, event_auth.check,
RoomVersions.V1,
_random_state_event(pleb), _random_state_event(pleb),
auth_events, auth_events,
do_sig_check=False, do_sig_check=False,
), ),
# king should be able to send state # king should be able to send state
event_auth.check(_random_state_event(king), auth_events, do_sig_check=False) event_auth.check(
RoomVersions.V1, _random_state_event(king), auth_events,
do_sig_check=False,
)
# helpers for making events # helpers for making events