Replaces calls to fetch_room_distributions_into with get_joined_hosts_for_room

This commit is contained in:
Mark Haines 2016-05-16 19:48:07 +01:00
parent 1a3a2002ff
commit 821306120a
5 changed files with 33 additions and 112 deletions

View File

@ -29,6 +29,8 @@ class ReceiptsHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(ReceiptsHandler, self).__init__(hs) super(ReceiptsHandler, self).__init__(hs)
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.federation.register_edu_handler( self.federation.register_edu_handler(
@ -131,12 +133,9 @@ class ReceiptsHandler(BaseHandler):
event_ids = receipt["event_ids"] event_ids = receipt["event_ids"]
data = receipt["data"] data = receipt["data"]
remotedomains = set() remotedomains = yield self.store.get_joined_hosts_for_room(room_id)
remotedomains = remotedomains.copy()
rm_handler = self.hs.get_handlers().room_member_handler remotedomains.discard(self.server_name)
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=None, remotedomains=remotedomains
)
logger.debug("Sending receipt to: %r", remotedomains) logger.debug("Sending receipt to: %r", remotedomains)

View File

@ -55,35 +55,6 @@ class RoomMemberHandler(BaseHandler):
self.distributor.declare("user_joined_room") self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room") self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def get_room_members(self, room_id):
users = yield self.store.get_users_in_room(room_id)
defer.returnValue([UserID.from_string(u) for u in users])
@defer.inlineCallbacks
def fetch_room_distributions_into(self, room_id, localusers=None,
remotedomains=None, ignore_user=None):
"""Fetch the distribution of a room, adding elements to either
'localusers' or 'remotedomains', which should be a set() if supplied.
If ignore_user is set, ignore that user.
This function returns nothing; its result is performed by the
side-effect on the two passed sets. This allows easy accumulation of
member lists of multiple rooms at once if required.
"""
members = yield self.get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if self.hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
@defer.inlineCallbacks @defer.inlineCallbacks
def _local_membership_update( def _local_membership_update(
self, requester, target, room_id, membership, self, requester, target, room_id, membership,

View File

@ -39,7 +39,8 @@ class TypingNotificationHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(TypingNotificationHandler, self).__init__(hs) super(TypingNotificationHandler, self).__init__(hs)
self.homeserver = hs self.store = hs.get_datastore()
self.server_name = hs.config.server_name
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -157,32 +158,26 @@ class TypingNotificationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _push_update(self, room_id, user, typing): def _push_update(self, room_id, user, typing):
localusers = set() domains = yield self.store.get_joined_hosts_for_room(room_id)
remotedomains = set()
rm_handler = self.homeserver.get_handlers().room_member_handler
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=localusers, remotedomains=remotedomains
)
if localusers:
self._push_update_local(
room_id=room_id,
user=user,
typing=typing
)
deferreds = [] deferreds = []
for domain in remotedomains: for domain in domains:
deferreds.append(self.federation.send_edu( if domain == self.server_name:
destination=domain, self._push_update_local(
edu_type="m.typing", room_id=room_id,
content={ user=user,
"room_id": room_id, typing=typing
"user_id": user.to_string(), )
"typing": typing, else:
}, deferreds.append(self.federation.send_edu(
)) destination=domain,
edu_type="m.typing",
content={
"room_id": room_id,
"user_id": user.to_string(),
"typing": typing,
},
))
yield defer.DeferredList(deferreds, consumeErrors=True) yield defer.DeferredList(deferreds, consumeErrors=True)
@ -191,14 +186,9 @@ class TypingNotificationHandler(BaseHandler):
room_id = content["room_id"] room_id = content["room_id"]
user = UserID.from_string(content["user_id"]) user = UserID.from_string(content["user_id"])
localusers = set() domains = yield self.store.get_joined_hosts_for_room(room_id)
rm_handler = self.homeserver.get_handlers().room_member_handler if self.server_name in domains:
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=localusers
)
if localusers:
self._push_update_local( self._push_update_local(
room_id=room_id, room_id=room_id,
user=user, user=user,

View File

@ -71,6 +71,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.auth = Mock(spec=[]) self.auth = Mock(spec=[])
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
"test",
auth=self.auth, auth=self.auth,
clock=self.clock, clock=self.clock,
datastore=Mock(spec=[ datastore=Mock(spec=[
@ -110,56 +111,16 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.room_id = "a-room" self.room_id = "a-room"
# Mock the RoomMemberHandler
hs.handlers.room_member_handler = Mock(spec=[])
self.room_member_handler = hs.handlers.room_member_handler
self.room_members = [] self.room_members = []
def get_rooms_for_user(user):
if user in self.room_members:
return defer.succeed([self.room_id])
else:
return defer.succeed([])
self.room_member_handler.get_rooms_for_user = get_rooms_for_user
def get_room_members(room_id):
if room_id == self.room_id:
return defer.succeed(self.room_members)
else:
return defer.succeed([])
self.room_member_handler.get_room_members = get_room_members
def get_joined_rooms_for_user(user):
if user in self.room_members:
return defer.succeed([self.room_id])
else:
return defer.succeed([])
self.room_member_handler.get_joined_rooms_for_user = get_joined_rooms_for_user
@defer.inlineCallbacks
def fetch_room_distributions_into(
room_id, localusers=None, remotedomains=None, ignore_user=None
):
members = yield get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
self.room_member_handler.fetch_room_distributions_into = (
fetch_room_distributions_into
)
def check_joined_room(room_id, user_id): def check_joined_room(room_id, user_id):
if user_id not in [u.to_string() for u in self.room_members]: if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room") raise AuthError(401, "User is not in the room")
def get_joined_hosts_for_room(room_id):
return set(member.domain for member in self.room_members)
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
self.auth.check_joined_room = check_joined_room self.auth.check_joined_room = check_joined_room
# Some local users to test with # Some local users to test with

View File

@ -50,7 +50,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.enable_registration = True config.enable_registration = True
config.macaroon_secret_key = "not even a little secret" config.macaroon_secret_key = "not even a little secret"
config.expire_access_token = False config.expire_access_token = False
config.server_name = "server.under.test" config.server_name = name
config.trusted_third_party_id_servers = [] config.trusted_third_party_id_servers = []
config.room_invite_state_types = [] config.room_invite_state_types = []