Merge branch 'develop' of github.com:matrix-org/synapse into postgres

This commit is contained in:
Erik Johnston 2015-04-28 13:39:42 +01:00
commit 327ca883ec
37 changed files with 1361 additions and 304 deletions

31
CAPTCHA_SETUP Normal file
View File

@ -0,0 +1,31 @@
Captcha can be enabled for this home server. This file explains how to do that.
The captcha mechanism used is Google's ReCaptcha. This requires API keys from Google.
Getting keys
------------
Requires a public/private key pair from:
https://developers.google.com/recaptcha/
Setting ReCaptcha Keys
----------------------
The keys are a config option on the home server config. If they are not
visible, you can generate them via --generate-config. Set the following value:
recaptcha_public_key: YOUR_PUBLIC_KEY
recaptcha_private_key: YOUR_PRIVATE_KEY
In addition, you MUST enable captchas via:
enable_registration_captcha: true
Configuring IP used for auth
----------------------------
The ReCaptcha API requires that the IP address of the user who solved the
captcha is sent. If the client is connecting through a proxy or load balancer,
it may be required to use the X-Forwarded-For (XFF) header instead of the origin
IP address. This can be configured as an option on the home server like so:
captcha_ip_origin_is_x_forwarded: true

View File

@ -33,10 +33,9 @@ def request_registration(user, password, server_location, shared_secret):
).hexdigest() ).hexdigest()
data = { data = {
"user": user, "username": user,
"password": password, "password": password,
"mac": mac, "mac": mac,
"type": "org.matrix.login.shared_secret",
} }
server_location = server_location.rstrip("/") server_location = server_location.rstrip("/")
@ -44,7 +43,7 @@ def request_registration(user, password, server_location, shared_secret):
print "Sending registration request..." print "Sending registration request..."
req = urllib2.Request( req = urllib2.Request(
"%s/_matrix/client/api/v1/register" % (server_location,), "%s/_matrix/client/v2_alpha/register" % (server_location,),
data=json.dumps(data), data=json.dumps(data),
headers={'Content-Type': 'application/json'} headers={'Content-Type': 'application/json'}
) )

View File

@ -37,9 +37,13 @@ textarea, input {
margin: auto margin: auto
} }
.g-recaptcha div {
margin: auto;
}
#registrationForm { #registrationForm {
text-align: left; text-align: left;
padding: 1em; padding: 5px;
margin-bottom: 40px; margin-bottom: 40px;
display: inline-block; display: inline-block;

View File

@ -18,7 +18,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.types import UserID, ClientInfo from synapse.types import UserID, ClientInfo
@ -40,6 +40,7 @@ class Auth(object):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
def check(self, event, auth_events): def check(self, event, auth_events):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
@ -222,6 +223,13 @@ class Auth(object):
elif target_in_room: # the target is already in the room. elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." % raise AuthError(403, "%s is already in the room." %
target_user_id) target_user_id)
else:
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, "You cannot invite user %s." % target_user_id
)
elif Membership.JOIN == membership: elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were: # Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation # invited: They are accepting the invitation
@ -362,7 +370,10 @@ class Auth(object):
defer.returnValue((user, ClientInfo(device_id, token_id))) defer.returnValue((user, ClientInfo(device_id, token_id)))
except KeyError: except KeyError:
raise AuthError(403, "Missing access token.") raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
errcode=Codes.MISSING_TOKEN
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_token(self, token): def get_user_by_token(self, token):
@ -376,21 +387,20 @@ class Auth(object):
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
try: ret = yield self.store.get_user_by_token(token)
ret = yield self.store.get_user_by_token(token) if not ret:
if not ret: raise AuthError(
raise StoreError(400, "Unknown token") self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
user_info = { errcode=Codes.UNKNOWN_TOKEN
"admin": bool(ret.get("admin", False)), )
"device_id": ret.get("device_id"), user_info = {
"user": UserID.from_string(ret.get("name")), "admin": bool(ret.get("admin", False)),
"token_id": ret.get("token_id", None), "device_id": ret.get("device_id"),
} "user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
}
defer.returnValue(user_info) defer.returnValue(user_info)
except StoreError:
raise AuthError(403, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
@ -398,11 +408,16 @@ class Auth(object):
token = request.args["access_token"][0] token = request.args["access_token"][0]
service = yield self.store.get_app_service_by_token(token) service = yield self.store.get_app_service_by_token(token)
if not service: if not service:
raise AuthError(403, "Unrecognised access token.", raise AuthError(
errcode=Codes.UNKNOWN_TOKEN) self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
)
defer.returnValue(service) defer.returnValue(service)
except KeyError: except KeyError:
raise AuthError(403, "Missing access token.") raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
)
def is_server_admin(self, user): def is_server_admin(self, user):
return self.store.is_server_admin(user) return self.store.is_server_admin(user)
@ -561,6 +576,7 @@ class Auth(object):
("ban", []), ("ban", []),
("redact", []), ("redact", []),
("kick", []), ("kick", []),
("invite", []),
] ]
old_list = current_state.content.get("users") old_list = current_state.content.get("users")

View File

@ -59,6 +59,9 @@ class LoginType(object):
EMAIL_URL = u"m.login.email.url" EMAIL_URL = u"m.login.email.url"
EMAIL_IDENTITY = u"m.login.email.identity" EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha" RECAPTCHA = u"m.login.recaptcha"
DUMMY = u"m.login.dummy"
# Only for C/S API v1
APPLICATION_SERVICE = u"m.login.application_service" APPLICATION_SERVICE = u"m.login.application_service"
SHARED_SECRET = u"org.matrix.login.shared_secret" SHARED_SECRET = u"org.matrix.login.shared_secret"

View File

@ -31,6 +31,7 @@ class Codes(object):
BAD_PAGINATION = "M_BAD_PAGINATION" BAD_PAGINATION = "M_BAD_PAGINATION"
UNKNOWN = "M_UNKNOWN" UNKNOWN = "M_UNKNOWN"
NOT_FOUND = "M_NOT_FOUND" NOT_FOUND = "M_NOT_FOUND"
MISSING_TOKEN = "M_MISSING_TOKEN"
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
@ -38,6 +39,7 @@ class Codes(object):
MISSING_PARAM = "M_MISSING_PARAM" MISSING_PARAM = "M_MISSING_PARAM"
TOO_LARGE = "M_TOO_LARGE" TOO_LARGE = "M_TOO_LARGE"
EXCLUSIVE = "M_EXCLUSIVE" EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):

View File

@ -17,8 +17,11 @@
import sys import sys
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
from synapse.storage import UpgradeDatabaseException
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage import (
prepare_database, prepare_sqlite3_database, are_all_users_on_domain,
UpgradeDatabaseException,
)
from synapse.server import HomeServer from synapse.server import HomeServer
@ -238,6 +241,21 @@ class SynapseHomeServer(HomeServer):
) )
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port) logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
def run_startup_checks(self, db_conn, database_engine):
all_users_native = are_all_users_on_domain(
db_conn.cursor(), database_engine, self.hostname
)
if not all_users_native:
sys.stderr.write(
"\n"
"******************************************************\n"
"Found users in database not native to %s!\n"
"You cannot changed a synapse server_name after it's been configured\n"
"******************************************************\n"
"\n" % (self.hostname,)
)
sys.exit(1)
def get_version_string(): def get_version_string():
try: try:
@ -382,6 +400,7 @@ def setup(config_options):
) )
database_engine.prepare_database(db_conn) database_engine.prepare_database(db_conn)
hs.run_startup_checks(db_conn, database_engine)
db_conn.commit() db_conn.commit()
except UpgradeDatabaseException: except UpgradeDatabaseException:

View File

@ -147,9 +147,10 @@ class Config(object):
and value is not None): and value is not None):
config[key] = value config[key] = value
with open(config_args.config_path, "w") as config_file: with open(config_args.config_path, "w") as config_file:
# TODO(paul) it would be lovely if we wrote out vim- and emacs- # TODO(mark/paul) We might want to output emacs-style mode
# style mode markers into the file, to hint to people that # markers as well as vim-style mode markers into the file,
# this is a YAML file. # to further hint to people this is a YAML file.
config_file.write("# vim:ft=yaml\n")
yaml.dump(config, config_file, default_flow_style=False) yaml.dump(config, config_file, default_flow_style=False)
print ( print (
"A config file has been generated in %s for server name" "A config file has been generated in %s for server name"

View File

@ -20,6 +20,7 @@ class CaptchaConfig(Config):
def __init__(self, args): def __init__(self, args):
super(CaptchaConfig, self).__init__(args) super(CaptchaConfig, self).__init__(args)
self.recaptcha_private_key = args.recaptcha_private_key self.recaptcha_private_key = args.recaptcha_private_key
self.recaptcha_public_key = args.recaptcha_public_key
self.enable_registration_captcha = args.enable_registration_captcha self.enable_registration_captcha = args.enable_registration_captcha
self.captcha_ip_origin_is_x_forwarded = ( self.captcha_ip_origin_is_x_forwarded = (
args.captcha_ip_origin_is_x_forwarded args.captcha_ip_origin_is_x_forwarded
@ -30,9 +31,13 @@ class CaptchaConfig(Config):
def add_arguments(cls, parser): def add_arguments(cls, parser):
super(CaptchaConfig, cls).add_arguments(parser) super(CaptchaConfig, cls).add_arguments(parser)
group = parser.add_argument_group("recaptcha") group = parser.add_argument_group("recaptcha")
group.add_argument(
"--recaptcha-public-key", type=str, default="YOUR_PUBLIC_KEY",
help="This Home Server's ReCAPTCHA public key."
)
group.add_argument( group.add_argument(
"--recaptcha-private-key", type=str, default="YOUR_PRIVATE_KEY", "--recaptcha-private-key", type=str, default="YOUR_PRIVATE_KEY",
help="The matching private key for the web client's public key." help="This Home Server's ReCAPTCHA private key."
) )
group.add_argument( group.add_argument(
"--enable-registration-captcha", type=bool, default=False, "--enable-registration-captcha", type=bool, default=False,

View File

@ -1,42 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class EmailConfig(Config):
def __init__(self, args):
super(EmailConfig, self).__init__(args)
self.email_from_address = args.email_from_address
self.email_smtp_server = args.email_smtp_server
@classmethod
def add_arguments(cls, parser):
super(EmailConfig, cls).add_arguments(parser)
email_group = parser.add_argument_group("email")
email_group.add_argument(
"--email-from-address",
default="FROM@EXAMPLE.COM",
help="The address to send emails from (e.g. for password resets)."
)
email_group.add_argument(
"--email-smtp-server",
default="",
help=(
"The SMTP server to send emails from (e.g. for password"
" resets)."
)
)

View File

@ -20,7 +20,6 @@ from .database import DatabaseConfig
from .ratelimiting import RatelimitConfig from .ratelimiting import RatelimitConfig
from .repository import ContentRepositoryConfig from .repository import ContentRepositoryConfig
from .captcha import CaptchaConfig from .captcha import CaptchaConfig
from .email import EmailConfig
from .voip import VoipConfig from .voip import VoipConfig
from .registration import RegistrationConfig from .registration import RegistrationConfig
from .metrics import MetricsConfig from .metrics import MetricsConfig
@ -29,7 +28,7 @@ from .appservice import AppServiceConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
EmailConfig, VoipConfig, RegistrationConfig, VoipConfig, RegistrationConfig,
MetricsConfig, AppServiceConfig,): MetricsConfig, AppServiceConfig,):
pass pass

View File

@ -30,6 +30,8 @@ from .typing import TypingNotificationHandler
from .admin import AdminHandler from .admin import AdminHandler
from .appservice import ApplicationServicesHandler from .appservice import ApplicationServicesHandler
from .sync import SyncHandler from .sync import SyncHandler
from .auth import AuthHandler
from .identity import IdentityHandler
class Handlers(object): class Handlers(object):
@ -64,3 +66,5 @@ class Handlers(object):
) )
) )
self.sync_handler = SyncHandler(hs) self.sync_handler = SyncHandler(hs)
self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs)

277
synapse/handlers/auth.py Normal file
View File

@ -0,0 +1,277 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import LoginType
from synapse.types import UserID
from synapse.api.errors import LoginError, Codes
from synapse.http.client import SimpleHttpClient
from synapse.util.async import run_on_reactor
from twisted.web.client import PartialDownloadError
import logging
import bcrypt
import simplejson
import synapse.util.stringutils as stringutils
logger = logging.getLogger(__name__)
class AuthHandler(BaseHandler):
def __init__(self, hs):
super(AuthHandler, self).__init__(hs)
self.checkers = {
LoginType.PASSWORD: self._check_password_auth,
LoginType.RECAPTCHA: self._check_recaptcha,
LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.DUMMY: self._check_dummy_auth,
}
self.sessions = {}
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip=None):
"""
Takes a dictionary sent by the client in the login / registration
protocol and handles the login flow.
Args:
flows: list of list of stages
authdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
Returns:
A tuple of authed, dict, dict where authed is true if the client
has successfully completed an auth flow. If it is true, the first
dict contains the authenticated credentials of each stage.
If authed is false, the first dictionary is the server response to
the login request and should be passed back to the client.
In either case, the second dict contains the parameters for this
request (which may have been given only in a previous call).
"""
authdict = None
sid = None
if clientdict and 'auth' in clientdict:
authdict = clientdict['auth']
del clientdict['auth']
if 'session' in authdict:
sid = authdict['session']
sess = self._get_session_info(sid)
if len(clientdict) > 0:
# This was designed to allow the client to omit the parameters
# and just supply the session in subsequent calls so it split
# auth between devices by just sharing the session, (eg. so you
# could continue registration from your phone having clicked the
# email auth link on there). It's probably too open to abuse
# because it lets unauthenticated clients store arbitrary objects
# on a home server.
# sess['clientdict'] = clientdict
# self._save_session(sess)
pass
elif 'clientdict' in sess:
clientdict = sess['clientdict']
if not authdict:
defer.returnValue(
(False, self._auth_dict_for_flows(flows, sess), clientdict)
)
if 'creds' not in sess:
sess['creds'] = {}
creds = sess['creds']
# check auth type currently being presented
if 'type' in authdict:
if authdict['type'] not in self.checkers:
raise LoginError(400, "", Codes.UNRECOGNIZED)
result = yield self.checkers[authdict['type']](authdict, clientip)
if result:
creds[authdict['type']] = result
self._save_session(sess)
for f in flows:
if len(set(f) - set(creds.keys())) == 0:
logger.info("Auth completed with creds: %r", creds)
self._remove_session(sess)
defer.returnValue((True, creds, clientdict))
ret = self._auth_dict_for_flows(flows, sess)
ret['completed'] = creds.keys()
defer.returnValue((False, ret, clientdict))
@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):
"""
Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth.
"""
if stagetype not in self.checkers:
raise LoginError(400, "", Codes.MISSING_PARAM)
if 'session' not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
sess = self._get_session_info(
authdict['session']
)
if 'creds' not in sess:
sess['creds'] = {}
creds = sess['creds']
result = yield self.checkers[stagetype](authdict, clientip)
if result:
creds[stagetype] = result
self._save_session(sess)
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks
def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
user = authdict["user"]
password = authdict["password"]
if not user.startswith('@'):
user = UserID.create(user, self.hs.hostname).to_string()
user_info = yield self.store.get_user_by_id(user_id=user)
if not user_info:
logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
stored_hash = user_info[0]["password_hash"]
if bcrypt.checkpw(password, stored_hash):
defer.returnValue(user)
else:
logger.warn("Failed password login for user %s", user)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
try:
user_response = authdict["response"]
except KeyError:
# Client tried to provide captcha but didn't give the parameter:
# bad request.
raise LoginError(
400, "Captcha response is required",
errcode=Codes.CAPTCHA_NEEDED
)
logger.info(
"Submitting recaptcha response %s with remoteip %s",
user_response, clientip
)
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
client = SimpleHttpClient(self.hs)
data = yield client.post_urlencoded_get_json(
"https://www.google.com/recaptcha/api/siteverify",
args={
'secret': self.hs.config.recaptcha_private_key,
'response': user_response,
'remoteip': clientip,
}
)
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
resp_body = simplejson.loads(data)
if 'success' in resp_body and resp_body['success']:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks
def _check_email_identity(self, authdict, _):
yield run_on_reactor()
if 'threepid_creds' not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
threepid_creds = authdict['threepid_creds']
identity_handler = self.hs.get_handlers().identity_handler
logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,))
threepid = yield identity_handler.threepid_from_creds(threepid_creds)
if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
threepid['threepid_creds'] = authdict['threepid_creds']
defer.returnValue(threepid)
@defer.inlineCallbacks
def _check_dummy_auth(self, authdict, _):
yield run_on_reactor()
defer.returnValue(True)
def _get_params_recaptcha(self):
return {"public_key": self.hs.config.recaptcha_public_key}
def _auth_dict_for_flows(self, flows, session):
public_flows = []
for f in flows:
public_flows.append(f)
get_params = {
LoginType.RECAPTCHA: self._get_params_recaptcha,
}
params = {}
for f in public_flows:
for stage in f:
if stage in get_params and stage not in params:
params[stage] = get_params[stage]()
return {
"session": session['id'],
"flows": [{"stages": f} for f in public_flows],
"params": params
}
def _get_session_info(self, session_id):
if session_id not in self.sessions:
session_id = None
if not session_id:
# create a new session
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)
self.sessions[session_id] = {
"id": session_id,
}
return self.sessions[session_id]
def _save_session(self, session):
# TODO: Persistent storage
logger.debug("Saving session %s", session)
self.sessions[session["id"]] = session
def _remove_session(self, session):
logger.debug("Removing session %s", session)
del self.sessions[session["id"]]

View File

@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for interacting with Identity Servers"""
from twisted.internet import defer
from synapse.api.errors import (
CodeMessageException
)
from ._base import BaseHandler
from synapse.http.client import SimpleHttpClient
from synapse.util.async import run_on_reactor
import json
import logging
logger = logging.getLogger(__name__)
class IdentityHandler(BaseHandler):
def __init__(self, hs):
super(IdentityHandler, self).__init__(hs)
@defer.inlineCallbacks
def threepid_from_creds(self, creds):
yield run_on_reactor()
# TODO: get this from the homeserver rather than creating a new one for
# each request
http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable!
# trustedIdServers = ['matrix.org', 'localhost:8090']
trustedIdServers = ['matrix.org']
if not creds['id_server'] in trustedIdServers:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', creds['id_server'])
defer.returnValue(None)
data = {}
try:
data = yield http_client.get_json(
"https://%s%s" % (
creds['id_server'],
"/_matrix/identity/api/v1/3pid/getValidated3pid"
),
{'sid': creds['sid'], 'client_secret': creds['client_secret']}
)
except CodeMessageException as e:
data = json.loads(e.msg)
if 'medium' in data:
defer.returnValue(data)
defer.returnValue(None)
@defer.inlineCallbacks
def bind_threepid(self, creds, mxid):
yield run_on_reactor()
logger.debug("binding threepid %r to %s", creds, mxid)
http_client = SimpleHttpClient(self.hs)
data = None
try:
data = yield http_client.post_urlencoded_get_json(
"https://%s%s" % (
creds['id_server'], "/_matrix/identity/api/v1/3pid/bind"
),
{
'sid': creds['sid'],
'client_secret': creds['client_secret'],
'mxid': mxid,
}
)
logger.debug("bound threepid %r to %s", creds, mxid)
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data)

View File

@ -16,13 +16,9 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import LoginError, Codes, CodeMessageException from synapse.api.errors import LoginError, Codes
from synapse.http.client import SimpleHttpClient
from synapse.util.emailutils import EmailException
import synapse.util.emailutils as emailutils
import bcrypt import bcrypt
import json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -69,48 +65,19 @@ class LoginHandler(BaseHandler):
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks @defer.inlineCallbacks
def reset_password(self, user_id, email): def set_password(self, user_id, newpassword, token_id=None):
is_valid = yield self._check_valid_association(user_id, email) password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
logger.info("reset_password user=%s email=%s valid=%s", user_id, email,
is_valid) yield self.store.user_set_password_hash(user_id, password_hash)
if is_valid: yield self.store.user_delete_access_tokens_apart_from(user_id, token_id)
try: yield self.hs.get_pusherpool().remove_pushers_by_user_access_token(
# send an email out user_id, token_id
emailutils.send_email( )
smtp_server=self.hs.config.email_smtp_server, yield self.store.flush_user(user_id)
from_addr=self.hs.config.email_from_address,
to_addr=email,
subject="Password Reset",
body="TODO."
)
except EmailException as e:
logger.exception(e)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_valid_association(self, user_id, email): def add_threepid(self, user_id, medium, address, validated_at):
identity = yield self._query_email(email) yield self.store.user_add_threepid(
if identity and "mxid" in identity: user_id, medium, address, validated_at,
if identity["mxid"] == user_id: self.hs.get_clock().time_msec()
defer.returnValue(True) )
return
defer.returnValue(False)
@defer.inlineCallbacks
def _query_email(self, email):
http_client = SimpleHttpClient(self.hs)
try:
data = yield http_client.get_json(
# TODO FIXME This should be configurable.
# XXX: ID servers need to use HTTPS
"http://%s%s" % (
"matrix.org:8090", "/_matrix/identity/api/v1/lookup"
),
{
'medium': 'email',
'address': email
}
)
defer.returnValue(data)
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data)

View File

@ -18,18 +18,15 @@ from twisted.internet import defer
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError, AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
CodeMessageException
) )
from ._base import BaseHandler from ._base import BaseHandler
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.http.client import SimpleHttpClient
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
import base64 import base64
import bcrypt import bcrypt
import json
import logging import logging
import urllib import urllib
@ -44,6 +41,30 @@ class RegistrationHandler(BaseHandler):
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.distributor.declare("registered_user") self.distributor.declare("registered_user")
@defer.inlineCallbacks
def check_username(self, localpart):
yield run_on_reactor()
if urllib.quote(localpart) != localpart:
raise SynapseError(
400,
"User ID must only contain characters which do not"
" require URL encoding."
)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
u = yield self.store.get_user_by_id(user_id)
if u:
raise SynapseError(
400,
"User ID already taken.",
errcode=Codes.USER_IN_USE,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, localpart=None, password=None): def register(self, localpart=None, password=None):
"""Registers a new client on the server. """Registers a new client on the server.
@ -64,18 +85,11 @@ class RegistrationHandler(BaseHandler):
password_hash = bcrypt.hashpw(password, bcrypt.gensalt()) password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
if localpart: if localpart:
if localpart and urllib.quote(localpart) != localpart: yield self.check_username(localpart)
raise SynapseError(
400,
"User ID must only contain characters which do not"
" require URL encoding."
)
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
token = self._generate_token(user_id) token = self._generate_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
@ -157,7 +171,11 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response): def check_recaptcha(self, ip, private_key, challenge, response):
"""Checks a recaptcha is correct.""" """
Checks a recaptcha is correct.
Used only by c/s api v1
"""
captcha_response = yield self._validate_captcha( captcha_response = yield self._validate_captcha(
ip, ip,
@ -176,13 +194,18 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def register_email(self, threepidCreds): def register_email(self, threepidCreds):
"""Registers emails with an identity server.""" """
Registers emails with an identity server.
Used only by c/s api v1
"""
for c in threepidCreds: for c in threepidCreds:
logger.info("validating theeepidcred sid %s on id server %s", logger.info("validating theeepidcred sid %s on id server %s",
c['sid'], c['idServer']) c['sid'], c['idServer'])
try: try:
threepid = yield self._threepid_from_creds(c) identity_handler = self.hs.get_handlers().identity_handler
threepid = yield identity_handler.threepid_from_creds(c)
except: except:
logger.exception("Couldn't validate 3pid") logger.exception("Couldn't validate 3pid")
raise RegistrationError(400, "Couldn't validate 3pid") raise RegistrationError(400, "Couldn't validate 3pid")
@ -194,12 +217,16 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def bind_emails(self, user_id, threepidCreds): def bind_emails(self, user_id, threepidCreds):
"""Links emails with a user ID and informs an identity server.""" """Links emails with a user ID and informs an identity server.
Used only by c/s api v1
"""
# Now we have a matrix ID, bind it to the threepids we were given # Now we have a matrix ID, bind it to the threepids we were given
for c in threepidCreds: for c in threepidCreds:
identity_handler = self.hs.get_handlers().identity_handler
# XXX: This should be a deferred list, shouldn't it? # XXX: This should be a deferred list, shouldn't it?
yield self._bind_threepid(c, user_id) yield identity_handler.bind_threepid(c, user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_user_id_is_valid(self, user_id): def check_user_id_is_valid(self, user_id):
@ -226,62 +253,12 @@ class RegistrationHandler(BaseHandler):
def _generate_user_id(self): def _generate_user_id(self):
return "-" + stringutils.random_string(18) return "-" + stringutils.random_string(18)
@defer.inlineCallbacks
def _threepid_from_creds(self, creds):
# TODO: get this from the homeserver rather than creating a new one for
# each request
http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable!
trustedIdServers = ['matrix.org:8090', 'matrix.org']
if not creds['idServer'] in trustedIdServers:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', creds['idServer'])
defer.returnValue(None)
data = {}
try:
data = yield http_client.get_json(
# XXX: This should be HTTPS
"http://%s%s" % (
creds['idServer'],
"/_matrix/identity/api/v1/3pid/getValidated3pid"
),
{'sid': creds['sid'], 'clientSecret': creds['clientSecret']}
)
except CodeMessageException as e:
data = json.loads(e.msg)
if 'medium' in data:
defer.returnValue(data)
defer.returnValue(None)
@defer.inlineCallbacks
def _bind_threepid(self, creds, mxid):
yield
logger.debug("binding threepid")
http_client = SimpleHttpClient(self.hs)
data = None
try:
data = yield http_client.post_urlencoded_get_json(
# XXX: Change when ID servers are all HTTPS
"http://%s%s" % (
creds['idServer'], "/_matrix/identity/api/v1/3pid/bind"
),
{
'sid': creds['sid'],
'clientSecret': creds['clientSecret'],
'mxid': mxid,
}
)
logger.debug("bound threepid")
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def _validate_captcha(self, ip_addr, private_key, challenge, response): def _validate_captcha(self, ip_addr, private_key, challenge, response):
"""Validates the captcha provided. """Validates the captcha provided.
Used only by c/s api v1
Returns: Returns:
dict: Containing 'valid'(bool) and 'error_url'(str) if invalid. dict: Containing 'valid'(bool) and 'error_url'(str) if invalid.
@ -299,6 +276,9 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _submit_captcha(self, ip_addr, private_key, challenge, response): def _submit_captcha(self, ip_addr, private_key, challenge, response):
"""
Used only by c/s api v1
"""
# TODO: get this from the homeserver rather than creating a new one for # TODO: get this from the homeserver rather than creating a new one for
# each request # each request
client = CaptchaServerHttpClient(self.hs) client = CaptchaServerHttpClient(self.hs)

View File

@ -213,7 +213,8 @@ class RoomCreationHandler(BaseHandler):
"state_default": 50, "state_default": 50,
"ban": 50, "ban": 50,
"kick": 50, "kick": 50,
"redact": 50 "redact": 50,
"invite": 0,
}, },
) )

View File

@ -200,6 +200,8 @@ class CaptchaServerHttpClient(SimpleHttpClient):
""" """
Separate HTTP client for talking to google's captcha servers Separate HTTP client for talking to google's captcha servers
Only slightly special because accepts partial download responses Only slightly special because accepts partial download responses
used only by c/s api v1
""" """
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -131,10 +131,10 @@ class HttpServer(object):
""" """
def register_path(self, method, path_pattern, callback): def register_path(self, method, path_pattern, callback):
""" Register a callback that get's fired if we receive a http request """ Register a callback that gets fired if we receive a http request
with the given method for a path that matches the given regex. with the given method for a path that matches the given regex.
If the regex contains groups these get's passed to the calback via If the regex contains groups these gets passed to the calback via
an unpacked tuple. an unpacked tuple.
Args: Args:
@ -153,6 +153,13 @@ class JsonResource(HttpServer, resource.Resource):
Resources. Resources.
Register callbacks via register_path() Register callbacks via register_path()
Callbacks can return a tuple of status code and a dict in which case the
the dict will automatically be sent to the client as a JSON object.
The JsonResource is primarily intended for returning JSON, but callbacks
may send something other than JSON, they may do so by using the methods
on the request object and instead returning None.
""" """
isLeaf = True isLeaf = True
@ -185,9 +192,8 @@ class JsonResource(HttpServer, resource.Resource):
interface=self.hs.config.bind_host interface=self.hs.config.bind_host
) )
# Gets called by twisted
def render(self, request): def render(self, request):
""" This get's called by twisted every time someone sends us a request. """ This gets called by twisted every time someone sends us a request.
""" """
self._async_render(request) self._async_render(request)
return server.NOT_DONE_YET return server.NOT_DONE_YET
@ -195,7 +201,7 @@ class JsonResource(HttpServer, resource.Resource):
@request_handler @request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render(self, request): def _async_render(self, request):
""" This get's called by twisted every time someone sends us a request. """ This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and This checks if anyone has registered a callback for that method and
path. path.
""" """
@ -227,9 +233,11 @@ class JsonResource(HttpServer, resource.Resource):
urllib.unquote(u).decode("UTF-8") for u in m.groups() urllib.unquote(u).decode("UTF-8") for u in m.groups()
] ]
code, response = yield callback(request, *args) callback_return = yield callback(request, *args)
if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response)
self._send_response(request, code, response)
response_timer.inc_by( response_timer.inc_by(
self.clock.time_msec() - start, request.method, servlet_classname self.clock.time_msec() - start, request.method, servlet_classname
) )

View File

@ -253,7 +253,8 @@ class Pusher(object):
self.user_name, config, timeout=0) self.user_name, config, timeout=0)
self.last_token = chunk['end'] self.last_token = chunk['end']
self.store.update_pusher_last_token( self.store.update_pusher_last_token(
self.app_id, self.pushkey, self.last_token) self.app_id, self.pushkey, self.user_name, self.last_token
)
logger.info("Pusher %s for user %s starting from token %s", logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_name, self.last_token) self.pushkey, self.user_name, self.last_token)
@ -314,7 +315,7 @@ class Pusher(object):
pk pk
) )
yield self.hs.get_pusherpool().remove_pusher( yield self.hs.get_pusherpool().remove_pusher(
self.app_id, pk self.app_id, pk, self.user_name
) )
if not self.alive: if not self.alive:
@ -326,6 +327,7 @@ class Pusher(object):
self.store.update_pusher_last_token_and_success( self.store.update_pusher_last_token_and_success(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name,
self.last_token, self.last_token,
self.clock.time_msec() self.clock.time_msec()
) )
@ -334,6 +336,7 @@ class Pusher(object):
self.store.update_pusher_failing_since( self.store.update_pusher_failing_since(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name,
self.failing_since) self.failing_since)
else: else:
if not self.failing_since: if not self.failing_since:
@ -341,6 +344,7 @@ class Pusher(object):
self.store.update_pusher_failing_since( self.store.update_pusher_failing_since(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name,
self.failing_since self.failing_since
) )
@ -358,6 +362,7 @@ class Pusher(object):
self.store.update_pusher_last_token( self.store.update_pusher_last_token(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name,
self.last_token self.last_token
) )
@ -365,6 +370,7 @@ class Pusher(object):
self.store.update_pusher_failing_since( self.store.update_pusher_failing_since(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name,
self.failing_since self.failing_since
) )
else: else:

View File

@ -52,7 +52,7 @@ class PusherPool:
self._start_pushers(pushers) self._start_pushers(pushers)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_name, profile_tag, kind, app_id, def add_pusher(self, user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, pushkey, lang, data): app_display_name, device_display_name, pushkey, lang, data):
# we try to create the pusher just to validate the config: it # we try to create the pusher just to validate the config: it
# will then get pulled out of the database, # will then get pulled out of the database,
@ -66,7 +66,7 @@ class PusherPool:
"app_display_name": app_display_name, "app_display_name": app_display_name,
"device_display_name": device_display_name, "device_display_name": device_display_name,
"pushkey": pushkey, "pushkey": pushkey,
"pushkey_ts": self.hs.get_clock().time_msec(), "ts": self.hs.get_clock().time_msec(),
"lang": lang, "lang": lang,
"data": data, "data": data,
"last_token": None, "last_token": None,
@ -74,17 +74,50 @@ class PusherPool:
"failing_since": None "failing_since": None
}) })
yield self._add_pusher_to_store( yield self._add_pusher_to_store(
user_name, profile_tag, kind, app_id, user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, app_display_name, device_display_name,
pushkey, lang, data pushkey, lang, data
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _add_pusher_to_store(self, user_name, profile_tag, kind, app_id, def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey,
app_display_name, device_display_name, not_user_id):
to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(
app_id, pushkey
)
for p in to_remove:
if p['user_name'] != not_user_id:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
app_id, pushkey, p['user_name']
)
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def remove_pushers_by_user_access_token(self, user_id, not_access_token_id):
all = yield self.store.get_all_pushers()
logger.info(
"Removing all pushers for user %s except access token %s",
user_id, not_access_token_id
)
for p in all:
if (
p['user_name'] == user_id and
p['access_token'] != not_access_token_id
):
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name']
)
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def _add_pusher_to_store(self, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, lang, data): pushkey, lang, data):
yield self.store.add_pusher( yield self.store.add_pusher(
user_name=user_name, user_name=user_name,
access_token=access_token,
profile_tag=profile_tag, profile_tag=profile_tag,
kind=kind, kind=kind,
app_id=app_id, app_id=app_id,
@ -95,7 +128,7 @@ class PusherPool:
lang=lang, lang=lang,
data=data, data=data,
) )
self._refresh_pusher((app_id, pushkey)) self._refresh_pusher(app_id, pushkey, user_name)
def _create_pusher(self, pusherdict): def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http': if pusherdict['kind'] == 'http':
@ -107,7 +140,7 @@ class PusherPool:
app_display_name=pusherdict['app_display_name'], app_display_name=pusherdict['app_display_name'],
device_display_name=pusherdict['device_display_name'], device_display_name=pusherdict['device_display_name'],
pushkey=pusherdict['pushkey'], pushkey=pusherdict['pushkey'],
pushkey_ts=pusherdict['pushkey_ts'], pushkey_ts=pusherdict['ts'],
data=pusherdict['data'], data=pusherdict['data'],
last_token=pusherdict['last_token'], last_token=pusherdict['last_token'],
last_success=pusherdict['last_success'], last_success=pusherdict['last_success'],
@ -120,29 +153,42 @@ class PusherPool:
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _refresh_pusher(self, app_id_pushkey): def _refresh_pusher(self, app_id, pushkey, user_name):
p = yield self.store.get_pushers_by_app_id_and_pushkey( resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(
app_id_pushkey app_id, pushkey
) )
self._start_pushers([p]) p = None
for r in resultlist:
if r['user_name'] == user_name:
p = r
if p:
self._start_pushers([p])
def _start_pushers(self, pushers): def _start_pushers(self, pushers):
logger.info("Starting %d pushers", len(pushers)) logger.info("Starting %d pushers", len(pushers))
for pusherdict in pushers: for pusherdict in pushers:
p = self._create_pusher(pusherdict) p = self._create_pusher(pusherdict)
if p: if p:
fullid = "%s:%s" % (pusherdict['app_id'], pusherdict['pushkey']) fullid = "%s:%s:%s" % (
pusherdict['app_id'],
pusherdict['pushkey'],
pusherdict['user_name']
)
if fullid in self.pushers: if fullid in self.pushers:
self.pushers[fullid].stop() self.pushers[fullid].stop()
self.pushers[fullid] = p self.pushers[fullid] = p
p.start() p.start()
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey): def remove_pusher(self, app_id, pushkey, user_name):
fullid = "%s:%s" % (app_id, pushkey) fullid = "%s:%s:%s" % (app_id, pushkey, user_name)
if fullid in self.pushers: if fullid in self.pushers:
logger.info("Stopping pusher %s", fullid) logger.info("Stopping pusher %s", fullid)
self.pushers[fullid].stop() self.pushers[fullid].stop()
del self.pushers[fullid] del self.pushers[fullid]
yield self.store.delete_pusher_by_app_id_pushkey(app_id, pushkey) yield self.store.delete_pusher_by_app_id_pushkey_user_name(
app_id, pushkey, user_name
)

View File

@ -48,5 +48,5 @@ class ClientV1RestServlet(RestServlet):
self.hs = hs self.hs = hs
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory() self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth() self.auth = hs.get_v1auth()
self.txns = HttpTransactionStore() self.txns = HttpTransactionStore()

View File

@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
user, _ = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
@ -37,7 +37,7 @@ class PusherRestServlet(ClientV1RestServlet):
and 'kind' in content and and 'kind' in content and
content['kind'] is None): content['kind'] is None):
yield pusher_pool.remove_pusher( yield pusher_pool.remove_pusher(
content['app_id'], content['pushkey'] content['app_id'], content['pushkey'], user_name=user.to_string()
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -51,9 +51,21 @@ class PusherRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Missing parameters: "+','.join(missing), raise SynapseError(400, "Missing parameters: "+','.join(missing),
errcode=Codes.MISSING_PARAM) errcode=Codes.MISSING_PARAM)
append = False
if 'append' in content:
append = content['append']
if not append:
yield pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
app_id=content['app_id'],
pushkey=content['pushkey'],
not_user_id=user.to_string()
)
try: try:
yield pusher_pool.add_pusher( yield pusher_pool.add_pusher(
user_name=user.to_string(), user_name=user.to_string(),
access_token=client.token_id,
profile_tag=content['profile_tag'], profile_tag=content['profile_tag'],
kind=content['kind'], kind=content['kind'],
app_id=content['app_id'], app_id=content['app_id'],

View File

@ -15,7 +15,10 @@
from . import ( from . import (
sync, sync,
filter filter,
account,
register,
auth
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -32,3 +35,6 @@ class ClientV2AlphaRestResource(JsonResource):
def register_servlets(client_resource, hs): def register_servlets(client_resource, hs):
sync.register_servlets(hs, client_resource) sync.register_servlets(hs, client_resource)
filter.register_servlets(hs, client_resource) filter.register_servlets(hs, client_resource)
account.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource)
auth.register_servlets(hs, client_resource)

View File

@ -17,9 +17,11 @@
""" """
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
from synapse.api.errors import SynapseError
import re import re
import logging import logging
import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,3 +38,23 @@ def client_v2_pattern(path_regex):
SRE_Pattern SRE_Pattern
""" """
return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex) return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)
def parse_request_allow_empty(request):
content = request.content.read()
if content is None or content == '':
return None
try:
return simplejson.loads(content)
except simplejson.JSONDecodeError:
raise SynapseError(400, "Content not JSON.")
def parse_json_dict_from_request(request):
try:
content = simplejson.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except simplejson.JSONDecodeError:
raise SynapseError(400, "Content not JSON.")

View File

@ -0,0 +1,159 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import LoginError, SynapseError, Codes
from synapse.http.servlet import RestServlet
from synapse.util.async import run_on_reactor
from ._base import client_v2_pattern, parse_json_dict_from_request
import logging
logger = logging.getLogger(__name__)
class PasswordRestServlet(RestServlet):
PATTERN = client_v2_pattern("/account/password")
def __init__(self, hs):
super(PasswordRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
body = parse_json_dict_from_request(request)
authed, result, params = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY]
], body)
if not authed:
defer.returnValue((401, result))
user_id = None
if LoginType.PASSWORD in result:
# if using password, they should also be logged in
auth_user, client = yield self.auth.get_user_by_req(request)
if auth_user.to_string() != result[LoginType.PASSWORD]:
raise LoginError(400, "", Codes.UNKNOWN)
user_id = auth_user.to_string()
elif LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
if 'medium' not in threepid or 'address' not in threepid:
raise SynapseError(500, "Malformed threepid")
# if using email, we must know about the email they're authing with!
threepid_user = yield self.hs.get_datastore().get_user_by_threepid(
threepid['medium'], threepid['address']
)
if not threepid_user:
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
user_id = threepid_user
else:
logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
if 'new_password' not in params:
raise SynapseError(400, "", Codes.MISSING_PARAM)
new_password = params['new_password']
yield self.login_handler.set_password(
user_id, new_password, None
)
defer.returnValue((200, {}))
def on_OPTIONS(self, _):
return 200, {}
class ThreepidRestServlet(RestServlet):
PATTERN = client_v2_pattern("/account/3pid")
def __init__(self, hs):
super(ThreepidRestServlet, self).__init__()
self.hs = hs
self.login_handler = hs.get_handlers().login_handler
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
yield run_on_reactor()
auth_user, _ = yield self.auth.get_user_by_req(request)
threepids = yield self.hs.get_datastore().user_get_threepids(
auth_user.to_string()
)
defer.returnValue((200, {'threepids': threepids}))
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
body = parse_json_dict_from_request(request)
if 'threePidCreds' not in body:
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
threePidCreds = body['threePidCreds']
auth_user, client = yield self.auth.get_user_by_req(request)
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
if not threepid:
raise SynapseError(
400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED
)
for reqd in ['medium', 'address', 'validated_at']:
if reqd not in threepid:
logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
raise SynapseError(500, "Invalid response from ID Server")
yield self.login_handler.add_threepid(
auth_user.to_string(),
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
if 'bind' in body and body['bind']:
logger.debug(
"Binding emails %s to %s",
threepid, auth_user.to_string()
)
yield self.identity_handler.bind_threepid(
threePidCreds, auth_user.to_string()
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
PasswordRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)

View File

@ -0,0 +1,190 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern
import logging
logger = logging.getLogger(__name__)
RECAPTCHA_TEMPLATE = """
<html>
<head>
<title>Authentication</title>
<meta name='viewport' content='width=device-width, initial-scale=1,
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<script src="https://www.google.com/recaptcha/api.js"
async defer></script>
<script src="//code.jquery.com/jquery-1.11.2.min.js"></script>
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
<script>
function captchaDone() {
$('#registrationForm').submit();
}
</script>
</head>
<body>
<form id="registrationForm" method="post" action="%(myurl)s">
<div>
<p>
Hello! We need to prevent computer programs and other automated
things from creating accounts on this server.
</p>
<p>
Please verify that you're not a robot.
</p>
<input type="hidden" name="session" value="%(session)s" />
<div class="g-recaptcha"
data-sitekey="%(sitekey)s"
data-callback="captchaDone">
</div>
<noscript>
<input type="submit" value="All Done" />
</noscript>
</div>
</div>
</form>
</body>
</html>
"""
SUCCESS_TEMPLATE = """
<html>
<head>
<title>Success!</title>
<meta name='viewport' content='width=device-width, initial-scale=1,
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
<script>
if (window.onAuthDone != undefined) {
window.onAuthDone();
}
</script>
</head>
<body>
<div>
<p>Thank you</p>
<p>You may now close this window and return to the application</p>
</div>
</body>
</html>
"""
class AuthRestServlet(RestServlet):
"""
Handles Client / Server API authentication in any situations where it
cannot be handled in the normal flow (with requests to the same endpoint).
Current use is for web fallback auth.
"""
PATTERN = client_v2_pattern("/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs):
super(AuthRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.registration_handler = hs.get_handlers().registration_handler
@defer.inlineCallbacks
def on_GET(self, request, stagetype):
yield
if stagetype == LoginType.RECAPTCHA:
if ('session' not in request.args or
len(request.args['session']) == 0):
raise SynapseError(400, "No session supplied")
session = request.args["session"][0]
html = RECAPTCHA_TEMPLATE % {
'session': session,
'myurl': "%s/auth/%s/fallback/web" % (
CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA
),
'sitekey': self.hs.config.recaptcha_public_key,
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
request.finish()
defer.returnValue(None)
else:
raise SynapseError(404, "Unknown auth stage type")
@defer.inlineCallbacks
def on_POST(self, request, stagetype):
yield
if stagetype == "m.login.recaptcha":
if ('g-recaptcha-response' not in request.args or
len(request.args['g-recaptcha-response'])) == 0:
raise SynapseError(400, "No captcha response supplied")
if ('session' not in request.args or
len(request.args['session'])) == 0:
raise SynapseError(400, "No session supplied")
session = request.args['session'][0]
authdict = {
'response': request.args['g-recaptcha-response'][0],
'session': session,
}
success = yield self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA,
authdict,
self.hs.get_ip_from_request(request)
)
if success:
html = SUCCESS_TEMPLATE
else:
html = RECAPTCHA_TEMPLATE % {
'session': session,
'myurl': "%s/auth/%s/fallback/web" % (
CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA
),
'sitekey': self.hs.config.recaptcha_public_key,
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
request.finish()
defer.returnValue(None)
else:
raise SynapseError(404, "Unknown auth stage type")
def on_OPTIONS(self, _):
return 200, {}
def register_servlets(hs, http_server):
AuthRestServlet(hs).register(http_server)

View File

@ -0,0 +1,183 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes
from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_request_allow_empty
import logging
import hmac
from hashlib import sha1
from synapse.util.async import run_on_reactor
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
# exist. It's a _really minor_ security flaw to use plain string comparison
# because the timing attack is so obscured by all the other code here it's
# unlikely to make much difference
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
compare_digest = lambda a, b: a == b
logger = logging.getLogger(__name__)
class RegisterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/register")
def __init__(self, hs):
super(RegisterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
body = parse_request_allow_empty(request)
if 'password' not in body:
raise SynapseError(400, "", Codes.MISSING_PARAM)
if 'username' in body:
desired_username = body['username']
yield self.registration_handler.check_username(desired_username)
is_using_shared_secret = False
is_application_server = False
service = None
if 'access_token' in request.args:
service = yield self.auth.get_appservice_by_req(request)
if self.hs.config.enable_registration_captcha:
flows = [
[LoginType.RECAPTCHA],
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]
]
else:
flows = [
[LoginType.DUMMY],
[LoginType.EMAIL_IDENTITY]
]
if service:
is_application_server = True
elif 'mac' in body:
# Check registration-specific shared secret auth
if 'username' not in body:
raise SynapseError(400, "", Codes.MISSING_PARAM)
self._check_shared_secret_auth(
body['username'], body['mac']
)
is_using_shared_secret = True
else:
authed, result, params = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
if not authed:
defer.returnValue((401, result))
can_register = (
not self.hs.config.disable_registration
or is_application_server
or is_using_shared_secret
)
if not can_register:
raise SynapseError(403, "Registration has been disabled")
if 'password' not in params:
raise SynapseError(400, "", Codes.MISSING_PARAM)
desired_username = params['username'] if 'username' in params else None
new_password = params['password']
(user_id, token) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password
)
if LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
for reqd in ['medium', 'address', 'validated_at']:
if reqd not in threepid:
logger.info("Can't add incomplete 3pid")
else:
yield self.login_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
if 'bind_email' in params and params['bind_email']:
logger.info("bind_email specified: binding")
emailThreepid = result[LoginType.EMAIL_IDENTITY]
threepid_creds = emailThreepid['threepid_creds']
logger.debug("Binding emails %s to %s" % (
emailThreepid, user_id
))
yield self.identity_handler.bind_threepid(threepid_creds, user_id)
else:
logger.info("bind_email not specified: not binding email")
result = {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
def on_OPTIONS(self, _):
return 200, {}
def _check_shared_secret_auth(self, username, mac):
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
user = username.encode("utf-8")
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got_mac = str(mac)
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret,
msg=user,
digestmod=sha1,
).hexdigest()
if compare_digest(want_mac, got_mac):
return True
else:
raise SynapseError(
403, "HMAC incorrect",
)
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)

View File

@ -65,6 +65,7 @@ class BaseHomeServer(object):
'replication_layer', 'replication_layer',
'datastore', 'datastore',
'handlers', 'handlers',
'v1auth',
'auth', 'auth',
'rest_servlet_factory', 'rest_servlet_factory',
'state_handler', 'state_handler',
@ -181,6 +182,15 @@ class HomeServer(BaseHomeServer):
def build_auth(self): def build_auth(self):
return Auth(self) return Auth(self)
def build_v1auth(self):
orf = Auth(self)
# Matrix spec makes no reference to what HTTP status code is returned,
# but the V1 API uses 403 where it means 401, and the webclient
# relies on this behaviour, so V1 gets its own copy of the auth
# with backwards compat behaviour.
orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403
return orf
def build_state_handler(self): def build_state_handler(self):
return StateHandler(self) return StateHandler(self)

View File

@ -487,3 +487,15 @@ def prepare_sqlite3_database(db_conn):
" VALUES (?,?)", " VALUES (?,?)",
(row[0], False) (row[0], False)
) )
def are_all_users_on_domain(txn, database_engine, domain):
sql = database_engine.convert_param_style(
"SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
)
pat = "%:" + domain
txn.execute(sql, (pat,))
num_not_matching = txn.fetchall()[0][0]
if num_not_matching == 0:
return True
return False

View File

@ -21,82 +21,38 @@ from synapse.api.errors import StoreError
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
import logging import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PusherStore(SQLBaseStore): class PusherStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_pushers_by_app_id_and_pushkey(self, app_id_and_pushkey): def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
sql = ( sql = (
"SELECT id, user_name, kind, profile_tag, app_id," "SELECT * FROM pushers "
"app_display_name, device_display_name, pushkey, ts, data, "
"last_token, last_success, failing_since "
"FROM pushers "
"WHERE app_id = ? AND pushkey = ?" "WHERE app_id = ? AND pushkey = ?"
) )
rows = yield self._execute( rows = yield self._execute_and_decode(
"get_pushers_by_app_id_and_pushkey", None, sql, "get_pushers_by_app_id_and_pushkey",
app_id_and_pushkey[0], app_id_and_pushkey[1] sql,
app_id, pushkey
) )
ret = [ defer.returnValue(rows)
{
"id": r[0],
"user_name": r[1],
"kind": r[2],
"profile_tag": r[3],
"app_id": r[4],
"app_display_name": r[5],
"device_display_name": r[6],
"pushkey": r[7],
"pushkey_ts": r[8],
"data": json.loads(r[9]),
"last_token": r[10],
"last_success": r[11],
"failing_since": r[12]
}
for r in rows
]
defer.returnValue(ret[0])
@defer.inlineCallbacks @defer.inlineCallbacks
def get_all_pushers(self): def get_all_pushers(self):
sql = ( sql = (
"SELECT id, user_name, kind, profile_tag, app_id," "SELECT * FROM pushers"
"app_display_name, device_display_name, pushkey, ts, data, "
"last_token, last_success, failing_since "
"FROM pushers"
) )
rows = yield self._execute("get_all_pushers", None, sql) rows = yield self._execute_and_decode("get_all_pushers", sql)
ret = [ defer.returnValue(rows)
{
"id": r[0],
"user_name": r[1],
"kind": r[2],
"profile_tag": r[3],
"app_id": r[4],
"app_display_name": r[5],
"device_display_name": r[6],
"pushkey": r[7],
"pushkey_ts": r[8],
"data": json.loads(r[9]),
"last_token": r[10],
"last_success": r[11],
"failing_since": r[12]
}
for r in rows
]
defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_name, profile_tag, kind, app_id, def add_pusher(self, user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, app_display_name, device_display_name,
pushkey, pushkey_ts, lang, data): pushkey, pushkey_ts, lang, data):
try: try:
@ -106,9 +62,10 @@ class PusherStore(SQLBaseStore):
dict( dict(
app_id=app_id, app_id=app_id,
pushkey=pushkey, pushkey=pushkey,
user_name=user_name,
), ),
dict( dict(
user_name=user_name, access_token=access_token,
kind=kind, kind=kind,
profile_tag=profile_tag, profile_tag=profile_tag,
app_display_name=app_display_name, app_display_name=app_display_name,
@ -127,37 +84,38 @@ class PusherStore(SQLBaseStore):
raise StoreError(500, "Problem creating pusher.") raise StoreError(500, "Problem creating pusher.")
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey(self, app_id, pushkey): def delete_pusher_by_app_id_pushkey_user_name(self, app_id, pushkey, user_name):
yield self._simple_delete_one( yield self._simple_delete_one(
PushersTable.table_name, PushersTable.table_name,
{"app_id": app_id, "pushkey": pushkey}, {"app_id": app_id, "pushkey": pushkey, 'user_name': user_name},
desc="delete_pusher_by_app_id_pushkey", desc="delete_pusher_by_app_id_pushkey_user_name",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def update_pusher_last_token(self, app_id, pushkey, last_token): def update_pusher_last_token(self, app_id, pushkey, user_name, last_token):
yield self._simple_update_one( yield self._simple_update_one(
PushersTable.table_name, PushersTable.table_name,
{'app_id': app_id, 'pushkey': pushkey}, {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
{'last_token': last_token}, {'last_token': last_token},
desc="update_pusher_last_token", desc="update_pusher_last_token",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def update_pusher_last_token_and_success(self, app_id, pushkey, def update_pusher_last_token_and_success(self, app_id, pushkey, user_name,
last_token, last_success): last_token, last_success):
yield self._simple_update_one( yield self._simple_update_one(
PushersTable.table_name, PushersTable.table_name,
{'app_id': app_id, 'pushkey': pushkey}, {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
{'last_token': last_token, 'last_success': last_success}, {'last_token': last_token, 'last_success': last_success},
desc="update_pusher_last_token_and_success", desc="update_pusher_last_token_and_success",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def update_pusher_failing_since(self, app_id, pushkey, failing_since): def update_pusher_failing_since(self, app_id, pushkey, user_name,
failing_since):
yield self._simple_update_one( yield self._simple_update_one(
PushersTable.table_name, PushersTable.table_name,
{'app_id': app_id, 'pushkey': pushkey}, {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_name},
{'failing_since': failing_since}, {'failing_since': failing_since},
desc="update_pusher_failing_since", desc="update_pusher_failing_since",
) )

View File

@ -87,9 +87,8 @@ class RegistrationStore(SQLBaseStore):
(next_id, user_id, token,) (next_id, user_id, token,)
) )
@defer.inlineCallbacks
def get_user_by_id(self, user_id): def get_user_by_id(self, user_id):
user_info = yield self._simple_select_one( return self._simple_select_one(
table="users", table="users",
keyvalues={ keyvalues={
"name": user_id, "name": user_id,
@ -98,13 +97,42 @@ class RegistrationStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
defer.returnValue(user_info) @defer.inlineCallbacks
def user_set_password_hash(self, user_id, password_hash):
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
pointless. Use flush_user separately.
"""
yield self._simple_update_one('users', {
'name': user_id
}, {
'password_hash': password_hash
})
@defer.inlineCallbacks
def user_delete_access_tokens_apart_from(self, user_id, token_id):
rows = yield self.get_user_by_id(user_id)
if len(rows) == 0:
raise Exception("No such user!")
yield self._execute(
"delete_access_tokens_apart_from", None,
"DELETE FROM access_tokens WHERE user_id = ? AND id != ?",
rows[0]['id'], token_id
)
@defer.inlineCallbacks
def flush_user(self, user_id):
rows = yield self._execute(
'flush_user', None,
"SELECT token FROM access_tokens WHERE user_id = ?",
user_id
)
for r in rows:
self.get_user_by_token.invalidate(r)
@cached() @cached()
# TODO(paul): Currently there's no code to invalidate this cache. That
# means if/when we ever add internal ways to invalidate access tokens or
# change whether a user is a server admin, those will need to invoke
# store.get_user_by_token.invalidate(token)
def get_user_by_token(self, token): def get_user_by_token(self, token):
"""Get a user from the given access token. """Get a user from the given access token.
@ -148,4 +176,40 @@ class RegistrationStore(SQLBaseStore):
if rows: if rows:
return rows[0] return rows[0]
raise StoreError(404, "Token not found.") return None
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self._simple_upsert("user_threepids", {
"user": user_id,
"medium": medium,
"address": address,
}, {
"validated_at": validated_at,
"added_at": added_at,
})
@defer.inlineCallbacks
def user_get_threepids(self, user_id):
ret = yield self._simple_select_list(
"user_threepids", {
"user": user_id
},
['medium', 'address', 'validated_at', 'added_at'],
'user_get_threepids'
)
defer.returnValue(ret)
@defer.inlineCallbacks
def get_user_by_threepid(self, medium, address):
ret = yield self._simple_select_one(
"user_threepids",
{
"medium": medium,
"address": address
},
['user'], True, 'get_user_by_threepid'
)
if ret:
defer.returnValue(ret['user'])
defer.returnValue(None)

View File

@ -0,0 +1,25 @@
-- Drop, copy & recreate pushers table to change unique key
-- Also add access_token column at the same time
CREATE TABLE IF NOT EXISTS pushers2 (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL,
access_token INTEGER DEFAULT NULL,
profile_tag varchar(32) NOT NULL,
kind varchar(8) NOT NULL,
app_id varchar(64) NOT NULL,
app_display_name varchar(64) NOT NULL,
device_display_name varchar(128) NOT NULL,
pushkey blob NOT NULL,
ts BIGINT NOT NULL,
lang varchar(8),
data blob,
last_token TEXT,
last_success BIGINT,
failing_since BIGINT,
FOREIGN KEY(user_name) REFERENCES users(name),
UNIQUE (app_id, pushkey, user_name)
);
INSERT INTO pushers2 (id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since)
SELECT id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers;
DROP TABLE pushers;
ALTER TABLE pushers2 RENAME TO pushers;

View File

@ -75,7 +75,7 @@ class PresenceStateTestCase(unittest.TestCase):
"token_id": 1, "token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_token = _get_user_by_token
room_member_handler = hs.handlers.room_member_handler = Mock( room_member_handler = hs.handlers.room_member_handler = Mock(
spec=[ spec=[
@ -170,7 +170,7 @@ class PresenceListTestCase(unittest.TestCase):
] ]
) )
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_token = _get_user_by_token
presence.register_servlets(hs, self.mock_resource) presence.register_servlets(hs, self.mock_resource)
@ -277,7 +277,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
def _get_user_by_req(req=None): def _get_user_by_req(req=None):
return (UserID.from_string(myid), "") return (UserID.from_string(myid), "")
hs.get_auth().get_user_by_req = _get_user_by_req hs.get_v1auth().get_user_by_req = _get_user_by_req
presence.register_servlets(hs, self.mock_resource) presence.register_servlets(hs, self.mock_resource)
events.register_servlets(hs, self.mock_resource) events.register_servlets(hs, self.mock_resource)

View File

@ -55,7 +55,7 @@ class ProfileTestCase(unittest.TestCase):
def _get_user_by_req(request=None): def _get_user_by_req(request=None):
return (UserID.from_string(myid), "") return (UserID.from_string(myid), "")
hs.get_auth().get_user_by_req = _get_user_by_req hs.get_v1auth().get_user_by_req = _get_user_by_req
hs.get_handlers().profile_handler = self.mock_handler hs.get_handlers().profile_handler = self.mock_handler

View File

@ -61,7 +61,7 @@ class RoomPermissionsTestCase(RestTestCase):
"device_id": None, "device_id": None,
"token_id": 1, "token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_token = _get_user_by_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -71,7 +71,7 @@ class RoomPermissionsTestCase(RestTestCase):
synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
self.auth = hs.get_auth() self.auth = hs.get_v1auth()
# create some rooms under the name rmcreator_id # create some rooms under the name rmcreator_id
self.uncreated_rmid = "!aa:test" self.uncreated_rmid = "!aa:test"
@ -448,7 +448,7 @@ class RoomsMemberListTestCase(RestTestCase):
"device_id": None, "device_id": None,
"token_id": 1, "token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_token = _get_user_by_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -528,7 +528,7 @@ class RoomsCreateTestCase(RestTestCase):
"device_id": None, "device_id": None,
"token_id": 1, "token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_token = _get_user_by_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -622,7 +622,7 @@ class RoomTopicTestCase(RestTestCase):
"token_id": 1, "token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_token = _get_user_by_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -728,7 +728,7 @@ class RoomMemberStateTestCase(RestTestCase):
"device_id": None, "device_id": None,
"token_id": 1, "token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_token = _get_user_by_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -855,7 +855,7 @@ class RoomMessagesTestCase(RestTestCase):
"device_id": None, "device_id": None,
"token_id": 1, "token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_token = _get_user_by_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -952,7 +952,7 @@ class RoomInitialSyncTestCase(RestTestCase):
"device_id": None, "device_id": None,
"token_id": 1, "token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_token = _get_user_by_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)

View File

@ -69,7 +69,7 @@ class RoomTypingTestCase(RestTestCase):
"token_id": 1, "token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_token = _get_user_by_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)