Merge pull request #1060 from matrix-org/erikj/state_ids

Assign state groups in state handler.
This commit is contained in:
Erik Johnston 2016-09-01 14:20:42 +01:00 committed by GitHub
commit 44982606ee
16 changed files with 216 additions and 172 deletions

View File

@ -66,7 +66,7 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True): def check_from_context(self, event, context, do_sig_check=True):
auth_events_ids = yield self.compute_auth_events( auth_events_ids = yield self.compute_auth_events(
event, context.current_state_ids, for_verification=True, event, context.prev_state_ids, for_verification=True,
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = yield self.store.get_events(auth_events_ids)
auth_events = { auth_events = {
@ -281,11 +281,13 @@ class Auth(object):
with Measure(self.clock, "check_host_in_room"): with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
group, curr_state_ids = yield self.state.resolve_state_groups( entry = yield self.state.resolve_state_groups(
room_id, latest_event_ids room_id, latest_event_ids
) )
ret = yield self.store.is_host_joined(room_id, host, group, curr_state_ids) ret = yield self.store.is_host_joined(
room_id, host, entry.state_group, entry.state
)
defer.returnValue(ret) defer.returnValue(ret)
def check_event_sender_in_room(self, event, auth_events): def check_event_sender_in_room(self, event, auth_events):
@ -852,7 +854,7 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def add_auth_events(self, builder, context): def add_auth_events(self, builder, context):
auth_ids = yield self.compute_auth_events(builder, context.current_state_ids) auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids)
auth_events_entries = yield self.store.add_event_hashes( auth_events_entries = yield self.store.add_event_hashes(
auth_ids auth_ids

View File

@ -15,8 +15,9 @@
class EventContext(object): class EventContext(object):
def __init__(self, current_state_ids=None): def __init__(self):
self.current_state_ids = current_state_ids self.current_state_ids = None
self.prev_state_ids = None
self.state_group = None self.state_group = None
self.rejected = False self.rejected = False
self.push_actions = [] self.push_actions = []

View File

@ -222,7 +222,7 @@ class FederationHandler(BaseHandler):
# joined the room. Don't bother if the user is just # joined the room. Don't bother if the user is just
# changing their profile info. # changing their profile info.
newly_joined = True newly_joined = True
prev_state_id = context.current_state_ids.get( prev_state_id = context.prev_state_ids.get(
(event.type, event.state_key) (event.type, event.state_key)
) )
if prev_state_id: if prev_state_id:
@ -835,12 +835,12 @@ class FederationHandler(BaseHandler):
self.replication_layer.send_pdu(new_pdu, destinations) self.replication_layer.send_pdu(new_pdu, destinations)
state_ids = context.current_state_ids.values() state_ids = context.prev_state_ids.values()
auth_chain = yield self.store.get_auth_chain(set( auth_chain = yield self.store.get_auth_chain(set(
[event.event_id] + state_ids [event.event_id] + state_ids
)) ))
state = yield self.store.get_events(context.current_state_ids.values()) state = yield self.store.get_events(context.prev_state_ids.values())
defer.returnValue({ defer.returnValue({
"state": state.values(), "state": state.values(),
@ -1333,7 +1333,7 @@ class FederationHandler(BaseHandler):
if not auth_events: if not auth_events:
auth_events_ids = yield self.auth.compute_auth_events( auth_events_ids = yield self.auth.compute_auth_events(
event, context.current_state_ids, for_verification=True, event, context.prev_state_ids, for_verification=True,
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = yield self.store.get_events(auth_events_ids)
auth_events = { auth_events = {
@ -1432,6 +1432,11 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(e_id for e_id, _ in event.auth_events) event_auth_events = set(e_id for e_id, _ in event.auth_events)
if event.is_state():
event_key = (event.type, event.state_key)
else:
event_key = None
if event_auth_events - current_state: if event_auth_events - current_state:
have_events = yield self.store.have_events( have_events = yield self.store.have_events(
event_auth_events - current_state event_auth_events - current_state
@ -1537,8 +1542,12 @@ class FederationHandler(BaseHandler):
context.current_state_ids.update({ context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items() k: a.event_id for k, a in auth_events.items()
if k != event_key
}) })
context.state_group = None context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
if different_auth and not event.internal_metadata.is_outlier(): if different_auth and not event.internal_metadata.is_outlier():
logger.info("Different auth after resolution: %s", different_auth) logger.info("Different auth after resolution: %s", different_auth)
@ -1560,7 +1569,7 @@ class FederationHandler(BaseHandler):
if do_resolution: if do_resolution:
# 1. Get what we think is the auth chain. # 1. Get what we think is the auth chain.
auth_ids = yield self.auth.compute_auth_events( auth_ids = yield self.auth.compute_auth_events(
event, context.current_state_ids event, context.prev_state_ids
) )
local_auth_chain = yield self.store.get_auth_chain(auth_ids) local_auth_chain = yield self.store.get_auth_chain(auth_ids)
@ -1618,8 +1627,12 @@ class FederationHandler(BaseHandler):
context.current_state_ids.update({ context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items() k: a.event_id for k, a in auth_events.items()
if k != event_key
}) })
context.state_group = None context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
try: try:
self.auth.check(event, auth_events=auth_events) self.auth.check(event, auth_events=auth_events)
@ -1855,7 +1868,7 @@ class FederationHandler(BaseHandler):
event.content["third_party_invite"]["signed"]["token"] event.content["third_party_invite"]["signed"]["token"]
) )
original_invite = None original_invite = None
original_invite_id = context.current_state_ids.get(key) original_invite_id = context.prev_state_ids.get(key)
if original_invite_id: if original_invite_id:
original_invite = yield self.store.get_event( original_invite = yield self.store.get_event(
original_invite_id, allow_none=True original_invite_id, allow_none=True
@ -1893,7 +1906,7 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"] signed = event.content["third_party_invite"]["signed"]
token = signed["token"] token = signed["token"]
invite_event_id = context.current_state_ids.get( invite_event_id = context.prev_state_ids.get(
(EventTypes.ThirdPartyInvite, token,) (EventTypes.ThirdPartyInvite, token,)
) )

View File

@ -272,7 +272,7 @@ class MessageHandler(BaseHandler):
If so, returns the version of the event in context. If so, returns the version of the event in context.
Otherwise, returns None. Otherwise, returns None.
""" """
prev_event_id = context.current_state_ids.get((event.type, event.state_key)) prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True) prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event: if not prev_event:
return return
@ -808,8 +808,8 @@ class MessageHandler(BaseHandler):
event = builder.build() event = builder.build()
logger.debug( logger.debug(
"Created event %s with current state: %s", "Created event %s with state: %s",
event.event_id, context.current_state_ids, event.event_id, context.prev_state_ids,
) )
defer.returnValue( defer.returnValue(
@ -904,7 +904,7 @@ class MessageHandler(BaseHandler):
if event.type == EventTypes.Redaction: if event.type == EventTypes.Redaction:
auth_events_ids = yield self.auth.compute_auth_events( auth_events_ids = yield self.auth.compute_auth_events(
event, context.current_state_ids, for_verification=True, event, context.prev_state_ids, for_verification=True,
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = yield self.store.get_events(auth_events_ids)
auth_events = { auth_events = {
@ -924,7 +924,7 @@ class MessageHandler(BaseHandler):
"You don't have permission to redact events" "You don't have permission to redact events"
) )
if event.type == EventTypes.Create and context.current_state_ids: if event.type == EventTypes.Create and context.prev_state_ids:
raise AuthError( raise AuthError(
403, 403,
"Changing the room create event is forbidden", "Changing the room create event is forbidden",

View File

@ -93,7 +93,7 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit, ratelimit=ratelimit,
) )
prev_member_event_id = context.current_state_ids.get( prev_member_event_id = context.prev_state_ids.get(
(EventTypes.Member, target.to_string()), (EventTypes.Member, target.to_string()),
None None
) )
@ -341,7 +341,7 @@ class RoomMemberHandler(BaseHandler):
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if requester.is_guest: if requester.is_guest:
guest_can_join = yield self._can_guest_join(context.current_state_ids) guest_can_join = yield self._can_guest_join(context.prev_state_ids)
if not guest_can_join: if not guest_can_join:
# This should be an auth check, but guests are a local concept, # This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process. # so don't really fit into the general auth process.
@ -355,7 +355,7 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit, ratelimit=ratelimit,
) )
prev_member_event_id = context.current_state_ids.get( prev_member_event_id = context.prev_state_ids.get(
(EventTypes.Member, event.state_key), (EventTypes.Member, event.state_key),
None None
) )

View File

@ -87,7 +87,7 @@ class BulkPushRuleEvaluator:
) )
room_members = yield self.store.get_joined_users_from_context( room_members = yield self.store.get_joined_users_from_context(
event.room_id, context.state_group, context.current_state_ids event, context
) )
evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) evaluator = PushRuleEvaluatorForEvent(event, len(room_members))

View File

@ -40,7 +40,6 @@ STREAM_NAMES = (
("backfill",), ("backfill",),
("push_rules",), ("push_rules",),
("pushers",), ("pushers",),
("state",),
("caches",), ("caches",),
("to_device",), ("to_device",),
) )
@ -131,7 +130,6 @@ class ReplicationResource(Resource):
backfill_token = yield self.store.get_current_backfill_token() backfill_token = yield self.store.get_current_backfill_token()
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token() pushers_token = self.store.get_pushers_stream_token()
state_token = self.store.get_state_stream_token()
caches_token = self.store.get_cache_stream_token() caches_token = self.store.get_cache_stream_token()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
@ -143,7 +141,7 @@ class ReplicationResource(Resource):
backfill_token, backfill_token,
push_rules_token, push_rules_token,
pushers_token, pushers_token,
state_token, 0, # State stream is no longer a thing
caches_token, caches_token,
int(stream_token.to_device_key), int(stream_token.to_device_key),
)) ))
@ -193,7 +191,6 @@ class ReplicationResource(Resource):
yield self.receipts(writer, current_token, limit, request_streams) yield self.receipts(writer, current_token, limit, request_streams)
yield self.push_rules(writer, current_token, limit, request_streams) yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams) yield self.pushers(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit, request_streams)
yield self.caches(writer, current_token, limit, request_streams) yield self.caches(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams) yield self.to_device(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
@ -368,25 +365,6 @@ class ReplicationResource(Resource):
"position", "user_id", "app_id", "pushkey" "position", "user_id", "app_id", "pushkey"
)) ))
@defer.inlineCallbacks
def state(self, writer, current_token, limit, request_streams):
current_position = current_token.state
state = request_streams.get("state")
if state is not None:
state_groups, state_group_state = (
yield self.store.get_all_new_state_groups(
state, current_position, limit
)
)
writer.write_header_and_rows("state_groups", state_groups, (
"position", "room_id", "event_id"
))
writer.write_header_and_rows("state_group_state", state_group_state, (
"position", "type", "state_key", "event_id"
))
@defer.inlineCallbacks @defer.inlineCallbacks
def caches(self, writer, current_token, limit, request_streams): def caches(self, writer, current_token, limit, request_streams):
current_position = current_token.caches current_position = current_token.caches

View File

@ -123,6 +123,7 @@ class SlavedEventStore(BaseSlavedStore):
get_state_groups_ids = DataStore.get_state_groups_ids.__func__ get_state_groups_ids = DataStore.get_state_groups_ids.__func__
get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__ get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__ get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__ get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
_get_joined_users_from_context = ( _get_joined_users_from_context = (
RoomMemberStore.__dict__["_get_joined_users_from_context"] RoomMemberStore.__dict__["_get_joined_users_from_context"]

View File

@ -43,11 +43,35 @@ SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60 EVICTION_TIMEOUT_SECONDS = 60 * 60
_NEXT_STATE_ID = 1
def _gen_state_id():
global _NEXT_STATE_ID
s = "X%d" % (_NEXT_STATE_ID,)
_NEXT_STATE_ID += 1
return s
class _StateCacheEntry(object): class _StateCacheEntry(object):
def __init__(self, state, state_group, ts): __slots__ = ["state", "state_group", "state_id"]
def __init__(self, state, state_group):
self.state = state self.state = state
self.state_group = state_group self.state_group = state_group
# The `state_id` is a unique ID we generate that can be used as ID for
# this collection of state. Usually this would be the same as the
# state group, but on worker instances we can't generate a new state
# group each time we resolve state, so we generate a separate one that
# isn't persisted and is used solely for caches.
# `state_id` is either a state_group (and so an int) or a string. This
# ensures we don't accidentally persist a state_id as a stateg_group
if state_group:
self.state_id = state_group
else:
self.state_id = _gen_state_id()
class StateHandler(object): class StateHandler(object):
""" Responsible for doing state conflict resolution. """ Responsible for doing state conflict resolution.
@ -93,7 +117,8 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
_, state = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
if event_type: if event_type:
event_id = state.get((event_type, state_key)) event_id = state.get((event_type, state_key))
@ -116,7 +141,8 @@ class StateHandler(object):
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
_, state = yield self.resolve_state_groups(room_id, latest_event_ids) ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
if event_type: if event_type:
defer.returnValue(state.get((event_type, state_key))) defer.returnValue(state.get((event_type, state_key)))
@ -127,9 +153,9 @@ class StateHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_user_in_room(self, room_id): def get_current_user_in_room(self, room_id):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids) entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_context( joined_users = yield self.store.get_joined_users_from_state(
room_id, group, state_ids room_id, entry.state_id, entry.state
) )
defer.returnValue(joined_users) defer.returnValue(joined_users)
@ -154,52 +180,73 @@ class StateHandler(object):
# state. Certainly store.get_current_state won't return any, and # state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group. # persisting the event won't store the state group.
if old_state: if old_state:
context.current_state_ids = { context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
if event.is_state():
context.current_state_events = dict(context.prev_state_ids)
key = (event.type, event.state_key)
context.current_state_events[key] = event.event_id
else:
context.current_state_events = context.prev_state_ids
else: else:
context.current_state_ids = {} context.current_state_ids = {}
context.prev_state_ids = {}
context.prev_state_events = [] context.prev_state_events = []
context.state_group = None context.state_group = self.store.get_next_state_group()
defer.returnValue(context) defer.returnValue(context)
if old_state: if old_state:
context.current_state_ids = { context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} }
context.state_group = None context.state_group = self.store.get_next_state_group()
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.current_state_ids: if key in context.prev_state_ids:
replaces = context.current_state_ids[key] replaces = context.prev_state_ids[key]
if replaces != event.event_id: # Paranoia check if replaces != event.event_id: # Paranoia check
event.unsigned["replaces_state"] = replaces event.unsigned["replaces_state"] = replaces
context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id
else:
context.current_state_ids = context.prev_state_ids
context.prev_state_events = [] context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
if event.is_state(): if event.is_state():
ret = yield self.resolve_state_groups( entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
event_type=event.type, event_type=event.type,
state_key=event.state_key, state_key=event.state_key,
) )
else: else:
ret = yield self.resolve_state_groups( entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
) )
group, curr_state = ret curr_state = entry.state
context.current_state_ids = curr_state context.prev_state_ids = curr_state
context.state_group = group if not event.is_state() else None if event.is_state():
context.state_group = self.store.get_next_state_group()
else:
if entry.state_group is None:
entry.state_group = self.store.get_next_state_group()
entry.state_id = entry.state_group
context.state_group = entry.state_group
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.current_state_ids: if key in context.prev_state_ids:
replaces = context.current_state_ids[key] replaces = context.prev_state_ids[key]
event.unsigned["replaces_state"] = replaces event.unsigned["replaces_state"] = replaces
context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id
else:
context.current_state_ids = context.prev_state_ids
context.prev_state_events = [] context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
@ -231,16 +278,15 @@ class StateHandler(object):
if len(group_names) == 1: if len(group_names) == 1:
name, state_list = state_groups_ids.items().pop() name, state_list = state_groups_ids.items().pop()
defer.returnValue((name, state_list,)) defer.returnValue(_StateCacheEntry(
state=state_list,
state_group=name,
))
if self._state_cache is not None: if self._state_cache is not None:
cache = self._state_cache.get(group_names, None) cache = self._state_cache.get(group_names, None)
if cache: if cache:
cache.ts = self.clock.time_msec() defer.returnValue(cache)
defer.returnValue(
(cache.state_group, cache.state,)
)
logger.info( logger.info(
"Resolving state for %s with %d groups", room_id, len(state_groups_ids) "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
@ -284,17 +330,22 @@ class StateHandler(object):
if new_state_event_ids == frozenset(e_id for e_id in events): if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg state_group = sg
break break
if state_group is None:
# Worker instances don't have access to this method, but we want
# to set the state_group on the main instance to increase cache
# hits.
if hasattr(self.store, "get_next_state_group"):
state_group = self.store.get_next_state_group()
if self._state_cache is not None:
cache = _StateCacheEntry( cache = _StateCacheEntry(
state=new_state, state=new_state,
state_group=state_group, state_group=state_group,
ts=self.clock.time_msec()
) )
if self._state_cache is not None:
self._state_cache[group_names] = cache self._state_cache[group_names] = cache
defer.returnValue((state_group, new_state,)) defer.returnValue(cache)
def resolve_events(self, state_sets, event): def resolve_events(self, state_sets, event):
logger.info( logger.info(

View File

@ -115,7 +115,7 @@ class DataStore(RoomMemberStore, RoomStore,
) )
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")

View File

@ -271,22 +271,11 @@ class EventsStore(SQLBaseStore):
len(events_and_contexts) len(events_and_contexts)
) )
state_group_id_manager = self._state_groups_id_gen.get_next_mult(
len(events_and_contexts)
)
with stream_ordering_manager as stream_orderings: with stream_ordering_manager as stream_orderings:
with state_group_id_manager as state_group_ids: for (event, context), stream, in zip(
for (event, context), stream, state_group_id in zip( events_and_contexts, stream_orderings
events_and_contexts, stream_orderings, state_group_ids
): ):
event.internal_metadata.stream_ordering = stream event.internal_metadata.stream_ordering = stream
# Assign a state group_id in case a new id is needed for
# this context. In theory we only need to assign this
# for contexts that have current_state and aren't outliers
# but that make the code more complicated. Assigning an ID
# per event only causes the state_group_ids to grow as fast
# as the stream_ordering so in practise shouldn't be a problem.
context.new_state_group_id = state_group_id
chunks = [ chunks = [
events_and_contexts[x:x + 100] events_and_contexts[x:x + 100]
@ -312,9 +301,7 @@ class EventsStore(SQLBaseStore):
delete_existing=False): delete_existing=False):
try: try:
with self._stream_id_gen.get_next() as stream_ordering: with self._stream_id_gen.get_next() as stream_ordering:
with self._state_groups_id_gen.get_next() as state_group_id:
event.internal_metadata.stream_ordering = stream_ordering event.internal_metadata.stream_ordering = stream_ordering
context.new_state_group_id = state_group_id
yield self.runInteraction( yield self.runInteraction(
"persist_event", "persist_event",
self._persist_event_txn, self._persist_event_txn,
@ -528,7 +515,7 @@ class EventsStore(SQLBaseStore):
# Add an entry to the ex_outlier_stream table to replicate the # Add an entry to the ex_outlier_stream table to replicate the
# change in outlier status to our workers. # change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group or context.new_state_group_id state_group_id = context.state_group
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="ex_outlier_stream", table="ex_outlier_stream",

View File

@ -354,7 +354,8 @@ class RoomMemberStore(SQLBaseStore):
desc="who_forgot" desc="who_forgot"
) )
def get_joined_users_from_context(self, room_id, state_group, state_ids): def get_joined_users_from_context(self, event, context):
state_group = context.state_group
if not state_group: if not state_group:
# If state_group is None it means it has yet to be assigned a # If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group # state group, i.e. we need to make sure that calls with a state_group
@ -363,12 +364,24 @@ class RoomMemberStore(SQLBaseStore):
state_group = object() state_group = object()
return self._get_joined_users_from_context( return self._get_joined_users_from_context(
room_id, state_group, state_ids event.room_id, state_group, context.current_state_ids, event=event,
)
def get_joined_users_from_state(self, room_id, state_group, state_ids):
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._get_joined_users_from_context(
room_id, state_group, state_ids,
) )
@cachedInlineCallbacks(num_args=2, cache_context=True) @cachedInlineCallbacks(num_args=2, cache_context=True)
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
cache_context): cache_context, event=None):
# We don't use `state_group`, its there so that we can cache based # We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's # on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different. # with a state_group of None are likely to be different.
@ -393,7 +406,13 @@ class RoomMemberStore(SQLBaseStore):
desc="_get_joined_users_from_context", desc="_get_joined_users_from_context",
) )
defer.returnValue(set(row["user_id"] for row in rows)) users_in_room = set(row["user_id"] for row in rows)
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
if event.event_id in member_event_ids:
users_in_room.add(event.state_key)
defer.returnValue(users_in_room)
def is_host_joined(self, room_id, host, state_group, state_ids): def is_host_joined(self, room_id, host, state_group, state_ids):
if not state_group: if not state_group:

View File

@ -83,6 +83,14 @@ class StateStore(SQLBaseStore):
for group, event_id_map in group_to_ids.items() for group, event_id_map in group_to_ids.items()
}) })
def _have_persisted_state_group_txn(self, txn, state_group):
txn.execute(
"SELECT count(*) FROM state_groups WHERE id = ?",
(state_group,)
)
row = txn.fetchone()
return row and row[0]
def _store_mult_state_groups_txn(self, txn, events_and_contexts): def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {} state_groups = {}
for event, context in events_and_contexts: for event, context in events_and_contexts:
@ -92,22 +100,19 @@ class StateStore(SQLBaseStore):
if context.current_state_ids is None: if context.current_state_ids is None:
continue continue
if context.state_group is not None:
state_groups[event.event_id] = context.state_group state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group):
logger.info("Already persisted state_group: %r", context.state_group)
continue continue
state_event_ids = dict(context.current_state_ids) state_event_ids = dict(context.current_state_ids)
if event.is_state():
state_event_ids[(event.type, event.state_key)] = event.event_id
state_group = context.new_state_group_id
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="state_groups", table="state_groups",
values={ values={
"id": state_group, "id": context.state_group,
"room_id": event.room_id, "room_id": event.room_id,
"event_id": event.event_id, "event_id": event.event_id,
}, },
@ -118,7 +123,7 @@ class StateStore(SQLBaseStore):
table="state_groups_state", table="state_groups_state",
values=[ values=[
{ {
"state_group": state_group, "state_group": context.state_group,
"room_id": event.room_id, "room_id": event.room_id,
"type": key[0], "type": key[0],
"state_key": key[1], "state_key": key[1],
@ -127,7 +132,6 @@ class StateStore(SQLBaseStore):
for key, state_id in state_event_ids.items() for key, state_id in state_event_ids.items()
], ],
) )
state_groups[event.event_id] = state_group
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
@ -527,5 +531,5 @@ class StateStore(SQLBaseStore):
"get_all_new_state_groups", get_all_new_state_groups_txn "get_all_new_state_groups", get_all_new_state_groups_txn
) )
def get_state_stream_token(self): def get_next_state_group(self):
return self._state_groups_id_gen.get_current_token() return self._state_groups_id_gen.get_next()

View File

@ -312,7 +312,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
else: else:
state_ids = None state_ids = None
context = EventContext(current_state_ids=state_ids) context = EventContext()
context.current_state_ids = state_ids
context.prev_state_ids = state_ids
context.push_actions = push_actions context.push_actions = push_actions
ordering = None ordering = None

View File

@ -60,8 +60,8 @@ class ReplicationResourceCase(unittest.TestCase):
self.assertEquals(body, {}) self.assertEquals(body, {})
@defer.inlineCallbacks @defer.inlineCallbacks
def test_events_and_state(self): def test_events(self):
get = self.get(events="-1", state="-1", timeout="0") get = self.get(events="-1", timeout="0")
yield self.hs.get_handlers().room_creation_handler.create_room( yield self.hs.get_handlers().room_creation_handler.create_room(
synapse.types.create_requester(self.user), {} synapse.types.create_requester(self.user), {}
) )
@ -70,12 +70,6 @@ class ReplicationResourceCase(unittest.TestCase):
self.assertEquals(body["events"]["field_names"], [ self.assertEquals(body["events"]["field_names"], [
"position", "internal", "json", "state_group" "position", "internal", "json", "state_group"
]) ])
self.assertEquals(body["state_groups"]["field_names"], [
"position", "room_id", "event_id"
])
self.assertEquals(body["state_group_state"]["field_names"], [
"position", "type", "state_key", "event_id"
])
@defer.inlineCallbacks @defer.inlineCallbacks
def test_presence(self): def test_presence(self):

View File

@ -86,17 +86,8 @@ class StateGroupStore(object):
state_events = dict(context.current_state_ids) state_events = dict(context.current_state_ids)
if event.is_state(): self._group_to_state[context.state_group] = state_events
state_events[(event.type, event.state_key)] = event.event_id self._event_to_state_group[event.event_id] = context.state_group
state_group = context.state_group
if not state_group:
state_group = self._next_group
self._next_group += 1
self._group_to_state[state_group] = state_events
self._event_to_state_group[event.event_id] = state_group
def get_events(self, event_ids, **kwargs): def get_events(self, event_ids, **kwargs):
return { return {
@ -151,6 +142,7 @@ class StateTestCase(unittest.TestCase):
"get_state_groups_ids", "get_state_groups_ids",
"add_event_hashes", "add_event_hashes",
"get_events", "get_events",
"get_next_state_group",
] ]
) )
hs = Mock(spec_set=[ hs = Mock(spec_set=[
@ -161,6 +153,8 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock() hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs) hs.get_auth.return_value = Auth(hs)
self.store.get_next_state_group.side_effect = Mock
self.state = StateHandler(hs) self.state = StateHandler(hs)
self.event_id = 0 self.event_id = 0
@ -209,7 +203,7 @@ class StateTestCase(unittest.TestCase):
store.store_state_groups(event, context) store.store_state_groups(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
self.assertEqual(2, len(context_store["D"].current_state_ids)) self.assertEqual(2, len(context_store["D"].prev_state_ids))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_branch_basic_conflict(self): def test_branch_basic_conflict(self):
@ -265,7 +259,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual( self.assertSetEqual(
{"START", "A", "C"}, {"START", "A", "C"},
{e_id for e_id in context_store["D"].current_state_ids.values()} {e_id for e_id in context_store["D"].prev_state_ids.values()}
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -331,7 +325,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual( self.assertSetEqual(
{"START", "A", "B", "C"}, {"START", "A", "B", "C"},
{e for e in context_store["E"].current_state_ids.values()} {e for e in context_store["E"].prev_state_ids.values()}
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -414,7 +408,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual( self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"}, {"A1", "A2", "A3", "A5", "B"},
{e for e in context_store["D"].current_state_ids.values()} {e for e in context_store["D"].prev_state_ids.values()}
) )
def _add_depths(self, nodes, edges): def _add_depths(self, nodes, edges):
@ -447,7 +441,7 @@ class StateTestCase(unittest.TestCase):
set(e.event_id for e in old_state), set(context.current_state_ids.values()) set(e.event_id for e in old_state), set(context.current_state_ids.values())
) )
self.assertIsNone(context.state_group) self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_annotate_with_old_state(self): def test_annotate_with_old_state(self):
@ -464,11 +458,9 @@ class StateTestCase(unittest.TestCase):
) )
self.assertEqual( self.assertEqual(
set(e.event_id for e in old_state), set(context.current_state_ids.values()) set(e.event_id for e in old_state), set(context.prev_state_ids.values())
) )
self.assertIsNone(context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_message(self): def test_trivial_annotate_message(self):
event = create_event(type="test_message", name="event") event = create_event(type="test_message", name="event")
@ -514,10 +506,10 @@ class StateTestCase(unittest.TestCase):
self.assertEqual( self.assertEqual(
set([e.event_id for e in old_state]), set([e.event_id for e in old_state]),
set(context.current_state_ids.values()) set(context.prev_state_ids.values())
) )
self.assertIsNone(context.state_group) self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_message_conflict(self): def test_resolve_message_conflict(self):
@ -550,7 +542,7 @@ class StateTestCase(unittest.TestCase):
self.assertEqual(len(context.current_state_ids), 6) self.assertEqual(len(context.current_state_ids), 6)
self.assertIsNone(context.state_group) self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_state_conflict(self): def test_resolve_state_conflict(self):
@ -583,7 +575,7 @@ class StateTestCase(unittest.TestCase):
self.assertEqual(len(context.current_state_ids), 6) self.assertEqual(len(context.current_state_ids), 6)
self.assertIsNone(context.state_group) self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_standard_depth_conflict(self): def test_standard_depth_conflict(self):