Add is_guest flag to users db to track whether a user is a guest user or not. Use this so we can run _filter_events_for_client when calculating event_push_actions.

This commit is contained in:
David Baker 2016-01-06 11:38:09 +00:00
parent eb03625626
commit c79f221192
9 changed files with 69 additions and 31 deletions

View File

@ -23,8 +23,6 @@ from synapse.push.action_generator import ActionGenerator
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.events.utils import serialize_event
import logging import logging
@ -256,9 +254,9 @@ class BaseHandler(object):
) )
action_generator = ActionGenerator(self.store) action_generator = ActionGenerator(self.store)
yield action_generator.handle_push_actions_for_event(serialize_event( yield action_generator.handle_push_actions_for_event(
event, self.clock.time_msec() event, self
)) )
destinations = set(extra_destinations) destinations = set(extra_destinations)
for k, s in context.current_state.items(): for k, s in context.current_state.items():

View File

@ -32,7 +32,7 @@ from synapse.crypto.event_signing import (
) )
from synapse.types import UserID from synapse.types import UserID
from synapse.events.utils import prune_event, serialize_event from synapse.events.utils import prune_event
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
@ -246,8 +246,8 @@ class FederationHandler(BaseHandler):
if not backfilled and not event.internal_metadata.is_outlier(): if not backfilled and not event.internal_metadata.is_outlier():
action_generator = ActionGenerator(self.store) action_generator = ActionGenerator(self.store)
yield action_generator.handle_push_actions_for_event(serialize_event( yield action_generator.handle_push_actions_for_event(
event, self.clock.time_msec()) event, self
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -84,7 +84,8 @@ class RegistrationHandler(BaseHandler):
localpart=None, localpart=None,
password=None, password=None,
generate_token=True, generate_token=True,
guest_access_token=None guest_access_token=None,
make_guest=False
): ):
"""Registers a new client on the server. """Registers a new client on the server.
@ -118,6 +119,7 @@ class RegistrationHandler(BaseHandler):
token=token, token=token,
password_hash=password_hash, password_hash=password_hash,
was_guest=guest_access_token is not None, was_guest=guest_access_token is not None,
make_guest=make_guest
) )
yield registered_user(self.distributor, user) yield registered_user(self.distributor, user)

View File

@ -33,12 +33,12 @@ class ActionGenerator:
# tag (ie. we just need all the users). # tag (ie. we just need all the users).
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_push_actions_for_event(self, event): def handle_push_actions_for_event(self, event, handler):
bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id( bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id(
event['room_id'], self.store event.room_id, self.store
) )
actions_by_user = bulk_evaluator.action_for_event_by_user(event) actions_by_user = yield bulk_evaluator.action_for_event_by_user(event, handler)
yield self.store.set_push_actions_for_event_and_users( yield self.store.set_push_actions_for_event_and_users(
event, event,

View File

@ -23,6 +23,8 @@ from synapse.types import UserID
import baserules import baserules
from push_rule_evaluator import PushRuleEvaluator from push_rule_evaluator import PushRuleEvaluator
from synapse.events.utils import serialize_event
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -54,7 +56,7 @@ def evaluator_for_room_id(room_id, store):
display_names[ev.state_key] = ev.content.get("displayname") display_names[ev.state_key] = ev.content.get("displayname")
defer.returnValue(BulkPushRuleEvaluator( defer.returnValue(BulkPushRuleEvaluator(
room_id, rules_by_user, display_names, users room_id, rules_by_user, display_names, users, store
)) ))
@ -67,13 +69,15 @@ class BulkPushRuleEvaluator:
the same logic to run the actual rules, but could be optimised further the same logic to run the actual rules, but could be optimised further
(see https://matrix.org/jira/browse/SYN-562) (see https://matrix.org/jira/browse/SYN-562)
""" """
def __init__(self, room_id, rules_by_user, display_names, users_in_room): def __init__(self, room_id, rules_by_user, display_names, users_in_room, store):
self.room_id = room_id self.room_id = room_id
self.rules_by_user = rules_by_user self.rules_by_user = rules_by_user
self.display_names = display_names self.display_names = display_names
self.users_in_room = users_in_room self.users_in_room = users_in_room
self.store = store
def action_for_event_by_user(self, event): @defer.inlineCallbacks
def action_for_event_by_user(self, event, handler):
actions_by_user = {} actions_by_user = {}
for uid, rules in self.rules_by_user.items(): for uid, rules in self.rules_by_user.items():
@ -81,6 +85,13 @@ class BulkPushRuleEvaluator:
if uid in self.display_names: if uid in self.display_names:
display_name = self.display_names[uid] display_name = self.display_names[uid]
is_guest = yield self.store.is_guest(UserID.from_string(uid))
filtered = yield handler._filter_events_for_client(
uid, [event], is_guest=is_guest
)
if len(filtered) == 0:
continue
for rule in rules: for rule in rules:
if 'enabled' in rule and not rule['enabled']: if 'enabled' in rule and not rule['enabled']:
continue continue
@ -94,14 +105,20 @@ class BulkPushRuleEvaluator:
if len(actions) > 0: if len(actions) > 0:
actions_by_user[uid] = actions actions_by_user[uid] = actions
break break
return actions_by_user defer.returnValue(actions_by_user)
@staticmethod @staticmethod
def event_matches_rule(event, rule, def event_matches_rule(event, rule,
display_name, room_member_count, profile_tag): display_name, room_member_count, profile_tag):
matches = True matches = True
# passing the clock all the way into here is extremely awkward and push
# rules do not care about any of the relative timestamps, so we just
# pass 0 for the current time.
client_event = serialize_event(event, 0)
for cond in rule['conditions']: for cond in rule['conditions']:
matches &= PushRuleEvaluator._event_fulfills_condition( matches &= PushRuleEvaluator._event_fulfills_condition(
event, cond, display_name, room_member_count, profile_tag client_event, cond, display_name, room_member_count, profile_tag
) )
return matches return matches

View File

@ -259,7 +259,10 @@ class RegisterRestServlet(RestServlet):
def _do_guest_registration(self): def _do_guest_registration(self):
if not self.hs.config.allow_guest_access: if not self.hs.config.allow_guest_access:
defer.returnValue((403, "Guest access is disabled")) defer.returnValue((403, "Guest access is disabled"))
user_id, _ = yield self.registration_handler.register(generate_token=False) user_id, _ = yield self.registration_handler.register(
generate_token=False,
make_guest=True
)
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"]) access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
defer.returnValue((200, { defer.returnValue((200, {
"user_id": user_id, "user_id": user_id,

View File

@ -32,8 +32,8 @@ class EventPushActionsStore(SQLBaseStore):
values = [] values = []
for uid, profile_tag, actions in tuples: for uid, profile_tag, actions in tuples:
values.append({ values.append({
'room_id': event['room_id'], 'room_id': event.room_id,
'event_id': event['event_id'], 'event_id': event.event_id,
'user_id': uid, 'user_id': uid,
'profile_tag': profile_tag, 'profile_tag': profile_tag,
'actions': json.dumps(actions) 'actions': json.dumps(actions)

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError, Codes from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
class RegistrationStore(SQLBaseStore): class RegistrationStore(SQLBaseStore):
@ -73,7 +73,8 @@ class RegistrationStore(SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, user_id, token, password_hash, was_guest=False): def register(self, user_id, token, password_hash,
was_guest=False, make_guest=False):
"""Attempts to register an account. """Attempts to register an account.
Args: Args:
@ -82,15 +83,18 @@ class RegistrationStore(SQLBaseStore):
password_hash (str): Optional. The password hash for this user. password_hash (str): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account. upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
false to add a regular user account.
Raises: Raises:
StoreError if the user_id could not be registered. StoreError if the user_id could not be registered.
""" """
yield self.runInteraction( yield self.runInteraction(
"register", "register",
self._register, user_id, token, password_hash, was_guest self._register, user_id, token, password_hash, was_guest, make_guest
) )
self.is_guest.invalidate((user_id,))
def _register(self, txn, user_id, token, password_hash, was_guest): def _register(self, txn, user_id, token, password_hash, was_guest, make_guest):
now = int(self.clock.time()) now = int(self.clock.time())
next_id = self._access_tokens_id_gen.get_next_txn(txn) next_id = self._access_tokens_id_gen.get_next_txn(txn)
@ -100,12 +104,14 @@ class RegistrationStore(SQLBaseStore):
txn.execute("UPDATE users SET" txn.execute("UPDATE users SET"
" password_hash = ?," " password_hash = ?,"
" upgrade_ts = ?" " upgrade_ts = ?"
" is_guest = ?"
" WHERE name = ?", " WHERE name = ?",
[password_hash, now, user_id]) [password_hash, now, make_guest, user_id])
else: else:
txn.execute("INSERT INTO users(name, password_hash, creation_ts) " txn.execute("INSERT INTO users "
"VALUES (?,?,?)", "(name, password_hash, creation_ts, is_guest) "
[user_id, password_hash, now]) "VALUES (?,?,?,?)",
[user_id, password_hash, now, make_guest])
except self.database_engine.module.IntegrityError: except self.database_engine.module.IntegrityError:
raise StoreError( raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE 400, "User ID already taken.", errcode=Codes.USER_IN_USE
@ -126,7 +132,7 @@ class RegistrationStore(SQLBaseStore):
keyvalues={ keyvalues={
"name": user_id, "name": user_id,
}, },
retcols=["name", "password_hash"], retcols=["name", "password_hash", "is_guest"],
allow_none=True, allow_none=True,
) )
@ -136,7 +142,7 @@ class RegistrationStore(SQLBaseStore):
""" """
def f(txn): def f(txn):
sql = ( sql = (
"SELECT name, password_hash FROM users" "SELECT name, password_hash, is_guest FROM users"
" WHERE lower(name) = lower(?)" " WHERE lower(name) = lower(?)"
) )
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
@ -249,9 +255,21 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False) defer.returnValue(res if res else False)
@cachedInlineCallbacks()
def is_guest(self, user):
res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="is_guest",
allow_none=True,
desc="is_guest",
)
defer.returnValue(res if res else False)
def _query_for_auth(self, txn, token): def _query_for_auth(self, txn, token):
sql = ( sql = (
"SELECT users.name, access_tokens.id as token_id" "SELECT users.name, users.is_guest, access_tokens.id as token_id"
" FROM users" " FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id" " INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?" " WHERE token = ?"