Merge branch 'develop' into email_login

This commit is contained in:
David Baker 2015-08-20 10:16:01 +01:00
commit c50ad14bae
60 changed files with 1759 additions and 714 deletions

View file

@ -85,9 +85,8 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = UserID.create(
user_id, self.hs.hostname).to_string()
handler = self.handlers.login_handler
token = yield handler.login(
user=user_id,
token = yield self.handlers.auth_handler.login_with_password(
user_id=user_id
password=login_submission["password"])
result = {

View file

@ -36,7 +36,6 @@ class PasswordRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks
def on_POST(self, request):
@ -47,7 +46,7 @@ class PasswordRestServlet(RestServlet):
authed, result, params = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY]
], body)
], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
@ -79,7 +78,7 @@ class PasswordRestServlet(RestServlet):
raise SynapseError(400, "", Codes.MISSING_PARAM)
new_password = params['new_password']
yield self.login_handler.set_password(
yield self.auth_handler.set_password(
user_id, new_password, None
)
@ -95,7 +94,6 @@ class ThreepidRestServlet(RestServlet):
def __init__(self, hs):
super(ThreepidRestServlet, self).__init__()
self.hs = hs
self.login_handler = hs.get_handlers().login_handler
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@ -135,7 +133,7 @@ class ThreepidRestServlet(RestServlet):
logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
raise SynapseError(500, "Invalid response from ID Server")
yield self.login_handler.add_threepid(
yield self.auth_handler.add_threepid(
auth_user.to_string(),
threepid['medium'],
threepid['address'],

View file

@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from synapse.types import UserID
from syutil.jsonutil import encode_canonical_json
from ._base import client_v2_pattern
@ -164,45 +165,63 @@ class KeyQueryServlet(RestServlet):
super(KeyQueryServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.federation = hs.get_replication_layer()
self.is_mine = hs.is_mine
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id):
logger.debug("onPOST")
yield self.auth.get_user_by_req(request)
try:
body = json.loads(request.content.read())
except:
raise SynapseError(400, "Invalid key JSON")
query = []
for user_id, device_ids in body.get("device_keys", {}).items():
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
defer.returnValue(self.json_result(request, results))
result = yield self.handle_request(body)
defer.returnValue(result)
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id):
auth_user, client_info = yield self.auth.get_user_by_req(request)
auth_user_id = auth_user.to_string()
if not user_id:
user_id = auth_user_id
if not device_id:
device_id = None
# Returns a map of user_id->device_id->json_bytes.
results = yield self.store.get_e2e_device_keys([(user_id, device_id)])
defer.returnValue(self.json_result(request, results))
user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else []
result = yield self.handle_request(
{"device_keys": {user_id: device_ids}}
)
defer.returnValue(result)
@defer.inlineCallbacks
def handle_request(self, body):
local_query = []
remote_queries = {}
for user_id, device_ids in body.get("device_keys", {}).items():
user = UserID.from_string(user_id)
if self.is_mine(user):
if not device_ids:
local_query.append((user_id, None))
else:
for device_id in device_ids:
local_query.append((user_id, device_id))
else:
remote_queries.setdefault(user.domain, {})[user_id] = list(
device_ids
)
results = yield self.store.get_e2e_device_keys(local_query)
def json_result(self, request, results):
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
return (200, {"device_keys": json_result})
for destination, device_keys in remote_queries.items():
remote_result = yield self.federation.query_client_keys(
destination, {"device_keys": device_keys}
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
defer.returnValue((200, {"device_keys": json_result}))
class OneTimeKeyServlet(RestServlet):
@ -236,14 +255,16 @@ class OneTimeKeyServlet(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.federation = hs.get_replication_layer()
self.is_mine = hs.is_mine
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
results = yield self.store.claim_e2e_one_time_keys(
[(user_id, device_id, algorithm)]
result = yield self.handle_request(
{"one_time_keys": {user_id: {device_id: algorithm}}}
)
defer.returnValue(self.json_result(request, results))
defer.returnValue(result)
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm):
@ -252,14 +273,24 @@ class OneTimeKeyServlet(RestServlet):
body = json.loads(request.content.read())
except:
raise SynapseError(400, "Invalid key JSON")
query = []
for user_id, device_keys in body.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))
results = yield self.store.claim_e2e_one_time_keys(query)
defer.returnValue(self.json_result(request, results))
result = yield self.handle_request(body)
defer.returnValue(result)
@defer.inlineCallbacks
def handle_request(self, body):
local_query = []
remote_queries = {}
for user_id, device_keys in body.get("one_time_keys", {}).items():
user = UserID.from_string(user_id)
if self.is_mine(user):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
remote_queries.setdefault(user.domain, {})[user_id] = (
device_keys
)
results = yield self.store.claim_e2e_one_time_keys(local_query)
def json_result(self, request, results):
json_result = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
@ -267,7 +298,16 @@ class OneTimeKeyServlet(RestServlet):
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
return (200, {"one_time_keys": json_result})
for destination, device_keys in remote_queries.items():
remote_result = yield self.federation.claim_client_keys(
destination, {"one_time_keys": device_keys}
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
defer.returnValue((200, {"one_time_keys": json_result}))
def register_servlets(hs, http_server):

View file

@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
class RegisterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/register*")
PATTERN = client_v2_pattern("/register")
def __init__(self, hs):
super(RegisterRestServlet, self).__init__()
@ -50,7 +50,6 @@ class RegisterRestServlet(RestServlet):
self.auth_handler = hs.get_handlers().auth_handler
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks
def on_POST(self, request):
@ -148,7 +147,7 @@ class RegisterRestServlet(RestServlet):
if reqd not in threepid:
logger.info("Can't add incomplete 3pid")
else:
yield self.login_handler.add_threepid(
yield self.auth_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
@ -224,6 +223,9 @@ class RegisterRestServlet(RestServlet):
if k not in body:
absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
@ -231,9 +233,6 @@ class RegisterRestServlet(RestServlet):
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))