Move the E2E key handling into the e2e handler

This commit is contained in:
Mark Haines 2016-09-13 11:35:35 +01:00
parent 76b09c29b0
commit 18ab019a4a
2 changed files with 118 additions and 115 deletions

View File

@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import ujson as json
import logging import logging
from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException from synapse.api.errors import SynapseError, CodeMessageException
@ -29,7 +30,9 @@ class E2eKeysHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.device_handler = hs.get_device_handler()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.clock = hs.get_clock()
# doesn't really work as part of the generic query API, because the # doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the # query request requires an object POST, but we abuse the
@ -103,9 +106,9 @@ class E2eKeysHandler(object):
for destination in remote_queries for destination in remote_queries
])) ]))
defer.returnValue((200, { defer.returnValue({
"device_keys": results, "failures": failures, "device_keys": results, "failures": failures,
})) })
@defer.inlineCallbacks @defer.inlineCallbacks
def query_local_devices(self, query): def query_local_devices(self, query):
@ -159,3 +162,99 @@ class E2eKeysHandler(object):
device_keys_query = query_body.get("device_keys", {}) device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query) res = yield self.query_local_devices(device_keys_query)
defer.returnValue({"device_keys": res}) defer.returnValue({"device_keys": res})
@defer.inlineCallbacks
def claim_one_time_keys(self, query, timeout):
local_query = []
remote_queries = {}
for user_id, device_keys in query.get("one_time_keys", {}).items():
if self.is_mine_id(user_id):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_keys
results = yield self.store.claim_e2e_one_time_keys(local_query)
json_result = {}
failures = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
@defer.inlineCallbacks
def claim_client_keys(destination):
device_keys = remote_queries[destination]
try:
remote_result = yield self.federation.claim_client_keys(
destination,
{"one_time_keys": device_keys},
timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(claim_client_keys)(destination)
for destination in remote_queries
]))
defer.returnValue({
"one_time_keys": json_result,
"failures": failures
})
@defer.inlineCallbacks
def upload_keys_for_user(self, user_id, device_id, keys):
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
device_keys = keys.get("device_keys", None)
if device_keys:
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id, user_id, time_now
)
# TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys(
user_id, device_id, time_now,
encode_canonical_json(device_keys)
)
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
logger.info(
"Adding %d one_time_keys for device %r for user %r at %d",
len(one_time_keys), device_id, user_id, time_now
)
key_list = []
for key_id, key_json in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((
algorithm, key_id, encode_canonical_json(key_json)
))
yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, key_list
)
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
self.device_handler.check_device_registered(user_id, device_id)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue({"one_time_key_counts": result})

View File

@ -15,16 +15,12 @@
import logging import logging
import simplejson as json
from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException from synapse.api.errors import SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, parse_integer RestServlet, parse_json_object_from_request, parse_integer
) )
from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from ._base import client_v2_patterns from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,17 +60,13 @@ class KeyUploadServlet(RestServlet):
hs (synapse.server.HomeServer): server hs (synapse.server.HomeServer): server
""" """
super(KeyUploadServlet, self).__init__() super(KeyUploadServlet, self).__init__()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, device_id): def on_POST(self, request, device_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
if device_id is not None: if device_id is not None:
@ -94,47 +86,10 @@ class KeyUploadServlet(RestServlet):
"To upload keys, you must pass device_id when authenticating" "To upload keys, you must pass device_id when authenticating"
) )
time_now = self.clock.time_msec() result = yield self.e2e_keys_handler.upload_keys_for_user(
user_id, device_id, body
# TODO: Validate the JSON to make sure it has the right keys. )
device_keys = body.get("device_keys", None) defer.returnValue((200, result))
if device_keys:
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id, user_id, time_now
)
# TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys(
user_id, device_id, time_now,
encode_canonical_json(device_keys)
)
one_time_keys = body.get("one_time_keys", None)
if one_time_keys:
logger.info(
"Adding %d one_time_keys for device %r for user %r at %d",
len(one_time_keys), device_id, user_id, time_now
)
key_list = []
for key_id, key_json in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((
algorithm, key_id, encode_canonical_json(key_json)
))
yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, key_list
)
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
self.device_handler.check_device_registered(user_id, device_id)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))
class KeyQueryServlet(RestServlet): class KeyQueryServlet(RestServlet):
@ -199,7 +154,7 @@ class KeyQueryServlet(RestServlet):
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.query_devices(body, timeout) result = yield self.e2e_keys_handler.query_devices(body, timeout)
defer.returnValue(result) defer.returnValue((200, result))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id): def on_GET(self, request, user_id, device_id):
@ -212,7 +167,7 @@ class KeyQueryServlet(RestServlet):
{"device_keys": {user_id: device_ids}}, {"device_keys": {user_id: device_ids}},
timeout, timeout,
) )
defer.returnValue(result) defer.returnValue((200, result))
class OneTimeKeyServlet(RestServlet): class OneTimeKeyServlet(RestServlet):
@ -244,80 +199,29 @@ class OneTimeKeyServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__() super(OneTimeKeyServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.federation = hs.get_replication_layer()
self.is_mine_id = hs.is_mine_id
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm): def on_GET(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
result = yield self.handle_request( result = yield self.e2e_keys_handler.claim_one_time_keys(
{"one_time_keys": {user_id: {device_id: algorithm}}}, {"one_time_keys": {user_id: {device_id: algorithm}}},
timeout, timeout,
) )
defer.returnValue(result) defer.returnValue((200, result))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm): def on_POST(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.handle_request(body, timeout) result = yield self.e2e_keys_handler.claim_one_time_keys(
defer.returnValue(result) body,
timeout,
@defer.inlineCallbacks )
def handle_request(self, body, timeout): defer.returnValue((200, result))
local_query = []
remote_queries = {}
for user_id, device_keys in body.get("one_time_keys", {}).items():
if self.is_mine_id(user_id):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_keys
results = yield self.store.claim_e2e_one_time_keys(local_query)
json_result = {}
failures = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
@defer.inlineCallbacks
def claim_client_keys(destination):
device_keys = remote_queries[destination]
try:
remote_result = yield self.federation.claim_client_keys(
destination,
{"one_time_keys": device_keys},
timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(claim_client_keys)(destination)
for destination in remote_queries
]))
defer.returnValue((200, {
"one_time_keys": json_result,
"failures": failures
}))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):