Merge branch 'develop' into markjh/bearer_token

This commit is contained in:
Mark Haines 2016-09-12 11:14:56 +01:00
commit 4a32d25d4c
9 changed files with 107 additions and 44 deletions

View File

@ -242,6 +242,9 @@ class SynchrotronTyping(object):
self._room_typing = {} self._room_typing = {}
def stream_positions(self): def stream_positions(self):
# We must update this typing token from the response of the previous
# sync. In particular, the stream id may "reset" back to zero/a low
# value which we *must* use for the next replication request.
return {"typing": self._latest_room_serial} return {"typing": self._latest_room_serial}
def process_replication(self, result): def process_replication(self, result):

View File

@ -122,8 +122,12 @@ class FederationClient(FederationBase):
pdu.event_id pdu.event_id
) )
def send_presence(self, destination, states):
if destination != self.server_name:
self._transaction_queue.enqueue_presence(destination, states)
@log_function @log_function
def send_edu(self, destination, edu_type, content): def send_edu(self, destination, edu_type, content, key=None):
edu = Edu( edu = Edu(
origin=self.server_name, origin=self.server_name,
destination=destination, destination=destination,
@ -134,7 +138,7 @@ class FederationClient(FederationBase):
sent_edus_counter.inc() sent_edus_counter.inc()
# TODO, add errback, etc. # TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu) self._transaction_queue.enqueue_edu(edu, key=key)
return defer.succeed(None) return defer.succeed(None)
@log_function @log_function

View File

@ -26,6 +26,7 @@ from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination, get_retry_limiter, NotRetryingDestination,
) )
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.handlers.presence import format_user_presence_state
import synapse.metrics import synapse.metrics
import logging import logging
@ -69,13 +70,21 @@ class TransactionQueue(object):
# destination -> list of tuple(edu, deferred) # destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = edus = {} self.pending_edus_by_dest = edus = {}
# Presence needs to be separate as we send single aggragate EDUs
self.pending_presence_by_dest = presence = {}
self.pending_edus_keyed_by_dest = edus_keyed = {}
metrics.register_callback( metrics.register_callback(
"pending_pdus", "pending_pdus",
lambda: sum(map(len, pdus.values())), lambda: sum(map(len, pdus.values())),
) )
metrics.register_callback( metrics.register_callback(
"pending_edus", "pending_edus",
lambda: sum(map(len, edus.values())), lambda: (
sum(map(len, edus.values()))
+ sum(map(len, presence.values()))
+ sum(map(len, edus_keyed.values()))
),
) )
# destination -> list of tuple(failure, deferred) # destination -> list of tuple(failure, deferred)
@ -130,12 +139,26 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination self._attempt_new_transaction, destination
) )
def enqueue_edu(self, edu): def enqueue_presence(self, destination, states):
self.pending_presence_by_dest.setdefault(destination, {}).update({
state.user_id: state for state in states
})
preserve_context_over_fn(
self._attempt_new_transaction, destination
)
def enqueue_edu(self, edu, key=None):
destination = edu.destination destination = edu.destination
if not self.can_send_to(destination): if not self.can_send_to(destination):
return return
if key:
self.pending_edus_keyed_by_dest.setdefault(
destination, {}
)[(edu.edu_type, key)] = edu
else:
self.pending_edus_by_dest.setdefault(destination, []).append(edu) self.pending_edus_by_dest.setdefault(destination, []).append(edu)
preserve_context_over_fn( preserve_context_over_fn(
@ -190,8 +213,13 @@ class TransactionQueue(object):
while True: while True:
pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_failures = self.pending_failures_by_dest.pop(destination, []) pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
)
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
destination, destination,
self.clock, self.clock,
@ -203,6 +231,22 @@ class TransactionQueue(object):
) )
pending_edus.extend(device_message_edus) pending_edus.extend(device_message_edus)
if pending_presence:
pending_edus.append(
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.presence",
content={
"push": [
format_user_presence_state(
presence, self.clock.time_msec()
)
for presence in pending_presence.values()
]
},
)
)
if pending_pdus: if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",

View File

@ -625,18 +625,8 @@ class PresenceHandler(object):
Args: Args:
hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]` hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]`
""" """
now = self.clock.time_msec()
for host, states in hosts_to_states.items(): for host, states in hosts_to_states.items():
self.federation.send_edu( self.federation.send_presence(host, states)
destination=host,
edu_type="m.presence",
content={
"push": [
_format_user_presence_state(state, now)
for state in states
]
}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def incoming_presence(self, origin, content): def incoming_presence(self, origin, content):
@ -723,13 +713,13 @@ class PresenceHandler(object):
defer.returnValue([ defer.returnValue([
{ {
"type": "m.presence", "type": "m.presence",
"content": _format_user_presence_state(state, now), "content": format_user_presence_state(state, now),
} }
for state in updates for state in updates
]) ])
else: else:
defer.returnValue([ defer.returnValue([
_format_user_presence_state(state, now) for state in updates format_user_presence_state(state, now) for state in updates
]) ])
@defer.inlineCallbacks @defer.inlineCallbacks
@ -988,7 +978,7 @@ def should_notify(old_state, new_state):
return False return False
def _format_user_presence_state(state, now): def format_user_presence_state(state, now):
"""Convert UserPresenceState to a format that can be sent down to clients """Convert UserPresenceState to a format that can be sent down to clients
and to other servers. and to other servers.
""" """
@ -1101,7 +1091,7 @@ class PresenceEventSource(object):
defer.returnValue(([ defer.returnValue(([
{ {
"type": "m.presence", "type": "m.presence",
"content": _format_user_presence_state(s, now), "content": format_user_presence_state(s, now),
} }
for s in updates.values() for s in updates.values()
if include_offline or s.state != PresenceState.OFFLINE if include_offline or s.state != PresenceState.OFFLINE

View File

@ -156,6 +156,7 @@ class ReceiptsHandler(BaseHandler):
} }
}, },
}, },
key=(room_id, receipt_type, user_id),
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -187,6 +187,7 @@ class TypingHandler(object):
"user_id": user_id, "user_id": user_id,
"typing": typing, "typing": typing,
}, },
key=(room_id, user_id),
)) ))
yield preserve_context_over_deferred( yield preserve_context_over_deferred(

View File

@ -274,11 +274,18 @@ class ReplicationResource(Resource):
@defer.inlineCallbacks @defer.inlineCallbacks
def typing(self, writer, current_token, request_streams): def typing(self, writer, current_token, request_streams):
current_position = current_token.presence current_position = current_token.typing
request_typing = request_streams.get("typing") request_typing = request_streams.get("typing")
if request_typing is not None: if request_typing is not None:
# If they have a higher token than current max, we can assume that
# they had been talking to a previous instance of the master. Since
# we reset the token on restart, the best (but hacky) thing we can
# do is to simply resend down all the typing notifications.
if request_typing > current_position:
request_typing = 0
typing_rows = yield self.typing_handler.get_all_typing_updates( typing_rows = yield self.typing_handler.get_all_typing_updates(
request_typing, current_position request_typing, current_position
) )

View File

@ -318,7 +318,7 @@ class CasRedirectServlet(ClientV1RestServlet):
service_param = urllib.urlencode({ service_param = urllib.urlencode({
"service": "%s?%s" % (hs_redirect_url, client_redirect_url_param) "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
}) })
request.redirect("%s?%s" % (self.cas_server_url, service_param)) request.redirect("%s/login?%s" % (self.cas_server_url, service_param))
finish_request(request) finish_request(request)
@ -385,7 +385,7 @@ class CasTicketServlet(ClientV1RestServlet):
def parse_cas_response(self, cas_response_body): def parse_cas_response(self, cas_response_body):
user = None user = None
attributes = None attributes = {}
try: try:
root = ET.fromstring(cas_response_body) root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"): if not root.tag.endswith("serviceResponse"):
@ -395,7 +395,6 @@ class CasTicketServlet(ClientV1RestServlet):
if child.tag.endswith("user"): if child.tag.endswith("user"):
user = child.text user = child.text
if child.tag.endswith("attributes"): if child.tag.endswith("attributes"):
attributes = {}
for attribute in child: for attribute in child:
# ElementTree library expands the namespace in # ElementTree library expands the namespace in
# attribute tags to the full URL of the namespace. # attribute tags to the full URL of the namespace.
@ -407,8 +406,6 @@ class CasTicketServlet(ClientV1RestServlet):
attributes[tag] = attribute.text attributes[tag] = attribute.text
if user is None: if user is None:
raise Exception("CAS response does not contain user") raise Exception("CAS response does not contain user")
if attributes is None:
raise Exception("CAS response does not contain attributes")
except Exception: except Exception:
logger.error("Error parsing CAS response", exc_info=1) logger.error("Error parsing CAS response", exc_info=1)
raise LoginError(401, "Invalid CAS response", raise LoginError(401, "Invalid CAS response",

View File

@ -306,13 +306,6 @@ class StateStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def _get_state_groups_from_groups_txn(self, txn, groups, types=None): def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
if types is not None:
where_clause = "AND (%s)" % (
" OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
)
else:
where_clause = ""
results = {group: {} for group in groups} results = {group: {} for group in groups}
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is # Temporarily disable sequential scans in this transaction. This is
@ -342,20 +335,43 @@ class StateStore(SQLBaseStore):
WHERE state_group IN ( WHERE state_group IN (
SELECT state_group FROM state SELECT state_group FROM state
) )
%s; %s
""") % (where_clause,) """)
# Turns out that postgres doesn't like doing a list of OR's and
# is about 1000x slower, so we just issue a query for each specific
# type seperately.
if types:
clause_to_args = [
(
"AND type = ? AND state_key = ?",
(etype, state_key)
)
for etype, state_key in types
]
else:
# If types is None we fetch all the state, and so just use an
# empty where clause with no extra args.
clause_to_args = [("", [])]
for where_clause, where_args in clause_to_args:
for group in groups: for group in groups:
args = [group] args = [group]
if types is not None: args.extend(where_args)
args.extend([i for typ in types for i in typ])
txn.execute(sql, args) txn.execute(sql % (where_clause,), args)
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
for row in rows: for row in rows:
key = (row["type"], row["state_key"]) key = (row["type"], row["state_key"])
results[group][key] = row["event_id"] results[group][key] = row["event_id"]
else: else:
if types is not None:
where_clause = "AND (%s)" % (
" OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
)
else:
where_clause = ""
# We don't use WITH RECURSIVE on sqlite3 as there are distributions # We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy) # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups: for group in groups: