Merge branch 'master' of github.com:matrix-org/synapse into sql_refactor

Conflicts:
	synapse/storage/_base.py
This commit is contained in:
Erik Johnston 2014-08-14 10:01:04 +01:00
commit 10294b6082
35 changed files with 412 additions and 188 deletions

View file

@ -14,6 +14,7 @@
# limitations under the License.
"""This module contains classes for authenticating the user."""
from twisted.internet import defer
from synapse.api.constants import Membership

View file

@ -75,8 +75,8 @@ class FederationEventHandler(object):
@log_function
@defer.inlineCallbacks
def backfill(self, room_id, limit):
# TODO: Work out which destinations to ask for pagination
# self.replication_layer.paginate(dest, room_id, limit)
# TODO: Work out which destinations to ask for backfill
# self.replication_layer.backfill(dest, room_id, limit)
pass
@log_function

View file

@ -114,14 +114,14 @@ class PduActions(object):
@defer.inlineCallbacks
@log_function
def paginate(self, context, pdu_list, limit):
def backfill(self, context, pdu_list, limit):
""" For a given list of PDU id and origins return the proceeding
`limit` `Pdu`s in the given `context`.
Returns:
Deferred: Results in a list of `Pdu`s.
"""
results = yield self.store.get_pagination(
results = yield self.store.get_backfill(
context, pdu_list, limit
)
@ -131,7 +131,7 @@ class PduActions(object):
def is_new(self, pdu):
""" When we receive a `Pdu` from a remote home server, we want to
figure out whether it is `new`, i.e. it is not some historic PDU that
we haven't seen simply because we haven't paginated back that far.
we haven't seen simply because we haven't backfilled back that far.
Returns:
Deferred: Results in a `bool`

View file

@ -118,7 +118,7 @@ class ReplicationLayer(object):
*Note:* The home server should always call `send_pdu` even if it knows
that it does not need to be replicated to other home servers. This is
in case e.g. someone else joins via a remote home server and then
paginates.
backfills.
TODO: Figure out when we should actually resolve the deferred.
@ -179,13 +179,13 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
def paginate(self, dest, context, limit):
def backfill(self, dest, context, limit):
"""Requests some more historic PDUs for the given context from the
given destination server.
Args:
dest (str): The remote home server to ask.
context (str): The context to paginate back on.
context (str): The context to backfill.
limit (int): The maximum number of PDUs to return.
Returns:
@ -193,16 +193,16 @@ class ReplicationLayer(object):
"""
extremities = yield self.store.get_oldest_pdus_in_context(context)
logger.debug("paginate extrem=%s", extremities)
logger.debug("backfill extrem=%s", extremities)
# If there are no extremeties then we've (probably) reached the start.
if not extremities:
return
transaction_data = yield self.transport_layer.paginate(
transaction_data = yield self.transport_layer.backfill(
dest, context, extremities, limit)
logger.debug("paginate transaction_data=%s", repr(transaction_data))
logger.debug("backfill transaction_data=%s", repr(transaction_data))
transaction = Transaction(**transaction_data)
@ -281,9 +281,9 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
def on_paginate_request(self, context, versions, limit):
def on_backfill_request(self, context, versions, limit):
pdus = yield self.pdu_actions.paginate(context, versions, limit)
pdus = yield self.pdu_actions.backfill(context, versions, limit)
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@ -427,7 +427,7 @@ class ReplicationLayer(object):
# Get missing pdus if necessary.
is_new = yield self.pdu_actions.is_new(pdu)
if is_new and not pdu.outlier:
# We only paginate backwards to the min depth.
# We only backfill backwards to the min depth.
min_depth = yield self.store.get_min_depth_for_context(pdu.context)
if min_depth and pdu.depth > min_depth:

View file

@ -112,7 +112,7 @@ class TransportLayer(object):
return self._do_request_for_transaction(destination, subpath)
@log_function
def paginate(self, dest, context, pdu_tuples, limit):
def backfill(self, dest, context, pdu_tuples, limit):
""" Requests `limit` previous PDUs in a given context before list of
PDUs.
@ -126,14 +126,14 @@ class TransportLayer(object):
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug(
"paginate dest=%s, context=%s, pdu_tuples=%s, limit=%s",
"backfill dest=%s, context=%s, pdu_tuples=%s, limit=%s",
dest, context, repr(pdu_tuples), str(limit)
)
if not pdu_tuples:
return
subpath = "/paginate/%s/" % context
subpath = "/backfill/%s/" % context
args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]}
args["limit"] = limit
@ -251,8 +251,8 @@ class TransportLayer(object):
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/paginate/([^/]*)/$"),
lambda request, context: self._on_paginate_request(
re.compile("^" + PREFIX + "/backfill/([^/]*)/$"),
lambda request, context: self._on_backfill_request(
context, request.args["v"],
request.args["limit"]
)
@ -352,7 +352,7 @@ class TransportLayer(object):
defer.returnValue(data)
@log_function
def _on_paginate_request(self, context, v_list, limits):
def _on_backfill_request(self, context, v_list, limits):
if not limits:
return defer.succeed(
(400, {"error": "Did not include limit param"})
@ -362,7 +362,7 @@ class TransportLayer(object):
versions = [v.split(",", 1) for v in v_list]
return self.request_handler.on_paginate_request(
return self.request_handler.on_backfill_request(
context, versions, limit)
@ -371,14 +371,14 @@ class TransportReceivedHandler(object):
"""
def on_incoming_transaction(self, transaction):
""" Called on PUT /send/<transaction_id>, or on response to a request
that we sent (e.g. a pagination request)
that we sent (e.g. a backfill request)
Args:
transaction (synapse.transaction.Transaction): The transaction that
was sent to us.
Returns:
twisted.internet.defer.Deferred: A deferred that get's fired when
twisted.internet.defer.Deferred: A deferred that gets fired when
the transaction has finished being processed.
The result should be a tuple in the form of
@ -438,14 +438,14 @@ class TransportRequestHandler(object):
def on_context_state_request(self, context):
""" Called on GET /state/<context>/
Get's hit when someone wants all the *current* state for a given
Gets hit when someone wants all the *current* state for a given
contexts.
Args:
context (str): The name of the context that we're interested in.
Returns:
twisted.internet.defer.Deferred: A deferred that get's fired when
twisted.internet.defer.Deferred: A deferred that gets fired when
the transaction has finished being processed.
The result should be a tuple in the form of
@ -457,20 +457,20 @@ class TransportRequestHandler(object):
"""
pass
def on_paginate_request(self, context, versions, limit):
""" Called on GET /paginate/<context>/?v=...&limit=...
def on_backfill_request(self, context, versions, limit):
""" Called on GET /backfill/<context>/?v=...&limit=...
Get's hit when we want to paginate backwards on a given context from
Gets hit when we want to backfill backwards on a given context from
the given point.
Args:
context (str): The context to paginate on
versions (list): A list of 2-tuple's representing where to paginate
context (str): The context to backfill
versions (list): A list of 2-tuples representing where to backfill
from, in the form `(pdu_id, origin)`
limit (int): How many pdus to return.
Returns:
Deferred: Resultsin a tuple in the form of
Deferred: Results in a tuple in the form of
`(response_code, respond_body)`, where `response_body` is a python
dict that will get serialized to JSON.

View file

@ -35,9 +35,11 @@ class DirectoryHandler(BaseHandler):
def __init__(self, hs):
super(DirectoryHandler, self).__init__(hs)
self.hs = hs
self.http_client = hs.get_http_client()
self.clock = hs.get_clock()
self.federation = hs.get_replication_layer()
self.federation.register_query_handler(
"directory", self.on_directory_query
)
@defer.inlineCallbacks
def create_association(self, room_alias, room_id, servers):
@ -58,9 +60,7 @@ class DirectoryHandler(BaseHandler):
)
@defer.inlineCallbacks
def get_association(self, room_alias, local_only=False):
# TODO(erikj): Do auth
def get_association(self, room_alias):
room_id = None
if room_alias.is_mine:
result = yield self.store.get_association_from_room_alias(
@ -70,22 +70,13 @@ class DirectoryHandler(BaseHandler):
if result:
room_id = result.room_id
servers = result.servers
elif not local_only:
path = "%s/ds/room/%s?local_only=1" % (
PREFIX,
urllib.quote(room_alias.to_string())
else:
result = yield self.federation.make_query(
destination=room_alias.domain,
query_type="directory",
args={"room_alias": room_alias.to_string()},
)
result = None
try:
result = yield self.http_client.get_json(
destination=room_alias.domain,
path=path,
)
except:
# TODO(erikj): Handle this better?
logger.exception("Failed to get remote room alias")
if result and "room_id" in result and "servers" in result:
room_id = result["room_id"]
servers = result["servers"]
@ -99,3 +90,20 @@ class DirectoryHandler(BaseHandler):
"servers": servers,
})
return
@defer.inlineCallbacks
def on_directory_query(self, args):
room_alias = self.hs.parse_roomalias(args["room_alias"])
if not room_alias.is_mine:
raise SynapseError(
400, "Room Alias is not hosted on this Home Server"
)
result = yield self.store.get_association_from_room_alias(
room_alias
)
defer.returnValue({
"room_id": result.room_id,
"servers": result.servers,
})

View file

@ -56,6 +56,8 @@ class PresenceHandler(BaseHandler):
self.homeserver = hs
self.clock = hs.get_clock()
distributor = hs.get_distributor()
distributor.observe("registered_user", self.registered_user)
@ -168,14 +170,15 @@ class PresenceHandler(BaseHandler):
state = yield self.store.get_presence_state(
target_user.localpart
)
defer.returnValue(state)
else:
raise SynapseError(404, "Presence information not visible")
else:
# TODO(paul): Have remote server send us permissions set
defer.returnValue(
self._get_or_offline_usercache(target_user).get_state()
)
state = self._get_or_offline_usercache(target_user).get_state()
if "mtime" in state:
state["mtime_age"] = self.clock.time_msec() - state.pop("mtime")
defer.returnValue(state)
@defer.inlineCallbacks
def set_state(self, target_user, auth_user, state):
@ -209,6 +212,8 @@ class PresenceHandler(BaseHandler):
),
])
state["mtime"] = self.clock.time_msec()
now_online = state["state"] != PresenceState.OFFLINE
was_polling = target_user in self._user_cachemap
@ -361,6 +366,8 @@ class PresenceHandler(BaseHandler):
observed_user = self.hs.parse_userid(p.pop("observed_user_id"))
p["observed_user"] = observed_user
p.update(self._get_or_offline_usercache(observed_user).get_state())
if "mtime" in p:
p["mtime_age"] = self.clock.time_msec() - p.pop("mtime")
defer.returnValue(presence)
@ -546,10 +553,15 @@ class PresenceHandler(BaseHandler):
def _push_presence_remote(self, user, destination, state=None):
if state is None:
state = yield self.store.get_presence_state(user.localpart)
yield self.distributor.fire(
"collect_presencelike_data", user, state
)
if "mtime" in state:
state = dict(state)
state["mtime_age"] = self.clock.time_msec() - state.pop("mtime")
yield self.federation.send_edu(
destination=destination,
edu_type="m.presence",
@ -585,6 +597,9 @@ class PresenceHandler(BaseHandler):
state = dict(push)
del state["user_id"]
if "mtime_age" in state:
state["mtime"] = self.clock.time_msec() - state.pop("mtime_age")
statuscache = self._get_or_make_usercache(user)
self._user_cachemap_latest_serial += 1
@ -631,9 +646,14 @@ class PresenceHandler(BaseHandler):
def push_update_to_clients(self, observer_user, observed_user,
statuscache):
state = statuscache.make_event(user=observed_user, clock=self.clock)
self.notifier.on_new_user_event(
observer_user.to_string(),
event_data=statuscache.make_event(user=observed_user),
event_data=statuscache.make_event(
user=observed_user,
clock=self.clock
),
stream_type=PresenceStreamData,
store_id=statuscache.serial
)
@ -652,8 +672,10 @@ class PresenceStreamData(StreamData):
if from_key < cachemap[k].serial <= to_key]
if updates:
clock = self.presence.clock
latest_serial = max([x[1].serial for x in updates])
data = [x[1].make_event(user=x[0]) for x in updates]
data = [x[1].make_event(user=x[0], clock=clock) for x in updates]
return ((data, latest_serial))
else:
return (([], self.presence._user_cachemap_latest_serial))
@ -674,6 +696,8 @@ class UserPresenceCache(object):
self.serial = None
def update(self, state, serial):
assert("mtime_age" not in state)
self.state.update(state)
# Delete keys that are now 'None'
for k in self.state.keys():
@ -691,8 +715,11 @@ class UserPresenceCache(object):
# clone it so caller can't break our cache
return dict(self.state)
def make_event(self, user):
def make_event(self, user, clock):
content = self.get_state()
content["user_id"] = user.to_string()
if "mtime" in content:
content["mtime_age"] = clock.time_msec() - content.pop("mtime")
return {"type": "m.presence", "content": content}

View file

@ -32,7 +32,7 @@ import urllib
logger = logging.getLogger(__name__)
# FIXME: SURELY these should be killed?!
_destination_mappings = {
"red": "localhost:8080",
"blue": "localhost:8081",
@ -147,7 +147,7 @@ class TwistedHttpClient(HttpClient):
destination.encode("ascii"),
"GET",
path.encode("ascii"),
query_bytes
query_bytes=query_bytes
)
body = yield readBody(response)

View file

@ -16,7 +16,6 @@
from twisted.internet import defer
from synapse.types import RoomAlias, RoomID
from base import RestServlet, client_path_pattern
import json
@ -36,17 +35,10 @@ class ClientDirectoryServer(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_alias):
# TODO(erikj): Handle request
local_only = "local_only" in request.args
room_alias = urllib.unquote(room_alias)
room_alias_obj = RoomAlias.from_string(room_alias, self.hs)
room_alias = self.hs.parse_roomalias(urllib.unquote(room_alias))
dir_handler = self.handlers.directory_handler
res = yield dir_handler.get_association(
room_alias_obj,
local_only=local_only
)
res = yield dir_handler.get_association(room_alias)
defer.returnValue((200, res))
@ -57,10 +49,9 @@ class ClientDirectoryServer(RestServlet):
logger.debug("Got content: %s", content)
room_alias = urllib.unquote(room_alias)
room_alias_obj = RoomAlias.from_string(room_alias, self.hs)
room_alias = self.hs.parse_roomalias(urllib.unquote(room_alias))
logger.debug("Got room name: %s", room_alias_obj.to_string())
logger.debug("Got room name: %s", room_alias.to_string())
room_id = content["room_id"]
servers = content["servers"]
@ -75,7 +66,7 @@ class ClientDirectoryServer(RestServlet):
try:
yield dir_handler.create_association(
room_alias_obj, room_id, servers
room_alias, room_id, servers
)
except:
logger.exception("Failed to create association")

View file

@ -22,7 +22,6 @@ from synapse.api.events.room import (RoomTopicEvent, MessageEvent,
RoomMemberEvent, FeedbackEvent)
from synapse.api.constants import Feedback, Membership
from synapse.api.streams import PaginationConfig
from synapse.types import RoomAlias
import json
import logging
@ -150,10 +149,7 @@ class JoinRoomAliasServlet(RestServlet):
logger.debug("room_alias: %s", room_alias)
room_alias = RoomAlias.from_string(
urllib.unquote(room_alias),
self.hs
)
room_alias = self.hs.parse_roomalias(urllib.unquote(room_alias))
handler = self.handlers.room_member_handler
ret_dict = yield handler.join_room_alias(user, room_alias)

View file

@ -28,7 +28,7 @@ from synapse.handlers import Handlers
from synapse.rest import RestServletFactory
from synapse.state import StateHandler
from synapse.storage import DataStore
from synapse.types import UserID
from synapse.types import UserID, RoomAlias
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.util.lockutils import LockManager
@ -120,6 +120,11 @@ class BaseHomeServer(object):
object."""
return UserID.from_string(s, hs=self)
def parse_roomalias(self, s):
"""Parse the string given by 's' as a Room Alias and return a RoomAlias
object."""
return RoomAlias.from_string(s, hs=self)
# Build magic accessors for every dependency
for depname in BaseHomeServer.DEPENDENCIES:
BaseHomeServer._make_dependency_method(depname)

View file

@ -44,7 +44,6 @@ class DataStore(RoomDataStore, RoomMemberStore, MessageStore, RoomStore,
def __init__(self, hs):
super(DataStore, self).__init__(hs)
self.event_factory = hs.get_event_factory()
self.hs = hs
@defer.inlineCallbacks
def persist_event(self, event):

View file

@ -28,8 +28,10 @@ logger = logging.getLogger(__name__)
class SQLBaseStore(object):
def __init__(self, hs):
self.hs = hs
self._db_pool = hs.get_db_pool()
self.event_factory = hs.get_event_factory()
self._clock = hs.get_clock()
def cursor_to_dict(self, cursor):
"""Converts a SQL cursor into an list of dicts.

View file

@ -168,7 +168,7 @@ class PduStore(SQLBaseStore):
return self._get_pdu_tuples(txn, txn.fetchall())
def get_pagination(self, context, pdu_list, limit):
def get_backfill(self, context, pdu_list, limit):
"""Get a list of Pdus for a given topic that occured before (and
including) the pdus in pdu_list. Return a list of max size `limit`.
@ -182,12 +182,12 @@ class PduStore(SQLBaseStore):
list: A list of PduTuples
"""
return self._db_pool.runInteraction(
self._get_paginate, context, pdu_list, limit
self._get_backfill, context, pdu_list, limit
)
def _get_paginate(self, txn, context, pdu_list, limit):
def _get_backfill(self, txn, context, pdu_list, limit):
logger.debug(
"paginate: %s, %s, %s",
"backfill: %s, %s, %s",
context, repr(pdu_list), limit
)
@ -213,7 +213,7 @@ class PduStore(SQLBaseStore):
new_front = []
for pdu_id, origin in front:
logger.debug(
"_paginate_interaction: i=%s, o=%s",
"_backfill_interaction: i=%s, o=%s",
pdu_id, origin
)
@ -224,7 +224,7 @@ class PduStore(SQLBaseStore):
for row in txn.fetchall():
logger.debug(
"_paginate_interaction: got i=%s, o=%s",
"_backfill_interaction: got i=%s, o=%s",
*row
)
new_front.append(row)
@ -262,7 +262,7 @@ class PduStore(SQLBaseStore):
def update_min_depth_for_context(self, context, depth):
"""Update the minimum `depth` of the given context, which is the line
where we stop paginating backwards on.
on which we stop backfilling backwards.
Args:
context (str)
@ -320,9 +320,9 @@ class PduStore(SQLBaseStore):
return [(row[0], row[1], row[2]) for row in results]
def get_oldest_pdus_in_context(self, context):
"""Get a list of Pdus that we paginated beyond yet (and haven't seen).
This list is used when we want to paginate backwards and is the list we
send to the remote server.
"""Get a list of Pdus that we haven't backfilled beyond yet (and haven't
seen). This list is used when we want to backfill backwards and is the
list we send to the remote server.
Args:
txn

View file

@ -35,7 +35,7 @@ class PresenceStore(SQLBaseStore):
return self._simple_select_one(
table="presence",
keyvalues={"user_id": user_localpart},
retcols=["state", "status_msg"],
retcols=["state", "status_msg", "mtime"],
)
def set_presence_state(self, user_localpart, new_state):
@ -43,7 +43,8 @@ class PresenceStore(SQLBaseStore):
table="presence",
keyvalues={"user_id": user_localpart},
updatevalues={"state": new_state["state"],
"status_msg": new_state["status_msg"]},
"status_msg": new_state["status_msg"],
"mtime": self._clock.time_msec()},
retcols=["state"],
)

View file

@ -16,6 +16,7 @@ CREATE TABLE IF NOT EXISTS presence(
user_id INTEGER NOT NULL,
state INTEGER,
status_msg TEXT,
mtime INTEGER, -- miliseconds since last state change
FOREIGN KEY(user_id) REFERENCES users(id)
);