Make the state resolution use actual power levels rather than taking them from a Pdu key.

This commit is contained in:
Erik Johnston 2014-09-12 17:11:00 +01:00
parent b42fe05c51
commit 39e3fc69e5
5 changed files with 194 additions and 127 deletions

View file

@ -69,6 +69,7 @@ class Pdu(JsonEncodedObject):
"prev_state_id",
"prev_state_origin",
"required_power_level",
"user_id",
]
internal_keys = [

View file

@ -115,6 +115,8 @@ class StateHandler(object):
is_new = yield self._handle_new_state(new_pdu)
logger.debug("is_new: %s %s %s", is_new, new_pdu.pdu_id, new_pdu.origin)
if is_new:
yield self.store.update_current_state(
pdu_id=new_pdu.pdu_id,
@ -187,11 +189,12 @@ class StateHandler(object):
# We didn't find a common ancestor. This is probably fine.
pass
result = self._do_conflict_res(
result = yield self._do_conflict_res(
new_branch, current_branch, common_ancestor
)
defer.returnValue(result)
@defer.inlineCallbacks
def _do_conflict_res(self, new_branch, current_branch, common_ancestor):
conflict_res = [
self._do_power_level_conflict_res,
@ -200,7 +203,8 @@ class StateHandler(object):
]
for algo in conflict_res:
new_res, curr_res = algo(
new_res, curr_res = yield defer.maybeDeferred(
algo,
new_branch, current_branch, common_ancestor
)
@ -211,19 +215,39 @@ class StateHandler(object):
raise Exception("Conflict resolution failed.")
@defer.inlineCallbacks
def _do_power_level_conflict_res(self, new_branch, current_branch,
common_ancestor):
max_power_new = max(
new_branch[:-1] if common_ancestor else new_branch,
key=lambda t: t.power_level
).power_level
new_powers_deferreds = []
for e in new_branch[:-1] if common_ancestor else new_branch:
if hasattr(e, "user_id"):
new_powers_deferreds.append(
self.store.get_power_level(e.context, e.user_id)
)
max_power_current = max(
current_branch[:-1] if common_ancestor else current_branch,
key=lambda t: t.power_level
).power_level
current_powers_deferreds = []
for e in current_branch[:-1] if common_ancestor else current_branch:
if hasattr(e, "user_id"):
current_powers_deferreds.append(
self.store.get_power_level(e.context, e.user_id)
)
return (max_power_new, max_power_current)
new_powers = yield defer.gatherResults(
new_powers_deferreds,
consumeErrors=True
)
current_powers = yield defer.gatherResults(
current_powers_deferreds,
consumeErrors=True
)
max_power_new = max(new_powers)
max_power_current = max(current_powers)
defer.returnValue(
(max_power_new, max_power_current)
)
def _do_chain_length_conflict_res(self, new_branch, current_branch,
common_ancestor):

View file

@ -17,6 +17,7 @@ import logging
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.util.logutils import log_function
import collections
import copy
@ -91,6 +92,7 @@ class SQLBaseStore(object):
self._simple_insert_txn, table, values, or_replace=or_replace
)
@log_function
def _simple_insert_txn(self, txn, table, values, or_replace=False):
sql = "%s INTO %s (%s) VALUES(%s)" % (
("INSERT OR REPLACE" if or_replace else "INSERT"),
@ -98,6 +100,12 @@ class SQLBaseStore(object):
", ".join(k for k in values),
", ".join("?" for k in values)
)
logger.debug(
"[SQL] %s Args=%s Func=%s",
sql, values.values(),
)
txn.execute(sql, values.values())
return txn.lastrowid

View file

@ -17,6 +17,7 @@ from twisted.internet import defer
from ._base import SQLBaseStore, Table, JoinHelper
from synapse.federation.units import Pdu
from synapse.util.logutils import log_function
from collections import namedtuple
@ -625,53 +626,6 @@ class StatePduStore(SQLBaseStore):
return result
def get_next_missing_pdu(self, new_pdu):
"""When we get a new state pdu we need to check whether we need to do
any conflict resolution, if we do then we need to check if we need
to go back and request some more state pdus that we haven't seen yet.
Args:
txn
new_pdu
Returns:
PduIdTuple: A pdu that we are missing, or None if we have all the
pdus required to do the conflict resolution.
"""
return self._db_pool.runInteraction(
self._get_next_missing_pdu, new_pdu
)
def _get_next_missing_pdu(self, txn, new_pdu):
logger.debug(
"get_next_missing_pdu %s %s",
new_pdu.pdu_id, new_pdu.origin
)
current = self._get_current_interaction(
txn,
new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
)
if (not current or not current.prev_state_id
or not current.prev_state_origin):
return None
# Oh look, it's a straight clobber, so wooooo almost no-op.
if (new_pdu.prev_state_id == current.pdu_id
and new_pdu.prev_state_origin == current.origin):
return None
enum_branches = self._enumerate_state_branches(txn, new_pdu, current)
for branch, prev_state, state in enum_branches:
if not state:
return PduIdTuple(
prev_state.prev_state_id,
prev_state.prev_state_origin
)
return None
def handle_new_state(self, new_pdu):
"""Actually perform conflict resolution on the new_pdu on the
assumption we have all the pdus required to perform it.
@ -755,24 +709,11 @@ class StatePduStore(SQLBaseStore):
return is_current
@classmethod
@log_function
def _enumerate_state_branches(cls, txn, pdu_a, pdu_b):
def _enumerate_state_branches(self, txn, pdu_a, pdu_b):
branch_a = pdu_a
branch_b = pdu_b
get_query = (
"SELECT %(fields)s FROM %(pdus)s as p "
"LEFT JOIN %(state)s as s "
"ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
"WHERE p.pdu_id = ? AND p.origin = ? "
) % {
"fields": _pdu_state_joiner.get_fields(
PdusTable="p", StatePdusTable="s"),
"pdus": PdusTable.table_name,
"state": StatePdusTable.table_name,
}
while True:
if (branch_a.pdu_id == branch_b.pdu_id
and branch_a.origin == branch_b.origin):
@ -804,13 +745,12 @@ class StatePduStore(SQLBaseStore):
branch_a.prev_state_origin
)
logger.debug("getting branch_a prev %s", pdu_tuple)
txn.execute(get_query, pdu_tuple)
prev_branch = branch_a
res = txn.fetchone()
branch_a = PduEntry(*res) if res else None
logger.debug("getting branch_a prev %s", pdu_tuple)
branch_a = self._get_pdu_tuple(txn, *pdu_tuple)
if branch_a:
branch_a = Pdu.from_pdu_tuple(branch_a)
logger.debug("branch_a=%s", branch_a)
@ -823,14 +763,13 @@ class StatePduStore(SQLBaseStore):
branch_b.prev_state_id,
branch_b.prev_state_origin
)
txn.execute(get_query, pdu_tuple)
logger.debug("getting branch_b prev %s", pdu_tuple)
prev_branch = branch_b
res = txn.fetchone()
branch_b = PduEntry(*res) if res else None
logger.debug("getting branch_b prev %s", pdu_tuple)
branch_b = self._get_pdu_tuple(txn, *pdu_tuple)
if branch_b:
branch_b = Pdu.from_pdu_tuple(branch_b)
logger.debug("branch_b=%s", branch_b)