Add management endpoints for account validity

This commit is contained in:
Brendan Abolivier 2019-04-16 20:13:59 +01:00 committed by Erik Johnston
parent 20f0617e87
commit eaf41a943b
8 changed files with 246 additions and 26 deletions

1
changelog.d/5073.feature Normal file
View File

@ -0,0 +1 @@
Add time-based account expiration.

View File

@ -0,0 +1,42 @@
Account validity API
====================
This API allows a server administrator to manage the validity of an account. To
use it, you must enable the account validity feature (under
``account_validity``) in Synapse's configuration.
Renew account
-------------
This API extends the validity of an account by as much time as configured in the
``period`` parameter from the ``account_validity`` configuration.
The API is::
POST /_matrix/client/unstable/account_validity/send_mail
with the following body:
.. code:: json
{
"user_id": "<user ID for the account to renew>",
"expiration_ts": 0,
"enable_renewal_emails": true
}
``expiration_ts`` is an optional parameter and overrides the expiration date,
which otherwise defaults to now + validity period.
``enable_renewal_emails`` is also an optional parameter and enables/disables
sending renewal emails to the user. Defaults to true.
The API returns with the new expiration date for this account, as a timestamp in
milliseconds since epoch:
.. code:: json
{
"expiration_ts": 0
}

View File

@ -232,7 +232,7 @@ class Auth(object):
if self._account_validity.enabled: if self._account_validity.enabled:
user_id = user.to_string() user_id = user.to_string()
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
if expiration_ts and self.clock.time_msec() >= expiration_ts: if expiration_ts is not None and self.clock.time_msec() >= expiration_ts:
raise AuthError( raise AuthError(
403, 403,
"User account has expired", "User account has expired",

View File

@ -90,6 +90,11 @@ class AccountValidityHandler(object):
expiration_ts=user["expiration_ts_ms"], expiration_ts=user["expiration_ts_ms"],
) )
@defer.inlineCallbacks
def send_renewal_email_to_user(self, user_id):
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
yield self._send_renewal_email(user_id, expiration_ts)
@defer.inlineCallbacks @defer.inlineCallbacks
def _send_renewal_email(self, user_id, expiration_ts): def _send_renewal_email(self, user_id, expiration_ts):
"""Sends out a renewal email to every email address attached to the given user """Sends out a renewal email to every email address attached to the given user
@ -217,12 +222,32 @@ class AccountValidityHandler(object):
renewal_token (str): Token sent with the renewal request. renewal_token (str): Token sent with the renewal request.
""" """
user_id = yield self.store.get_user_from_renewal_token(renewal_token) user_id = yield self.store.get_user_from_renewal_token(renewal_token)
logger.debug("Renewing an account for user %s", user_id) logger.debug("Renewing an account for user %s", user_id)
yield self.renew_account_for_user(user_id)
new_expiration_date = self.clock.time_msec() + self._account_validity.period @defer.inlineCallbacks
def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
"""Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's
configuration.
yield self.store.renew_account_for_user( Args:
renewal_token (str): Token sent with the renewal request.
expiration_ts (int): New expiration date. Defaults to now + validity period.
email_sent (bool): Whether an email has been sent for this validity period.
Defaults to False.
Returns:
defer.Deferred[int]: New expiration date for this account, as a timestamp
in milliseconds since epoch.
"""
if expiration_ts is None:
expiration_ts = self.clock.time_msec() + self._account_validity.period
yield self.store.set_account_validity_for_user(
user_id=user_id, user_id=user_id,
new_expiration_ts=new_expiration_date, expiration_ts=expiration_ts,
email_sent=email_sent,
) )
defer.returnValue(expiration_ts)

View File

@ -786,6 +786,44 @@ class SearchUsersRestServlet(ClientV1RestServlet):
defer.returnValue((200, ret)) defer.returnValue((200, ret))
class AccountValidityRenewServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/account_validity/validity$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(AccountValidityRenewServlet, self).__init__(hs)
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
body = parse_json_object_from_request(request)
if "user_id" not in body:
raise SynapseError(400, "Missing property 'user_id' in the request body")
expiration_ts = yield self.account_activity_handler.renew_account_for_user(
body["user_id"], body.get("expiration_ts"),
not body.get("enable_renewal_emails", True),
)
res = {
"expiration_ts": expiration_ts,
}
defer.returnValue((200, res))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server) WhoisRestServlet(hs).register(http_server)
PurgeMediaCacheRestServlet(hs).register(http_server) PurgeMediaCacheRestServlet(hs).register(http_server)
@ -801,3 +839,4 @@ def register_servlets(hs, http_server):
ListMediaInRoom(hs).register(http_server) ListMediaInRoom(hs).register(http_server)
UserRegisterServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server)
VersionServlet(hs).register(http_server) VersionServlet(hs).register(http_server)
AccountValidityRenewServlet(hs).register(http_server)

View File

@ -17,7 +17,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
@ -39,6 +39,7 @@ class AccountValidityRenewServlet(RestServlet):
self.hs = hs self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler() self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -58,5 +59,33 @@ class AccountValidityRenewServlet(RestServlet):
defer.returnValue(None) defer.returnValue(None)
class AccountValiditySendMailServlet(RestServlet):
PATTERNS = client_v2_patterns("/account_validity/send_mail$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(AccountValiditySendMailServlet, self).__init__()
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
self.account_validity = self.hs.config.account_validity
@defer.inlineCallbacks
def on_POST(self, request):
if not self.account_validity.renew_by_email_enabled:
raise AuthError(403, "Account renewal via email is disabled on this server.")
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
yield self.account_activity_handler.send_renewal_email_to_user(user_id)
defer.returnValue((200, {}))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
AccountValidityRenewServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server)
AccountValiditySendMailServlet(hs).register(http_server)

View File

@ -108,25 +108,30 @@ class RegistrationWorkerStore(SQLBaseStore):
defer.returnValue(res) defer.returnValue(res)
@defer.inlineCallbacks @defer.inlineCallbacks
def renew_account_for_user(self, user_id, new_expiration_ts): def set_account_validity_for_user(self, user_id, expiration_ts, email_sent,
"""Updates the account validity table with a new timestamp for a given renewal_token=None):
user, removes the existing renewal token from this user, and unsets the """Updates the account validity properties of the given account, with the
flag indicating that an email has been sent for renewing this account. given values.
Args: Args:
user_id (str): ID of the user whose account validity to renew. user_id (str): ID of the account to update properties for.
new_expiration_ts: New expiration date, as a timestamp in milliseconds expiration_ts (int): New expiration date, as a timestamp in milliseconds
since epoch. since epoch.
email_sent (bool): True means a renewal email has been sent for this
account and there's no need to send another one for the current validity
period.
renewal_token (str): Renewal token the user can use to extend the validity
of their account. Defaults to no token.
""" """
def renew_account_for_user_txn(txn): def set_account_validity_for_user_txn(txn):
self._simple_update_txn( self._simple_update_txn(
txn=txn, txn=txn,
table="account_validity", table="account_validity",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
updatevalues={ updatevalues={
"expiration_ts_ms": new_expiration_ts, "expiration_ts_ms": expiration_ts,
"email_sent": False, "email_sent": email_sent,
"renewal_token": None, "renewal_token": renewal_token,
}, },
) )
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
@ -134,8 +139,8 @@ class RegistrationWorkerStore(SQLBaseStore):
) )
yield self.runInteraction( yield self.runInteraction(
"renew_account_for_user", "set_account_validity_for_user",
renew_account_for_user_txn, set_account_validity_for_user_txn,
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -201,6 +201,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
admin.register_servlets, admin.register_servlets,
login.register_servlets, login.register_servlets,
sync.register_servlets, sync.register_servlets,
account_validity.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
@ -238,6 +239,68 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result, channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result,
) )
def test_manual_renewal(self):
user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
# If we register the admin user at the beginning of the test, it will
# expire at the same time as the normal user and the renewal request
# will be denied.
self.register_user("admin", "adminpassword", admin=True)
admin_tok = self.login("admin", "adminpassword")
url = "/_matrix/client/unstable/admin/account_validity/validity"
params = {
"user_id": user_id,
}
request_data = json.dumps(params)
request, channel = self.make_request(
b"POST", url, request_data, access_token=admin_tok,
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
request, channel = self.make_request(
b"GET", "/sync", access_token=tok,
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_manual_expire(self):
user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
self.register_user("admin", "adminpassword", admin=True)
admin_tok = self.login("admin", "adminpassword")
url = "/_matrix/client/unstable/admin/account_validity/validity"
params = {
"user_id": user_id,
"expiration_ts": 0,
"enable_renewal_emails": False,
}
request_data = json.dumps(params)
request, channel = self.make_request(
b"POST", url, request_data, access_token=admin_tok,
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
request, channel = self.make_request(
b"GET", "/sync", access_token=tok,
)
self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result,
)
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
@ -287,6 +350,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
return self.hs return self.hs
def test_renewal_email(self): def test_renewal_email(self):
self.email_attempts = []
user_id = self.register_user("kermit", "monkey") user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey") tok = self.login("kermit", "monkey")
# We need to manually add an email address otherwise the handler will do # We need to manually add an email address otherwise the handler will do
@ -297,14 +362,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
validated_at=now, added_at=now, validated_at=now, added_at=now,
)) ))
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
request, channel = self.make_request(
b"GET", "/sync", access_token=tok,
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Move 6 days forward. This should trigger a renewal email to be sent. # Move 6 days forward. This should trigger a renewal email to be sent.
self.reactor.advance(datetime.timedelta(days=6).total_seconds()) self.reactor.advance(datetime.timedelta(days=6).total_seconds())
self.assertEqual(len(self.email_attempts), 1) self.assertEqual(len(self.email_attempts), 1)
@ -326,3 +383,25 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
) )
self.render(request) self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
def test_manual_email_send(self):
self.email_attempts = []
user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
# We need to manually add an email address otherwise the handler will do
# nothing.
now = self.hs.clock.time_msec()
self.get_success(self.store.user_add_threepid(
user_id=user_id, medium="email", address="kermit@example.com",
validated_at=now, added_at=now,
))
request, channel = self.make_request(
b"POST", "/_matrix/client/unstable/account_validity/send_mail",
access_token=tok,
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertEqual(len(self.email_attempts), 1)