Merge pull request #1818 from matrix-org/erikj/state_auth_splitout_split

Optimise state resolution
This commit is contained in:
Erik Johnston 2017-01-18 10:53:00 +00:00 committed by GitHub
commit 15f012032c
8 changed files with 249 additions and 76 deletions

View File

@ -27,7 +27,7 @@ from synapse.types import UserID, get_domain_from_id
logger = logging.getLogger(__name__)
def check(event, auth_events, do_sig_check=True):
def check(event, auth_events, do_sig_check=True, do_size_check=True):
""" Checks if this event is correctly authed.
Args:
@ -38,6 +38,7 @@ def check(event, auth_events, do_sig_check=True):
Returns:
True if the auth checks pass.
"""
if do_size_check:
_check_size_limits(event)
if not hasattr(event, "room_id"):
@ -119,6 +120,7 @@ def check(event, auth_events, do_sig_check=True):
)
return True
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Auth events: %s",
[a.event_id for a in auth_events.values()]
@ -639,3 +641,38 @@ def get_public_keys(invite_event):
public_keys.append(o)
public_keys.extend(invite_event.content.get("public_keys", []))
return public_keys
def auth_types_for_event(event):
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room.
Used to limit the number of events to fetch from the database to
actually auth the event.
"""
if event.type == EventTypes.Create:
return []
auth_types = []
auth_types.append((EventTypes.PowerLevels, "", ))
auth_types.append((EventTypes.Member, event.user_id, ))
auth_types.append((EventTypes.Create, "", ))
if event.type == EventTypes.Member:
membership = event.content["membership"]
if membership in [Membership.JOIN, Membership.INVITE]:
auth_types.append((EventTypes.JoinRules, "", ))
auth_types.append((EventTypes.Member, event.state_key, ))
if membership == Membership.INVITE:
if "third_party_invite" in event.content:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
)
auth_types.append(key)
return auth_types

View File

@ -79,7 +79,6 @@ class EventBase(object):
auth_events = _event_dict_property("auth_events")
depth = _event_dict_property("depth")
content = _event_dict_property("content")
event_id = _event_dict_property("event_id")
hashes = _event_dict_property("hashes")
origin = _event_dict_property("origin")
origin_server_ts = _event_dict_property("origin_server_ts")
@ -88,8 +87,6 @@ class EventBase(object):
redacts = _event_dict_property("redacts")
room_id = _event_dict_property("room_id")
sender = _event_dict_property("sender")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
user_id = _event_dict_property("sender")
@property
@ -162,6 +159,11 @@ class FrozenEvent(EventBase):
else:
frozen_dict = event_dict
self.event_id = event_dict["event_id"]
self.type = event_dict["type"]
if "state_key" in event_dict:
self.state_key = event_dict["state_key"]
super(FrozenEvent, self).__init__(
frozen_dict,
signatures=signatures,

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import EventBase, FrozenEvent
from . import EventBase, FrozenEvent, _event_dict_property
from synapse.types import EventID
@ -34,6 +34,10 @@ class EventBuilder(EventBase):
internal_metadata_dict=internal_metadata_dict,
)
event_id = _event_dict_property("event_id")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
def build(self):
return FrozenEvent.from_event(self)

View File

@ -26,7 +26,7 @@ from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent
from synapse.events import FrozenEvent, builder
import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
@ -499,8 +499,10 @@ class FederationClient(FederationBase):
if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = []
ev = builder.EventBuilder(pdu_dict)
defer.returnValue(
(destination, self.event_from_pdu_json(pdu_dict))
(destination, ev)
)
break
except CodeMessageException as e:

View File

@ -596,7 +596,7 @@ class FederationHandler(BaseHandler):
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
for e in event_ids
]))
states = dict(zip(event_ids, [s[1] for s in states]))
states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events(
[e_id for ids in states.values() for e_id in ids],
@ -1530,7 +1530,7 @@ class FederationHandler(BaseHandler):
(d.type, d.state_key): d for d in different_events if d
})
new_state, prev_state = self.state_handler.resolve_events(
new_state = self.state_handler.resolve_events(
[local_view.values(), remote_view.values()],
event
)

View File

@ -22,7 +22,6 @@ from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer
@ -48,6 +47,8 @@ EVICTION_TIMEOUT_SECONDS = 60 * 60
_NEXT_STATE_ID = 1
POWER_KEY = (EventTypes.PowerLevels, "")
def _gen_state_id():
global _NEXT_STATE_ID
@ -332,21 +333,13 @@ class StateHandler(object):
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
state_map = yield self.store.get_events(
[e_id for st in state_groups_ids.values() for e_id in st.values()],
get_prev_content=False
)
state_sets = [
[state_map[e_id] for key, e_id in st.items() if e_id in state_map]
for st in state_groups_ids.values()
]
with Measure(self.clock, "state._resolve_events"):
new_state, _ = resolve_events(
state_sets, event_type, state_key
new_state = yield resolve_events(
state_groups_ids.values(),
state_map_factory=lambda ev_ids: self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
),
)
new_state = {
key: e.event_id for key, e in new_state.items()
}
else:
new_state = {
key: e_ids.pop() for key, e_ids in state.items()
@ -394,13 +387,25 @@ class StateHandler(object):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
state_set_ids = [{
(ev.type, ev.state_key): ev.event_id
for ev in st
} for st in state_sets]
state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
with Measure(self.clock, "state._resolve_events"):
if event.is_state():
return resolve_events(
state_sets, event.type, event.state_key
)
else:
return resolve_events(state_sets)
new_state = resolve_events(state_set_ids, state_map)
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items()
}
return new_state
def _ordered_events(events):
@ -410,43 +415,131 @@ def _ordered_events(events):
return sorted(events, key=key_func)
def resolve_events(state_sets, event_type=None, state_key=""):
def resolve_events(state_sets, state_map_factory):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map_factory(dict|callable): If callable, then will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event. Otherwise, should be
a dict from event_id to event of all events in state_sets.
Returns
(dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple
(new_state, prev_states). new_state is a map from (type, state_key)
to event. prev_states is a list of event_ids.
dict[(str, str), synapse.events.FrozenEvent] is a map from
(type, state_key) to event.
"""
state = {}
for st in state_sets:
for e in st:
state.setdefault(
(e.type, e.state_key),
{}
)[e.event_id] = e
unconflicted_state = {
k: v.values()[0] for k, v in state.items()
if len(v.values()) == 1
}
conflicted_state = {
k: v.values()
for k, v in state.items()
if len(v.values()) > 1
}
if event_type:
prev_states_events = conflicted_state.get(
(event_type, state_key), []
unconflicted_state, conflicted_state = _seperate(
state_sets,
)
prev_states = [s.event_id for s in prev_states_events]
if callable(state_map_factory):
return _resolve_with_state_fac(
unconflicted_state, conflicted_state, state_map_factory
)
state_map = state_map_factory
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
return _resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
)
def _seperate(state_sets):
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
"""
unconflicted_state = dict(state_sets[0])
conflicted_state = {}
for state_set in state_sets[1:]:
for key, value in state_set.iteritems():
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
# There isn't an unconflicted entry so check if there is a
# conflicted entry.
ls = conflicted_state.get(key)
if ls is None:
# There wasn't a conflicted entry so haven't seen this key before.
# Therefore it isn't conflicted yet.
unconflicted_state[key] = value
else:
prev_states = []
# This key is already conflicted, add our value to the conflict set.
ls.add(value)
elif unconflicted_value != value:
# If the unconflicted value is not the same as our value then we
# have a new conflict. So move the key from the unconflicted_state
# to the conflicted state.
conflicted_state[key] = {value, unconflicted_value}
unconflicted_state.pop(key, None)
return unconflicted_state, conflicted_state
@defer.inlineCallbacks
def _resolve_with_state_fac(unconflicted_state, conflicted_state,
state_map_factory):
needed_events = set(
event_id
for event_ids in conflicted_state.itervalues()
for event_id in event_ids
)
logger.info("Asking for %d conflicted events", len(needed_events))
state_map = yield state_map_factory(needed_events)
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
new_needed_events = set(auth_events.itervalues())
new_needed_events -= needed_events
logger.info("Asking for %d auth events", len(new_needed_events))
state_map_new = yield state_map_factory(new_needed_events)
state_map.update(state_map_new)
defer.returnValue(_resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
))
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
for event_ids in conflicted_state.itervalues():
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
event_id = unconflicted_state.get(key, None)
if event_id:
auth_events[key] = event_id
return auth_events
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
state_map):
conflicted_state = {}
for key, event_ids in conflicted_state_ds.iteritems():
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
elif len(events) == 1:
unconflicted_state_ids[key] = events[0].event_id
auth_events = {
k: e for k, e in unconflicted_state.items()
if k[0] in AuthEventTypes
key: state_map[ev_id]
for key, ev_id in auth_event_ids.items()
if ev_id in state_map
}
try:
@ -457,10 +550,11 @@ def resolve_events(state_sets, event_type=None, state_key=""):
logger.exception("Failed to resolve state")
raise
new_state = unconflicted_state
new_state.update(resolved_state)
new_state = unconflicted_state_ids
for key, event in resolved_state.iteritems():
new_state[key] = event.event_id
return new_state, prev_states
return new_state
def _resolve_state_events(conflicted_state, auth_events):
@ -474,11 +568,10 @@ def _resolve_state_events(conflicted_state, auth_events):
4. other events.
"""
resolved_state = {}
power_key = (EventTypes.PowerLevels, "")
if power_key in conflicted_state:
events = conflicted_state[power_key]
if POWER_KEY in conflicted_state:
events = conflicted_state[POWER_KEY]
logger.debug("Resolving conflicted power levels %r", events)
resolved_state[power_key] = _resolve_auth_events(
resolved_state[POWER_KEY] = _resolve_auth_events(
events, auth_events)
auth_events.update(resolved_state)
@ -516,14 +609,26 @@ def _resolve_state_events(conflicted_state, auth_events):
def _resolve_auth_events(events, auth_events):
reverse = [i for i in reversed(_ordered_events(events))]
auth_events = dict(auth_events)
auth_keys = set(
key
for event in events
for key in event_auth.auth_types_for_event(event)
)
new_auth_events = {}
for key in auth_keys:
auth_event = auth_events.get(key, None)
if auth_event:
new_auth_events[key] = auth_event
auth_events = new_auth_events
prev_event = reverse[0]
for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False)
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
prev_event = event
except AuthError:
return prev_event
@ -535,7 +640,7 @@ def _resolve_normal_events(events, auth_events):
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False)
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
return event
except AuthError:
pass

View File

@ -25,10 +25,13 @@ from synapse.api.filtering import Filter
from synapse.events import FrozenEvent
user_localpart = "test_user"
# MockEvent = namedtuple("MockEvent", "sender type room_id")
def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
return FrozenEvent(kwargs)

View File

@ -21,6 +21,10 @@ from synapse.events.utils import prune_event, serialize_event
def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
return FrozenEvent(kwargs)
@ -35,9 +39,13 @@ class PruneEventTestCase(unittest.TestCase):
def test_minimal(self):
self.run_test(
{'type': 'A'},
{
'type': 'A',
'event_id': '$test:domain',
},
{
'type': 'A',
'event_id': '$test:domain',
'content': {},
'signatures': {},
'unsigned': {},
@ -69,10 +77,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test(
{
'type': 'B',
'event_id': '$test:domain',
'unsigned': {'age_ts': 20},
},
{
'type': 'B',
'event_id': '$test:domain',
'content': {},
'signatures': {},
'unsigned': {'age_ts': 20},
@ -82,10 +92,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test(
{
'type': 'B',
'event_id': '$test:domain',
'unsigned': {'other_key': 'here'},
},
{
'type': 'B',
'event_id': '$test:domain',
'content': {},
'signatures': {},
'unsigned': {},
@ -96,10 +108,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test(
{
'type': 'C',
'event_id': '$test:domain',
'content': {'things': 'here'},
},
{
'type': 'C',
'event_id': '$test:domain',
'content': {},
'signatures': {},
'unsigned': {},
@ -109,10 +123,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test(
{
'type': 'm.room.create',
'event_id': '$test:domain',
'content': {'creator': '@2:domain', 'other_field': 'here'},
},
{
'type': 'm.room.create',
'event_id': '$test:domain',
'content': {'creator': '@2:domain'},
'signatures': {},
'unsigned': {},
@ -255,6 +271,8 @@ class SerializeEventTestCase(unittest.TestCase):
self.assertEquals(
self.serialize(
MockEvent(
type="foo",
event_id="test",
room_id="!foo:bar",
content={
"foo": "bar",
@ -263,6 +281,8 @@ class SerializeEventTestCase(unittest.TestCase):
[]
),
{
"type": "foo",
"event_id": "test",
"room_id": "!foo:bar",
"content": {
"foo": "bar",