Merge branch 'develop' of github.com:matrix-org/synapse into erikj/public_room_fix

This commit is contained in:
Erik Johnston 2016-02-03 11:06:29 +00:00
commit 6f52e90065
36 changed files with 216 additions and 238 deletions

View File

@ -16,3 +16,4 @@ ignore =
[flake8] [flake8]
max-line-length = 90 max-line-length = 90
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.

View File

@ -696,6 +696,7 @@ class Auth(object):
def _look_up_user_by_access_token(self, token): def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token) ret = yield self.store.get_user_by_access_token(token)
if not ret: if not ret:
logger.warn("Unrecognised access token - not in store: %s" % (token,))
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
@ -713,6 +714,7 @@ class Auth(object):
token = request.args["access_token"][0] token = request.args["access_token"][0]
service = yield self.store.get_app_service_by_token(token) service = yield self.store.get_app_service_by_token(token)
if not service: if not service:
logger.warn("Unrecognised appservice access token: %s" % (token,))
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.", "Unrecognised access token.",

View File

@ -12,3 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
sys.dont_write_bytecode = True
from synapse.python_dependencies import (
check_requirements, MissingRequirementError
) # NOQA
try:
check_requirements()
except MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",
" pip install --upgrade --force \"%s\"" % (e.dependency,),
"",
])
sys.stderr.writelines(message)
sys.exit(1)

View File

@ -14,27 +14,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import synapse
from synapse.rest import ClientRestResource
import contextlib
import logging
import os
import re
import resource
import subprocess
import sys
import time
sys.dont_write_bytecode = True
from synapse.python_dependencies import ( from synapse.python_dependencies import (
check_requirements, DEPENDENCY_LINKS, MissingRequirementError check_requirements, DEPENDENCY_LINKS
) )
if __name__ == '__main__': from synapse.rest import ClientRestResource
try:
check_requirements()
except MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",
" pip install --upgrade --force \"%s\"" % (e.dependency,),
"",
])
sys.stderr.writelines(message)
sys.exit(1)
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import are_all_users_on_domain from synapse.storage import are_all_users_on_domain
from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.storage.prepare_database import UpgradeDatabaseException
@ -73,17 +68,6 @@ from synapse import events
from daemonize import Daemonize from daemonize import Daemonize
import synapse
import contextlib
import logging
import os
import re
import resource
import subprocess
import time
logger = logging.getLogger("synapse.app.homeserver") logger = logging.getLogger("synapse.app.homeserver")

View File

@ -57,7 +57,7 @@ class FederationClient(FederationBase):
cache_name="get_pdu_cache", cache_name="get_pdu_cache",
clock=self._clock, clock=self._clock,
max_len=1000, max_len=1000,
expiry_ms=120*1000, expiry_ms=120 * 1000,
reset_expiry_on_get=False, reset_expiry_on_get=False,
) )

View File

@ -147,7 +147,7 @@ class BaseHandler(object):
) )
if not allowed: if not allowed:
raise LimitExceededError( raise LimitExceededError(
retry_after_ms=int(1000*(time_allowed - time_now)), retry_after_ms=int(1000 * (time_allowed - time_now)),
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -175,8 +175,8 @@ class DirectoryHandler(BaseHandler):
# If this server is in the list of servers, return it first. # If this server is in the list of servers, return it first.
if self.server_name in servers: if self.server_name in servers:
servers = ( servers = (
[self.server_name] [self.server_name] +
+ [s for s in servers if s != self.server_name] [s for s in servers if s != self.server_name]
) )
else: else:
servers = list(servers) servers = list(servers)

View File

@ -130,7 +130,7 @@ class EventStreamHandler(BaseHandler):
# Add some randomness to this value to try and mitigate against # Add some randomness to this value to try and mitigate against
# thundering herds on restart. # thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
events, tokens = yield self.notifier.get_events_for( events, tokens = yield self.notifier.get_events_for(
auth_user, pagin_config, timeout, auth_user, pagin_config, timeout,

View File

@ -34,7 +34,7 @@ metrics = synapse.metrics.get_metrics_for(__name__)
# Don't bother bumping "last active" time if it differs by less than 60 seconds # Don't bother bumping "last active" time if it differs by less than 60 seconds
LAST_ACTIVE_GRANULARITY = 60*1000 LAST_ACTIVE_GRANULARITY = 60 * 1000
# Keep no more than this number of offline serial revisions # Keep no more than this number of offline serial revisions
MAX_OFFLINE_SERIALS = 1000 MAX_OFFLINE_SERIALS = 1000

View File

@ -139,7 +139,9 @@ class RegistrationHandler(BaseHandler):
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash) password_hash=password_hash,
make_guest=make_guest
)
yield registered_user(self.distributor, user) yield registered_user(self.distributor, user)
except SynapseError: except SynapseError:

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset, EventTypes, Membership, JoinRules, RoomCreationPreset,
) )
@ -967,7 +967,7 @@ class RoomContextHandler(BaseHandler):
Returns: Returns:
dict, or None if the event isn't found dict, or None if the event isn't found
""" """
before_limit = math.floor(limit/2.) before_limit = math.floor(limit / 2.)
after_limit = limit - before_limit after_limit = limit - before_limit
now_token = yield self.hs.get_event_sources().get_current_token() now_token = yield self.hs.get_event_sources().get_current_token()
@ -1037,6 +1037,11 @@ class RoomEventSource(object):
to_key = yield self.get_current_key() to_key = yield self.get_current_key()
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
logger.warn("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,)
app_service = yield self.store.get_app_service_by_user_id( app_service = yield self.store.get_app_service_by_user_id(
user.to_string() user.to_string()
) )
@ -1048,7 +1053,7 @@ class RoomEventSource(object):
limit=limit, limit=limit,
) )
else: else:
room_events = yield self.store.get_room_changes_for_user( room_events = yield self.store.get_membership_changes_for_user(
user.to_string(), from_key, to_key user.to_string(), from_key, to_key
) )

View File

@ -23,6 +23,7 @@ from twisted.internet import defer
import collections import collections
import logging import logging
import itertools
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -478,7 +479,7 @@ class SyncHandler(BaseHandler):
) )
# Get a list of membership change events that have happened. # Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_room_changes_for_user( rooms_changed = yield self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key user_id, since_token.room_key, now_token.room_key
) )
@ -672,35 +673,10 @@ class SyncHandler(BaseHandler):
account_data_by_room, account_data_by_room,
all_ephemeral_by_room, all_ephemeral_by_room,
batch, full_state=False): batch, full_state=False):
if full_state:
state = yield self.get_state_at(room_id, now_token)
elif batch.limited:
current_state = yield self.get_state_at(room_id, now_token)
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
state = yield self.compute_state_delta( state = yield self.compute_state_delta(
since_token=since_token, room_id, batch, sync_config, since_token, now_token,
previous_state=state_at_previous_sync, full_state=full_state
current_state=current_state,
) )
else:
state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
just_joined = yield self.check_joined_room(sync_config, state)
if just_joined:
state = yield self.get_state_at(room_id, now_token)
state = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values())
}
account_data = self.account_data_for_room( account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room room_id, tags_by_room, account_data_by_room
@ -766,29 +742,10 @@ class SyncHandler(BaseHandler):
logger.debug("Recents %r", batch) logger.debug("Recents %r", batch)
state_events_at_leave = yield self.store.get_state_for_event(
leave_event_id
)
if not full_state:
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
state_events_delta = yield self.compute_state_delta( state_events_delta = yield self.compute_state_delta(
since_token=since_token, room_id, batch, sync_config, since_token, leave_token,
previous_state=state_at_previous_sync, full_state=full_state
current_state=state_events_at_leave,
) )
else:
state_events_delta = state_events_at_leave
state_events_delta = {
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(
state_events_delta.values()
)
}
account_data = self.account_data_for_room( account_data = self.account_data_for_room(
room_id, tags_by_room, account_data_by_room room_id, tags_by_room, account_data_by_room
@ -843,15 +800,19 @@ class SyncHandler(BaseHandler):
state = {} state = {}
defer.returnValue(state) defer.returnValue(state)
def compute_state_delta(self, since_token, previous_state, current_state): @defer.inlineCallbacks
""" Works out the differnce in state between the current state and the def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
state the client got when it last performed a sync. full_state):
""" Works out the differnce in state between the start of the timeline
and the previous sync.
:param str since_token: the point we are comparing against :param str room_id
:param dict[(str,str), synapse.events.FrozenEvent] previous_state: the :param TimelineBatch batch: The timeline batch for the room that will
state to compare to be sent to the user.
:param dict[(str,str), synapse.events.FrozenEvent] current_state: the :param sync_config
new state :param str since_token: Token of the end of the previous batch. May be None.
:param str now_token: Token of the end of the current batch.
:param bool full_state: Whether to force returning the full state.
:returns A new event dictionary :returns A new event dictionary
""" """
@ -860,12 +821,50 @@ class SyncHandler(BaseHandler):
# updates even if they occured logically before the previous event. # updates even if they occured logically before the previous event.
# TODO(mjark) Check for new redactions in the state events. # TODO(mjark) Check for new redactions in the state events.
state_delta = {} if full_state:
for key, event in current_state.iteritems(): if batch:
if (key not in previous_state or state = yield self.store.get_state_for_event(batch.events[0].event_id)
previous_state[key].event_id != event.event_id): else:
state_delta[key] = event state = yield self.get_state_at(
return state_delta room_id, stream_position=now_token
)
timeline_state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
state = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state,
previous={},
)
elif batch.limited:
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
state_at_timeline_start = yield self.store.get_state_for_event(
batch.events[0].event_id
)
timeline_state = {
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
state = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state_at_timeline_start,
previous=state_at_previous_sync,
)
else:
state = {}
defer.returnValue({
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values())
})
def check_joined_room(self, sync_config, state_delta): def check_joined_room(self, sync_config, state_delta):
""" """
@ -912,3 +911,37 @@ def _action_has_highlight(actions):
pass pass
return False return False
def _calculate_state(timeline_contains, timeline_start, previous):
"""Works out what state to include in a sync response.
Args:
timeline_contains (dict): state in the timeline
timeline_start (dict): state at the start of the timeline
previous (dict): state at the end of the previous sync (or empty dict
if this is an initial sync)
Returns:
dict
"""
event_id_to_state = {
e.event_id: e
for e in itertools.chain(
timeline_contains.values(),
previous.values(),
timeline_start.values(),
)
}
tc_ids = set(e.event_id for e in timeline_contains.values())
p_ids = set(e.event_id for e in previous.values())
ts_ids = set(e.event_id for e in timeline_start.values())
state_ids = (ts_ids - p_ids) - tc_ids
evs = (event_id_to_state[e] for e in state_ids)
return {
(e.type, e.state_key): e
for e in evs
}

View File

@ -152,7 +152,7 @@ class MatrixFederationHttpClient(object):
return self.clock.time_bound_deferred( return self.clock.time_bound_deferred(
request_deferred, request_deferred,
time_out=timeout/1000. if timeout else 60, time_out=timeout / 1000. if timeout else 60,
) )
response = yield preserve_context_over_fn( response = yield preserve_context_over_fn(

View File

@ -308,7 +308,7 @@ class Notifier(object):
def timed_out(): def timed_out():
if listener: if listener:
listener.deferred.cancel() listener.deferred.cancel()
timer = self.clock.call_later(timeout/1000., timed_out) timer = self.clock.call_later(timeout / 1000., timed_out)
prev_token = from_token prev_token = from_token
while not result: while not result:

View File

@ -304,7 +304,7 @@ def _flatten_dict(d, prefix=[], result={}):
if isinstance(value, basestring): if isinstance(value, basestring):
result[".".join(prefix + [key])] = value.lower() result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"): elif hasattr(value, "items"):
_flatten_dict(value, prefix=(prefix+[key]), result=result) _flatten_dict(value, prefix=(prefix + [key]), result=result)
return result return result

View File

@ -89,7 +89,7 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.SAML2_TYPE): LoginRestServlet.SAML2_TYPE):
relay_state = "" relay_state = ""
if "relay_state" in login_submission: if "relay_state" in login_submission:
relay_state = "&RelayState="+urllib.quote( relay_state = "&RelayState=" + urllib.quote(
login_submission["relay_state"]) login_submission["relay_state"])
result = { result = {
"uri": "%s%s" % (self.idp_redirect_url, relay_state) "uri": "%s%s" % (self.idp_redirect_url, relay_state)

View File

@ -52,7 +52,7 @@ class PusherRestServlet(ClientV1RestServlet):
if i not in content: if i not in content:
missing.append(i) missing.append(i)
if len(missing): if len(missing):
raise SynapseError(400, "Missing parameters: "+','.join(missing), raise SynapseError(400, "Missing parameters: " + ','.join(missing),
errcode=Codes.MISSING_PARAM) errcode=Codes.MISSING_PARAM)
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind']) logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
@ -83,7 +83,7 @@ class PusherRestServlet(ClientV1RestServlet):
data=content['data'] data=content['data']
) )
except PusherConfigException as pce: except PusherConfigException as pce:
raise SynapseError(400, "Config Error: "+pce.message, raise SynapseError(400, "Config Error: " + pce.message,
errcode=Codes.MISSING_PARAM) errcode=Codes.MISSING_PARAM)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -38,7 +38,8 @@ logger = logging.getLogger(__name__)
if hasattr(hmac, "compare_digest"): if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest compare_digest = hmac.compare_digest
else: else:
compare_digest = lambda a, b: a == b def compare_digest(a, b):
return a == b
class RegisterRestServlet(ClientV1RestServlet): class RegisterRestServlet(ClientV1RestServlet):

View File

@ -34,7 +34,8 @@ from synapse.util.async import run_on_reactor
if hasattr(hmac, "compare_digest"): if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest compare_digest = hmac.compare_digest
else: else:
compare_digest = lambda a, b: a == b def compare_digest(a, b):
return a == b
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -152,6 +153,7 @@ class RegisterRestServlet(RestServlet):
desired_username = params.get("username", None) desired_username = params.get("username", None)
new_password = params.get("password", None) new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None)
(user_id, token) = yield self.registration_handler.register( (user_id, token) = yield self.registration_handler.register(
localpart=desired_username, localpart=desired_username,

View File

@ -20,7 +20,6 @@ from synapse.http.servlet import (
) )
from synapse.handlers.sync import SyncConfig from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.events import FrozenEvent
from synapse.events.utils import ( from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_room_id, serialize_event, format_event_for_client_v2_without_room_id,
) )
@ -287,9 +286,6 @@ class SyncRestServlet(RestServlet):
state_dict = room.state state_dict = room.state
timeline_events = room.timeline.events timeline_events = room.timeline.events
state_dict = SyncRestServlet._rollback_state_for_timeline(
state_dict, timeline_events)
state_events = state_dict.values() state_events = state_dict.values()
serialized_state = [serialize(e) for e in state_events] serialized_state = [serialize(e) for e in state_events]
@ -314,77 +310,6 @@ class SyncRestServlet(RestServlet):
return result return result
@staticmethod
def _rollback_state_for_timeline(state, timeline):
"""
Wind the state dictionary backwards, so that it represents the
state at the start of the timeline, rather than at the end.
:param dict[(str, str), synapse.events.EventBase] state: the
state dictionary. Will be updated to the state before the timeline.
:param list[synapse.events.EventBase] timeline: the event timeline
:return: updated state dictionary
"""
result = state.copy()
for timeline_event in reversed(timeline):
if not timeline_event.is_state():
continue
event_key = (timeline_event.type, timeline_event.state_key)
logger.debug("Considering %s for removal", event_key)
state_event = result.get(event_key)
if (state_event is None or
state_event.event_id != timeline_event.event_id):
# the event in the timeline isn't present in the state
# dictionary.
#
# the most likely cause for this is that there was a fork in
# the event graph, and the state is no longer valid. Really,
# the event shouldn't be in the timeline. We're going to ignore
# it for now, however.
logger.debug("Found state event %r in timeline which doesn't "
"match state dictionary", timeline_event)
continue
prev_event_id = timeline_event.unsigned.get("replaces_state", None)
prev_content = timeline_event.unsigned.get('prev_content')
prev_sender = timeline_event.unsigned.get('prev_sender')
# Empircally it seems possible for the event to have a
# "replaces_state" key but not a prev_content or prev_sender
# markjh conjectures that it could be due to the server not
# having a copy of that event.
# If this is the case the we ignore the previous event. This will
# cause the displayname calculations on the client to be incorrect
if prev_event_id is None or not prev_content or not prev_sender:
logger.debug(
"Removing %r from the state dict, as it is missing"
" prev_content (prev_event_id=%r)",
timeline_event.event_id, prev_event_id
)
del result[event_key]
else:
logger.debug(
"Replacing %r with %r in state dict",
timeline_event.event_id, prev_event_id
)
result[event_key] = FrozenEvent({
"type": timeline_event.type,
"state_key": timeline_event.state_key,
"content": prev_content,
"sender": prev_sender,
"event_id": prev_event_id,
"room_id": timeline_event.room_id,
})
logger.debug("New value: %r", result.get(event_key))
return result
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
SyncRestServlet(hs).register(http_server) SyncRestServlet(hs).register(http_server)

View File

@ -26,9 +26,7 @@ class VersionsRestServlet(RestServlet):
def on_GET(self, request): def on_GET(self, request):
return (200, { return (200, {
"versions": [ "versions": ["r0.0.1"]
"r0.0.1",
]
}) })

View File

@ -63,7 +63,7 @@ class StateHandler(object):
cache_name="state_cache", cache_name="state_cache",
clock=self.clock, clock=self.clock,
max_len=SIZE_OF_CACHE, max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS*1000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
reset_expiry_on_get=True, reset_expiry_on_get=True,
) )

View File

@ -59,7 +59,7 @@ logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller # Number of msec of granularity to store the user IP 'last seen' time. Smaller
# times give more inserts into the database even for readonly API hits # times give more inserts into the database even for readonly API hits
# 120 seconds == 2 minutes # 120 seconds == 2 minutes
LAST_SEEN_GRANULARITY = 120*1000 LAST_SEEN_GRANULARITY = 120 * 1000
class DataStore(RoomMemberStore, RoomStore, class DataStore(RoomMemberStore, RoomStore,

View File

@ -185,7 +185,7 @@ class SQLBaseStore(object):
time_then = self._previous_loop_ts time_then = self._previous_loop_ts
self._previous_loop_ts = time_now self._previous_loop_ts = time_now
ratio = (curr - prev)/(time_now - time_then) ratio = (curr - prev) / (time_now - time_then)
top_three_counters = self._txn_perf_counters.interval( top_three_counters = self._txn_perf_counters.interval(
time_now - time_then, limit=3 time_now - time_then, limit=3
@ -643,7 +643,10 @@ class SQLBaseStore(object):
if not iterable: if not iterable:
defer.returnValue(results) defer.returnValue(results)
chunks = [iterable[i:i+batch_size] for i in xrange(0, len(iterable), batch_size)] chunks = [
iterable[i:i + batch_size]
for i in xrange(0, len(iterable), batch_size)
]
for chunk in chunks: for chunk in chunks:
rows = yield self.runInteraction( rows = yield self.runInteraction(
desc, desc,

View File

@ -54,7 +54,7 @@ class Sqlite3Engine(object):
def _parse_match_info(buf): def _parse_match_info(buf):
bufsize = len(buf) bufsize = len(buf)
return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)] return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)]
def _rank(raw_match_info): def _rank(raw_match_info):

View File

@ -58,7 +58,7 @@ class EventFederationStore(SQLBaseStore):
new_front = set() new_front = set()
front_list = list(front) front_list = list(front)
chunks = [ chunks = [
front_list[x:x+100] front_list[x:x + 100]
for x in xrange(0, len(front), 100) for x in xrange(0, len(front), 100)
] ]
for chunk in chunks: for chunk in chunks:

View File

@ -84,7 +84,7 @@ class EventsStore(SQLBaseStore):
event.internal_metadata.stream_ordering = stream event.internal_metadata.stream_ordering = stream
chunks = [ chunks = [
events_and_contexts[x:x+100] events_and_contexts[x:x + 100]
for x in xrange(0, len(events_and_contexts), 100) for x in xrange(0, len(events_and_contexts), 100)
] ]
@ -740,7 +740,7 @@ class EventsStore(SQLBaseStore):
rows = [] rows = []
N = 200 N = 200
for i in range(1 + len(events) / N): for i in range(1 + len(events) / N):
evs = events[i*N:(i + 1)*N] evs = events[i * N:(i + 1) * N]
if not evs: if not evs:
break break
@ -755,7 +755,7 @@ class EventsStore(SQLBaseStore):
" LEFT JOIN rejections as rej USING (event_id)" " LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts" " LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)" " WHERE e.event_id IN (%s)"
) % (",".join(["?"]*len(evs)),) ) % (",".join(["?"] * len(evs)),)
txn.execute(sql, evs) txn.execute(sql, evs)
rows.extend(self.cursor_to_dict(txn)) rows.extend(self.cursor_to_dict(txn))

View File

@ -168,7 +168,7 @@ class StreamStore(SQLBaseStore):
results = {} results = {}
room_ids = list(room_ids) room_ids = list(room_ids)
for rm_ids in (room_ids[i:i+20] for i in xrange(0, len(room_ids), 20)): for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
res = yield defer.gatherResults([ res = yield defer.gatherResults([
self.get_room_events_stream_for_room( self.get_room_events_stream_for_room(
room_id, from_key, to_key, limit room_id, from_key, to_key, limit
@ -220,8 +220,11 @@ class StreamStore(SQLBaseStore):
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
ret = self._get_events_txn( return rows
txn,
rows = yield self.runInteraction("get_room_events_stream_for_room", f)
ret = yield self._get_events(
[r["event_id"] for r in rows], [r["event_id"] for r in rows],
get_prev_content=True get_prev_content=True
) )
@ -237,11 +240,10 @@ class StreamStore(SQLBaseStore):
# get. # get.
key = from_key key = from_key
return ret, key defer.returnValue((ret, key))
res = yield self.runInteraction("get_room_events_stream_for_room", f)
defer.returnValue(res)
def get_room_changes_for_user(self, user_id, from_key, to_key): @defer.inlineCallbacks
def get_membership_changes_for_user(self, user_id, from_key, to_key):
if from_key is not None: if from_key is not None:
from_id = RoomStreamToken.parse_stream_token(from_key).stream from_id = RoomStreamToken.parse_stream_token(from_key).stream
else: else:
@ -249,14 +251,14 @@ class StreamStore(SQLBaseStore):
to_id = RoomStreamToken.parse_stream_token(to_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key: if from_key == to_key:
return defer.succeed([]) defer.returnValue([])
if from_id: if from_id:
has_changed = self._membership_stream_cache.has_entity_changed( has_changed = self._membership_stream_cache.has_entity_changed(
user_id, int(from_id) user_id, int(from_id)
) )
if not has_changed: if not has_changed:
return defer.succeed([]) defer.returnValue([])
def f(txn): def f(txn):
if from_id is not None: if from_id is not None:
@ -281,17 +283,18 @@ class StreamStore(SQLBaseStore):
txn.execute(sql, (user_id, to_id,)) txn.execute(sql, (user_id, to_id,))
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
ret = self._get_events_txn( return rows
txn,
rows = yield self.runInteraction("get_membership_changes_for_user", f)
ret = yield self._get_events(
[r["event_id"] for r in rows], [r["event_id"] for r in rows],
get_prev_content=True get_prev_content=True
) )
self._set_before_and_after(ret, rows, topo_order=False) self._set_before_and_after(ret, rows, topo_order=False)
return ret defer.returnValue(ret)
return self.runInteraction("get_room_changes_for_user", f)
def get_room_events_stream( def get_room_events_stream(
self, self,

View File

@ -46,7 +46,7 @@ class Clock(object):
def looping_call(self, f, msec): def looping_call(self, f, msec):
l = task.LoopingCall(f) l = task.LoopingCall(f)
l.start(msec/1000.0, now=False) l.start(msec / 1000.0, now=False)
return l return l
def stop_looping_call(self, loop): def stop_looping_call(self, loop):

View File

@ -149,7 +149,7 @@ class CacheDescriptor(object):
self.lru = lru self.lru = lru
self.tree = tree self.tree = tree
self.arg_names = inspect.getargspec(orig).args[1:num_args+1] self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
if len(self.arg_names) < self.num_args: if len(self.arg_names) < self.num_args:
raise Exception( raise Exception(
@ -250,7 +250,7 @@ class CacheListDescriptor(object):
self.num_args = num_args self.num_args = num_args
self.list_name = list_name self.list_name = list_name
self.arg_names = inspect.getargspec(orig).args[1:num_args+1] self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
self.list_pos = self.arg_names.index(self.list_name) self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache self.cache = cache

View File

@ -55,7 +55,7 @@ class ExpiringCache(object):
def f(): def f():
self._prune_cache() self._prune_cache()
self._clock.looping_call(f, self._expiry_ms/2) self._clock.looping_call(f, self._expiry_ms / 2)
def __setitem__(self, key, value): def __setitem__(self, key, value):
now = self._clock.time_msec() now = self._clock.time_msec()

View File

@ -58,7 +58,7 @@ class TreeCache(object):
if n: if n:
break break
node_and_keys[i+1][0].pop(k) node_and_keys[i + 1][0].pop(k)
popped, cnt = _strip_and_count_entires(popped) popped, cnt = _strip_and_count_entires(popped)
self.size -= cnt self.size -= cnt

View File

@ -111,7 +111,7 @@ def time_function(f):
_log_debug_as_f( _log_debug_as_f(
f, f,
"[FUNC END] {%s-%d} %f", "[FUNC END] {%s-%d} %f",
(func_name, id, end-start,), (func_name, id, end - start,),
) )
return r return r

View File

@ -163,7 +163,7 @@ class _PerHostRatelimiter(object):
"Ratelimit [%s]: sleeping req", "Ratelimit [%s]: sleeping req",
id(request_id), id(request_id),
) )
ret_defer = sleep(self.sleep_msec/1000.0) ret_defer = sleep(self.sleep_msec / 1000.0)
self.sleeping_requests.add(request_id) self.sleeping_requests.add(request_id)