Fix bugs with invites/joins across federatiom.

Both in terms of auth and not trying to fetch missing PDUs for invites,
joins etc.
This commit is contained in:
Erik Johnston 2014-11-12 11:22:51 +00:00
parent 2c400363e8
commit 6fea478d2e
7 changed files with 54 additions and 39 deletions

View File

@ -36,6 +36,7 @@ class Auth(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler()
def check(self, event, raises=False): def check(self, event, raises=False):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
@ -90,7 +91,7 @@ class Auth(object):
) )
logger.info("Denying! %s", event) logger.info("Denying! %s", event)
if raises: if raises:
raise e raise
return False return False
@ -109,9 +110,21 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_host_in_room(self, room_id, host): def check_host_in_room(self, room_id, host):
joined_hosts = yield self.store.get_joined_hosts_for_room(room_id) curr_state = yield self.state.get_current_state(room_id)
defer.returnValue(host in joined_hosts) for event in curr_state:
if event.type == RoomMemberEvent.TYPE:
try:
if self.hs.parse_userid(event.state_key).domain != host:
continue
except:
logger.warn("state_key not user_id: %s", event.state_key)
continue
if event.content["membership"] == Membership.JOIN:
defer.returnValue(True)
defer.returnValue(False)
def check_event_sender_in_room(self, event): def check_event_sender_in_room(self, event):
key = (RoomMemberEvent.TYPE, event.user_id, ) key = (RoomMemberEvent.TYPE, event.user_id, )

View File

@ -267,8 +267,6 @@ class ReplicationLayer(object):
transaction = Transaction(**transaction_data) transaction = Transaction(**transaction_data)
pdus = [Pdu(outlier=True, **p) for p in transaction.pdus] pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
for pdu in pdus:
yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdus) defer.returnValue(pdus)
@ -452,15 +450,12 @@ class ReplicationLayer(object):
) )
logger.debug("Got content: %s", content) logger.debug("Got content: %s", content)
state = [Pdu(outlier=True, **p) for p in content.get("state", [])] state = [Pdu(outlier=True, **p) for p in content.get("state", [])]
for pdu in state:
yield self._handle_new_pdu(destination, pdu)
auth_chain = [ auth_chain = [
Pdu(outlier=True, **p) for p in content.get("auth_chain", []) Pdu(outlier=True, **p) for p in content.get("auth_chain", [])
] ]
for pdu in auth_chain:
yield self._handle_new_pdu(destination, pdu)
defer.returnValue(state) defer.returnValue(state)

View File

@ -229,12 +229,6 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def do_invite_join(self, target_host, room_id, joinee, content, snapshot): def do_invite_join(self, target_host, room_id, joinee, content, snapshot):
hosts = yield self.store.get_joined_hosts_for_room(room_id)
if self.hs.hostname in hosts:
# We are already in the room.
logger.debug("We're already in the room apparently")
defer.returnValue(False)
pdu = yield self.replication_layer.make_join( pdu = yield self.replication_layer.make_join(
target_host, target_host,
room_id, room_id,
@ -268,7 +262,7 @@ class FederationHandler(BaseHandler):
logger.debug("do_invite_join state: %s", state) logger.debug("do_invite_join state: %s", state)
is_new_state = yield self.state_handler.annotate_event_with_state( yield self.state_handler.annotate_event_with_state(
event, event,
old_state=state old_state=state
) )
@ -296,13 +290,13 @@ class FederationHandler(BaseHandler):
yield self.store.persist_event( yield self.store.persist_event(
e, e,
backfilled=False, backfilled=False,
is_new_state=False is_new_state=True
) )
yield self.store.persist_event( yield self.store.persist_event(
event, event,
backfilled=False, backfilled=False,
is_new_state=is_new_state is_new_state=True
) )
finally: finally:
room_queue = self.room_queues[room_id] room_queue = self.room_queues[room_id]

View File

@ -24,6 +24,7 @@ from synapse.api.events.room import (
RoomTopicEvent, RoomNameEvent, RoomJoinRulesEvent, RoomTopicEvent, RoomNameEvent, RoomJoinRulesEvent,
) )
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async import run_on_reactor
from ._base import BaseHandler from ._base import BaseHandler
import logging import logging
@ -432,9 +433,12 @@ class RoomMemberHandler(BaseHandler):
# that we are allowed to join when we decide whether or not we # that we are allowed to join when we decide whether or not we
# need to do the invite/join dance. # need to do the invite/join dance.
hosts = yield self.store.get_joined_hosts_for_room(room_id) is_host_in_room = yield self.auth.check_host_in_room(
event.room_id,
self.hs.hostname
)
if self.hs.hostname in hosts: if is_host_in_room:
should_do_dance = False should_do_dance = False
elif room_host: elif room_host:
should_do_dance = True should_do_dance = True
@ -517,6 +521,8 @@ class RoomMemberHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_local_membership_update(self, event, membership, snapshot, def _do_local_membership_update(self, event, membership, snapshot,
do_auth): do_auth):
yield run_on_reactor()
# If we're inviting someone, then we should also send it to that # If we're inviting someone, then we should also send it to that
# HS. # HS.
target_user_id = event.state_key target_user_id = event.state_key

View File

@ -186,6 +186,7 @@ class DataStore(RoomMemberStore, RoomStore,
"events", "events",
vals, vals,
or_replace=(not outlier), or_replace=(not outlier),
or_ignore=bool(outlier),
) )
except: except:
logger.warn( logger.warn(
@ -217,7 +218,12 @@ class DataStore(RoomMemberStore, RoomStore,
if hasattr(event, "replaces_state"): if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state vals["prev_state"] = event.replaces_state
self._simple_insert_txn(txn, "state_events", vals) self._simple_insert_txn(
txn,
"state_events",
vals,
or_replace=True,
)
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
@ -227,7 +233,8 @@ class DataStore(RoomMemberStore, RoomStore,
"room_id": event.room_id, "room_id": event.room_id,
"type": event.type, "type": event.type,
"state_key": event.state_key, "state_key": event.state_key,
} },
or_replace=True,
) )
for e_id, h in event.prev_state: for e_id, h in event.prev_state:
@ -252,7 +259,8 @@ class DataStore(RoomMemberStore, RoomStore,
"room_id": event.room_id, "room_id": event.room_id,
"type": event.type, "type": event.type,
"state_key": event.state_key, "state_key": event.state_key,
} },
or_replace=True,
) )
for prev_state_id, _ in event.prev_state: for prev_state_id, _ in event.prev_state:

View File

@ -70,7 +70,8 @@ class StateStore(SQLBaseStore):
values={ values={
"room_id": event.room_id, "room_id": event.room_id,
"event_id": event.event_id, "event_id": event.event_id,
} },
or_ignore=True,
) )
for state in event.state_events.values(): for state in event.state_events.values():
@ -83,7 +84,8 @@ class StateStore(SQLBaseStore):
"type": state.type, "type": state.type,
"state_key": state.state_key, "state_key": state.state_key,
"event_id": state.event_id, "event_id": state.event_id,
} },
or_ignore=True,
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -92,5 +94,6 @@ class StateStore(SQLBaseStore):
values={ values={
"state_group": state_group, "state_group": state_group,
"event_id": event.event_id, "event_id": event.event_id,
} },
or_replace=True,
) )

View File

@ -44,7 +44,6 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
]), ]),
datastore=NonCallableMock(spec_set=[ datastore=NonCallableMock(spec_set=[
"persist_event", "persist_event",
"get_joined_hosts_for_room",
"get_room_member", "get_room_member",
"get_room", "get_room",
"store_room", "store_room",
@ -58,9 +57,14 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
"profile_handler", "profile_handler",
"federation_handler", "federation_handler",
]), ]),
auth=NonCallableMock(spec_set=["check", "add_auth_events"]), auth=NonCallableMock(spec_set=[
"check",
"add_auth_events",
"check_host_in_room",
]),
state_handler=NonCallableMock(spec_set=[ state_handler=NonCallableMock(spec_set=[
"annotate_event_with_state", "annotate_event_with_state",
"get_current_state",
]), ]),
config=self.mock_config, config=self.mock_config,
) )
@ -76,6 +80,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.auth = hs.get_auth()
self.hs = hs self.hs = hs
self.handlers.federation_handler = self.federation self.handlers.federation_handler = self.federation
@ -108,11 +113,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
content=content, content=content,
) )
joined = ["red", "green"] self.auth.check_host_in_room.return_value = defer.succeed(True)
self.datastore.get_joined_hosts_for_room.return_value = (
defer.succeed(joined)
)
store_id = "store_id_fooo" store_id = "store_id_fooo"
self.datastore.persist_event.return_value = defer.succeed(store_id) self.datastore.persist_event.return_value = defer.succeed(store_id)
@ -164,12 +165,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
room_id=room_id, room_id=room_id,
) )
joined = ["red", "green"] self.auth.check_host_in_room.return_value = defer.succeed(True)
def get_joined(*args):
return defer.succeed(joined)
self.datastore.get_joined_hosts_for_room.side_effect = get_joined
store_id = "store_id_fooo" store_id = "store_id_fooo"
self.datastore.persist_event.return_value = defer.succeed(store_id) self.datastore.persist_event.return_value = defer.succeed(store_id)