Optimise state resolution

This commit is contained in:
Erik Johnston 2017-01-13 13:16:54 +00:00
parent beda469bc6
commit 5d6bad1b3c
7 changed files with 230 additions and 73 deletions

View File

@ -27,7 +27,7 @@ from synapse.types import UserID, get_domain_from_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check(event, auth_events, do_sig_check=True): def check(event, auth_events, do_sig_check=True, do_size_check=True):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Args: Args:
@ -38,6 +38,7 @@ def check(event, auth_events, do_sig_check=True):
Returns: Returns:
True if the auth checks pass. True if the auth checks pass.
""" """
if do_size_check:
_check_size_limits(event) _check_size_limits(event)
if not hasattr(event, "room_id"): if not hasattr(event, "room_id"):
@ -119,6 +120,7 @@ def check(event, auth_events, do_sig_check=True):
) )
return True return True
if logger.isEnabledFor(logging.DEBUG):
logger.debug( logger.debug(
"Auth events: %s", "Auth events: %s",
[a.event_id for a in auth_events.values()] [a.event_id for a in auth_events.values()]
@ -639,3 +641,38 @@ def get_public_keys(invite_event):
public_keys.append(o) public_keys.append(o)
public_keys.extend(invite_event.content.get("public_keys", [])) public_keys.extend(invite_event.content.get("public_keys", []))
return public_keys return public_keys
def auth_types_for_event(event):
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room.
Used to limit the number of events to fetch from the database to
actually auth the event.
"""
if event.type == EventTypes.Create:
return []
auth_types = []
auth_types.append((EventTypes.PowerLevels, "", ))
auth_types.append((EventTypes.Member, event.user_id, ))
auth_types.append((EventTypes.Create, "", ))
if event.type == EventTypes.Member:
e_type = event.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]:
auth_types.append((EventTypes.JoinRules, "", ))
auth_types.append((EventTypes.Member, event.state_key, ))
if e_type == Membership.INVITE:
if "third_party_invite" in event.content:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
)
auth_types.append(key)
return auth_types

View File

@ -79,7 +79,6 @@ class EventBase(object):
auth_events = _event_dict_property("auth_events") auth_events = _event_dict_property("auth_events")
depth = _event_dict_property("depth") depth = _event_dict_property("depth")
content = _event_dict_property("content") content = _event_dict_property("content")
event_id = _event_dict_property("event_id")
hashes = _event_dict_property("hashes") hashes = _event_dict_property("hashes")
origin = _event_dict_property("origin") origin = _event_dict_property("origin")
origin_server_ts = _event_dict_property("origin_server_ts") origin_server_ts = _event_dict_property("origin_server_ts")
@ -88,8 +87,6 @@ class EventBase(object):
redacts = _event_dict_property("redacts") redacts = _event_dict_property("redacts")
room_id = _event_dict_property("room_id") room_id = _event_dict_property("room_id")
sender = _event_dict_property("sender") sender = _event_dict_property("sender")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
user_id = _event_dict_property("sender") user_id = _event_dict_property("sender")
@property @property
@ -162,6 +159,11 @@ class FrozenEvent(EventBase):
else: else:
frozen_dict = event_dict frozen_dict = event_dict
self.event_id = event_dict["event_id"]
self.type = event_dict["type"]
if "state_key" in event_dict:
self.state_key = event_dict["state_key"]
super(FrozenEvent, self).__init__( super(FrozenEvent, self).__init__(
frozen_dict, frozen_dict,
signatures=signatures, signatures=signatures,

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import EventBase, FrozenEvent from . import EventBase, FrozenEvent, _event_dict_property
from synapse.types import EventID from synapse.types import EventID
@ -34,6 +34,10 @@ class EventBuilder(EventBase):
internal_metadata_dict=internal_metadata_dict, internal_metadata_dict=internal_metadata_dict,
) )
event_id = _event_dict_property("event_id")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
def build(self): def build(self):
return FrozenEvent.from_event(self) return FrozenEvent.from_event(self)

View File

@ -1530,7 +1530,7 @@ class FederationHandler(BaseHandler):
(d.type, d.state_key): d for d in different_events if d (d.type, d.state_key): d for d in different_events if d
}) })
new_state, prev_state = self.state_handler.resolve_events( new_state = self.state_handler.resolve_events(
[local_view.values(), remote_view.values()], [local_view.values(), remote_view.values()],
event event
) )

View File

@ -22,11 +22,10 @@ from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from collections import namedtuple from collections import namedtuple, defaultdict
from frozendict import frozendict from frozendict import frozendict
import logging import logging
@ -48,6 +47,8 @@ EVICTION_TIMEOUT_SECONDS = 60 * 60
_NEXT_STATE_ID = 1 _NEXT_STATE_ID = 1
POWER_KEY = (EventTypes.PowerLevels, "")
def _gen_state_id(): def _gen_state_id():
global _NEXT_STATE_ID global _NEXT_STATE_ID
@ -328,21 +329,13 @@ class StateHandler(object):
if conflicted_state: if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id) logger.info("Resolving conflicted state for %r", room_id)
state_map = yield self.store.get_events(
[e_id for st in state_groups_ids.values() for e_id in st.values()],
get_prev_content=False
)
state_sets = [
[state_map[e_id] for key, e_id in st.items() if e_id in state_map]
for st in state_groups_ids.values()
]
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state, _ = resolve_events( new_state = yield resolve_events(
state_sets, event_type, state_key state_groups_ids.values(),
state_map_factory=lambda ev_ids: self.store.get_events(
ev_ids, get_prev_content=False
),
) )
new_state = {
key: e.event_id for key, e in new_state.items()
}
else: else:
new_state = { new_state = {
key: e_ids.pop() for key, e_ids in state.items() key: e_ids.pop() for key, e_ids in state.items()
@ -390,13 +383,25 @@ class StateHandler(object):
logger.info( logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets) "Resolving state for %s with %d groups", event.room_id, len(state_sets)
) )
state_set_ids = [{
(ev.type, ev.state_key): ev.event_id
for ev in st
} for st in state_sets]
state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
if event.is_state(): new_state = resolve_events(state_set_ids, state_map)
return resolve_events(
state_sets, event.type, event.state_key new_state = {
) key: state_map[ev_id] for key, ev_id in new_state.items()
else: }
return resolve_events(state_sets)
return new_state
def _ordered_events(events): def _ordered_events(events):
@ -406,43 +411,117 @@ def _ordered_events(events):
return sorted(events, key=key_func) return sorted(events, key=key_func)
def resolve_events(state_sets, event_type=None, state_key=""): def resolve_events(state_sets, state_map_factory):
""" """
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map_factory(dict|callable): If callable, then will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event. Otherwise, should be
a dict from event_id to event of all events in state_sets.
Returns Returns
(dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple dict[(str, str), synapse.events.FrozenEvent] is a map from
(new_state, prev_states). new_state is a map from (type, state_key) (type, state_key) to event.
to event. prev_states is a list of event_ids.
""" """
state = {} unconflicted_state, conflicted_state = _seperate(
for st in state_sets: state_sets,
for e in st:
state.setdefault(
(e.type, e.state_key),
{}
)[e.event_id] = e
unconflicted_state = {
k: v.values()[0] for k, v in state.items()
if len(v.values()) == 1
}
conflicted_state = {
k: v.values()
for k, v in state.items()
if len(v.values()) > 1
}
if event_type:
prev_states_events = conflicted_state.get(
(event_type, state_key), []
) )
prev_states = [s.event_id for s in prev_states_events]
else: if callable(state_map_factory):
prev_states = [] return _resolve_with_state_fac(
unconflicted_state, conflicted_state, state_map_factory
)
state_map = state_map_factory
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
return _resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
)
def _seperate(state_sets):
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
"""
unconflicted_state = dict(state_sets[0])
conflicted_state = {}
full_states = defaultdict(
set,
{k: set((v,)) for k, v in state_sets[0].iteritems()}
)
for state_set in state_sets[1:]:
for key, value in state_set.iteritems():
ls = full_states[key]
if not ls:
ls.add(value)
unconflicted_state[key] = value
elif value not in ls:
ls.add(value)
if len(ls) == 2:
conflicted_state[key] = ls
unconflicted_state.pop(key, None)
return unconflicted_state, conflicted_state
@defer.inlineCallbacks
def _resolve_with_state_fac(unconflicted_state, conflicted_state,
state_map_factory):
needed_events = set(
event_id
for event_ids in conflicted_state.itervalues()
for event_id in event_ids
)
state_map = yield state_map_factory(needed_events)
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
new_needed_events = set(auth_events.itervalues())
new_needed_events -= needed_events
state_map_new = yield state_map_factory(new_needed_events)
state_map.update(state_map_new)
defer.returnValue(_resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
))
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
for event_ids in conflicted_state.itervalues():
for event_id in event_ids:
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
event_id = unconflicted_state.get(key, None)
if event_id:
auth_events[key] = event_id
return auth_events
def _resolve_with_state(unconflicted_state, conflicted_state, auth_events,
state_map):
conflicted_state = {
key: [state_map[ev_id] for ev_id in event_ids]
for key, event_ids in conflicted_state.items()
}
auth_events = { auth_events = {
k: e for k, e in unconflicted_state.items() key: state_map[ev_id]
if k[0] in AuthEventTypes for key, ev_id in auth_events.items()
} }
try: try:
@ -454,9 +533,10 @@ def resolve_events(state_sets, event_type=None, state_key=""):
raise raise
new_state = unconflicted_state new_state = unconflicted_state
new_state.update(resolved_state) for key, event in resolved_state.iteritems():
new_state[key] = event.event_id
return new_state, prev_states return new_state
def _resolve_state_events(conflicted_state, auth_events): def _resolve_state_events(conflicted_state, auth_events):
@ -470,11 +550,10 @@ def _resolve_state_events(conflicted_state, auth_events):
4. other events. 4. other events.
""" """
resolved_state = {} resolved_state = {}
power_key = (EventTypes.PowerLevels, "") if POWER_KEY in conflicted_state:
if power_key in conflicted_state: events = conflicted_state[POWER_KEY]
events = conflicted_state[power_key]
logger.debug("Resolving conflicted power levels %r", events) logger.debug("Resolving conflicted power levels %r", events)
resolved_state[power_key] = _resolve_auth_events( resolved_state[POWER_KEY] = _resolve_auth_events(
events, auth_events) events, auth_events)
auth_events.update(resolved_state) auth_events.update(resolved_state)
@ -512,14 +591,26 @@ def _resolve_state_events(conflicted_state, auth_events):
def _resolve_auth_events(events, auth_events): def _resolve_auth_events(events, auth_events):
reverse = [i for i in reversed(_ordered_events(events))] reverse = [i for i in reversed(_ordered_events(events))]
auth_events = dict(auth_events) auth_keys = set(
key
for event in events
for key in event_auth.auth_types_for_event(event)
)
new_auth_events = {}
for key in auth_keys:
auth_event = auth_events.get(key, None)
if auth_event:
new_auth_events[key] = auth_event
auth_events = new_auth_events
prev_event = reverse[0] prev_event = reverse[0]
for event in reverse[1:]: for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try: try:
# The signatures have already been checked at this point # The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False) event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
prev_event = event prev_event = event
except AuthError: except AuthError:
return prev_event return prev_event
@ -531,7 +622,7 @@ def _resolve_normal_events(events, auth_events):
for event in _ordered_events(events): for event in _ordered_events(events):
try: try:
# The signatures have already been checked at this point # The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False) event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
return event return event
except AuthError: except AuthError:
pass pass

View File

@ -25,10 +25,13 @@ from synapse.api.filtering import Filter
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
user_localpart = "test_user" user_localpart = "test_user"
# MockEvent = namedtuple("MockEvent", "sender type room_id")
def MockEvent(**kwargs): def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
return FrozenEvent(kwargs) return FrozenEvent(kwargs)

View File

@ -21,6 +21,10 @@ from synapse.events.utils import prune_event, serialize_event
def MockEvent(**kwargs): def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
return FrozenEvent(kwargs) return FrozenEvent(kwargs)
@ -35,9 +39,13 @@ class PruneEventTestCase(unittest.TestCase):
def test_minimal(self): def test_minimal(self):
self.run_test( self.run_test(
{'type': 'A'},
{ {
'type': 'A', 'type': 'A',
'event_id': '$test:domain',
},
{
'type': 'A',
'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
@ -69,10 +77,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test( self.run_test(
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain',
'unsigned': {'age_ts': 20}, 'unsigned': {'age_ts': 20},
}, },
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {'age_ts': 20}, 'unsigned': {'age_ts': 20},
@ -82,10 +92,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test( self.run_test(
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain',
'unsigned': {'other_key': 'here'}, 'unsigned': {'other_key': 'here'},
}, },
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
@ -96,10 +108,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test( self.run_test(
{ {
'type': 'C', 'type': 'C',
'event_id': '$test:domain',
'content': {'things': 'here'}, 'content': {'things': 'here'},
}, },
{ {
'type': 'C', 'type': 'C',
'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
@ -109,10 +123,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test( self.run_test(
{ {
'type': 'm.room.create', 'type': 'm.room.create',
'event_id': '$test:domain',
'content': {'creator': '@2:domain', 'other_field': 'here'}, 'content': {'creator': '@2:domain', 'other_field': 'here'},
}, },
{ {
'type': 'm.room.create', 'type': 'm.room.create',
'event_id': '$test:domain',
'content': {'creator': '@2:domain'}, 'content': {'creator': '@2:domain'},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
@ -255,6 +271,8 @@ class SerializeEventTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
self.serialize( self.serialize(
MockEvent( MockEvent(
type="foo",
event_id="test",
room_id="!foo:bar", room_id="!foo:bar",
content={ content={
"foo": "bar", "foo": "bar",
@ -263,6 +281,8 @@ class SerializeEventTestCase(unittest.TestCase):
[] []
), ),
{ {
"type": "foo",
"event_id": "test",
"room_id": "!foo:bar", "room_id": "!foo:bar",
"content": { "content": {
"foo": "bar", "foo": "bar",