Merge branch 'develop' into dbkr/email_notifs_on_pusher

This commit is contained in:
Mark Haines 2016-05-13 16:59:56 +01:00
commit 1f71f386f6
19 changed files with 511 additions and 27 deletions

View File

@ -612,7 +612,8 @@ class Auth(object):
def get_user_from_macaroon(self, macaroon_str): def get_user_from_macaroon(self, macaroon_str):
try: try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self.validate_macaroon(macaroon, "access", False)
self.validate_macaroon(macaroon, "access", self.hs.config.expire_access_token)
user_prefix = "user_id = " user_prefix = "user_id = "
user = None user = None

View File

@ -57,6 +57,8 @@ class KeyConfig(Config):
seed = self.signing_key[0].seed seed = self.signing_key[0].seed
self.macaroon_secret_key = hashlib.sha256(seed) self.macaroon_secret_key = hashlib.sha256(seed)
self.expire_access_token = config.get("expire_access_token", False)
def default_config(self, config_dir_path, server_name, is_generating_file=False, def default_config(self, config_dir_path, server_name, is_generating_file=False,
**kwargs): **kwargs):
base_key_name = os.path.join(config_dir_path, server_name) base_key_name = os.path.join(config_dir_path, server_name)
@ -69,6 +71,9 @@ class KeyConfig(Config):
return """\ return """\
macaroon_secret_key: "%(macaroon_secret_key)s" macaroon_secret_key: "%(macaroon_secret_key)s"
# Used to enable access token expiration.
expire_access_token: False
## Signing Keys ## ## Signing Keys ##
# Path to the signing key to sign messages with # Path to the signing key to sign messages with

View File

@ -32,6 +32,7 @@ class RegistrationConfig(Config):
) )
self.registration_shared_secret = config.get("registration_shared_secret") self.registration_shared_secret = config.get("registration_shared_secret")
self.user_creation_max_duration = int(config["user_creation_max_duration"])
self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"] self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
@ -54,6 +55,11 @@ class RegistrationConfig(Config):
# secret, even if registration is otherwise disabled. # secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s" registration_shared_secret: "%(registration_shared_secret)s"
# Sets the expiry for the short term user creation in
# milliseconds. For instance the bellow duration is two weeks
# in milliseconds.
user_creation_max_duration: 1209600000
# Set the number of bcrypt rounds used to generate password hash. # Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash. # Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12. # The default number of rounds is 12.

View File

@ -521,11 +521,11 @@ class AuthHandler(BaseHandler):
)) ))
return m.serialize() return m.serialize()
def generate_short_term_login_token(self, user_id): def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login") macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
expiry = now + (2 * 60 * 1000) expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize() return macaroon.serialize()

View File

@ -358,6 +358,59 @@ class RegistrationHandler(BaseHandler):
) )
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks
def get_or_create_user(self, localpart, displayname, duration_seconds):
"""Creates a new user or returns an access token for an existing one
Args:
localpart : The local part of the user ID to register. If None,
one will be randomly generated.
Returns:
A tuple of (user_id, access_token).
Raises:
RegistrationError if there was a problem registering.
"""
yield run_on_reactor()
if localpart is None:
raise SynapseError(400, "Request must include user id")
need_register = True
try:
yield self.check_username(localpart)
except SynapseError as e:
if e.errcode == Codes.USER_IN_USE:
need_register = False
else:
raise
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
auth_handler = self.hs.get_handlers().auth_handler
token = auth_handler.generate_short_term_login_token(user_id, duration_seconds)
if need_register:
yield self.store.register(
user_id=user_id,
token=token,
password_hash=None
)
yield registered_user(self.distributor, user)
else:
yield self.store.flush_user(user_id=user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler
yield profile_handler.set_displayname(
user, user, displayname
)
defer.returnValue((user_id, token))
def auth_handler(self): def auth_handler(self):
return self.hs.get_handlers().auth_handler return self.hs.get_handlers().auth_handler

View File

@ -164,8 +164,8 @@ class ReplicationResource(Resource):
"Replicating %d rows of %s from %s -> %s", "Replicating %d rows of %s from %s -> %s",
len(stream_content["rows"]), len(stream_content["rows"]),
stream_name, stream_name,
stream_content["position"],
request_streams.get(stream_name), request_streams.get(stream_name),
stream_content["position"],
) )
request.write(json.dumps(result, ensure_ascii=False)) request.write(json.dumps(result, ensure_ascii=False))

View File

@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage.account_data import AccountDataStore
class SlavedAccountDataStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedAccountDataStore, self).__init__(db_conn, hs)
self._account_data_id_gen = SlavedIdTracker(
db_conn, "account_data_max_stream_id", "stream_id",
)
get_global_account_data_by_type_for_users = (
AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
)
get_global_account_data_by_type_for_user = (
AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
)
def stream_positions(self):
result = super(SlavedAccountDataStore, self).stream_positions()
position = self._account_data_id_gen.get_current_token()
result["user_account_data"] = position
result["room_account_data"] = position
result["tag_account_data"] = position
return result
def process_replication(self, result):
stream = result.get("user_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
user_id, data_type = row[1:3]
self.get_global_account_data_by_type_for_user.invalidate(
(data_type, user_id,)
)
stream = result.get("room_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
stream = result.get("tag_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))

View File

@ -165,12 +165,14 @@ class SlavedEventStore(BaseSlavedStore):
stream = result.get("forward_ex_outliers") stream = result.get("forward_ex_outliers")
if stream: if stream:
self._stream_id_gen.advance(stream["position"])
for row in stream["rows"]: for row in stream["rows"]:
event_id = row[1] event_id = row[1]
self._invalidate_get_event_cache(event_id) self._invalidate_get_event_cache(event_id)
stream = result.get("backward_ex_outliers") stream = result.get("backward_ex_outliers")
if stream: if stream:
self._backfill_id_gen.advance(-stream["position"])
for row in stream["rows"]: for row in stream["rows"]:
event_id = row[1] event_id = row[1]
self._invalidate_get_event_cache(event_id) self._invalidate_get_event_cache(event_id)

View File

@ -355,5 +355,76 @@ class RegisterRestServlet(ClientV1RestServlet):
) )
class CreateUserRestServlet(ClientV1RestServlet):
"""Handles user creation via a server-to-server interface
"""
PATTERNS = client_path_patterns("/createUser$", releases=())
def __init__(self, hs):
super(CreateUserRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
@defer.inlineCallbacks
def on_POST(self, request):
user_json = parse_json_object_from_request(request)
if "access_token" not in request.args:
raise SynapseError(400, "Expected application service token.")
app_service = yield self.store.get_app_service_by_token(
request.args["access_token"][0]
)
if not app_service:
raise SynapseError(403, "Invalid application service token.")
logger.debug("creating user: %s", user_json)
response = yield self._do_create(user_json)
defer.returnValue((200, response))
def on_OPTIONS(self, request):
return 403, {}
@defer.inlineCallbacks
def _do_create(self, user_json):
yield run_on_reactor()
if "localpart" not in user_json:
raise SynapseError(400, "Expected 'localpart' key.")
if "displayname" not in user_json:
raise SynapseError(400, "Expected 'displayname' key.")
if "duration_seconds" not in user_json:
raise SynapseError(400, "Expected 'duration_seconds' key.")
localpart = user_json["localpart"].encode("utf-8")
displayname = user_json["displayname"].encode("utf-8")
duration_seconds = 0
try:
duration_seconds = int(user_json["duration_seconds"])
except ValueError:
raise SynapseError(400, "Failed to parse 'duration_seconds'")
if duration_seconds > self.direct_user_creation_max_duration:
duration_seconds = self.direct_user_creation_max_duration
handler = self.handlers.registration_handler
user_id, token = yield handler.get_or_create_user(
localpart=localpart,
displayname=displayname,
duration_seconds=duration_seconds
)
defer.returnValue({
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
})
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)
CreateUserRestServlet(hs).register(http_server)

View File

@ -453,7 +453,9 @@ class SQLBaseStore(object):
keyvalues (dict): The unique key tables and their new values keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values values (dict): The nonunique columns and their new values
insertion_values (dict): key/values to use when inserting insertion_values (dict): key/values to use when inserting
Returns: A deferred Returns:
Deferred(bool): True if a new entry was created, False if an
existing one was updated.
""" """
return self.runInteraction( return self.runInteraction(
desc, desc,
@ -498,6 +500,10 @@ class SQLBaseStore(object):
) )
txn.execute(sql, allvalues.values()) txn.execute(sql, allvalues.values())
return True
else:
return False
def _simple_select_one(self, table, keyvalues, retcols, def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False, desc="_simple_select_one"): allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to

View File

@ -224,6 +224,18 @@ class EventPushActionsStore(SQLBaseStore):
(room_id, event_id) (room_id, event_id)
) )
def _remove_push_actions_before_txn(self, txn, room_id, user_id,
topological_ordering):
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(room_id, user_id, )
)
txn.execute(
"DELETE FROM event_push_actions"
" WHERE room_id = ? AND user_id = ? AND topological_ordering < ?",
(room_id, user_id, topological_ordering,)
)
def _action_has_highlight(actions): def _action_has_highlight(actions):
for action in actions: for action in actions:

View File

@ -156,8 +156,7 @@ class PusherStore(SQLBaseStore):
profile_tag=""): profile_tag=""):
with self._pushers_id_gen.get_next() as stream_id: with self._pushers_id_gen.get_next() as stream_id:
def f(txn): def f(txn):
txn.call_after(self.get_users_with_pushers_in_room.invalidate_all) newly_inserted = self._simple_upsert_txn(
return self._simple_upsert_txn(
txn, txn,
"pushers", "pushers",
{ {
@ -178,11 +177,18 @@ class PusherStore(SQLBaseStore):
"id": stream_id, "id": stream_id,
}, },
) )
defer.returnValue((yield self.runInteraction("add_pusher", f))) if newly_inserted:
# get_users_with_pushers_in_room only cares if the user has
# at least *one* pusher.
txn.call_after(self.get_users_with_pushers_in_room.invalidate_all)
yield self.runInteraction("add_pusher", f)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
def delete_pusher_txn(txn, stream_id): def delete_pusher_txn(txn, stream_id):
txn.call_after(self.get_users_with_pushers_in_room.invalidate_all)
self._simple_delete_one_txn( self._simple_delete_one_txn(
txn, txn,
"pushers", "pushers",
@ -194,6 +200,7 @@ class PusherStore(SQLBaseStore):
{"app_id": app_id, "pushkey": pushkey, "user_id": user_id}, {"app_id": app_id, "pushkey": pushkey, "user_id": user_id},
{"stream_id": stream_id}, {"stream_id": stream_id},
) )
with self._pushers_id_gen.get_next() as stream_id: with self._pushers_id_gen.get_next() as stream_id:
yield self.runInteraction( yield self.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id "delete_pusher", delete_pusher_txn, stream_id

View File

@ -100,7 +100,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue([ev for res in results.values() for ev in res]) defer.returnValue([ev for res in results.values() for ev in res])
@cachedInlineCallbacks(num_args=3, max_entries=5000) @cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients. """Get receipts for a single room for sending to clients.
@ -232,7 +232,7 @@ class ReceiptsStore(SQLBaseStore):
self.get_receipts_for_user.invalidate, (user_id, receipt_type) self.get_receipts_for_user.invalidate, (user_id, receipt_type)
) )
# FIXME: This shouldn't invalidate the whole cache # FIXME: This shouldn't invalidate the whole cache
txn.call_after(self.get_linearized_receipts_for_room.invalidate_all) txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
txn.call_after( txn.call_after(
self._receipts_stream_cache.entity_has_changed, self._receipts_stream_cache.entity_has_changed,
@ -244,6 +244,17 @@ class ReceiptsStore(SQLBaseStore):
(user_id, room_id, receipt_type) (user_id, room_id, receipt_type)
) )
res = self._simple_select_one_txn(
txn,
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
allow_none=True
)
topological_ordering = int(res["topological_ordering"]) if res else None
stream_ordering = int(res["stream_ordering"]) if res else None
# We don't want to clobber receipts for more recent events, so we # We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts # have to compare orderings of existing receipts
sql = ( sql = (
@ -255,16 +266,7 @@ class ReceiptsStore(SQLBaseStore):
txn.execute(sql, (room_id, receipt_type, user_id)) txn.execute(sql, (room_id, receipt_type, user_id))
results = txn.fetchall() results = txn.fetchall()
if results: if results and topological_ordering:
res = self._simple_select_one_txn(
txn,
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
)
topological_ordering = int(res["topological_ordering"])
stream_ordering = int(res["stream_ordering"])
for to, so, _ in results: for to, so, _ in results:
if int(to) > topological_ordering: if int(to) > topological_ordering:
return False return False
@ -294,6 +296,14 @@ class ReceiptsStore(SQLBaseStore):
} }
) )
if receipt_type == "m.read" and topological_ordering:
self._remove_push_actions_before_txn(
txn,
room_id=room_id,
user_id=user_id,
topological_ordering=topological_ordering,
)
return True return True
@defer.inlineCallbacks @defer.inlineCallbacks
@ -367,7 +377,7 @@ class ReceiptsStore(SQLBaseStore):
self.get_receipts_for_user.invalidate, (user_id, receipt_type) self.get_receipts_for_user.invalidate, (user_id, receipt_type)
) )
# FIXME: This shouldn't invalidate the whole cache # FIXME: This shouldn't invalidate the whole cache
txn.call_after(self.get_linearized_receipts_for_room.invalidate_all) txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,

View File

@ -0,0 +1,38 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- The following indices are redundant, other indices are equivalent or
-- supersets
DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream
DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room
DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room
DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY
DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY
DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY
DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT
DROP INDEX IF EXISTS event_destinations_id; -- Prefix of UNIQUE CONSTRAINT
DROP INDEX IF EXISTS st_extrem_id; -- Prefix of UNIQUE CONSTRAINT
DROP INDEX IF EXISTS event_content_hashes_id; -- Prefix of UNIQUE CONSTRAINT
DROP INDEX IF EXISTS event_signatures_id; -- Prefix of UNIQUE CONSTRAINT
DROP INDEX IF EXISTS event_edge_hashes_id; -- Prefix of UNIQUE CONSTRAINT
DROP INDEX IF EXISTS redactions_event_id; -- Duplicate of UNIQUE CONSTRAINT
DROP INDEX IF EXISTS room_hosts_room_id; -- Prefix of UNIQUE CONSTRAINT
-- The following indices were unused
DROP INDEX IF EXISTS remote_media_cache_thumbnails_media_id;
DROP INDEX IF EXISTS evauth_edges_auth_id;
DROP INDEX IF EXISTS presence_stream_state;

View File

@ -284,12 +284,12 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("time < 1") # ms macaroon.add_first_party_caveat("time < 1") # ms
self.hs.clock.now = 5000 # seconds self.hs.clock.now = 5000 # seconds
self.hs.config.expire_access_token = True
yield self.auth.get_user_from_macaroon(macaroon.serialize()) # yield self.auth.get_user_from_macaroon(macaroon.serialize())
# TODO(daniel): Turn on the check that we validate expiration, when we # TODO(daniel): Turn on the check that we validate expiration, when we
# validate expiration (and remove the above line, which will start # validate expiration (and remove the above line, which will start
# throwing). # throwing).
# with self.assertRaises(AuthError) as cm: with self.assertRaises(AuthError) as cm:
# yield self.auth.get_user_from_macaroon(macaroon.serialize()) yield self.auth.get_user_from_macaroon(macaroon.serialize())
# self.assertEqual(401, cm.exception.code) self.assertEqual(401, cm.exception.code)
# self.assertIn("Invalid macaroon", cm.exception.msg) self.assertIn("Invalid macaroon", cm.exception.msg)

View File

@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .. import unittest
from synapse.handlers.register import RegistrationHandler
from tests.utils import setup_test_homeserver
from mock import Mock
class RegistrationHandlers(object):
def __init__(self, hs):
self.registration_handler = RegistrationHandler(hs)
class RegistrationTestCase(unittest.TestCase):
""" Tests the RegistrationHandler. """
@defer.inlineCallbacks
def setUp(self):
self.mock_distributor = Mock()
self.mock_distributor.declare("registered_user")
self.mock_captcha_client = Mock()
hs = yield setup_test_homeserver(
handlers=None,
http_client=None,
expire_access_token=True)
hs.handlers = RegistrationHandlers(hs)
self.handler = hs.get_handlers().registration_handler
hs.get_handlers().profile_handler = Mock()
self.mock_handler = Mock(spec=[
"generate_short_term_login_token",
])
hs.get_handlers().auth_handler = self.mock_handler
@defer.inlineCallbacks
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
"""
Returns:
The user doess not exist in this case so it will register and log it in
"""
duration_ms = 200
local_part = "someone"
display_name = "someone"
user_id = "@someone:test"
mock_token = self.mock_handler.generate_short_term_login_token
mock_token.return_value = 'secret'
result_user_id, result_token = yield self.handler.get_or_create_user(
local_part, display_name, duration_ms)
self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret')

View File

@ -0,0 +1,56 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStoreTestCase
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from twisted.internet import defer
USER_ID = "@feeling:blue"
TYPE = "my.type"
class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedAccountDataStore
@defer.inlineCallbacks
def test_user_account_data(self):
yield self.master_store.add_account_data_for_user(
USER_ID, TYPE, {"a": 1}
)
yield self.replicate()
yield self.check(
"get_global_account_data_by_type_for_user",
[TYPE, USER_ID], {"a": 1}
)
yield self.check(
"get_global_account_data_by_type_for_users",
[TYPE, [USER_ID]], {USER_ID: {"a": 1}}
)
yield self.master_store.add_account_data_for_user(
USER_ID, TYPE, {"a": 2}
)
yield self.replicate()
yield self.check(
"get_global_account_data_by_type_for_user",
[TYPE, USER_ID], {"a": 2}
)
yield self.check(
"get_global_account_data_by_type_for_users",
[TYPE, [USER_ID]], {USER_ID: {"a": 2}}
)

View File

@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.rest.client.v1.register import CreateUserRestServlet
from twisted.internet import defer
from mock import Mock
from tests import unittest
import json
class CreateUserServletTestCase(unittest.TestCase):
def setUp(self):
# do the dance to hook up request data to self.request_data
self.request_data = ""
self.request = Mock(
content=Mock(read=Mock(side_effect=lambda: self.request_data)),
path='/_matrix/client/api/v1/createUser'
)
self.request.args = {}
self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock(
side_effect=lambda x: defer.succeed(self.appservice))
)
self.auth_result = (False, None, None, None)
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None)
)
self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
# do the dance to hook it up to the hs global
self.handlers = Mock(
auth_handler=self.auth_handler,
registration_handler=self.registration_handler,
identity_handler=self.identity_handler,
login_handler=self.login_handler
)
self.hs = Mock()
self.hs.hostname = "supergbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.config.enable_registration = True
# init the thing we're testing
self.servlet = CreateUserRestServlet(self.hs)
@defer.inlineCallbacks
def test_POST_createuser_with_valid_user(self):
user_id = "@someone:interesting"
token = "my token"
self.request.args = {
"access_token": "i_am_an_app_service"
}
self.request_data = json.dumps({
"localpart": "someone",
"displayname": "someone interesting",
"duration_seconds": 200
})
self.registration_handler.get_or_create_user = Mock(
return_value=(user_id, token)
)
(code, result) = yield self.servlet.on_POST(self.request)
self.assertEquals(code, 200)
det_data = {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname
}
self.assertDictContainsSubset(det_data, result)

View File

@ -49,6 +49,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.event_cache_size = 1 config.event_cache_size = 1
config.enable_registration = True config.enable_registration = True
config.macaroon_secret_key = "not even a little secret" config.macaroon_secret_key = "not even a little secret"
config.expire_access_token = False
config.server_name = "server.under.test" config.server_name = "server.under.test"
config.trusted_third_party_id_servers = [] config.trusted_third_party_id_servers = []
config.room_invite_state_types = [] config.room_invite_state_types = []