Merge pull request #47 from matrix-org/signature_failures

Federation fixes.
This commit is contained in:
Erik Johnston 2015-02-05 14:03:00 +00:00
commit f08bd95880
9 changed files with 427 additions and 270 deletions

View File

@ -102,8 +102,6 @@ class Auth(object):
def check_host_in_room(self, room_id, host): def check_host_in_room(self, room_id, host):
curr_state = yield self.state.get_current_state(room_id) curr_state = yield self.state.get_current_state(room_id)
logger.debug("Got curr_state %s", curr_state)
for event in curr_state: for event in curr_state:
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
try: try:
@ -360,7 +358,7 @@ class Auth(object):
def add_auth_events(self, builder, context): def add_auth_events(self, builder, context):
yield run_on_reactor() yield run_on_reactor()
auth_ids = self.compute_auth_events(builder, context) auth_ids = self.compute_auth_events(builder, context.current_state)
auth_events_entries = yield self.store.add_event_hashes( auth_events_entries = yield self.store.add_event_hashes(
auth_ids auth_ids
@ -374,26 +372,26 @@ class Auth(object):
if v.event_id in auth_ids if v.event_id in auth_ids
} }
def compute_auth_events(self, event, context): def compute_auth_events(self, event, current_state):
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
return [] return []
auth_ids = [] auth_ids = []
key = (EventTypes.PowerLevels, "", ) key = (EventTypes.PowerLevels, "", )
power_level_event = context.current_state.get(key) power_level_event = current_state.get(key)
if power_level_event: if power_level_event:
auth_ids.append(power_level_event.event_id) auth_ids.append(power_level_event.event_id)
key = (EventTypes.JoinRules, "", ) key = (EventTypes.JoinRules, "", )
join_rule_event = context.current_state.get(key) join_rule_event = current_state.get(key)
key = (EventTypes.Member, event.user_id, ) key = (EventTypes.Member, event.user_id, )
member_event = context.current_state.get(key) member_event = current_state.get(key)
key = (EventTypes.Create, "", ) key = (EventTypes.Create, "", )
create_event = context.current_state.get(key) create_event = current_state.get(key)
if create_event: if create_event:
auth_ids.append(create_event.event_id) auth_ids.append(create_event.event_id)

View File

@ -77,7 +77,7 @@ class EventBase(object):
return self.content["membership"] return self.content["membership"]
def is_state(self): def is_state(self):
return hasattr(self, "state_key") return hasattr(self, "state_key") and self.state_key is not None
def get_dict(self): def get_dict(self):
d = dict(self._event_dict) d = dict(self._event_dict)

View File

@ -0,0 +1,118 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError
import logging
logger = logging.getLogger(__name__)
class FederationBase(object):
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
that PDU.
If a PDU fails its content hash check then it is redacted.
The given list of PDUs are not modified, instead the function returns
a new list.
Args:
pdu (list)
outlier (bool)
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
"""
signed_pdus = []
for pdu in pdus:
try:
new_pdu = yield self._check_sigs_and_hash(pdu)
signed_pdus.append(new_pdu)
except SynapseError:
# FIXME: We should handle signature failures more gracefully.
# Check local db.
new_pdu = yield self.store.get_event(
pdu.event_id,
allow_rejected=True
)
if new_pdu:
signed_pdus.append(new_pdu)
continue
# Check pdu.origin
if pdu.origin != origin:
new_pdu = yield self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
outlier=outlier,
)
if new_pdu:
signed_pdus.append(new_pdu)
continue
logger.warn("Failed to find copy of %s with valid signature")
defer.returnValue(signed_pdus)
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
"""Throws a SynapseError if the PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
try:
yield self.keyring.verify_json_for_server(
pdu.origin, redacted_pdu_json
)
except SynapseError:
logger.warn(
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
raise
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s, %s",
pdu.event_id, encode_canonical_json(pdu.get_dict())
)
defer.returnValue(redacted_event)
defer.returnValue(pdu)

View File

@ -16,17 +16,11 @@
from twisted.internet import defer from twisted.internet import defer
from .federation_base import FederationBase
from .units import Edu from .units import Edu
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError
import logging import logging
@ -34,7 +28,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FederationClient(object): class FederationClient(FederationBase):
@log_function @log_function
def send_pdu(self, pdu, destinations): def send_pdu(self, pdu, destinations):
"""Informs the replication layer about a new PDU generated within the """Informs the replication layer about a new PDU generated within the
@ -224,17 +218,17 @@ class FederationClient(object):
for p in result.get("auth_chain", []) for p in result.get("auth_chain", [])
] ]
for i, pdu in enumerate(pdus): signed_pdus = yield self._check_sigs_and_hash_and_fetch(
pdus[i] = yield self._check_sigs_and_hash(pdu) destination, pdus, outlier=True
)
# FIXME: We should handle signature failures more gracefully. signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
for i, pdu in enumerate(auth_chain): signed_auth.sort(key=lambda e: e.depth)
auth_chain[i] = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully. defer.returnValue((signed_pdus, signed_auth))
defer.returnValue((pdus, auth_chain))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -248,14 +242,13 @@ class FederationClient(object):
for p in res["auth_chain"] for p in res["auth_chain"]
] ]
for i, pdu in enumerate(auth_chain): signed_auth = yield self._check_sigs_and_hash_and_fetch(
auth_chain[i] = yield self._check_sigs_and_hash(pdu) destination, auth_chain, outlier=True
)
# FIXME: We should handle signature failures more gracefully. signed_auth.sort(key=lambda e: e.depth)
auth_chain.sort(key=lambda e: e.depth) defer.returnValue(signed_auth)
defer.returnValue(auth_chain)
@defer.inlineCallbacks @defer.inlineCallbacks
def make_join(self, destination, room_id, user_id): def make_join(self, destination, room_id, user_id):
@ -291,21 +284,19 @@ class FederationClient(object):
for p in content.get("auth_chain", []) for p in content.get("auth_chain", [])
] ]
for i, pdu in enumerate(state): signed_state = yield self._check_sigs_and_hash_and_fetch(
state[i] = yield self._check_sigs_and_hash(pdu) destination, state, outlier=True
)
# FIXME: We should handle signature failures more gracefully. signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
for i, pdu in enumerate(auth_chain): )
auth_chain[i] = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully.
auth_chain.sort(key=lambda e: e.depth) auth_chain.sort(key=lambda e: e.depth)
defer.returnValue({ defer.returnValue({
"state": state, "state": signed_state,
"auth_chain": auth_chain, "auth_chain": signed_auth,
}) })
@defer.inlineCallbacks @defer.inlineCallbacks
@ -353,12 +344,18 @@ class FederationClient(object):
) )
auth_chain = [ auth_chain = [
(yield self._check_sigs_and_hash(self.event_from_pdu_json(e))) self.event_from_pdu_json(e)
for e in content["auth_chain"] for e in content["auth_chain"]
] ]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
signed_auth.sort(key=lambda e: e.depth)
ret = { ret = {
"auth_chain": auth_chain, "auth_chain": signed_auth,
"rejects": content.get("rejects", []), "rejects": content.get("rejects", []),
"missing": content.get("missing", []), "missing": content.get("missing", []),
} }
@ -373,37 +370,3 @@ class FederationClient(object):
event.internal_metadata.outlier = outlier event.internal_metadata.outlier = outlier
return event return event
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
"""Throws a SynapseError if the PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
try:
yield self.keyring.verify_json_for_server(
pdu.origin, redacted_pdu_json
)
except SynapseError:
logger.warn(
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
raise
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s, %s",
pdu.event_id, encode_canonical_json(pdu.get_dict())
)
defer.returnValue(redacted_event)
defer.returnValue(pdu)

View File

@ -16,16 +16,12 @@
from twisted.internet import defer from twisted.internet import defer
from .federation_base import FederationBase
from .units import Transaction, Edu from .units import Transaction, Edu
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import FederationError, SynapseError from synapse.api.errors import FederationError, SynapseError
@ -35,7 +31,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FederationServer(object): class FederationServer(FederationBase):
def set_handler(self, handler): def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate """Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are receipt of new PDUs from other home servers. The required methods are
@ -251,17 +247,20 @@ class FederationServer(object):
Deferred: Results in `dict` with the same format as `content` Deferred: Results in `dict` with the same format as `content`
""" """
auth_chain = [ auth_chain = [
(yield self._check_sigs_and_hash(self.event_from_pdu_json(e))) self.event_from_pdu_json(e)
for e in content["auth_chain"] for e in content["auth_chain"]
] ]
missing = [ signed_auth = yield self._check_sigs_and_hash_and_fetch(
(yield self._check_sigs_and_hash(self.event_from_pdu_json(e))) origin, auth_chain, outlier=True
for e in content.get("missing", []) )
]
ret = yield self.handler.on_query_auth( ret = yield self.handler.on_query_auth(
origin, event_id, auth_chain, content.get("rejects", []), missing origin,
event_id,
signed_auth,
content.get("rejects", []),
content.get("missing", []),
) )
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
@ -426,37 +425,3 @@ class FederationServer(object):
event.internal_metadata.outlier = outlier event.internal_metadata.outlier = outlier
return event return event
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
"""Throws a SynapseError if the PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
try:
yield self.keyring.verify_json_for_server(
pdu.origin, redacted_pdu_json
)
except SynapseError:
logger.warn(
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
raise
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s, %s",
pdu.event_id, encode_canonical_json(pdu.get_dict())
)
defer.returnValue(redacted_event)
defer.returnValue(pdu)

View File

@ -30,6 +30,7 @@ from synapse.types import UserID
from twisted.internet import defer from twisted.internet import defer
import itertools
import logging import logging
@ -123,8 +124,21 @@ class FederationHandler(BaseHandler):
logger.debug("Got event for room we're not in.") logger.debug("Got event for room we're not in.")
current_state = state current_state = state
event_ids = set()
if state:
event_ids |= {e.event_id for e in state}
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
seen_ids = (yield self.store.have_events(event_ids)).keys()
if state and auth_chain is not None: if state and auth_chain is not None:
for e in state: # If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.)
for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: try:
auth_ids = [e_id for e_id, _ in e.auth_events] auth_ids = [e_id for e_id, _ in e.auth_events]
@ -132,7 +146,10 @@ class FederationHandler(BaseHandler):
(e.type, e.state_key): e for e in auth_chain (e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids if e.event_id in auth_ids
} }
yield self._handle_new_event(origin, e, auth_events=auth) yield self._handle_new_event(
origin, e, auth_events=auth
)
seen_ids.add(e.event_id)
except: except:
logger.exception( logger.exception(
"Failed to handle state event %s", "Failed to handle state event %s",
@ -498,6 +515,8 @@ class FederationHandler(BaseHandler):
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id
) )
destinations.remove(origin)
logger.debug( logger.debug(
"on_send_join_request: Sending event: %s, signatures: %s", "on_send_join_request: Sending event: %s, signatures: %s",
event.event_id, event.event_id,
@ -618,6 +637,7 @@ class FederationHandler(BaseHandler):
event = yield self.store.get_event( event = yield self.store.get_event(
event_id, event_id,
allow_none=True, allow_none=True,
allow_rejected=True,
) )
if event: if event:
@ -701,6 +721,8 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
# FIXME: Don't store as rejected with AUTH_ERROR if we haven't
# seen all the auth events.
yield self.store.persist_event( yield self.store.persist_event(
event, event,
context=context, context=context,
@ -750,7 +772,7 @@ class FederationHandler(BaseHandler):
) )
) )
logger.debug("on_query_auth reutrning: %s", ret) logger.debug("on_query_auth returning: %s", ret)
defer.returnValue(ret) defer.returnValue(ret)
@ -770,6 +792,7 @@ class FederationHandler(BaseHandler):
if missing_auth: if missing_auth:
logger.debug("Missing auth: %s", missing_auth) logger.debug("Missing auth: %s", missing_auth)
# If we don't have all the auth events, we need to get them. # If we don't have all the auth events, we need to get them.
try:
remote_auth_chain = yield self.replication_layer.get_event_auth( remote_auth_chain = yield self.replication_layer.get_event_auth(
origin, event.room_id, event.event_id origin, event.room_id, event.event_id
) )
@ -805,6 +828,9 @@ class FederationHandler(BaseHandler):
auth_events[(e.type, e.state_key)] = e auth_events[(e.type, e.state_key)] = e
except AuthError: except AuthError:
pass pass
except:
# FIXME:
logger.exception("Failed to get auth chain")
# FIXME: Assumes we have and stored all the state for all the # FIXME: Assumes we have and stored all the state for all the
# prev_events # prev_events
@ -816,9 +842,12 @@ class FederationHandler(BaseHandler):
logger.debug("Different auth: %s", different_auth) logger.debug("Different auth: %s", different_auth)
# 1. Get what we think is the auth chain. # 1. Get what we think is the auth chain.
auth_ids = self.auth.compute_auth_events(event, context) auth_ids = self.auth.compute_auth_events(
event, context.current_state
)
local_auth_chain = yield self.store.get_auth_chain(auth_ids) local_auth_chain = yield self.store.get_auth_chain(auth_ids)
try:
# 2. Get remote difference. # 2. Get remote difference.
result = yield self.replication_layer.query_auth( result = yield self.replication_layer.query_auth(
origin, origin,
@ -861,6 +890,10 @@ class FederationHandler(BaseHandler):
except AuthError: except AuthError:
pass pass
except:
# FIXME:
logger.exception("Failed to query auth chain")
# 4. Look at rejects and their proofs. # 4. Look at rejects and their proofs.
# TODO. # TODO.
@ -983,7 +1016,7 @@ class FederationHandler(BaseHandler):
if reason is None: if reason is None:
# FIXME: ERRR?! # FIXME: ERRR?!
logger.warn("Could not find reason for %s", e.event_id) logger.warn("Could not find reason for %s", e.event_id)
raise RuntimeError("") raise RuntimeError("Could not find reason for %s" % e.event_id)
reason_map[e.event_id] = reason reason_map[e.event_id] = reason

View File

@ -37,7 +37,10 @@ def _get_state_key_from_event(event):
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
AuthEventTypes = (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,) AuthEventTypes = (
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
EventTypes.JoinRules,
)
class StateHandler(object): class StateHandler(object):
@ -100,7 +103,9 @@ class StateHandler(object):
context.state_group = None context.state_group = None
if hasattr(event, "auth_events") and event.auth_events: if hasattr(event, "auth_events") and event.auth_events:
auth_ids = zip(*event.auth_events)[0] auth_ids = self.hs.get_auth().compute_auth_events(
event, context.current_state
)
context.auth_events = { context.auth_events = {
k: v k: v
for k, v in context.current_state.items() for k, v in context.current_state.items()
@ -146,7 +151,9 @@ class StateHandler(object):
event.unsigned["replaces_state"] = replaces.event_id event.unsigned["replaces_state"] = replaces.event_id
if hasattr(event, "auth_events") and event.auth_events: if hasattr(event, "auth_events") and event.auth_events:
auth_ids = zip(*event.auth_events)[0] auth_ids = self.hs.get_auth().compute_auth_events(
event, context.current_state
)
context.auth_events = { context.auth_events = {
k: v k: v
for k, v in context.current_state.items() for k, v in context.current_state.items()
@ -258,6 +265,15 @@ class StateHandler(object):
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
resolved_state[key] = self._resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in conflicted_state.items(): for key, events in conflicted_state.items():
if key[0] == EventTypes.Member: if key[0] == EventTypes.Member:
resolved_state[key] = self._resolve_auth_events( resolved_state[key] = self._resolve_auth_events(

View File

@ -128,21 +128,144 @@ class DataStore(RoomMemberStore, RoomStore,
pass pass
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, event_id, allow_none=False): def get_event(self, event_id, check_redacted=True,
events = yield self._get_events([event_id]) get_prev_content=False, allow_rejected=False,
allow_none=False):
"""Get an event from the database by event_id.
if not events: Args:
if allow_none: event_id (str): The event_id of the event to fetch
defer.returnValue(None) check_redacted (bool): If True, check if event has been redacted
else: and redact it.
get_prev_content (bool): If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
allow_none (bool): If True, return None if no event found, if
False throw an exception.
Returns:
Deferred : A FrozenEvent.
"""
event = yield self.runInteraction(
"get_event", self._get_event_txn,
event_id,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
if not event and not allow_none:
raise RuntimeError("Could not find event %s" % (event_id,)) raise RuntimeError("Could not find event %s" % (event_id,))
defer.returnValue(events[0]) defer.returnValue(event)
@log_function @log_function
def _persist_event_txn(self, txn, event, context, backfilled, def _persist_event_txn(self, txn, event, context, backfilled,
stream_ordering=None, is_new_state=True, stream_ordering=None, is_new_state=True,
current_state=None): current_state=None):
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
txn.execute(
"DELETE FROM current_state_events WHERE room_id = ?",
(event.room_id,)
)
for s in current_state:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": s.event_id,
"room_id": s.room_id,
"type": s.type,
"state_key": s.state_key,
},
or_replace=True,
)
if event.is_state() and is_new_state:
if not backfilled and not context.rejected:
self._simple_insert_txn(
txn,
table="state_forward_extremities",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
or_replace=True,
)
for prev_state_id, _ in event.prev_state:
self._simple_delete_txn(
txn,
table="state_forward_extremities",
keyvalues={
"event_id": prev_state_id,
}
)
outlier = event.internal_metadata.is_outlier()
if not outlier:
self._store_state_groups_txn(txn, event, context)
self._update_min_depth_for_room_txn(
txn,
event.room_id,
event.depth
)
self._handle_prev_events(
txn,
outlier=outlier,
event_id=event.event_id,
prev_events=event.prev_events,
room_id=event.room_id,
)
have_persisted = self._simple_select_one_onecol_txn(
txn,
table="event_json",
keyvalues={"event_id": event.event_id},
retcol="event_id",
allow_none=True,
)
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
)
# If we have already persisted this event, we don't need to do any
# more processing.
# The processing above must be done on every call to persist event,
# since they might not have happened on previous calls. For example,
# if we are persisting an event that we had persisted as an outlier,
# but is no longer one.
if have_persisted:
if not outlier:
sql = (
"UPDATE event_json SET internal_metadata = ?"
" WHERE event_id = ?"
)
txn.execute(
sql,
(metadata_json.decode("UTF-8"), event.event_id,)
)
sql = (
"UPDATE events SET outlier = 0"
" WHERE event_id = ?"
)
txn.execute(
sql,
(event.event_id,)
)
return
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
self._store_room_member_txn(txn, event) self._store_room_member_txn(txn, event)
elif event.type == EventTypes.Feedback: elif event.type == EventTypes.Feedback:
@ -154,8 +277,6 @@ class DataStore(RoomMemberStore, RoomStore,
elif event.type == EventTypes.Redaction: elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event) self._store_redaction(txn, event)
outlier = event.internal_metadata.is_outlier()
event_dict = { event_dict = {
k: v k: v
for k, v in event.get_dict().items() for k, v in event.get_dict().items()
@ -165,10 +286,6 @@ class DataStore(RoomMemberStore, RoomStore,
] ]
} }
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
)
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="event_json", table="event_json",
@ -224,41 +341,10 @@ class DataStore(RoomMemberStore, RoomStore,
) )
raise _RollbackButIsFineException("_persist_event") raise _RollbackButIsFineException("_persist_event")
self._handle_prev_events(
txn,
outlier=outlier,
event_id=event.event_id,
prev_events=event.prev_events,
room_id=event.room_id,
)
if not outlier:
self._store_state_groups_txn(txn, event, context)
if context.rejected: if context.rejected:
self._store_rejections_txn(txn, event.event_id, context.rejected) self._store_rejections_txn(txn, event.event_id, context.rejected)
if current_state: if event.is_state():
txn.execute(
"DELETE FROM current_state_events WHERE room_id = ?",
(event.room_id,)
)
for s in current_state:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": s.event_id,
"room_id": s.room_id,
"type": s.type,
"state_key": s.state_key,
},
or_replace=True,
)
is_state = hasattr(event, "state_key") and event.state_key is not None
if is_state:
vals = { vals = {
"event_id": event.event_id, "event_id": event.event_id,
"room_id": event.room_id, "room_id": event.room_id,
@ -266,6 +352,7 @@ class DataStore(RoomMemberStore, RoomStore,
"state_key": event.state_key, "state_key": event.state_key,
} }
# TODO: How does this work with backfilling?
if hasattr(event, "replaces_state"): if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state vals["prev_state"] = event.replaces_state
@ -302,28 +389,6 @@ class DataStore(RoomMemberStore, RoomStore,
or_ignore=True, or_ignore=True,
) )
if not backfilled and not context.rejected:
self._simple_insert_txn(
txn,
table="state_forward_extremities",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
or_replace=True,
)
for prev_state_id, _ in event.prev_state:
self._simple_delete_txn(
txn,
table="state_forward_extremities",
keyvalues={
"event_id": prev_state_id,
}
)
for hash_alg, hash_base64 in event.hashes.items(): for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64) hash_bytes = decode_base64(hash_base64)
self._store_event_content_hash_txn( self._store_event_content_hash_txn(
@ -354,13 +419,6 @@ class DataStore(RoomMemberStore, RoomStore,
txn, event.event_id, ref_alg, ref_hash_bytes txn, event.event_id, ref_alg, ref_hash_bytes
) )
if not outlier:
self._update_min_depth_for_room_txn(
txn,
event.room_id,
event.depth
)
def _store_redaction(self, txn, event): def _store_redaction(self, txn, event):
txn.execute( txn.execute(
"INSERT OR IGNORE INTO redactions " "INSERT OR IGNORE INTO redactions "
@ -477,6 +535,9 @@ class DataStore(RoomMemberStore, RoomStore,
the rejected reason string if we rejected the event, else maps to the rejected reason string if we rejected the event, else maps to
None. None.
""" """
if not event_ids:
return defer.succeed({})
def f(txn): def f(txn):
sql = ( sql = (
"SELECT e.event_id, reason FROM events as e " "SELECT e.event_id, reason FROM events as e "

View File

@ -91,7 +91,10 @@ class FederationTestCase(unittest.TestCase):
self.datastore.persist_event.return_value = defer.succeed(None) self.datastore.persist_event.return_value = defer.succeed(None)
self.datastore.get_room.return_value = defer.succeed(True) self.datastore.get_room.return_value = defer.succeed(True)
self.auth.check_host_in_room.return_value = defer.succeed(True) self.auth.check_host_in_room.return_value = defer.succeed(True)
self.datastore.have_events.return_value = defer.succeed({})
def have_events(event_ids):
return defer.succeed({})
self.datastore.have_events.side_effect = have_events
def annotate(ev, old_state=None): def annotate(ev, old_state=None):
context = Mock() context = Mock()