Merge pull request #1818 from matrix-org/erikj/state_auth_splitout_split

Optimise state resolution
This commit is contained in:
Erik Johnston 2017-01-18 10:53:00 +00:00 committed by GitHub
commit 15f012032c
8 changed files with 249 additions and 76 deletions

View File

@ -27,7 +27,7 @@ from synapse.types import UserID, get_domain_from_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check(event, auth_events, do_sig_check=True): def check(event, auth_events, do_sig_check=True, do_size_check=True):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Args: Args:
@ -38,6 +38,7 @@ def check(event, auth_events, do_sig_check=True):
Returns: Returns:
True if the auth checks pass. True if the auth checks pass.
""" """
if do_size_check:
_check_size_limits(event) _check_size_limits(event)
if not hasattr(event, "room_id"): if not hasattr(event, "room_id"):
@ -119,6 +120,7 @@ def check(event, auth_events, do_sig_check=True):
) )
return True return True
if logger.isEnabledFor(logging.DEBUG):
logger.debug( logger.debug(
"Auth events: %s", "Auth events: %s",
[a.event_id for a in auth_events.values()] [a.event_id for a in auth_events.values()]
@ -639,3 +641,38 @@ def get_public_keys(invite_event):
public_keys.append(o) public_keys.append(o)
public_keys.extend(invite_event.content.get("public_keys", [])) public_keys.extend(invite_event.content.get("public_keys", []))
return public_keys return public_keys
def auth_types_for_event(event):
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room.
Used to limit the number of events to fetch from the database to
actually auth the event.
"""
if event.type == EventTypes.Create:
return []
auth_types = []
auth_types.append((EventTypes.PowerLevels, "", ))
auth_types.append((EventTypes.Member, event.user_id, ))
auth_types.append((EventTypes.Create, "", ))
if event.type == EventTypes.Member:
membership = event.content["membership"]
if membership in [Membership.JOIN, Membership.INVITE]:
auth_types.append((EventTypes.JoinRules, "", ))
auth_types.append((EventTypes.Member, event.state_key, ))
if membership == Membership.INVITE:
if "third_party_invite" in event.content:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
)
auth_types.append(key)
return auth_types

View File

@ -79,7 +79,6 @@ class EventBase(object):
auth_events = _event_dict_property("auth_events") auth_events = _event_dict_property("auth_events")
depth = _event_dict_property("depth") depth = _event_dict_property("depth")
content = _event_dict_property("content") content = _event_dict_property("content")
event_id = _event_dict_property("event_id")
hashes = _event_dict_property("hashes") hashes = _event_dict_property("hashes")
origin = _event_dict_property("origin") origin = _event_dict_property("origin")
origin_server_ts = _event_dict_property("origin_server_ts") origin_server_ts = _event_dict_property("origin_server_ts")
@ -88,8 +87,6 @@ class EventBase(object):
redacts = _event_dict_property("redacts") redacts = _event_dict_property("redacts")
room_id = _event_dict_property("room_id") room_id = _event_dict_property("room_id")
sender = _event_dict_property("sender") sender = _event_dict_property("sender")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
user_id = _event_dict_property("sender") user_id = _event_dict_property("sender")
@property @property
@ -162,6 +159,11 @@ class FrozenEvent(EventBase):
else: else:
frozen_dict = event_dict frozen_dict = event_dict
self.event_id = event_dict["event_id"]
self.type = event_dict["type"]
if "state_key" in event_dict:
self.state_key = event_dict["state_key"]
super(FrozenEvent, self).__init__( super(FrozenEvent, self).__init__(
frozen_dict, frozen_dict,
signatures=signatures, signatures=signatures,

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import EventBase, FrozenEvent from . import EventBase, FrozenEvent, _event_dict_property
from synapse.types import EventID from synapse.types import EventID
@ -34,6 +34,10 @@ class EventBuilder(EventBase):
internal_metadata_dict=internal_metadata_dict, internal_metadata_dict=internal_metadata_dict,
) )
event_id = _event_dict_property("event_id")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
def build(self): def build(self):
return FrozenEvent.from_event(self) return FrozenEvent.from_event(self)

View File

@ -26,7 +26,7 @@ from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent from synapse.events import FrozenEvent, builder
import synapse.metrics import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
@ -499,8 +499,10 @@ class FederationClient(FederationBase):
if "prev_state" not in pdu_dict: if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = [] pdu_dict["prev_state"] = []
ev = builder.EventBuilder(pdu_dict)
defer.returnValue( defer.returnValue(
(destination, self.event_from_pdu_json(pdu_dict)) (destination, ev)
) )
break break
except CodeMessageException as e: except CodeMessageException as e:

View File

@ -596,7 +596,7 @@ class FederationHandler(BaseHandler):
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
for e in event_ids for e in event_ids
])) ]))
states = dict(zip(event_ids, [s[1] for s in states])) states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events( state_map = yield self.store.get_events(
[e_id for ids in states.values() for e_id in ids], [e_id for ids in states.values() for e_id in ids],
@ -1530,7 +1530,7 @@ class FederationHandler(BaseHandler):
(d.type, d.state_key): d for d in different_events if d (d.type, d.state_key): d for d in different_events if d
}) })
new_state, prev_state = self.state_handler.resolve_events( new_state = self.state_handler.resolve_events(
[local_view.values(), remote_view.values()], [local_view.values(), remote_view.values()],
event event
) )

View File

@ -22,7 +22,6 @@ from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
@ -48,6 +47,8 @@ EVICTION_TIMEOUT_SECONDS = 60 * 60
_NEXT_STATE_ID = 1 _NEXT_STATE_ID = 1
POWER_KEY = (EventTypes.PowerLevels, "")
def _gen_state_id(): def _gen_state_id():
global _NEXT_STATE_ID global _NEXT_STATE_ID
@ -332,21 +333,13 @@ class StateHandler(object):
if conflicted_state: if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id) logger.info("Resolving conflicted state for %r", room_id)
state_map = yield self.store.get_events(
[e_id for st in state_groups_ids.values() for e_id in st.values()],
get_prev_content=False
)
state_sets = [
[state_map[e_id] for key, e_id in st.items() if e_id in state_map]
for st in state_groups_ids.values()
]
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state, _ = resolve_events( new_state = yield resolve_events(
state_sets, event_type, state_key state_groups_ids.values(),
state_map_factory=lambda ev_ids: self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
),
) )
new_state = {
key: e.event_id for key, e in new_state.items()
}
else: else:
new_state = { new_state = {
key: e_ids.pop() for key, e_ids in state.items() key: e_ids.pop() for key, e_ids in state.items()
@ -394,13 +387,25 @@ class StateHandler(object):
logger.info( logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets) "Resolving state for %s with %d groups", event.room_id, len(state_sets)
) )
state_set_ids = [{
(ev.type, ev.state_key): ev.event_id
for ev in st
} for st in state_sets]
state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
if event.is_state(): new_state = resolve_events(state_set_ids, state_map)
return resolve_events(
state_sets, event.type, event.state_key new_state = {
) key: state_map[ev_id] for key, ev_id in new_state.items()
else: }
return resolve_events(state_sets)
return new_state
def _ordered_events(events): def _ordered_events(events):
@ -410,43 +415,131 @@ def _ordered_events(events):
return sorted(events, key=key_func) return sorted(events, key=key_func)
def resolve_events(state_sets, event_type=None, state_key=""): def resolve_events(state_sets, state_map_factory):
""" """
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map_factory(dict|callable): If callable, then will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event. Otherwise, should be
a dict from event_id to event of all events in state_sets.
Returns Returns
(dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple dict[(str, str), synapse.events.FrozenEvent] is a map from
(new_state, prev_states). new_state is a map from (type, state_key) (type, state_key) to event.
to event. prev_states is a list of event_ids.
""" """
state = {} unconflicted_state, conflicted_state = _seperate(
for st in state_sets: state_sets,
for e in st:
state.setdefault(
(e.type, e.state_key),
{}
)[e.event_id] = e
unconflicted_state = {
k: v.values()[0] for k, v in state.items()
if len(v.values()) == 1
}
conflicted_state = {
k: v.values()
for k, v in state.items()
if len(v.values()) > 1
}
if event_type:
prev_states_events = conflicted_state.get(
(event_type, state_key), []
) )
prev_states = [s.event_id for s in prev_states_events]
if callable(state_map_factory):
return _resolve_with_state_fac(
unconflicted_state, conflicted_state, state_map_factory
)
state_map = state_map_factory
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
return _resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
)
def _seperate(state_sets):
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
"""
unconflicted_state = dict(state_sets[0])
conflicted_state = {}
for state_set in state_sets[1:]:
for key, value in state_set.iteritems():
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
# There isn't an unconflicted entry so check if there is a
# conflicted entry.
ls = conflicted_state.get(key)
if ls is None:
# There wasn't a conflicted entry so haven't seen this key before.
# Therefore it isn't conflicted yet.
unconflicted_state[key] = value
else: else:
prev_states = [] # This key is already conflicted, add our value to the conflict set.
ls.add(value)
elif unconflicted_value != value:
# If the unconflicted value is not the same as our value then we
# have a new conflict. So move the key from the unconflicted_state
# to the conflicted state.
conflicted_state[key] = {value, unconflicted_value}
unconflicted_state.pop(key, None)
return unconflicted_state, conflicted_state
@defer.inlineCallbacks
def _resolve_with_state_fac(unconflicted_state, conflicted_state,
state_map_factory):
needed_events = set(
event_id
for event_ids in conflicted_state.itervalues()
for event_id in event_ids
)
logger.info("Asking for %d conflicted events", len(needed_events))
state_map = yield state_map_factory(needed_events)
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
new_needed_events = set(auth_events.itervalues())
new_needed_events -= needed_events
logger.info("Asking for %d auth events", len(new_needed_events))
state_map_new = yield state_map_factory(new_needed_events)
state_map.update(state_map_new)
defer.returnValue(_resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
))
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
for event_ids in conflicted_state.itervalues():
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
event_id = unconflicted_state.get(key, None)
if event_id:
auth_events[key] = event_id
return auth_events
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
state_map):
conflicted_state = {}
for key, event_ids in conflicted_state_ds.iteritems():
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
elif len(events) == 1:
unconflicted_state_ids[key] = events[0].event_id
auth_events = { auth_events = {
k: e for k, e in unconflicted_state.items() key: state_map[ev_id]
if k[0] in AuthEventTypes for key, ev_id in auth_event_ids.items()
if ev_id in state_map
} }
try: try:
@ -457,10 +550,11 @@ def resolve_events(state_sets, event_type=None, state_key=""):
logger.exception("Failed to resolve state") logger.exception("Failed to resolve state")
raise raise
new_state = unconflicted_state new_state = unconflicted_state_ids
new_state.update(resolved_state) for key, event in resolved_state.iteritems():
new_state[key] = event.event_id
return new_state, prev_states return new_state
def _resolve_state_events(conflicted_state, auth_events): def _resolve_state_events(conflicted_state, auth_events):
@ -474,11 +568,10 @@ def _resolve_state_events(conflicted_state, auth_events):
4. other events. 4. other events.
""" """
resolved_state = {} resolved_state = {}
power_key = (EventTypes.PowerLevels, "") if POWER_KEY in conflicted_state:
if power_key in conflicted_state: events = conflicted_state[POWER_KEY]
events = conflicted_state[power_key]
logger.debug("Resolving conflicted power levels %r", events) logger.debug("Resolving conflicted power levels %r", events)
resolved_state[power_key] = _resolve_auth_events( resolved_state[POWER_KEY] = _resolve_auth_events(
events, auth_events) events, auth_events)
auth_events.update(resolved_state) auth_events.update(resolved_state)
@ -516,14 +609,26 @@ def _resolve_state_events(conflicted_state, auth_events):
def _resolve_auth_events(events, auth_events): def _resolve_auth_events(events, auth_events):
reverse = [i for i in reversed(_ordered_events(events))] reverse = [i for i in reversed(_ordered_events(events))]
auth_events = dict(auth_events) auth_keys = set(
key
for event in events
for key in event_auth.auth_types_for_event(event)
)
new_auth_events = {}
for key in auth_keys:
auth_event = auth_events.get(key, None)
if auth_event:
new_auth_events[key] = auth_event
auth_events = new_auth_events
prev_event = reverse[0] prev_event = reverse[0]
for event in reverse[1:]: for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try: try:
# The signatures have already been checked at this point # The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False) event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
prev_event = event prev_event = event
except AuthError: except AuthError:
return prev_event return prev_event
@ -535,7 +640,7 @@ def _resolve_normal_events(events, auth_events):
for event in _ordered_events(events): for event in _ordered_events(events):
try: try:
# The signatures have already been checked at this point # The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False) event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
return event return event
except AuthError: except AuthError:
pass pass

View File

@ -25,10 +25,13 @@ from synapse.api.filtering import Filter
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
user_localpart = "test_user" user_localpart = "test_user"
# MockEvent = namedtuple("MockEvent", "sender type room_id")
def MockEvent(**kwargs): def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
return FrozenEvent(kwargs) return FrozenEvent(kwargs)

View File

@ -21,6 +21,10 @@ from synapse.events.utils import prune_event, serialize_event
def MockEvent(**kwargs): def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
return FrozenEvent(kwargs) return FrozenEvent(kwargs)
@ -35,9 +39,13 @@ class PruneEventTestCase(unittest.TestCase):
def test_minimal(self): def test_minimal(self):
self.run_test( self.run_test(
{'type': 'A'},
{ {
'type': 'A', 'type': 'A',
'event_id': '$test:domain',
},
{
'type': 'A',
'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
@ -69,10 +77,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test( self.run_test(
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain',
'unsigned': {'age_ts': 20}, 'unsigned': {'age_ts': 20},
}, },
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {'age_ts': 20}, 'unsigned': {'age_ts': 20},
@ -82,10 +92,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test( self.run_test(
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain',
'unsigned': {'other_key': 'here'}, 'unsigned': {'other_key': 'here'},
}, },
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
@ -96,10 +108,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test( self.run_test(
{ {
'type': 'C', 'type': 'C',
'event_id': '$test:domain',
'content': {'things': 'here'}, 'content': {'things': 'here'},
}, },
{ {
'type': 'C', 'type': 'C',
'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
@ -109,10 +123,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test( self.run_test(
{ {
'type': 'm.room.create', 'type': 'm.room.create',
'event_id': '$test:domain',
'content': {'creator': '@2:domain', 'other_field': 'here'}, 'content': {'creator': '@2:domain', 'other_field': 'here'},
}, },
{ {
'type': 'm.room.create', 'type': 'm.room.create',
'event_id': '$test:domain',
'content': {'creator': '@2:domain'}, 'content': {'creator': '@2:domain'},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
@ -255,6 +271,8 @@ class SerializeEventTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
self.serialize( self.serialize(
MockEvent( MockEvent(
type="foo",
event_id="test",
room_id="!foo:bar", room_id="!foo:bar",
content={ content={
"foo": "bar", "foo": "bar",
@ -263,6 +281,8 @@ class SerializeEventTestCase(unittest.TestCase):
[] []
), ),
{ {
"type": "foo",
"event_id": "test",
"room_id": "!foo:bar", "room_id": "!foo:bar",
"content": { "content": {
"foo": "bar", "foo": "bar",