David Baker 2017-03-07 16:37:23 +00:00
parent b0effa2160
commit 00466e2feb

View File

@ -25,6 +25,7 @@ from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json import simplejson as json
import urllib import urllib
import urlparse import urlparse
import phonenumbers
import logging import logging
from saml2 import BINDING_HTTP_POST from saml2 import BINDING_HTTP_POST
@ -37,6 +38,58 @@ import xml.etree.ElementTree as ET
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def login_submission_legacy_convert(submission):
"""
If the input login submission is an old style object
(ie. with top-level user / medium / address) convert it
to a typed object.
Returns: Typed-object style login identifier
"""
if "user" in submission:
submission["identifier"] = {
"type": "m.id.user",
"user": submission["user"],
}
del submission["user"]
if "medium" in submission and "address" in submission:
submission["identifier"] = {
"type": "m.id.thirdparty",
"medium": submission["medium"],
"address": submission["address"],
}
del submission["medium"]
del submission["address"]
return submission
def login_id_thirdparty_from_phone(identifier):
"""
Convert a phone login identifier type to a generic threepid identifier
Args:
identifier: Login identifier dict of type 'm.id.phone'
Returns: Login identifier dict of type 'm.id.threepid'
"""
if "country" not in identifier or "number" not in identifier:
raise SynapseError(400, "Invalid phone-type identifier")
phoneNumber = None
try:
phoneNumber = phonenumbers.parse(identifier["number"], identifier["country"])
except phonenumbers.NumberParseException:
raise SynapseError(400, "Unable to parse phone number")
msisdn = phonenumbers.format_number(
phoneNumber, phonenumbers.PhoneNumberFormat.E164
)[1:]
return {
"type": "m.id.thirdparty",
"medium": "msisdn",
"address": msisdn,
}
class LoginRestServlet(ClientV1RestServlet): class LoginRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login$") PATTERNS = client_path_patterns("/login$")
PASS_TYPE = "m.login.password" PASS_TYPE = "m.login.password"
@ -117,20 +170,52 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_password_login(self, login_submission): def do_password_login(self, login_submission):
if 'medium' in login_submission and 'address' in login_submission: if "password" not in login_submission:
address = login_submission['address'] raise SynapseError(400, "Missing parameter: password")
if login_submission['medium'] == 'email':
login_submission = login_submission_legacy_convert(login_submission)
if "identifier" not in login_submission:
raise SynapseError(400, "Missing param: identifier")
identifier = login_submission["identifier"]
if "type" not in identifier:
raise SynapseError(400, "Login identifier has no type")
# convert phone type identifiers to geberic threepids
if identifier["type"] == "m.id.phone":
identifier = login_id_thirdparty_from_phone(identifier)
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
if not 'medium' in identifier or not 'address' in identifier:
raise SynapseError(400, "Invalid thirdparty identifier")
address = identifier['address']
if identifier['medium'] == 'email':
# For emails, transform the address to lowercase. # For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB. # We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py) # (See add_threepid in synapse/handlers/auth.py)
address = address.lower() address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid( user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
login_submission['medium'], address identifier['medium'], address
) )
if not user_id: if not user_id:
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
else:
user_id = login_submission['user'] identifier = {
"type": "m.id.user",
"user": user_id,
}
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
if identifier["type"] != "m.id.user":
raise SynapseError(400, "Unknown login identifier type")
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
user_id = identifier["user"]
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create( user_id = UserID.create(