mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-16 21:37:08 -05:00
Clarify list/set/dict/tuple comprehensions and enforce via flake8 (#6957)
Ensure good comprehension hygiene using flake8-comprehensions.
This commit is contained in:
parent
272eee1ae1
commit
509e381afa
@ -60,7 +60,7 @@ python 3.6 and to install each tool:
|
|||||||
|
|
||||||
```
|
```
|
||||||
# Install the dependencies
|
# Install the dependencies
|
||||||
pip install -U black flake8 isort
|
pip install -U black flake8 flake8-comprehensions isort
|
||||||
|
|
||||||
# Run the linter script
|
# Run the linter script
|
||||||
./scripts-dev/lint.sh
|
./scripts-dev/lint.sh
|
||||||
|
1
changelog.d/6957.misc
Normal file
1
changelog.d/6957.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Use flake8-comprehensions to enforce good hygiene of list/set/dict comprehensions.
|
@ -30,7 +30,7 @@ The necessary tools are detailed below.
|
|||||||
|
|
||||||
Install `flake8` with:
|
Install `flake8` with:
|
||||||
|
|
||||||
pip install --upgrade flake8
|
pip install --upgrade flake8 flake8-comprehensions
|
||||||
|
|
||||||
Check all application and test code with:
|
Check all application and test code with:
|
||||||
|
|
||||||
|
@ -103,7 +103,7 @@ def main():
|
|||||||
|
|
||||||
yaml.safe_dump(result, sys.stdout, default_flow_style=False)
|
yaml.safe_dump(result, sys.stdout, default_flow_style=False)
|
||||||
|
|
||||||
rows = list(row for server, json in result.items() for row in rows_v2(server, json))
|
rows = [row for server, json in result.items() for row in rows_v2(server, json)]
|
||||||
|
|
||||||
cursor = connection.cursor()
|
cursor = connection.cursor()
|
||||||
cursor.executemany(
|
cursor.executemany(
|
||||||
|
@ -141,7 +141,7 @@ def start_reactor(
|
|||||||
|
|
||||||
def quit_with_error(error_string):
|
def quit_with_error(error_string):
|
||||||
message_lines = error_string.split("\n")
|
message_lines = error_string.split("\n")
|
||||||
line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
|
line_length = max(len(l) for l in message_lines if len(l) < 80) + 2
|
||||||
sys.stderr.write("*" * line_length + "\n")
|
sys.stderr.write("*" * line_length + "\n")
|
||||||
for line in message_lines:
|
for line in message_lines:
|
||||||
sys.stderr.write(" %s\n" % (line.rstrip(),))
|
sys.stderr.write(" %s\n" % (line.rstrip(),))
|
||||||
|
@ -262,7 +262,7 @@ class FederationSenderHandler(object):
|
|||||||
|
|
||||||
# ... as well as device updates and messages
|
# ... as well as device updates and messages
|
||||||
elif stream_name == DeviceListsStream.NAME:
|
elif stream_name == DeviceListsStream.NAME:
|
||||||
hosts = set(row.destination for row in rows)
|
hosts = {row.destination for row in rows}
|
||||||
for host in hosts:
|
for host in hosts:
|
||||||
self.federation_sender.send_device_messages(host)
|
self.federation_sender.send_device_messages(host)
|
||||||
|
|
||||||
@ -270,7 +270,7 @@ class FederationSenderHandler(object):
|
|||||||
# The to_device stream includes stuff to be pushed to both local
|
# The to_device stream includes stuff to be pushed to both local
|
||||||
# clients and remote servers, so we ignore entities that start with
|
# clients and remote servers, so we ignore entities that start with
|
||||||
# '@' (since they'll be local users rather than destinations).
|
# '@' (since they'll be local users rather than destinations).
|
||||||
hosts = set(row.entity for row in rows if not row.entity.startswith("@"))
|
hosts = {row.entity for row in rows if not row.entity.startswith("@")}
|
||||||
for host in hosts:
|
for host in hosts:
|
||||||
self.federation_sender.send_device_messages(host)
|
self.federation_sender.send_device_messages(host)
|
||||||
|
|
||||||
|
@ -158,7 +158,7 @@ class PusherReplicationHandler(ReplicationClientHandler):
|
|||||||
yield self.pusher_pool.on_new_notifications(token, token)
|
yield self.pusher_pool.on_new_notifications(token, token)
|
||||||
elif stream_name == "receipts":
|
elif stream_name == "receipts":
|
||||||
yield self.pusher_pool.on_new_receipts(
|
yield self.pusher_pool.on_new_receipts(
|
||||||
token, token, set(row.room_id for row in rows)
|
token, token, {row.room_id for row in rows}
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error poking pushers")
|
logger.exception("Error poking pushers")
|
||||||
|
@ -1066,12 +1066,12 @@ KNOWN_RESOURCES = (
|
|||||||
|
|
||||||
|
|
||||||
def _check_resource_config(listeners):
|
def _check_resource_config(listeners):
|
||||||
resource_names = set(
|
resource_names = {
|
||||||
res_name
|
res_name
|
||||||
for listener in listeners
|
for listener in listeners
|
||||||
for res in listener.get("resources", [])
|
for res in listener.get("resources", [])
|
||||||
for res_name in res.get("names", [])
|
for res_name in res.get("names", [])
|
||||||
)
|
}
|
||||||
|
|
||||||
for resource in resource_names:
|
for resource in resource_names:
|
||||||
if resource not in KNOWN_RESOURCES:
|
if resource not in KNOWN_RESOURCES:
|
||||||
|
@ -260,7 +260,7 @@ class TlsConfig(Config):
|
|||||||
crypto.FILETYPE_ASN1, self.tls_certificate
|
crypto.FILETYPE_ASN1, self.tls_certificate
|
||||||
)
|
)
|
||||||
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
|
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
|
||||||
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
|
sha256_fingerprints = {f["sha256"] for f in self.tls_fingerprints}
|
||||||
if sha256_fingerprint not in sha256_fingerprints:
|
if sha256_fingerprint not in sha256_fingerprints:
|
||||||
self.tls_fingerprints.append({"sha256": sha256_fingerprint})
|
self.tls_fingerprints.append({"sha256": sha256_fingerprint})
|
||||||
|
|
||||||
|
@ -326,9 +326,7 @@ class Keyring(object):
|
|||||||
verify_requests (list[VerifyJsonRequest]): list of verify requests
|
verify_requests (list[VerifyJsonRequest]): list of verify requests
|
||||||
"""
|
"""
|
||||||
|
|
||||||
remaining_requests = set(
|
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
|
||||||
(rq for rq in verify_requests if not rq.key_ready.called)
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_iterations():
|
def do_iterations():
|
||||||
@ -396,7 +394,7 @@ class Keyring(object):
|
|||||||
|
|
||||||
results = yield fetcher.get_keys(missing_keys)
|
results = yield fetcher.get_keys(missing_keys)
|
||||||
|
|
||||||
completed = list()
|
completed = []
|
||||||
for verify_request in remaining_requests:
|
for verify_request in remaining_requests:
|
||||||
server_name = verify_request.server_name
|
server_name = verify_request.server_name
|
||||||
|
|
||||||
|
@ -129,9 +129,9 @@ class FederationRemoteSendQueue(object):
|
|||||||
for key in keys[:i]:
|
for key in keys[:i]:
|
||||||
del self.presence_changed[key]
|
del self.presence_changed[key]
|
||||||
|
|
||||||
user_ids = set(
|
user_ids = {
|
||||||
user_id for uids in self.presence_changed.values() for user_id in uids
|
user_id for uids in self.presence_changed.values() for user_id in uids
|
||||||
)
|
}
|
||||||
|
|
||||||
keys = self.presence_destinations.keys()
|
keys = self.presence_destinations.keys()
|
||||||
i = self.presence_destinations.bisect_left(position_to_delete)
|
i = self.presence_destinations.bisect_left(position_to_delete)
|
||||||
|
@ -608,7 +608,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
user_results = yield self.store.get_users_in_group(
|
user_results = yield self.store.get_users_in_group(
|
||||||
group_id, include_private=True
|
group_id, include_private=True
|
||||||
)
|
)
|
||||||
if user_id in [user_result["user_id"] for user_result in user_results]:
|
if user_id in (user_result["user_id"] for user_result in user_results):
|
||||||
raise SynapseError(400, "User already in group")
|
raise SynapseError(400, "User already in group")
|
||||||
|
|
||||||
content = {
|
content = {
|
||||||
|
@ -742,6 +742,6 @@ class DeviceListUpdater(object):
|
|||||||
|
|
||||||
# We clobber the seen updates since we've re-synced from a given
|
# We clobber the seen updates since we've re-synced from a given
|
||||||
# point.
|
# point.
|
||||||
self._seen_updates[user_id] = set([stream_id])
|
self._seen_updates[user_id] = {stream_id}
|
||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
@ -72,7 +72,7 @@ class DirectoryHandler(BaseHandler):
|
|||||||
# TODO(erikj): Check if there is a current association.
|
# TODO(erikj): Check if there is a current association.
|
||||||
if not servers:
|
if not servers:
|
||||||
users = yield self.state.get_current_users_in_room(room_id)
|
users = yield self.state.get_current_users_in_room(room_id)
|
||||||
servers = set(get_domain_from_id(u) for u in users)
|
servers = {get_domain_from_id(u) for u in users}
|
||||||
|
|
||||||
if not servers:
|
if not servers:
|
||||||
raise SynapseError(400, "Failed to get server list")
|
raise SynapseError(400, "Failed to get server list")
|
||||||
@ -255,7 +255,7 @@ class DirectoryHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
users = yield self.state.get_current_users_in_room(room_id)
|
users = yield self.state.get_current_users_in_room(room_id)
|
||||||
extra_servers = set(get_domain_from_id(u) for u in users)
|
extra_servers = {get_domain_from_id(u) for u in users}
|
||||||
servers = set(extra_servers) | set(servers)
|
servers = set(extra_servers) | set(servers)
|
||||||
|
|
||||||
# If this server is in the list of servers, return it first.
|
# If this server is in the list of servers, return it first.
|
||||||
|
@ -659,11 +659,11 @@ class FederationHandler(BaseHandler):
|
|||||||
# this can happen if a remote server claims that the state or
|
# this can happen if a remote server claims that the state or
|
||||||
# auth_events at an event in room A are actually events in room B
|
# auth_events at an event in room A are actually events in room B
|
||||||
|
|
||||||
bad_events = list(
|
bad_events = [
|
||||||
(event_id, event.room_id)
|
(event_id, event.room_id)
|
||||||
for event_id, event in fetched_events.items()
|
for event_id, event in fetched_events.items()
|
||||||
if event.room_id != room_id
|
if event.room_id != room_id
|
||||||
)
|
]
|
||||||
|
|
||||||
for bad_event_id, bad_room_id in bad_events:
|
for bad_event_id, bad_room_id in bad_events:
|
||||||
# This is a bogus situation, but since we may only discover it a long time
|
# This is a bogus situation, but since we may only discover it a long time
|
||||||
@ -856,7 +856,7 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
# Don't bother processing events we already have.
|
# Don't bother processing events we already have.
|
||||||
seen_events = await self.store.have_events_in_timeline(
|
seen_events = await self.store.have_events_in_timeline(
|
||||||
set(e.event_id for e in events)
|
{e.event_id for e in events}
|
||||||
)
|
)
|
||||||
|
|
||||||
events = [e for e in events if e.event_id not in seen_events]
|
events = [e for e in events if e.event_id not in seen_events]
|
||||||
@ -866,7 +866,7 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
event_map = {e.event_id: e for e in events}
|
event_map = {e.event_id: e for e in events}
|
||||||
|
|
||||||
event_ids = set(e.event_id for e in events)
|
event_ids = {e.event_id for e in events}
|
||||||
|
|
||||||
# build a list of events whose prev_events weren't in the batch.
|
# build a list of events whose prev_events weren't in the batch.
|
||||||
# (XXX: this will include events whose prev_events we already have; that doesn't
|
# (XXX: this will include events whose prev_events we already have; that doesn't
|
||||||
@ -892,13 +892,13 @@ class FederationHandler(BaseHandler):
|
|||||||
state_events.update({s.event_id: s for s in state})
|
state_events.update({s.event_id: s for s in state})
|
||||||
events_to_state[e_id] = state
|
events_to_state[e_id] = state
|
||||||
|
|
||||||
required_auth = set(
|
required_auth = {
|
||||||
a_id
|
a_id
|
||||||
for event in events
|
for event in events
|
||||||
+ list(state_events.values())
|
+ list(state_events.values())
|
||||||
+ list(auth_events.values())
|
+ list(auth_events.values())
|
||||||
for a_id in event.auth_event_ids()
|
for a_id in event.auth_event_ids()
|
||||||
)
|
}
|
||||||
auth_events.update(
|
auth_events.update(
|
||||||
{e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
|
{e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
|
||||||
)
|
)
|
||||||
@ -1247,7 +1247,7 @@ class FederationHandler(BaseHandler):
|
|||||||
async def on_event_auth(self, event_id: str) -> List[EventBase]:
|
async def on_event_auth(self, event_id: str) -> List[EventBase]:
|
||||||
event = await self.store.get_event(event_id)
|
event = await self.store.get_event(event_id)
|
||||||
auth = await self.store.get_auth_chain(
|
auth = await self.store.get_auth_chain(
|
||||||
[auth_id for auth_id in event.auth_event_ids()], include_given=True
|
list(event.auth_event_ids()), include_given=True
|
||||||
)
|
)
|
||||||
return list(auth)
|
return list(auth)
|
||||||
|
|
||||||
@ -2152,7 +2152,7 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
# Now get the current auth_chain for the event.
|
# Now get the current auth_chain for the event.
|
||||||
local_auth_chain = await self.store.get_auth_chain(
|
local_auth_chain = await self.store.get_auth_chain(
|
||||||
[auth_id for auth_id in event.auth_event_ids()], include_given=True
|
list(event.auth_event_ids()), include_given=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Check if we would now reject event_id. If so we need to tell
|
# TODO: Check if we would now reject event_id. If so we need to tell
|
||||||
@ -2654,7 +2654,7 @@ class FederationHandler(BaseHandler):
|
|||||||
member_handler = self.hs.get_room_member_handler()
|
member_handler = self.hs.get_room_member_handler()
|
||||||
yield member_handler.send_membership_event(None, event, context)
|
yield member_handler.send_membership_event(None, event, context)
|
||||||
else:
|
else:
|
||||||
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
|
destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
|
||||||
yield self.federation_client.forward_third_party_invite(
|
yield self.federation_client.forward_third_party_invite(
|
||||||
destinations, room_id, event_dict
|
destinations, room_id, event_dict
|
||||||
)
|
)
|
||||||
|
@ -313,7 +313,7 @@ class PresenceHandler(object):
|
|||||||
notified_presence_counter.inc(len(to_notify))
|
notified_presence_counter.inc(len(to_notify))
|
||||||
yield self._persist_and_notify(list(to_notify.values()))
|
yield self._persist_and_notify(list(to_notify.values()))
|
||||||
|
|
||||||
self.unpersisted_users_changes |= set(s.user_id for s in new_states)
|
self.unpersisted_users_changes |= {s.user_id for s in new_states}
|
||||||
self.unpersisted_users_changes -= set(to_notify.keys())
|
self.unpersisted_users_changes -= set(to_notify.keys())
|
||||||
|
|
||||||
to_federation_ping = {
|
to_federation_ping = {
|
||||||
@ -698,7 +698,7 @@ class PresenceHandler(object):
|
|||||||
updates = yield self.current_state_for_users(target_user_ids)
|
updates = yield self.current_state_for_users(target_user_ids)
|
||||||
updates = list(updates.values())
|
updates = list(updates.values())
|
||||||
|
|
||||||
for user_id in set(target_user_ids) - set(u.user_id for u in updates):
|
for user_id in set(target_user_ids) - {u.user_id for u in updates}:
|
||||||
updates.append(UserPresenceState.default(user_id))
|
updates.append(UserPresenceState.default(user_id))
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
@ -886,7 +886,7 @@ class PresenceHandler(object):
|
|||||||
hosts = yield self.state.get_current_hosts_in_room(room_id)
|
hosts = yield self.state.get_current_hosts_in_room(room_id)
|
||||||
|
|
||||||
# Filter out ourselves.
|
# Filter out ourselves.
|
||||||
hosts = set(host for host in hosts if host != self.server_name)
|
hosts = {host for host in hosts if host != self.server_name}
|
||||||
|
|
||||||
self.federation.send_presence_to_destinations(
|
self.federation.send_presence_to_destinations(
|
||||||
states=[state], destinations=hosts
|
states=[state], destinations=hosts
|
||||||
|
@ -94,7 +94,7 @@ class ReceiptsHandler(BaseHandler):
|
|||||||
# no new receipts
|
# no new receipts
|
||||||
return False
|
return False
|
||||||
|
|
||||||
affected_room_ids = list(set([r.room_id for r in receipts]))
|
affected_room_ids = list({r.room_id for r in receipts})
|
||||||
|
|
||||||
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
|
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
|
||||||
# Note that the min here shouldn't be relied upon to be accurate.
|
# Note that the min here shouldn't be relied upon to be accurate.
|
||||||
|
@ -355,7 +355,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
# If so, mark the new room as non-federatable as well
|
# If so, mark the new room as non-federatable as well
|
||||||
creation_content["m.federate"] = False
|
creation_content["m.federate"] = False
|
||||||
|
|
||||||
initial_state = dict()
|
initial_state = {}
|
||||||
|
|
||||||
# Replicate relevant room events
|
# Replicate relevant room events
|
||||||
types_to_copy = (
|
types_to_copy = (
|
||||||
|
@ -184,7 +184,7 @@ class SearchHandler(BaseHandler):
|
|||||||
membership_list=[Membership.JOIN],
|
membership_list=[Membership.JOIN],
|
||||||
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
|
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
|
||||||
)
|
)
|
||||||
room_ids = set(r.room_id for r in rooms)
|
room_ids = {r.room_id for r in rooms}
|
||||||
|
|
||||||
# If doing a subset of all rooms seearch, check if any of the rooms
|
# If doing a subset of all rooms seearch, check if any of the rooms
|
||||||
# are from an upgraded room, and search their contents as well
|
# are from an upgraded room, and search their contents as well
|
||||||
@ -374,12 +374,12 @@ class SearchHandler(BaseHandler):
|
|||||||
).to_string()
|
).to_string()
|
||||||
|
|
||||||
if include_profile:
|
if include_profile:
|
||||||
senders = set(
|
senders = {
|
||||||
ev.sender
|
ev.sender
|
||||||
for ev in itertools.chain(
|
for ev in itertools.chain(
|
||||||
res["events_before"], [event], res["events_after"]
|
res["events_before"], [event], res["events_after"]
|
||||||
)
|
)
|
||||||
)
|
}
|
||||||
|
|
||||||
if res["events_after"]:
|
if res["events_after"]:
|
||||||
last_event_id = res["events_after"][-1].event_id
|
last_event_id = res["events_after"][-1].event_id
|
||||||
@ -421,7 +421,7 @@ class SearchHandler(BaseHandler):
|
|||||||
|
|
||||||
state_results = {}
|
state_results = {}
|
||||||
if include_state:
|
if include_state:
|
||||||
rooms = set(e.room_id for e in allowed_events)
|
rooms = {e.room_id for e in allowed_events}
|
||||||
for room_id in rooms:
|
for room_id in rooms:
|
||||||
state = yield self.state_handler.get_current_state(room_id)
|
state = yield self.state_handler.get_current_state(room_id)
|
||||||
state_results[room_id] = list(state.values())
|
state_results[room_id] = list(state.values())
|
||||||
|
@ -682,11 +682,9 @@ class SyncHandler(object):
|
|||||||
|
|
||||||
# FIXME: order by stream ordering rather than as returned by SQL
|
# FIXME: order by stream ordering rather than as returned by SQL
|
||||||
if joined_user_ids or invited_user_ids:
|
if joined_user_ids or invited_user_ids:
|
||||||
summary["m.heroes"] = sorted(
|
summary["m.heroes"] = sorted(joined_user_ids + invited_user_ids)[0:5]
|
||||||
[user_id for user_id in (joined_user_ids + invited_user_ids)]
|
|
||||||
)[0:5]
|
|
||||||
else:
|
else:
|
||||||
summary["m.heroes"] = sorted([user_id for user_id in gone_user_ids])[0:5]
|
summary["m.heroes"] = sorted(gone_user_ids)[0:5]
|
||||||
|
|
||||||
if not sync_config.filter_collection.lazy_load_members():
|
if not sync_config.filter_collection.lazy_load_members():
|
||||||
return summary
|
return summary
|
||||||
@ -697,9 +695,9 @@ class SyncHandler(object):
|
|||||||
|
|
||||||
# track which members the client should already know about via LL:
|
# track which members the client should already know about via LL:
|
||||||
# Ones which are already in state...
|
# Ones which are already in state...
|
||||||
existing_members = set(
|
existing_members = {
|
||||||
user_id for (typ, user_id) in state.keys() if typ == EventTypes.Member
|
user_id for (typ, user_id) in state.keys() if typ == EventTypes.Member
|
||||||
)
|
}
|
||||||
|
|
||||||
# ...or ones which are in the timeline...
|
# ...or ones which are in the timeline...
|
||||||
for ev in batch.events:
|
for ev in batch.events:
|
||||||
@ -773,10 +771,10 @@ class SyncHandler(object):
|
|||||||
# We only request state for the members needed to display the
|
# We only request state for the members needed to display the
|
||||||
# timeline:
|
# timeline:
|
||||||
|
|
||||||
members_to_fetch = set(
|
members_to_fetch = {
|
||||||
event.sender # FIXME: we also care about invite targets etc.
|
event.sender # FIXME: we also care about invite targets etc.
|
||||||
for event in batch.events
|
for event in batch.events
|
||||||
)
|
}
|
||||||
|
|
||||||
if full_state:
|
if full_state:
|
||||||
# always make sure we LL ourselves so we know we're in the room
|
# always make sure we LL ourselves so we know we're in the room
|
||||||
@ -1993,10 +1991,10 @@ def _calculate_state(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
c_ids = set(e for e in itervalues(current))
|
c_ids = set(itervalues(current))
|
||||||
ts_ids = set(e for e in itervalues(timeline_start))
|
ts_ids = set(itervalues(timeline_start))
|
||||||
p_ids = set(e for e in itervalues(previous))
|
p_ids = set(itervalues(previous))
|
||||||
tc_ids = set(e for e in itervalues(timeline_contains))
|
tc_ids = set(itervalues(timeline_contains))
|
||||||
|
|
||||||
# If we are lazyloading room members, we explicitly add the membership events
|
# If we are lazyloading room members, we explicitly add the membership events
|
||||||
# for the senders in the timeline into the state block returned by /sync,
|
# for the senders in the timeline into the state block returned by /sync,
|
||||||
|
@ -198,7 +198,7 @@ class TypingHandler(object):
|
|||||||
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
|
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
|
||||||
)
|
)
|
||||||
|
|
||||||
for domain in set(get_domain_from_id(u) for u in users):
|
for domain in {get_domain_from_id(u) for u in users}:
|
||||||
if domain != self.server_name:
|
if domain != self.server_name:
|
||||||
logger.debug("sending typing update to %s", domain)
|
logger.debug("sending typing update to %s", domain)
|
||||||
self.federation.build_and_send_edu(
|
self.federation.build_and_send_edu(
|
||||||
@ -231,7 +231,7 @@ class TypingHandler(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
users = yield self.state.get_current_users_in_room(room_id)
|
users = yield self.state.get_current_users_in_room(room_id)
|
||||||
domains = set(get_domain_from_id(u) for u in users)
|
domains = {get_domain_from_id(u) for u in users}
|
||||||
|
|
||||||
if self.server_name in domains:
|
if self.server_name in domains:
|
||||||
logger.info("Got typing update from %s: %r", user_id, content)
|
logger.info("Got typing update from %s: %r", user_id, content)
|
||||||
|
@ -148,7 +148,7 @@ def trace_function(f):
|
|||||||
pathname=pathname,
|
pathname=pathname,
|
||||||
lineno=lineno,
|
lineno=lineno,
|
||||||
msg=msg,
|
msg=msg,
|
||||||
args=tuple(),
|
args=(),
|
||||||
exc_info=None,
|
exc_info=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -240,7 +240,7 @@ class BucketCollector(object):
|
|||||||
res.append(["+Inf", sum(data.values())])
|
res.append(["+Inf", sum(data.values())])
|
||||||
|
|
||||||
metric = HistogramMetricFamily(
|
metric = HistogramMetricFamily(
|
||||||
self.name, "", buckets=res, sum_value=sum([x * y for x, y in data.items()])
|
self.name, "", buckets=res, sum_value=sum(x * y for x, y in data.items())
|
||||||
)
|
)
|
||||||
yield metric
|
yield metric
|
||||||
|
|
||||||
|
@ -80,13 +80,13 @@ _background_process_db_sched_duration = Counter(
|
|||||||
# map from description to a counter, so that we can name our logcontexts
|
# map from description to a counter, so that we can name our logcontexts
|
||||||
# incrementally. (It actually duplicates _background_process_start_count, but
|
# incrementally. (It actually duplicates _background_process_start_count, but
|
||||||
# it's much simpler to do so than to try to combine them.)
|
# it's much simpler to do so than to try to combine them.)
|
||||||
_background_process_counts = dict() # type: dict[str, int]
|
_background_process_counts = {} # type: dict[str, int]
|
||||||
|
|
||||||
# map from description to the currently running background processes.
|
# map from description to the currently running background processes.
|
||||||
#
|
#
|
||||||
# it's kept as a dict of sets rather than a big set so that we can keep track
|
# it's kept as a dict of sets rather than a big set so that we can keep track
|
||||||
# of process descriptions that no longer have any active processes.
|
# of process descriptions that no longer have any active processes.
|
||||||
_background_processes = dict() # type: dict[str, set[_BackgroundProcess]]
|
_background_processes = {} # type: dict[str, set[_BackgroundProcess]]
|
||||||
|
|
||||||
# A lock that covers the above dicts
|
# A lock that covers the above dicts
|
||||||
_bg_metrics_lock = threading.Lock()
|
_bg_metrics_lock = threading.Lock()
|
||||||
|
@ -400,11 +400,11 @@ class RulesForRoom(object):
|
|||||||
if logger.isEnabledFor(logging.DEBUG):
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
logger.debug("Found members %r: %r", self.room_id, members.values())
|
logger.debug("Found members %r: %r", self.room_id, members.values())
|
||||||
|
|
||||||
interested_in_user_ids = set(
|
interested_in_user_ids = {
|
||||||
user_id
|
user_id
|
||||||
for user_id, membership in itervalues(members)
|
for user_id, membership in itervalues(members)
|
||||||
if membership == Membership.JOIN
|
if membership == Membership.JOIN
|
||||||
)
|
}
|
||||||
|
|
||||||
logger.debug("Joined: %r", interested_in_user_ids)
|
logger.debug("Joined: %r", interested_in_user_ids)
|
||||||
|
|
||||||
@ -412,9 +412,9 @@ class RulesForRoom(object):
|
|||||||
interested_in_user_ids, on_invalidate=self.invalidate_all_cb
|
interested_in_user_ids, on_invalidate=self.invalidate_all_cb
|
||||||
)
|
)
|
||||||
|
|
||||||
user_ids = set(
|
user_ids = {
|
||||||
uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher
|
uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher
|
||||||
)
|
}
|
||||||
|
|
||||||
logger.debug("With pushers: %r", user_ids)
|
logger.debug("With pushers: %r", user_ids)
|
||||||
|
|
||||||
|
@ -204,7 +204,7 @@ class EmailPusher(object):
|
|||||||
yield self.send_notification(unprocessed, reason)
|
yield self.send_notification(unprocessed, reason)
|
||||||
|
|
||||||
yield self.save_last_stream_ordering_and_success(
|
yield self.save_last_stream_ordering_and_success(
|
||||||
max([ea["stream_ordering"] for ea in unprocessed])
|
max(ea["stream_ordering"] for ea in unprocessed)
|
||||||
)
|
)
|
||||||
|
|
||||||
# we update the throttle on all the possible unprocessed push actions
|
# we update the throttle on all the possible unprocessed push actions
|
||||||
|
@ -526,12 +526,10 @@ class Mailer(object):
|
|||||||
# If the room doesn't have a name, say who the messages
|
# If the room doesn't have a name, say who the messages
|
||||||
# are from explicitly to avoid, "messages in the Bob room"
|
# are from explicitly to avoid, "messages in the Bob room"
|
||||||
sender_ids = list(
|
sender_ids = list(
|
||||||
set(
|
{
|
||||||
[
|
|
||||||
notif_events[n["event_id"]].sender
|
notif_events[n["event_id"]].sender
|
||||||
for n in notifs_by_room[room_id]
|
for n in notifs_by_room[room_id]
|
||||||
]
|
}
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
member_events = yield self.store.get_events(
|
member_events = yield self.store.get_events(
|
||||||
@ -558,12 +556,10 @@ class Mailer(object):
|
|||||||
# If the reason room doesn't have a name, say who the messages
|
# If the reason room doesn't have a name, say who the messages
|
||||||
# are from explicitly to avoid, "messages in the Bob room"
|
# are from explicitly to avoid, "messages in the Bob room"
|
||||||
sender_ids = list(
|
sender_ids = list(
|
||||||
set(
|
{
|
||||||
[
|
|
||||||
notif_events[n["event_id"]].sender
|
notif_events[n["event_id"]].sender
|
||||||
for n in notifs_by_room[reason["room_id"]]
|
for n in notifs_by_room[reason["room_id"]]
|
||||||
]
|
}
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
member_events = yield self.store.get_events(
|
member_events = yield self.store.get_events(
|
||||||
|
@ -191,7 +191,7 @@ class PusherPool:
|
|||||||
min_stream_id - 1, max_stream_id
|
min_stream_id - 1, max_stream_id
|
||||||
)
|
)
|
||||||
# This returns a tuple, user_id is at index 3
|
# This returns a tuple, user_id is at index 3
|
||||||
users_affected = set([r[3] for r in updated_receipts])
|
users_affected = {r[3] for r in updated_receipts}
|
||||||
|
|
||||||
for u in users_affected:
|
for u in users_affected:
|
||||||
if u in self.pushers:
|
if u in self.pushers:
|
||||||
|
@ -29,7 +29,7 @@ def historical_admin_path_patterns(path_regex):
|
|||||||
Note that this should only be used for existing endpoints: new ones should just
|
Note that this should only be used for existing endpoints: new ones should just
|
||||||
register for the /_synapse/admin path.
|
register for the /_synapse/admin path.
|
||||||
"""
|
"""
|
||||||
return list(
|
return [
|
||||||
re.compile(prefix + path_regex)
|
re.compile(prefix + path_regex)
|
||||||
for prefix in (
|
for prefix in (
|
||||||
"^/_synapse/admin/v1",
|
"^/_synapse/admin/v1",
|
||||||
@ -37,7 +37,7 @@ def historical_admin_path_patterns(path_regex):
|
|||||||
"^/_matrix/client/unstable/admin",
|
"^/_matrix/client/unstable/admin",
|
||||||
"^/_matrix/client/r0/admin",
|
"^/_matrix/client/r0/admin",
|
||||||
)
|
)
|
||||||
)
|
]
|
||||||
|
|
||||||
|
|
||||||
def admin_patterns(path_regex: str):
|
def admin_patterns(path_regex: str):
|
||||||
|
@ -49,7 +49,7 @@ class PushRuleRestServlet(RestServlet):
|
|||||||
if self._is_worker:
|
if self._is_worker:
|
||||||
raise Exception("Cannot handle PUT /push_rules on worker")
|
raise Exception("Cannot handle PUT /push_rules on worker")
|
||||||
|
|
||||||
spec = _rule_spec_from_path([x for x in path.split("/")])
|
spec = _rule_spec_from_path(path.split("/"))
|
||||||
try:
|
try:
|
||||||
priority_class = _priority_class_from_spec(spec)
|
priority_class = _priority_class_from_spec(spec)
|
||||||
except InvalidRuleException as e:
|
except InvalidRuleException as e:
|
||||||
@ -110,7 +110,7 @@ class PushRuleRestServlet(RestServlet):
|
|||||||
if self._is_worker:
|
if self._is_worker:
|
||||||
raise Exception("Cannot handle DELETE /push_rules on worker")
|
raise Exception("Cannot handle DELETE /push_rules on worker")
|
||||||
|
|
||||||
spec = _rule_spec_from_path([x for x in path.split("/")])
|
spec = _rule_spec_from_path(path.split("/"))
|
||||||
|
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
@ -138,7 +138,7 @@ class PushRuleRestServlet(RestServlet):
|
|||||||
|
|
||||||
rules = format_push_rules_for_user(requester.user, rules)
|
rules = format_push_rules_for_user(requester.user, rules)
|
||||||
|
|
||||||
path = [x for x in path.split("/")][1:]
|
path = path.split("/")[1:]
|
||||||
|
|
||||||
if path == []:
|
if path == []:
|
||||||
# we're a reference impl: pedantry is our job.
|
# we're a reference impl: pedantry is our job.
|
||||||
|
@ -54,9 +54,9 @@ class PushersRestServlet(RestServlet):
|
|||||||
|
|
||||||
pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
|
pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
|
||||||
|
|
||||||
filtered_pushers = list(
|
filtered_pushers = [
|
||||||
{k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
|
{k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
|
||||||
)
|
]
|
||||||
|
|
||||||
return 200, {"pushers": filtered_pushers}
|
return 200, {"pushers": filtered_pushers}
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ class SyncRestServlet(RestServlet):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
PATTERNS = client_patterns("/sync$")
|
PATTERNS = client_patterns("/sync$")
|
||||||
ALLOWED_PRESENCE = set(["online", "offline", "unavailable"])
|
ALLOWED_PRESENCE = {"online", "offline", "unavailable"}
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(SyncRestServlet, self).__init__()
|
super(SyncRestServlet, self).__init__()
|
||||||
|
@ -149,7 +149,7 @@ class RemoteKey(DirectServeResource):
|
|||||||
|
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
|
|
||||||
cache_misses = dict() # type: Dict[str, Set[str]]
|
cache_misses = {} # type: Dict[str, Set[str]]
|
||||||
for (server_name, key_id, from_server), results in cached.items():
|
for (server_name, key_id, from_server), results in cached.items():
|
||||||
results = [(result["ts_added_ms"], result) for result in results]
|
results = [(result["ts_added_ms"], result) for result in results]
|
||||||
|
|
||||||
|
@ -135,8 +135,7 @@ def add_file_headers(request, media_type, file_size, upload_name):
|
|||||||
|
|
||||||
# separators as defined in RFC2616. SP and HT are handled separately.
|
# separators as defined in RFC2616. SP and HT are handled separately.
|
||||||
# see _can_encode_filename_as_token.
|
# see _can_encode_filename_as_token.
|
||||||
_FILENAME_SEPARATOR_CHARS = set(
|
_FILENAME_SEPARATOR_CHARS = {
|
||||||
(
|
|
||||||
"(",
|
"(",
|
||||||
")",
|
")",
|
||||||
"<",
|
"<",
|
||||||
@ -154,8 +153,7 @@ _FILENAME_SEPARATOR_CHARS = set(
|
|||||||
"=",
|
"=",
|
||||||
"{",
|
"{",
|
||||||
"}",
|
"}",
|
||||||
)
|
}
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _can_encode_filename_as_token(x):
|
def _can_encode_filename_as_token(x):
|
||||||
|
@ -69,9 +69,9 @@ def resolve_events_with_store(
|
|||||||
|
|
||||||
unconflicted_state, conflicted_state = _seperate(state_sets)
|
unconflicted_state, conflicted_state = _seperate(state_sets)
|
||||||
|
|
||||||
needed_events = set(
|
needed_events = {
|
||||||
event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids
|
event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids
|
||||||
)
|
}
|
||||||
needed_event_count = len(needed_events)
|
needed_event_count = len(needed_events)
|
||||||
if event_map is not None:
|
if event_map is not None:
|
||||||
needed_events -= set(iterkeys(event_map))
|
needed_events -= set(iterkeys(event_map))
|
||||||
@ -261,11 +261,11 @@ def _resolve_state_events(conflicted_state, auth_events):
|
|||||||
|
|
||||||
|
|
||||||
def _resolve_auth_events(events, auth_events):
|
def _resolve_auth_events(events, auth_events):
|
||||||
reverse = [i for i in reversed(_ordered_events(events))]
|
reverse = list(reversed(_ordered_events(events)))
|
||||||
|
|
||||||
auth_keys = set(
|
auth_keys = {
|
||||||
key for event in events for key in event_auth.auth_types_for_event(event)
|
key for event in events for key in event_auth.auth_types_for_event(event)
|
||||||
)
|
}
|
||||||
|
|
||||||
new_auth_events = {}
|
new_auth_events = {}
|
||||||
for key in auth_keys:
|
for key in auth_keys:
|
||||||
|
@ -105,7 +105,7 @@ def resolve_events_with_store(
|
|||||||
% (room_id, event.event_id, event.room_id,)
|
% (room_id, event.event_id, event.room_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map)
|
full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map}
|
||||||
|
|
||||||
logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
|
logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
|
||||||
|
|
||||||
@ -233,7 +233,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
|
|||||||
|
|
||||||
auth_sets = []
|
auth_sets = []
|
||||||
for state_set in state_sets:
|
for state_set in state_sets:
|
||||||
auth_ids = set(
|
auth_ids = {
|
||||||
eid
|
eid
|
||||||
for key, eid in iteritems(state_set)
|
for key, eid in iteritems(state_set)
|
||||||
if (
|
if (
|
||||||
@ -246,7 +246,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
and eid not in common
|
and eid not in common
|
||||||
)
|
}
|
||||||
|
|
||||||
auth_chain = yield state_res_store.get_auth_chain(auth_ids, common)
|
auth_chain = yield state_res_store.get_auth_chain(auth_ids, common)
|
||||||
auth_ids.update(auth_chain)
|
auth_ids.update(auth_chain)
|
||||||
@ -275,7 +275,7 @@ def _seperate(state_sets):
|
|||||||
conflicted_state = {}
|
conflicted_state = {}
|
||||||
|
|
||||||
for key in set(itertools.chain.from_iterable(state_sets)):
|
for key in set(itertools.chain.from_iterable(state_sets)):
|
||||||
event_ids = set(state_set.get(key) for state_set in state_sets)
|
event_ids = {state_set.get(key) for state_set in state_sets}
|
||||||
if len(event_ids) == 1:
|
if len(event_ids) == 1:
|
||||||
unconflicted_state[key] = event_ids.pop()
|
unconflicted_state[key] = event_ids.pop()
|
||||||
else:
|
else:
|
||||||
|
@ -56,7 +56,7 @@ class SQLBaseStore(metaclass=ABCMeta):
|
|||||||
members_changed (iterable[str]): The user_ids of members that have
|
members_changed (iterable[str]): The user_ids of members that have
|
||||||
changed
|
changed
|
||||||
"""
|
"""
|
||||||
for host in set(get_domain_from_id(u) for u in members_changed):
|
for host in {get_domain_from_id(u) for u in members_changed}:
|
||||||
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
|
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
|
||||||
self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
|
self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
|
||||||
|
|
||||||
|
@ -189,7 +189,7 @@ class BackgroundUpdater(object):
|
|||||||
keyvalues=None,
|
keyvalues=None,
|
||||||
retcols=("update_name", "depends_on"),
|
retcols=("update_name", "depends_on"),
|
||||||
)
|
)
|
||||||
in_flight = set(update["update_name"] for update in updates)
|
in_flight = {update["update_name"] for update in updates}
|
||||||
for update in updates:
|
for update in updates:
|
||||||
if update["depends_on"] not in in_flight:
|
if update["depends_on"] not in in_flight:
|
||||||
self._background_update_queue.append(update["update_name"])
|
self._background_update_queue.append(update["update_name"])
|
||||||
|
@ -135,7 +135,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
may be empty.
|
may be empty.
|
||||||
"""
|
"""
|
||||||
results = yield self.db.simple_select_list(
|
results = yield self.db.simple_select_list(
|
||||||
"application_services_state", dict(state=state), ["as_id"]
|
"application_services_state", {"state": state}, ["as_id"]
|
||||||
)
|
)
|
||||||
# NB: This assumes this class is linked with ApplicationServiceStore
|
# NB: This assumes this class is linked with ApplicationServiceStore
|
||||||
as_list = self.get_app_services()
|
as_list = self.get_app_services()
|
||||||
@ -158,7 +158,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
"""
|
"""
|
||||||
result = yield self.db.simple_select_one(
|
result = yield self.db.simple_select_one(
|
||||||
"application_services_state",
|
"application_services_state",
|
||||||
dict(as_id=service.id),
|
{"as_id": service.id},
|
||||||
["state"],
|
["state"],
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
desc="get_appservice_state",
|
desc="get_appservice_state",
|
||||||
@ -177,7 +177,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
A Deferred which resolves when the state was set successfully.
|
A Deferred which resolves when the state was set successfully.
|
||||||
"""
|
"""
|
||||||
return self.db.simple_upsert(
|
return self.db.simple_upsert(
|
||||||
"application_services_state", dict(as_id=service.id), dict(state=state)
|
"application_services_state", {"as_id": service.id}, {"state": state}
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_appservice_txn(self, service, events):
|
def create_appservice_txn(self, service, events):
|
||||||
@ -253,13 +253,15 @@ class ApplicationServiceTransactionWorkerStore(
|
|||||||
self.db.simple_upsert_txn(
|
self.db.simple_upsert_txn(
|
||||||
txn,
|
txn,
|
||||||
"application_services_state",
|
"application_services_state",
|
||||||
dict(as_id=service.id),
|
{"as_id": service.id},
|
||||||
dict(last_txn=txn_id),
|
{"last_txn": txn_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Delete txn
|
# Delete txn
|
||||||
self.db.simple_delete_txn(
|
self.db.simple_delete_txn(
|
||||||
txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id)
|
txn,
|
||||||
|
"application_services_txns",
|
||||||
|
{"txn_id": txn_id, "as_id": service.id},
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.db.runInteraction(
|
return self.db.runInteraction(
|
||||||
|
@ -530,7 +530,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
|
|||||||
((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
|
((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
|
||||||
for row in rows
|
for row in rows
|
||||||
)
|
)
|
||||||
return list(
|
return [
|
||||||
{
|
{
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"ip": ip,
|
"ip": ip,
|
||||||
@ -538,7 +538,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
|
|||||||
"last_seen": last_seen,
|
"last_seen": last_seen,
|
||||||
}
|
}
|
||||||
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
|
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
|
||||||
)
|
]
|
||||||
|
|
||||||
@wrap_as_background_process("prune_old_user_ips")
|
@wrap_as_background_process("prune_old_user_ips")
|
||||||
async def _prune_old_user_ips(self):
|
async def _prune_old_user_ips(self):
|
||||||
|
@ -137,7 +137,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
# get the cross-signing keys of the users in the list, so that we can
|
# get the cross-signing keys of the users in the list, so that we can
|
||||||
# determine which of the device changes were cross-signing keys
|
# determine which of the device changes were cross-signing keys
|
||||||
users = set(r[0] for r in updates)
|
users = {r[0] for r in updates}
|
||||||
master_key_by_user = {}
|
master_key_by_user = {}
|
||||||
self_signing_key_by_user = {}
|
self_signing_key_by_user = {}
|
||||||
for user in users:
|
for user in users:
|
||||||
@ -446,7 +446,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
a set of user_ids and results_map is a mapping of
|
a set of user_ids and results_map is a mapping of
|
||||||
user_id -> device_id -> device_info
|
user_id -> device_id -> device_info
|
||||||
"""
|
"""
|
||||||
user_ids = set(user_id for user_id, _ in query_list)
|
user_ids = {user_id for user_id, _ in query_list}
|
||||||
user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
|
user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
|
||||||
|
|
||||||
# We go and check if any of the users need to have their device lists
|
# We go and check if any of the users need to have their device lists
|
||||||
@ -454,10 +454,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
users_needing_resync = yield self.get_user_ids_requiring_device_list_resync(
|
users_needing_resync = yield self.get_user_ids_requiring_device_list_resync(
|
||||||
user_ids
|
user_ids
|
||||||
)
|
)
|
||||||
user_ids_in_cache = (
|
user_ids_in_cache = {
|
||||||
set(user_id for user_id, stream_id in user_map.items() if stream_id)
|
user_id for user_id, stream_id in user_map.items() if stream_id
|
||||||
- users_needing_resync
|
} - users_needing_resync
|
||||||
)
|
|
||||||
user_ids_not_in_cache = user_ids - user_ids_in_cache
|
user_ids_not_in_cache = user_ids - user_ids_in_cache
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
@ -604,7 +603,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
rows = yield self.db.execute(
|
rows = yield self.db.execute(
|
||||||
"get_users_whose_signatures_changed", None, sql, user_id, from_key
|
"get_users_whose_signatures_changed", None, sql, user_id, from_key
|
||||||
)
|
)
|
||||||
return set(user for row in rows for user in json.loads(row[0]))
|
return {user for row in rows for user in json.loads(row[0])}
|
||||||
else:
|
else:
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
|
@ -426,7 +426,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||||||
query, (room_id, event_id, False, limit - len(event_results))
|
query, (room_id, event_id, False, limit - len(event_results))
|
||||||
)
|
)
|
||||||
|
|
||||||
new_results = set(t[0] for t in txn) - seen_events
|
new_results = {t[0] for t in txn} - seen_events
|
||||||
|
|
||||||
new_front |= new_results
|
new_front |= new_results
|
||||||
seen_events |= new_results
|
seen_events |= new_results
|
||||||
|
@ -145,7 +145,7 @@ class EventsStore(
|
|||||||
return txn.fetchall()
|
return txn.fetchall()
|
||||||
|
|
||||||
res = yield self.db.runInteraction("read_forward_extremities", fetch)
|
res = yield self.db.runInteraction("read_forward_extremities", fetch)
|
||||||
self._current_forward_extremities_amount = c_counter(list(x[0] for x in res))
|
self._current_forward_extremities_amount = c_counter([x[0] for x in res])
|
||||||
|
|
||||||
@_retry_on_integrity_error
|
@_retry_on_integrity_error
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -598,11 +598,11 @@ class EventsStore(
|
|||||||
# We find out which membership events we may have deleted
|
# We find out which membership events we may have deleted
|
||||||
# and which we have added, then we invlidate the caches for all
|
# and which we have added, then we invlidate the caches for all
|
||||||
# those users.
|
# those users.
|
||||||
members_changed = set(
|
members_changed = {
|
||||||
state_key
|
state_key
|
||||||
for ev_type, state_key in itertools.chain(to_delete, to_insert)
|
for ev_type, state_key in itertools.chain(to_delete, to_insert)
|
||||||
if ev_type == EventTypes.Member
|
if ev_type == EventTypes.Member
|
||||||
)
|
}
|
||||||
|
|
||||||
for member in members_changed:
|
for member in members_changed:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
@ -1615,7 +1615,7 @@ class EventsStore(
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
referenced_state_groups = set(sg for sg, in txn)
|
referenced_state_groups = {sg for sg, in txn}
|
||||||
logger.info(
|
logger.info(
|
||||||
"[purge] found %i referenced state groups", len(referenced_state_groups)
|
"[purge] found %i referenced state groups", len(referenced_state_groups)
|
||||||
)
|
)
|
||||||
|
@ -402,7 +402,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||||||
keyvalues={},
|
keyvalues={},
|
||||||
retcols=("room_id",),
|
retcols=("room_id",),
|
||||||
)
|
)
|
||||||
room_ids = set(row["room_id"] for row in rows)
|
room_ids = {row["room_id"] for row in rows}
|
||||||
for room_id in room_ids:
|
for room_id in room_ids:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_latest_event_ids_in_room.invalidate, (room_id,)
|
self.get_latest_event_ids_in_room.invalidate, (room_id,)
|
||||||
|
@ -494,9 +494,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
"""
|
"""
|
||||||
with Measure(self._clock, "_fetch_event_list"):
|
with Measure(self._clock, "_fetch_event_list"):
|
||||||
try:
|
try:
|
||||||
events_to_fetch = set(
|
events_to_fetch = {
|
||||||
event_id for events, _ in event_list for event_id in events
|
event_id for events, _ in event_list for event_id in events
|
||||||
)
|
}
|
||||||
|
|
||||||
row_dict = self.db.new_transaction(
|
row_dict = self.db.new_transaction(
|
||||||
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
|
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
|
||||||
@ -804,7 +804,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
desc="have_events_in_timeline",
|
desc="have_events_in_timeline",
|
||||||
)
|
)
|
||||||
|
|
||||||
return set(r["event_id"] for r in rows)
|
return {r["event_id"] for r in rows}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def have_seen_events(self, event_ids):
|
def have_seen_events(self, event_ids):
|
||||||
|
@ -276,21 +276,21 @@ class PushRulesWorkerStore(
|
|||||||
# We ignore app service users for now. This is so that we don't fill
|
# We ignore app service users for now. This is so that we don't fill
|
||||||
# up the `get_if_users_have_pushers` cache with AS entries that we
|
# up the `get_if_users_have_pushers` cache with AS entries that we
|
||||||
# know don't have pushers, nor even read receipts.
|
# know don't have pushers, nor even read receipts.
|
||||||
local_users_in_room = set(
|
local_users_in_room = {
|
||||||
u
|
u
|
||||||
for u in users_in_room
|
for u in users_in_room
|
||||||
if self.hs.is_mine_id(u)
|
if self.hs.is_mine_id(u)
|
||||||
and not self.get_if_app_services_interested_in_user(u)
|
and not self.get_if_app_services_interested_in_user(u)
|
||||||
)
|
}
|
||||||
|
|
||||||
# users in the room who have pushers need to get push rules run because
|
# users in the room who have pushers need to get push rules run because
|
||||||
# that's how their pushers work
|
# that's how their pushers work
|
||||||
if_users_with_pushers = yield self.get_if_users_have_pushers(
|
if_users_with_pushers = yield self.get_if_users_have_pushers(
|
||||||
local_users_in_room, on_invalidate=cache_context.invalidate
|
local_users_in_room, on_invalidate=cache_context.invalidate
|
||||||
)
|
)
|
||||||
user_ids = set(
|
user_ids = {
|
||||||
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
|
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
|
||||||
)
|
}
|
||||||
|
|
||||||
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
|
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
|
||||||
room_id, on_invalidate=cache_context.invalidate
|
room_id, on_invalidate=cache_context.invalidate
|
||||||
|
@ -58,7 +58,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||||||
@cachedInlineCallbacks()
|
@cachedInlineCallbacks()
|
||||||
def get_users_with_read_receipts_in_room(self, room_id):
|
def get_users_with_read_receipts_in_room(self, room_id):
|
||||||
receipts = yield self.get_receipts_for_room(room_id, "m.read")
|
receipts = yield self.get_receipts_for_room(room_id, "m.read")
|
||||||
return set(r["user_id"] for r in receipts)
|
return {r["user_id"] for r in receipts}
|
||||||
|
|
||||||
@cached(num_args=2)
|
@cached(num_args=2)
|
||||||
def get_receipts_for_room(self, room_id, receipt_type):
|
def get_receipts_for_room(self, room_id, receipt_type):
|
||||||
@ -283,7 +283,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||||||
args.append(limit)
|
args.append(limit)
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
|
|
||||||
return list(r[0:5] + (json.loads(r[5]),) for r in txn)
|
return [r[0:5] + (json.loads(r[5]),) for r in txn]
|
||||||
|
|
||||||
return self.db.runInteraction(
|
return self.db.runInteraction(
|
||||||
"get_all_updated_receipts", get_all_updated_receipts_txn
|
"get_all_updated_receipts", get_all_updated_receipts_txn
|
||||||
|
@ -465,7 +465,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||||||
|
|
||||||
txn.execute(sql % (clause,), args)
|
txn.execute(sql % (clause,), args)
|
||||||
|
|
||||||
return set(row[0] for row in txn)
|
return {row[0] for row in txn}
|
||||||
|
|
||||||
return await self.db.runInteraction(
|
return await self.db.runInteraction(
|
||||||
"get_users_server_still_shares_room_with",
|
"get_users_server_still_shares_room_with",
|
||||||
@ -826,7 +826,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||||||
GROUP BY room_id, user_id;
|
GROUP BY room_id, user_id;
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (user_id,))
|
txn.execute(sql, (user_id,))
|
||||||
return set(row[0] for row in txn if row[1] == 0)
|
return {row[0] for row in txn if row[1] == 0}
|
||||||
|
|
||||||
return self.db.runInteraction(
|
return self.db.runInteraction(
|
||||||
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
|
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
|
||||||
|
@ -321,7 +321,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||||||
desc="get_referenced_state_groups",
|
desc="get_referenced_state_groups",
|
||||||
)
|
)
|
||||||
|
|
||||||
return set(row["state_group"] for row in rows)
|
return {row["state_group"] for row in rows}
|
||||||
|
|
||||||
|
|
||||||
class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
||||||
@ -367,7 +367,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
txn.execute(sql, (last_room_id, batch_size))
|
txn.execute(sql, (last_room_id, batch_size))
|
||||||
room_ids = list(row[0] for row in txn)
|
room_ids = [row[0] for row in txn]
|
||||||
if not room_ids:
|
if not room_ids:
|
||||||
return True, set()
|
return True, set()
|
||||||
|
|
||||||
@ -384,7 +384,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
|||||||
|
|
||||||
txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name))
|
txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name))
|
||||||
|
|
||||||
joined_room_ids = set(row[0] for row in txn)
|
joined_room_ids = {row[0] for row in txn}
|
||||||
|
|
||||||
left_rooms = set(room_ids) - joined_room_ids
|
left_rooms = set(room_ids) - joined_room_ids
|
||||||
|
|
||||||
@ -404,7 +404,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
|||||||
retcols=("state_key",),
|
retcols=("state_key",),
|
||||||
)
|
)
|
||||||
|
|
||||||
potentially_left_users = set(row["state_key"] for row in rows)
|
potentially_left_users = {row["state_key"] for row in rows}
|
||||||
|
|
||||||
# Now lets actually delete the rooms from the DB.
|
# Now lets actually delete the rooms from the DB.
|
||||||
self.db.simple_delete_many_txn(
|
self.db.simple_delete_many_txn(
|
||||||
|
@ -346,11 +346,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||||||
from_key (str): The room_key portion of a StreamToken
|
from_key (str): The room_key portion of a StreamToken
|
||||||
"""
|
"""
|
||||||
from_key = RoomStreamToken.parse_stream_token(from_key).stream
|
from_key = RoomStreamToken.parse_stream_token(from_key).stream
|
||||||
return set(
|
return {
|
||||||
room_id
|
room_id
|
||||||
for room_id in room_ids
|
for room_id in room_ids
|
||||||
if self._events_stream_cache.has_entity_changed(room_id, from_key)
|
if self._events_stream_cache.has_entity_changed(room_id, from_key)
|
||||||
)
|
}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_room_events_stream_for_room(
|
def get_room_events_stream_for_room(
|
||||||
@ -679,11 +679,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
events_before = yield self.get_events_as_list(
|
events_before = yield self.get_events_as_list(
|
||||||
[e for e in results["before"]["event_ids"]], get_prev_content=True
|
list(results["before"]["event_ids"]), get_prev_content=True
|
||||||
)
|
)
|
||||||
|
|
||||||
events_after = yield self.get_events_as_list(
|
events_after = yield self.get_events_as_list(
|
||||||
[e for e in results["after"]["event_ids"]], get_prev_content=True
|
list(results["after"]["event_ids"]), get_prev_content=True
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -63,9 +63,9 @@ class UserErasureWorkerStore(SQLBaseStore):
|
|||||||
retcols=("user_id",),
|
retcols=("user_id",),
|
||||||
desc="are_users_erased",
|
desc="are_users_erased",
|
||||||
)
|
)
|
||||||
erased_users = set(row["user_id"] for row in rows)
|
erased_users = {row["user_id"] for row in rows}
|
||||||
|
|
||||||
res = dict((u, u in erased_users) for u in user_ids)
|
res = {u: u in erased_users for u in user_ids}
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@ -520,11 +520,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||||||
retcols=("state_group",),
|
retcols=("state_group",),
|
||||||
)
|
)
|
||||||
|
|
||||||
remaining_state_groups = set(
|
remaining_state_groups = {
|
||||||
row["state_group"]
|
row["state_group"]
|
||||||
for row in rows
|
for row in rows
|
||||||
if row["state_group"] not in state_groups_to_delete
|
if row["state_group"] not in state_groups_to_delete
|
||||||
)
|
}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"[purge] de-delta-ing %i remaining state groups",
|
"[purge] de-delta-ing %i remaining state groups",
|
||||||
|
@ -554,8 +554,8 @@ class Database(object):
|
|||||||
Returns:
|
Returns:
|
||||||
A list of dicts where the key is the column header.
|
A list of dicts where the key is the column header.
|
||||||
"""
|
"""
|
||||||
col_headers = list(intern(str(column[0])) for column in cursor.description)
|
col_headers = [intern(str(column[0])) for column in cursor.description]
|
||||||
results = list(dict(zip(col_headers, row)) for row in cursor)
|
results = [dict(zip(col_headers, row)) for row in cursor]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def execute(self, desc, decoder, query, *args):
|
def execute(self, desc, decoder, query, *args):
|
||||||
|
@ -602,14 +602,14 @@ class EventsPersistenceStorage(object):
|
|||||||
event_id_to_state_group.update(event_to_groups)
|
event_id_to_state_group.update(event_to_groups)
|
||||||
|
|
||||||
# State groups of old_latest_event_ids
|
# State groups of old_latest_event_ids
|
||||||
old_state_groups = set(
|
old_state_groups = {
|
||||||
event_id_to_state_group[evid] for evid in old_latest_event_ids
|
event_id_to_state_group[evid] for evid in old_latest_event_ids
|
||||||
)
|
}
|
||||||
|
|
||||||
# State groups of new_latest_event_ids
|
# State groups of new_latest_event_ids
|
||||||
new_state_groups = set(
|
new_state_groups = {
|
||||||
event_id_to_state_group[evid] for evid in new_latest_event_ids
|
event_id_to_state_group[evid] for evid in new_latest_event_ids
|
||||||
)
|
}
|
||||||
|
|
||||||
# If they old and new groups are the same then we don't need to do
|
# If they old and new groups are the same then we don't need to do
|
||||||
# anything.
|
# anything.
|
||||||
|
@ -345,9 +345,9 @@ def _upgrade_existing_database(
|
|||||||
"Could not open delta dir for version %d: %s" % (v, directory)
|
"Could not open delta dir for version %d: %s" % (v, directory)
|
||||||
)
|
)
|
||||||
|
|
||||||
duplicates = set(
|
duplicates = {
|
||||||
file_name for file_name, count in file_name_counter.items() if count > 1
|
file_name for file_name, count in file_name_counter.items() if count > 1
|
||||||
)
|
}
|
||||||
if duplicates:
|
if duplicates:
|
||||||
# We don't support using the same file name in the same delta version.
|
# We don't support using the same file name in the same delta version.
|
||||||
raise PrepareDatabaseException(
|
raise PrepareDatabaseException(
|
||||||
@ -454,7 +454,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
|
|||||||
),
|
),
|
||||||
(modname,),
|
(modname,),
|
||||||
)
|
)
|
||||||
applied_deltas = set(d for d, in cur)
|
applied_deltas = {d for d, in cur}
|
||||||
for (name, stream) in names_and_streams:
|
for (name, stream) in names_and_streams:
|
||||||
if name in applied_deltas:
|
if name in applied_deltas:
|
||||||
continue
|
continue
|
||||||
|
@ -30,7 +30,7 @@ def freeze(o):
|
|||||||
return o
|
return o
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return tuple([freeze(i) for i in o])
|
return tuple(freeze(i) for i in o)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ def filter_events_for_client(
|
|||||||
"""
|
"""
|
||||||
# Filter out events that have been soft failed so that we don't relay them
|
# Filter out events that have been soft failed so that we don't relay them
|
||||||
# to clients.
|
# to clients.
|
||||||
events = list(e for e in events if not e.internal_metadata.is_soft_failed())
|
events = [e for e in events if not e.internal_metadata.is_soft_failed()]
|
||||||
|
|
||||||
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
|
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
|
||||||
event_id_to_state = yield storage.state.get_state_for_events(
|
event_id_to_state = yield storage.state.get_state_for_events(
|
||||||
@ -97,7 +97,7 @@ def filter_events_for_client(
|
|||||||
erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
|
erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
|
||||||
|
|
||||||
if apply_retention_policies:
|
if apply_retention_policies:
|
||||||
room_ids = set(e.room_id for e in events)
|
room_ids = {e.room_id for e in events}
|
||||||
retention_policies = {}
|
retention_policies = {}
|
||||||
|
|
||||||
for room_id in room_ids:
|
for room_id in room_ids:
|
||||||
|
@ -48,7 +48,7 @@ class ConfigGenerationTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertSetEqual(
|
self.assertSetEqual(
|
||||||
set(["homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"]),
|
{"homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"},
|
||||||
set(os.listdir(self.dir)),
|
set(os.listdir(self.dir)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(members, set(["@user:other.example.com", u1]))
|
self.assertEqual(members, {"@user:other.example.com", u1})
|
||||||
self.assertEqual(len(channel.json_body["pdus"]), 6)
|
self.assertEqual(len(channel.json_body["pdus"]), 6)
|
||||||
|
|
||||||
def test_needs_to_be_in_room(self):
|
def test_needs_to_be_in_room(self):
|
||||||
|
@ -338,7 +338,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
new_state = handle_timeout(
|
new_state = handle_timeout(
|
||||||
state, is_mine=True, syncing_user_ids=set([user_id]), now=now
|
state, is_mine=True, syncing_user_ids={user_id}, now=now
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsNotNone(new_state)
|
self.assertIsNotNone(new_state)
|
||||||
@ -579,7 +579,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(expected_state.state, PresenceState.ONLINE)
|
self.assertEqual(expected_state.state, PresenceState.ONLINE)
|
||||||
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
|
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
|
||||||
destinations=set(("server2", "server3")), states=[expected_state]
|
destinations={"server2", "server3"}, states=[expected_state]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_new_user(self, room_id, user_id):
|
def _add_new_user(self, room_id, user_id):
|
||||||
|
@ -129,12 +129,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||||||
hs.get_auth().check_user_in_room = check_user_in_room
|
hs.get_auth().check_user_in_room = check_user_in_room
|
||||||
|
|
||||||
def get_joined_hosts_for_room(room_id):
|
def get_joined_hosts_for_room(room_id):
|
||||||
return set(member.domain for member in self.room_members)
|
return {member.domain for member in self.room_members}
|
||||||
|
|
||||||
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
|
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
|
||||||
|
|
||||||
def get_current_users_in_room(room_id):
|
def get_current_users_in_room(room_id):
|
||||||
return set(str(u) for u in self.room_members)
|
return {str(u) for u in self.room_members}
|
||||||
|
|
||||||
hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
|
hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
|
||||||
|
|
||||||
@ -257,7 +257,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
member = RoomMember(ROOM_ID, U_APPLE.to_string())
|
member = RoomMember(ROOM_ID, U_APPLE.to_string())
|
||||||
self.handler._member_typing_until[member] = 1002000
|
self.handler._member_typing_until[member] = 1002000
|
||||||
self.handler._room_typing[ROOM_ID] = set([U_APPLE.to_string()])
|
self.handler._room_typing[ROOM_ID] = {U_APPLE.to_string()}
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 0)
|
self.assertEquals(self.event_source.get_current_key(), 0)
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
|||||||
public_users = self.get_users_in_public_rooms()
|
public_users = self.get_users_in_public_rooms()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)])
|
self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
|
||||||
)
|
)
|
||||||
self.assertEqual(public_users, [])
|
self.assertEqual(public_users, [])
|
||||||
|
|
||||||
@ -169,7 +169,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
|||||||
public_users = self.get_users_in_public_rooms()
|
public_users = self.get_users_in_public_rooms()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)])
|
self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
|
||||||
)
|
)
|
||||||
self.assertEqual(public_users, [])
|
self.assertEqual(public_users, [])
|
||||||
|
|
||||||
@ -226,7 +226,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
|||||||
public_users = self.get_users_in_public_rooms()
|
public_users = self.get_users_in_public_rooms()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)])
|
self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
|
||||||
)
|
)
|
||||||
self.assertEqual(public_users, [])
|
self.assertEqual(public_users, [])
|
||||||
|
|
||||||
@ -358,12 +358,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
|||||||
public_users = self.get_users_in_public_rooms()
|
public_users = self.get_users_in_public_rooms()
|
||||||
|
|
||||||
# User 1 and User 2 are in the same public room
|
# User 1 and User 2 are in the same public room
|
||||||
self.assertEqual(set(public_users), set([(u1, room), (u2, room)]))
|
self.assertEqual(set(public_users), {(u1, room), (u2, room)})
|
||||||
|
|
||||||
# User 1 and User 3 share private rooms
|
# User 1 and User 3 share private rooms
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self._compress_shared(shares_private),
|
self._compress_shared(shares_private),
|
||||||
set([(u1, u3, private_room), (u3, u1, private_room)]),
|
{(u1, u3, private_room), (u3, u1, private_room)},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_initial_share_all_users(self):
|
def test_initial_share_all_users(self):
|
||||||
@ -398,7 +398,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
# No users share rooms
|
# No users share rooms
|
||||||
self.assertEqual(public_users, [])
|
self.assertEqual(public_users, [])
|
||||||
self.assertEqual(self._compress_shared(shares_private), set([]))
|
self.assertEqual(self._compress_shared(shares_private), set())
|
||||||
|
|
||||||
# Despite not sharing a room, search_all_users means we get a search
|
# Despite not sharing a room, search_all_users means we get a search
|
||||||
# result.
|
# result.
|
||||||
|
@ -163,7 +163,7 @@ class EmailPusherTests(HomeserverTestCase):
|
|||||||
|
|
||||||
# Get the stream ordering before it gets sent
|
# Get the stream ordering before it gets sent
|
||||||
pushers = self.get_success(
|
pushers = self.get_success(
|
||||||
self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
|
self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
pushers = list(pushers)
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
@ -174,7 +174,7 @@ class EmailPusherTests(HomeserverTestCase):
|
|||||||
|
|
||||||
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
|
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
|
||||||
pushers = self.get_success(
|
pushers = self.get_success(
|
||||||
self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
|
self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
pushers = list(pushers)
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
@ -192,7 +192,7 @@ class EmailPusherTests(HomeserverTestCase):
|
|||||||
|
|
||||||
# The stream ordering has increased
|
# The stream ordering has increased
|
||||||
pushers = self.get_success(
|
pushers = self.get_success(
|
||||||
self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
|
self.hs.get_datastore().get_pushers_by({"user_name": self.user_id})
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
pushers = list(pushers)
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
|
@ -102,7 +102,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||||||
|
|
||||||
# Get the stream ordering before it gets sent
|
# Get the stream ordering before it gets sent
|
||||||
pushers = self.get_success(
|
pushers = self.get_success(
|
||||||
self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
|
self.hs.get_datastore().get_pushers_by({"user_name": user_id})
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
pushers = list(pushers)
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
@ -113,7 +113,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||||||
|
|
||||||
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
|
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
|
||||||
pushers = self.get_success(
|
pushers = self.get_success(
|
||||||
self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
|
self.hs.get_datastore().get_pushers_by({"user_name": user_id})
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
pushers = list(pushers)
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
@ -132,7 +132,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||||||
|
|
||||||
# The stream ordering has increased
|
# The stream ordering has increased
|
||||||
pushers = self.get_success(
|
pushers = self.get_success(
|
||||||
self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
|
self.hs.get_datastore().get_pushers_by({"user_name": user_id})
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
pushers = list(pushers)
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
@ -152,7 +152,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||||||
|
|
||||||
# The stream ordering has increased, again
|
# The stream ordering has increased, again
|
||||||
pushers = self.get_success(
|
pushers = self.get_success(
|
||||||
self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
|
self.hs.get_datastore().get_pushers_by({"user_name": user_id})
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
pushers = list(pushers)
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
|
@ -40,16 +40,14 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
set(
|
{
|
||||||
[
|
|
||||||
"next_batch",
|
"next_batch",
|
||||||
"rooms",
|
"rooms",
|
||||||
"presence",
|
"presence",
|
||||||
"account_data",
|
"account_data",
|
||||||
"to_device",
|
"to_device",
|
||||||
"device_lists",
|
"device_lists",
|
||||||
]
|
}.issubset(set(channel.json_body.keys()))
|
||||||
).issubset(set(channel.json_body.keys()))
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_sync_presence_disabled(self):
|
def test_sync_presence_disabled(self):
|
||||||
@ -63,9 +61,13 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
set(
|
{
|
||||||
["next_batch", "rooms", "account_data", "to_device", "device_lists"]
|
"next_batch",
|
||||||
).issubset(set(channel.json_body.keys()))
|
"rooms",
|
||||||
|
"account_data",
|
||||||
|
"to_device",
|
||||||
|
"device_lists",
|
||||||
|
}.issubset(set(channel.json_body.keys()))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -373,7 +373,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set(self._dump_to_tuple(res)),
|
set(self._dump_to_tuple(res)),
|
||||||
set([(1, "user1", "hello"), (2, "user2", "there")]),
|
{(1, "user1", "hello"), (2, "user2", "there")},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update only user2
|
# Update only user2
|
||||||
@ -400,5 +400,5 @@ class UpsertManyTests(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set(self._dump_to_tuple(res)),
|
set(self._dump_to_tuple(res)),
|
||||||
set([(1, "user1", "hello"), (2, "user2", "bleb")]),
|
{(1, "user1", "hello"), (2, "user2", "bleb")},
|
||||||
)
|
)
|
||||||
|
@ -69,14 +69,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def _add_appservice(self, as_token, id, url, hs_token, sender):
|
def _add_appservice(self, as_token, id, url, hs_token, sender):
|
||||||
as_yaml = dict(
|
as_yaml = {
|
||||||
url=url,
|
"url": url,
|
||||||
as_token=as_token,
|
"as_token": as_token,
|
||||||
hs_token=hs_token,
|
"hs_token": hs_token,
|
||||||
id=id,
|
"id": id,
|
||||||
sender_localpart=sender,
|
"sender_localpart": sender,
|
||||||
namespaces={},
|
"namespaces": {},
|
||||||
)
|
}
|
||||||
# use the token as the filename
|
# use the token as the filename
|
||||||
with open(as_token, "w") as outfile:
|
with open(as_token, "w") as outfile:
|
||||||
outfile.write(yaml.dump(as_yaml))
|
outfile.write(yaml.dump(as_yaml))
|
||||||
@ -135,14 +135,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _add_service(self, url, as_token, id):
|
def _add_service(self, url, as_token, id):
|
||||||
as_yaml = dict(
|
as_yaml = {
|
||||||
url=url,
|
"url": url,
|
||||||
as_token=as_token,
|
"as_token": as_token,
|
||||||
hs_token="something",
|
"hs_token": "something",
|
||||||
id=id,
|
"id": id,
|
||||||
sender_localpart="a_sender",
|
"sender_localpart": "a_sender",
|
||||||
namespaces={},
|
"namespaces": {},
|
||||||
)
|
}
|
||||||
# use the token as the filename
|
# use the token as the filename
|
||||||
with open(as_token, "w") as outfile:
|
with open(as_token, "w") as outfile:
|
||||||
outfile.write(yaml.dump(as_yaml))
|
outfile.write(yaml.dump(as_yaml))
|
||||||
@ -384,8 +384,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEquals(2, len(services))
|
self.assertEquals(2, len(services))
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
set([self.as_list[2]["id"], self.as_list[0]["id"]]),
|
{self.as_list[2]["id"], self.as_list[0]["id"]},
|
||||||
set([services[0].id, services[1].id]),
|
{services[0].id, services[1].id},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -134,7 +134,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
|||||||
latest_event_ids = self.get_success(
|
latest_event_ids = self.get_success(
|
||||||
self.store.get_latest_event_ids_in_room(self.room_id)
|
self.store.get_latest_event_ids_in_room(self.room_id)
|
||||||
)
|
)
|
||||||
self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b)))
|
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
|
||||||
|
|
||||||
# Run the background update and check it did the right thing
|
# Run the background update and check it did the right thing
|
||||||
self.run_background_update()
|
self.run_background_update()
|
||||||
@ -172,7 +172,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
|||||||
latest_event_ids = self.get_success(
|
latest_event_ids = self.get_success(
|
||||||
self.store.get_latest_event_ids_in_room(self.room_id)
|
self.store.get_latest_event_ids_in_room(self.room_id)
|
||||||
)
|
)
|
||||||
self.assertEqual(set(latest_event_ids), set((event_id_a, event_id_b)))
|
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
|
||||||
|
|
||||||
# Run the background update and check it did the right thing
|
# Run the background update and check it did the right thing
|
||||||
self.run_background_update()
|
self.run_background_update()
|
||||||
@ -227,9 +227,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
|||||||
latest_event_ids = self.get_success(
|
latest_event_ids = self.get_success(
|
||||||
self.store.get_latest_event_ids_in_room(self.room_id)
|
self.store.get_latest_event_ids_in_room(self.room_id)
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c})
|
||||||
set(latest_event_ids), set((event_id_a, event_id_b, event_id_c))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run the background update and check it did the right thing
|
# Run the background update and check it did the right thing
|
||||||
self.run_background_update()
|
self.run_background_update()
|
||||||
@ -237,7 +235,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
|||||||
latest_event_ids = self.get_success(
|
latest_event_ids = self.get_success(
|
||||||
self.store.get_latest_event_ids_in_room(self.room_id)
|
self.store.get_latest_event_ids_in_room(self.room_id)
|
||||||
)
|
)
|
||||||
self.assertEqual(set(latest_event_ids), set([event_id_b, event_id_c]))
|
self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c})
|
||||||
|
|
||||||
|
|
||||||
class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
|
class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
|
||||||
|
@ -59,8 +59,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
expected = set(
|
expected = {
|
||||||
[
|
|
||||||
b'synapse_forward_extremities_bucket{le="1.0"} 0.0',
|
b'synapse_forward_extremities_bucket{le="1.0"} 0.0',
|
||||||
b'synapse_forward_extremities_bucket{le="2.0"} 2.0',
|
b'synapse_forward_extremities_bucket{le="2.0"} 2.0',
|
||||||
b'synapse_forward_extremities_bucket{le="3.0"} 2.0',
|
b'synapse_forward_extremities_bucket{le="3.0"} 2.0',
|
||||||
@ -76,7 +75,6 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
|
|||||||
b'synapse_forward_extremities_bucket{le="+Inf"} 3.0',
|
b'synapse_forward_extremities_bucket{le="+Inf"} 3.0',
|
||||||
b"synapse_forward_extremities_count 3.0",
|
b"synapse_forward_extremities_count 3.0",
|
||||||
b"synapse_forward_extremities_sum 10.0",
|
b"synapse_forward_extremities_sum 10.0",
|
||||||
]
|
}
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(items, expected)
|
self.assertEqual(items, expected)
|
||||||
|
@ -394,7 +394,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
|||||||
) = self.state_datastore._state_group_cache.get(group)
|
) = self.state_datastore._state_group_cache.get(group)
|
||||||
|
|
||||||
self.assertEqual(is_all, False)
|
self.assertEqual(is_all, False)
|
||||||
self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
|
self.assertEqual(known_absent, {(e1.type, e1.state_key)})
|
||||||
self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
|
self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
|
||||||
|
|
||||||
############################################
|
############################################
|
||||||
|
@ -254,9 +254,7 @@ class StateTestCase(unittest.TestCase):
|
|||||||
ctx_d = context_store["D"]
|
ctx_d = context_store["D"]
|
||||||
|
|
||||||
prev_state_ids = yield ctx_d.get_prev_state_ids()
|
prev_state_ids = yield ctx_d.get_prev_state_ids()
|
||||||
self.assertSetEqual(
|
self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
|
||||||
{"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
|
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
|
||||||
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
|
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
|
||||||
@ -313,9 +311,7 @@ class StateTestCase(unittest.TestCase):
|
|||||||
ctx_e = context_store["E"]
|
ctx_e = context_store["E"]
|
||||||
|
|
||||||
prev_state_ids = yield ctx_e.get_prev_state_ids()
|
prev_state_ids = yield ctx_e.get_prev_state_ids()
|
||||||
self.assertSetEqual(
|
self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
|
||||||
{"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
|
|
||||||
)
|
|
||||||
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
|
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
|
||||||
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
|
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
|
||||||
|
|
||||||
@ -388,9 +384,7 @@ class StateTestCase(unittest.TestCase):
|
|||||||
ctx_d = context_store["D"]
|
ctx_d = context_store["D"]
|
||||||
|
|
||||||
prev_state_ids = yield ctx_d.get_prev_state_ids()
|
prev_state_ids = yield ctx_d.get_prev_state_ids()
|
||||||
self.assertSetEqual(
|
self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
|
||||||
{"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
|
self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
|
||||||
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
|
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
|
||||||
@ -482,7 +476,7 @@ class StateTestCase(unittest.TestCase):
|
|||||||
current_state_ids = yield context.get_current_state_ids()
|
current_state_ids = yield context.get_current_state_ids()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set([e.event_id for e in old_state]), set(current_state_ids.values())
|
{e.event_id for e in old_state}, set(current_state_ids.values())
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(group_name, context.state_group)
|
self.assertEqual(group_name, context.state_group)
|
||||||
@ -513,9 +507,7 @@ class StateTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids()
|
prev_state_ids = yield context.get_prev_state_ids()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
|
||||||
set([e.event_id for e in old_state]), set(prev_state_ids.values())
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsNotNone(context.state_group)
|
self.assertIsNotNone(context.state_group)
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ class StreamChangeCacheTests(unittest.TestCase):
|
|||||||
# If we update an existing entity, it keeps the two existing entities
|
# If we update an existing entity, it keeps the two existing entities
|
||||||
cache.entity_has_changed("bar@baz.net", 5)
|
cache.entity_has_changed("bar@baz.net", 5)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set(["bar@baz.net", "user@elsewhere.org"]), set(cache._entity_to_key)
|
{"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_all_entities_changed(self):
|
def test_get_all_entities_changed(self):
|
||||||
@ -137,7 +137,7 @@ class StreamChangeCacheTests(unittest.TestCase):
|
|||||||
cache.get_entities_changed(
|
cache.get_entities_changed(
|
||||||
["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2
|
["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2
|
||||||
),
|
),
|
||||||
set(["bar@baz.net", "user@elsewhere.org"]),
|
{"bar@baz.net", "user@elsewhere.org"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Query all the entries mid-way through the stream, but include one
|
# Query all the entries mid-way through the stream, but include one
|
||||||
@ -153,7 +153,7 @@ class StreamChangeCacheTests(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
stream_pos=2,
|
stream_pos=2,
|
||||||
),
|
),
|
||||||
set(["bar@baz.net", "user@elsewhere.org"]),
|
{"bar@baz.net", "user@elsewhere.org"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Query all the entries, but before the first known point. We will get
|
# Query all the entries, but before the first known point. We will get
|
||||||
@ -168,21 +168,13 @@ class StreamChangeCacheTests(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
stream_pos=0,
|
stream_pos=0,
|
||||||
),
|
),
|
||||||
set(
|
{"user@foo.com", "bar@baz.net", "user@elsewhere.org", "not@here.website"},
|
||||||
[
|
|
||||||
"user@foo.com",
|
|
||||||
"bar@baz.net",
|
|
||||||
"user@elsewhere.org",
|
|
||||||
"not@here.website",
|
|
||||||
]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Query a subset of the entries mid-way through the stream. We should
|
# Query a subset of the entries mid-way through the stream. We should
|
||||||
# only get back the subset.
|
# only get back the subset.
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
cache.get_entities_changed(["bar@baz.net"], stream_pos=2),
|
cache.get_entities_changed(["bar@baz.net"], stream_pos=2), {"bar@baz.net"},
|
||||||
set(["bar@baz.net"]),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_max_pos(self):
|
def test_max_pos(self):
|
||||||
|
1
tox.ini
1
tox.ini
@ -123,6 +123,7 @@ skip_install = True
|
|||||||
basepython = python3.6
|
basepython = python3.6
|
||||||
deps =
|
deps =
|
||||||
flake8
|
flake8
|
||||||
|
flake8-comprehensions
|
||||||
black==19.10b0 # We pin so that our tests don't start failing on new releases of black.
|
black==19.10b0 # We pin so that our tests don't start failing on new releases of black.
|
||||||
commands =
|
commands =
|
||||||
python -m black --check --diff .
|
python -m black --check --diff .
|
||||||
|
Loading…
Reference in New Issue
Block a user