Reference Matrix Home Server

This commit is contained in:
matrix.org 2014-08-12 15:10:52 +01:00
commit 4f475c7697
217 changed files with 48447 additions and 0 deletions

View file

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This package includes all the federation specific logic.
"""
from .replication import ReplicationLayer
from .transport import TransportLayer
def initialize_http_replication(homeserver):
transport = TransportLayer(
homeserver.hostname,
server=homeserver.get_http_server(),
client=homeserver.get_http_client()
)
return ReplicationLayer(homeserver, transport)

View file

@ -0,0 +1,148 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .pdu_codec import PduCodec
from synapse.api.errors import AuthError
from synapse.util.logutils import log_function
import logging
logger = logging.getLogger(__name__)
class FederationEventHandler(object):
""" Responsible for:
a) handling received Pdus before handing them on as Events to the rest
of the home server (including auth and state conflict resoultion)
b) converting events that were produced by local clients that may need
to be sent to remote home servers.
"""
def __init__(self, hs):
self.store = hs.get_datastore()
self.replication_layer = hs.get_replication_layer()
self.state_handler = hs.get_state_handler()
# self.auth_handler = gs.get_auth_handler()
self.event_handler = hs.get_handlers().federation_handler
self.server_name = hs.hostname
self.lock_manager = hs.get_room_lock_manager()
self.replication_layer.set_handler(self)
self.pdu_codec = PduCodec(hs)
@log_function
@defer.inlineCallbacks
def handle_new_event(self, event):
""" Takes in an event from the client to server side, that has already
been authed and handled by the state module, and sends it to any
remote home servers that may be interested.
Args:
event
Returns:
Deferred: Resolved when it has successfully been queued for
processing.
"""
yield self._fill_out_prev_events(event)
pdu = self.pdu_codec.pdu_from_event(event)
if not hasattr(pdu, "destinations") or not pdu.destinations:
pdu.destinations = []
yield self.replication_layer.send_pdu(pdu)
@log_function
@defer.inlineCallbacks
def backfill(self, room_id, limit):
# TODO: Work out which destinations to ask for pagination
# self.replication_layer.paginate(dest, room_id, limit)
pass
@log_function
def get_state_for_room(self, destination, room_id):
return self.replication_layer.get_state_for_context(
destination, room_id
)
@log_function
@defer.inlineCallbacks
def on_receive_pdu(self, pdu):
""" Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it throught the StateHandler.
"""
event = self.pdu_codec.event_from_pdu(pdu)
try:
with (yield self.lock_manager.lock(pdu.context)):
if event.is_state:
is_new_state = yield self.state_handler.handle_new_state(
pdu
)
if not is_new_state:
return
else:
is_new_state = False
yield self.event_handler.on_receive(event, is_new_state)
except AuthError:
# TODO: Implement something in federation that allows us to
# respond to PDU.
raise
return
@defer.inlineCallbacks
def _on_new_state(self, pdu, new_state_event):
# TODO: Do any store stuff here. Notifiy C2S about this new
# state.
yield self.store.update_current_state(
pdu_id=pdu.pdu_id,
origin=pdu.origin,
context=pdu.context,
pdu_type=pdu.pdu_type,
state_key=pdu.state_key
)
yield self.event_handler.on_receive(new_state_event)
@defer.inlineCallbacks
def _fill_out_prev_events(self, event):
if hasattr(event, "prev_events"):
return
results = yield self.store.get_latest_pdus_in_context(
event.room_id
)
es = [
"%s@%s" % (p_id, origin) for p_id, origin, _ in results
]
event.prev_events = [e for e in es if e != event.event_id]
if results:
event.depth = max([int(v) for _, _, v in results]) + 1
else:
event.depth = 0

View file

@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .units import Pdu
import copy
def decode_event_id(event_id, server_name):
parts = event_id.split("@")
if len(parts) < 2:
return (event_id, server_name)
else:
return (parts[0], "".join(parts[1:]))
def encode_event_id(pdu_id, origin):
return "%s@%s" % (pdu_id, origin)
class PduCodec(object):
def __init__(self, hs):
self.server_name = hs.hostname
self.event_factory = hs.get_event_factory()
self.clock = hs.get_clock()
def event_from_pdu(self, pdu):
kwargs = {}
kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
kwargs["room_id"] = pdu.context
kwargs["etype"] = pdu.pdu_type
kwargs["prev_events"] = [
encode_event_id(p[0], p[1]) for p in pdu.prev_pdus
]
if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
kwargs["prev_state"] = encode_event_id(
pdu.prev_state_id, pdu.prev_state_origin
)
kwargs.update({
k: v
for k, v in pdu.get_full_dict().items()
if k not in [
"pdu_id",
"context",
"pdu_type",
"prev_pdus",
"prev_state_id",
"prev_state_origin",
]
})
return self.event_factory.create_event(**kwargs)
def pdu_from_event(self, event):
d = event.get_full_dict()
d["pdu_id"], d["origin"] = decode_event_id(
event.event_id, self.server_name
)
d["context"] = event.room_id
d["pdu_type"] = event.type
if hasattr(event, "prev_events"):
d["prev_pdus"] = [
decode_event_id(e, self.server_name)
for e in event.prev_events
]
if hasattr(event, "prev_state"):
d["prev_state_id"], d["prev_state_origin"] = (
decode_event_id(event.prev_state, self.server_name)
)
if hasattr(event, "state_key"):
d["is_state"] = True
kwargs = copy.deepcopy(event.unrecognized_keys)
kwargs.update({
k: v for k, v in d.items()
if k not in ["event_id", "room_id", "type", "prev_events"]
})
if "ts" not in kwargs:
kwargs["ts"] = int(self.clock.time_msec())
return Pdu(**kwargs)

View file

@ -0,0 +1,240 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This module contains all the persistence actions done by the federation
package.
These actions are mostly only used by the :py:mod:`.replication` module.
"""
from twisted.internet import defer
from .units import Pdu
from synapse.util.logutils import log_function
import copy
import json
import logging
logger = logging.getLogger(__name__)
class PduActions(object):
""" Defines persistence actions that relate to handling PDUs.
"""
def __init__(self, datastore):
self.store = datastore
@log_function
def persist_received(self, pdu):
""" Persists the given `Pdu` that was received from a remote home
server.
Returns:
Deferred
"""
return self._persist(pdu)
@defer.inlineCallbacks
@log_function
def persist_outgoing(self, pdu):
""" Persists the given `Pdu` that this home server created.
Returns:
Deferred
"""
ret = yield self._persist(pdu)
defer.returnValue(ret)
@log_function
def mark_as_processed(self, pdu):
""" Persist the fact that we have fully processed the given `Pdu`
Returns:
Deferred
"""
return self.store.mark_pdu_as_processed(pdu.pdu_id, pdu.origin)
@defer.inlineCallbacks
@log_function
def populate_previous_pdus(self, pdu):
""" Given an outgoing `Pdu` fill out its `prev_ids` key with the `Pdu`s
that we have received.
Returns:
Deferred
"""
results = yield self.store.get_latest_pdus_in_context(pdu.context)
pdu.prev_pdus = [(p_id, origin) for p_id, origin, _ in results]
vs = [int(v) for _, _, v in results]
if vs:
pdu.depth = max(vs) + 1
else:
pdu.depth = 0
@defer.inlineCallbacks
@log_function
def after_transaction(self, transaction_id, destination, origin):
""" Returns all `Pdu`s that we sent to the given remote home server
after a given transaction id.
Returns:
Deferred: Results in a list of `Pdu`s
"""
results = yield self.store.get_pdus_after_transaction(
transaction_id,
destination
)
defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
@defer.inlineCallbacks
@log_function
def get_all_pdus_from_context(self, context):
results = yield self.store.get_all_pdus_from_context(context)
defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
@defer.inlineCallbacks
@log_function
def paginate(self, context, pdu_list, limit):
""" For a given list of PDU id and origins return the proceeding
`limit` `Pdu`s in the given `context`.
Returns:
Deferred: Results in a list of `Pdu`s.
"""
results = yield self.store.get_pagination(
context, pdu_list, limit
)
defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
@log_function
def is_new(self, pdu):
""" When we receive a `Pdu` from a remote home server, we want to
figure out whether it is `new`, i.e. it is not some historic PDU that
we haven't seen simply because we haven't paginated back that far.
Returns:
Deferred: Results in a `bool`
"""
return self.store.is_pdu_new(
pdu_id=pdu.pdu_id,
origin=pdu.origin,
context=pdu.context,
depth=pdu.depth
)
@defer.inlineCallbacks
@log_function
def _persist(self, pdu):
kwargs = copy.copy(pdu.__dict__)
unrec_keys = copy.copy(pdu.unrecognized_keys)
del kwargs["content"]
kwargs["content_json"] = json.dumps(pdu.content)
kwargs["unrecognized_keys"] = json.dumps(unrec_keys)
logger.debug("Persisting: %s", repr(kwargs))
if pdu.is_state:
ret = yield self.store.persist_state(**kwargs)
else:
ret = yield self.store.persist_pdu(**kwargs)
yield self.store.update_min_depth_for_context(
pdu.context, pdu.depth
)
defer.returnValue(ret)
class TransactionActions(object):
""" Defines persistence actions that relate to handling Transactions.
"""
def __init__(self, datastore):
self.store = datastore
@log_function
def have_responded(self, transaction):
""" Have we already responded to a transaction with the same id and
origin?
Returns:
Deferred: Results in `None` if we have not previously responded to
this transaction or a 2-tuple of `(int, dict)` representing the
response code and response body.
"""
if not transaction.transaction_id:
raise RuntimeError("Cannot persist a transaction with no "
"transaction_id")
return self.store.get_received_txn_response(
transaction.transaction_id, transaction.origin
)
@log_function
def set_response(self, transaction, code, response):
""" Persist how we responded to a transaction.
Returns:
Deferred
"""
if not transaction.transaction_id:
raise RuntimeError("Cannot persist a transaction with no "
"transaction_id")
return self.store.set_received_txn_response(
transaction.transaction_id,
transaction.origin,
code,
json.dumps(response)
)
@defer.inlineCallbacks
@log_function
def prepare_to_send(self, transaction):
""" Persists the `Transaction` we are about to send and works out the
correct value for the `prev_ids` key.
Returns:
Deferred
"""
transaction.prev_ids = yield self.store.prep_send_transaction(
transaction.transaction_id,
transaction.destination,
transaction.ts,
[(p["pdu_id"], p["origin"]) for p in transaction.pdus]
)
@log_function
def delivered(self, transaction, response_code, response_dict):
""" Marks the given `Transaction` as having been successfully
delivered to the remote homeserver, and what the response was.
Returns:
Deferred
"""
return self.store.delivered_txn(
transaction.transaction_id,
transaction.destination,
response_code,
json.dumps(response_dict)
)

View file

@ -0,0 +1,582 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This layer is responsible for replicating with remote home servers using
a given transport.
"""
from twisted.internet import defer
from .units import Transaction, Pdu, Edu
from .persistence import PduActions, TransactionActions
from synapse.util.logutils import log_function
import logging
logger = logging.getLogger(__name__)
class ReplicationLayer(object):
"""This layer is responsible for replicating with remote home servers over
the given transport. I.e., does the sending and receiving of PDUs to
remote home servers.
The layer communicates with the rest of the server via a registered
ReplicationHandler.
In more detail, the layer:
* Receives incoming data and processes it into transactions and pdus.
* Fetches any PDUs it thinks it might have missed.
* Keeps the current state for contexts up to date by applying the
suitable conflict resolution.
* Sends outgoing pdus wrapped in transactions.
* Fills out the references to previous pdus/transactions appropriately
for outgoing data.
"""
def __init__(self, hs, transport_layer):
self.server_name = hs.hostname
self.transport_layer = transport_layer
self.transport_layer.register_received_handler(self)
self.transport_layer.register_request_handler(self)
self.store = hs.get_datastore()
self.pdu_actions = PduActions(self.store)
self.transaction_actions = TransactionActions(self.store)
self._transaction_queue = _TransactionQueue(
hs, self.transaction_actions, transport_layer
)
self.handler = None
self.edu_handlers = {}
self._order = 0
self._clock = hs.get_clock()
def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are
documented on :py:class:`.ReplicationHandler`.
"""
self.handler = handler
def register_edu_handler(self, edu_type, handler):
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type))
self.edu_handlers[edu_type] = handler
@defer.inlineCallbacks
@log_function
def send_pdu(self, pdu):
"""Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others.
This will fill out various attributes on the PDU object, e.g. the
`prev_pdus` key.
*Note:* The home server should always call `send_pdu` even if it knows
that it does not need to be replicated to other home servers. This is
in case e.g. someone else joins via a remote home server and then
paginates.
TODO: Figure out when we should actually resolve the deferred.
Args:
pdu (Pdu): The new Pdu.
Returns:
Deferred: Completes when we have successfully processed the PDU
and replicated it to any interested remote home servers.
"""
order = self._order
self._order += 1
logger.debug("[%s] Persisting PDU", pdu.pdu_id)
#yield self.pdu_actions.populate_previous_pdus(pdu)
# Save *before* trying to send
yield self.pdu_actions.persist_outgoing(pdu)
logger.debug("[%s] Persisted PDU", pdu.pdu_id)
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id)
# TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, order)
logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id)
@log_function
def send_edu(self, destination, edu_type, content):
edu = Edu(
origin=self.server_name,
destination=destination,
edu_type=edu_type,
content=content,
)
# TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu)
@defer.inlineCallbacks
@log_function
def paginate(self, dest, context, limit):
"""Requests some more historic PDUs for the given context from the
given destination server.
Args:
dest (str): The remote home server to ask.
context (str): The context to paginate back on.
limit (int): The maximum number of PDUs to return.
Returns:
Deferred: Results in the received PDUs.
"""
extremities = yield self.store.get_oldest_pdus_in_context(context)
logger.debug("paginate extrem=%s", extremities)
# If there are no extremeties then we've (probably) reached the start.
if not extremities:
return
transaction_data = yield self.transport_layer.paginate(
dest, context, extremities, limit)
logger.debug("paginate transaction_data=%s", repr(transaction_data))
transaction = Transaction(**transaction_data)
pdus = [Pdu(outlier=False, **p) for p in transaction.pdus]
for pdu in pdus:
yield self._handle_new_pdu(pdu)
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False):
"""Requests the PDU with given origin and ID from the remote home
server.
This will persist the PDU locally upon receipt.
Args:
destination (str): Which home server to query
pdu_origin (str): The home server that originally sent the pdu.
pdu_id (str)
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
Returns:
Deferred: Results in the requested PDU.
"""
transaction_data = yield self.transport_layer.get_pdu(
destination, pdu_origin, pdu_id)
transaction = Transaction(**transaction_data)
pdu_list = [Pdu(outlier=outlier, **p) for p in transaction.pdus]
pdu = None
if pdu_list:
pdu = pdu_list[0]
yield self._handle_new_pdu(pdu)
defer.returnValue(pdu)
@defer.inlineCallbacks
@log_function
def get_state_for_context(self, destination, context):
"""Requests all of the `current` state PDUs for a given context from
a remote home server.
Args:
destination (str): The remote homeserver to query for the state.
context (str): The context we're interested in.
Returns:
Deferred: Results in a list of PDUs.
"""
transaction_data = yield self.transport_layer.get_context_state(
destination, context)
transaction = Transaction(**transaction_data)
pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
for pdu in pdus:
yield self._handle_new_pdu(pdu)
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
def on_context_pdus_request(self, context):
pdus = yield self.pdu_actions.get_all_pdus_from_context(
context
)
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@log_function
def on_paginate_request(self, context, versions, limit):
pdus = yield self.pdu_actions.paginate(context, versions, limit)
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@log_function
def on_incoming_transaction(self, transaction_data):
transaction = Transaction(**transaction_data)
logger.debug("[%s] Got transaction", transaction.transaction_id)
response = yield self.transaction_actions.have_responded(transaction)
if response:
logger.debug("[%s] We've already responed to this request",
transaction.transaction_id)
defer.returnValue(response)
return
logger.debug("[%s] Transacition is new", transaction.transaction_id)
pdu_list = [Pdu(**p) for p in transaction.pdus]
dl = []
for pdu in pdu_list:
dl.append(self._handle_new_pdu(pdu))
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu(edu.origin, edu.edu_type, edu.content)
results = yield defer.DeferredList(dl)
ret = []
for r in results:
if r[0]:
ret.append({})
else:
logger.exception(r[1])
ret.append({"error": str(r[1])})
logger.debug("Returning: %s", str(ret))
yield self.transaction_actions.set_response(
transaction,
200, response
)
defer.returnValue((200, response))
def received_edu(self, origin, edu_type, content):
if edu_type in self.edu_handlers:
self.edu_handlers[edu_type](origin, content)
else:
logger.warn("Received EDU of type %s with no handler", edu_type)
@defer.inlineCallbacks
@log_function
def on_context_state_request(self, context):
results = yield self.store.get_current_state_for_context(
context
)
logger.debug("Context returning %d results", len(results))
pdus = [Pdu.from_pdu_tuple(p) for p in results]
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@log_function
def on_pdu_request(self, pdu_origin, pdu_id):
pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin)
if pdu:
defer.returnValue(
(200, self._transaction_from_pdus([pdu]).get_dict())
)
else:
defer.returnValue((404, ""))
@defer.inlineCallbacks
@log_function
def on_pull_request(self, origin, versions):
transaction_id = max([int(v) for v in versions])
response = yield self.pdu_actions.after_transaction(
transaction_id,
origin,
self.server_name
)
if not response:
response = []
defer.returnValue(
(200, self._transaction_from_pdus(response).get_dict())
)
@defer.inlineCallbacks
@log_function
def _get_persisted_pdu(self, pdu_id, pdu_origin):
""" Get a PDU from the database with given origin and id.
Returns:
Deferred: Results in a `Pdu`.
"""
pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin)
defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple))
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
"""
return Transaction(
pdus=[p.get_dict() for p in pdu_list],
origin=self.server_name,
ts=int(self._clock.time_msec()),
destination=None,
)
@defer.inlineCallbacks
@log_function
def _handle_new_pdu(self, pdu):
# We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
if existing and (not existing.outlier or pdu.outlier):
logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin)
defer.returnValue({})
return
# Get missing pdus if necessary.
is_new = yield self.pdu_actions.is_new(pdu)
if is_new and not pdu.outlier:
# We only paginate backwards to the min depth.
min_depth = yield self.store.get_min_depth_for_context(pdu.context)
if min_depth and pdu.depth > min_depth:
for pdu_id, origin in pdu.prev_pdus:
exists = yield self._get_persisted_pdu(pdu_id, origin)
if not exists:
logger.debug("Requesting pdu %s %s", pdu_id, origin)
try:
yield self.get_pdu(
pdu.origin,
pdu_id=pdu_id,
pdu_origin=origin
)
logger.debug("Processed pdu %s %s", pdu_id, origin)
except:
# TODO(erikj): Do some more intelligent retries.
logger.exception("Failed to get PDU")
# Persist the Pdu, but don't mark it as processed yet.
yield self.pdu_actions.persist_received(pdu)
ret = yield self.handler.on_receive_pdu(pdu)
yield self.pdu_actions.mark_as_processed(pdu)
defer.returnValue(ret)
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
class ReplicationHandler(object):
"""This defines the methods that the :py:class:`.ReplicationLayer` will
use to communicate with the rest of the home server.
"""
def on_receive_pdu(self, pdu):
raise NotImplementedError("on_receive_pdu")
class _TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
a time for a given destination.
It batches pending PDUs into single transactions.
"""
def __init__(self, hs, transaction_actions, transport_layer):
self.server_name = hs.hostname
self.transaction_actions = transaction_actions
self.transport_layer = transport_layer
self._clock = hs.get_clock()
# Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are
# done
self.pending_transactions = {}
# Is a mapping from destination -> list of
# tuple(pending pdus, deferred, order)
self.pending_pdus_by_dest = {}
# destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = {}
# HACK to get unique tx id
self._next_txn_id = int(self._clock.time_msec())
@defer.inlineCallbacks
@log_function
def enqueue_pdu(self, pdu, order):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
destinations = [
d for d in pdu.destinations
if d != self.server_name
]
logger.debug("Sending to: %s", str(destinations))
if not destinations:
return
deferreds = []
for destination in destinations:
deferred = defer.Deferred()
self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, deferred, order)
)
self._attempt_new_transaction(destination)
deferreds.append(deferred)
yield defer.DeferredList(deferreds)
# NO inlineCallbacks
def enqueue_edu(self, edu):
destination = edu.destination
deferred = defer.Deferred()
self.pending_edus_by_dest.setdefault(destination, []).append(
(edu, deferred)
)
def eb(failure):
deferred.errback(failure)
self._attempt_new_transaction(destination).addErrback(eb)
return deferred
@defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination):
if destination in self.pending_transactions:
return
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
if not pending_pdus and not pending_edus:
return
logger.debug("TX [%s] Attempting new transaction", destination)
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[2])
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
deferreds = [x[1] for x in pending_pdus + pending_edus]
try:
self.pending_transactions[destination] = 1
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
ts=self._clock.time_msec(),
transaction_id=self._next_txn_id,
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
)
self._next_txn_id += 1
yield self.transaction_actions.prepare_to_send(transaction)
logger.debug("TX [%s] Persisted transaction", destination)
logger.debug("TX [%s] Sending transaction...", destination)
# Actually send the transaction
code, response = yield self.transport_layer.send_transaction(
transaction
)
logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination)
yield self.transaction_actions.delivered(
transaction, code, response
)
logger.debug("TX [%s] Marked as delivered", destination)
logger.debug("TX [%s] Yielding to callbacks...", destination)
for deferred in deferreds:
if code == 200:
deferred.callback(None)
else:
deferred.errback(RuntimeError("Got status %d" % code))
# Ensures we don't continue until all callbacks on that
# deferred have fired
yield deferred
logger.debug("TX [%s] Yielded to callbacks", destination)
except Exception as e:
logger.error("TX Problem in _attempt_transaction")
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.exception(e)
for deferred in deferreds:
deferred.errback(e)
yield deferred
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
# Check to see if there is anything else to send.
self._attempt_new_transaction(destination)

View file

@ -0,0 +1,454 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The transport layer is responsible for both sending transactions to remote
home servers and receiving a variety of requests from other home servers.
Typically, this is done over HTTP (and all home servers are required to
support HTTP), however individual pairings of servers may decide to communicate
over a different (albeit still reliable) protocol.
"""
from twisted.internet import defer
from synapse.util.logutils import log_function
import logging
import json
import re
logger = logging.getLogger(__name__)
class TransportLayer(object):
"""This is a basic implementation of the transport layer that translates
transactions and other requests to/from HTTP.
Attributes:
server_name (str): Local home server host
server (synapse.http.server.HttpServer): the http server to
register listeners on
client (synapse.http.client.HttpClient): the http client used to
send requests
request_handler (TransportRequestHandler): The handler to fire when we
receive requests for data.
received_handler (TransportReceivedHandler): The handler to fire when
we receive data.
"""
def __init__(self, server_name, server, client):
"""
Args:
server_name (str): Local home server host
server (synapse.protocol.http.HttpServer): the http server to
register listeners on
client (synapse.protocol.http.HttpClient): the http client used to
send requests
"""
self.server_name = server_name
self.server = server
self.client = client
self.request_handler = None
self.received_handler = None
@log_function
def get_context_state(self, destination, context):
""" Requests all state for a given context (i.e. room) from the
given server.
Args:
destination (str): The host name of the remote home server we want
to get the state from.
context (str): The name of the context we want the state of
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug("get_context_state dest=%s, context=%s",
destination, context)
path = "/state/%s/" % context
return self._do_request_for_transaction(destination, path)
@log_function
def get_pdu(self, destination, pdu_origin, pdu_id):
""" Requests the pdu with give id and origin from the given server.
Args:
destination (str): The host name of the remote home server we want
to get the state from.
pdu_origin (str): The home server which created the PDU.
pdu_id (str): The id of the PDU being requested.
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s",
destination, pdu_origin, pdu_id)
path = "/pdu/%s/%s/" % (pdu_origin, pdu_id)
return self._do_request_for_transaction(destination, path)
@log_function
def paginate(self, dest, context, pdu_tuples, limit):
""" Requests `limit` previous PDUs in a given context before list of
PDUs.
Args:
dest (str)
context (str)
pdu_tuples (list)
limt (int)
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug(
"paginate dest=%s, context=%s, pdu_tuples=%s, limit=%s",
dest, context, repr(pdu_tuples), str(limit)
)
if not pdu_tuples:
return
path = "/paginate/%s/" % context
args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]}
args["limit"] = limit
return self._do_request_for_transaction(
dest,
path,
args=args,
)
@defer.inlineCallbacks
@log_function
def send_transaction(self, transaction):
""" Sends the given Transaction to it's destination
Args:
transaction (Transaction)
Returns:
Deferred: Results of the deferred is a tuple in the form of
(response_code, response_body) where the response_body is a
python dict decoded from json
"""
logger.debug(
"send_data dest=%s, txid=%s",
transaction.destination, transaction.transaction_id
)
if transaction.destination == self.server_name:
raise RuntimeError("Transport layer cannot send to itself!")
data = transaction.get_dict()
code, response = yield self.client.put_json(
transaction.destination,
path="/send/%s/" % transaction.transaction_id,
data=data
)
logger.debug(
"send_data dest=%s, txid=%s, got response: %d",
transaction.destination, transaction.transaction_id, code
)
defer.returnValue((code, response))
@log_function
def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data.
Args:
handler (TransportReceivedHandler)
"""
self.received_handler = handler
# This is when someone is trying to send us a bunch of data.
self.server.register_path(
"PUT",
re.compile("^/send/([^/]*)/$"),
self._on_send_request
)
@log_function
def register_request_handler(self, handler):
""" Register a handler that will be fired when we get asked for data.
Args:
handler (TransportRequestHandler)
"""
self.request_handler = handler
# TODO(markjh): Namespace the federation URI paths
# This is for when someone asks us for everything since version X
self.server.register_path(
"GET",
re.compile("^/pull/$"),
lambda request: handler.on_pull_request(
request.args["origin"][0],
request.args["v"]
)
)
# This is when someone asks for a data item for a given server
# data_id pair.
self.server.register_path(
"GET",
re.compile("^/pdu/([^/]*)/([^/]*)/$"),
lambda request, pdu_origin, pdu_id: handler.on_pdu_request(
pdu_origin, pdu_id
)
)
# This is when someone asks for all data for a given context.
self.server.register_path(
"GET",
re.compile("^/state/([^/]*)/$"),
lambda request, context: handler.on_context_state_request(
context
)
)
self.server.register_path(
"GET",
re.compile("^/paginate/([^/]*)/$"),
lambda request, context: self._on_paginate_request(
context, request.args["v"],
request.args["limit"]
)
)
self.server.register_path(
"GET",
re.compile("^/context/([^/]*)/$"),
lambda request, context: handler.on_context_pdus_request(context)
)
@defer.inlineCallbacks
@log_function
def _on_send_request(self, request, transaction_id):
""" Called on PUT /send/<transaction_id>/
Args:
request (twisted.web.http.Request): The HTTP request.
transaction_id (str): The transaction_id associated with this
request. This is *not* None.
Returns:
Deferred: Results in a tuple of `(code, response)`, where
`response` is a python dict to be converted into JSON that is
used as the response body.
"""
# Parse the request
try:
data = request.content.read()
l = data[:20].encode("string_escape")
logger.debug("Got data: \"%s\"", l)
transaction_data = json.loads(data)
logger.debug(
"Decoded %s: %s",
transaction_id, str(transaction_data)
)
# We should ideally be getting this from the security layer.
# origin = body["origin"]
# Add some extra data to the transaction dict that isn't included
# in the request body.
transaction_data.update(
transaction_id=transaction_id,
destination=self.server_name
)
except Exception as e:
logger.exception(e)
defer.returnValue((400, {"error": "Invalid transaction"}))
return
code, response = yield self.received_handler.on_incoming_transaction(
transaction_data
)
defer.returnValue((code, response))
@defer.inlineCallbacks
@log_function
def _do_request_for_transaction(self, destination, path, args={}):
"""
Args:
destination (str)
path (str)
args (dict): This is parsed directly to the HttpClient.
Returns:
Deferred: Results in a dict.
"""
data = yield self.client.get_json(
destination,
path=path,
args=args,
)
# Add certain keys to the JSON, ready for decoding as a Transaction
data.update(
origin=destination,
destination=self.server_name,
transaction_id=None
)
defer.returnValue(data)
@log_function
def _on_paginate_request(self, context, v_list, limits):
if not limits:
return defer.succeed(
(400, {"error": "Did not include limit param"})
)
limit = int(limits[-1])
versions = [v.split(",", 1) for v in v_list]
return self.request_handler.on_paginate_request(
context, versions, limit)
class TransportReceivedHandler(object):
""" Callbacks used when we receive a transaction
"""
def on_incoming_transaction(self, transaction):
""" Called on PUT /send/<transaction_id>, or on response to a request
that we sent (e.g. a pagination request)
Args:
transaction (synapse.transaction.Transaction): The transaction that
was sent to us.
Returns:
twisted.internet.defer.Deferred: A deferred that get's fired when
the transaction has finished being processed.
The result should be a tuple in the form of
`(response_code, respond_body)`, where `response_body` is a python
dict that will get serialized to JSON.
On errors, the dict should have an `error` key with a brief message
of what went wrong.
"""
pass
class TransportRequestHandler(object):
""" Handlers used when someone want's data from us
"""
def on_pull_request(self, versions):
""" Called on GET /pull/?v=...
This is hit when a remote home server wants to get all data
after a given transaction. Mainly used when a home server comes back
online and wants to get everything it has missed.
Args:
versions (list): A list of transaction_ids that should be used to
determine what PDUs the remote side have not yet seen.
Returns:
Deferred: Resultsin a tuple in the form of
`(response_code, respond_body)`, where `response_body` is a python
dict that will get serialized to JSON.
On errors, the dict should have an `error` key with a brief message
of what went wrong.
"""
pass
def on_pdu_request(self, pdu_origin, pdu_id):
""" Called on GET /pdu/<pdu_origin>/<pdu_id>/
Someone wants a particular PDU. This PDU may or may not have originated
from us.
Args:
pdu_origin (str)
pdu_id (str)
Returns:
Deferred: Resultsin a tuple in the form of
`(response_code, respond_body)`, where `response_body` is a python
dict that will get serialized to JSON.
On errors, the dict should have an `error` key with a brief message
of what went wrong.
"""
pass
def on_context_state_request(self, context):
""" Called on GET /state/<context>/
Get's hit when someone wants all the *current* state for a given
contexts.
Args:
context (str): The name of the context that we're interested in.
Returns:
twisted.internet.defer.Deferred: A deferred that get's fired when
the transaction has finished being processed.
The result should be a tuple in the form of
`(response_code, respond_body)`, where `response_body` is a python
dict that will get serialized to JSON.
On errors, the dict should have an `error` key with a brief message
of what went wrong.
"""
pass
def on_paginate_request(self, context, versions, limit):
""" Called on GET /paginate/<context>/?v=...&limit=...
Get's hit when we want to paginate backwards on a given context from
the given point.
Args:
context (str): The context to paginate on
versions (list): A list of 2-tuple's representing where to paginate
from, in the form `(pdu_id, origin)`
limit (int): How many pdus to return.
Returns:
Deferred: Resultsin a tuple in the form of
`(response_code, respond_body)`, where `response_body` is a python
dict that will get serialized to JSON.
On errors, the dict should have an `error` key with a brief message
of what went wrong.
"""
pass

236
synapse/federation/units.py Normal file
View file

@ -0,0 +1,236 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Defines the JSON structure of the protocol units used by the server to
server protocol.
"""
from synapse.util.jsonobject import JsonEncodedObject
import logging
import json
import copy
logger = logging.getLogger(__name__)
class Pdu(JsonEncodedObject):
""" A Pdu represents a piece of data sent from a server and is associated
with a context.
A Pdu can be classified as "state". For a given context, we can efficiently
retrieve all state pdu's that haven't been clobbered. Clobbering is done
via a unique constraint on the tuple (context, pdu_type, state_key). A pdu
is a state pdu if `is_state` is True.
Example pdu::
{
"pdu_id": "78c",
"ts": 1404835423000,
"origin": "bar",
"prev_ids": [
["23b", "foo"],
["56a", "bar"],
],
"content": { ... },
}
"""
valid_keys = [
"pdu_id",
"context",
"origin",
"ts",
"pdu_type",
"destinations",
"transaction_id",
"prev_pdus",
"depth",
"content",
"outlier",
"is_state", # Below this are keys valid only for State Pdus.
"state_key",
"power_level",
"prev_state_id",
"prev_state_origin",
]
internal_keys = [
"destinations",
"transaction_id",
"outlier",
]
required_keys = [
"pdu_id",
"context",
"origin",
"ts",
"pdu_type",
"content",
]
# TODO: We need to make this properly load content rather than
# just leaving it as a dict. (OR DO WE?!)
def __init__(self, destinations=[], is_state=False, prev_pdus=[],
outlier=False, **kwargs):
if is_state:
for required_key in ["state_key"]:
if required_key not in kwargs:
raise RuntimeError("Key %s is required" % required_key)
super(Pdu, self).__init__(
destinations=destinations,
is_state=is_state,
prev_pdus=prev_pdus,
outlier=outlier,
**kwargs
)
@classmethod
def from_pdu_tuple(cls, pdu_tuple):
""" Converts a PduTuple to a Pdu
Args:
pdu_tuple (synapse.persistence.transactions.PduTuple): The tuple to
convert
Returns:
Pdu
"""
if pdu_tuple:
d = copy.copy(pdu_tuple.pdu_entry._asdict())
d["content"] = json.loads(d["content_json"])
del d["content_json"]
args = {f: d[f] for f in cls.valid_keys if f in d}
if "unrecognized_keys" in d and d["unrecognized_keys"]:
args.update(json.loads(d["unrecognized_keys"]))
return Pdu(
prev_pdus=pdu_tuple.prev_pdu_list,
**args
)
else:
return None
def __str__(self):
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
def __repr__(self):
return "<%s, %s>" % (self.__class__.__name__, repr(self.__dict__))
class Edu(JsonEncodedObject):
""" An Edu represents a piece of data sent from one homeserver to another.
In comparison to Pdus, Edus are not persisted for a long time on disk, are
not meaningful beyond a given pair of homeservers, and don't have an
internal ID or previous references graph.
"""
valid_keys = [
"origin",
"destination",
"edu_type",
"content",
]
required_keys = [
"origin",
"destination",
"edu_type",
]
class Transaction(JsonEncodedObject):
""" A transaction is a list of Pdus and Edus to be sent to a remote home
server with some extra metadata.
Example transaction::
{
"origin": "foo",
"prev_ids": ["abc", "def"],
"pdus": [
...
],
}
"""
valid_keys = [
"transaction_id",
"origin",
"destination",
"ts",
"previous_ids",
"pdus",
"edus",
]
internal_keys = [
"transaction_id",
"destination",
]
required_keys = [
"transaction_id",
"origin",
"destination",
"ts",
"pdus",
]
def __init__(self, transaction_id=None, pdus=[], **kwargs):
""" If we include a list of pdus then we decode then as PDU's
automatically.
"""
# If there's no EDUs then remove the arg
if "edus" in kwargs and not kwargs["edus"]:
del kwargs["edus"]
super(Transaction, self).__init__(
transaction_id=transaction_id,
pdus=pdus,
**kwargs
)
@staticmethod
def create_new(pdus, **kwargs):
""" Used to create a new transaction. Will auto fill out
transaction_id and ts keys.
"""
if "ts" not in kwargs:
raise KeyError("Require 'ts' to construct a Transaction")
if "transaction_id" not in kwargs:
raise KeyError(
"Require 'transaction_id' to construct a Transaction"
)
for p in pdus:
p.transaction_id = kwargs["transaction_id"]
kwargs["pdus"] = [p.get_dict() for p in pdus]
return Transaction(**kwargs)