Finish implementing the new join dance.

This commit is contained in:
Erik Johnston 2014-10-17 15:04:17 +01:00
parent 1116f5330e
commit f71627567b
6 changed files with 225 additions and 129 deletions

View File

@ -48,6 +48,15 @@ class Auth(object):
""" """
try: try:
if hasattr(event, "room_id"): if hasattr(event, "room_id"):
if not event.old_state_events:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
defer.returnValue(True)
if hasattr(event, "outlier") and event.outlier:
# TODO (erikj): Auth for outliers is done differently.
defer.returnValue(True)
is_state = hasattr(event, "state_key") is_state = hasattr(event, "state_key")
if event.type == RoomMemberEvent.TYPE: if event.type == RoomMemberEvent.TYPE:

View File

@ -51,12 +51,20 @@ class EventFactory(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.hs = hs self.hs = hs
self.event_id_count = 0
def create_event_id(self):
i = str(self.event_id_count)
self.event_id_count += 1
local_part = str(int(self.clock.time())) + i + random_string(5)
return "%s@%s" % (local_part, self.hs.hostname)
def create_event(self, etype=None, **kwargs): def create_event(self, etype=None, **kwargs):
kwargs["type"] = etype kwargs["type"] = etype
if "event_id" not in kwargs: if "event_id" not in kwargs:
kwargs["event_id"] = "%s@%s" % ( kwargs["event_id"] = self.create_event_id()
random_string(10), self.hs.hostname
)
if "ts" not in kwargs: if "ts" not in kwargs:
kwargs["ts"] = int(self.clock.time_msec()) kwargs["ts"] = int(self.clock.time_msec())

View File

@ -244,13 +244,14 @@ class ReplicationLayer(object):
pdu = None pdu = None
if pdu_list: if pdu_list:
pdu = pdu_list[0] pdu = pdu_list[0]
yield self._handle_new_pdu(pdu) yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdu) defer.returnValue(pdu)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_state_for_context(self, destination, context): def get_state_for_context(self, destination, context, pdu_id=None,
pdu_origin=None):
"""Requests all of the `current` state PDUs for a given context from """Requests all of the `current` state PDUs for a given context from
a remote home server. a remote home server.
@ -263,13 +264,14 @@ class ReplicationLayer(object):
""" """
transaction_data = yield self.transport_layer.get_context_state( transaction_data = yield self.transport_layer.get_context_state(
destination, context) destination, context, pdu_id=pdu_id, pdu_origin=pdu_origin,
)
transaction = Transaction(**transaction_data) transaction = Transaction(**transaction_data)
pdus = [Pdu(outlier=True, **p) for p in transaction.pdus] pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
for pdu in pdus: for pdu in pdus:
yield self._handle_new_pdu(pdu) yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdus) defer.returnValue(pdus)
@ -315,7 +317,7 @@ class ReplicationLayer(object):
dl = [] dl = []
for pdu in pdu_list: for pdu in pdu_list:
dl.append(self._handle_new_pdu(pdu)) dl.append(self._handle_new_pdu(transaction.origin, pdu))
if hasattr(transaction, "edus"): if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]: for edu in [Edu(**x) for x in transaction.edus]:
@ -347,14 +349,19 @@ class ReplicationLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_context_state_request(self, context): def on_context_state_request(self, context, pdu_id, pdu_origin):
results = yield self.store.get_current_state_for_context( if pdu_id and pdu_origin:
context pdus = yield self.handler.get_state_for_pdu(
) pdu_id, pdu_origin
)
else:
results = yield self.store.get_current_state_for_context(
context
)
pdus = [Pdu.from_pdu_tuple(p) for p in results]
logger.debug("Context returning %d results", len(results)) logger.debug("Context returning %d results", len(pdus))
pdus = [Pdu.from_pdu_tuple(p) for p in results]
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -396,9 +403,10 @@ class ReplicationLayer(object):
defer.returnValue( defer.returnValue(
(404, "No handler for Query type '%s'" % (query_type, )) (404, "No handler for Query type '%s'" % (query_type, ))
) )
@defer.inlineCallbacks
def on_make_join_request(self, context, user_id): def on_make_join_request(self, context, user_id):
return self.handler.on_make_join_request(context, user_id) pdu = yield self.handler.on_make_join_request(context, user_id)
defer.returnValue(pdu.get_dict())
@defer.inlineCallbacks @defer.inlineCallbacks
def on_send_join_request(self, origin, content): def on_send_join_request(self, origin, content):
@ -406,13 +414,27 @@ class ReplicationLayer(object):
state = yield self.handler.on_send_join_request(origin, pdu) state = yield self.handler.on_send_join_request(origin, pdu)
defer.returnValue((200, self._transaction_from_pdus(state).get_dict())) defer.returnValue((200, self._transaction_from_pdus(state).get_dict()))
@defer.inlineCallbacks
def make_join(self, destination, context, user_id): def make_join(self, destination, context, user_id):
return self.transport_layer.make_join( pdu_dict = yield self.transport_layer.make_join(
destination=destination, destination=destination,
context=context, context=context,
user_id=user_id, user_id=user_id,
) )
logger.debug("Got response to make_join: %s", pdu_dict)
defer.returnValue(Pdu(**pdu_dict))
def send_join(self, destination, pdu):
return self.transport_layer.send_join(
destination,
pdu.context,
pdu.pdu_id,
pdu.origin,
pdu.get_dict(),
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _get_persisted_pdu(self, pdu_id, pdu_origin): def _get_persisted_pdu(self, pdu_id, pdu_origin):
@ -443,7 +465,7 @@ class ReplicationLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _handle_new_pdu(self, pdu, backfilled=False): def _handle_new_pdu(self, origin, pdu, backfilled=False):
# We reprocess pdus when we have seen them only as outliers # We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin) existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
@ -452,6 +474,8 @@ class ReplicationLayer(object):
defer.returnValue({}) defer.returnValue({})
return return
state = None
# Get missing pdus if necessary. # Get missing pdus if necessary.
is_new = yield self.pdu_actions.is_new(pdu) is_new = yield self.pdu_actions.is_new(pdu)
if is_new and not pdu.outlier: if is_new and not pdu.outlier:
@ -475,12 +499,22 @@ class ReplicationLayer(object):
except: except:
# TODO(erikj): Do some more intelligent retries. # TODO(erikj): Do some more intelligent retries.
logger.exception("Failed to get PDU") logger.exception("Failed to get PDU")
else:
# We need to get the state at this event, since we have reached
# a backward extremity edge.
state = yield self.get_state_for_context(
origin, pdu.context, pdu.pdu_id, pdu.origin,
)
# Persist the Pdu, but don't mark it as processed yet. # Persist the Pdu, but don't mark it as processed yet.
yield self.store.persist_event(pdu=pdu) yield self.store.persist_event(pdu=pdu)
if not backfilled: if not backfilled:
ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled) ret = yield self.handler.on_receive_pdu(
pdu,
backfilled=backfilled,
state=state,
)
else: else:
ret = None ret = None

View File

@ -72,7 +72,8 @@ class TransportLayer(object):
self.received_handler = None self.received_handler = None
@log_function @log_function
def get_context_state(self, destination, context): def get_context_state(self, destination, context, pdu_id=None,
pdu_origin=None):
""" Requests all state for a given context (i.e. room) from the """ Requests all state for a given context (i.e. room) from the
given server. given server.
@ -89,7 +90,14 @@ class TransportLayer(object):
subpath = "/state/%s/" % context subpath = "/state/%s/" % context
return self._do_request_for_transaction(destination, subpath) args = {}
if pdu_id and pdu_origin:
args["pdu_id"] = pdu_id
args["pdu_origin"] = pdu_origin
return self._do_request_for_transaction(
destination, subpath, args=args
)
@log_function @log_function
def get_pdu(self, destination, pdu_origin, pdu_id): def get_pdu(self, destination, pdu_origin, pdu_id):
@ -135,8 +143,10 @@ class TransportLayer(object):
subpath = "/backfill/%s/" % context subpath = "/backfill/%s/" % context
args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]} args = {
args["limit"] = limit "v": ["%s,%s" % (i, o) for i, o in pdu_tuples],
"limit": limit,
}
return self._do_request_for_transaction( return self._do_request_for_transaction(
dest, dest,
@ -210,6 +220,23 @@ class TransportLayer(object):
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def send_join(self, destination, context, pdu_id, origin, content):
path = PREFIX + "/send_join/%s/%s/%s" % (
context,
origin,
pdu_id,
)
response = yield self.client.put_json(
destination=destination,
path=path,
data=content,
)
defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
def _authenticate_request(self, request): def _authenticate_request(self, request):
json_request = { json_request = {
@ -330,7 +357,11 @@ class TransportLayer(object):
re.compile("^" + PREFIX + "/state/([^/]*)/$"), re.compile("^" + PREFIX + "/state/([^/]*)/$"),
self._with_authentication( self._with_authentication(
lambda origin, content, query, context: lambda origin, content, query, context:
handler.on_context_state_request(context) handler.on_context_state_request(
context,
query.get("pdu_id", [None])[0],
query.get("pdu_origin", [None])[0]
)
) )
) )
@ -369,7 +400,23 @@ class TransportLayer(object):
self.server.register_path( self.server.register_path(
"GET", "GET",
re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"), re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
self._on_make_join_request self._with_authentication(
lambda origin, content, query, context, user_id:
self._on_make_join_request(
origin, content, query, context, user_id
)
)
)
self.server.register_path(
"PUT",
re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)/([^/]*)$"),
self._with_authentication(
lambda origin, content, query, context, pdu_origin, pdu_id:
self._on_send_join_request(
origin, content, query,
)
)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -460,18 +507,23 @@ class TransportLayer(object):
context, versions, limit context, versions, limit
) )
@defer.inlineCallbacks
@log_function @log_function
def _on_make_join_request(self, origin, content, query, context, user_id): def _on_make_join_request(self, origin, content, query, context, user_id):
return self.request_handler.on_make_join_request( content = yield self.request_handler.on_make_join_request(
context, user_id, context, user_id,
) )
defer.returnValue((200, content))
@defer.inlineCallbacks
@log_function @log_function
def _on_send_join_request(self, origin, content, query): def _on_send_join_request(self, origin, content, query):
return self.request_handler.on_send_join_request( content = yield self.request_handler.on_send_join_request(
origin, content, origin, content,
) )
defer.returnValue((200, content))
class TransportReceivedHandler(object): class TransportReceivedHandler(object):
""" Callbacks used when we receive a transaction """ Callbacks used when we receive a transaction

View File

@ -20,7 +20,7 @@ from ._base import BaseHandler
from synapse.api.events.room import InviteJoinEvent, RoomMemberEvent from synapse.api.events.room import InviteJoinEvent, RoomMemberEvent
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.federation.pdu_codec import PduCodec from synapse.federation.pdu_codec import PduCodec, encode_event_id
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
@ -87,7 +87,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def on_receive_pdu(self, pdu, backfilled): def on_receive_pdu(self, pdu, backfilled, state=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to """ Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it through the StateHandler. do auth checks and put it through the StateHandler.
""" """
@ -95,7 +95,10 @@ class FederationHandler(BaseHandler):
logger.debug("Got event: %s", event.event_id) logger.debug("Got event: %s", event.event_id)
yield self.state_handler.annotate_state_groups(event) if state:
state = [self.pdu_codec.event_from_pdu(p) for p in state]
state = {(e.type, e.state_key): e for e in state}
yield self.state_handler.annotate_state_groups(event, state=state)
logger.debug("Event: %s", event) logger.debug("Event: %s", event)
@ -108,83 +111,55 @@ class FederationHandler(BaseHandler):
) )
else: else:
is_new_state = False is_new_state = False
# TODO: Implement something in federation that allows us to # TODO: Implement something in federation that allows us to
# respond to PDU. # respond to PDU.
target_is_mine = False with (yield self.room_lock.lock(event.room_id)):
if hasattr(event, "target_host"): yield self.store.persist_event(
target_is_mine = event.target_host == self.hs.hostname event,
backfilled,
if event.type == InviteJoinEvent.TYPE: is_new_state=is_new_state
if not target_is_mine:
logger.debug("Ignoring invite/join event %s", event)
return
# If we receive an invite/join event then we need to join the
# sender to the given room.
# TODO: We should probably auth this or some such
content = event.content
content.update({"membership": Membership.JOIN})
new_event = self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
state_key=event.user_id,
room_id=event.room_id,
user_id=event.user_id,
membership=Membership.JOIN,
content=content
) )
yield self.hs.get_handlers().room_member_handler.change_membership( room = yield self.store.get_room(event.room_id)
new_event,
do_auth=False, if not room:
# Huh, let's try and get the current state
try:
yield self.replication_layer.get_state_for_context(
event.origin, event.room_id, pdu.pdu_id, pdu.origin,
)
hosts = yield self.store.get_joined_hosts_for_room(
event.room_id
)
if self.hs.hostname in hosts:
try:
yield self.store.store_room(
room_id=event.room_id,
room_creator_user_id="",
is_public=False,
)
except:
pass
except:
logger.exception(
"Failed to get current state for room %s",
event.room_id
)
if not backfilled:
extra_users = []
if event.type == RoomMemberEvent.TYPE:
target_user_id = event.state_key
target_user = self.hs.parse_userid(target_user_id)
extra_users.append(target_user)
yield self.notifier.on_new_room_event(
event, extra_users=extra_users
) )
else:
with (yield self.room_lock.lock(event.room_id)):
yield self.store.persist_event(
event,
backfilled,
is_new_state=is_new_state
)
room = yield self.store.get_room(event.room_id)
if not room:
# Huh, let's try and get the current state
try:
yield self.replication_layer.get_state_for_context(
event.origin, event.room_id
)
hosts = yield self.store.get_joined_hosts_for_room(
event.room_id
)
if self.hs.hostname in hosts:
try:
yield self.store.store_room(
room_id=event.room_id,
room_creator_user_id="",
is_public=False,
)
except:
pass
except:
logger.exception(
"Failed to get current state for room %s",
event.room_id
)
if not backfilled:
extra_users = []
if event.type == RoomMemberEvent.TYPE:
target_user_id = event.state_key
target_user = self.hs.parse_userid(target_user_id)
extra_users.append(target_user)
yield self.notifier.on_new_room_event(
event, extra_users=extra_users
)
if event.type == RoomMemberEvent.TYPE: if event.type == RoomMemberEvent.TYPE:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
user = self.hs.parse_userid(event.state_key) user = self.hs.parse_userid(event.state_key)
@ -214,40 +189,35 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def do_invite_join(self, target_host, room_id, joinee, content, snapshot): def do_invite_join(self, target_host, room_id, joinee, content, snapshot):
hosts = yield self.store.get_joined_hosts_for_room(room_id) hosts = yield self.store.get_joined_hosts_for_room(room_id)
if self.hs.hostname in hosts: if self.hs.hostname in hosts:
# We are already in the room. # We are already in the room.
logger.debug("We're already in the room apparently") logger.debug("We're already in the room apparently")
defer.returnValue(False) defer.returnValue(False)
# First get current state to see if we are already joined. pdu = yield self.replication_layer.make_join(
try: target_host,
yield self.replication_layer.get_state_for_context( room_id,
target_host, room_id joinee
)
hosts = yield self.store.get_joined_hosts_for_room(room_id)
if self.hs.hostname in hosts:
# Oh, we were actually in the room already.
logger.debug("We're already in the room apparently")
defer.returnValue(False)
except Exception:
logger.exception("Failed to get current state")
new_event = self.event_factory.create_event(
etype=InviteJoinEvent.TYPE,
target_host=target_host,
room_id=room_id,
user_id=joinee,
content=content
) )
new_event.destinations = [target_host] logger.debug("Got response to make_join: %s", pdu)
snapshot.fill_out_prev_events(new_event) event = self.pdu_codec.event_from_pdu(pdu)
yield self.state_handler.annotate_state_groups(new_event)
yield self.handle_new_event(new_event, snapshot) # We should assert some things.
assert(event.type == RoomMemberEvent.TYPE)
assert(event.user_id == joinee)
assert(event.state_key == joinee)
assert(event.room_id == room_id)
event.event_id = self.event_factory.create_event_id()
event.content = content
state = yield self.replication_layer.send_join(
target_host,
self.pdu_codec.pdu_from_event(event)
)
# TODO (erikj): Time out here. # TODO (erikj): Time out here.
d = defer.Deferred() d = defer.Deferred()
@ -326,14 +296,31 @@ class FederationHandler(BaseHandler):
"user_joined_room", user=user, room_id=event.room_id "user_joined_room", user=user, room_id=event.room_id
) )
pdu.destinations = yield self.store.get_joined_hosts_for_room( new_pdu = self.pdu_codec.pdu_from_event(event);
new_pdu.destinations = yield self.store.get_joined_hosts_for_room(
event.room_id event.room_id
) )
yield self.replication_layer.send_pdu(pdu) yield self.replication_layer.send_pdu(new_pdu)
defer.returnValue(event.state_events.values()) defer.returnValue(event.state_events.values())
@defer.inlineCallbacks
def get_state_for_pdu(self, pdu_id, pdu_origin):
state_groups = yield self.store.get_state_groups(
[encode_event_id(pdu_id, pdu_origin)]
)
if state_groups:
defer.returnValue(
[
self.pdu_codec.pdu_from_event(s)
for s in state_groups[0].state
]
)
else:
defer.returnValue([])
@log_function @log_function
def _on_user_joined(self, user, room_id): def _on_user_joined(self, user, room_id):
waiters = self.waiting_for_join_list.get((user.to_string(), room_id), []) waiters = self.waiting_for_join_list.get((user.to_string(), room_id), [])

View File

@ -130,7 +130,13 @@ class StateHandler(object):
defer.returnValue(is_new) defer.returnValue(is_new)
@defer.inlineCallbacks @defer.inlineCallbacks
def annotate_state_groups(self, event): def annotate_state_groups(self, event, state=None):
if state:
event.state_group = None
event.old_state_events = None
event.state_events = state
return
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
event.prev_events event.prev_events
) )
@ -177,7 +183,7 @@ class StateHandler(object):
new_powers_deferreds = [] new_powers_deferreds = []
for e in curr_events: for e in curr_events:
new_powers_deferreds.append( new_powers_deferreds.append(
self.store.get_power_level(e.context, e.user_id) self.store.get_power_level(e.room_id, e.user_id)
) )
new_powers = yield defer.gatherResults( new_powers = yield defer.gatherResults(