Implement new state resolution algorithm

This commit is contained in:
Erik Johnston 2015-01-21 16:27:04 +00:00
parent dc70d1fef8
commit 6dcade97be
2 changed files with 430 additions and 105 deletions

View File

@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from collections import namedtuple from collections import namedtuple
@ -42,6 +43,8 @@ class StateHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
# self.auth = hs.get_auth()
self.hs = hs
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""): def get_current_state(self, room_id, event_type=None, state_key=""):
@ -210,15 +213,22 @@ class StateHandler(object):
else: else:
prev_states = [] prev_states = []
auth_events = {
k: e for k, e in unconflicted_state.items()
if k[0] in (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,)
}
try: try:
new_state = {} resolved_state = self._resolve_state_events(
new_state.update(unconflicted_state) conflicted_state, auth_events
for key, events in conflicted_state.items(): )
new_state[key] = self._resolve_state_events(events)
except: except:
logger.exception("Failed to resolve state") logger.exception("Failed to resolve state")
raise raise
new_state = unconflicted_state
new_state.update(resolved_state)
defer.returnValue((None, new_state, prev_states)) defer.returnValue((None, new_state, prev_states))
def _get_power_level_from_event_state(self, event, user_id): def _get_power_level_from_event_state(self, event, user_id):
@ -238,36 +248,65 @@ class StateHandler(object):
return 0 return 0
@log_function @log_function
def _resolve_state_events(self, events): def _resolve_state_events(self, conflicted_state, auth_events):
curr_events = events resolved_state = {}
power_key = (EventTypes.PowerLevels, "")
if power_key in conflicted_state.items():
power_levels = conflicted_state[power_key]
resolved_state[power_key] = self._resolve_auth_events(power_levels)
new_powers = [ auth_events.update(resolved_state)
self._get_power_level_from_event_state(e, e.user_id)
for e in curr_events
]
new_powers = [ for key, events in conflicted_state.items():
int(p) if p else 0 for p in new_powers if key[0] == EventTypes.Member:
] resolved_state[key] = self._resolve_auth_events(
events,
auth_events
)
max_power = max(new_powers) auth_events.update(resolved_state)
curr_events = [ for key, events in conflicted_state.items():
z[0] for z in zip(curr_events, new_powers) if key not in resolved_state:
if z[1] == max_power resolved_state[key] = self._resolve_normal_events(
] events, auth_events
)
if not curr_events: return resolved_state
raise RuntimeError("Max didn't get a max?")
elif len(curr_events) == 1:
return curr_events[0]
# TODO: For now, just choose the one with the largest event_id. def _resolve_auth_events(self, events, auth_events):
return ( reverse = [i for i in reversed(self._ordered_events(events))]
sorted(
curr_events, auth_events = dict(auth_events)
key=lambda e: hashlib.sha1(
e.event_id + e.user_id + e.room_id + e.type prev_event = reverse[0]
).hexdigest() for event in reverse[1:]:
)[0] auth_events[(prev_event.type, prev_event.state_key)] = prev_event
) try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
self.hs.get_auth().check(event, auth_events)
prev_event = event
except AuthError:
return prev_event
return event
def _resolve_normal_events(self, events, auth_events):
for event in self._ordered_events(events):
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
self.hs.get_auth().check(event, auth_events)
return event
except AuthError as e:
pass
# Oh dear.
return event
def _ordered_events(self, events):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
return sorted(events, key=key_func)

View File

@ -16,11 +16,120 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.events import FrozenEvent
from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
from synapse.state import StateHandler from synapse.state import StateHandler
from mock import Mock from mock import Mock
_next_event_id = 1000
def create_event(name=None, type=None, state_key=None, depth=2, event_id=None,
prev_events=[], **kwargs):
global _next_event_id
if not event_id:
_next_event_id += 1
event_id = str(_next_event_id)
if not name:
if state_key is not None:
name = "<%s-%s, %s>" % (type, state_key, event_id,)
else:
name = "<%s, %s>" % (type, event_id,)
d = {
"event_id": event_id,
"type": type,
"sender": "@user_id:example.com",
"room_id": "!room_id:example.com",
"depth": depth,
"prev_events": prev_events,
}
if state_key is not None:
d["state_key"] = state_key
d.update(kwargs)
event = FrozenEvent(d)
return event
class StateGroupStore(object):
def __init__(self):
self._event_to_state_group = {}
self._group_to_state = {}
self._next_group = 1
def get_state_groups(self, event_ids):
groups = {}
for event_id in event_ids:
group = self._event_to_state_group.get(event_id)
if group:
groups[group] = self._group_to_state[group]
return defer.succeed(groups)
def store_state_groups(self, event, context):
if context.current_state is None:
return
state_events = context.current_state
if event.is_state():
state_events[(event.type, event.state_key)] = event
state_group = context.state_group
if not state_group:
state_group = self._next_group
self._next_group += 1
self._group_to_state[state_group] = state_events.values()
self._event_to_state_group[event.event_id] = state_group
class DictObj(dict):
def __init__(self, **kwargs):
super(DictObj, self).__init__(kwargs)
self.__dict__ = self
class Graph(object):
def __init__(self, nodes, edges):
events = {}
clobbered = set(events.keys())
for event_id, fields in nodes.items():
refs = edges.get(event_id)
if refs:
clobbered.difference_update(refs)
prev_events = [(r, {}) for r in refs]
else:
prev_events = []
events[event_id] = create_event(
event_id=event_id,
prev_events=prev_events,
**fields
)
self._leaves = clobbered
self._events = sorted(events.values(), key=lambda e: e.depth)
def walk(self):
return iter(self._events)
def get_leaves(self):
return (self._events[i] for i in self._leaves)
class StateTestCase(unittest.TestCase): class StateTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.store = Mock( self.store = Mock(
@ -29,20 +138,188 @@ class StateTestCase(unittest.TestCase):
"add_event_hashes", "add_event_hashes",
] ]
) )
hs = Mock(spec=["get_datastore"]) hs = Mock(spec=["get_datastore", "get_auth", "get_state_handler"])
hs.get_datastore.return_value = self.store hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None
hs.get_auth.return_value = Auth(hs)
self.state = StateHandler(hs) self.state = StateHandler(hs)
self.event_id = 0 self.event_id = 0
@defer.inlineCallbacks
def test_branch_no_conflict(self):
graph = Graph(
nodes={
"START": DictObj(
type=EventTypes.Create,
state_key="",
depth=1,
),
"A": DictObj(
type=EventTypes.Message,
depth=2,
),
"B": DictObj(
type=EventTypes.Message,
depth=3,
),
"C": DictObj(
type=EventTypes.Name,
state_key="",
depth=3,
),
"D": DictObj(
type=EventTypes.Message,
depth=4,
),
},
edges={
"A": ["START"],
"B": ["A"],
"C": ["A"],
"D": ["B", "C"]
}
)
store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context)
context_store[event.event_id] = context
self.assertEqual(2, len(context_store["D"].current_state))
@defer.inlineCallbacks
def test_branch_basic_conflict(self):
graph = Graph(
nodes={
"START": DictObj(
type=EventTypes.Create,
state_key="creator",
content={"membership": "@user_id:example.com"},
depth=1,
),
"A": DictObj(
type=EventTypes.Member,
state_key="@user_id:example.com",
content={"membership": Membership.JOIN},
membership=Membership.JOIN,
depth=2,
),
"B": DictObj(
type=EventTypes.Name,
state_key="",
depth=3,
),
"C": DictObj(
type=EventTypes.Name,
state_key="",
depth=4,
),
"D": DictObj(
type=EventTypes.Message,
depth=5,
),
},
edges={
"A": ["START"],
"B": ["A"],
"C": ["A"],
"D": ["B", "C"]
}
)
store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context)
context_store[event.event_id] = context
self.assertSetEqual(
{"START", "A", "C"},
{e.event_id for e in context_store["D"].current_state.values()}
)
@defer.inlineCallbacks
def test_branch_have_banned_conflict(self):
graph = Graph(
nodes={
"START": DictObj(
type=EventTypes.Create,
state_key="creator",
content={"membership": "@user_id:example.com"},
depth=1,
),
"A": DictObj(
type=EventTypes.Member,
state_key="@user_id:example.com",
content={"membership": Membership.JOIN},
membership=Membership.JOIN,
depth=2,
),
"B": DictObj(
type=EventTypes.Name,
state_key="",
depth=3,
),
"C": DictObj(
type=EventTypes.Member,
state_key="@user_id_2:example.com",
content={"membership": Membership.BAN},
membership=Membership.BAN,
depth=4,
),
"D": DictObj(
type=EventTypes.Name,
state_key="",
depth=4,
sender="@user_id_2:example.com",
),
"E": DictObj(
type=EventTypes.Message,
depth=5,
),
},
edges={
"A": ["START"],
"B": ["A"],
"C": ["B"],
"D": ["B"],
"E": ["C", "D"]
}
)
store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context)
context_store[event.event_id] = context
self.assertSetEqual(
{"START", "A", "B", "C"},
{e.event_id for e in context_store["E"].current_state.values()}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_annotate_with_old_message(self): def test_annotate_with_old_message(self):
event = self.create_event(type="test_message", name="event") event = create_event(type="test_message", name="event")
old_state = [ old_state = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
context = yield self.state.compute_event_context( context = yield self.state.compute_event_context(
@ -62,12 +339,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_annotate_with_old_state(self): def test_annotate_with_old_state(self):
event = self.create_event(type="state", state_key="", name="event") event = create_event(type="state", state_key="", name="event")
old_state = [ old_state = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
context = yield self.state.compute_event_context( context = yield self.state.compute_event_context(
@ -88,13 +365,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_message(self): def test_trivial_annotate_message(self):
event = self.create_event(type="test_message", name="event") event = create_event(type="test_message", name="event")
event.prev_events = []
old_state = [ old_state = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
group_name = "group_name_1" group_name = "group_name_1"
@ -119,13 +395,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_state(self): def test_trivial_annotate_state(self):
event = self.create_event(type="state", state_key="", name="event") event = create_event(type="state", state_key="", name="event")
event.prev_events = []
old_state = [ old_state = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
group_name = "group_name_1" group_name = "group_name_1"
@ -150,30 +425,21 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_message_conflict(self): def test_resolve_message_conflict(self):
event = self.create_event(type="test_message", name="event") event = create_event(type="test_message", name="event")
event.prev_events = []
old_state_1 = [ old_state_1 = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
old_state_2 = [ old_state_2 = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test3", state_key="2"), create_event(type="test3", state_key="2"),
self.create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
] ]
group_name_1 = "group_name_1" context = yield self._get_context(event, old_state_1, old_state_2)
group_name_2 = "group_name_2"
self.store.get_state_groups.return_value = {
group_name_1: old_state_1,
group_name_2: old_state_2,
}
context = yield self.state.compute_event_context(event)
self.assertEqual(len(context.current_state), 5) self.assertEqual(len(context.current_state), 5)
@ -181,21 +447,70 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_state_conflict(self): def test_resolve_state_conflict(self):
event = self.create_event(type="test4", state_key="", name="event") event = create_event(type="test4", state_key="", name="event")
event.prev_events = []
old_state_1 = [ old_state_1 = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
old_state_2 = [ old_state_2 = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test3", state_key="2"), create_event(type="test3", state_key="2"),
self.create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
] ]
context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(len(context.current_state), 5)
self.assertIsNone(context.state_group)
@defer.inlineCallbacks
def test_standard_depth_conflict(self):
event = create_event(type="test4", name="event")
member_event = create_event(
type=EventTypes.Member,
state_key="@user_id:example.com",
content={
"membership": Membership.JOIN,
}
)
old_state_1 = [
member_event,
create_event(type="test1", state_key="1", depth=1),
]
old_state_2 = [
member_event,
create_event(type="test1", state_key="1", depth=2),
]
context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(old_state_2[1], context.current_state[("test1", "1")])
# Reverse the depth to make sure we are actually using the depths
# during state resolution.
old_state_1 = [
member_event,
create_event(type="test1", state_key="1", depth=2),
]
old_state_2 = [
member_event,
create_event(type="test1", state_key="1", depth=1),
]
context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(old_state_1[1], context.current_state[("test1", "1")])
def _get_context(self, event, old_state_1, old_state_2):
group_name_1 = "group_name_1" group_name_1 = "group_name_1"
group_name_2 = "group_name_2" group_name_2 = "group_name_2"
@ -204,33 +519,4 @@ class StateTestCase(unittest.TestCase):
group_name_2: old_state_2, group_name_2: old_state_2,
} }
context = yield self.state.compute_event_context(event) return self.state.compute_event_context(event)
self.assertEqual(len(context.current_state), 5)
self.assertIsNone(context.state_group)
def create_event(self, name=None, type=None, state_key=None):
self.event_id += 1
event_id = str(self.event_id)
if not name:
if state_key is not None:
name = "<%s-%s>" % (type, state_key)
else:
name = "<%s>" % (type, )
event = Mock(name=name, spec=[])
event.type = type
if state_key is not None:
event.state_key = state_key
event.event_id = event_id
event.is_state = lambda: (state_key is not None)
event.unsigned = {}
event.user_id = "@user_id:example.com"
event.room_id = "!room_id:example.com"
return event