Merge branch 'develop' of github.com:matrix-org/synapse into erikj/backfill_fixes

This commit is contained in:
Erik Johnston 2015-05-22 16:10:42 +01:00
commit 74b7de83ec
26 changed files with 259 additions and 144 deletions

View File

@ -1,3 +1,9 @@
Changes in synapse v0.9.0-r5 (2015-05-21)
=========================================
* Add more database caches to reduce amount of work done for each pusher. This
radically reduces CPU usage when multiple pushers are set up in the same room.
Changes in synapse v0.9.0 (2015-05-07) Changes in synapse v0.9.0 (2015-05-07)
====================================== ======================================

View File

@ -33,9 +33,10 @@ def request_registration(user, password, server_location, shared_secret):
).hexdigest() ).hexdigest()
data = { data = {
"username": user, "user": user,
"password": password, "password": password,
"mac": mac, "mac": mac,
"type": "org.matrix.login.shared_secret",
} }
server_location = server_location.rstrip("/") server_location = server_location.rstrip("/")
@ -43,7 +44,7 @@ def request_registration(user, password, server_location, shared_secret):
print "Sending registration request..." print "Sending registration request..."
req = urllib2.Request( req = urllib2.Request(
"%s/_matrix/client/v2_alpha/register" % (server_location,), "%s/_matrix/client/api/v1/register" % (server_location,),
data=json.dumps(data), data=json.dumps(data),
headers={'Content-Type': 'application/json'} headers={'Content-Type': 'application/json'}
) )

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.9.0-r4" __version__ = "0.9.0-r5"

View File

@ -277,9 +277,12 @@ class SynapseHomeServer(HomeServer):
config, config,
metrics_resource, metrics_resource,
), ),
interface="127.0.0.1", interface=config.metrics_bind_host,
)
logger.info(
"Metrics now running on %s port %d",
config.metrics_bind_host, config.metrics_port,
) )
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
def run_startup_checks(self, db_conn, database_engine): def run_startup_checks(self, db_conn, database_engine):
all_users_native = are_all_users_on_domain( all_users_native = are_all_users_on_domain(

View File

@ -148,8 +148,8 @@ class ApplicationService(object):
and self.is_interested_in_user(event.state_key)): and self.is_interested_in_user(event.state_key)):
return True return True
# check joined member events # check joined member events
for member in member_list: for user_id in member_list:
if self.is_interested_in_user(member.state_key): if self.is_interested_in_user(user_id):
return True return True
return False return False
@ -173,7 +173,7 @@ class ApplicationService(object):
restrict_to(str): The namespace to restrict regex tests to. restrict_to(str): The namespace to restrict regex tests to.
aliases_for_event(list): A list of all the known room aliases for aliases_for_event(list): A list of all the known room aliases for
this event. this event.
member_list(list): A list of all joined room members in this room. member_list(list): A list of all joined user_ids in this room.
Returns: Returns:
bool: True if this service would like to know about this event. bool: True if this service would like to know about this event.
""" """

View File

@ -20,6 +20,7 @@ class MetricsConfig(Config):
def read_config(self, config): def read_config(self, config):
self.enable_metrics = config["enable_metrics"] self.enable_metrics = config["enable_metrics"]
self.metrics_port = config.get("metrics_port") self.metrics_port = config.get("metrics_port")
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
def default_config(self, config_dir_path, server_name): def default_config(self, config_dir_path, server_name):
return """\ return """\
@ -28,6 +29,9 @@ class MetricsConfig(Config):
# Enable collection and rendering of performance metrics # Enable collection and rendering of performance metrics
enable_metrics: False enable_metrics: False
# Separate port to accept metrics requests on (on localhost) # Separate port to accept metrics requests on
# metrics_port: 8081 # metrics_port: 8081
# Which host to bind the metric listener to
# metrics_bind_host: 127.0.0.1
""" """

View File

@ -194,6 +194,8 @@ class FederationClient(FederationBase):
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False` of the current block of PDUs. Defaults to `False`
timeout (int): How long to try (in ms) each destination for before
moving to the next destination. None indicates no timeout.
Returns: Returns:
Deferred: Results in the requested PDU. Deferred: Results in the requested PDU.

View File

@ -207,13 +207,13 @@ class TransactionQueue(object):
# request at which point pending_pdus_by_dest just keeps growing. # request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these # we need application-layer timeouts of some flavour of these
# requests # requests
logger.info( logger.debug(
"TX [%s] Transaction already in progress", "TX [%s] Transaction already in progress",
destination destination
) )
return return
logger.info("TX [%s] _attempt_new_transaction", destination) logger.debug("TX [%s] _attempt_new_transaction", destination)
# list of (pending_pdu, deferred, order) # list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
@ -221,11 +221,11 @@ class TransactionQueue(object):
pending_failures = self.pending_failures_by_dest.pop(destination, []) pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus: if pending_pdus:
logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d", logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus)) destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures: if not pending_pdus and not pending_edus and not pending_failures:
logger.info("TX [%s] Nothing to send", destination) logger.debug("TX [%s] Nothing to send", destination)
return return
# Sort based on the order field # Sort based on the order field
@ -242,6 +242,8 @@ class TransactionQueue(object):
try: try:
self.pending_transactions[destination] = 1 self.pending_transactions[destination] = 1
txn_id = str(self._next_txn_id)
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
destination, destination,
self._clock, self._clock,
@ -249,9 +251,9 @@ class TransactionQueue(object):
) )
logger.debug( logger.debug(
"TX [%s] Attempting new transaction" "TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)", " (pdus: %d, edus: %d, failures: %d)",
destination, destination, txn_id,
len(pending_pdus), len(pending_pdus),
len(pending_edus), len(pending_edus),
len(pending_failures) len(pending_failures)
@ -261,7 +263,7 @@ class TransactionQueue(object):
transaction = Transaction.create_new( transaction = Transaction.create_new(
origin_server_ts=int(self._clock.time_msec()), origin_server_ts=int(self._clock.time_msec()),
transaction_id=str(self._next_txn_id), transaction_id=txn_id,
origin=self.server_name, origin=self.server_name,
destination=destination, destination=destination,
pdus=pdus, pdus=pdus,
@ -275,9 +277,13 @@ class TransactionQueue(object):
logger.debug("TX [%s] Persisted transaction", destination) logger.debug("TX [%s] Persisted transaction", destination)
logger.info( logger.info(
"TX [%s] Sending transaction [%s]", "TX [%s] {%s} Sending transaction [%s],"
destination, " (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id,
transaction.transaction_id, transaction.transaction_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures),
) )
with limiter: with limiter:
@ -313,7 +319,10 @@ class TransactionQueue(object):
code = e.code code = e.code
response = e.response response = e.response
logger.info("TX [%s] got %d response", destination, code) logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
)
logger.debug("TX [%s] Sent transaction", destination) logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination) logger.debug("TX [%s] Marking as delivered...", destination)

View File

@ -57,6 +57,8 @@ class TransportLayerClient(object):
destination (str): The host name of the remote home server we want destination (str): The host name of the remote home server we want
to get the state from. to get the state from.
event_id (str): The id of the event being requested. event_id (str): The id of the event being requested.
timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
Returns: Returns:
Deferred: Results in a dict received from the remote homeserver. Deferred: Results in a dict received from the remote homeserver.

View File

@ -196,6 +196,14 @@ class FederationSendServlet(BaseFederationServlet):
transaction_id, str(transaction_data) transaction_id, str(transaction_data)
) )
logger.info(
"Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
transaction_id, origin,
len(transaction_data.get("pdus", [])),
len(transaction_data.get("edus", [])),
len(transaction_data.get("failures", [])),
)
# We should ideally be getting this from the security layer. # We should ideally be getting this from the security layer.
# origin = body["origin"] # origin = body["origin"]

View File

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.types import UserID from synapse.types import UserID
@ -147,10 +147,7 @@ class ApplicationServicesHandler(object):
) )
# We need to know the members associated with this event.room_id, # We need to know the members associated with this event.room_id,
# if any. # if any.
member_list = yield self.store.get_room_members( member_list = yield self.store.get_users_in_room(event.room_id)
room_id=event.room_id,
membership=Membership.JOIN
)
services = yield self.store.get_app_services() services = yield self.store.get_app_services()
interested_list = [ interested_list = [

View File

@ -146,6 +146,10 @@ class PresenceHandler(BaseHandler):
self._user_cachemap = {} self._user_cachemap = {}
self._user_cachemap_latest_serial = 0 self._user_cachemap_latest_serial = 0
# map room_ids to the latest presence serial for a member of that
# room
self._room_serials = {}
metrics.register_callback( metrics.register_callback(
"userCachemap:size", "userCachemap:size",
lambda: len(self._user_cachemap), lambda: len(self._user_cachemap),
@ -297,13 +301,34 @@ class PresenceHandler(BaseHandler):
self.changed_presencelike_data(user, {"last_active": now}) self.changed_presencelike_data(user, {"last_active": now})
def get_joined_rooms_for_user(self, user):
"""Get the list of rooms a user is joined to.
Args:
user(UserID): The user.
Returns:
A Deferred of a list of room id strings.
"""
rm_handler = self.homeserver.get_handlers().room_member_handler
return rm_handler.get_joined_rooms_for_user(user)
def get_joined_users_for_room_id(self, room_id):
rm_handler = self.homeserver.get_handlers().room_member_handler
return rm_handler.get_room_members(room_id)
@defer.inlineCallbacks
def changed_presencelike_data(self, user, state): def changed_presencelike_data(self, user, state):
statuscache = self._get_or_make_usercache(user) """Updates the presence state of a local user.
Args:
user(UserID): The user being updated.
state(dict): The new presence state for the user.
Returns:
A Deferred
"""
self._user_cachemap_latest_serial += 1 self._user_cachemap_latest_serial += 1
statuscache.update(state, serial=self._user_cachemap_latest_serial) statuscache = yield self.update_presence_cache(user, state)
yield self.push_presence(user, statuscache=statuscache)
return self.push_presence(user, statuscache=statuscache)
@log_function @log_function
def started_user_eventstream(self, user): def started_user_eventstream(self, user):
@ -326,13 +351,12 @@ class PresenceHandler(BaseHandler):
room_id(str): The room id the user joined. room_id(str): The room id the user joined.
""" """
if self.hs.is_mine(user): if self.hs.is_mine(user):
statuscache = self._get_or_make_usercache(user)
# No actual update but we need to bump the serial anyway for the # No actual update but we need to bump the serial anyway for the
# event source # event source
self._user_cachemap_latest_serial += 1 self._user_cachemap_latest_serial += 1
statuscache.update({}, serial=self._user_cachemap_latest_serial) statuscache = yield self.update_presence_cache(
user, room_ids=[room_id]
)
self.push_update_to_local_and_remote( self.push_update_to_local_and_remote(
observed_user=user, observed_user=user,
room_ids=[room_id], room_ids=[room_id],
@ -340,16 +364,17 @@ class PresenceHandler(BaseHandler):
) )
# We also want to tell them about current presence of people. # We also want to tell them about current presence of people.
rm_handler = self.homeserver.get_handlers().room_member_handler curr_users = yield self.get_joined_users_for_room_id(room_id)
curr_users = yield rm_handler.get_room_members(room_id)
for local_user in [c for c in curr_users if self.hs.is_mine(c)]: for local_user in [c for c in curr_users if self.hs.is_mine(c)]:
statuscache = self._get_or_offline_usercache(local_user) statuscache = yield self.update_presence_cache(
statuscache.update({}, serial=self._user_cachemap_latest_serial) local_user, room_ids=[room_id], add_to_cache=False
)
self.push_update_to_local_and_remote( self.push_update_to_local_and_remote(
observed_user=local_user, observed_user=local_user,
users_to_push=[user], users_to_push=[user],
statuscache=self._get_or_offline_usercache(local_user), statuscache=statuscache,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -546,8 +571,7 @@ class PresenceHandler(BaseHandler):
# Also include people in all my rooms # Also include people in all my rooms
rm_handler = self.homeserver.get_handlers().room_member_handler room_ids = yield self.get_joined_rooms_for_user(user)
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if state is None: if state is None:
state = yield self.store.get_presence_state(user.localpart) state = yield self.store.get_presence_state(user.localpart)
@ -747,8 +771,7 @@ class PresenceHandler(BaseHandler):
# and also user is informed of server-forced pushes # and also user is informed of server-forced pushes
localusers.add(user) localusers.add(user)
rm_handler = self.homeserver.get_handlers().room_member_handler room_ids = yield self.get_joined_rooms_for_user(user)
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if not localusers and not room_ids: if not localusers and not room_ids:
defer.returnValue(None) defer.returnValue(None)
@ -793,8 +816,7 @@ class PresenceHandler(BaseHandler):
" | %d interested local observers %r", len(observers), observers " | %d interested local observers %r", len(observers), observers
) )
rm_handler = self.homeserver.get_handlers().room_member_handler room_ids = yield self.get_joined_rooms_for_user(user)
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if room_ids: if room_ids:
logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids) logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids)
@ -813,10 +835,8 @@ class PresenceHandler(BaseHandler):
self.clock.time_msec() - state.pop("last_active_ago") self.clock.time_msec() - state.pop("last_active_ago")
) )
statuscache = self._get_or_make_usercache(user)
self._user_cachemap_latest_serial += 1 self._user_cachemap_latest_serial += 1
statuscache.update(state, serial=self._user_cachemap_latest_serial) yield self.update_presence_cache(user, state, room_ids=room_ids)
if not observers and not room_ids: if not observers and not room_ids:
logger.debug(" | no interested observers or room IDs") logger.debug(" | no interested observers or room IDs")
@ -874,6 +894,35 @@ class PresenceHandler(BaseHandler):
yield defer.DeferredList(deferreds, consumeErrors=True) yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks
def update_presence_cache(self, user, state={}, room_ids=None,
add_to_cache=True):
"""Update the presence cache for a user with a new state and bump the
serial to the latest value.
Args:
user(UserID): The user being updated
state(dict): The presence state being updated
room_ids(None or list of str): A list of room_ids to update. If
room_ids is None then fetch the list of room_ids the user is
joined to.
add_to_cache: Whether to add an entry to the presence cache if the
user isn't already in the cache.
Returns:
A Deferred UserPresenceCache for the user being updated.
"""
if room_ids is None:
room_ids = yield self.get_joined_rooms_for_user(user)
for room_id in room_ids:
self._room_serials[room_id] = self._user_cachemap_latest_serial
if add_to_cache:
statuscache = self._get_or_make_usercache(user)
else:
statuscache = self._get_or_offline_usercache(user)
statuscache.update(state, serial=self._user_cachemap_latest_serial)
defer.returnValue(statuscache)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_update_to_local_and_remote(self, observed_user, statuscache, def push_update_to_local_and_remote(self, observed_user, statuscache,
users_to_push=[], room_ids=[], users_to_push=[], room_ids=[],
@ -996,39 +1045,11 @@ class PresenceEventSource(object):
self.hs = hs self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
@defer.inlineCallbacks
def is_visible(self, observer_user, observed_user):
if observer_user == observed_user:
defer.returnValue(True)
presence = self.hs.get_handlers().presence_handler
if (yield presence.store.user_rooms_intersect(
[u.to_string() for u in observer_user, observed_user])):
defer.returnValue(True)
if self.hs.is_mine(observed_user):
pushmap = presence._local_pushmap
defer.returnValue(
observed_user.localpart in pushmap and
observer_user in pushmap[observed_user.localpart]
)
else:
recvmap = presence._remote_recvmap
defer.returnValue(
observed_user in recvmap and
observer_user in recvmap[observed_user]
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_new_events_for_user(self, user, from_key, limit): def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key) from_key = int(from_key)
observer_user = user
presence = self.hs.get_handlers().presence_handler presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap cachemap = presence._user_cachemap
@ -1037,17 +1058,27 @@ class PresenceEventSource(object):
clock = self.clock clock = self.clock
latest_serial = 0 latest_serial = 0
user_ids_to_check = {user}
presence_list = yield presence.store.get_presence_list(
user.localpart, accepted=True
)
if presence_list is not None:
user_ids_to_check |= set(
UserID.from_string(p["observed_user_id"]) for p in presence_list
)
room_ids = yield presence.get_joined_rooms_for_user(user)
for room_id in set(room_ids) & set(presence._room_serials):
if presence._room_serials[room_id] > from_key:
joined = yield presence.get_joined_users_for_room_id(room_id)
user_ids_to_check |= set(joined)
updates = [] updates = []
# TODO(paul): use a DeferredList ? How to limit concurrency. for observed_user in user_ids_to_check & set(cachemap):
for observed_user in cachemap.keys():
cached = cachemap[observed_user] cached = cachemap[observed_user]
if cached.serial <= from_key or cached.serial > max_serial: if cached.serial <= from_key or cached.serial > max_serial:
continue continue
if not (yield self.is_visible(observer_user, observed_user)):
continue
latest_serial = max(cached.serial, latest_serial) latest_serial = max(cached.serial, latest_serial)
updates.append(cached.make_event(user=observed_user, clock=clock)) updates.append(cached.make_event(user=observed_user, clock=clock))
@ -1084,8 +1115,6 @@ class PresenceEventSource(object):
def get_pagination_rows(self, user, pagination_config, key): def get_pagination_rows(self, user, pagination_config, key):
# TODO (erikj): Does this make sense? Ordering? # TODO (erikj): Does this make sense? Ordering?
observer_user = user
from_key = int(pagination_config.from_key) from_key = int(pagination_config.from_key)
if pagination_config.to_key: if pagination_config.to_key:
@ -1096,14 +1125,26 @@ class PresenceEventSource(object):
presence = self.hs.get_handlers().presence_handler presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap cachemap = presence._user_cachemap
user_ids_to_check = {user}
presence_list = yield presence.store.get_presence_list(
user.localpart, accepted=True
)
if presence_list is not None:
user_ids_to_check |= set(
UserID.from_string(p["observed_user_id"]) for p in presence_list
)
room_ids = yield presence.get_joined_rooms_for_user(user)
for room_id in set(room_ids) & set(presence._room_serials):
if presence._room_serials[room_id] >= from_key:
joined = yield presence.get_joined_users_for_room_id(room_id)
user_ids_to_check |= set(joined)
updates = [] updates = []
# TODO(paul): use a DeferredList ? How to limit concurrency. for observed_user in user_ids_to_check & set(cachemap):
for observed_user in cachemap.keys():
if not (to_key < cachemap[observed_user].serial <= from_key): if not (to_key < cachemap[observed_user].serial <= from_key):
continue continue
if (yield self.is_visible(observer_user, observed_user)): updates.append((observed_user, cachemap[observed_user]))
updates.append((observed_user, cachemap[observed_user]))
# TODO(paul): limit # TODO(paul): limit

View File

@ -536,9 +536,7 @@ class RoomListHandler(BaseHandler):
chunk = yield self.store.get_rooms(is_public=True) chunk = yield self.store.get_rooms(is_public=True)
results = yield defer.gatherResults( results = yield defer.gatherResults(
[ [
self.store.get_users_in_room( self.store.get_users_in_room(room["room_id"])
room_id=room["room_id"],
)
for room in chunk for room in chunk
], ],
consumeErrors=True, consumeErrors=True,

View File

@ -345,6 +345,9 @@ class MatrixFederationHttpClient(object):
path (str): The HTTP path. path (str): The HTTP path.
args (dict): A dictionary used to create query strings, defaults to args (dict): A dictionary used to create query strings, defaults to
None. None.
timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout and that the request will
be retried.
Returns: Returns:
Deferred: Succeeds when we get *any* HTTP response. Deferred: Succeeds when we get *any* HTTP response.

View File

@ -84,25 +84,20 @@ class Pusher(object):
rules = baserules.list_with_base_rules(rawrules, user) rules = baserules.list_with_base_rules(rawrules, user)
room_id = ev['room_id']
# get *our* member event for display name matching # get *our* member event for display name matching
member_events_for_room = yield self.store.get_current_state(
room_id=ev['room_id'],
event_type='m.room.member',
state_key=None
)
my_display_name = None my_display_name = None
room_member_count = 0 our_member_event = yield self.store.get_current_state(
for mev in member_events_for_room: room_id=room_id,
if mev.content['membership'] != 'join': event_type='m.room.member',
continue state_key=self.user_name,
)
if our_member_event:
my_display_name = our_member_event[0].content.get("displayname")
# This loop does two things: room_members = yield self.store.get_users_in_room(room_id)
# 1) Find our current display name room_member_count = len(room_members)
if mev.state_key == self.user_name and 'displayname' in mev.content:
my_display_name = mev.content['displayname']
# and 2) Get the number of people in that room
room_member_count += 1
for r in rules: for r in rules:
if r['rule_id'] in enabled_map: if r['rule_id'] in enabled_map:
@ -287,9 +282,11 @@ class Pusher(object):
if len(actions) == 0: if len(actions) == 0:
logger.warn("Empty actions! Using default action.") logger.warn("Empty actions! Using default action.")
actions = Pusher.DEFAULT_ACTIONS actions = Pusher.DEFAULT_ACTIONS
if 'notify' not in actions and 'dont_notify' not in actions: if 'notify' not in actions and 'dont_notify' not in actions:
logger.warn("Neither notify nor dont_notify in actions: adding default") logger.warn("Neither notify nor dont_notify in actions: adding default")
actions.extend(Pusher.DEFAULT_ACTIONS) actions.extend(Pusher.DEFAULT_ACTIONS)
if 'dont_notify' in actions: if 'dont_notify' in actions:
logger.debug( logger.debug(
"%s for %s: dont_notify", "%s for %s: dont_notify",

View File

@ -121,6 +121,11 @@ class Cache(object):
self.sequence += 1 self.sequence += 1
self.cache.pop(keyargs, None) self.cache.pop(keyargs, None)
def invalidate_all(self):
self.check_thread()
self.sequence += 1
self.cache.clear()
def cached(max_entries=1000, num_args=1, lru=False): def cached(max_entries=1000, num_args=1, lru=False):
""" A method decorator that applies a memoizing cache around the function. """ A method decorator that applies a memoizing cache around the function.
@ -172,6 +177,7 @@ def cached(max_entries=1000, num_args=1, lru=False):
defer.returnValue(ret) defer.returnValue(ret)
wrapped.invalidate = cache.invalidate wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.prefill = cache.prefill wrapped.prefill = cache.prefill
return wrapped return wrapped

View File

@ -489,3 +489,4 @@ class EventFederationStore(SQLBaseStore):
query = "DELETE FROM event_forward_extremities WHERE room_id = ?" query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,)) txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)

View File

@ -125,6 +125,12 @@ class EventsStore(SQLBaseStore):
# We purposefully do this first since if we include a `current_state` # We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table # key, we *want* to update the `current_state_events` table
if current_state: if current_state:
txn.call_after(self.get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
txn.call_after(self.get_room_name_and_aliases, event.room_id)
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="current_state_events", table="current_state_events",
@ -132,13 +138,6 @@ class EventsStore(SQLBaseStore):
) )
for s in current_state: for s in current_state:
if s.type == EventTypes.Member:
txn.call_after(
self.get_rooms_for_user.invalidate, s.state_key
)
txn.call_after(
self.get_joined_hosts_for_room.invalidate, s.room_id
)
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
"current_state_events", "current_state_events",
@ -356,6 +355,18 @@ class EventsStore(SQLBaseStore):
) )
if is_new_state and not context.rejected: if is_new_state and not context.rejected:
txn.call_after(
self.get_current_state_for_key.invalidate,
event.room_id, event.type, event.state_key
)
if (event.type == EventTypes.Name
or event.type == EventTypes.Aliases):
txn.call_after(
self.get_room_name_and_aliases.invalidate,
event.room_id
)
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
"current_state_events", "current_state_events",

View File

@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore, Table
from twisted.internet import defer from twisted.internet import defer
import logging import logging
@ -41,6 +39,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rows) defer.returnValue(rows)
@cached()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_push_rules_enabled_for_user(self, user_name): def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list( results = yield self._simple_select_list(
@ -151,6 +150,10 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.execute(sql, (user_name, priority_class, new_rule_priority))
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name
)
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table=PushRuleTable.table_name, table=PushRuleTable.table_name,
@ -179,6 +182,10 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority_class'] = priority_class new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio new_rule['priority'] = new_prio
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name
)
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table=PushRuleTable.table_name, table=PushRuleTable.table_name,
@ -201,6 +208,7 @@ class PushRuleStore(SQLBaseStore):
{'user_name': user_name, 'rule_id': rule_id}, {'user_name': user_name, 'rule_id': rule_id},
desc="delete_push_rule", desc="delete_push_rule",
) )
self.get_push_rules_enabled_for_user.invalidate(user_name)
@defer.inlineCallbacks @defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled): def set_push_rule_enabled(self, user_name, rule_id, enabled):
@ -220,6 +228,7 @@ class PushRuleStore(SQLBaseStore):
{'enabled': 1 if enabled else 0}, {'enabled': 1 if enabled else 0},
{'id': new_id}, {'id': new_id},
) )
self.get_push_rules_enabled_for_user.invalidate(user_name)
class RuleNotFoundException(Exception): class RuleNotFoundException(Exception):
@ -230,7 +239,7 @@ class InconsistentRuleException(Exception):
pass pass
class PushRuleTable(Table): class PushRuleTable(object):
table_name = "push_rules" table_name = "push_rules"
fields = [ fields = [
@ -243,10 +252,8 @@ class PushRuleTable(Table):
"actions", "actions",
] ]
EntryType = collections.namedtuple("PushRuleEntry", fields)
class PushRuleEnableTable(object):
class PushRuleEnableTable(Table):
table_name = "push_rules_enable" table_name = "push_rules_enable"
fields = [ fields = [

View File

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from ._base import SQLBaseStore from ._base import SQLBaseStore, cached
import collections import collections
import logging import logging
@ -186,6 +186,7 @@ class RoomStore(SQLBaseStore):
} }
) )
@cached()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_name_and_aliases(self, room_id): def get_room_name_and_aliases(self, room_id):
def f(txn): def f(txn):

View File

@ -66,6 +66,7 @@ class RoomMemberStore(SQLBaseStore):
txn.call_after(self.get_rooms_for_user.invalidate, target_user_id) txn.call_after(self.get_rooms_for_user.invalidate, target_user_id)
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member. """Retrieve the current state of a room member.
@ -87,6 +88,7 @@ class RoomMemberStore(SQLBaseStore):
lambda events: events[0] if events else None lambda events: events[0] if events else None
) )
@cached()
def get_users_in_room(self, room_id): def get_users_in_room(self, room_id):
def f(txn): def f(txn):

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore, cached
from twisted.internet import defer from twisted.internet import defer
@ -143,6 +143,12 @@ class StateStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""): def get_current_state(self, room_id, event_type=None, state_key=""):
if event_type and state_key is not None:
result = yield self.get_current_state_for_key(
room_id, event_type, state_key
)
defer.returnValue(result)
def f(txn): def f(txn):
sql = ( sql = (
"SELECT event_id FROM current_state_events" "SELECT event_id FROM current_state_events"
@ -167,6 +173,23 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False) events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events) defer.returnValue(events)
@cached(num_args=3)
@defer.inlineCallbacks
def get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn):
sql = (
"SELECT event_id FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?"
)
args = (room_id, event_type, state_key)
txn.execute(sql, args)
results = txn.fetchall()
return [r[0] for r in results]
event_ids = yield self.runInteraction("get_current_state_for_key", f)
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
def _make_group_id(clock): def _make_group_id(clock):
return str(int(clock.time_msec())) + random_string(5) return str(int(clock.time_msec())) + random_string(5)

View File

@ -20,7 +20,6 @@ import threading
class LruCache(object): class LruCache(object):
"""Least-recently-used cache.""" """Least-recently-used cache."""
# TODO(mjark) Add mutex for linked list for thread safety.
def __init__(self, max_size): def __init__(self, max_size):
cache = {} cache = {}
list_root = [] list_root = []
@ -105,6 +104,12 @@ class LruCache(object):
else: else:
return default return default
@synchronized
def cache_clear():
list_root[NEXT] = list_root
list_root[PREV] = list_root
cache.clear()
@synchronized @synchronized
def cache_len(): def cache_len():
return len(cache) return len(cache)
@ -120,6 +125,7 @@ class LruCache(object):
self.pop = cache_pop self.pop = cache_pop
self.len = cache_len self.len = cache_len
self.contains = cache_contains self.contains = cache_contains
self.clear = cache_clear
def __getitem__(self, key): def __getitem__(self, key):
result = self.get(key, self.sentinel) result = self.get(key, self.sentinel)

View File

@ -217,18 +217,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("@irc_.*") _regex("@irc_.*")
) )
join_list = [ join_list = [
Mock( "@alice:here",
type="m.room.member", room_id="!foo:bar", sender="@alice:here", "@irc_fo:here", # AS user
state_key="@alice:here" "@bob:here",
),
Mock(
type="m.room.member", room_id="!foo:bar", sender="@irc_fo:here",
state_key="@irc_fo:here" # AS user
),
Mock(
type="m.room.member", room_id="!foo:bar", sender="@bob:here",
state_key="@bob:here"
)
] ]
self.event.sender = "@xmpp_foobar:matrix.org" self.event.sender = "@xmpp_foobar:matrix.org"

View File

@ -624,6 +624,7 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
""" """
PRESENCE_LIST = { PRESENCE_LIST = {
'apple': [ "@banana:test", "@clementine:test" ], 'apple': [ "@banana:test", "@clementine:test" ],
'banana': [ "@apple:test" ],
} }
@defer.inlineCallbacks @defer.inlineCallbacks
@ -836,12 +837,7 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_recv_remote(self): def test_recv_remote(self):
# TODO(paul): Gut-wrenching self.room_members = [self.u_apple, self.u_banana, self.u_potato]
potato_set = self.handler._remote_recvmap.setdefault(self.u_potato,
set())
potato_set.add(self.u_apple)
self.room_members = [self.u_banana, self.u_potato]
self.assertEquals(self.event_source.get_current_key(), 0) self.assertEquals(self.event_source.get_current_key(), 0)
@ -886,11 +882,8 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_recv_remote_offline(self): def test_recv_remote_offline(self):
""" Various tests relating to SYN-261 """ """ Various tests relating to SYN-261 """
potato_set = self.handler._remote_recvmap.setdefault(self.u_potato,
set())
potato_set.add(self.u_apple)
self.room_members = [self.u_banana, self.u_potato] self.room_members = [self.u_apple, self.u_banana, self.u_potato]
self.assertEquals(self.event_source.get_current_key(), 0) self.assertEquals(self.event_source.get_current_key(), 0)

View File

@ -297,6 +297,9 @@ class PresenceEventStreamTestCase(unittest.TestCase):
else: else:
return [] return []
hs.handlers.room_member_handler.get_joined_rooms_for_user = get_rooms_for_user hs.handlers.room_member_handler.get_joined_rooms_for_user = get_rooms_for_user
hs.handlers.room_member_handler.get_room_members = (
lambda r: self.room_members if r == "a-room" else []
)
self.mock_datastore = hs.get_datastore() self.mock_datastore = hs.get_datastore()
self.mock_datastore.get_app_service_by_token = Mock(return_value=None) self.mock_datastore.get_app_service_by_token = Mock(return_value=None)