Merge pull request #5047 from matrix-org/babolivier/account_expiration

Send out emails with links to extend an account's validity period
This commit is contained in:
Brendan Abolivier 2019-04-17 14:57:39 +01:00 committed by GitHub
commit 91934025b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 699 additions and 43 deletions

1
changelog.d/5047.feature Normal file
View File

@ -0,0 +1 @@
Add time-based account expiration.

View File

@ -646,11 +646,31 @@ uploads_path: "DATADIR/uploads"
# #
#enable_registration: false #enable_registration: false
# Optional account validity parameter. This allows for, e.g., accounts to # Optional account validity configuration. This allows for accounts to be denied
# be denied any request after a given period. # any request after a given period.
#
# ``enabled`` defines whether the account validity feature is enabled. Defaults
# to False.
#
# ``period`` allows setting the period after which an account is valid
# after its registration. When renewing the account, its validity period
# will be extended by this amount of time. This parameter is required when using
# the account validity feature.
#
# ``renew_at`` is the amount of time before an account's expiry date at which
# Synapse will send an email to the account's email address with a renewal link.
# This needs the ``email`` and ``public_baseurl`` configuration sections to be
# filled.
#
# ``renew_email_subject`` is the subject of the email sent out with the renewal
# link. ``%(app)s`` can be used as a placeholder for the ``app_name`` parameter
# from the ``email`` section.
# #
#account_validity: #account_validity:
# enabled: True
# period: 6w # period: 6w
# renew_at: 1w
# renew_email_subject: "Renew your %(app)s account"
# The user must provide all of the below types of 3PID when registering. # The user must provide all of the below types of 3PID when registering.
# #
@ -897,7 +917,7 @@ password_config:
# Enable sending emails for notification events # Enable sending emails for notification events or expiry notices
# Defining a custom URL for Riot is only needed if email notifications # Defining a custom URL for Riot is only needed if email notifications
# should contain links to a self-hosted installation of Riot; when set # should contain links to a self-hosted installation of Riot; when set
# the "app_name" setting is ignored. # the "app_name" setting is ignored.
@ -919,6 +939,9 @@ password_config:
# #template_dir: res/templates # #template_dir: res/templates
# notif_template_html: notif_mail.html # notif_template_html: notif_mail.html
# notif_template_text: notif_mail.txt # notif_template_text: notif_mail.txt
# # Templates for account expiry notices.
# expiry_template_html: notice_expiry.html
# expiry_template_text: notice_expiry.txt
# notif_for_new_users: True # notif_for_new_users: True
# riot_base_url: "http://localhost/riot" # riot_base_url: "http://localhost/riot"

View File

@ -230,8 +230,9 @@ class Auth(object):
# Deny the request if the user account has expired. # Deny the request if the user account has expired.
if self._account_validity.enabled: if self._account_validity.enabled:
expiration_ts = yield self.store.get_expiration_ts_for_user(user) user_id = user.to_string()
if self.clock.time_msec() >= expiration_ts: expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
if expiration_ts and self.clock.time_msec() >= expiration_ts:
raise AuthError( raise AuthError(
403, 403,
"User account has expired", "User account has expired",

View File

@ -71,6 +71,8 @@ class EmailConfig(Config):
self.email_notif_from = email_config["notif_from"] self.email_notif_from = email_config["notif_from"]
self.email_notif_template_html = email_config["notif_template_html"] self.email_notif_template_html = email_config["notif_template_html"]
self.email_notif_template_text = email_config["notif_template_text"] self.email_notif_template_text = email_config["notif_template_text"]
self.email_expiry_template_html = email_config["expiry_template_html"]
self.email_expiry_template_text = email_config["expiry_template_text"]
template_dir = email_config.get("template_dir") template_dir = email_config.get("template_dir")
# we need an absolute path, because we change directory after starting (and # we need an absolute path, because we change directory after starting (and
@ -120,7 +122,7 @@ class EmailConfig(Config):
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
return """ return """
# Enable sending emails for notification events # Enable sending emails for notification events or expiry notices
# Defining a custom URL for Riot is only needed if email notifications # Defining a custom URL for Riot is only needed if email notifications
# should contain links to a self-hosted installation of Riot; when set # should contain links to a self-hosted installation of Riot; when set
# the "app_name" setting is ignored. # the "app_name" setting is ignored.
@ -142,6 +144,9 @@ class EmailConfig(Config):
# #template_dir: res/templates # #template_dir: res/templates
# notif_template_html: notif_mail.html # notif_template_html: notif_mail.html
# notif_template_text: notif_mail.txt # notif_template_text: notif_mail.txt
# # Templates for account expiry notices.
# expiry_template_html: notice_expiry.html
# expiry_template_text: notice_expiry.txt
# notif_for_new_users: True # notif_for_new_users: True
# riot_base_url: "http://localhost/riot" # riot_base_url: "http://localhost/riot"
""" """

View File

@ -21,12 +21,26 @@ from synapse.util.stringutils import random_string_with_symbols
class AccountValidityConfig(Config): class AccountValidityConfig(Config):
def __init__(self, config): def __init__(self, config, synapse_config):
self.enabled = (len(config) > 0) self.enabled = config.get("enabled", False)
self.renew_by_email_enabled = ("renew_at" in config)
period = config.get("period", None) if self.enabled:
if period: if "period" in config:
self.period = self.parse_duration(period) self.period = self.parse_duration(config["period"])
else:
raise ConfigError("'period' is required when using account validity")
if "renew_at" in config:
self.renew_at = self.parse_duration(config["renew_at"])
if "renew_email_subject" in config:
self.renew_email_subject = config["renew_email_subject"]
else:
self.renew_email_subject = "Renew your %(app)s account"
if self.renew_by_email_enabled and "public_baseurl" not in synapse_config:
raise ConfigError("Can't send renewal emails without 'public_baseurl'")
class RegistrationConfig(Config): class RegistrationConfig(Config):
@ -40,7 +54,9 @@ class RegistrationConfig(Config):
strtobool(str(config["disable_registration"])) strtobool(str(config["disable_registration"]))
) )
self.account_validity = AccountValidityConfig(config.get("account_validity", {})) self.account_validity = AccountValidityConfig(
config.get("account_validity", {}), config,
)
self.registrations_require_3pid = config.get("registrations_require_3pid", []) self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", []) self.allowed_local_3pids = config.get("allowed_local_3pids", [])
@ -87,11 +103,31 @@ class RegistrationConfig(Config):
# #
#enable_registration: false #enable_registration: false
# Optional account validity parameter. This allows for, e.g., accounts to # Optional account validity configuration. This allows for accounts to be denied
# be denied any request after a given period. # any request after a given period.
#
# ``enabled`` defines whether the account validity feature is enabled. Defaults
# to False.
#
# ``period`` allows setting the period after which an account is valid
# after its registration. When renewing the account, its validity period
# will be extended by this amount of time. This parameter is required when using
# the account validity feature.
#
# ``renew_at`` is the amount of time before an account's expiry date at which
# Synapse will send an email to the account's email address with a renewal link.
# This needs the ``email`` and ``public_baseurl`` configuration sections to be
# filled.
#
# ``renew_email_subject`` is the subject of the email sent out with the renewal
# link. ``%%(app)s`` can be used as a placeholder for the ``app_name`` parameter
# from the ``email`` section.
# #
#account_validity: #account_validity:
# enabled: True
# period: 6w # period: 6w
# renew_at: 1w
# renew_email_subject: "Renew your %%(app)s account"
# The user must provide all of the below types of 3PID when registering. # The user must provide all of the below types of 3PID when registering.
# #

View File

@ -0,0 +1,228 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import email.mime.multipart
import email.utils
import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.types import UserID
from synapse.util import stringutils
from synapse.util.logcontext import make_deferred_yieldable
try:
from synapse.push.mailer import load_jinja2_templates
except ImportError:
load_jinja2_templates = None
logger = logging.getLogger(__name__)
class AccountValidityHandler(object):
def __init__(self, hs):
self.hs = hs
self.store = self.hs.get_datastore()
self.sendmail = self.hs.get_sendmail()
self.clock = self.hs.get_clock()
self._account_validity = self.hs.config.account_validity
if self._account_validity.renew_by_email_enabled and load_jinja2_templates:
# Don't do email-specific configuration if renewal by email is disabled.
try:
app_name = self.hs.config.email_app_name
self._subject = self._account_validity.renew_email_subject % {
"app": app_name,
}
self._from_string = self.hs.config.email_notif_from % {
"app": app_name,
}
except Exception:
# If substitution failed, fall back to the bare strings.
self._subject = self._account_validity.renew_email_subject
self._from_string = self.hs.config.email_notif_from
self._raw_from = email.utils.parseaddr(self._from_string)[1]
self._template_html, self._template_text = load_jinja2_templates(
config=self.hs.config,
template_html_name=self.hs.config.email_expiry_template_html,
template_text_name=self.hs.config.email_expiry_template_text,
)
# Check the renewal emails to send and send them every 30min.
self.clock.looping_call(
self.send_renewal_emails,
30 * 60 * 1000,
)
@defer.inlineCallbacks
def send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity``
configuration, and sends renewal emails to all of these users as long as they
have an email 3PID attached to their account.
"""
expiring_users = yield self.store.get_users_expiring_soon()
if expiring_users:
for user in expiring_users:
yield self._send_renewal_email(
user_id=user["user_id"],
expiration_ts=user["expiration_ts_ms"],
)
@defer.inlineCallbacks
def _send_renewal_email(self, user_id, expiration_ts):
"""Sends out a renewal email to every email address attached to the given user
with a unique link allowing them to renew their account.
Args:
user_id (str): ID of the user to send email(s) to.
expiration_ts (int): Timestamp in milliseconds for the expiration date of
this user's account (used in the email templates).
"""
addresses = yield self._get_email_addresses_for_user(user_id)
# Stop right here if the user doesn't have at least one email address.
# In this case, they will have to ask their server admin to renew their
# account manually.
if not addresses:
return
try:
user_display_name = yield self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
if user_display_name is None:
user_display_name = user_id
except StoreError:
user_display_name = user_id
renewal_token = yield self._get_renewal_token(user_id)
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
self.hs.config.public_baseurl,
renewal_token,
)
template_vars = {
"display_name": user_display_name,
"expiration_ts": expiration_ts,
"url": url,
}
html_text = self._template_html.render(**template_vars)
html_part = MIMEText(html_text, "html", "utf8")
plain_text = self._template_text.render(**template_vars)
text_part = MIMEText(plain_text, "plain", "utf8")
for address in addresses:
raw_to = email.utils.parseaddr(address)[1]
multipart_msg = MIMEMultipart('alternative')
multipart_msg['Subject'] = self._subject
multipart_msg['From'] = self._from_string
multipart_msg['To'] = address
multipart_msg['Date'] = email.utils.formatdate()
multipart_msg['Message-ID'] = email.utils.make_msgid()
multipart_msg.attach(text_part)
multipart_msg.attach(html_part)
logger.info("Sending renewal email to %s", address)
yield make_deferred_yieldable(self.sendmail(
self.hs.config.email_smtp_host,
self._raw_from, raw_to, multipart_msg.as_string().encode('utf8'),
reactor=self.hs.get_reactor(),
port=self.hs.config.email_smtp_port,
requireAuthentication=self.hs.config.email_smtp_user is not None,
username=self.hs.config.email_smtp_user,
password=self.hs.config.email_smtp_pass,
requireTransportSecurity=self.hs.config.require_transport_security
))
yield self.store.set_renewal_mail_status(
user_id=user_id,
email_sent=True,
)
@defer.inlineCallbacks
def _get_email_addresses_for_user(self, user_id):
"""Retrieve the list of email addresses attached to a user's account.
Args:
user_id (str): ID of the user to lookup email addresses for.
Returns:
defer.Deferred[list[str]]: Email addresses for this account.
"""
threepids = yield self.store.user_get_threepids(user_id)
addresses = []
for threepid in threepids:
if threepid["medium"] == "email":
addresses.append(threepid["address"])
defer.returnValue(addresses)
@defer.inlineCallbacks
def _get_renewal_token(self, user_id):
"""Generates a 32-byte long random string that will be inserted into the
user's renewal email's unique link, then saves it into the database.
Args:
user_id (str): ID of the user to generate a string for.
Returns:
defer.Deferred[str]: The generated string.
Raises:
StoreError(500): Couldn't generate a unique string after 5 attempts.
"""
attempts = 0
while attempts < 5:
try:
renewal_token = stringutils.random_string(32)
yield self.store.set_renewal_token_for_user(user_id, renewal_token)
defer.returnValue(renewal_token)
except StoreError:
attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
@defer.inlineCallbacks
def renew_account(self, renewal_token):
"""Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration.
Args:
renewal_token (str): Token sent with the renewal request.
"""
user_id = yield self.store.get_user_from_renewal_token(renewal_token)
logger.debug("Renewing an account for user %s", user_id)
new_expiration_date = self.clock.time_msec() + self._account_validity.period
yield self.store.renew_account_for_user(
user_id=user_id,
new_expiration_ts=new_expiration_date,
)

View File

@ -521,11 +521,11 @@ def format_ts_filter(value, format):
return time.strftime(format, time.localtime(value / 1000)) return time.strftime(format, time.localtime(value / 1000))
def load_jinja2_templates(config): def load_jinja2_templates(config, template_html_name, template_text_name):
"""Load the jinja2 email templates from disk """Load the jinja2 email templates from disk
Returns: Returns:
(notif_template_html, notif_template_text) (template_html, template_text)
""" """
logger.info("loading email templates from '%s'", config.email_template_dir) logger.info("loading email templates from '%s'", config.email_template_dir)
loader = jinja2.FileSystemLoader(config.email_template_dir) loader = jinja2.FileSystemLoader(config.email_template_dir)
@ -533,14 +533,10 @@ def load_jinja2_templates(config):
env.filters["format_ts"] = format_ts_filter env.filters["format_ts"] = format_ts_filter
env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config) env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config)
notif_template_html = env.get_template( template_html = env.get_template(template_html_name)
config.email_notif_template_html template_text = env.get_template(template_text_name)
)
notif_template_text = env.get_template(
config.email_notif_template_text
)
return notif_template_html, notif_template_text return template_html, template_text
def _create_mxc_to_http_filter(config): def _create_mxc_to_http_filter(config):

View File

@ -44,7 +44,11 @@ class PusherFactory(object):
if hs.config.email_enable_notifs: if hs.config.email_enable_notifs:
self.mailers = {} # app_name -> Mailer self.mailers = {} # app_name -> Mailer
templates = load_jinja2_templates(hs.config) templates = load_jinja2_templates(
config=hs.config,
template_html_name=hs.config.email_notif_template_html,
template_text_name=hs.config.email_notif_template_text,
)
self.notif_template_html, self.notif_template_text = templates self.notif_template_html, self.notif_template_text = templates
self.pusher_types["email"] = self._create_email_pusher self.pusher_types["email"] = self._create_email_pusher

View File

@ -0,0 +1,4 @@
.noticetext {
margin-top: 10px;
margin-bottom: 10px;
}

View File

@ -0,0 +1,43 @@
<!doctype html>
<html lang="en">
<head>
<style type="text/css">
{% include 'mail.css' without context %}
{% include "mail-%s.css" % app_name ignore missing without context %}
{% include 'mail-expiry.css' without context %}
</style>
</head>
<body>
<table id="page">
<tr>
<td> </td>
<td id="inner">
<table class="header">
<tr>
<td>
<div class="salutation">Hi {{ display_name }},</div>
</td>
<td class="logo">
{% if app_name == "Riot" %}
<img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
{% elif app_name == "Vector" %}
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
{% else %}
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
{% endif %}
</td>
</tr>
<tr>
<td colspan="2">
<div class="noticetext">Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.</div>
<div class="noticetext">To extend the validity of your account, please click on the link bellow (or copy and paste it into a new browser tab):</div>
<div class="noticetext"><a href="{{ url }}">{{ url }}</a></div>
</td>
</tr>
</table>
</td>
<td> </td>
</tr>
</table>
</body>
</html>

View File

@ -0,0 +1,7 @@
Hi {{ display_name }},
Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.
To extend the validity of your account, please click on the link bellow (or copy and paste it to a new browser tab):
{{ url }}

View File

@ -33,6 +33,7 @@ from synapse.rest.client.v1 import (
from synapse.rest.client.v2_alpha import ( from synapse.rest.client.v2_alpha import (
account, account,
account_data, account_data,
account_validity,
auth, auth,
capabilities, capabilities,
devices, devices,
@ -109,3 +110,4 @@ class ClientRestResource(JsonResource):
groups.register_servlets(hs, client_resource) groups.register_servlets(hs, client_resource)
room_upgrade_rest_servlet.register_servlets(hs, client_resource) room_upgrade_rest_servlet.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)

View File

@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class AccountValidityRenewServlet(RestServlet):
PATTERNS = client_v2_patterns("/account_validity/renew$")
SUCCESS_HTML = b"<html><body>Your account has been successfully renewed.</body><html>"
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(AccountValidityRenewServlet, self).__init__()
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
@defer.inlineCallbacks
def on_GET(self, request):
if b"token" not in request.args:
raise SynapseError(400, "Missing renewal token")
renewal_token = request.args[b"token"][0]
yield self.account_activity_handler.renew_account(renewal_token.decode('utf8'))
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%d" % (
len(AccountValidityRenewServlet.SUCCESS_HTML),
))
request.write(AccountValidityRenewServlet.SUCCESS_HTML)
finish_request(request)
defer.returnValue(None)
def register_servlets(hs, http_server):
AccountValidityRenewServlet(hs).register(http_server)

View File

@ -47,6 +47,7 @@ from synapse.federation.transport.client import TransportLayerClient
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
from synapse.groups.groups_server import GroupsServerHandler from synapse.groups.groups_server import GroupsServerHandler
from synapse.handlers import Handlers from synapse.handlers import Handlers
from synapse.handlers.account_validity import AccountValidityHandler
from synapse.handlers.acme import AcmeHandler from synapse.handlers.acme import AcmeHandler
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGenerator from synapse.handlers.auth import AuthHandler, MacaroonGenerator
@ -183,6 +184,7 @@ class HomeServer(object):
'room_context_handler', 'room_context_handler',
'sendmail', 'sendmail',
'registration_handler', 'registration_handler',
'account_validity_handler',
] ]
REQUIRED_ON_MASTER_STARTUP = [ REQUIRED_ON_MASTER_STARTUP = [
@ -506,6 +508,9 @@ class HomeServer(object):
def build_registration_handler(self): def build_registration_handler(self):
return RegistrationHandler(self) return RegistrationHandler(self)
def build_account_validity_handler(self):
return AccountValidityHandler(self)
def remove_pusher(self, app_id, push_key, user_id): def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View File

@ -32,6 +32,7 @@ class RegistrationWorkerStore(SQLBaseStore):
super(RegistrationWorkerStore, self).__init__(db_conn, hs) super(RegistrationWorkerStore, self).__init__(db_conn, hs)
self.config = hs.config self.config = hs.config
self.clock = hs.get_clock()
@cached() @cached()
def get_user_by_id(self, user_id): def get_user_by_id(self, user_id):
@ -87,25 +88,156 @@ class RegistrationWorkerStore(SQLBaseStore):
) )
@cachedInlineCallbacks() @cachedInlineCallbacks()
def get_expiration_ts_for_user(self, user): def get_expiration_ts_for_user(self, user_id):
"""Get the expiration timestamp for the account bearing a given user ID. """Get the expiration timestamp for the account bearing a given user ID.
Args: Args:
user (str): The ID of the user. user_id (str): The ID of the user.
Returns: Returns:
defer.Deferred: None, if the account has no expiration timestamp, defer.Deferred: None, if the account has no expiration timestamp,
otherwise int representation of the timestamp (as a number of otherwise int representation of the timestamp (as a number of
milliseconds since epoch). milliseconds since epoch).
""" """
res = yield self._simple_select_one_onecol( res = yield self._simple_select_one_onecol(
table="account_validity", table="account_validity",
keyvalues={"user_id": user.to_string()}, keyvalues={"user_id": user_id},
retcol="expiration_ts_ms", retcol="expiration_ts_ms",
allow_none=True, allow_none=True,
desc="get_expiration_date_for_user", desc="get_expiration_ts_for_user",
) )
defer.returnValue(res) defer.returnValue(res)
@defer.inlineCallbacks
def renew_account_for_user(self, user_id, new_expiration_ts):
"""Updates the account validity table with a new timestamp for a given
user, removes the existing renewal token from this user, and unsets the
flag indicating that an email has been sent for renewing this account.
Args:
user_id (str): ID of the user whose account validity to renew.
new_expiration_ts: New expiration date, as a timestamp in milliseconds
since epoch.
"""
def renew_account_for_user_txn(txn):
self._simple_update_txn(
txn=txn,
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={
"expiration_ts_ms": new_expiration_ts,
"email_sent": False,
"renewal_token": None,
},
)
self._invalidate_cache_and_stream(
txn, self.get_expiration_ts_for_user, (user_id,),
)
yield self.runInteraction(
"renew_account_for_user",
renew_account_for_user_txn,
)
@defer.inlineCallbacks
def set_renewal_token_for_user(self, user_id, renewal_token):
"""Defines a renewal token for a given user.
Args:
user_id (str): ID of the user to set the renewal token for.
renewal_token (str): Random unique string that will be used to renew the
user's account.
Raises:
StoreError: The provided token is already set for another user.
"""
yield self._simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"renewal_token": renewal_token},
desc="set_renewal_token_for_user",
)
@defer.inlineCallbacks
def get_user_from_renewal_token(self, renewal_token):
"""Get a user ID from a renewal token.
Args:
renewal_token (str): The renewal token to perform the lookup with.
Returns:
defer.Deferred[str]: The ID of the user to which the token belongs.
"""
res = yield self._simple_select_one_onecol(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcol="user_id",
desc="get_user_from_renewal_token",
)
defer.returnValue(res)
@defer.inlineCallbacks
def get_renewal_token_for_user(self, user_id):
"""Get the renewal token associated with a given user ID.
Args:
user_id (str): The user ID to lookup a token for.
Returns:
defer.Deferred[str]: The renewal token associated with this user ID.
"""
res = yield self._simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="renewal_token",
desc="get_renewal_token_for_user",
)
defer.returnValue(res)
@defer.inlineCallbacks
def get_users_expiring_soon(self):
"""Selects users whose account will expire in the [now, now + renew_at] time
window (see configuration for account_validity for information on what renew_at
refers to).
Returns:
Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
"""
def select_users_txn(txn, now_ms, renew_at):
sql = (
"SELECT user_id, expiration_ts_ms FROM account_validity"
" WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
)
values = [False, now_ms, renew_at]
txn.execute(sql, values)
return self.cursor_to_dict(txn)
res = yield self.runInteraction(
"get_users_expiring_soon",
select_users_txn,
self.clock.time_msec(), self.config.account_validity.renew_at,
)
defer.returnValue(res)
@defer.inlineCallbacks
def set_renewal_mail_status(self, user_id, email_sent):
"""Sets or unsets the flag that indicates whether a renewal email has been sent
to the user (and the user hasn't renewed their account yet).
Args:
user_id (str): ID of the user to set/unset the flag for.
email_sent (bool): Flag which indicates whether a renewal email has been sent
to this user.
"""
yield self._simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"email_sent": email_sent},
desc="set_renewal_mail_status",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def is_server_admin(self, user): def is_server_admin(self, user):
res = yield self._simple_select_one_onecol( res = yield self._simple_select_one_onecol(
@ -584,20 +716,22 @@ class RegistrationStore(
}, },
) )
if self._account_validity.enabled:
now_ms = self.clock.time_msec()
expiration_ts = now_ms + self._account_validity.period
self._simple_insert_txn(
txn,
"account_validity",
values={
"user_id": user_id,
"expiration_ts_ms": expiration_ts,
}
)
except self.database_engine.module.IntegrityError: except self.database_engine.module.IntegrityError:
raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
if self._account_validity.enabled:
now_ms = self.clock.time_msec()
expiration_ts = now_ms + self._account_validity.period
self._simple_insert_txn(
txn,
"account_validity",
values={
"user_id": user_id,
"expiration_ts_ms": expiration_ts,
"email_sent": False,
}
)
if token: if token:
# it's possible for this to get a conflict, but only for a single user # it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID # since tokens are namespaced based on their user ID

View File

@ -13,8 +13,15 @@
* limitations under the License. * limitations under the License.
*/ */
DROP TABLE IF EXISTS account_validity;
-- Track what users are in public rooms. -- Track what users are in public rooms.
CREATE TABLE IF NOT EXISTS account_validity ( CREATE TABLE IF NOT EXISTS account_validity (
user_id TEXT PRIMARY KEY, user_id TEXT PRIMARY KEY,
expiration_ts_ms BIGINT NOT NULL expiration_ts_ms BIGINT NOT NULL,
email_sent BOOLEAN NOT NULL,
renewal_token TEXT
); );
CREATE INDEX account_validity_email_sent_idx ON account_validity(email_sent, expiration_ts_ms)
CREATE UNIQUE INDEX account_validity_renewal_string_idx ON account_validity(renewal_token)

View File

@ -1,14 +1,22 @@
import datetime import datetime
import json import json
import os
import pkg_resources
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import admin, login from synapse.rest.client.v1 import admin, login
from synapse.rest.client.v2_alpha import register, sync from synapse.rest.client.v2_alpha import account_validity, register, sync
from tests import unittest from tests import unittest
try:
from synapse.push.mailer import load_jinja2_templates
except ImportError:
load_jinja2_templates = None
class RegisterRestServletTestCase(unittest.HomeserverTestCase): class RegisterRestServletTestCase(unittest.HomeserverTestCase):
@ -197,6 +205,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
config = self.default_config() config = self.default_config()
# Test for account expiring after a week.
config.enable_registration = True config.enable_registration = True
config.account_validity.enabled = True config.account_validity.enabled = True
config.account_validity.period = 604800000 # Time in ms for 1 week config.account_validity.period = 604800000 # Time in ms for 1 week
@ -228,3 +237,92 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.assertEquals( self.assertEquals(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result, channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result,
) )
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
skip = "No Jinja installed" if not load_jinja2_templates else None
servlets = [
register.register_servlets,
admin.register_servlets,
login.register_servlets,
sync.register_servlets,
account_validity.register_servlets,
]
def make_homeserver(self, reactor, clock):
config = self.default_config()
# Test for account expiring after a week and renewal emails being sent 2
# days before expiry.
config.enable_registration = True
config.account_validity.enabled = True
config.account_validity.renew_by_email_enabled = True
config.account_validity.period = 604800000 # Time in ms for 1 week
config.account_validity.renew_at = 172800000 # Time in ms for 2 days
config.account_validity.renew_email_subject = "Renew your account"
# Email config.
self.email_attempts = []
def sendmail(*args, **kwargs):
self.email_attempts.append((args, kwargs))
return
config.email_template_dir = os.path.abspath(
pkg_resources.resource_filename('synapse', 'res/templates')
)
config.email_expiry_template_html = "notice_expiry.html"
config.email_expiry_template_text = "notice_expiry.txt"
config.email_smtp_host = "127.0.0.1"
config.email_smtp_port = 20
config.require_transport_security = False
config.email_smtp_user = None
config.email_smtp_pass = None
config.email_notif_from = "test@example.com"
self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
self.store = self.hs.get_datastore()
return self.hs
def test_renewal_email(self):
user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
# We need to manually add an email address otherwise the handler will do
# nothing.
now = self.hs.clock.time_msec()
self.get_success(self.store.user_add_threepid(
user_id=user_id, medium="email", address="kermit@example.com",
validated_at=now, added_at=now,
))
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
request, channel = self.make_request(
b"GET", "/sync", access_token=tok,
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Move 6 days forward. This should trigger a renewal email to be sent.
self.reactor.advance(datetime.timedelta(days=6).total_seconds())
self.assertEqual(len(self.email_attempts), 1)
# Retrieving the URL from the email is too much pain for now, so we
# retrieve the token from the DB.
renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
request, channel = self.make_request(b"GET", url)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
# Move 3 days forward. If the renewal failed, every authed request with
# our access token should be denied from now, otherwise they should
# succeed.
self.reactor.advance(datetime.timedelta(days=3).total_seconds())
request, channel = self.make_request(
b"GET", "/sync", access_token=tok,
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)