Begin making auth use event.old_state_events

This commit is contained in:
Erik Johnston 2014-10-15 16:06:59 +01:00
parent 80472ac198
commit e7bc1291a0
10 changed files with 115 additions and 83 deletions

View File

@ -21,6 +21,7 @@ from synapse.api.constants import Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
from synapse.api.events.room import ( from synapse.api.events.room import (
RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent, RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent,
RoomJoinRulesEvent, RoomOpsPowerLevelsEvent,
) )
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -55,11 +56,7 @@ class Auth(object):
defer.returnValue(allowed) defer.returnValue(allowed)
return return
self._check_joined_room( self.check_event_sender_in_room(event)
member=snapshot.membership_state,
user_id=snapshot.user_id,
room_id=snapshot.room_id,
)
if is_state: if is_state:
# TODO (erikj): This really only should be called for *new* # TODO (erikj): This really only should be called for *new*
@ -98,6 +95,16 @@ class Auth(object):
pass pass
defer.returnValue(None) defer.returnValue(None)
def check_event_sender_in_room(self, event):
key = (RoomMemberEvent.TYPE, event.user_id, )
member_event = event.state_events.get(key)
return self._check_joined_room(
member_event,
event.user_id,
event.room_id
)
def _check_joined_room(self, member, user_id, room_id): def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN: if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s (%s)" % ( raise AuthError(403, "User %s not in room %s (%s)" % (
@ -114,29 +121,39 @@ class Auth(object):
raise AuthError(403, "Room does not exist") raise AuthError(403, "Room does not exist")
# get info about the caller # get info about the caller
try: key = (RoomMemberEvent.TYPE, event.user_id, )
caller = yield self.store.get_room_member( caller = event.old_state_events.get(key)
user_id=event.user_id,
room_id=event.room_id)
except:
caller = None
caller_in_room = caller and caller.membership == "join" caller_in_room = caller and caller.membership == "join"
# get info about the target # get info about the target
try: key = (RoomMemberEvent.TYPE, target_user_id, )
target = yield self.store.get_room_member( target = event.old_state_events.get(key)
user_id=target_user_id,
room_id=event.room_id)
except:
target = None
target_in_room = target and target.membership == "join" target_in_room = target and target.membership == "join"
membership = event.content["membership"] membership = event.content["membership"]
join_rule = yield self.store.get_room_join_rule(event.room_id) key = (RoomJoinRulesEvent.TYPE, "", )
if not join_rule: join_rule_event = event.old_state_events.get(key)
if join_rule_event:
join_rule = join_rule_event.content.get(
"join_rule", JoinRules.INVITE
)
else:
join_rule = JoinRules.INVITE join_rule = JoinRules.INVITE
user_level = self._get_power_level_from_event_state(
event,
event.user_id,
)
ban_level, kick_level, redact_level = (
yield self._get_ops_level_from_event_state(
event
)
)
if Membership.INVITE == membership: if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently # TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules. # PRIVATE join rules.
@ -171,29 +188,16 @@ class Auth(object):
if not caller_in_room: # trying to leave a room you aren't joined if not caller_in_room: # trying to leave a room you aren't joined
raise AuthError(403, "You are not in room %s." % event.room_id) raise AuthError(403, "You are not in room %s." % event.room_id)
elif target_user_id != event.user_id: elif target_user_id != event.user_id:
user_level = yield self.store.get_power_level(
event.room_id,
event.user_id,
)
_, kick_level, _ = yield self.store.get_ops_levels(event.room_id)
if kick_level: if kick_level:
kick_level = int(kick_level) kick_level = int(kick_level)
else: else:
kick_level = 50 kick_level = 50 # FIXME (erikj): What should we do here?
if user_level < kick_level: if user_level < kick_level:
raise AuthError( raise AuthError(
403, "You cannot kick user %s." % target_user_id 403, "You cannot kick user %s." % target_user_id
) )
elif Membership.BAN == membership: elif Membership.BAN == membership:
user_level = yield self.store.get_power_level(
event.room_id,
event.user_id,
)
ban_level, _, _ = yield self.store.get_ops_levels(event.room_id)
if ban_level: if ban_level:
ban_level = int(ban_level) ban_level = int(ban_level)
else: else:
@ -206,6 +210,29 @@ class Auth(object):
defer.returnValue(True) defer.returnValue(True)
def _get_power_level_from_event_state(self, event, user_id):
key = (RoomPowerLevelsEvent.TYPE, "", )
power_level_event = event.old_state_events.get(key)
level = None
if power_level_event:
level = power_level_event.content[user_id]
if not level:
level = power_level_event.content["default"]
return level
def _get_ops_level_from_event_state(self, event):
key = (RoomOpsPowerLevelsEvent.TYPE, "", )
ops_event = event.old_state_events.get(key)
if ops_event:
return (
ops_event.content.get("ban_level"),
ops_event.content.get("kick_level"),
ops_event.content.get("redact_level"),
)
return None, None, None,
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_req(self, request): def get_user_by_req(self, request):
""" Get a registered user's ID. """ Get a registered user's ID.
@ -282,8 +309,8 @@ class Auth(object):
else: else:
send_level = 0 send_level = 0
user_level = yield self.store.get_power_level( user_level = self._get_power_level_from_event_state(
event.room_id, event,
event.user_id, event.user_id,
) )
@ -308,8 +335,8 @@ class Auth(object):
add_level = int(add_level) add_level = int(add_level)
user_level = yield self.store.get_power_level( user_level = self._get_power_level_from_event_state(
event.room_id, event,
event.user_id, event.user_id,
) )
@ -333,8 +360,8 @@ class Auth(object):
if current_state: if current_state:
current_state = current_state[0] current_state = current_state[0]
user_level = yield self.store.get_power_level( user_level = self._get_power_level_from_event_state(
event.room_id, event,
event.user_id, event.user_id,
) )
@ -363,10 +390,10 @@ class Auth(object):
event.user_id, event.user_id,
) )
if user_level: user_level = self._get_power_level_from_event_state(
user_level = int(user_level) event,
else: event.user_id,
user_level = 0 )
_, _, redact_level = yield self.store.get_ops_levels(event.room_id) _, _, redact_level = yield self.store.get_ops_levels(event.room_id)

View File

@ -44,9 +44,17 @@ class BaseHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _on_new_room_event(self, event, snapshot, extra_destinations=[], def _on_new_room_event(self, event, snapshot, extra_destinations=[],
extra_users=[]): extra_users=[], suppress_auth=False):
snapshot.fill_out_prev_events(event) snapshot.fill_out_prev_events(event)
yield self.state_handler.annotate_state_groups(event)
if not suppress_auth:
yield self.auth.check(event, snapshot, raises=True)
if hasattr(event, "state_key"):
yield self.state_handler.handle_new_event(event, snapshot)
yield self.store.persist_event(event) yield self.store.persist_event(event)
destinations = set(extra_destinations) destinations = set(extra_destinations)

View File

@ -152,5 +152,6 @@ class DirectoryHandler(BaseHandler):
user_id=user_id, user_id=user_id,
) )
yield self.state_handler.handle_new_event(event, snapshot) yield self._on_new_room_event(
yield self._on_new_room_event(event, snapshot, extra_users=[user_id]) event, snapshot, extra_users=[user_id], suppress_auth=True
)

View File

@ -95,6 +95,8 @@ class FederationHandler(BaseHandler):
logger.debug("Got event: %s", event.event_id) logger.debug("Got event: %s", event.event_id)
yield self.state_handler.annotate_state_groups(event)
with (yield self.lock_manager.lock(pdu.context)): with (yield self.lock_manager.lock(pdu.context)):
if event.is_state and not backfilled: if event.is_state and not backfilled:
is_new_state = yield self.state_handler.handle_new_state( is_new_state = yield self.state_handler.handle_new_state(
@ -195,7 +197,12 @@ class FederationHandler(BaseHandler):
for pdu in pdus: for pdu in pdus:
event = self.pdu_codec.event_from_pdu(pdu) event = self.pdu_codec.event_from_pdu(pdu)
# FIXME (erikj): Not sure this actually works :/
yield self.state_handler.annotate_state_groups(event)
events.append(event) events.append(event)
yield self.store.persist_event(event, backfilled=True) yield self.store.persist_event(event, backfilled=True)
defer.returnValue(events) defer.returnValue(events)
@ -235,6 +242,7 @@ class FederationHandler(BaseHandler):
new_event.destinations = [target_host] new_event.destinations = [target_host]
snapshot.fill_out_prev_events(new_event) snapshot.fill_out_prev_events(new_event)
yield self.state_handler.annotate_state_groups(new_event)
yield self.handle_new_event(new_event, snapshot) yield self.handle_new_event(new_event, snapshot)
# TODO (erikj): Time out here. # TODO (erikj): Time out here.
@ -254,12 +262,11 @@ class FederationHandler(BaseHandler):
is_public=False is_public=False
) )
except: except:
# FIXME
pass pass
defer.returnValue(True) defer.returnValue(True)
@log_function @log_function
def _on_user_joined(self, user, room_id): def _on_user_joined(self, user, room_id):
waiters = self.waiting_for_join_list.get((user.to_string(), room_id), []) waiters = self.waiting_for_join_list.get((user.to_string(), room_id), [])

View File

@ -87,10 +87,9 @@ class MessageHandler(BaseHandler):
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
if not suppress_auth: yield self._on_new_room_event(
yield self.auth.check(event, snapshot, raises=True) event, snapshot, suppress_auth=suppress_auth
)
yield self._on_new_room_event(event, snapshot)
self.hs.get_handlers().presence_handler.bump_presence_active_time( self.hs.get_handlers().presence_handler.bump_presence_active_time(
user user
@ -149,13 +148,9 @@ class MessageHandler(BaseHandler):
state_key=event.state_key, state_key=event.state_key,
) )
yield self.auth.check(event, snapshot, raises=True)
if stamp_event: if stamp_event:
event.content["hsob_ts"] = int(self.clock.time_msec()) event.content["hsob_ts"] = int(self.clock.time_msec())
yield self.state_handler.handle_new_event(event, snapshot)
yield self._on_new_room_event(event, snapshot) yield self._on_new_room_event(event, snapshot)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -227,8 +222,6 @@ class MessageHandler(BaseHandler):
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
yield self.auth.check(event, snapshot, raises=True)
# store message in db # store message in db
yield self._on_new_room_event(event, snapshot) yield self._on_new_room_event(event, snapshot)

View File

@ -218,5 +218,6 @@ class ProfileHandler(BaseHandler):
user_id=j.state_key, user_id=j.state_key,
) )
yield self.state_handler.handle_new_event(new_event, snapshot) yield self._on_new_room_event(
yield self._on_new_room_event(new_event, snapshot) new_event, snapshot, suppress_auth=True
)

View File

@ -129,8 +129,9 @@ class RoomCreationHandler(BaseHandler):
logger.debug("Event: %s", event) logger.debug("Event: %s", event)
yield self.state_handler.handle_new_event(event, snapshot) yield self._on_new_room_event(
yield self._on_new_room_event(event, snapshot, extra_users=[user]) event, snapshot, extra_users=[user], suppress_auth=True
)
for event in creation_events: for event in creation_events:
yield handle_event(event) yield handle_event(event)
@ -396,8 +397,6 @@ class RoomMemberHandler(BaseHandler):
yield self._do_join(event, snapshot, do_auth=do_auth) yield self._do_join(event, snapshot, do_auth=do_auth)
else: else:
# This is not a JOIN, so we can handle it normally. # This is not a JOIN, so we can handle it normally.
if do_auth:
yield self.auth.check(event, snapshot, raises=True)
# If we're banning someone, set a req power level # If we're banning someone, set a req power level
if event.membership == Membership.BAN: if event.membership == Membership.BAN:
@ -419,6 +418,7 @@ class RoomMemberHandler(BaseHandler):
event, event,
membership=event.content["membership"], membership=event.content["membership"],
snapshot=snapshot, snapshot=snapshot,
do_auth=do_auth,
) )
defer.returnValue({"room_id": room_id}) defer.returnValue({"room_id": room_id})
@ -507,14 +507,11 @@ class RoomMemberHandler(BaseHandler):
if not have_joined: if not have_joined:
logger.debug("Doing normal join") logger.debug("Doing normal join")
if do_auth:
yield self.auth.check(event, snapshot, raises=True)
yield self.state_handler.handle_new_event(event, snapshot)
yield self._do_local_membership_update( yield self._do_local_membership_update(
event, event,
membership=event.content["membership"], membership=event.content["membership"],
snapshot=snapshot, snapshot=snapshot,
do_auth=do_auth,
) )
user = self.hs.parse_userid(event.user_id) user = self.hs.parse_userid(event.user_id)
@ -558,7 +555,8 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue([r.room_id for r in rooms]) defer.returnValue([r.room_id for r in rooms])
def _do_local_membership_update(self, event, membership, snapshot): def _do_local_membership_update(self, event, membership, snapshot,
do_auth):
destinations = [] destinations = []
# If we're inviting someone, then we should also send it to that # If we're inviting someone, then we should also send it to that
@ -575,9 +573,10 @@ class RoomMemberHandler(BaseHandler):
return self._on_new_room_event( return self._on_new_room_event(
event, snapshot, extra_destinations=destinations, event, snapshot, extra_destinations=destinations,
extra_users=[target_user] extra_users=[target_user], suppress_auth=(not do_auth),
) )
class RoomListHandler(BaseHandler): class RoomListHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -71,6 +71,7 @@ class StateHandler(object):
# (w.r.t. to power levels) # (w.r.t. to power levels)
snapshot.fill_out_prev_events(event) snapshot.fill_out_prev_events(event)
yield self.annotate_state_groups(event)
event.prev_events = [ event.prev_events = [
e for e in event.prev_events if e != event.event_id e for e in event.prev_events if e != event.event_id
@ -83,8 +84,6 @@ class StateHandler(object):
current_state.pdu_id, current_state.origin current_state.pdu_id, current_state.origin
) )
yield self.update_state_groups(event)
# TODO check current_state to see if the min power level is less # TODO check current_state to see if the min power level is less
# than the power level of the user # than the power level of the user
# power_level = self._get_power_level_for_event(event) # power_level = self._get_power_level_for_event(event)
@ -131,21 +130,16 @@ class StateHandler(object):
defer.returnValue(is_new) defer.returnValue(is_new)
@defer.inlineCallbacks @defer.inlineCallbacks
def update_state_groups(self, event): def annotate_state_groups(self, event):
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
event.prev_events event.prev_events
) )
if len(state_groups) == 1 and not hasattr(event, "state_key"):
event.state_group = state_groups[0].group
event.current_state = state_groups[0].state
return
state = {} state = {}
state_sets = {} state_sets = {}
for group in state_groups: for group in state_groups:
for s in group.state: for s in group.state:
state.setdefault((s.type, s.state_key), []).add(s) state.setdefault((s.type, s.state_key), []).append(s)
state_sets.setdefault( state_sets.setdefault(
(s.type, s.state_key), (s.type, s.state_key),
@ -153,7 +147,7 @@ class StateHandler(object):
).add(s.event_id) ).add(s.event_id)
unconflicted_state = { unconflicted_state = {
k: v.pop() for k, v in state_sets.items() k: state[k].pop() for k, v in state_sets.items()
if len(v) == 1 if len(v) == 1
} }
@ -168,11 +162,13 @@ class StateHandler(object):
for key, events in conflicted_state.items(): for key, events in conflicted_state.items():
new_state[key] = yield self.resolve(events) new_state[key] = yield self.resolve(events)
event.old_state_events = new_state
if hasattr(event, "state_key"): if hasattr(event, "state_key"):
new_state[(event.type, event.state_key)] = event new_state[(event.type, event.state_key)] = event
event.state_group = None event.state_group = None
event.current_state = new_state.values() event.state_events = new_state
@defer.inlineCallbacks @defer.inlineCallbacks
def resolve(self, events): def resolve(self, events):

View File

@ -14,7 +14,7 @@
*/ */
CREATE TABLE IF NOT EXISTS state_groups( CREATE TABLE IF NOT EXISTS state_groups(
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
event_id TEXT NOT NULL event_id TEXT NOT NULL
); );

View File

@ -74,7 +74,7 @@ class StateStore(SQLBaseStore):
} }
) )
for state in event.state_events: for state in event.state_events.values():
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="state_groups_state", table="state_groups_state",