mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-10-01 08:25:44 -04:00
Merge remote-tracking branch 'origin/develop' into markjh/direct_to_device
This commit is contained in:
commit
4bbef62124
@ -52,7 +52,7 @@ class Auth(object):
|
||||
self.state = hs.get_state_handler()
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
|
||||
# Docs for these currently lives at
|
||||
# https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
|
||||
# github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
|
||||
# In addition, we have type == delete_pusher which grants access only to
|
||||
# delete pushers.
|
||||
self._KNOWN_CAVEAT_PREFIXES = set([
|
||||
@ -63,6 +63,17 @@ class Auth(object):
|
||||
"user_id = ",
|
||||
])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_from_context(self, event, context, do_sig_check=True):
|
||||
auth_events_ids = yield self.compute_auth_events(
|
||||
event, context.current_state_ids, for_verification=True,
|
||||
)
|
||||
auth_events = yield self.store.get_events(auth_events_ids)
|
||||
auth_events = {
|
||||
(e.type, e.state_key): e for e in auth_events.values()
|
||||
}
|
||||
self.check(event, auth_events=auth_events, do_sig_check=False)
|
||||
|
||||
def check(self, event, auth_events, do_sig_check=True):
|
||||
""" Checks if this event is correctly authed.
|
||||
|
||||
@ -267,21 +278,15 @@ class Auth(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_host_in_room(self, room_id, host):
|
||||
curr_state = yield self.state.get_current_state(room_id)
|
||||
with Measure(self.clock, "check_host_in_room"):
|
||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
|
||||
for event in curr_state.values():
|
||||
if event.type == EventTypes.Member:
|
||||
try:
|
||||
if get_domain_from_id(event.state_key) != host:
|
||||
continue
|
||||
except:
|
||||
logger.warn("state_key not user_id: %s", event.state_key)
|
||||
continue
|
||||
group, curr_state_ids = yield self.state.resolve_state_groups(
|
||||
room_id, latest_event_ids
|
||||
)
|
||||
|
||||
if event.content["membership"] == Membership.JOIN:
|
||||
defer.returnValue(True)
|
||||
|
||||
defer.returnValue(False)
|
||||
ret = yield self.store.is_host_joined(room_id, host, group, curr_state_ids)
|
||||
defer.returnValue(ret)
|
||||
|
||||
def check_event_sender_in_room(self, event, auth_events):
|
||||
key = (EventTypes.Member, event.user_id, )
|
||||
@ -847,7 +852,7 @@ class Auth(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_auth_events(self, builder, context):
|
||||
auth_ids = self.compute_auth_events(builder, context.current_state)
|
||||
auth_ids = yield self.compute_auth_events(builder, context.current_state_ids)
|
||||
|
||||
auth_events_entries = yield self.store.add_event_hashes(
|
||||
auth_ids
|
||||
@ -855,30 +860,32 @@ class Auth(object):
|
||||
|
||||
builder.auth_events = auth_events_entries
|
||||
|
||||
def compute_auth_events(self, event, current_state):
|
||||
@defer.inlineCallbacks
|
||||
def compute_auth_events(self, event, current_state_ids, for_verification=False):
|
||||
if event.type == EventTypes.Create:
|
||||
return []
|
||||
defer.returnValue([])
|
||||
|
||||
auth_ids = []
|
||||
|
||||
key = (EventTypes.PowerLevels, "", )
|
||||
power_level_event = current_state.get(key)
|
||||
power_level_event_id = current_state_ids.get(key)
|
||||
|
||||
if power_level_event:
|
||||
auth_ids.append(power_level_event.event_id)
|
||||
if power_level_event_id:
|
||||
auth_ids.append(power_level_event_id)
|
||||
|
||||
key = (EventTypes.JoinRules, "", )
|
||||
join_rule_event = current_state.get(key)
|
||||
join_rule_event_id = current_state_ids.get(key)
|
||||
|
||||
key = (EventTypes.Member, event.user_id, )
|
||||
member_event = current_state.get(key)
|
||||
member_event_id = current_state_ids.get(key)
|
||||
|
||||
key = (EventTypes.Create, "", )
|
||||
create_event = current_state.get(key)
|
||||
if create_event:
|
||||
auth_ids.append(create_event.event_id)
|
||||
create_event_id = current_state_ids.get(key)
|
||||
if create_event_id:
|
||||
auth_ids.append(create_event_id)
|
||||
|
||||
if join_rule_event:
|
||||
if join_rule_event_id:
|
||||
join_rule_event = yield self.store.get_event(join_rule_event_id)
|
||||
join_rule = join_rule_event.content.get("join_rule")
|
||||
is_public = join_rule == JoinRules.PUBLIC if join_rule else False
|
||||
else:
|
||||
@ -887,15 +894,21 @@ class Auth(object):
|
||||
if event.type == EventTypes.Member:
|
||||
e_type = event.content["membership"]
|
||||
if e_type in [Membership.JOIN, Membership.INVITE]:
|
||||
if join_rule_event:
|
||||
auth_ids.append(join_rule_event.event_id)
|
||||
if join_rule_event_id:
|
||||
auth_ids.append(join_rule_event_id)
|
||||
|
||||
if e_type == Membership.JOIN:
|
||||
if member_event and not is_public:
|
||||
auth_ids.append(member_event.event_id)
|
||||
if member_event_id and not is_public:
|
||||
auth_ids.append(member_event_id)
|
||||
else:
|
||||
if member_event:
|
||||
auth_ids.append(member_event.event_id)
|
||||
if member_event_id:
|
||||
auth_ids.append(member_event_id)
|
||||
|
||||
if for_verification:
|
||||
key = (EventTypes.Member, event.state_key, )
|
||||
existing_event_id = current_state_ids.get(key)
|
||||
if existing_event_id:
|
||||
auth_ids.append(existing_event_id)
|
||||
|
||||
if e_type == Membership.INVITE:
|
||||
if "third_party_invite" in event.content:
|
||||
@ -903,14 +916,15 @@ class Auth(object):
|
||||
EventTypes.ThirdPartyInvite,
|
||||
event.content["third_party_invite"]["signed"]["token"]
|
||||
)
|
||||
third_party_invite = current_state.get(key)
|
||||
if third_party_invite:
|
||||
auth_ids.append(third_party_invite.event_id)
|
||||
elif member_event:
|
||||
third_party_invite_id = current_state_ids.get(key)
|
||||
if third_party_invite_id:
|
||||
auth_ids.append(third_party_invite_id)
|
||||
elif member_event_id:
|
||||
member_event = yield self.store.get_event(member_event_id)
|
||||
if member_event.content["membership"] == Membership.JOIN:
|
||||
auth_ids.append(member_event.event_id)
|
||||
|
||||
return auth_ids
|
||||
defer.returnValue(auth_ids)
|
||||
|
||||
def _get_send_level(self, etype, state_key, auth_events):
|
||||
key = (EventTypes.PowerLevels, "", )
|
||||
|
@ -85,3 +85,8 @@ class RoomCreationPreset(object):
|
||||
PRIVATE_CHAT = "private_chat"
|
||||
PUBLIC_CHAT = "public_chat"
|
||||
TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
|
||||
|
||||
|
||||
class ThirdPartyEntityKind(object):
|
||||
USER = "user"
|
||||
LOCATION = "location"
|
||||
|
@ -25,4 +25,3 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1"
|
||||
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
|
||||
MEDIA_PREFIX = "/_matrix/media/r0"
|
||||
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
|
||||
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"
|
||||
|
@ -14,11 +14,11 @@
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import ThirdPartyEntityKind
|
||||
from synapse.api.errors import CodeMessageException
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.types import ThirdPartyEntityKind
|
||||
|
||||
import logging
|
||||
import urllib
|
||||
@ -29,6 +29,9 @@ logger = logging.getLogger(__name__)
|
||||
HOUR_IN_MS = 60 * 60 * 1000
|
||||
|
||||
|
||||
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
|
||||
|
||||
|
||||
def _is_valid_3pe_result(r, field):
|
||||
if not isinstance(r, dict):
|
||||
return False
|
||||
@ -103,16 +106,20 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||
@defer.inlineCallbacks
|
||||
def query_3pe(self, service, kind, protocol, fields):
|
||||
if kind == ThirdPartyEntityKind.USER:
|
||||
uri = "%s/thirdparty/user/%s" % (service.url, urllib.quote(protocol))
|
||||
required_field = "userid"
|
||||
elif kind == ThirdPartyEntityKind.LOCATION:
|
||||
uri = "%s/thirdparty/location/%s" % (service.url, urllib.quote(protocol))
|
||||
required_field = "alias"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unrecognised 'kind' argument %r to query_3pe()", kind
|
||||
)
|
||||
|
||||
uri = "%s%s/thirdparty/%s/%s" % (
|
||||
service.url,
|
||||
APP_SERVICE_PREFIX,
|
||||
kind,
|
||||
urllib.quote(protocol)
|
||||
)
|
||||
try:
|
||||
response = yield self.get_json(uri, fields)
|
||||
if not isinstance(response, list):
|
||||
@ -140,7 +147,11 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||
def get_3pe_protocol(self, service, protocol):
|
||||
@defer.inlineCallbacks
|
||||
def _get():
|
||||
uri = "%s/thirdparty/protocol/%s" % (service.url, urllib.quote(protocol))
|
||||
uri = "%s%s/thirdparty/protocol/%s" % (
|
||||
service.url,
|
||||
APP_SERVICE_PREFIX,
|
||||
urllib.quote(protocol)
|
||||
)
|
||||
try:
|
||||
defer.returnValue((yield self.get_json(uri, {})))
|
||||
except Exception as ex:
|
||||
|
@ -99,7 +99,7 @@ class EventBase(object):
|
||||
|
||||
return d
|
||||
|
||||
def get(self, key, default):
|
||||
def get(self, key, default=None):
|
||||
return self._event_dict.get(key, default)
|
||||
|
||||
def get_internal_metadata_dict(self):
|
||||
|
@ -15,9 +15,8 @@
|
||||
|
||||
|
||||
class EventContext(object):
|
||||
|
||||
def __init__(self, current_state=None):
|
||||
self.current_state = current_state
|
||||
def __init__(self, current_state_ids=None):
|
||||
self.current_state_ids = current_state_ids
|
||||
self.state_group = None
|
||||
self.rejected = False
|
||||
self.push_actions = []
|
||||
|
@ -65,33 +65,21 @@ class BaseHandler(object):
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now)),
|
||||
)
|
||||
|
||||
def is_host_in_room(self, current_state):
|
||||
room_members = [
|
||||
(state_key, event.membership)
|
||||
for ((event_type, state_key), event) in current_state.items()
|
||||
if event_type == EventTypes.Member
|
||||
]
|
||||
if len(room_members) == 0:
|
||||
# Have we just created the room, and is this about to be the very
|
||||
# first member event?
|
||||
create_event = current_state.get(("m.room.create", ""))
|
||||
if create_event:
|
||||
return True
|
||||
for (state_key, membership) in room_members:
|
||||
if (
|
||||
self.hs.is_mine_id(state_key)
|
||||
and membership == Membership.JOIN
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def maybe_kick_guest_users(self, event, current_state):
|
||||
def maybe_kick_guest_users(self, event, context=None):
|
||||
# Technically this function invalidates current_state by changing it.
|
||||
# Hopefully this isn't that important to the caller.
|
||||
if event.type == EventTypes.GuestAccess:
|
||||
guest_access = event.content.get("guest_access", "forbidden")
|
||||
if guest_access != "can_join":
|
||||
if context:
|
||||
current_state = yield self.store.get_events(
|
||||
context.current_state_ids.values()
|
||||
)
|
||||
current_state = current_state.values()
|
||||
else:
|
||||
current_state = yield self.store.get_current_state(event.room_id)
|
||||
logger.info("maybe_kick_guest_users %r", current_state)
|
||||
yield self.kick_guest_users(current_state)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -29,6 +29,7 @@ from synapse.util import unwrapFirstError
|
||||
from synapse.util.logcontext import (
|
||||
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
|
||||
)
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.frozenutils import unfreeze
|
||||
@ -217,17 +218,28 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
if event.membership == Membership.JOIN:
|
||||
prev_state = context.current_state.get((event.type, event.state_key))
|
||||
if not prev_state or prev_state.membership != Membership.JOIN:
|
||||
# Only fire user_joined_room if the user has acutally
|
||||
# joined the room. Don't bother if the user is just
|
||||
# changing their profile info.
|
||||
# Only fire user_joined_room if the user has acutally
|
||||
# joined the room. Don't bother if the user is just
|
||||
# changing their profile info.
|
||||
newly_joined = True
|
||||
prev_state_id = context.current_state_ids.get(
|
||||
(event.type, event.state_key)
|
||||
)
|
||||
if prev_state_id:
|
||||
prev_state = yield self.store.get_event(
|
||||
prev_state_id, allow_none=True,
|
||||
)
|
||||
if prev_state and prev_state.membership == Membership.JOIN:
|
||||
newly_joined = False
|
||||
|
||||
if newly_joined:
|
||||
user = UserID.from_string(event.state_key)
|
||||
yield user_joined_room(self.distributor, user, event.room_id)
|
||||
|
||||
@measure_func("_filter_events_for_server")
|
||||
@defer.inlineCallbacks
|
||||
def _filter_events_for_server(self, server_name, room_id, events):
|
||||
event_to_state = yield self.store.get_state_for_events(
|
||||
event_to_state_ids = yield self.store.get_state_ids_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
types=(
|
||||
(EventTypes.RoomHistoryVisibility, ""),
|
||||
@ -235,6 +247,30 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
)
|
||||
|
||||
# We only want to pull out member events that correspond to the
|
||||
# server's domain.
|
||||
|
||||
def check_match(id):
|
||||
try:
|
||||
return server_name == get_domain_from_id(id)
|
||||
except:
|
||||
return False
|
||||
|
||||
event_map = yield self.store.get_events([
|
||||
e_id for key_to_eid in event_to_state_ids.values()
|
||||
for key, e_id in key_to_eid
|
||||
if key[0] != EventTypes.Member or check_match(key[1])
|
||||
])
|
||||
|
||||
event_to_state = {
|
||||
e_id: {
|
||||
key: event_map[inner_e_id]
|
||||
for key, inner_e_id in key_to_eid.items()
|
||||
if inner_e_id in event_map
|
||||
}
|
||||
for e_id, key_to_eid in event_to_state_ids.items()
|
||||
}
|
||||
|
||||
def redact_disallowed(event, state):
|
||||
if not state:
|
||||
return event
|
||||
@ -377,7 +413,9 @@ class FederationHandler(BaseHandler):
|
||||
)).addErrback(unwrapFirstError)
|
||||
auth_events.update({a.event_id: a for a in results if a})
|
||||
required_auth.update(
|
||||
a_id for event in results for a_id, _ in event.auth_events if event
|
||||
a_id
|
||||
for event in results if event
|
||||
for a_id, _ in event.auth_events
|
||||
)
|
||||
missing_auth = required_auth - set(auth_events)
|
||||
|
||||
@ -560,6 +598,18 @@ class FederationHandler(BaseHandler):
|
||||
]))
|
||||
states = dict(zip(event_ids, [s[1] for s in states]))
|
||||
|
||||
state_map = yield self.store.get_events(
|
||||
[e_id for ids in states.values() for e_id in ids],
|
||||
get_prev_content=False
|
||||
)
|
||||
states = {
|
||||
key: {
|
||||
k: state_map[e_id]
|
||||
for k, e_id in state_dict.items()
|
||||
if e_id in state_map
|
||||
} for key, state_dict in states.items()
|
||||
}
|
||||
|
||||
for e_id, _ in sorted_extremeties_tuple:
|
||||
likely_domains = get_domains_from_state(states[e_id])
|
||||
|
||||
@ -722,7 +772,7 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||
# when we get the event back in `on_send_join_request`
|
||||
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
|
||||
yield self.auth.check_from_context(event, context, do_sig_check=False)
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
@ -770,18 +820,11 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
new_pdu = event
|
||||
|
||||
destinations = set()
|
||||
|
||||
for k, s in context.current_state.items():
|
||||
try:
|
||||
if k[0] == EventTypes.Member:
|
||||
if s.content["membership"] == Membership.JOIN:
|
||||
destinations.add(get_domain_from_id(s.state_key))
|
||||
except:
|
||||
logger.warn(
|
||||
"Failed to get destination from event %s", s.event_id
|
||||
)
|
||||
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
destinations = yield message_handler.get_joined_hosts_for_room_from_state(
|
||||
context
|
||||
)
|
||||
destinations = set(destinations)
|
||||
destinations.discard(origin)
|
||||
|
||||
logger.debug(
|
||||
@ -792,13 +835,15 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
self.replication_layer.send_pdu(new_pdu, destinations)
|
||||
|
||||
state_ids = [e.event_id for e in context.current_state.values()]
|
||||
state_ids = context.current_state_ids.values()
|
||||
auth_chain = yield self.store.get_auth_chain(set(
|
||||
[event.event_id] + state_ids
|
||||
))
|
||||
|
||||
state = yield self.store.get_events(context.current_state_ids.values())
|
||||
|
||||
defer.returnValue({
|
||||
"state": context.current_state.values(),
|
||||
"state": state.values(),
|
||||
"auth_chain": auth_chain,
|
||||
})
|
||||
|
||||
@ -954,7 +999,7 @@ class FederationHandler(BaseHandler):
|
||||
try:
|
||||
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||
# when we get the event back in `on_send_leave_request`
|
||||
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
|
||||
yield self.auth.check_from_context(event, context, do_sig_check=False)
|
||||
except AuthError as e:
|
||||
logger.warn("Failed to create new leave %r because %s", event, e)
|
||||
raise e
|
||||
@ -998,18 +1043,11 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
new_pdu = event
|
||||
|
||||
destinations = set()
|
||||
|
||||
for k, s in context.current_state.items():
|
||||
try:
|
||||
if k[0] == EventTypes.Member:
|
||||
if s.content["membership"] == Membership.LEAVE:
|
||||
destinations.add(get_domain_from_id(s.state_key))
|
||||
except:
|
||||
logger.warn(
|
||||
"Failed to get destination from event %s", s.event_id
|
||||
)
|
||||
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
destinations = yield message_handler.get_joined_hosts_for_room_from_state(
|
||||
context
|
||||
)
|
||||
destinations = set(destinations)
|
||||
destinations.discard(origin)
|
||||
|
||||
logger.debug(
|
||||
@ -1294,7 +1332,13 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
|
||||
if not auth_events:
|
||||
auth_events = context.current_state
|
||||
auth_events_ids = yield self.auth.compute_auth_events(
|
||||
event, context.current_state_ids, for_verification=True,
|
||||
)
|
||||
auth_events = yield self.store.get_events(auth_events_ids)
|
||||
auth_events = {
|
||||
(e.type, e.state_key): e for e in auth_events.values()
|
||||
}
|
||||
|
||||
# This is a hack to fix some old rooms where the initial join event
|
||||
# didn't reference the create event in its auth events.
|
||||
@ -1320,8 +1364,7 @@ class FederationHandler(BaseHandler):
|
||||
context.rejected = RejectedReason.AUTH_ERROR
|
||||
|
||||
if event.type == EventTypes.GuestAccess:
|
||||
full_context = yield self.store.get_current_state(room_id=event.room_id)
|
||||
yield self.maybe_kick_guest_users(event, full_context)
|
||||
yield self.maybe_kick_guest_users(event)
|
||||
|
||||
defer.returnValue(context)
|
||||
|
||||
@ -1492,7 +1535,9 @@ class FederationHandler(BaseHandler):
|
||||
current_state = set(e.event_id for e in auth_events.values())
|
||||
different_auth = event_auth_events - current_state
|
||||
|
||||
context.current_state.update(auth_events)
|
||||
context.current_state_ids.update({
|
||||
k: a.event_id for k, a in auth_events.items()
|
||||
})
|
||||
context.state_group = None
|
||||
|
||||
if different_auth and not event.internal_metadata.is_outlier():
|
||||
@ -1514,8 +1559,8 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
if do_resolution:
|
||||
# 1. Get what we think is the auth chain.
|
||||
auth_ids = self.auth.compute_auth_events(
|
||||
event, context.current_state
|
||||
auth_ids = yield self.auth.compute_auth_events(
|
||||
event, context.current_state_ids
|
||||
)
|
||||
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
|
||||
|
||||
@ -1571,7 +1616,9 @@ class FederationHandler(BaseHandler):
|
||||
# 4. Look at rejects and their proofs.
|
||||
# TODO.
|
||||
|
||||
context.current_state.update(auth_events)
|
||||
context.current_state_ids.update({
|
||||
k: a.event_id for k, a in auth_events.items()
|
||||
})
|
||||
context.state_group = None
|
||||
|
||||
try:
|
||||
@ -1758,12 +1805,12 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
|
||||
try:
|
||||
self.auth.check(event, context.current_state)
|
||||
yield self.auth.check_from_context(event, context)
|
||||
except AuthError as e:
|
||||
logger.warn("Denying new third party invite %r because %s", event, e)
|
||||
raise e
|
||||
|
||||
yield self._check_signature(event, auth_events=context.current_state)
|
||||
yield self._check_signature(event, context)
|
||||
member_handler = self.hs.get_handlers().room_member_handler
|
||||
yield member_handler.send_membership_event(None, event, context)
|
||||
else:
|
||||
@ -1789,11 +1836,11 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
|
||||
try:
|
||||
self.auth.check(event, auth_events=context.current_state)
|
||||
self.auth.check_from_context(event, context)
|
||||
except AuthError as e:
|
||||
logger.warn("Denying third party invite %r because %s", event, e)
|
||||
raise e
|
||||
yield self._check_signature(event, auth_events=context.current_state)
|
||||
yield self._check_signature(event, context)
|
||||
|
||||
returned_invite = yield self.send_invite(origin, event)
|
||||
# TODO: Make sure the signatures actually are correct.
|
||||
@ -1807,7 +1854,12 @@ class FederationHandler(BaseHandler):
|
||||
EventTypes.ThirdPartyInvite,
|
||||
event.content["third_party_invite"]["signed"]["token"]
|
||||
)
|
||||
original_invite = context.current_state.get(key)
|
||||
original_invite = None
|
||||
original_invite_id = context.current_state_ids.get(key)
|
||||
if original_invite_id:
|
||||
original_invite = yield self.store.get_event(
|
||||
original_invite_id, allow_none=True
|
||||
)
|
||||
if not original_invite:
|
||||
logger.info(
|
||||
"Could not find invite event for third_party_invite - "
|
||||
@ -1824,13 +1876,13 @@ class FederationHandler(BaseHandler):
|
||||
defer.returnValue((event, context))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_signature(self, event, auth_events):
|
||||
def _check_signature(self, event, context):
|
||||
"""
|
||||
Checks that the signature in the event is consistent with its invite.
|
||||
|
||||
Args:
|
||||
event (Event): The m.room.member event to check
|
||||
auth_events (dict<(event type, state_key), event>):
|
||||
context (EventContext):
|
||||
|
||||
Raises:
|
||||
AuthError: if signature didn't match any keys, or key has been
|
||||
@ -1841,10 +1893,14 @@ class FederationHandler(BaseHandler):
|
||||
signed = event.content["third_party_invite"]["signed"]
|
||||
token = signed["token"]
|
||||
|
||||
invite_event = auth_events.get(
|
||||
invite_event_id = context.current_state_ids.get(
|
||||
(EventTypes.ThirdPartyInvite, token,)
|
||||
)
|
||||
|
||||
invite_event = None
|
||||
if invite_event_id:
|
||||
invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
|
||||
|
||||
if not invite_event:
|
||||
raise AuthError(403, "Could not find invite")
|
||||
|
||||
|
@ -30,6 +30,7 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo
|
||||
from synapse.util.caches.snapshot_cache import SnapshotCache
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from ._base import BaseHandler
|
||||
@ -248,7 +249,7 @@ class MessageHandler(BaseHandler):
|
||||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||
|
||||
if event.is_state():
|
||||
prev_state = self.deduplicate_state_event(event, context)
|
||||
prev_state = yield self.deduplicate_state_event(event, context)
|
||||
if prev_state is not None:
|
||||
defer.returnValue(prev_state)
|
||||
|
||||
@ -263,6 +264,7 @@ class MessageHandler(BaseHandler):
|
||||
presence = self.hs.get_presence_handler()
|
||||
yield presence.bump_presence_active_time(user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def deduplicate_state_event(self, event, context):
|
||||
"""
|
||||
Checks whether event is in the latest resolved state in context.
|
||||
@ -270,13 +272,17 @@ class MessageHandler(BaseHandler):
|
||||
If so, returns the version of the event in context.
|
||||
Otherwise, returns None.
|
||||
"""
|
||||
prev_event = context.current_state.get((event.type, event.state_key))
|
||||
prev_event_id = context.current_state_ids.get((event.type, event.state_key))
|
||||
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
||||
if not prev_event:
|
||||
return
|
||||
|
||||
if prev_event and event.user_id == prev_event.user_id:
|
||||
prev_content = encode_canonical_json(prev_event.content)
|
||||
next_content = encode_canonical_json(event.content)
|
||||
if prev_content == next_content:
|
||||
return prev_event
|
||||
return None
|
||||
defer.returnValue(prev_event)
|
||||
return
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_and_send_nonmember_event(
|
||||
@ -803,7 +809,7 @@ class MessageHandler(BaseHandler):
|
||||
|
||||
logger.debug(
|
||||
"Created event %s with current state: %s",
|
||||
event.event_id, context.current_state,
|
||||
event.event_id, context.current_state_ids,
|
||||
)
|
||||
|
||||
defer.returnValue(
|
||||
@ -826,12 +832,12 @@ class MessageHandler(BaseHandler):
|
||||
self.ratelimit(requester)
|
||||
|
||||
try:
|
||||
self.auth.check(event, auth_events=context.current_state)
|
||||
yield self.auth.check_from_context(event, context)
|
||||
except AuthError as err:
|
||||
logger.warn("Denying new event %r because %s", event, err)
|
||||
raise err
|
||||
|
||||
yield self.maybe_kick_guest_users(event, context.current_state.values())
|
||||
yield self.maybe_kick_guest_users(event, context)
|
||||
|
||||
if event.type == EventTypes.CanonicalAlias:
|
||||
# Check the alias is acually valid (at this time at least)
|
||||
@ -859,6 +865,15 @@ class MessageHandler(BaseHandler):
|
||||
e.sender == event.sender
|
||||
)
|
||||
|
||||
state_to_include_ids = [
|
||||
e_id
|
||||
for k, e_id in context.current_state_ids.items()
|
||||
if k[0] in self.hs.config.room_invite_state_types
|
||||
or k[0] == EventTypes.Member and k[1] == event.sender
|
||||
]
|
||||
|
||||
state_to_include = yield self.store.get_events(state_to_include_ids)
|
||||
|
||||
event.unsigned["invite_room_state"] = [
|
||||
{
|
||||
"type": e.type,
|
||||
@ -866,9 +881,7 @@ class MessageHandler(BaseHandler):
|
||||
"content": e.content,
|
||||
"sender": e.sender,
|
||||
}
|
||||
for k, e in context.current_state.items()
|
||||
if e.type in self.hs.config.room_invite_state_types
|
||||
or is_inviter_member_event(e)
|
||||
for e in state_to_include.values()
|
||||
]
|
||||
|
||||
invitee = UserID.from_string(event.state_key)
|
||||
@ -890,7 +903,14 @@ class MessageHandler(BaseHandler):
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Redaction:
|
||||
if self.auth.check_redaction(event, auth_events=context.current_state):
|
||||
auth_events_ids = yield self.auth.compute_auth_events(
|
||||
event, context.current_state_ids, for_verification=True,
|
||||
)
|
||||
auth_events = yield self.store.get_events(auth_events_ids)
|
||||
auth_events = {
|
||||
(e.type, e.state_key): e for e in auth_events.values()
|
||||
}
|
||||
if self.auth.check_redaction(event, auth_events=auth_events):
|
||||
original_event = yield self.store.get_event(
|
||||
event.redacts,
|
||||
check_redacted=False,
|
||||
@ -904,7 +924,7 @@ class MessageHandler(BaseHandler):
|
||||
"You don't have permission to redact events"
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Create and context.current_state:
|
||||
if event.type == EventTypes.Create and context.current_state_ids:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Changing the room create event is forbidden",
|
||||
@ -925,16 +945,7 @@ class MessageHandler(BaseHandler):
|
||||
event_stream_id, max_stream_id
|
||||
)
|
||||
|
||||
destinations = set()
|
||||
for k, s in context.current_state.items():
|
||||
try:
|
||||
if k[0] == EventTypes.Member:
|
||||
if s.content["membership"] == Membership.JOIN:
|
||||
destinations.add(get_domain_from_id(s.state_key))
|
||||
except SynapseError:
|
||||
logger.warn(
|
||||
"Failed to get destination from event %s", s.event_id
|
||||
)
|
||||
destinations = yield self.get_joined_hosts_for_room_from_state(context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _notify():
|
||||
@ -952,3 +963,39 @@ class MessageHandler(BaseHandler):
|
||||
preserve_fn(federation_handler.handle_new_event)(
|
||||
event, destinations=destinations,
|
||||
)
|
||||
|
||||
def get_joined_hosts_for_room_from_state(self, context):
|
||||
state_group = context.state_group
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
# state group, i.e. we need to make sure that calls with a state_group
|
||||
# of None don't hit previous cached calls with a None state_group.
|
||||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
return self._get_joined_hosts_for_room_from_state(
|
||||
state_group, context.current_state_ids
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(num_args=1, cache_context=True)
|
||||
def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids,
|
||||
cache_context):
|
||||
|
||||
# Don't bother getting state for people on the same HS
|
||||
current_state = yield self.store.get_events([
|
||||
e_id for key, e_id in current_state_ids.items()
|
||||
if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1])
|
||||
])
|
||||
|
||||
destinations = set()
|
||||
for e in current_state.itervalues():
|
||||
try:
|
||||
if e.type == EventTypes.Member:
|
||||
if e.content["membership"] == Membership.JOIN:
|
||||
destinations.add(get_domain_from_id(e.state_key))
|
||||
except SynapseError:
|
||||
logger.warn(
|
||||
"Failed to get destination from event %s", e.event_id
|
||||
)
|
||||
|
||||
defer.returnValue(destinations)
|
||||
|
@ -93,20 +93,26 @@ class RoomMemberHandler(BaseHandler):
|
||||
ratelimit=ratelimit,
|
||||
)
|
||||
|
||||
prev_member_event = context.current_state.get(
|
||||
prev_member_event_id = context.current_state_ids.get(
|
||||
(EventTypes.Member, target.to_string()),
|
||||
None
|
||||
)
|
||||
|
||||
if event.membership == Membership.JOIN:
|
||||
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
|
||||
# Only fire user_joined_room if the user has acutally joined the
|
||||
# room. Don't bother if the user is just changing their profile
|
||||
# info.
|
||||
# Only fire user_joined_room if the user has acutally joined the
|
||||
# room. Don't bother if the user is just changing their profile
|
||||
# info.
|
||||
newly_joined = True
|
||||
if prev_member_event_id:
|
||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||
if newly_joined:
|
||||
yield user_joined_room(self.distributor, target, room_id)
|
||||
elif event.membership == Membership.LEAVE:
|
||||
if prev_member_event and prev_member_event.membership == Membership.JOIN:
|
||||
user_left_room(self.distributor, target, room_id)
|
||||
if prev_member_event_id:
|
||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||
if prev_member_event.membership == Membership.JOIN:
|
||||
user_left_room(self.distributor, target, room_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remote_join(self, remote_room_hosts, room_id, user, content):
|
||||
@ -195,29 +201,32 @@ class RoomMemberHandler(BaseHandler):
|
||||
remote_room_hosts = []
|
||||
|
||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
current_state = yield self.state_handler.get_current_state(
|
||||
current_state_ids = yield self.state_handler.get_current_state_ids(
|
||||
room_id, latest_event_ids=latest_event_ids,
|
||||
)
|
||||
|
||||
old_state = current_state.get((EventTypes.Member, target.to_string()))
|
||||
old_membership = old_state.content.get("membership") if old_state else None
|
||||
if action == "unban" and old_membership != "ban":
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Cannot unban user who was not banned (membership=%s)" % old_membership,
|
||||
errcode=Codes.BAD_STATE
|
||||
)
|
||||
if old_membership == "ban" and action != "unban":
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Cannot %s user who was banned" % (action,),
|
||||
errcode=Codes.BAD_STATE
|
||||
)
|
||||
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
|
||||
if old_state_id:
|
||||
old_state = yield self.store.get_event(old_state_id, allow_none=True)
|
||||
old_membership = old_state.content.get("membership") if old_state else None
|
||||
if action == "unban" and old_membership != "ban":
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Cannot unban user who was not banned"
|
||||
" (membership=%s)" % old_membership,
|
||||
errcode=Codes.BAD_STATE
|
||||
)
|
||||
if old_membership == "ban" and action != "unban":
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Cannot %s user who was banned" % (action,),
|
||||
errcode=Codes.BAD_STATE
|
||||
)
|
||||
|
||||
is_host_in_room = self.is_host_in_room(current_state)
|
||||
is_host_in_room = yield self._is_host_in_room(current_state_ids)
|
||||
|
||||
if effective_membership_state == Membership.JOIN:
|
||||
if requester.is_guest and not self._can_guest_join(current_state):
|
||||
if requester.is_guest and not self._can_guest_join(current_state_ids):
|
||||
# This should be an auth check, but guests are a local concept,
|
||||
# so don't really fit into the general auth process.
|
||||
raise AuthError(403, "Guest access not allowed")
|
||||
@ -326,15 +335,17 @@ class RoomMemberHandler(BaseHandler):
|
||||
requester = synapse.types.create_requester(target_user)
|
||||
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
prev_event = message_handler.deduplicate_state_event(event, context)
|
||||
prev_event = yield message_handler.deduplicate_state_event(event, context)
|
||||
if prev_event is not None:
|
||||
return
|
||||
|
||||
if event.membership == Membership.JOIN:
|
||||
if requester.is_guest and not self._can_guest_join(context.current_state):
|
||||
# This should be an auth check, but guests are a local concept,
|
||||
# so don't really fit into the general auth process.
|
||||
raise AuthError(403, "Guest access not allowed")
|
||||
if requester.is_guest:
|
||||
guest_can_join = yield self._can_guest_join(context.current_state_ids)
|
||||
if not guest_can_join:
|
||||
# This should be an auth check, but guests are a local concept,
|
||||
# so don't really fit into the general auth process.
|
||||
raise AuthError(403, "Guest access not allowed")
|
||||
|
||||
yield message_handler.handle_new_client_event(
|
||||
requester,
|
||||
@ -344,27 +355,39 @@ class RoomMemberHandler(BaseHandler):
|
||||
ratelimit=ratelimit,
|
||||
)
|
||||
|
||||
prev_member_event = context.current_state.get(
|
||||
(EventTypes.Member, target_user.to_string()),
|
||||
prev_member_event_id = context.current_state_ids.get(
|
||||
(EventTypes.Member, event.state_key),
|
||||
None
|
||||
)
|
||||
|
||||
if event.membership == Membership.JOIN:
|
||||
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
|
||||
# Only fire user_joined_room if the user has acutally joined the
|
||||
# room. Don't bother if the user is just changing their profile
|
||||
# info.
|
||||
# Only fire user_joined_room if the user has acutally joined the
|
||||
# room. Don't bother if the user is just changing their profile
|
||||
# info.
|
||||
newly_joined = True
|
||||
if prev_member_event_id:
|
||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||
if newly_joined:
|
||||
yield user_joined_room(self.distributor, target_user, room_id)
|
||||
elif event.membership == Membership.LEAVE:
|
||||
if prev_member_event and prev_member_event.membership == Membership.JOIN:
|
||||
user_left_room(self.distributor, target_user, room_id)
|
||||
if prev_member_event_id:
|
||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||
if prev_member_event.membership == Membership.JOIN:
|
||||
user_left_room(self.distributor, target_user, room_id)
|
||||
|
||||
def _can_guest_join(self, current_state):
|
||||
@defer.inlineCallbacks
|
||||
def _can_guest_join(self, current_state_ids):
|
||||
"""
|
||||
Returns whether a guest can join a room based on its current state.
|
||||
"""
|
||||
guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
|
||||
return (
|
||||
guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
|
||||
if not guest_access_id:
|
||||
defer.returnValue(False)
|
||||
|
||||
guest_access = yield self.store.get_event(guest_access_id)
|
||||
|
||||
defer.returnValue(
|
||||
guest_access
|
||||
and guest_access.content
|
||||
and "guest_access" in guest_access.content
|
||||
@ -683,3 +706,24 @@ class RoomMemberHandler(BaseHandler):
|
||||
|
||||
if membership:
|
||||
yield self.store.forget(user_id, room_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _is_host_in_room(self, current_state_ids):
|
||||
# Have we just created the room, and is this about to be the very
|
||||
# first member event?
|
||||
create_event_id = current_state_ids.get(("m.room.create", ""))
|
||||
if len(current_state_ids) == 1 and create_event_id:
|
||||
defer.returnValue(self.hs.is_mine_id(create_event_id))
|
||||
|
||||
for (etype, state_key), event_id in current_state_ids.items():
|
||||
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
|
||||
continue
|
||||
|
||||
event = yield self.store.get_event(event_id, allow_none=True)
|
||||
if not event:
|
||||
continue
|
||||
|
||||
if event.membership == Membership.JOIN:
|
||||
defer.returnValue(True)
|
||||
|
||||
defer.returnValue(False)
|
||||
|
@ -358,11 +358,11 @@ class SyncHandler(object):
|
||||
Returns:
|
||||
A Deferred map from ((type, state_key)->Event)
|
||||
"""
|
||||
state = yield self.store.get_state_for_event(event.event_id)
|
||||
state_ids = yield self.store.get_state_ids_for_event(event.event_id)
|
||||
if event.is_state():
|
||||
state = state.copy()
|
||||
state[(event.type, event.state_key)] = event
|
||||
defer.returnValue(state)
|
||||
state_ids = state_ids.copy()
|
||||
state_ids[(event.type, event.state_key)] = event.event_id
|
||||
defer.returnValue(state_ids)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_at(self, room_id, stream_position):
|
||||
@ -415,57 +415,61 @@ class SyncHandler(object):
|
||||
with Measure(self.clock, "compute_state_delta"):
|
||||
if full_state:
|
||||
if batch:
|
||||
current_state = yield self.store.get_state_for_event(
|
||||
current_state_ids = yield self.store.get_state_ids_for_event(
|
||||
batch.events[-1].event_id
|
||||
)
|
||||
|
||||
state = yield self.store.get_state_for_event(
|
||||
state_ids = yield self.store.get_state_ids_for_event(
|
||||
batch.events[0].event_id
|
||||
)
|
||||
else:
|
||||
current_state = yield self.get_state_at(
|
||||
current_state_ids = yield self.get_state_at(
|
||||
room_id, stream_position=now_token
|
||||
)
|
||||
|
||||
state = current_state
|
||||
state_ids = current_state_ids
|
||||
|
||||
timeline_state = {
|
||||
(event.type, event.state_key): event
|
||||
(event.type, event.state_key): event.event_id
|
||||
for event in batch.events if event.is_state()
|
||||
}
|
||||
|
||||
state = _calculate_state(
|
||||
state_ids = _calculate_state(
|
||||
timeline_contains=timeline_state,
|
||||
timeline_start=state,
|
||||
timeline_start=state_ids,
|
||||
previous={},
|
||||
current=current_state,
|
||||
current=current_state_ids,
|
||||
)
|
||||
elif batch.limited:
|
||||
state_at_previous_sync = yield self.get_state_at(
|
||||
room_id, stream_position=since_token
|
||||
)
|
||||
|
||||
current_state = yield self.store.get_state_for_event(
|
||||
current_state_ids = yield self.store.get_state_ids_for_event(
|
||||
batch.events[-1].event_id
|
||||
)
|
||||
|
||||
state_at_timeline_start = yield self.store.get_state_for_event(
|
||||
state_at_timeline_start = yield self.store.get_state_ids_for_event(
|
||||
batch.events[0].event_id
|
||||
)
|
||||
|
||||
timeline_state = {
|
||||
(event.type, event.state_key): event
|
||||
(event.type, event.state_key): event.event_id
|
||||
for event in batch.events if event.is_state()
|
||||
}
|
||||
|
||||
state = _calculate_state(
|
||||
state_ids = _calculate_state(
|
||||
timeline_contains=timeline_state,
|
||||
timeline_start=state_at_timeline_start,
|
||||
previous=state_at_previous_sync,
|
||||
current=current_state,
|
||||
current=current_state_ids,
|
||||
)
|
||||
else:
|
||||
state = {}
|
||||
state_ids = {}
|
||||
|
||||
state = {}
|
||||
if state_ids:
|
||||
state = yield self.store.get_events(state_ids.values())
|
||||
|
||||
defer.returnValue({
|
||||
(e.type, e.state_key): e
|
||||
@ -806,8 +810,13 @@ class SyncHandler(object):
|
||||
# the last sync (even if we have since left). This is to make sure
|
||||
# we do send down the room, and with full state, where necessary
|
||||
if room_id in joined_room_ids or has_join:
|
||||
old_state = yield self.get_state_at(room_id, since_token)
|
||||
old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
|
||||
old_state_ids = yield self.get_state_at(room_id, since_token)
|
||||
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
|
||||
old_mem_ev = None
|
||||
if old_mem_ev_id:
|
||||
old_mem_ev = yield self.store.get_event(
|
||||
old_mem_ev_id, allow_none=True
|
||||
)
|
||||
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
|
||||
newly_joined_rooms.append(room_id)
|
||||
|
||||
@ -1099,27 +1108,25 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
|
||||
Returns:
|
||||
dict
|
||||
"""
|
||||
event_id_to_state = {
|
||||
e.event_id: e
|
||||
for e in itertools.chain(
|
||||
timeline_contains.values(),
|
||||
previous.values(),
|
||||
timeline_start.values(),
|
||||
current.values(),
|
||||
event_id_to_key = {
|
||||
e: key
|
||||
for key, e in itertools.chain(
|
||||
timeline_contains.items(),
|
||||
previous.items(),
|
||||
timeline_start.items(),
|
||||
current.items(),
|
||||
)
|
||||
}
|
||||
|
||||
c_ids = set(e.event_id for e in current.values())
|
||||
tc_ids = set(e.event_id for e in timeline_contains.values())
|
||||
p_ids = set(e.event_id for e in previous.values())
|
||||
ts_ids = set(e.event_id for e in timeline_start.values())
|
||||
c_ids = set(e for e in current.values())
|
||||
tc_ids = set(e for e in timeline_contains.values())
|
||||
p_ids = set(e for e in previous.values())
|
||||
ts_ids = set(e for e in timeline_start.values())
|
||||
|
||||
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
|
||||
|
||||
evs = (event_id_to_state[e] for e in state_ids)
|
||||
return {
|
||||
(e.type, e.state_key): e
|
||||
for e in evs
|
||||
event_id_to_key[e]: e for e in state_ids
|
||||
}
|
||||
|
||||
|
||||
|
@ -40,12 +40,12 @@ class ActionGenerator:
|
||||
def handle_push_actions_for_event(self, event, context):
|
||||
with Measure(self.clock, "evaluator_for_event"):
|
||||
bulk_evaluator = yield evaluator_for_event(
|
||||
event, self.hs, self.store, context.state_group, context.current_state
|
||||
event, self.hs, self.store, context
|
||||
)
|
||||
|
||||
with Measure(self.clock, "action_for_event_by_user"):
|
||||
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
|
||||
event, context.current_state
|
||||
event, context
|
||||
)
|
||||
|
||||
context.push_actions = [
|
||||
|
@ -19,8 +19,8 @@ from twisted.internet import defer
|
||||
|
||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.visibility import filter_events_for_clients
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.visibility import filter_events_for_clients_context
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -36,9 +36,9 @@ def _get_rules(room_id, user_ids, store):
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def evaluator_for_event(event, hs, store, state_group, current_state):
|
||||
def evaluator_for_event(event, hs, store, context):
|
||||
rules_by_user = yield store.bulk_get_push_rules_for_room(
|
||||
event.room_id, state_group, current_state
|
||||
event.room_id, context
|
||||
)
|
||||
|
||||
# if this event is an invite event, we may need to run rules for the user
|
||||
@ -72,7 +72,7 @@ class BulkPushRuleEvaluator:
|
||||
self.store = store
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def action_for_event_by_user(self, event, current_state):
|
||||
def action_for_event_by_user(self, event, context):
|
||||
actions_by_user = {}
|
||||
|
||||
# None of these users can be peeking since this list of users comes
|
||||
@ -82,27 +82,25 @@ class BulkPushRuleEvaluator:
|
||||
(u, False) for u in self.rules_by_user.keys()
|
||||
]
|
||||
|
||||
filtered_by_user = yield filter_events_for_clients(
|
||||
self.store, user_tuples, [event], {event.event_id: current_state}
|
||||
filtered_by_user = yield filter_events_for_clients_context(
|
||||
self.store, user_tuples, [event], {event.event_id: context}
|
||||
)
|
||||
|
||||
room_members = set(
|
||||
e.state_key for e in current_state.values()
|
||||
if e.type == EventTypes.Member and e.membership == Membership.JOIN
|
||||
room_members = yield self.store.get_joined_users_from_context(
|
||||
event.room_id, context,
|
||||
)
|
||||
|
||||
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
|
||||
|
||||
condition_cache = {}
|
||||
|
||||
display_names = {}
|
||||
for ev in current_state.values():
|
||||
nm = ev.content.get("displayname", None)
|
||||
if nm and ev.type == EventTypes.Member:
|
||||
display_names[ev.state_key] = nm
|
||||
|
||||
for uid, rules in self.rules_by_user.items():
|
||||
display_name = display_names.get(uid, None)
|
||||
display_name = None
|
||||
member_ev_id = context.current_state_ids.get((EventTypes.Member, uid))
|
||||
if member_ev_id:
|
||||
member_ev = yield self.store.get_event(member_ev_id, allow_none=True)
|
||||
if member_ev:
|
||||
display_name = member_ev.content.get("displayname", None)
|
||||
|
||||
filtered = filtered_by_user[uid]
|
||||
if len(filtered) == 0:
|
||||
|
@ -245,7 +245,7 @@ class HttpPusher(object):
|
||||
@defer.inlineCallbacks
|
||||
def _build_notification_dict(self, event, tweaks, badge):
|
||||
ctx = yield push_tools.get_context_for_event(
|
||||
self.state_handler, event, self.user_id
|
||||
self.store, self.state_handler, event, self.user_id
|
||||
)
|
||||
|
||||
d = {
|
||||
|
@ -22,7 +22,7 @@ from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
from synapse.util.async import concurrently_execute
|
||||
from synapse.util.presentable_names import (
|
||||
from synapse.push.presentable_names import (
|
||||
calculate_room_name, name_from_member_event, descriptor_from_member_events
|
||||
)
|
||||
from synapse.types import UserID
|
||||
@ -139,7 +139,7 @@ class Mailer(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _fetch_room_state(room_id):
|
||||
room_state = yield self.state_handler.get_current_state(room_id)
|
||||
room_state = yield self.state_handler.get_current_state_ids(room_id)
|
||||
state_by_room[room_id] = room_state
|
||||
|
||||
# Run at most 3 of these at once: sync does 10 at a time but email
|
||||
@ -159,11 +159,12 @@ class Mailer(object):
|
||||
)
|
||||
rooms.append(roomvars)
|
||||
|
||||
reason['room_name'] = calculate_room_name(
|
||||
state_by_room[reason['room_id']], user_id, fallback_to_members=True
|
||||
reason['room_name'] = yield calculate_room_name(
|
||||
self.store, state_by_room[reason['room_id']], user_id,
|
||||
fallback_to_members=True
|
||||
)
|
||||
|
||||
summary_text = self.make_summary_text(
|
||||
summary_text = yield self.make_summary_text(
|
||||
notifs_by_room, state_by_room, notif_events, user_id, reason
|
||||
)
|
||||
|
||||
@ -203,12 +204,15 @@ class Mailer(object):
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state):
|
||||
my_member_event = room_state[("m.room.member", user_id)]
|
||||
def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids):
|
||||
my_member_event_id = room_state_ids[("m.room.member", user_id)]
|
||||
my_member_event = yield self.store.get_event(my_member_event_id)
|
||||
is_invite = my_member_event.content["membership"] == "invite"
|
||||
|
||||
room_name = yield calculate_room_name(self.store, room_state_ids, user_id)
|
||||
|
||||
room_vars = {
|
||||
"title": calculate_room_name(room_state, user_id),
|
||||
"title": room_name,
|
||||
"hash": string_ordinal_total(room_id), # See sender avatar hash
|
||||
"notifs": [],
|
||||
"invite": is_invite,
|
||||
@ -218,7 +222,7 @@ class Mailer(object):
|
||||
if not is_invite:
|
||||
for n in notifs:
|
||||
notifvars = yield self.get_notif_vars(
|
||||
n, user_id, notif_events[n['event_id']], room_state
|
||||
n, user_id, notif_events[n['event_id']], room_state_ids
|
||||
)
|
||||
|
||||
# merge overlapping notifs together.
|
||||
@ -243,7 +247,7 @@ class Mailer(object):
|
||||
defer.returnValue(room_vars)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_notif_vars(self, notif, user_id, notif_event, room_state):
|
||||
def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
|
||||
results = yield self.store.get_events_around(
|
||||
notif['room_id'], notif['event_id'],
|
||||
before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER
|
||||
@ -261,17 +265,19 @@ class Mailer(object):
|
||||
the_events.append(notif_event)
|
||||
|
||||
for event in the_events:
|
||||
messagevars = self.get_message_vars(notif, event, room_state)
|
||||
messagevars = yield self.get_message_vars(notif, event, room_state_ids)
|
||||
if messagevars is not None:
|
||||
ret['messages'].append(messagevars)
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
def get_message_vars(self, notif, event, room_state):
|
||||
@defer.inlineCallbacks
|
||||
def get_message_vars(self, notif, event, room_state_ids):
|
||||
if event.type != EventTypes.Message:
|
||||
return None
|
||||
return
|
||||
|
||||
sender_state_event = room_state[("m.room.member", event.sender)]
|
||||
sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
|
||||
sender_state_event = yield self.store.get_event(sender_state_event_id)
|
||||
sender_name = name_from_member_event(sender_state_event)
|
||||
sender_avatar_url = sender_state_event.content.get("avatar_url")
|
||||
|
||||
@ -299,7 +305,7 @@ class Mailer(object):
|
||||
if "body" in event.content:
|
||||
ret["body_text_plain"] = event.content["body"]
|
||||
|
||||
return ret
|
||||
defer.returnValue(ret)
|
||||
|
||||
def add_text_message_vars(self, messagevars, event):
|
||||
msgformat = event.content.get("format")
|
||||
@ -321,6 +327,7 @@ class Mailer(object):
|
||||
|
||||
return messagevars
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def make_summary_text(self, notifs_by_room, state_by_room,
|
||||
notif_events, user_id, reason):
|
||||
if len(notifs_by_room) == 1:
|
||||
@ -330,7 +337,7 @@ class Mailer(object):
|
||||
# If the room has some kind of name, use it, but we don't
|
||||
# want the generated-from-names one here otherwise we'll
|
||||
# end up with, "new message from Bob in the Bob room"
|
||||
room_name = calculate_room_name(
|
||||
room_name = yield calculate_room_name(
|
||||
state_by_room[room_id], user_id, fallback_to_members=False
|
||||
)
|
||||
|
||||
@ -342,16 +349,16 @@ class Mailer(object):
|
||||
inviter_name = name_from_member_event(inviter_member_event)
|
||||
|
||||
if room_name is None:
|
||||
return INVITE_FROM_PERSON % {
|
||||
defer.returnValue(INVITE_FROM_PERSON % {
|
||||
"person": inviter_name,
|
||||
"app": self.app_name
|
||||
}
|
||||
})
|
||||
else:
|
||||
return INVITE_FROM_PERSON_TO_ROOM % {
|
||||
defer.returnValue(INVITE_FROM_PERSON_TO_ROOM % {
|
||||
"person": inviter_name,
|
||||
"room": room_name,
|
||||
"app": self.app_name,
|
||||
}
|
||||
})
|
||||
|
||||
sender_name = None
|
||||
if len(notifs_by_room[room_id]) == 1:
|
||||
@ -362,24 +369,24 @@ class Mailer(object):
|
||||
sender_name = name_from_member_event(state_event)
|
||||
|
||||
if sender_name is not None and room_name is not None:
|
||||
return MESSAGE_FROM_PERSON_IN_ROOM % {
|
||||
defer.returnValue(MESSAGE_FROM_PERSON_IN_ROOM % {
|
||||
"person": sender_name,
|
||||
"room": room_name,
|
||||
"app": self.app_name,
|
||||
}
|
||||
})
|
||||
elif sender_name is not None:
|
||||
return MESSAGE_FROM_PERSON % {
|
||||
defer.returnValue(MESSAGE_FROM_PERSON % {
|
||||
"person": sender_name,
|
||||
"app": self.app_name,
|
||||
}
|
||||
})
|
||||
else:
|
||||
# There's more than one notification for this room, so just
|
||||
# say there are several
|
||||
if room_name is not None:
|
||||
return MESSAGES_IN_ROOM % {
|
||||
defer.returnValue(MESSAGES_IN_ROOM % {
|
||||
"room": room_name,
|
||||
"app": self.app_name,
|
||||
}
|
||||
})
|
||||
else:
|
||||
# If the room doesn't have a name, say who the messages
|
||||
# are from explicitly to avoid, "messages in the Bob room"
|
||||
@ -388,22 +395,22 @@ class Mailer(object):
|
||||
for n in notifs_by_room[room_id]
|
||||
]))
|
||||
|
||||
return MESSAGES_FROM_PERSON % {
|
||||
defer.returnValue(MESSAGES_FROM_PERSON % {
|
||||
"person": descriptor_from_member_events([
|
||||
state_by_room[room_id][("m.room.member", s)]
|
||||
for s in sender_ids
|
||||
]),
|
||||
"app": self.app_name,
|
||||
}
|
||||
})
|
||||
else:
|
||||
# Stuff's happened in multiple different rooms
|
||||
|
||||
# ...but we still refer to the 'reason' room which triggered the mail
|
||||
if reason['room_name'] is not None:
|
||||
return MESSAGES_IN_ROOM_AND_OTHERS % {
|
||||
defer.returnValue(MESSAGES_IN_ROOM_AND_OTHERS % {
|
||||
"room": reason['room_name'],
|
||||
"app": self.app_name,
|
||||
}
|
||||
})
|
||||
else:
|
||||
# If the reason room doesn't have a name, say who the messages
|
||||
# are from explicitly to avoid, "messages in the Bob room"
|
||||
@ -412,13 +419,13 @@ class Mailer(object):
|
||||
for n in notifs_by_room[reason['room_id']]
|
||||
]))
|
||||
|
||||
return MESSAGES_FROM_PERSON_AND_OTHERS % {
|
||||
defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % {
|
||||
"person": descriptor_from_member_events([
|
||||
state_by_room[reason['room_id']][("m.room.member", s)]
|
||||
for s in sender_ids
|
||||
]),
|
||||
"app": self.app_name,
|
||||
}
|
||||
})
|
||||
|
||||
def make_room_link(self, room_id):
|
||||
# need /beta for Universal Links to work on iOS
|
||||
|
@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import re
|
||||
import logging
|
||||
|
||||
@ -25,7 +27,8 @@ ALIAS_RE = re.compile(r"^#.*:.+$")
|
||||
ALL_ALONE = "Empty Room"
|
||||
|
||||
|
||||
def calculate_room_name(room_state, user_id, fallback_to_members=True,
|
||||
@defer.inlineCallbacks
|
||||
def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True,
|
||||
fallback_to_single_member=True):
|
||||
"""
|
||||
Works out a user-facing name for the given room as per Matrix
|
||||
@ -42,59 +45,78 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
|
||||
(string or None) A human readable name for the room.
|
||||
"""
|
||||
# does it have a name?
|
||||
if ("m.room.name", "") in room_state:
|
||||
m_room_name = room_state[("m.room.name", "")]
|
||||
if m_room_name.content and m_room_name.content["name"]:
|
||||
return m_room_name.content["name"]
|
||||
if ("m.room.name", "") in room_state_ids:
|
||||
m_room_name = yield store.get_event(
|
||||
room_state_ids[("m.room.name", "")], allow_none=True
|
||||
)
|
||||
if m_room_name and m_room_name.content and m_room_name.content["name"]:
|
||||
defer.returnValue(m_room_name.content["name"])
|
||||
|
||||
# does it have a canonical alias?
|
||||
if ("m.room.canonical_alias", "") in room_state:
|
||||
canon_alias = room_state[("m.room.canonical_alias", "")]
|
||||
if ("m.room.canonical_alias", "") in room_state_ids:
|
||||
canon_alias = yield store.get_event(
|
||||
room_state_ids[("m.room.canonical_alias", "")], allow_none=True
|
||||
)
|
||||
if (
|
||||
canon_alias.content and canon_alias.content["alias"] and
|
||||
canon_alias and canon_alias.content and canon_alias.content["alias"] and
|
||||
_looks_like_an_alias(canon_alias.content["alias"])
|
||||
):
|
||||
return canon_alias.content["alias"]
|
||||
defer.returnValue(canon_alias.content["alias"])
|
||||
|
||||
# at this point we're going to need to search the state by all state keys
|
||||
# for an event type, so rearrange the data structure
|
||||
room_state_bytype = _state_as_two_level_dict(room_state)
|
||||
room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
|
||||
|
||||
# right then, any aliases at all?
|
||||
if "m.room.aliases" in room_state_bytype:
|
||||
m_room_aliases = room_state_bytype["m.room.aliases"]
|
||||
if len(m_room_aliases.values()) > 0:
|
||||
first_alias_event = m_room_aliases.values()[0]
|
||||
if first_alias_event.content and first_alias_event.content["aliases"]:
|
||||
the_aliases = first_alias_event.content["aliases"]
|
||||
if "m.room.aliases" in room_state_bytype_ids:
|
||||
m_room_aliases = room_state_bytype_ids["m.room.aliases"]
|
||||
for alias_id in m_room_aliases.values():
|
||||
alias_event = yield store.get_event(
|
||||
alias_id, allow_none=True
|
||||
)
|
||||
if alias_event and alias_event.content and alias_event.get("aliases"):
|
||||
the_aliases = alias_event.content["aliases"]
|
||||
if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]):
|
||||
return the_aliases[0]
|
||||
defer.returnValue(the_aliases[0])
|
||||
|
||||
if not fallback_to_members:
|
||||
return None
|
||||
defer.returnValue(None)
|
||||
|
||||
my_member_event = None
|
||||
if ("m.room.member", user_id) in room_state:
|
||||
my_member_event = room_state[("m.room.member", user_id)]
|
||||
if ("m.room.member", user_id) in room_state_ids:
|
||||
my_member_event = yield store.get_event(
|
||||
room_state_ids[("m.room.member", user_id)], allow_none=True
|
||||
)
|
||||
|
||||
if (
|
||||
my_member_event is not None and
|
||||
my_member_event.content['membership'] == "invite"
|
||||
):
|
||||
if ("m.room.member", my_member_event.sender) in room_state:
|
||||
inviter_member_event = room_state[("m.room.member", my_member_event.sender)]
|
||||
if fallback_to_single_member:
|
||||
return "Invite from %s" % (name_from_member_event(inviter_member_event),)
|
||||
else:
|
||||
return None
|
||||
if ("m.room.member", my_member_event.sender) in room_state_ids:
|
||||
inviter_member_event = yield store.get_event(
|
||||
room_state_ids[("m.room.member", my_member_event.sender)],
|
||||
allow_none=True,
|
||||
)
|
||||
if inviter_member_event:
|
||||
if fallback_to_single_member:
|
||||
defer.returnValue(
|
||||
"Invite from %s" % (
|
||||
name_from_member_event(inviter_member_event),
|
||||
)
|
||||
)
|
||||
else:
|
||||
return
|
||||
else:
|
||||
return "Room Invite"
|
||||
defer.returnValue("Room Invite")
|
||||
|
||||
# we're going to have to generate a name based on who's in the room,
|
||||
# so find out who is in the room that isn't the user.
|
||||
if "m.room.member" in room_state_bytype:
|
||||
if "m.room.member" in room_state_bytype_ids:
|
||||
member_events = yield store.get_events(
|
||||
room_state_bytype_ids["m.room.member"].values()
|
||||
)
|
||||
all_members = [
|
||||
ev for ev in room_state_bytype["m.room.member"].values()
|
||||
ev for ev in member_events.values()
|
||||
if ev.content['membership'] == "join" or ev.content['membership'] == "invite"
|
||||
]
|
||||
# Sort the member events oldest-first so the we name people in the
|
||||
@ -111,9 +133,9 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
|
||||
# self-chat, peeked room with 1 participant,
|
||||
# or inbound invite, or outbound 3PID invite.
|
||||
if all_members[0].sender == user_id:
|
||||
if "m.room.third_party_invite" in room_state_bytype:
|
||||
if "m.room.third_party_invite" in room_state_bytype_ids:
|
||||
third_party_invites = (
|
||||
room_state_bytype["m.room.third_party_invite"].values()
|
||||
room_state_bytype_ids["m.room.third_party_invite"].values()
|
||||
)
|
||||
|
||||
if len(third_party_invites) > 0:
|
||||
@ -126,17 +148,17 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
|
||||
# return "Inviting %s" % (
|
||||
# descriptor_from_member_events(third_party_invites)
|
||||
# )
|
||||
return "Inviting email address"
|
||||
defer.returnValue("Inviting email address")
|
||||
else:
|
||||
return ALL_ALONE
|
||||
defer.returnValue(ALL_ALONE)
|
||||
else:
|
||||
return name_from_member_event(all_members[0])
|
||||
defer.returnValue(name_from_member_event(all_members[0]))
|
||||
else:
|
||||
return ALL_ALONE
|
||||
defer.returnValue(ALL_ALONE)
|
||||
elif len(other_members) == 1 and not fallback_to_single_member:
|
||||
return None
|
||||
return
|
||||
else:
|
||||
return descriptor_from_member_events(other_members)
|
||||
defer.returnValue(descriptor_from_member_events(other_members))
|
||||
|
||||
|
||||
def descriptor_from_member_events(member_events):
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
from synapse.util.presentable_names import (
|
||||
from synapse.push.presentable_names import (
|
||||
calculate_room_name, name_from_member_event
|
||||
)
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
@ -49,21 +49,22 @@ def get_badge_count(store, user_id):
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_context_for_event(state_handler, ev, user_id):
|
||||
def get_context_for_event(store, state_handler, ev, user_id):
|
||||
ctx = {}
|
||||
|
||||
room_state = yield state_handler.get_current_state(ev.room_id)
|
||||
room_state_ids = yield state_handler.get_current_state_ids(ev.room_id)
|
||||
|
||||
# we no longer bother setting room_alias, and make room_name the
|
||||
# human-readable name instead, be that m.room.name, an alias or
|
||||
# a list of people in the room
|
||||
name = calculate_room_name(
|
||||
room_state, user_id, fallback_to_single_member=False
|
||||
name = yield calculate_room_name(
|
||||
store, room_state_ids, user_id, fallback_to_single_member=False
|
||||
)
|
||||
if name:
|
||||
ctx['name'] = name
|
||||
|
||||
sender_state_event = room_state[("m.room.member", ev.sender)]
|
||||
sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
|
||||
sender_state_event = yield store.get_event(sender_state_event_id)
|
||||
ctx['sender_display_name'] = name_from_member_event(sender_state_event)
|
||||
|
||||
defer.returnValue(ctx)
|
||||
|
@ -120,10 +120,15 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
get_state_for_event = DataStore.get_state_for_event.__func__
|
||||
get_state_for_events = DataStore.get_state_for_events.__func__
|
||||
get_state_groups = DataStore.get_state_groups.__func__
|
||||
get_state_groups_ids = DataStore.get_state_groups_ids.__func__
|
||||
get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
|
||||
get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
|
||||
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
|
||||
get_room_events_stream_for_rooms = (
|
||||
DataStore.get_room_events_stream_for_rooms.__func__
|
||||
)
|
||||
is_host_joined = DataStore.is_host_joined.__func__
|
||||
_is_host_joined = RoomMemberStore.__dict__["_is_host_joined"]
|
||||
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
|
||||
|
||||
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
|
||||
|
@ -18,8 +18,8 @@ import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import ThirdPartyEntityKind
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.types import ThirdPartyEntityKind
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
135
synapse/state.py
135
synapse/state.py
@ -93,8 +93,30 @@ class StateHandler(object):
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
|
||||
res = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||
state = res[1]
|
||||
_, state = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||
|
||||
if event_type:
|
||||
event_id = state.get((event_type, state_key))
|
||||
event = None
|
||||
if event_id:
|
||||
event = yield self.store.get_event(event_id, allow_none=True)
|
||||
defer.returnValue(event)
|
||||
return
|
||||
|
||||
state_map = yield self.store.get_events(state.values(), get_prev_content=False)
|
||||
state = {
|
||||
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
|
||||
}
|
||||
|
||||
defer.returnValue(state)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_state_ids(self, room_id, event_type=None, state_key="",
|
||||
latest_event_ids=None):
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
|
||||
_, state = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||
|
||||
if event_type:
|
||||
defer.returnValue(state.get((event_type, state_key)))
|
||||
@ -123,27 +145,27 @@ class StateHandler(object):
|
||||
# state. Certainly store.get_current_state won't return any, and
|
||||
# persisting the event won't store the state group.
|
||||
if old_state:
|
||||
context.current_state = {
|
||||
(s.type, s.state_key): s for s in old_state
|
||||
context.current_state_ids = {
|
||||
(s.type, s.state_key): s.event_id for s in old_state
|
||||
}
|
||||
else:
|
||||
context.current_state = {}
|
||||
context.current_state_ids = {}
|
||||
context.prev_state_events = []
|
||||
context.state_group = None
|
||||
defer.returnValue(context)
|
||||
|
||||
if old_state:
|
||||
context.current_state = {
|
||||
(s.type, s.state_key): s for s in old_state
|
||||
context.current_state_ids = {
|
||||
(s.type, s.state_key): s.event_id for s in old_state
|
||||
}
|
||||
context.state_group = None
|
||||
|
||||
if event.is_state():
|
||||
key = (event.type, event.state_key)
|
||||
if key in context.current_state:
|
||||
replaces = context.current_state[key]
|
||||
if replaces.event_id != event.event_id: # Paranoia check
|
||||
event.unsigned["replaces_state"] = replaces.event_id
|
||||
if key in context.current_state_ids:
|
||||
replaces = context.current_state_ids[key]
|
||||
if replaces != event.event_id: # Paranoia check
|
||||
event.unsigned["replaces_state"] = replaces
|
||||
|
||||
context.prev_state_events = []
|
||||
defer.returnValue(context)
|
||||
@ -159,18 +181,18 @@ class StateHandler(object):
|
||||
event.room_id, [e for e, _ in event.prev_events],
|
||||
)
|
||||
|
||||
group, curr_state, prev_state = ret
|
||||
group, curr_state = ret
|
||||
|
||||
context.current_state = curr_state
|
||||
context.current_state_ids = curr_state
|
||||
context.state_group = group if not event.is_state() else None
|
||||
|
||||
if event.is_state():
|
||||
key = (event.type, event.state_key)
|
||||
if key in context.current_state:
|
||||
replaces = context.current_state[key]
|
||||
event.unsigned["replaces_state"] = replaces.event_id
|
||||
if key in context.current_state_ids:
|
||||
replaces = context.current_state_ids[key]
|
||||
event.unsigned["replaces_state"] = replaces
|
||||
|
||||
context.prev_state_events = prev_state
|
||||
context.prev_state_events = []
|
||||
defer.returnValue(context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -187,72 +209,83 @@ class StateHandler(object):
|
||||
"""
|
||||
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
||||
|
||||
state_groups = yield self.store.get_state_groups(
|
||||
state_groups_ids = yield self.store.get_state_groups_ids(
|
||||
room_id, event_ids
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"resolve_state_groups state_groups %s",
|
||||
state_groups.keys()
|
||||
state_groups_ids.keys()
|
||||
)
|
||||
|
||||
group_names = frozenset(state_groups.keys())
|
||||
group_names = frozenset(state_groups_ids.keys())
|
||||
if len(group_names) == 1:
|
||||
name, state_list = state_groups.items().pop()
|
||||
state = {
|
||||
(e.type, e.state_key): e
|
||||
for e in state_list
|
||||
}
|
||||
prev_state = state.get((event_type, state_key), None)
|
||||
if prev_state:
|
||||
prev_state = prev_state.event_id
|
||||
prev_states = [prev_state]
|
||||
else:
|
||||
prev_states = []
|
||||
name, state_list = state_groups_ids.items().pop()
|
||||
|
||||
defer.returnValue((name, state, prev_states))
|
||||
defer.returnValue((name, state_list,))
|
||||
|
||||
if self._state_cache is not None:
|
||||
cache = self._state_cache.get(group_names, None)
|
||||
if cache:
|
||||
cache.ts = self.clock.time_msec()
|
||||
|
||||
event_dict = yield self.store.get_events(cache.state.values())
|
||||
state = {(e.type, e.state_key): e for e in event_dict.values()}
|
||||
|
||||
prev_state = state.get((event_type, state_key), None)
|
||||
if prev_state:
|
||||
prev_state = prev_state.event_id
|
||||
prev_states = [prev_state]
|
||||
else:
|
||||
prev_states = []
|
||||
defer.returnValue(
|
||||
(cache.state_group, state, prev_states)
|
||||
(cache.state_group, cache.state,)
|
||||
)
|
||||
|
||||
logger.info("Resolving state for %s with %d groups", room_id, len(state_groups))
|
||||
|
||||
new_state, prev_states = self._resolve_events(
|
||||
state_groups.values(), event_type, state_key
|
||||
logger.info(
|
||||
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
|
||||
)
|
||||
|
||||
state = {}
|
||||
for st in state_groups_ids.values():
|
||||
for key, e_id in st.items():
|
||||
state.setdefault(key, set()).add(e_id)
|
||||
|
||||
conflicted_state = {
|
||||
k: list(v)
|
||||
for k, v in state.items()
|
||||
if len(v) > 1
|
||||
}
|
||||
|
||||
if conflicted_state:
|
||||
logger.info("Resolving conflicted state for %r", room_id)
|
||||
state_map = yield self.store.get_events(
|
||||
[e_id for st in state_groups_ids.values() for e_id in st.values()],
|
||||
get_prev_content=False
|
||||
)
|
||||
state_sets = [
|
||||
[state_map[e_id] for key, e_id in st.items() if e_id in state_map]
|
||||
for st in state_groups_ids.values()
|
||||
]
|
||||
new_state, _ = self._resolve_events(
|
||||
state_sets, event_type, state_key
|
||||
)
|
||||
new_state = {
|
||||
key: e.event_id for key, e in new_state.items()
|
||||
}
|
||||
else:
|
||||
new_state = {
|
||||
key: e_ids.pop() for key, e_ids in state.items()
|
||||
}
|
||||
|
||||
state_group = None
|
||||
new_state_event_ids = frozenset(e.event_id for e in new_state.values())
|
||||
for sg, events in state_groups.items():
|
||||
if new_state_event_ids == frozenset(e.event_id for e in events):
|
||||
new_state_event_ids = frozenset(new_state.values())
|
||||
for sg, events in state_groups_ids.items():
|
||||
if new_state_event_ids == frozenset(e_id for e_id in events):
|
||||
state_group = sg
|
||||
break
|
||||
|
||||
if self._state_cache is not None:
|
||||
cache = _StateCacheEntry(
|
||||
state={key: event.event_id for key, event in new_state.items()},
|
||||
state=new_state,
|
||||
state_group=state_group,
|
||||
ts=self.clock.time_msec()
|
||||
)
|
||||
|
||||
self._state_cache[group_names] = cache
|
||||
|
||||
defer.returnValue((state_group, new_state, prev_states))
|
||||
defer.returnValue((state_group, new_state,))
|
||||
|
||||
def resolve_events(self, state_sets, event):
|
||||
logger.info(
|
||||
|
@ -124,7 +124,8 @@ class PushRuleStore(SQLBaseStore):
|
||||
|
||||
defer.returnValue(results)
|
||||
|
||||
def bulk_get_push_rules_for_room(self, room_id, state_group, current_state):
|
||||
def bulk_get_push_rules_for_room(self, room_id, context):
|
||||
state_group = context.state_group
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
# state group, i.e. we need to make sure that calls with a state_group
|
||||
@ -132,10 +133,12 @@ class PushRuleStore(SQLBaseStore):
|
||||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
return self._bulk_get_push_rules_for_room(room_id, state_group, current_state)
|
||||
return self._bulk_get_push_rules_for_room(
|
||||
room_id, state_group, context.current_state_ids
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
||||
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state,
|
||||
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
|
||||
cache_context):
|
||||
# We don't use `state_group`, its there so that we can cache based
|
||||
# on it. However, its important that its never None, since two current_state's
|
||||
@ -147,10 +150,16 @@ class PushRuleStore(SQLBaseStore):
|
||||
# their unread countss are correct in the event stream, but to avoid
|
||||
# generating them for bot / AS users etc, we only do so for people who've
|
||||
# sent a read receipt into the room.
|
||||
local_user_member_ids = [
|
||||
e_id for (etype, state_key), e_id in current_state_ids.iteritems()
|
||||
if etype == EventTypes.Member and self.hs.is_mine_id(state_key)
|
||||
]
|
||||
|
||||
local_member_events = yield self._get_events(local_user_member_ids)
|
||||
|
||||
local_users_in_room = set(
|
||||
e.state_key for e in current_state.values()
|
||||
if e.type == EventTypes.Member and e.membership == Membership.JOIN
|
||||
and self.hs.is_mine_id(e.state_key)
|
||||
member_event.state_key for member_event in local_member_events
|
||||
if member_event.membership == Membership.JOIN
|
||||
)
|
||||
|
||||
# users in the room who have pushers need to get push rules run because
|
||||
|
@ -20,7 +20,7 @@ from collections import namedtuple
|
||||
from ._base import SQLBaseStore
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.constants import Membership, EventTypes
|
||||
from synapse.types import get_domain_from_id
|
||||
|
||||
import logging
|
||||
@ -325,7 +325,8 @@ class RoomMemberStore(SQLBaseStore):
|
||||
|
||||
@cachedInlineCallbacks(num_args=3)
|
||||
def was_forgotten_at(self, user_id, room_id, event_id):
|
||||
"""Returns whether user_id has elected to discard history for room_id at event_id.
|
||||
"""Returns whether user_id has elected to discard history for room_id at
|
||||
event_id.
|
||||
|
||||
event_id must be a membership event."""
|
||||
def f(txn):
|
||||
@ -358,3 +359,80 @@ class RoomMemberStore(SQLBaseStore):
|
||||
},
|
||||
desc="who_forgot"
|
||||
)
|
||||
|
||||
def get_joined_users_from_context(self, room_id, context):
|
||||
state_group = context.state_group
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
# state group, i.e. we need to make sure that calls with a state_group
|
||||
# of None don't hit previous cached calls with a None state_group.
|
||||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
return self._get_joined_users_from_context(
|
||||
room_id, state_group, context.current_state_ids
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(num_args=2, cache_context=True)
|
||||
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
|
||||
cache_context):
|
||||
# We don't use `state_group`, its there so that we can cache based
|
||||
# on it. However, its important that its never None, since two current_state's
|
||||
# with a state_group of None are likely to be different.
|
||||
# See bulk_get_push_rules_for_room for how we work around this.
|
||||
assert state_group is not None
|
||||
|
||||
member_event_ids = [
|
||||
e_id
|
||||
for key, e_id in current_state_ids.iteritems()
|
||||
if key[0] == EventTypes.Member
|
||||
]
|
||||
|
||||
rows = yield self._simple_select_many_batch(
|
||||
table="room_memberships",
|
||||
column="event_id",
|
||||
iterable=member_event_ids,
|
||||
retcols=['user_id'],
|
||||
keyvalues={
|
||||
"membership": Membership.JOIN,
|
||||
},
|
||||
batch_size=1000,
|
||||
desc="_get_joined_users_from_context",
|
||||
)
|
||||
|
||||
defer.returnValue(set(row["user_id"] for row in rows))
|
||||
|
||||
def is_host_joined(self, room_id, host, state_group, state_ids):
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
# state group, i.e. we need to make sure that calls with a state_group
|
||||
# of None don't hit previous cached calls with a None state_group.
|
||||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
return self._is_host_joined(
|
||||
room_id, host, state_group, state_ids
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(num_args=3)
|
||||
def _is_host_joined(self, room_id, host, state_group, current_state_ids):
|
||||
# We don't use `state_group`, its there so that we can cache based
|
||||
# on it. However, its important that its never None, since two current_state's
|
||||
# with a state_group of None are likely to be different.
|
||||
# See bulk_get_push_rules_for_room for how we work around this.
|
||||
assert state_group is not None
|
||||
|
||||
for (etype, state_key), event_id in current_state_ids.items():
|
||||
if etype == EventTypes.Member:
|
||||
try:
|
||||
if get_domain_from_id(state_key) != host:
|
||||
continue
|
||||
except:
|
||||
logger.warn("state_key not user_id: %s", state_key)
|
||||
continue
|
||||
|
||||
event = yield self.get_event(event_id, allow_none=True)
|
||||
if event and event.content["membership"] == Membership.JOIN:
|
||||
defer.returnValue(True)
|
||||
|
||||
defer.returnValue(False)
|
||||
|
@ -44,11 +44,7 @@ class StateStore(SQLBaseStore):
|
||||
"""
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_groups(self, room_id, event_ids):
|
||||
""" Get the state groups for the given list of event_ids
|
||||
|
||||
The return value is a dict mapping group names to lists of events.
|
||||
"""
|
||||
def get_state_groups_ids(self, room_id, event_ids):
|
||||
if not event_ids:
|
||||
defer.returnValue({})
|
||||
|
||||
@ -59,9 +55,32 @@ class StateStore(SQLBaseStore):
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = yield self._get_state_for_groups(groups)
|
||||
|
||||
defer.returnValue(group_to_state)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_groups(self, room_id, event_ids):
|
||||
""" Get the state groups for the given list of event_ids
|
||||
|
||||
The return value is a dict mapping group names to lists of events.
|
||||
"""
|
||||
if not event_ids:
|
||||
defer.returnValue({})
|
||||
|
||||
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
|
||||
|
||||
state_event_map = yield self.get_events(
|
||||
[
|
||||
ev_id for group_ids in group_to_ids.values()
|
||||
for ev_id in group_ids.values()
|
||||
],
|
||||
get_prev_content=False
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
group: state_map.values()
|
||||
for group, state_map in group_to_state.items()
|
||||
group: [
|
||||
state_event_map[v] for v in event_id_map.values() if v in state_event_map
|
||||
]
|
||||
for group, event_id_map in group_to_ids.items()
|
||||
})
|
||||
|
||||
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
|
||||
@ -70,17 +89,17 @@ class StateStore(SQLBaseStore):
|
||||
if event.internal_metadata.is_outlier():
|
||||
continue
|
||||
|
||||
if context.current_state is None:
|
||||
if context.current_state_ids is None:
|
||||
continue
|
||||
|
||||
if context.state_group is not None:
|
||||
state_groups[event.event_id] = context.state_group
|
||||
continue
|
||||
|
||||
state_events = dict(context.current_state)
|
||||
state_event_ids = dict(context.current_state_ids)
|
||||
|
||||
if event.is_state():
|
||||
state_events[(event.type, event.state_key)] = event
|
||||
state_event_ids[(event.type, event.state_key)] = event.event_id
|
||||
|
||||
state_group = context.new_state_group_id
|
||||
|
||||
@ -100,12 +119,12 @@ class StateStore(SQLBaseStore):
|
||||
values=[
|
||||
{
|
||||
"state_group": state_group,
|
||||
"room_id": state.room_id,
|
||||
"type": state.type,
|
||||
"state_key": state.state_key,
|
||||
"event_id": state.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": key[0],
|
||||
"state_key": key[1],
|
||||
"event_id": state_id,
|
||||
}
|
||||
for state in state_events.values()
|
||||
for key, state_id in state_event_ids.items()
|
||||
],
|
||||
)
|
||||
state_groups[event.event_id] = state_group
|
||||
@ -248,6 +267,31 @@ class StateStore(SQLBaseStore):
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = yield self._get_state_for_groups(groups, types)
|
||||
|
||||
state_event_map = yield self.get_events(
|
||||
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
|
||||
get_prev_content=False
|
||||
)
|
||||
|
||||
event_to_state = {
|
||||
event_id: {
|
||||
k: state_event_map[v]
|
||||
for k, v in group_to_state[group].items()
|
||||
if v in state_event_map
|
||||
}
|
||||
for event_id, group in event_to_groups.items()
|
||||
}
|
||||
|
||||
defer.returnValue({event: event_to_state[event] for event in event_ids})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_events(self, event_ids, types):
|
||||
event_to_groups = yield self._get_state_group_for_events(
|
||||
event_ids,
|
||||
)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = yield self._get_state_for_groups(groups, types)
|
||||
|
||||
event_to_state = {
|
||||
event_id: group_to_state[group]
|
||||
for event_id, group in event_to_groups.items()
|
||||
@ -272,6 +316,23 @@ class StateStore(SQLBaseStore):
|
||||
state_map = yield self.get_state_for_events([event_id], types)
|
||||
defer.returnValue(state_map[event_id])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_ids_for_event(self, event_id, types=None):
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
||||
Args:
|
||||
event_id(str): event whose state should be returned
|
||||
types(list[(str, str)]|None): List of (type, state_key) tuples
|
||||
which are used to filter the state fetched. May be None, which
|
||||
matches any key
|
||||
|
||||
Returns:
|
||||
A deferred dict from (type, state_key) -> state_event
|
||||
"""
|
||||
state_map = yield self.get_state_ids_for_events([event_id], types)
|
||||
defer.returnValue(state_map[event_id])
|
||||
|
||||
@cached(num_args=2, max_entries=10000)
|
||||
def _get_state_group_for_event(self, room_id, event_id):
|
||||
return self._simple_select_one_onecol(
|
||||
@ -428,20 +489,13 @@ class StateStore(SQLBaseStore):
|
||||
full=(types is None),
|
||||
)
|
||||
|
||||
state_events = yield self._get_events(
|
||||
[ev_id for sd in results.values() for ev_id in sd.values()],
|
||||
get_prev_content=False
|
||||
)
|
||||
|
||||
state_events = {e.event_id: e for e in state_events}
|
||||
|
||||
# Remove all the entries with None values. The None values were just
|
||||
# used for bookkeeping in the cache.
|
||||
for group, state_dict in results.items():
|
||||
results[group] = {
|
||||
key: state_events[event_id]
|
||||
key: event_id
|
||||
for key, event_id in state_dict.items()
|
||||
if event_id and event_id in state_events
|
||||
if event_id
|
||||
}
|
||||
|
||||
defer.returnValue(results)
|
||||
|
@ -271,10 +271,3 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
||||
return "t%d-%d" % (self.topological, self.stream)
|
||||
else:
|
||||
return "s%d" % (self.stream,)
|
||||
|
||||
|
||||
# Some arbitrary constants used for internal API enumerations. Don't rely on
|
||||
# exact values; always pass or compare symbolically
|
||||
class ThirdPartyEntityKind(object):
|
||||
USER = 'user'
|
||||
LOCATION = 'location'
|
||||
|
@ -180,6 +180,25 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
|
||||
})
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context):
|
||||
user_ids = set(u[0] for u in user_tuples)
|
||||
event_id_to_state = {}
|
||||
for event_id, context in event_id_to_context.items():
|
||||
state = yield store.get_events([
|
||||
e_id
|
||||
for key, e_id in context.current_state_ids.iteritems()
|
||||
if key == (EventTypes.RoomHistoryVisibility, "")
|
||||
or (key[0] == EventTypes.Member and key[1] in user_ids)
|
||||
])
|
||||
event_id_to_state[event_id] = state
|
||||
|
||||
res = yield filter_events_for_clients(
|
||||
store, user_tuples, events, event_id_to_state
|
||||
)
|
||||
defer.returnValue(res)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def filter_events_for_client(store, user_id, events, is_peeking=False):
|
||||
"""
|
||||
|
@ -305,7 +305,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
|
||||
self.event_id += 1
|
||||
|
||||
context = EventContext(current_state=state)
|
||||
if state is not None:
|
||||
state_ids = {
|
||||
key: e.event_id for key, e in state.items()
|
||||
}
|
||||
else:
|
||||
state_ids = None
|
||||
|
||||
context = EventContext(current_state_ids=state_ids)
|
||||
context.push_actions = push_actions
|
||||
|
||||
ordering = None
|
||||
|
@ -67,9 +67,11 @@ class StateGroupStore(object):
|
||||
self._event_to_state_group = {}
|
||||
self._group_to_state = {}
|
||||
|
||||
self._event_id_to_event = {}
|
||||
|
||||
self._next_group = 1
|
||||
|
||||
def get_state_groups(self, room_id, event_ids):
|
||||
def get_state_groups_ids(self, room_id, event_ids):
|
||||
groups = {}
|
||||
for event_id in event_ids:
|
||||
group = self._event_to_state_group.get(event_id)
|
||||
@ -79,23 +81,33 @@ class StateGroupStore(object):
|
||||
return defer.succeed(groups)
|
||||
|
||||
def store_state_groups(self, event, context):
|
||||
if context.current_state is None:
|
||||
if context.current_state_ids is None:
|
||||
return
|
||||
|
||||
state_events = context.current_state
|
||||
state_events = dict(context.current_state_ids)
|
||||
|
||||
if event.is_state():
|
||||
state_events[(event.type, event.state_key)] = event
|
||||
state_events[(event.type, event.state_key)] = event.event_id
|
||||
|
||||
state_group = context.state_group
|
||||
if not state_group:
|
||||
state_group = self._next_group
|
||||
self._next_group += 1
|
||||
|
||||
self._group_to_state[state_group] = state_events.values()
|
||||
self._group_to_state[state_group] = state_events
|
||||
|
||||
self._event_to_state_group[event.event_id] = state_group
|
||||
|
||||
def get_events(self, event_ids, **kwargs):
|
||||
return {
|
||||
e_id: self._event_id_to_event[e_id] for e_id in event_ids
|
||||
if e_id in self._event_id_to_event
|
||||
}
|
||||
|
||||
def register_events(self, events):
|
||||
for e in events:
|
||||
self._event_id_to_event[e.event_id] = e
|
||||
|
||||
|
||||
class DictObj(dict):
|
||||
def __init__(self, **kwargs):
|
||||
@ -136,8 +148,9 @@ class StateTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.store = Mock(
|
||||
spec_set=[
|
||||
"get_state_groups",
|
||||
"get_state_groups_ids",
|
||||
"add_event_hashes",
|
||||
"get_events",
|
||||
]
|
||||
)
|
||||
hs = Mock(spec_set=[
|
||||
@ -187,7 +200,7 @@ class StateTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
store = StateGroupStore()
|
||||
self.store.get_state_groups.side_effect = store.get_state_groups
|
||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||
|
||||
context_store = {}
|
||||
|
||||
@ -196,7 +209,7 @@ class StateTestCase(unittest.TestCase):
|
||||
store.store_state_groups(event, context)
|
||||
context_store[event.event_id] = context
|
||||
|
||||
self.assertEqual(2, len(context_store["D"].current_state))
|
||||
self.assertEqual(2, len(context_store["D"].current_state_ids))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_branch_basic_conflict(self):
|
||||
@ -239,7 +252,9 @@ class StateTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
store = StateGroupStore()
|
||||
self.store.get_state_groups.side_effect = store.get_state_groups
|
||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||
self.store.get_events = store.get_events
|
||||
store.register_events(graph.walk())
|
||||
|
||||
context_store = {}
|
||||
|
||||
@ -250,7 +265,7 @@ class StateTestCase(unittest.TestCase):
|
||||
|
||||
self.assertSetEqual(
|
||||
{"START", "A", "C"},
|
||||
{e.event_id for e in context_store["D"].current_state.values()}
|
||||
{e_id for e_id in context_store["D"].current_state_ids.values()}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -303,7 +318,9 @@ class StateTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
store = StateGroupStore()
|
||||
self.store.get_state_groups.side_effect = store.get_state_groups
|
||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||
self.store.get_events = store.get_events
|
||||
store.register_events(graph.walk())
|
||||
|
||||
context_store = {}
|
||||
|
||||
@ -314,7 +331,7 @@ class StateTestCase(unittest.TestCase):
|
||||
|
||||
self.assertSetEqual(
|
||||
{"START", "A", "B", "C"},
|
||||
{e.event_id for e in context_store["E"].current_state.values()}
|
||||
{e for e in context_store["E"].current_state_ids.values()}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -384,7 +401,9 @@ class StateTestCase(unittest.TestCase):
|
||||
graph = Graph(nodes, edges)
|
||||
|
||||
store = StateGroupStore()
|
||||
self.store.get_state_groups.side_effect = store.get_state_groups
|
||||
self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
|
||||
self.store.get_events = store.get_events
|
||||
store.register_events(graph.walk())
|
||||
|
||||
context_store = {}
|
||||
|
||||
@ -395,7 +414,7 @@ class StateTestCase(unittest.TestCase):
|
||||
|
||||
self.assertSetEqual(
|
||||
{"A1", "A2", "A3", "A5", "B"},
|
||||
{e.event_id for e in context_store["D"].current_state.values()}
|
||||
{e for e in context_store["D"].current_state_ids.values()}
|
||||
)
|
||||
|
||||
def _add_depths(self, nodes, edges):
|
||||
@ -424,13 +443,8 @@ class StateTestCase(unittest.TestCase):
|
||||
event, old_state=old_state
|
||||
)
|
||||
|
||||
for k, v in context.current_state.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(
|
||||
set(old_state), set(context.current_state.values())
|
||||
set(e.event_id for e in old_state), set(context.current_state_ids.values())
|
||||
)
|
||||
|
||||
self.assertIsNone(context.state_group)
|
||||
@ -449,14 +463,8 @@ class StateTestCase(unittest.TestCase):
|
||||
event, old_state=old_state
|
||||
)
|
||||
|
||||
for k, v in context.current_state.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(
|
||||
set(old_state),
|
||||
set(context.current_state.values())
|
||||
set(e.event_id for e in old_state), set(context.current_state_ids.values())
|
||||
)
|
||||
|
||||
self.assertIsNone(context.state_group)
|
||||
@ -473,20 +481,15 @@ class StateTestCase(unittest.TestCase):
|
||||
|
||||
group_name = "group_name_1"
|
||||
|
||||
self.store.get_state_groups.return_value = {
|
||||
group_name: old_state,
|
||||
self.store.get_state_groups_ids.return_value = {
|
||||
group_name: {(e.type, e.state_key): e.event_id for e in old_state},
|
||||
}
|
||||
|
||||
context = yield self.state.compute_event_context(event)
|
||||
|
||||
for k, v in context.current_state.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in old_state]),
|
||||
set([e.event_id for e in context.current_state.values()])
|
||||
set(context.current_state_ids.values())
|
||||
)
|
||||
|
||||
self.assertEqual(group_name, context.state_group)
|
||||
@ -503,20 +506,15 @@ class StateTestCase(unittest.TestCase):
|
||||
|
||||
group_name = "group_name_1"
|
||||
|
||||
self.store.get_state_groups.return_value = {
|
||||
group_name: old_state,
|
||||
self.store.get_state_groups_ids.return_value = {
|
||||
group_name: {(e.type, e.state_key): e.event_id for e in old_state},
|
||||
}
|
||||
|
||||
context = yield self.state.compute_event_context(event)
|
||||
|
||||
for k, v in context.current_state.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in old_state]),
|
||||
set([e.event_id for e in context.current_state.values()])
|
||||
set(context.current_state_ids.values())
|
||||
)
|
||||
|
||||
self.assertIsNone(context.state_group)
|
||||
@ -543,9 +541,14 @@ class StateTestCase(unittest.TestCase):
|
||||
create_event(type="test4", state_key=""),
|
||||
]
|
||||
|
||||
store = StateGroupStore()
|
||||
store.register_events(old_state_1)
|
||||
store.register_events(old_state_2)
|
||||
self.store.get_events = store.get_events
|
||||
|
||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||
|
||||
self.assertEqual(len(context.current_state), 6)
|
||||
self.assertEqual(len(context.current_state_ids), 6)
|
||||
|
||||
self.assertIsNone(context.state_group)
|
||||
|
||||
@ -571,9 +574,14 @@ class StateTestCase(unittest.TestCase):
|
||||
create_event(type="test4", state_key=""),
|
||||
]
|
||||
|
||||
store = StateGroupStore()
|
||||
store.register_events(old_state_1)
|
||||
store.register_events(old_state_2)
|
||||
self.store.get_events = store.get_events
|
||||
|
||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||
|
||||
self.assertEqual(len(context.current_state), 6)
|
||||
self.assertEqual(len(context.current_state_ids), 6)
|
||||
|
||||
self.assertIsNone(context.state_group)
|
||||
|
||||
@ -606,9 +614,16 @@ class StateTestCase(unittest.TestCase):
|
||||
create_event(type="test1", state_key="1", depth=2),
|
||||
]
|
||||
|
||||
store = StateGroupStore()
|
||||
store.register_events(old_state_1)
|
||||
store.register_events(old_state_2)
|
||||
self.store.get_events = store.get_events
|
||||
|
||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||
|
||||
self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])
|
||||
self.assertEqual(
|
||||
old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
|
||||
)
|
||||
|
||||
# Reverse the depth to make sure we are actually using the depths
|
||||
# during state resolution.
|
||||
@ -625,17 +640,22 @@ class StateTestCase(unittest.TestCase):
|
||||
create_event(type="test1", state_key="1", depth=1),
|
||||
]
|
||||
|
||||
store.register_events(old_state_1)
|
||||
store.register_events(old_state_2)
|
||||
|
||||
context = yield self._get_context(event, old_state_1, old_state_2)
|
||||
|
||||
self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])
|
||||
self.assertEqual(
|
||||
old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
|
||||
)
|
||||
|
||||
def _get_context(self, event, old_state_1, old_state_2):
|
||||
group_name_1 = "group_name_1"
|
||||
group_name_2 = "group_name_2"
|
||||
|
||||
self.store.get_state_groups.return_value = {
|
||||
group_name_1: old_state_1,
|
||||
group_name_2: old_state_2,
|
||||
self.store.get_state_groups_ids.return_value = {
|
||||
group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
|
||||
group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
|
||||
}
|
||||
|
||||
return self.state.compute_event_context(event)
|
||||
|
Loading…
Reference in New Issue
Block a user