Pass since/from parameters over federation

This commit is contained in:
Erik Johnston 2016-09-15 10:36:19 +01:00
parent f3eead0660
commit 5810cffd33
6 changed files with 63 additions and 57 deletions

View File

@ -24,7 +24,6 @@ from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError, CodeMessageException, HttpResponseException, SynapseError,
) )
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@ -719,24 +718,11 @@ class FederationClient(FederationBase):
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks def get_public_rooms(self, destination, limit=None, since_token=None):
def get_public_rooms(self, destinations): if destination == self.server_name:
results_by_server = {} return
@defer.inlineCallbacks return self.transport_layer.get_public_rooms(destination, limit, since_token)
def _get_result(s):
if s == self.server_name:
defer.returnValue()
try:
result = yield self.transport_layer.get_public_rooms(s)
results_by_server[s] = result
except:
logger.exception("Error getting room list from server %r", s)
yield concurrently_execute(_get_result, destinations, 3)
defer.returnValue(results_by_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth): def query_auth(self, destination, room_id, event_id, local_auth):

View File

@ -248,12 +248,19 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_public_rooms(self, remote_server): def get_public_rooms(self, remote_server, limit, since_token):
path = PREFIX + "/publicRooms" path = PREFIX + "/publicRooms"
args = {}
if limit:
args["limit"] = [str(limit)]
if since_token:
args["since"] = [since_token]
response = yield self.client.get_json( response = yield self.client.get_json(
destination=remote_server, destination=remote_server,
path=path, path=path,
args=args,
) )
defer.returnValue(response) defer.returnValue(response)

View File

@ -18,7 +18,9 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
)
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
@ -554,7 +556,11 @@ class PublicRoomList(BaseFederationServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, origin, content, query): def on_GET(self, origin, content, query):
data = yield self.room_list_handler.get_local_public_room_list() limit = parse_integer_from_args(query, "limit", 0)
since_token = parse_string_from_args(query, "since", None)
data = yield self.room_list_handler.get_local_public_room_list(
limit, since_token
)
defer.returnValue((200, data)) defer.returnValue((200, data))

View File

@ -20,7 +20,6 @@ from ._base import BaseHandler
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, JoinRules, EventTypes, JoinRules,
) )
from synapse.api.errors import SynapseError
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
@ -40,21 +39,21 @@ class RoomListHandler(BaseHandler):
super(RoomListHandler, self).__init__(hs) super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache(hs) self.response_cache = ResponseCache(hs)
def get_local_public_room_list(self, limit=None, next_batch=None): def get_local_public_room_list(self, limit=None, since_token=None):
result = self.response_cache.get((limit, next_batch)) result = self.response_cache.get((limit, since_token))
if not result: if not result:
result = self.response_cache.set( result = self.response_cache.set(
(limit, next_batch), (limit, since_token),
self._get_public_room_list(limit, next_batch) self._get_public_room_list(limit, since_token)
) )
return result return result
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_public_room_list(self, limit=None, next_batch=None): def _get_public_room_list(self, limit=None, since_token=None):
if next_batch and next_batch != "END": if since_token and since_token != "END":
next_batch = RoomListNextBatch.from_token(next_batch) since_token = RoomListNextBatch.from_token(since_token)
else: else:
next_batch = None since_token = None
room_ids = yield self.store.get_public_room_ids() room_ids = yield self.store.get_public_room_ids()
@ -62,8 +61,8 @@ class RoomListHandler(BaseHandler):
rooms_to_num_joined = {} rooms_to_num_joined = {}
rooms_to_latest_event_ids = {} rooms_to_latest_event_ids = {}
if next_batch: if since_token:
current_stream_token = next_batch.stream_ordering current_stream_token = since_token.stream_ordering
else: else:
current_stream_token = yield self.store.get_room_max_stream_ordering() current_stream_token = yield self.store.get_room_max_stream_ordering()
@ -99,22 +98,22 @@ class RoomListHandler(BaseHandler):
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1]) sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
sorted_rooms = [room_id for room_id, _ in sorted_entries] sorted_rooms = [room_id for room_id, _ in sorted_entries]
if next_batch: if since_token:
if next_batch.direction_is_forward: if since_token.direction_is_forward:
sorted_rooms = sorted_rooms[next_batch.current_limit:] sorted_rooms = sorted_rooms[since_token.current_limit:]
else: else:
sorted_rooms = sorted_rooms[:next_batch.current_limit] sorted_rooms = sorted_rooms[:since_token.current_limit]
sorted_rooms.reverse() sorted_rooms.reverse()
new_limit = None new_limit = None
if limit: if limit:
if sorted_rooms[limit:]: if sorted_rooms[limit:]:
new_limit = limit new_limit = limit
if next_batch: if since_token:
if next_batch.direction_is_forward: if since_token.direction_is_forward:
new_limit += next_batch.current_limit new_limit += since_token.current_limit
else: else:
new_limit = next_batch.current_limit - new_limit new_limit = since_token.current_limit - new_limit
new_limit = max(0, new_limit) new_limit = max(0, new_limit)
sorted_rooms = sorted_rooms[:limit] sorted_rooms = sorted_rooms[:limit]
@ -208,7 +207,7 @@ class RoomListHandler(BaseHandler):
"chunk": chunk, "chunk": chunk,
} }
if not next_batch or next_batch.direction_is_forward: if not since_token or since_token.direction_is_forward:
if new_limit: if new_limit:
results["next_batch"] = RoomListNextBatch( results["next_batch"] = RoomListNextBatch(
stream_ordering=current_stream_token, stream_ordering=current_stream_token,
@ -216,8 +215,8 @@ class RoomListHandler(BaseHandler):
direction_is_forward=True, direction_is_forward=True,
).to_token() ).to_token()
if next_batch: if since_token:
results["prev_batch"] = next_batch.copy_and_replace( results["prev_batch"] = since_token.copy_and_replace(
direction_is_forward=False, direction_is_forward=False,
).to_token() ).to_token()
else: else:
@ -228,22 +227,20 @@ class RoomListHandler(BaseHandler):
direction_is_forward=False, direction_is_forward=False,
).to_token() ).to_token()
if next_batch: if since_token:
results["next_batch"] = next_batch.copy_and_replace( results["next_batch"] = since_token.copy_and_replace(
direction_is_forward=True, direction_is_forward=True,
).to_token() ).to_token()
defer.returnValue(results) defer.returnValue(results)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_remote_public_room_list(self, server_name, limit=None, next_batch=None): def get_remote_public_room_list(self, server_name, limit=None, since_token=None):
res = yield self.hs.get_replication_layer().get_public_rooms( res = yield self.hs.get_replication_layer().get_public_rooms(
[server_name] server_name, limit=limit, since_token=since_token,
) )
if server_name not in res: defer.returnValue(res)
raise SynapseError(404, "Server not found")
defer.returnValue(res[server_name])
class RoomListNextBatch(namedtuple("RoomListNextBatch", ( class RoomListNextBatch(namedtuple("RoomListNextBatch", (

View File

@ -41,9 +41,13 @@ def parse_integer(request, name, default=None, required=False):
SynapseError: if the parameter is absent and required, or if the SynapseError: if the parameter is absent and required, or if the
parameter is present and not an integer. parameter is present and not an integer.
""" """
if name in request.args: return parse_integer_from_args(request.args, name, default, required)
def parse_integer_from_args(args, name, default=None, required=False):
if name in args:
try: try:
return int(request.args[name][0]) return int(args[name][0])
except: except:
message = "Query parameter %r must be an integer" % (name,) message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message) raise SynapseError(400, message)
@ -116,9 +120,15 @@ def parse_string(request, name, default=None, required=False,
parameter is present, must be one of a list of allowed values and parameter is present, must be one of a list of allowed values and
is not one of those allowed values. is not one of those allowed values.
""" """
return parse_string_from_args(
request.args, name, default, required, allowed_values, param_type,
)
if name in request.args:
value = request.args[name][0] def parse_string_from_args(args, name, default=None, required=False,
allowed_values=None, param_type="string"):
if name in args:
value = args[name][0]
if allowed_values is not None and value not in allowed_values: if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % ( message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values) name, ", ".join(repr(v) for v in allowed_values)

View File

@ -320,19 +320,19 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
pass pass
limit = parse_integer(request, "limit", 0) limit = parse_integer(request, "limit", 0)
next_batch = parse_string(request, "since", None) since_token = parse_string(request, "since", None)
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
if server: if server:
data = yield handler.get_remote_public_room_list( data = yield handler.get_remote_public_room_list(
server, server,
limit=limit, limit=limit,
next_batch=next_batch, since_token=since_token,
) )
else: else:
data = yield handler.get_local_public_room_list( data = yield handler.get_local_public_room_list(
limit=limit, limit=limit,
next_batch=next_batch, since_token=since_token,
) )
defer.returnValue((200, data)) defer.returnValue((200, data))