Merge pull request #2797 from matrix-org/rav/user_id_checking

Sanity checking for user ids
This commit is contained in:
Richard van der Hoff 2018-01-17 14:29:12 +00:00 committed by GitHub
commit f884cfffb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 8 deletions

View File

@ -17,7 +17,8 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.types import get_domain_from_id from synapse.api.errors import SynapseError
from synapse.types import get_domain_from_id, UserID
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -33,7 +34,7 @@ class DeviceMessageHandler(object):
""" """
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id self.is_mine = hs.is_mine
self.federation = hs.get_federation_sender() self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler( hs.get_replication_layer().register_edu_handler(
@ -52,6 +53,12 @@ class DeviceMessageHandler(object):
message_type = content["type"] message_type = content["type"]
message_id = content["message_id"] message_id = content["message_id"]
for user_id, by_device in content["messages"].items(): for user_id, by_device in content["messages"].items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s",
user_id)
raise SynapseError(400, "Not a user here")
messages_by_device = { messages_by_device = {
device_id: { device_id: {
"content": message_content, "content": message_content,
@ -77,7 +84,8 @@ class DeviceMessageHandler(object):
local_messages = {} local_messages = {}
remote_messages = {} remote_messages = {}
for user_id, by_device in messages.items(): for user_id, by_device in messages.items():
if self.is_mine_id(user_id): # we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
messages_by_device = { messages_by_device = {
device_id: { device_id: {
"content": message_content, "content": message_content,

View File

@ -20,7 +20,7 @@ from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException from synapse.api.errors import SynapseError, CodeMessageException
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id, UserID
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
@ -32,7 +32,7 @@ class E2eKeysHandler(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.is_mine_id = hs.is_mine_id self.is_mine = hs.is_mine
self.clock = hs.get_clock() self.clock = hs.get_clock()
# doesn't really work as part of the generic query API, because the # doesn't really work as part of the generic query API, because the
@ -70,7 +70,8 @@ class E2eKeysHandler(object):
remote_queries = {} remote_queries = {}
for user_id, device_ids in device_keys_query.items(): for user_id, device_ids in device_keys_query.items():
if self.is_mine_id(user_id): # we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids local_query[user_id] = device_ids
else: else:
remote_queries[user_id] = device_ids remote_queries[user_id] = device_ids
@ -170,7 +171,8 @@ class E2eKeysHandler(object):
result_dict = {} result_dict = {}
for user_id, device_ids in query.items(): for user_id, device_ids in query.items():
if not self.is_mine_id(user_id): # we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s", logger.warning("Request for keys for non-local user %s",
user_id) user_id)
raise SynapseError(400, "Not a user here") raise SynapseError(400, "Not a user here")
@ -213,7 +215,8 @@ class E2eKeysHandler(object):
remote_queries = {} remote_queries = {}
for user_id, device_keys in query.get("one_time_keys", {}).items(): for user_id, device_keys in query.get("one_time_keys", {}).items():
if self.is_mine_id(user_id): # we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
for device_id, algorithm in device_keys.items(): for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm)) local_query.append((user_id, device_id, algorithm))
else: else: